mirror of
https://github.com/gryf/coach.git
synced 2026-02-13 04:15:45 +01:00
Create a dataset using an agent (#306)
Generate a dataset using an agent (allowing to select between this and a random dataset)
This commit is contained in:
@@ -16,10 +16,6 @@ def test_init():
|
||||
action_space = DiscreteActionSpace(3)
|
||||
noise_schedule = LinearSchedule(1.0, 1.0, 1000)
|
||||
|
||||
# additive noise doesn't work for discrete controls
|
||||
with pytest.raises(ValueError):
|
||||
policy = AdditiveNoise(action_space, noise_schedule, 0)
|
||||
|
||||
# additive noise requires a bounded range for the actions
|
||||
action_space = BoxActionSpace(np.array([10]))
|
||||
with pytest.raises(ValueError):
|
||||
|
||||
@@ -21,14 +21,14 @@ def test_get_action():
|
||||
# verify that test phase gives greedy actions (evaluation_epsilon = 0)
|
||||
policy.change_phase(RunPhase.TEST)
|
||||
for i in range(100):
|
||||
best_action = policy.get_action(np.array([10, 20, 30]))
|
||||
best_action, _ = policy.get_action(np.array([10, 20, 30]))
|
||||
assert best_action == 2
|
||||
|
||||
# verify that train phase gives uniform actions (exploration = 1)
|
||||
policy.change_phase(RunPhase.TRAIN)
|
||||
counters = np.array([0, 0, 0])
|
||||
for i in range(30000):
|
||||
best_action = policy.get_action(np.array([10, 20, 30]))
|
||||
best_action, _ = policy.get_action(np.array([10, 20, 30]))
|
||||
counters[best_action] += 1
|
||||
assert np.all(counters > 9500) # this is noisy so we allow 5% error
|
||||
|
||||
|
||||
@@ -15,7 +15,7 @@ def test_get_action():
|
||||
action_space = DiscreteActionSpace(3)
|
||||
policy = Greedy(action_space)
|
||||
|
||||
best_action = policy.get_action(np.array([10, 20, 30]))
|
||||
best_action, _ = policy.get_action(np.array([10, 20, 30]))
|
||||
assert best_action == 2
|
||||
|
||||
# continuous control
|
||||
|
||||
@@ -16,10 +16,6 @@ def test_init():
|
||||
# discrete control
|
||||
action_space = DiscreteActionSpace(3)
|
||||
|
||||
# OU process doesn't work for discrete controls
|
||||
with pytest.raises(ValueError):
|
||||
policy = OUProcess(action_space, mu=0, theta=0.1, sigma=0.2, dt=0.01)
|
||||
|
||||
|
||||
@pytest.mark.unit_test
|
||||
def test_get_action():
|
||||
|
||||
Reference in New Issue
Block a user