mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 11:10:20 +01:00
Adding should_train helper and should_train in graph_manager
This commit is contained in:
committed by
zach dwiel
parent
a2e57a44f1
commit
a7f5442015
@@ -523,6 +523,8 @@ class Agent(AgentInterface):
|
|||||||
Determine if online weights should be copied to the target.
|
Determine if online weights should be copied to the target.
|
||||||
:return: boolean: True if the online weights should be copied to the target.
|
:return: boolean: True if the online weights should be copied to the target.
|
||||||
"""
|
"""
|
||||||
|
if hasattr(self.ap.memory, 'memory_backend_params'):
|
||||||
|
self.total_steps_counter = self.call_memory('num_transitions')
|
||||||
# update the target network of every network that has a target network
|
# update the target network of every network that has a target network
|
||||||
step_method = self.ap.algorithm.num_steps_between_copying_online_weights_to_target
|
step_method = self.ap.algorithm.num_steps_between_copying_online_weights_to_target
|
||||||
if step_method.__class__ == TrainingSteps:
|
if step_method.__class__ == TrainingSteps:
|
||||||
@@ -544,22 +546,35 @@ class Agent(AgentInterface):
|
|||||||
:return: boolean: True if we should start a training phase
|
:return: boolean: True if we should start a training phase
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
should_update = self._should_train_helper(wait_for_full_episode)
|
||||||
|
|
||||||
|
step_method = self.ap.algorithm.num_consecutive_playing_steps
|
||||||
|
|
||||||
|
if should_update:
|
||||||
|
if step_method.__class__ == EnvironmentEpisodes:
|
||||||
|
self.last_training_phase_step = self.current_episode
|
||||||
|
if step_method.__class__ == EnvironmentSteps:
|
||||||
|
self.last_training_phase_step = self.total_steps_counter
|
||||||
|
|
||||||
|
return should_update
|
||||||
|
|
||||||
|
def _should_train_helper(self, wait_for_full_episode=False):
|
||||||
|
|
||||||
if hasattr(self.ap.memory, 'memory_backend_params'):
|
if hasattr(self.ap.memory, 'memory_backend_params'):
|
||||||
self.total_steps_counter = self.call_memory('num_transitions')
|
self.total_steps_counter = self.call_memory('num_transitions')
|
||||||
|
|
||||||
step_method = self.ap.algorithm.num_consecutive_playing_steps
|
step_method = self.ap.algorithm.num_consecutive_playing_steps
|
||||||
if step_method.__class__ == EnvironmentEpisodes:
|
if step_method.__class__ == EnvironmentEpisodes:
|
||||||
should_update = (self.current_episode - self.last_training_phase_step) >= step_method.num_steps
|
should_update = (self.current_episode - self.last_training_phase_step) >= step_method.num_steps
|
||||||
if should_update:
|
|
||||||
self.last_training_phase_step = self.current_episode
|
|
||||||
elif step_method.__class__ == EnvironmentSteps:
|
elif step_method.__class__ == EnvironmentSteps:
|
||||||
should_update = (self.total_steps_counter - self.last_training_phase_step) >= step_method.num_steps
|
should_update = (self.total_steps_counter - self.last_training_phase_step) >= step_method.num_steps
|
||||||
if wait_for_full_episode:
|
if wait_for_full_episode:
|
||||||
should_update = should_update and self.current_episode_buffer.is_complete
|
should_update = should_update and self.current_episode_buffer.is_complete
|
||||||
if should_update:
|
|
||||||
self.last_training_phase_step = self.total_steps_counter
|
|
||||||
else:
|
else:
|
||||||
raise ValueError("The num_consecutive_playing_steps parameter should be either "
|
raise ValueError("The num_consecutive_playing_steps parameter should be either "
|
||||||
"EnvironmentSteps or Episodes. Instead it is {}".format(step_method.__class__))
|
"EnvironmentSteps or Episodes. Instead it is {}".format(step_method.__class__))
|
||||||
|
|
||||||
return should_update
|
return should_update
|
||||||
|
|
||||||
def train(self):
|
def train(self):
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ from minio.error import ResponseError
|
|||||||
from configparser import ConfigParser, Error
|
from configparser import ConfigParser, Error
|
||||||
from google.protobuf import text_format
|
from google.protobuf import text_format
|
||||||
from tensorflow.python.training.checkpoint_state_pb2 import CheckpointState
|
from tensorflow.python.training.checkpoint_state_pb2 import CheckpointState
|
||||||
|
from minio.error import ResponseError
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
import io
|
import io
|
||||||
@@ -63,7 +64,7 @@ class S3DataStore(DataStore):
|
|||||||
for filename in files:
|
for filename in files:
|
||||||
if filename == 'checkpoint':
|
if filename == 'checkpoint':
|
||||||
checkpoint_file = (root, filename)
|
checkpoint_file = (root, filename)
|
||||||
pass
|
continue
|
||||||
abs_name = os.path.abspath(os.path.join(root, filename))
|
abs_name = os.path.abspath(os.path.join(root, filename))
|
||||||
rel_name = os.path.relpath(abs_name, self.params.checkpoint_dir)
|
rel_name = os.path.relpath(abs_name, self.params.checkpoint_dir)
|
||||||
self.mc.fput_object(self.params.bucket_name, rel_name, abs_name)
|
self.mc.fput_object(self.params.bucket_name, rel_name, abs_name)
|
||||||
@@ -79,17 +80,21 @@ class S3DataStore(DataStore):
|
|||||||
def load_from_store(self):
|
def load_from_store(self):
|
||||||
try:
|
try:
|
||||||
|
|
||||||
|
filename = os.path.abspath(os.path.join(self.params.checkpoint_dir, "checkpoint"))
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
objects = self.mc.list_objects_v2(self.params.bucket_name, self.params.lock_file)
|
objects = self.mc.list_objects_v2(self.params.bucket_name, self.params.lock_file)
|
||||||
time.sleep(10)
|
|
||||||
if next(objects, None) is None:
|
if next(objects, None) is None:
|
||||||
|
try:
|
||||||
|
self.mc.fget_object(self.params.bucket_name, "checkpoint", filename)
|
||||||
|
except ResponseError as e:
|
||||||
|
continue
|
||||||
break
|
break
|
||||||
|
time.sleep(10)
|
||||||
|
|
||||||
print("loading from s3")
|
print("loading from s3")
|
||||||
|
|
||||||
filename = os.path.abspath(os.path.join(self.params.checkpoint_dir, "checkpoint"))
|
|
||||||
self.mc.fget_object(self.params.bucket_name, "checkpoint", filename)
|
|
||||||
|
|
||||||
ckpt = CheckpointState()
|
ckpt = CheckpointState()
|
||||||
if os.path.exists(filename):
|
if os.path.exists(filename):
|
||||||
contents = open(filename, 'r').read()
|
contents = open(filename, 'r').read()
|
||||||
|
|||||||
@@ -591,3 +591,6 @@ class GraphManager(object):
|
|||||||
result += "{}: \n{}\n".format(key, params)
|
result += "{}: \n{}\n".format(key, params)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
def should_train(self) -> bool:
|
||||||
|
return any([manager.should_train() for manager in self.level_managers])
|
||||||
|
|||||||
@@ -260,3 +260,6 @@ class LevelManager(EnvironmentInterface):
|
|||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
[agent.sync() for agent in self.agents.values()]
|
[agent.sync() for agent in self.agents.values()]
|
||||||
|
|
||||||
|
def should_train(self) -> bool:
|
||||||
|
return any([agent._should_train_helper() for agent in self.agents.values()])
|
||||||
|
|||||||
@@ -20,12 +20,12 @@ def main(preset: str, image: str='ajaysudh/testing:coach', num_workers: int=1, n
|
|||||||
if data_store == "s3":
|
if data_store == "s3":
|
||||||
ds_params = DataStoreParameters("s3", "", "")
|
ds_params = DataStoreParameters("s3", "", "")
|
||||||
ds_params_instance = S3DataStoreParameters(ds_params=ds_params, end_point=s3_end_point, bucket_name=s3_bucket_name,
|
ds_params_instance = S3DataStoreParameters(ds_params=ds_params, end_point=s3_end_point, bucket_name=s3_bucket_name,
|
||||||
checkpoint_dir="/checkpoint")
|
checkpoint_dir="/checkpoint")
|
||||||
elif data_store == "nfs":
|
elif data_store == "nfs":
|
||||||
ds_params = DataStoreParameters("nfs", "kubernetes", {"namespace": "default"})
|
ds_params = DataStoreParameters("nfs", "kubernetes", {"namespace": "default"})
|
||||||
ds_params_instance = NFSDataStoreParameters(ds_params)
|
ds_params_instance = NFSDataStoreParameters(ds_params)
|
||||||
|
|
||||||
worker_run_type_params = RunTypeParameters(image, rollout_command, run_type="worker")
|
worker_run_type_params = RunTypeParameters(image, rollout_command, run_type="worker", num_replicas=num_workers)
|
||||||
trainer_run_type_params = RunTypeParameters(image, training_command, run_type="trainer")
|
trainer_run_type_params = RunTypeParameters(image, training_command, run_type="trainer")
|
||||||
|
|
||||||
orchestration_params = KubernetesParameters([worker_run_type_params, trainer_run_type_params],
|
orchestration_params = KubernetesParameters([worker_run_type_params, trainer_run_type_params],
|
||||||
@@ -53,7 +53,7 @@ def main(preset: str, image: str='ajaysudh/testing:coach', num_workers: int=1, n
|
|||||||
orchestrator.trainer_logs()
|
orchestrator.trainer_logs()
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
pass
|
pass
|
||||||
# orchestrator.undeploy()
|
orchestrator.undeploy()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
@@ -90,6 +90,11 @@ if __name__ == '__main__':
|
|||||||
help="(string) S3 bucket name to use when S3 data store is used.",
|
help="(string) S3 bucket name to use when S3 data store is used.",
|
||||||
type=str,
|
type=str,
|
||||||
required=True)
|
required=True)
|
||||||
|
parser.add_argument('--num-workers',
|
||||||
|
help="(string) Number of rollout workers",
|
||||||
|
type=int,
|
||||||
|
required=False,
|
||||||
|
default=1)
|
||||||
|
|
||||||
# parser.add_argument('--checkpoint_dir',
|
# parser.add_argument('--checkpoint_dir',
|
||||||
# help='(string) Path to a folder containing a checkpoint to write the model to.',
|
# help='(string) Path to a folder containing a checkpoint to write the model to.',
|
||||||
@@ -99,4 +104,4 @@ if __name__ == '__main__':
|
|||||||
|
|
||||||
main(preset=args.preset, image=args.image, nfs_server=args.nfs_server, nfs_path=args.nfs_path,
|
main(preset=args.preset, image=args.image, nfs_server=args.nfs_server, nfs_path=args.nfs_path,
|
||||||
memory_backend=args.memory_backend, data_store=args.data_store, s3_end_point=args.s3_end_point,
|
memory_backend=args.memory_backend, data_store=args.data_store, s3_end_point=args.s3_end_point,
|
||||||
s3_bucket_name=args.s3_bucket_name)
|
s3_bucket_name=args.s3_bucket_name, num_workers=args.num_workers)
|
||||||
|
|||||||
75
rl_coach/presets/CartPole_PPO.py
Normal file
75
rl_coach/presets/CartPole_PPO.py
Normal file
@@ -0,0 +1,75 @@
|
|||||||
|
from rl_coach.agents.clipped_ppo_agent import ClippedPPOAgentParameters
|
||||||
|
from rl_coach.architectures.tensorflow_components.architecture import Dense
|
||||||
|
from rl_coach.base_parameters import VisualizationParameters, PresetValidationParameters
|
||||||
|
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps, RunPhase
|
||||||
|
from rl_coach.environments.environment import MaxDumpMethod, SelectedPhaseOnlyDumpMethod, SingleLevelSelection
|
||||||
|
from rl_coach.environments.gym_environment import Mujoco, mujoco_v2, MujocoInputFilter
|
||||||
|
from rl_coach.exploration_policies.additive_noise import AdditiveNoiseParameters
|
||||||
|
from rl_coach.exploration_policies.e_greedy import EGreedyParameters
|
||||||
|
from rl_coach.filters.observation.observation_normalization_filter import ObservationNormalizationFilter
|
||||||
|
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager
|
||||||
|
from rl_coach.graph_managers.graph_manager import ScheduleParameters
|
||||||
|
from rl_coach.schedules import LinearSchedule
|
||||||
|
|
||||||
|
####################
|
||||||
|
# Graph Scheduling #
|
||||||
|
####################
|
||||||
|
|
||||||
|
schedule_params = ScheduleParameters()
|
||||||
|
schedule_params.improve_steps = TrainingSteps(10000000)
|
||||||
|
schedule_params.steps_between_evaluation_periods = EnvironmentSteps(2048)
|
||||||
|
schedule_params.evaluation_steps = EnvironmentEpisodes(5)
|
||||||
|
schedule_params.heatup_steps = EnvironmentSteps(0)
|
||||||
|
|
||||||
|
#########
|
||||||
|
# Agent #
|
||||||
|
#########
|
||||||
|
agent_params = ClippedPPOAgentParameters()
|
||||||
|
|
||||||
|
|
||||||
|
agent_params.network_wrappers['main'].learning_rate = 0.0003
|
||||||
|
agent_params.network_wrappers['main'].input_embedders_parameters['observation'].activation_function = 'tanh'
|
||||||
|
agent_params.network_wrappers['main'].input_embedders_parameters['observation'].scheme = [Dense([64])]
|
||||||
|
agent_params.network_wrappers['main'].middleware_parameters.scheme = [Dense([64])]
|
||||||
|
agent_params.network_wrappers['main'].middleware_parameters.activation_function = 'tanh'
|
||||||
|
agent_params.network_wrappers['main'].batch_size = 64
|
||||||
|
agent_params.network_wrappers['main'].optimizer_epsilon = 1e-5
|
||||||
|
agent_params.network_wrappers['main'].adam_optimizer_beta2 = 0.999
|
||||||
|
|
||||||
|
agent_params.algorithm.clip_likelihood_ratio_using_epsilon = 0.2
|
||||||
|
agent_params.algorithm.clipping_decay_schedule = LinearSchedule(1.0, 0, 1000000)
|
||||||
|
agent_params.algorithm.beta_entropy = 0
|
||||||
|
agent_params.algorithm.gae_lambda = 0.95
|
||||||
|
agent_params.algorithm.discount = 0.99
|
||||||
|
agent_params.algorithm.optimization_epochs = 10
|
||||||
|
agent_params.algorithm.estimate_state_value_using_gae = True
|
||||||
|
agent_params.algorithm.num_steps_between_copying_online_weights_to_target = EnvironmentSteps(2048)
|
||||||
|
|
||||||
|
# agent_params.input_filter = MujocoInputFilter()
|
||||||
|
agent_params.exploration = EGreedyParameters()
|
||||||
|
agent_params.exploration.epsilon_schedule = LinearSchedule(1.0, 0.01, 10000)
|
||||||
|
# agent_params.pre_network_filter = MujocoInputFilter()
|
||||||
|
agent_params.pre_network_filter.add_observation_filter('observation', 'normalize_observation',
|
||||||
|
ObservationNormalizationFilter(name='normalize_observation'))
|
||||||
|
|
||||||
|
###############
|
||||||
|
# Environment #
|
||||||
|
###############
|
||||||
|
env_params = Mujoco()
|
||||||
|
env_params.level = 'CartPole-v0'
|
||||||
|
|
||||||
|
vis_params = VisualizationParameters()
|
||||||
|
vis_params.video_dump_methods = [SelectedPhaseOnlyDumpMethod(RunPhase.TEST), MaxDumpMethod()]
|
||||||
|
vis_params.dump_mp4 = False
|
||||||
|
|
||||||
|
########
|
||||||
|
# Test #
|
||||||
|
########
|
||||||
|
preset_validation_params = PresetValidationParameters()
|
||||||
|
preset_validation_params.test = True
|
||||||
|
preset_validation_params.min_reward_threshold = 150
|
||||||
|
preset_validation_params.max_episodes_to_achieve_reward = 250
|
||||||
|
|
||||||
|
graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params,
|
||||||
|
schedule_params=schedule_params, vis_params=vis_params,
|
||||||
|
preset_validation_params=preset_validation_params)
|
||||||
@@ -32,13 +32,13 @@ def training_worker(graph_manager, checkpoint_dir):
|
|||||||
|
|
||||||
# training loop
|
# training loop
|
||||||
while True:
|
while True:
|
||||||
graph_manager.phase = core_types.RunPhase.TRAIN
|
if graph_manager.should_train():
|
||||||
graph_manager.train(core_types.TrainingSteps(1))
|
graph_manager.phase = core_types.RunPhase.TRAIN
|
||||||
graph_manager.phase = core_types.RunPhase.UNDEFINED
|
graph_manager.train(core_types.TrainingSteps(1))
|
||||||
|
graph_manager.phase = core_types.RunPhase.UNDEFINED
|
||||||
graph_manager.evaluate(graph_manager.evaluation_steps)
|
graph_manager.evaluate(graph_manager.evaluation_steps)
|
||||||
|
graph_manager.save_checkpoint()
|
||||||
graph_manager.save_checkpoint()
|
time.sleep(10)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
|||||||
Reference in New Issue
Block a user