mirror of
https://github.com/gryf/coach.git
synced 2025-12-18 03:30:19 +01:00
Adding checkpointing framework (#74)
* Adding checkpointing framework as well as mxnet checkpointing implementation. - MXNet checkpoint for each network is saved in a separate file. * Adding checkpoint restore for mxnet to graph-manager * Add unit-test for get_checkpoint_state() * Added match.group() to fix unit-test failing on CI * Added ONNX export support for MXNet
This commit is contained in:
committed by
shadiendrawis
parent
4da56b1ff2
commit
67eb9e4c28
@@ -30,6 +30,7 @@ from rl_coach.core_types import RunPhase, PredictionType, EnvironmentEpisodes, A
|
||||
from rl_coach.core_types import Transition, ActionInfo, TrainingSteps, EnvironmentSteps, EnvResponse
|
||||
from rl_coach.logger import screen, Logger, EpisodeLogger
|
||||
from rl_coach.memories.episodic.episodic_experience_replay import EpisodicExperienceReplay
|
||||
from rl_coach.saver import SaverCollection
|
||||
from rl_coach.spaces import SpacesDefinition, VectorObservationSpace, GoalsSpace, AttentionActionSpace
|
||||
from rl_coach.utils import Signal, force_list
|
||||
from rl_coach.utils import dynamic_import_and_instantiate_module_from_params
|
||||
@@ -996,3 +997,16 @@ class Agent(AgentInterface):
|
||||
|
||||
def get_success_rate(self) -> float:
|
||||
return self.num_successes_across_evaluation_episodes / self.num_evaluation_episodes_completed
|
||||
|
||||
def collect_savers(self, parent_path_suffix: str) -> SaverCollection:
|
||||
"""
|
||||
Collect all of agent's network savers
|
||||
:param parent_path_suffix: path suffix of the parent of the agent
|
||||
(could be name of level manager or composite agent)
|
||||
:return: collection of all agent savers
|
||||
"""
|
||||
parent_path_suffix = "{}.{}".format(parent_path_suffix, self.name)
|
||||
savers = SaverCollection()
|
||||
for network in self.networks.values():
|
||||
savers.update(network.collect_savers(parent_path_suffix))
|
||||
return savers
|
||||
|
||||
@@ -19,6 +19,7 @@ from typing import Union, List, Dict
|
||||
import numpy as np
|
||||
|
||||
from rl_coach.core_types import EnvResponse, ActionInfo, RunPhase, PredictionType, ActionType, Transition
|
||||
from rl_coach.saver import SaverCollection
|
||||
|
||||
|
||||
class AgentInterface(object):
|
||||
@@ -153,3 +154,12 @@ class AgentInterface(object):
|
||||
:return: A tuple containing the actual action and additional info on the action
|
||||
"""
|
||||
raise NotImplementedError("")
|
||||
|
||||
def collect_savers(self, parent_path_suffix: str) -> SaverCollection:
|
||||
"""
|
||||
Collect all of agent savers
|
||||
:param parent_path_suffix: path suffix of the parent of the agent
|
||||
(could be name of level manager or composite agent)
|
||||
:return: collection of all agent savers
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -25,6 +25,7 @@ from rl_coach.agents.agent_interface import AgentInterface
|
||||
from rl_coach.base_parameters import AgentParameters, VisualizationParameters
|
||||
from rl_coach.core_types import ActionInfo, EnvResponse, ActionType, RunPhase
|
||||
from rl_coach.filters.observation.observation_crop_filter import ObservationCropFilter
|
||||
from rl_coach.saver import SaverCollection
|
||||
from rl_coach.spaces import ActionSpace
|
||||
from rl_coach.spaces import AgentSelection, AttentionActionSpace, SpacesDefinition
|
||||
from rl_coach.utils import short_dynamic_import
|
||||
@@ -412,3 +413,16 @@ class CompositeAgent(AgentInterface):
|
||||
:return:
|
||||
"""
|
||||
[agent.sync() for agent in self.agents.values()]
|
||||
|
||||
def collect_savers(self, parent_path_suffix: str) -> SaverCollection:
|
||||
"""
|
||||
Collect all of agent's network savers
|
||||
:param parent_path_suffix: path suffix of the parent of the agent
|
||||
(could be name of level manager or composite agent)
|
||||
:return: collection of all agent savers
|
||||
"""
|
||||
savers = SaverCollection()
|
||||
for agent in self.agents.values():
|
||||
savers.update(agent.collect_savers(
|
||||
parent_path_suffix="{}.{}".format(parent_path_suffix, self.name)))
|
||||
return savers
|
||||
|
||||
Reference in New Issue
Block a user