diff --git a/rl_coach/agents/agent_interface.py b/rl_coach/agents/agent_interface.py index f3b7903..e8aba49 100644 --- a/rl_coach/agents/agent_interface.py +++ b/rl_coach/agents/agent_interface.py @@ -162,4 +162,14 @@ class AgentInterface(object): (could be name of level manager or composite agent) :return: collection of all agent savers """ - raise NotImplementedError + raise NotImplementedError("") + + def handle_episode_ended(self) -> None: + """ + Make any changes needed when each episode is ended. + This includes incrementing counters, updating full episode dependent values, updating logs, etc. + This function is called right after each episode is ended. + + :return: None + """ + raise NotImplementedError("") diff --git a/rl_coach/agents/composite_agent.py b/rl_coach/agents/composite_agent.py index 33b437b..a3cf747 100644 --- a/rl_coach/agents/composite_agent.py +++ b/rl_coach/agents/composite_agent.py @@ -305,9 +305,12 @@ class CompositeAgent(AgentInterface): for agent in self.agents.values(): agent.phase = val - def end_episode(self) -> None: + def handle_episode_ended(self) -> None: """ - End an episode + Make any changes needed when each episode is ended. + This includes incrementing counters, updating full episode dependent values, updating logs, etc. + This function is called right after each episode is ended. + :return: None """ self.current_episode += 1