1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 19:20:19 +01:00

rename AgentInterface.emulate_observe_on_trainer or observe_transition and call from AgentInterface.observe

This commit is contained in:
zach dwiel
2019-04-05 12:11:21 -04:00
committed by Zach Dwiel
parent f8741522e4
commit fd2c210915
3 changed files with 22 additions and 67 deletions

View File

@@ -900,31 +900,35 @@ class Agent(AgentInterface):
# make agent specific changes to the transition if needed
transition = self.update_transition_before_adding_to_replay_buffer(transition)
# sum up the total shaped reward
self.total_shaped_reward_in_current_episode += transition.reward
self.total_reward_in_current_episode += env_response.reward
self.shaped_reward.add_sample(transition.reward)
self.reward.add_sample(env_response.reward)
# add action info to transition
if type(self.parent).__name__ == 'CompositeAgent':
transition.add_info(self.parent.last_action_info.__dict__)
else:
transition.add_info(self.last_action_info.__dict__)
# create and store the transition
if self.phase in [RunPhase.TRAIN, RunPhase.HEATUP]:
# for episodic memories we keep the transitions in a local buffer until the episode is ended.
# for regular memories we insert the transitions directly to the memory
self.current_episode_buffer.insert(transition)
if not isinstance(self.memory, EpisodicExperienceReplay) \
and not self.ap.algorithm.store_transitions_only_when_episodes_are_terminated:
self.call_memory('store', transition)
self.total_reward_in_current_episode += env_response.reward
self.reward.add_sample(env_response.reward)
if self.ap.visualization.dump_in_episode_signals:
self.update_step_in_episode_log()
return self.observe_transition(transition)
return transition.game_over
def observe_transition(self, transition):
# sum up the total shaped reward
self.total_shaped_reward_in_current_episode += transition.reward
self.shaped_reward.add_sample(transition.reward)
# create and store the transition
if self.phase in [RunPhase.TRAIN, RunPhase.HEATUP]:
# for episodic memories we keep the transitions in a local buffer until the episode is ended.
# for regular memories we insert the transitions directly to the memory
self.current_episode_buffer.insert(transition)
if not isinstance(self.memory, EpisodicExperienceReplay) \
and not self.ap.algorithm.store_transitions_only_when_episodes_are_terminated:
self.call_memory('store', transition)
if self.ap.visualization.dump_in_episode_signals:
self.update_step_in_episode_log()
return transition.game_over
def post_training_commands(self) -> None:
"""
@@ -1009,38 +1013,6 @@ class Agent(AgentInterface):
for network in self.networks.values():
network.sync()
# TODO-remove - this is a temporary flow, used by the trainer worker, duplicated from observe() - need to create
# an external trainer flow reusing the existing flow and methods [e.g. observe(), step(), act()]
def emulate_observe_on_trainer(self, transition: Transition) -> bool:
"""
This emulates the observe using the transition obtained from the rollout worker on the training worker
in case of distributed training.
Given a response from the environment, distill the observation from it and store it for later use.
The response should be a dictionary containing the performed action, the new observation and measurements,
the reward, a game over flag and any additional information necessary.
:return:
"""
# sum up the total shaped reward
self.total_shaped_reward_in_current_episode += transition.reward
self.total_reward_in_current_episode += transition.reward
self.shaped_reward.add_sample(transition.reward)
self.reward.add_sample(transition.reward)
# create and store the transition
if self.phase in [RunPhase.TRAIN, RunPhase.HEATUP]:
# for episodic memories we keep the transitions in a local buffer until the episode is ended.
# for regular memories we insert the transitions directly to the memory
self.current_episode_buffer.insert(transition)
if not isinstance(self.memory, EpisodicExperienceReplay) \
and not self.ap.algorithm.store_transitions_only_when_episodes_are_terminated:
self.call_memory('store', transition)
if self.ap.visualization.dump_in_episode_signals:
self.update_step_in_episode_log()
return transition.game_over
def get_success_rate(self) -> float:
return self.num_successes_across_evaluation_episodes / self.num_evaluation_episodes_completed