diff --git a/rl_coach/agents/agent.py b/rl_coach/agents/agent.py index 2dd1e9a..b249749 100644 --- a/rl_coach/agents/agent.py +++ b/rl_coach/agents/agent.py @@ -523,6 +523,8 @@ class Agent(AgentInterface): Determine if online weights should be copied to the target. :return: boolean: True if the online weights should be copied to the target. """ + if hasattr(self.ap.memory, 'memory_backend_params'): + self.total_steps_counter = self.call_memory('num_transitions') # update the target network of every network that has a target network step_method = self.ap.algorithm.num_steps_between_copying_online_weights_to_target if step_method.__class__ == TrainingSteps: @@ -544,22 +546,35 @@ class Agent(AgentInterface): :return: boolean: True if we should start a training phase """ + should_update = self._should_train_helper(wait_for_full_episode) + + step_method = self.ap.algorithm.num_consecutive_playing_steps + + if should_update: + if step_method.__class__ == EnvironmentEpisodes: + self.last_training_phase_step = self.current_episode + if step_method.__class__ == EnvironmentSteps: + self.last_training_phase_step = self.total_steps_counter + + return should_update + + def _should_train_helper(self, wait_for_full_episode=False): + if hasattr(self.ap.memory, 'memory_backend_params'): self.total_steps_counter = self.call_memory('num_transitions') + step_method = self.ap.algorithm.num_consecutive_playing_steps if step_method.__class__ == EnvironmentEpisodes: should_update = (self.current_episode - self.last_training_phase_step) >= step_method.num_steps - if should_update: - self.last_training_phase_step = self.current_episode + elif step_method.__class__ == EnvironmentSteps: should_update = (self.total_steps_counter - self.last_training_phase_step) >= step_method.num_steps if wait_for_full_episode: should_update = should_update and self.current_episode_buffer.is_complete - if should_update: - self.last_training_phase_step = self.total_steps_counter else: raise ValueError("The num_consecutive_playing_steps parameter should be either " "EnvironmentSteps or Episodes. Instead it is {}".format(step_method.__class__)) + return should_update def train(self): diff --git a/rl_coach/data_stores/s3_data_store.py b/rl_coach/data_stores/s3_data_store.py index 2cb838b..bed6398 100644 --- a/rl_coach/data_stores/s3_data_store.py +++ b/rl_coach/data_stores/s3_data_store.py @@ -5,6 +5,7 @@ from minio.error import ResponseError from configparser import ConfigParser, Error from google.protobuf import text_format from tensorflow.python.training.checkpoint_state_pb2 import CheckpointState +from minio.error import ResponseError import os import time import io @@ -63,7 +64,7 @@ class S3DataStore(DataStore): for filename in files: if filename == 'checkpoint': checkpoint_file = (root, filename) - pass + continue abs_name = os.path.abspath(os.path.join(root, filename)) rel_name = os.path.relpath(abs_name, self.params.checkpoint_dir) self.mc.fput_object(self.params.bucket_name, rel_name, abs_name) @@ -79,17 +80,21 @@ class S3DataStore(DataStore): def load_from_store(self): try: + filename = os.path.abspath(os.path.join(self.params.checkpoint_dir, "checkpoint")) + while True: objects = self.mc.list_objects_v2(self.params.bucket_name, self.params.lock_file) - time.sleep(10) + if next(objects, None) is None: + try: + self.mc.fget_object(self.params.bucket_name, "checkpoint", filename) + except ResponseError as e: + continue break + time.sleep(10) print("loading from s3") - filename = os.path.abspath(os.path.join(self.params.checkpoint_dir, "checkpoint")) - self.mc.fget_object(self.params.bucket_name, "checkpoint", filename) - ckpt = CheckpointState() if os.path.exists(filename): contents = open(filename, 'r').read() diff --git a/rl_coach/graph_managers/graph_manager.py b/rl_coach/graph_managers/graph_manager.py index 5ee24ed..576ac8d 100644 --- a/rl_coach/graph_managers/graph_manager.py +++ b/rl_coach/graph_managers/graph_manager.py @@ -591,3 +591,6 @@ class GraphManager(object): result += "{}: \n{}\n".format(key, params) return result + + def should_train(self) -> bool: + return any([manager.should_train() for manager in self.level_managers]) diff --git a/rl_coach/level_manager.py b/rl_coach/level_manager.py index 19a2dc0..7697bf5 100644 --- a/rl_coach/level_manager.py +++ b/rl_coach/level_manager.py @@ -260,3 +260,6 @@ class LevelManager(EnvironmentInterface): :return: """ [agent.sync() for agent in self.agents.values()] + + def should_train(self) -> bool: + return any([agent._should_train_helper() for agent in self.agents.values()]) diff --git a/rl_coach/orchestrators/start_training.py b/rl_coach/orchestrators/start_training.py index 05af288..dbf9df0 100644 --- a/rl_coach/orchestrators/start_training.py +++ b/rl_coach/orchestrators/start_training.py @@ -20,12 +20,12 @@ def main(preset: str, image: str='ajaysudh/testing:coach', num_workers: int=1, n if data_store == "s3": ds_params = DataStoreParameters("s3", "", "") ds_params_instance = S3DataStoreParameters(ds_params=ds_params, end_point=s3_end_point, bucket_name=s3_bucket_name, - checkpoint_dir="/checkpoint") + checkpoint_dir="/checkpoint") elif data_store == "nfs": ds_params = DataStoreParameters("nfs", "kubernetes", {"namespace": "default"}) ds_params_instance = NFSDataStoreParameters(ds_params) - worker_run_type_params = RunTypeParameters(image, rollout_command, run_type="worker") + worker_run_type_params = RunTypeParameters(image, rollout_command, run_type="worker", num_replicas=num_workers) trainer_run_type_params = RunTypeParameters(image, training_command, run_type="trainer") orchestration_params = KubernetesParameters([worker_run_type_params, trainer_run_type_params], @@ -53,7 +53,7 @@ def main(preset: str, image: str='ajaysudh/testing:coach', num_workers: int=1, n orchestrator.trainer_logs() except KeyboardInterrupt: pass - # orchestrator.undeploy() + orchestrator.undeploy() if __name__ == '__main__': @@ -90,6 +90,11 @@ if __name__ == '__main__': help="(string) S3 bucket name to use when S3 data store is used.", type=str, required=True) + parser.add_argument('--num-workers', + help="(string) Number of rollout workers", + type=int, + required=False, + default=1) # parser.add_argument('--checkpoint_dir', # help='(string) Path to a folder containing a checkpoint to write the model to.', @@ -99,4 +104,4 @@ if __name__ == '__main__': main(preset=args.preset, image=args.image, nfs_server=args.nfs_server, nfs_path=args.nfs_path, memory_backend=args.memory_backend, data_store=args.data_store, s3_end_point=args.s3_end_point, - s3_bucket_name=args.s3_bucket_name) + s3_bucket_name=args.s3_bucket_name, num_workers=args.num_workers) diff --git a/rl_coach/presets/CartPole_PPO.py b/rl_coach/presets/CartPole_PPO.py new file mode 100644 index 0000000..c163efd --- /dev/null +++ b/rl_coach/presets/CartPole_PPO.py @@ -0,0 +1,75 @@ +from rl_coach.agents.clipped_ppo_agent import ClippedPPOAgentParameters +from rl_coach.architectures.tensorflow_components.architecture import Dense +from rl_coach.base_parameters import VisualizationParameters, PresetValidationParameters +from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps, RunPhase +from rl_coach.environments.environment import MaxDumpMethod, SelectedPhaseOnlyDumpMethod, SingleLevelSelection +from rl_coach.environments.gym_environment import Mujoco, mujoco_v2, MujocoInputFilter +from rl_coach.exploration_policies.additive_noise import AdditiveNoiseParameters +from rl_coach.exploration_policies.e_greedy import EGreedyParameters +from rl_coach.filters.observation.observation_normalization_filter import ObservationNormalizationFilter +from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager +from rl_coach.graph_managers.graph_manager import ScheduleParameters +from rl_coach.schedules import LinearSchedule + +#################### +# Graph Scheduling # +#################### + +schedule_params = ScheduleParameters() +schedule_params.improve_steps = TrainingSteps(10000000) +schedule_params.steps_between_evaluation_periods = EnvironmentSteps(2048) +schedule_params.evaluation_steps = EnvironmentEpisodes(5) +schedule_params.heatup_steps = EnvironmentSteps(0) + +######### +# Agent # +######### +agent_params = ClippedPPOAgentParameters() + + +agent_params.network_wrappers['main'].learning_rate = 0.0003 +agent_params.network_wrappers['main'].input_embedders_parameters['observation'].activation_function = 'tanh' +agent_params.network_wrappers['main'].input_embedders_parameters['observation'].scheme = [Dense([64])] +agent_params.network_wrappers['main'].middleware_parameters.scheme = [Dense([64])] +agent_params.network_wrappers['main'].middleware_parameters.activation_function = 'tanh' +agent_params.network_wrappers['main'].batch_size = 64 +agent_params.network_wrappers['main'].optimizer_epsilon = 1e-5 +agent_params.network_wrappers['main'].adam_optimizer_beta2 = 0.999 + +agent_params.algorithm.clip_likelihood_ratio_using_epsilon = 0.2 +agent_params.algorithm.clipping_decay_schedule = LinearSchedule(1.0, 0, 1000000) +agent_params.algorithm.beta_entropy = 0 +agent_params.algorithm.gae_lambda = 0.95 +agent_params.algorithm.discount = 0.99 +agent_params.algorithm.optimization_epochs = 10 +agent_params.algorithm.estimate_state_value_using_gae = True +agent_params.algorithm.num_steps_between_copying_online_weights_to_target = EnvironmentSteps(2048) + +# agent_params.input_filter = MujocoInputFilter() +agent_params.exploration = EGreedyParameters() +agent_params.exploration.epsilon_schedule = LinearSchedule(1.0, 0.01, 10000) +# agent_params.pre_network_filter = MujocoInputFilter() +agent_params.pre_network_filter.add_observation_filter('observation', 'normalize_observation', + ObservationNormalizationFilter(name='normalize_observation')) + +############### +# Environment # +############### +env_params = Mujoco() +env_params.level = 'CartPole-v0' + +vis_params = VisualizationParameters() +vis_params.video_dump_methods = [SelectedPhaseOnlyDumpMethod(RunPhase.TEST), MaxDumpMethod()] +vis_params.dump_mp4 = False + +######## +# Test # +######## +preset_validation_params = PresetValidationParameters() +preset_validation_params.test = True +preset_validation_params.min_reward_threshold = 150 +preset_validation_params.max_episodes_to_achieve_reward = 250 + +graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params, + schedule_params=schedule_params, vis_params=vis_params, + preset_validation_params=preset_validation_params) diff --git a/rl_coach/training_worker.py b/rl_coach/training_worker.py index 4da27ed..c5664a8 100644 --- a/rl_coach/training_worker.py +++ b/rl_coach/training_worker.py @@ -32,13 +32,13 @@ def training_worker(graph_manager, checkpoint_dir): # training loop while True: - graph_manager.phase = core_types.RunPhase.TRAIN - graph_manager.train(core_types.TrainingSteps(1)) - graph_manager.phase = core_types.RunPhase.UNDEFINED - - graph_manager.evaluate(graph_manager.evaluation_steps) - - graph_manager.save_checkpoint() + if graph_manager.should_train(): + graph_manager.phase = core_types.RunPhase.TRAIN + graph_manager.train(core_types.TrainingSteps(1)) + graph_manager.phase = core_types.RunPhase.UNDEFINED + graph_manager.evaluate(graph_manager.evaluation_steps) + graph_manager.save_checkpoint() + time.sleep(10) def main():