diff --git a/rl_coach/agents/agent.py b/rl_coach/agents/agent.py index 145be5d..f4f0fc8 100644 --- a/rl_coach/agents/agent.py +++ b/rl_coach/agents/agent.py @@ -900,31 +900,35 @@ class Agent(AgentInterface): # make agent specific changes to the transition if needed transition = self.update_transition_before_adding_to_replay_buffer(transition) - # sum up the total shaped reward - self.total_shaped_reward_in_current_episode += transition.reward - self.total_reward_in_current_episode += env_response.reward - self.shaped_reward.add_sample(transition.reward) - self.reward.add_sample(env_response.reward) - # add action info to transition if type(self.parent).__name__ == 'CompositeAgent': transition.add_info(self.parent.last_action_info.__dict__) else: transition.add_info(self.last_action_info.__dict__) - # create and store the transition - if self.phase in [RunPhase.TRAIN, RunPhase.HEATUP]: - # for episodic memories we keep the transitions in a local buffer until the episode is ended. - # for regular memories we insert the transitions directly to the memory - self.current_episode_buffer.insert(transition) - if not isinstance(self.memory, EpisodicExperienceReplay) \ - and not self.ap.algorithm.store_transitions_only_when_episodes_are_terminated: - self.call_memory('store', transition) + self.total_reward_in_current_episode += env_response.reward + self.reward.add_sample(env_response.reward) - if self.ap.visualization.dump_in_episode_signals: - self.update_step_in_episode_log() + return self.observe_transition(transition) - return transition.game_over + def observe_transition(self, transition): + # sum up the total shaped reward + self.total_shaped_reward_in_current_episode += transition.reward + self.shaped_reward.add_sample(transition.reward) + + # create and store the transition + if self.phase in [RunPhase.TRAIN, RunPhase.HEATUP]: + # for episodic memories we keep the transitions in a local buffer until the episode is ended. + # for regular memories we insert the transitions directly to the memory + self.current_episode_buffer.insert(transition) + if not isinstance(self.memory, EpisodicExperienceReplay) \ + and not self.ap.algorithm.store_transitions_only_when_episodes_are_terminated: + self.call_memory('store', transition) + + if self.ap.visualization.dump_in_episode_signals: + self.update_step_in_episode_log() + + return transition.game_over def post_training_commands(self) -> None: """ @@ -1009,38 +1013,6 @@ class Agent(AgentInterface): for network in self.networks.values(): network.sync() - # TODO-remove - this is a temporary flow, used by the trainer worker, duplicated from observe() - need to create - # an external trainer flow reusing the existing flow and methods [e.g. observe(), step(), act()] - def emulate_observe_on_trainer(self, transition: Transition) -> bool: - """ - This emulates the observe using the transition obtained from the rollout worker on the training worker - in case of distributed training. - Given a response from the environment, distill the observation from it and store it for later use. - The response should be a dictionary containing the performed action, the new observation and measurements, - the reward, a game over flag and any additional information necessary. - :return: - """ - - # sum up the total shaped reward - self.total_shaped_reward_in_current_episode += transition.reward - self.total_reward_in_current_episode += transition.reward - self.shaped_reward.add_sample(transition.reward) - self.reward.add_sample(transition.reward) - - # create and store the transition - if self.phase in [RunPhase.TRAIN, RunPhase.HEATUP]: - # for episodic memories we keep the transitions in a local buffer until the episode is ended. - # for regular memories we insert the transitions directly to the memory - self.current_episode_buffer.insert(transition) - if not isinstance(self.memory, EpisodicExperienceReplay) \ - and not self.ap.algorithm.store_transitions_only_when_episodes_are_terminated: - self.call_memory('store', transition) - - if self.ap.visualization.dump_in_episode_signals: - self.update_step_in_episode_log() - - return transition.game_over - def get_success_rate(self) -> float: return self.num_successes_across_evaluation_episodes / self.num_evaluation_episodes_completed diff --git a/rl_coach/agents/agent_interface.py b/rl_coach/agents/agent_interface.py index 7bd550a..87a9232 100644 --- a/rl_coach/agents/agent_interface.py +++ b/rl_coach/agents/agent_interface.py @@ -125,23 +125,6 @@ class AgentInterface(object): """ raise NotImplementedError("") - # TODO-remove - this is a temporary flow, used by the trainer worker, duplicated from observe() - need to create - # an external trainer flow reusing the existing flow and methods [e.g. observe(), step(), act()] - def emulate_observe_on_trainer(self, transition: Transition) -> bool: - """ - This emulates the act using the transition obtained from the rollout worker on the training worker - in case of distributed training. - Gets a response from the environment. - Processes this information for later use. For example, create a transition and store it in memory. - The action info (a class containing any info the agent wants to store regarding its action decision process) is - stored by the agent itself when deciding on the action. - :param env_response: a EnvResponse containing the response from the environment - :return: a done signal which is based on the agent knowledge. This can be different from the done signal from - the environment. For example, an agent can decide to finish the episode each time it gets some - intrinsic reward - """ - raise NotImplementedError("") - def collect_savers(self, parent_path_suffix: str) -> SaverCollection: """ Collect all of agent savers diff --git a/rl_coach/level_manager.py b/rl_coach/level_manager.py index ed60c1e..7494edd 100644 --- a/rl_coach/level_manager.py +++ b/rl_coach/level_manager.py @@ -312,7 +312,7 @@ class LevelManager(EnvironmentInterface): # for i in range(self.steps_limit.num_steps): # let the agent observe the result and decide if it wants to terminate the episode - done = acting_agent.emulate_observe_on_trainer(transition) + done = acting_agent.observe_transition(transition) acting_agent.act(transition.action) if done: