mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
Adding should_train helper and should_train in graph_manager
This commit is contained in:
committed by
zach dwiel
parent
a2e57a44f1
commit
a7f5442015
@@ -523,6 +523,8 @@ class Agent(AgentInterface):
|
||||
Determine if online weights should be copied to the target.
|
||||
:return: boolean: True if the online weights should be copied to the target.
|
||||
"""
|
||||
if hasattr(self.ap.memory, 'memory_backend_params'):
|
||||
self.total_steps_counter = self.call_memory('num_transitions')
|
||||
# update the target network of every network that has a target network
|
||||
step_method = self.ap.algorithm.num_steps_between_copying_online_weights_to_target
|
||||
if step_method.__class__ == TrainingSteps:
|
||||
@@ -544,22 +546,35 @@ class Agent(AgentInterface):
|
||||
:return: boolean: True if we should start a training phase
|
||||
"""
|
||||
|
||||
should_update = self._should_train_helper(wait_for_full_episode)
|
||||
|
||||
step_method = self.ap.algorithm.num_consecutive_playing_steps
|
||||
|
||||
if should_update:
|
||||
if step_method.__class__ == EnvironmentEpisodes:
|
||||
self.last_training_phase_step = self.current_episode
|
||||
if step_method.__class__ == EnvironmentSteps:
|
||||
self.last_training_phase_step = self.total_steps_counter
|
||||
|
||||
return should_update
|
||||
|
||||
def _should_train_helper(self, wait_for_full_episode=False):
|
||||
|
||||
if hasattr(self.ap.memory, 'memory_backend_params'):
|
||||
self.total_steps_counter = self.call_memory('num_transitions')
|
||||
|
||||
step_method = self.ap.algorithm.num_consecutive_playing_steps
|
||||
if step_method.__class__ == EnvironmentEpisodes:
|
||||
should_update = (self.current_episode - self.last_training_phase_step) >= step_method.num_steps
|
||||
if should_update:
|
||||
self.last_training_phase_step = self.current_episode
|
||||
|
||||
elif step_method.__class__ == EnvironmentSteps:
|
||||
should_update = (self.total_steps_counter - self.last_training_phase_step) >= step_method.num_steps
|
||||
if wait_for_full_episode:
|
||||
should_update = should_update and self.current_episode_buffer.is_complete
|
||||
if should_update:
|
||||
self.last_training_phase_step = self.total_steps_counter
|
||||
else:
|
||||
raise ValueError("The num_consecutive_playing_steps parameter should be either "
|
||||
"EnvironmentSteps or Episodes. Instead it is {}".format(step_method.__class__))
|
||||
|
||||
return should_update
|
||||
|
||||
def train(self):
|
||||
|
||||
@@ -5,6 +5,7 @@ from minio.error import ResponseError
|
||||
from configparser import ConfigParser, Error
|
||||
from google.protobuf import text_format
|
||||
from tensorflow.python.training.checkpoint_state_pb2 import CheckpointState
|
||||
from minio.error import ResponseError
|
||||
import os
|
||||
import time
|
||||
import io
|
||||
@@ -63,7 +64,7 @@ class S3DataStore(DataStore):
|
||||
for filename in files:
|
||||
if filename == 'checkpoint':
|
||||
checkpoint_file = (root, filename)
|
||||
pass
|
||||
continue
|
||||
abs_name = os.path.abspath(os.path.join(root, filename))
|
||||
rel_name = os.path.relpath(abs_name, self.params.checkpoint_dir)
|
||||
self.mc.fput_object(self.params.bucket_name, rel_name, abs_name)
|
||||
@@ -79,17 +80,21 @@ class S3DataStore(DataStore):
|
||||
def load_from_store(self):
|
||||
try:
|
||||
|
||||
filename = os.path.abspath(os.path.join(self.params.checkpoint_dir, "checkpoint"))
|
||||
|
||||
while True:
|
||||
objects = self.mc.list_objects_v2(self.params.bucket_name, self.params.lock_file)
|
||||
time.sleep(10)
|
||||
|
||||
if next(objects, None) is None:
|
||||
try:
|
||||
self.mc.fget_object(self.params.bucket_name, "checkpoint", filename)
|
||||
except ResponseError as e:
|
||||
continue
|
||||
break
|
||||
time.sleep(10)
|
||||
|
||||
print("loading from s3")
|
||||
|
||||
filename = os.path.abspath(os.path.join(self.params.checkpoint_dir, "checkpoint"))
|
||||
self.mc.fget_object(self.params.bucket_name, "checkpoint", filename)
|
||||
|
||||
ckpt = CheckpointState()
|
||||
if os.path.exists(filename):
|
||||
contents = open(filename, 'r').read()
|
||||
|
||||
@@ -591,3 +591,6 @@ class GraphManager(object):
|
||||
result += "{}: \n{}\n".format(key, params)
|
||||
|
||||
return result
|
||||
|
||||
def should_train(self) -> bool:
|
||||
return any([manager.should_train() for manager in self.level_managers])
|
||||
|
||||
@@ -260,3 +260,6 @@ class LevelManager(EnvironmentInterface):
|
||||
:return:
|
||||
"""
|
||||
[agent.sync() for agent in self.agents.values()]
|
||||
|
||||
def should_train(self) -> bool:
|
||||
return any([agent._should_train_helper() for agent in self.agents.values()])
|
||||
|
||||
@@ -20,12 +20,12 @@ def main(preset: str, image: str='ajaysudh/testing:coach', num_workers: int=1, n
|
||||
if data_store == "s3":
|
||||
ds_params = DataStoreParameters("s3", "", "")
|
||||
ds_params_instance = S3DataStoreParameters(ds_params=ds_params, end_point=s3_end_point, bucket_name=s3_bucket_name,
|
||||
checkpoint_dir="/checkpoint")
|
||||
checkpoint_dir="/checkpoint")
|
||||
elif data_store == "nfs":
|
||||
ds_params = DataStoreParameters("nfs", "kubernetes", {"namespace": "default"})
|
||||
ds_params_instance = NFSDataStoreParameters(ds_params)
|
||||
|
||||
worker_run_type_params = RunTypeParameters(image, rollout_command, run_type="worker")
|
||||
worker_run_type_params = RunTypeParameters(image, rollout_command, run_type="worker", num_replicas=num_workers)
|
||||
trainer_run_type_params = RunTypeParameters(image, training_command, run_type="trainer")
|
||||
|
||||
orchestration_params = KubernetesParameters([worker_run_type_params, trainer_run_type_params],
|
||||
@@ -53,7 +53,7 @@ def main(preset: str, image: str='ajaysudh/testing:coach', num_workers: int=1, n
|
||||
orchestrator.trainer_logs()
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
# orchestrator.undeploy()
|
||||
orchestrator.undeploy()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
@@ -90,6 +90,11 @@ if __name__ == '__main__':
|
||||
help="(string) S3 bucket name to use when S3 data store is used.",
|
||||
type=str,
|
||||
required=True)
|
||||
parser.add_argument('--num-workers',
|
||||
help="(string) Number of rollout workers",
|
||||
type=int,
|
||||
required=False,
|
||||
default=1)
|
||||
|
||||
# parser.add_argument('--checkpoint_dir',
|
||||
# help='(string) Path to a folder containing a checkpoint to write the model to.',
|
||||
@@ -99,4 +104,4 @@ if __name__ == '__main__':
|
||||
|
||||
main(preset=args.preset, image=args.image, nfs_server=args.nfs_server, nfs_path=args.nfs_path,
|
||||
memory_backend=args.memory_backend, data_store=args.data_store, s3_end_point=args.s3_end_point,
|
||||
s3_bucket_name=args.s3_bucket_name)
|
||||
s3_bucket_name=args.s3_bucket_name, num_workers=args.num_workers)
|
||||
|
||||
75
rl_coach/presets/CartPole_PPO.py
Normal file
75
rl_coach/presets/CartPole_PPO.py
Normal file
@@ -0,0 +1,75 @@
|
||||
from rl_coach.agents.clipped_ppo_agent import ClippedPPOAgentParameters
|
||||
from rl_coach.architectures.tensorflow_components.architecture import Dense
|
||||
from rl_coach.base_parameters import VisualizationParameters, PresetValidationParameters
|
||||
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps, RunPhase
|
||||
from rl_coach.environments.environment import MaxDumpMethod, SelectedPhaseOnlyDumpMethod, SingleLevelSelection
|
||||
from rl_coach.environments.gym_environment import Mujoco, mujoco_v2, MujocoInputFilter
|
||||
from rl_coach.exploration_policies.additive_noise import AdditiveNoiseParameters
|
||||
from rl_coach.exploration_policies.e_greedy import EGreedyParameters
|
||||
from rl_coach.filters.observation.observation_normalization_filter import ObservationNormalizationFilter
|
||||
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager
|
||||
from rl_coach.graph_managers.graph_manager import ScheduleParameters
|
||||
from rl_coach.schedules import LinearSchedule
|
||||
|
||||
####################
|
||||
# Graph Scheduling #
|
||||
####################
|
||||
|
||||
schedule_params = ScheduleParameters()
|
||||
schedule_params.improve_steps = TrainingSteps(10000000)
|
||||
schedule_params.steps_between_evaluation_periods = EnvironmentSteps(2048)
|
||||
schedule_params.evaluation_steps = EnvironmentEpisodes(5)
|
||||
schedule_params.heatup_steps = EnvironmentSteps(0)
|
||||
|
||||
#########
|
||||
# Agent #
|
||||
#########
|
||||
agent_params = ClippedPPOAgentParameters()
|
||||
|
||||
|
||||
agent_params.network_wrappers['main'].learning_rate = 0.0003
|
||||
agent_params.network_wrappers['main'].input_embedders_parameters['observation'].activation_function = 'tanh'
|
||||
agent_params.network_wrappers['main'].input_embedders_parameters['observation'].scheme = [Dense([64])]
|
||||
agent_params.network_wrappers['main'].middleware_parameters.scheme = [Dense([64])]
|
||||
agent_params.network_wrappers['main'].middleware_parameters.activation_function = 'tanh'
|
||||
agent_params.network_wrappers['main'].batch_size = 64
|
||||
agent_params.network_wrappers['main'].optimizer_epsilon = 1e-5
|
||||
agent_params.network_wrappers['main'].adam_optimizer_beta2 = 0.999
|
||||
|
||||
agent_params.algorithm.clip_likelihood_ratio_using_epsilon = 0.2
|
||||
agent_params.algorithm.clipping_decay_schedule = LinearSchedule(1.0, 0, 1000000)
|
||||
agent_params.algorithm.beta_entropy = 0
|
||||
agent_params.algorithm.gae_lambda = 0.95
|
||||
agent_params.algorithm.discount = 0.99
|
||||
agent_params.algorithm.optimization_epochs = 10
|
||||
agent_params.algorithm.estimate_state_value_using_gae = True
|
||||
agent_params.algorithm.num_steps_between_copying_online_weights_to_target = EnvironmentSteps(2048)
|
||||
|
||||
# agent_params.input_filter = MujocoInputFilter()
|
||||
agent_params.exploration = EGreedyParameters()
|
||||
agent_params.exploration.epsilon_schedule = LinearSchedule(1.0, 0.01, 10000)
|
||||
# agent_params.pre_network_filter = MujocoInputFilter()
|
||||
agent_params.pre_network_filter.add_observation_filter('observation', 'normalize_observation',
|
||||
ObservationNormalizationFilter(name='normalize_observation'))
|
||||
|
||||
###############
|
||||
# Environment #
|
||||
###############
|
||||
env_params = Mujoco()
|
||||
env_params.level = 'CartPole-v0'
|
||||
|
||||
vis_params = VisualizationParameters()
|
||||
vis_params.video_dump_methods = [SelectedPhaseOnlyDumpMethod(RunPhase.TEST), MaxDumpMethod()]
|
||||
vis_params.dump_mp4 = False
|
||||
|
||||
########
|
||||
# Test #
|
||||
########
|
||||
preset_validation_params = PresetValidationParameters()
|
||||
preset_validation_params.test = True
|
||||
preset_validation_params.min_reward_threshold = 150
|
||||
preset_validation_params.max_episodes_to_achieve_reward = 250
|
||||
|
||||
graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params,
|
||||
schedule_params=schedule_params, vis_params=vis_params,
|
||||
preset_validation_params=preset_validation_params)
|
||||
@@ -32,13 +32,13 @@ def training_worker(graph_manager, checkpoint_dir):
|
||||
|
||||
# training loop
|
||||
while True:
|
||||
graph_manager.phase = core_types.RunPhase.TRAIN
|
||||
graph_manager.train(core_types.TrainingSteps(1))
|
||||
graph_manager.phase = core_types.RunPhase.UNDEFINED
|
||||
|
||||
graph_manager.evaluate(graph_manager.evaluation_steps)
|
||||
|
||||
graph_manager.save_checkpoint()
|
||||
if graph_manager.should_train():
|
||||
graph_manager.phase = core_types.RunPhase.TRAIN
|
||||
graph_manager.train(core_types.TrainingSteps(1))
|
||||
graph_manager.phase = core_types.RunPhase.UNDEFINED
|
||||
graph_manager.evaluate(graph_manager.evaluation_steps)
|
||||
graph_manager.save_checkpoint()
|
||||
time.sleep(10)
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
Reference in New Issue
Block a user