1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-18 11:40:18 +01:00

integration test changes to reach the train part (#254)

* integration test changes to override heatup to 1000 steps +  run each preset for 30 sec (to make sure we reach the train part)

* fixes to failing presets uncovered with this change + changes in the golden testing to properly test BatchRL

* fix for rainbow dqn

* fix to gym_environment (due to a change in Gym 0.12.1) + fix for rainbow DQN + some bug-fix in utils.squeeze_list

* fix for NEC agent
This commit is contained in:
Gal Leibovich
2019-03-27 21:14:19 +02:00
committed by GitHub
parent 6e08c55ad5
commit 310d31c227
8 changed files with 28 additions and 17 deletions

View File

@@ -29,7 +29,7 @@ schedule_params.heatup_steps = EnvironmentSteps(DATASET_SIZE)
#########
# TODO add a preset which uses a dataset to train a BatchRL graph. e.g. save a cartpole dataset in a csv format.
agent_params = DDQNAgentParameters()
agent_params.network_wrappers['main'].batch_size = 1024
agent_params.network_wrappers['main'].batch_size = 128
# DQN params
# agent_params.algorithm.num_steps_between_copying_online_weights_to_target = TrainingSteps(100)
@@ -81,11 +81,11 @@ env_params = GymVectorEnvironment(level='CartPole-v0')
preset_validation_params = PresetValidationParameters()
preset_validation_params.test = True
preset_validation_params.min_reward_threshold = 150
preset_validation_params.max_episodes_to_achieve_reward = 250
preset_validation_params.max_episodes_to_achieve_reward = 2000
graph_manager = BatchRLGraphManager(agent_params=agent_params, env_params=env_params,
schedule_params=schedule_params,
vis_params=VisualizationParameters(dump_signals_to_csv_every_x_episodes=1),
preset_validation_params=preset_validation_params,
reward_model_num_epochs=50,
reward_model_num_epochs=30,
train_to_eval_ratio=0.8)