mirror of
https://github.com/gryf/coach.git
synced 2025-12-18 11:40:18 +01:00
Simulating the act on the trainer. (#65)
* Remove the use of daemon threads for Redis subscribe. * Emulate act and observe on trainer side to update internal vars.
This commit is contained in:
committed by
Balaji Subramaniam
parent
fe6857eabd
commit
fde73ced13
@@ -80,9 +80,7 @@ class Agent(AgentInterface):
|
||||
if hasattr(self.ap.memory, 'memory_backend_params'):
|
||||
self.memory_backend = get_memory_backend(self.ap.memory.memory_backend_params)
|
||||
|
||||
if self.ap.memory.memory_backend_params.run_type == 'trainer':
|
||||
self.memory_backend.subscribe(self)
|
||||
else:
|
||||
if self.ap.memory.memory_backend_params.run_type != 'trainer':
|
||||
self.memory.set_memory_backend(self.memory_backend)
|
||||
|
||||
if agent_parameters.memory.load_memory_from_file_path:
|
||||
@@ -583,14 +581,14 @@ class Agent(AgentInterface):
|
||||
"EnvironmentSteps or TrainingSteps. Instead it is {}".format(step_method.__class__))
|
||||
return should_update
|
||||
|
||||
def _should_train(self, wait_for_full_episode=False) -> bool:
|
||||
def _should_train(self):
|
||||
"""
|
||||
Determine if we should start a training phase according to the number of steps passed since the last training
|
||||
|
||||
:return: boolean: True if we should start a training phase
|
||||
"""
|
||||
|
||||
should_update = self._should_train_helper(wait_for_full_episode=wait_for_full_episode)
|
||||
should_update = self._should_train_helper()
|
||||
|
||||
step_method = self.ap.algorithm.num_consecutive_playing_steps
|
||||
|
||||
@@ -602,8 +600,8 @@ class Agent(AgentInterface):
|
||||
|
||||
return should_update
|
||||
|
||||
def _should_train_helper(self, wait_for_full_episode=False):
|
||||
|
||||
def _should_train_helper(self):
|
||||
wait_for_full_episode = self.ap.algorithm.act_for_full_episodes
|
||||
step_method = self.ap.algorithm.num_consecutive_playing_steps
|
||||
|
||||
if step_method.__class__ == EnvironmentEpisodes:
|
||||
@@ -922,5 +920,66 @@ class Agent(AgentInterface):
|
||||
for network in self.networks.values():
|
||||
network.sync()
|
||||
|
||||
# TODO-remove - this is a temporary flow, used by the trainer worker, duplicated from observe() - need to create
|
||||
# an external trainer flow reusing the existing flow and methods [e.g. observe(), step(), act()]
|
||||
def emulate_observe_on_trainer(self, transition: Transition) -> bool:
|
||||
"""
|
||||
This emulates the observe using the transition obtained from the rollout worker on the training worker
|
||||
in case of distributed training.
|
||||
Given a response from the environment, distill the observation from it and store it for later use.
|
||||
The response should be a dictionary containing the performed action, the new observation and measurements,
|
||||
the reward, a game over flag and any additional information necessary.
|
||||
:return:
|
||||
"""
|
||||
|
||||
# if we are in the first step in the episode, then we don't have a a next state and a reward and thus no
|
||||
# transition yet, and therefore we don't need to store anything in the memory.
|
||||
# also we did not reach the goal yet.
|
||||
if self.current_episode_steps_counter == 0:
|
||||
# initialize the current state
|
||||
return transition.game_over
|
||||
else:
|
||||
# sum up the total shaped reward
|
||||
self.total_shaped_reward_in_current_episode += transition.reward
|
||||
self.total_reward_in_current_episode += transition.reward
|
||||
self.shaped_reward.add_sample(transition.reward)
|
||||
self.reward.add_sample(transition.reward)
|
||||
|
||||
# create and store the transition
|
||||
if self.phase in [RunPhase.TRAIN, RunPhase.HEATUP]:
|
||||
# for episodic memories we keep the transitions in a local buffer until the episode is ended.
|
||||
# for regular memories we insert the transitions directly to the memory
|
||||
self.current_episode_buffer.insert(transition)
|
||||
if not isinstance(self.memory, EpisodicExperienceReplay) \
|
||||
and not self.ap.algorithm.store_transitions_only_when_episodes_are_terminated:
|
||||
self.call_memory('store', transition)
|
||||
|
||||
if self.ap.visualization.dump_in_episode_signals:
|
||||
self.update_step_in_episode_log()
|
||||
|
||||
return transition.game_over
|
||||
|
||||
# TODO-remove - this is a temporary flow, used by the trainer worker, duplicated from observe() - need to create
|
||||
# an external trainer flow reusing the existing flow and methods [e.g. observe(), step(), act()]
|
||||
def emulate_act_on_trainer(self, transition: Transition) -> ActionInfo:
|
||||
"""
|
||||
This emulates the act using the transition obtained from the rollout worker on the training worker
|
||||
in case of distributed training.
|
||||
Given the agents current knowledge, decide on the next action to apply to the environment
|
||||
:return: an action and a dictionary containing any additional info from the action decision process
|
||||
"""
|
||||
if self.phase == RunPhase.TRAIN and self.ap.algorithm.num_consecutive_playing_steps.num_steps == 0:
|
||||
# This agent never plays while training (e.g. behavioral cloning)
|
||||
return None
|
||||
|
||||
# count steps (only when training or if we are in the evaluation worker)
|
||||
if self.phase != RunPhase.TEST or self.ap.task_parameters.evaluate_only:
|
||||
self.total_steps_counter += 1
|
||||
self.current_episode_steps_counter += 1
|
||||
|
||||
self.last_action_info = transition.action
|
||||
|
||||
return self.last_action_info
|
||||
|
||||
def get_success_rate(self) -> float:
|
||||
return self.num_successes_across_evaluation_episodes / self.num_evaluation_episodes_completed
|
||||
|
||||
@@ -18,7 +18,7 @@ from typing import Union, List, Dict
|
||||
|
||||
import numpy as np
|
||||
|
||||
from rl_coach.core_types import EnvResponse, ActionInfo, RunPhase, PredictionType, ActionType
|
||||
from rl_coach.core_types import EnvResponse, ActionInfo, RunPhase, PredictionType, ActionType, Transition
|
||||
|
||||
|
||||
class AgentInterface(object):
|
||||
@@ -123,3 +123,33 @@ class AgentInterface(object):
|
||||
:return: None
|
||||
"""
|
||||
raise NotImplementedError("")
|
||||
|
||||
# TODO-remove - this is a temporary flow, used by the trainer worker, duplicated from observe() - need to create
|
||||
# an external trainer flow reusing the existing flow and methods [e.g. observe(), step(), act()]
|
||||
def emulate_observe_on_trainer(self, transition: Transition) -> bool:
|
||||
"""
|
||||
This emulates the act using the transition obtained from the rollout worker on the training worker
|
||||
in case of distributed training.
|
||||
Gets a response from the environment.
|
||||
Processes this information for later use. For example, create a transition and store it in memory.
|
||||
The action info (a class containing any info the agent wants to store regarding its action decision process) is
|
||||
stored by the agent itself when deciding on the action.
|
||||
:param env_response: a EnvResponse containing the response from the environment
|
||||
:return: a done signal which is based on the agent knowledge. This can be different from the done signal from
|
||||
the environment. For example, an agent can decide to finish the episode each time it gets some
|
||||
intrinsic reward
|
||||
"""
|
||||
raise NotImplementedError("")
|
||||
|
||||
# TODO-remove - this is a temporary flow, used by the trainer worker, duplicated from observe() - need to create
|
||||
# an external trainer flow reusing the existing flow and methods [e.g. observe(), step(), act()]
|
||||
def emulate_act_on_trainer(self, transition: Transition) -> ActionInfo:
|
||||
"""
|
||||
This emulates the act using the transition obtained from the rollout worker on the training worker
|
||||
in case of distributed training.
|
||||
Get a decision of the next action to take.
|
||||
The action is dependent on the current state which the agent holds from resetting the environment or from
|
||||
the observe function.
|
||||
:return: A tuple containing the actual action and additional info on the action
|
||||
"""
|
||||
raise NotImplementedError("")
|
||||
|
||||
@@ -112,6 +112,7 @@ class ClippedPPOAlgorithmParameters(AlgorithmParameters):
|
||||
self.optimization_epochs = 10
|
||||
self.normalization_stats = None
|
||||
self.clipping_decay_schedule = ConstantSchedule(1)
|
||||
self.act_for_full_episodes = True
|
||||
|
||||
|
||||
class ClippedPPOAgentParameters(AgentParameters):
|
||||
@@ -294,11 +295,8 @@ class ClippedPPOAgent(ActorCriticAgent):
|
||||
# clean memory
|
||||
self.call_memory('clean')
|
||||
|
||||
def _should_train_helper(self, wait_for_full_episode=True):
|
||||
return super()._should_train_helper(True)
|
||||
|
||||
def train(self):
|
||||
if self._should_train(wait_for_full_episode=True):
|
||||
if self._should_train():
|
||||
for network in self.networks.values():
|
||||
network.set_is_training(True)
|
||||
|
||||
@@ -334,3 +332,4 @@ class ClippedPPOAgent(ActorCriticAgent):
|
||||
def choose_action(self, curr_state):
|
||||
self.ap.algorithm.clipping_decay_schedule.step()
|
||||
return super().choose_action(curr_state)
|
||||
|
||||
|
||||
@@ -121,6 +121,7 @@ class PPOAlgorithmParameters(AlgorithmParameters):
|
||||
self.use_kl_regularization = True
|
||||
self.beta_entropy = 0.01
|
||||
self.num_consecutive_playing_steps = EnvironmentSteps(5000)
|
||||
self.act_for_full_episodes = True
|
||||
|
||||
|
||||
class PPOAgentParameters(AgentParameters):
|
||||
@@ -354,12 +355,9 @@ class PPOAgent(ActorCriticAgent):
|
||||
# clean memory
|
||||
self.call_memory('clean')
|
||||
|
||||
def _should_train_helper(self, wait_for_full_episode=True):
|
||||
return super()._should_train_helper(True)
|
||||
|
||||
def train(self):
|
||||
loss = 0
|
||||
if self._should_train(wait_for_full_episode=True):
|
||||
if self._should_train():
|
||||
for network in self.networks.values():
|
||||
network.set_is_training(True)
|
||||
|
||||
@@ -391,3 +389,4 @@ class PPOAgent(ActorCriticAgent):
|
||||
def get_prediction(self, states):
|
||||
tf_input_state = self.prepare_batch_for_inference(states, "actor")
|
||||
return self.networks['actor'].online_network.predict(tf_input_state)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user