1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 19:20:19 +01:00

Adding checkpointing framework (#74)

* Adding checkpointing framework as well as mxnet checkpointing implementation.

- MXNet checkpoint for each network is saved in a separate file.

* Adding checkpoint restore for mxnet to graph-manager

* Add unit-test for get_checkpoint_state()

* Added match.group() to fix unit-test failing on CI

* Added ONNX export support for MXNet
This commit is contained in:
Sina Afrooze
2018-11-19 09:45:49 -08:00
committed by shadiendrawis
parent 4da56b1ff2
commit 67eb9e4c28
19 changed files with 598 additions and 29 deletions

View File

@@ -30,6 +30,7 @@ from rl_coach.core_types import RunPhase, PredictionType, EnvironmentEpisodes, A
from rl_coach.core_types import Transition, ActionInfo, TrainingSteps, EnvironmentSteps, EnvResponse
from rl_coach.logger import screen, Logger, EpisodeLogger
from rl_coach.memories.episodic.episodic_experience_replay import EpisodicExperienceReplay
from rl_coach.saver import SaverCollection
from rl_coach.spaces import SpacesDefinition, VectorObservationSpace, GoalsSpace, AttentionActionSpace
from rl_coach.utils import Signal, force_list
from rl_coach.utils import dynamic_import_and_instantiate_module_from_params
@@ -996,3 +997,16 @@ class Agent(AgentInterface):
def get_success_rate(self) -> float:
return self.num_successes_across_evaluation_episodes / self.num_evaluation_episodes_completed
def collect_savers(self, parent_path_suffix: str) -> SaverCollection:
"""
Collect all of agent's network savers
:param parent_path_suffix: path suffix of the parent of the agent
(could be name of level manager or composite agent)
:return: collection of all agent savers
"""
parent_path_suffix = "{}.{}".format(parent_path_suffix, self.name)
savers = SaverCollection()
for network in self.networks.values():
savers.update(network.collect_savers(parent_path_suffix))
return savers