1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 19:20:19 +01:00

Adding should_train helper and should_train in graph_manager

This commit is contained in:
Ajay Deshpande
2018-10-05 14:22:15 -07:00
committed by zach dwiel
parent a2e57a44f1
commit a7f5442015
7 changed files with 126 additions and 20 deletions

View File

@@ -523,6 +523,8 @@ class Agent(AgentInterface):
Determine if online weights should be copied to the target.
:return: boolean: True if the online weights should be copied to the target.
"""
if hasattr(self.ap.memory, 'memory_backend_params'):
self.total_steps_counter = self.call_memory('num_transitions')
# update the target network of every network that has a target network
step_method = self.ap.algorithm.num_steps_between_copying_online_weights_to_target
if step_method.__class__ == TrainingSteps:
@@ -544,22 +546,35 @@ class Agent(AgentInterface):
:return: boolean: True if we should start a training phase
"""
should_update = self._should_train_helper(wait_for_full_episode)
step_method = self.ap.algorithm.num_consecutive_playing_steps
if should_update:
if step_method.__class__ == EnvironmentEpisodes:
self.last_training_phase_step = self.current_episode
if step_method.__class__ == EnvironmentSteps:
self.last_training_phase_step = self.total_steps_counter
return should_update
def _should_train_helper(self, wait_for_full_episode=False):
if hasattr(self.ap.memory, 'memory_backend_params'):
self.total_steps_counter = self.call_memory('num_transitions')
step_method = self.ap.algorithm.num_consecutive_playing_steps
if step_method.__class__ == EnvironmentEpisodes:
should_update = (self.current_episode - self.last_training_phase_step) >= step_method.num_steps
if should_update:
self.last_training_phase_step = self.current_episode
elif step_method.__class__ == EnvironmentSteps:
should_update = (self.total_steps_counter - self.last_training_phase_step) >= step_method.num_steps
if wait_for_full_episode:
should_update = should_update and self.current_episode_buffer.is_complete
if should_update:
self.last_training_phase_step = self.total_steps_counter
else:
raise ValueError("The num_consecutive_playing_steps parameter should be either "
"EnvironmentSteps or Episodes. Instead it is {}".format(step_method.__class__))
return should_update
def train(self):

View File

@@ -5,6 +5,7 @@ from minio.error import ResponseError
from configparser import ConfigParser, Error
from google.protobuf import text_format
from tensorflow.python.training.checkpoint_state_pb2 import CheckpointState
from minio.error import ResponseError
import os
import time
import io
@@ -63,7 +64,7 @@ class S3DataStore(DataStore):
for filename in files:
if filename == 'checkpoint':
checkpoint_file = (root, filename)
pass
continue
abs_name = os.path.abspath(os.path.join(root, filename))
rel_name = os.path.relpath(abs_name, self.params.checkpoint_dir)
self.mc.fput_object(self.params.bucket_name, rel_name, abs_name)
@@ -79,17 +80,21 @@ class S3DataStore(DataStore):
def load_from_store(self):
try:
filename = os.path.abspath(os.path.join(self.params.checkpoint_dir, "checkpoint"))
while True:
objects = self.mc.list_objects_v2(self.params.bucket_name, self.params.lock_file)
time.sleep(10)
if next(objects, None) is None:
try:
self.mc.fget_object(self.params.bucket_name, "checkpoint", filename)
except ResponseError as e:
continue
break
time.sleep(10)
print("loading from s3")
filename = os.path.abspath(os.path.join(self.params.checkpoint_dir, "checkpoint"))
self.mc.fget_object(self.params.bucket_name, "checkpoint", filename)
ckpt = CheckpointState()
if os.path.exists(filename):
contents = open(filename, 'r').read()

View File

@@ -591,3 +591,6 @@ class GraphManager(object):
result += "{}: \n{}\n".format(key, params)
return result
def should_train(self) -> bool:
return any([manager.should_train() for manager in self.level_managers])

View File

@@ -260,3 +260,6 @@ class LevelManager(EnvironmentInterface):
:return:
"""
[agent.sync() for agent in self.agents.values()]
def should_train(self) -> bool:
return any([agent._should_train_helper() for agent in self.agents.values()])

View File

@@ -20,12 +20,12 @@ def main(preset: str, image: str='ajaysudh/testing:coach', num_workers: int=1, n
if data_store == "s3":
ds_params = DataStoreParameters("s3", "", "")
ds_params_instance = S3DataStoreParameters(ds_params=ds_params, end_point=s3_end_point, bucket_name=s3_bucket_name,
checkpoint_dir="/checkpoint")
checkpoint_dir="/checkpoint")
elif data_store == "nfs":
ds_params = DataStoreParameters("nfs", "kubernetes", {"namespace": "default"})
ds_params_instance = NFSDataStoreParameters(ds_params)
worker_run_type_params = RunTypeParameters(image, rollout_command, run_type="worker")
worker_run_type_params = RunTypeParameters(image, rollout_command, run_type="worker", num_replicas=num_workers)
trainer_run_type_params = RunTypeParameters(image, training_command, run_type="trainer")
orchestration_params = KubernetesParameters([worker_run_type_params, trainer_run_type_params],
@@ -53,7 +53,7 @@ def main(preset: str, image: str='ajaysudh/testing:coach', num_workers: int=1, n
orchestrator.trainer_logs()
except KeyboardInterrupt:
pass
# orchestrator.undeploy()
orchestrator.undeploy()
if __name__ == '__main__':
@@ -90,6 +90,11 @@ if __name__ == '__main__':
help="(string) S3 bucket name to use when S3 data store is used.",
type=str,
required=True)
parser.add_argument('--num-workers',
help="(string) Number of rollout workers",
type=int,
required=False,
default=1)
# parser.add_argument('--checkpoint_dir',
# help='(string) Path to a folder containing a checkpoint to write the model to.',
@@ -99,4 +104,4 @@ if __name__ == '__main__':
main(preset=args.preset, image=args.image, nfs_server=args.nfs_server, nfs_path=args.nfs_path,
memory_backend=args.memory_backend, data_store=args.data_store, s3_end_point=args.s3_end_point,
s3_bucket_name=args.s3_bucket_name)
s3_bucket_name=args.s3_bucket_name, num_workers=args.num_workers)

View File

@@ -0,0 +1,75 @@
from rl_coach.agents.clipped_ppo_agent import ClippedPPOAgentParameters
from rl_coach.architectures.tensorflow_components.architecture import Dense
from rl_coach.base_parameters import VisualizationParameters, PresetValidationParameters
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps, RunPhase
from rl_coach.environments.environment import MaxDumpMethod, SelectedPhaseOnlyDumpMethod, SingleLevelSelection
from rl_coach.environments.gym_environment import Mujoco, mujoco_v2, MujocoInputFilter
from rl_coach.exploration_policies.additive_noise import AdditiveNoiseParameters
from rl_coach.exploration_policies.e_greedy import EGreedyParameters
from rl_coach.filters.observation.observation_normalization_filter import ObservationNormalizationFilter
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager
from rl_coach.graph_managers.graph_manager import ScheduleParameters
from rl_coach.schedules import LinearSchedule
####################
# Graph Scheduling #
####################
schedule_params = ScheduleParameters()
schedule_params.improve_steps = TrainingSteps(10000000)
schedule_params.steps_between_evaluation_periods = EnvironmentSteps(2048)
schedule_params.evaluation_steps = EnvironmentEpisodes(5)
schedule_params.heatup_steps = EnvironmentSteps(0)
#########
# Agent #
#########
agent_params = ClippedPPOAgentParameters()
agent_params.network_wrappers['main'].learning_rate = 0.0003
agent_params.network_wrappers['main'].input_embedders_parameters['observation'].activation_function = 'tanh'
agent_params.network_wrappers['main'].input_embedders_parameters['observation'].scheme = [Dense([64])]
agent_params.network_wrappers['main'].middleware_parameters.scheme = [Dense([64])]
agent_params.network_wrappers['main'].middleware_parameters.activation_function = 'tanh'
agent_params.network_wrappers['main'].batch_size = 64
agent_params.network_wrappers['main'].optimizer_epsilon = 1e-5
agent_params.network_wrappers['main'].adam_optimizer_beta2 = 0.999
agent_params.algorithm.clip_likelihood_ratio_using_epsilon = 0.2
agent_params.algorithm.clipping_decay_schedule = LinearSchedule(1.0, 0, 1000000)
agent_params.algorithm.beta_entropy = 0
agent_params.algorithm.gae_lambda = 0.95
agent_params.algorithm.discount = 0.99
agent_params.algorithm.optimization_epochs = 10
agent_params.algorithm.estimate_state_value_using_gae = True
agent_params.algorithm.num_steps_between_copying_online_weights_to_target = EnvironmentSteps(2048)
# agent_params.input_filter = MujocoInputFilter()
agent_params.exploration = EGreedyParameters()
agent_params.exploration.epsilon_schedule = LinearSchedule(1.0, 0.01, 10000)
# agent_params.pre_network_filter = MujocoInputFilter()
agent_params.pre_network_filter.add_observation_filter('observation', 'normalize_observation',
ObservationNormalizationFilter(name='normalize_observation'))
###############
# Environment #
###############
env_params = Mujoco()
env_params.level = 'CartPole-v0'
vis_params = VisualizationParameters()
vis_params.video_dump_methods = [SelectedPhaseOnlyDumpMethod(RunPhase.TEST), MaxDumpMethod()]
vis_params.dump_mp4 = False
########
# Test #
########
preset_validation_params = PresetValidationParameters()
preset_validation_params.test = True
preset_validation_params.min_reward_threshold = 150
preset_validation_params.max_episodes_to_achieve_reward = 250
graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params,
schedule_params=schedule_params, vis_params=vis_params,
preset_validation_params=preset_validation_params)

View File

@@ -32,13 +32,13 @@ def training_worker(graph_manager, checkpoint_dir):
# training loop
while True:
graph_manager.phase = core_types.RunPhase.TRAIN
graph_manager.train(core_types.TrainingSteps(1))
graph_manager.phase = core_types.RunPhase.UNDEFINED
graph_manager.evaluate(graph_manager.evaluation_steps)
graph_manager.save_checkpoint()
if graph_manager.should_train():
graph_manager.phase = core_types.RunPhase.TRAIN
graph_manager.train(core_types.TrainingSteps(1))
graph_manager.phase = core_types.RunPhase.UNDEFINED
graph_manager.evaluate(graph_manager.evaluation_steps)
graph_manager.save_checkpoint()
time.sleep(10)
def main():