1
0
mirror of https://github.com/gryf/coach.git synced 2026-02-17 14:45:50 +01:00

more clear names for methods of Space (#181)

* rename Space.val_matches_space_definition -> contains; Space.is_point_in_space_shape -> valid_index
* rename valid_index -> is_valid_index
This commit is contained in:
Zach Dwiel
2019-01-14 15:02:53 -05:00
committed by GitHub
parent 0ccc333d77
commit cd812b0d25
19 changed files with 77 additions and 62 deletions

View File

@@ -32,7 +32,7 @@ def test_filter():
result = filter.filter(action)
assert np.all(result == np.array([[41.5, 0], [83., 41.5]]))
assert output_space.val_matches_space_definition(result)
assert output_space.contains(result)
# force int bins
filter = AttentionDiscretization(2, force_int_bins=True)

View File

@@ -26,7 +26,7 @@ def test_filter():
result = filter.filter(action)
assert result == [7.5]
assert output_space.val_matches_space_definition(result)
assert output_space.contains(result)
# 2 dimensional box
filter = BoxDiscretization(3)
@@ -42,4 +42,4 @@ def test_filter():
result = filter.filter(action)
assert result == [5., 15.]
assert output_space.val_matches_space_definition(result)
assert output_space.contains(result)

View File

@@ -23,5 +23,5 @@ def test_filter():
action = np.array([2])
result = filter.filter(action)
assert result == np.array([12])
assert output_space.val_matches_space_definition(result)
assert output_space.contains(result)

View File

@@ -25,5 +25,5 @@ def test_filter():
action = np.array([12])
result = filter.filter(action)
assert result == np.array([11])
assert output_space.val_matches_space_definition(result)
assert output_space.contains(result)

View File

@@ -132,18 +132,18 @@ def test_agent_selection():
def test_observation_space():
observation_space = ObservationSpace(np.array([1, 10]), -10, 10)
# testing that val_matches_space_definition works
assert observation_space.val_matches_space_definition(np.ones([1, 10]))
assert not observation_space.val_matches_space_definition(np.ones([2, 10]))
assert not observation_space.val_matches_space_definition(np.ones([1, 10]) * 100)
assert not observation_space.val_matches_space_definition(np.ones([1, 1, 10]))
# testing that contains works
assert observation_space.contains(np.ones([1, 10]))
assert not observation_space.contains(np.ones([2, 10]))
assert not observation_space.contains(np.ones([1, 10]) * 100)
assert not observation_space.contains(np.ones([1, 1, 10]))
# is_point_in_space_shape
assert observation_space.is_point_in_space_shape(np.array([0, 9]))
assert observation_space.is_point_in_space_shape(np.array([0, 0]))
assert not observation_space.is_point_in_space_shape(np.array([1, 8]))
assert not observation_space.is_point_in_space_shape(np.array([0, 10]))
assert not observation_space.is_point_in_space_shape(np.array([-1, 6]))
# is_valid_index
assert observation_space.is_valid_index(np.array([0, 9]))
assert observation_space.is_valid_index(np.array([0, 0]))
assert not observation_space.is_valid_index(np.array([1, 8]))
assert not observation_space.is_valid_index(np.array([0, 10]))
assert not observation_space.is_valid_index(np.array([-1, 6]))
@pytest.mark.unit_test