mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
Adding checkpointing framework (#74)
* Adding checkpointing framework as well as mxnet checkpointing implementation. - MXNet checkpoint for each network is saved in a separate file. * Adding checkpoint restore for mxnet to graph-manager * Add unit-test for get_checkpoint_state() * Added match.group() to fix unit-test failing on CI * Added ONNX export support for MXNet
This commit is contained in:
committed by
shadiendrawis
parent
4da56b1ff2
commit
67eb9e4c28
@@ -20,6 +20,7 @@ from rl_coach.agents.composite_agent import CompositeAgent
|
||||
from rl_coach.core_types import EnvResponse, ActionInfo, RunPhase, ActionType, EnvironmentSteps, Transition
|
||||
from rl_coach.environments.environment import Environment
|
||||
from rl_coach.environments.environment_interface import EnvironmentInterface
|
||||
from rl_coach.saver import SaverCollection
|
||||
from rl_coach.spaces import ActionSpace, SpacesDefinition
|
||||
|
||||
|
||||
@@ -292,3 +293,13 @@ class LevelManager(EnvironmentInterface):
|
||||
|
||||
def should_stop(self) -> bool:
|
||||
return all([agent.get_success_rate() >= self.environment.get_target_success_rate() for agent in self.agents.values()])
|
||||
|
||||
def collect_savers(self) -> SaverCollection:
|
||||
"""
|
||||
Calls collect_savers() on all agents and combines the results to a single collection
|
||||
:return: saver collection of all agent savers
|
||||
"""
|
||||
savers = SaverCollection()
|
||||
for agent in self.agents.values():
|
||||
savers.update(agent.collect_savers(parent_path_suffix=self.name))
|
||||
return savers
|
||||
|
||||
Reference in New Issue
Block a user