mirror of
https://github.com/gryf/coach.git
synced 2026-02-17 14:45:50 +01:00
more clear names for methods of Space (#181)
* rename Space.val_matches_space_definition -> contains; Space.is_point_in_space_shape -> valid_index * rename valid_index -> is_valid_index
This commit is contained in:
@@ -32,7 +32,7 @@ def test_filter():
|
||||
|
||||
result = filter.filter(action)
|
||||
assert np.all(result == np.array([[41.5, 0], [83., 41.5]]))
|
||||
assert output_space.val_matches_space_definition(result)
|
||||
assert output_space.contains(result)
|
||||
|
||||
# force int bins
|
||||
filter = AttentionDiscretization(2, force_int_bins=True)
|
||||
|
||||
@@ -26,7 +26,7 @@ def test_filter():
|
||||
|
||||
result = filter.filter(action)
|
||||
assert result == [7.5]
|
||||
assert output_space.val_matches_space_definition(result)
|
||||
assert output_space.contains(result)
|
||||
|
||||
# 2 dimensional box
|
||||
filter = BoxDiscretization(3)
|
||||
@@ -42,4 +42,4 @@ def test_filter():
|
||||
|
||||
result = filter.filter(action)
|
||||
assert result == [5., 15.]
|
||||
assert output_space.val_matches_space_definition(result)
|
||||
assert output_space.contains(result)
|
||||
|
||||
@@ -23,5 +23,5 @@ def test_filter():
|
||||
action = np.array([2])
|
||||
result = filter.filter(action)
|
||||
assert result == np.array([12])
|
||||
assert output_space.val_matches_space_definition(result)
|
||||
assert output_space.contains(result)
|
||||
|
||||
|
||||
@@ -25,5 +25,5 @@ def test_filter():
|
||||
action = np.array([12])
|
||||
result = filter.filter(action)
|
||||
assert result == np.array([11])
|
||||
assert output_space.val_matches_space_definition(result)
|
||||
assert output_space.contains(result)
|
||||
|
||||
|
||||
@@ -132,18 +132,18 @@ def test_agent_selection():
|
||||
def test_observation_space():
|
||||
observation_space = ObservationSpace(np.array([1, 10]), -10, 10)
|
||||
|
||||
# testing that val_matches_space_definition works
|
||||
assert observation_space.val_matches_space_definition(np.ones([1, 10]))
|
||||
assert not observation_space.val_matches_space_definition(np.ones([2, 10]))
|
||||
assert not observation_space.val_matches_space_definition(np.ones([1, 10]) * 100)
|
||||
assert not observation_space.val_matches_space_definition(np.ones([1, 1, 10]))
|
||||
# testing that contains works
|
||||
assert observation_space.contains(np.ones([1, 10]))
|
||||
assert not observation_space.contains(np.ones([2, 10]))
|
||||
assert not observation_space.contains(np.ones([1, 10]) * 100)
|
||||
assert not observation_space.contains(np.ones([1, 1, 10]))
|
||||
|
||||
# is_point_in_space_shape
|
||||
assert observation_space.is_point_in_space_shape(np.array([0, 9]))
|
||||
assert observation_space.is_point_in_space_shape(np.array([0, 0]))
|
||||
assert not observation_space.is_point_in_space_shape(np.array([1, 8]))
|
||||
assert not observation_space.is_point_in_space_shape(np.array([0, 10]))
|
||||
assert not observation_space.is_point_in_space_shape(np.array([-1, 6]))
|
||||
# is_valid_index
|
||||
assert observation_space.is_valid_index(np.array([0, 9]))
|
||||
assert observation_space.is_valid_index(np.array([0, 0]))
|
||||
assert not observation_space.is_valid_index(np.array([1, 8]))
|
||||
assert not observation_space.is_valid_index(np.array([0, 10]))
|
||||
assert not observation_space.is_valid_index(np.array([-1, 6]))
|
||||
|
||||
|
||||
@pytest.mark.unit_test
|
||||
|
||||
Reference in New Issue
Block a user