mirror of
https://github.com/gryf/coach.git
synced 2025-12-18 19:50:17 +01:00
integration test changes to reach the train part (#254)
* integration test changes to override heatup to 1000 steps + run each preset for 30 sec (to make sure we reach the train part) * fixes to failing presets uncovered with this change + changes in the golden testing to properly test BatchRL * fix for rainbow dqn * fix to gym_environment (due to a change in Gym 0.12.1) + fix for rainbow DQN + some bug-fix in utils.squeeze_list * fix for NEC agent
This commit is contained in:
@@ -29,7 +29,7 @@ schedule_params.heatup_steps = EnvironmentSteps(DATASET_SIZE)
|
||||
#########
|
||||
# TODO add a preset which uses a dataset to train a BatchRL graph. e.g. save a cartpole dataset in a csv format.
|
||||
agent_params = DDQNAgentParameters()
|
||||
agent_params.network_wrappers['main'].batch_size = 1024
|
||||
agent_params.network_wrappers['main'].batch_size = 128
|
||||
|
||||
# DQN params
|
||||
# agent_params.algorithm.num_steps_between_copying_online_weights_to_target = TrainingSteps(100)
|
||||
@@ -81,11 +81,11 @@ env_params = GymVectorEnvironment(level='CartPole-v0')
|
||||
preset_validation_params = PresetValidationParameters()
|
||||
preset_validation_params.test = True
|
||||
preset_validation_params.min_reward_threshold = 150
|
||||
preset_validation_params.max_episodes_to_achieve_reward = 250
|
||||
preset_validation_params.max_episodes_to_achieve_reward = 2000
|
||||
|
||||
graph_manager = BatchRLGraphManager(agent_params=agent_params, env_params=env_params,
|
||||
schedule_params=schedule_params,
|
||||
vis_params=VisualizationParameters(dump_signals_to_csv_every_x_episodes=1),
|
||||
preset_validation_params=preset_validation_params,
|
||||
reward_model_num_epochs=50,
|
||||
reward_model_num_epochs=30,
|
||||
train_to_eval_ratio=0.8)
|
||||
|
||||
Reference in New Issue
Block a user