mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
Batch RL (#238)
This commit is contained in:
@@ -40,9 +40,10 @@ class LevelManager(EnvironmentInterface):
|
||||
name: str,
|
||||
agents: Union['Agent', CompositeAgent, Dict[str, Union['Agent', CompositeAgent]]],
|
||||
environment: Union['LevelManager', Environment],
|
||||
real_environment: Environment=None,
|
||||
steps_limit: EnvironmentSteps=EnvironmentSteps(1),
|
||||
should_reset_agent_state_after_time_limit_passes: bool=False
|
||||
real_environment: Environment = None,
|
||||
steps_limit: EnvironmentSteps = EnvironmentSteps(1),
|
||||
should_reset_agent_state_after_time_limit_passes: bool = False,
|
||||
spaces_definition: SpacesDefinition = None
|
||||
):
|
||||
"""
|
||||
A level manager controls a single or multiple composite agents and a single environment.
|
||||
@@ -56,6 +57,7 @@ class LevelManager(EnvironmentInterface):
|
||||
:param steps_limit: the number of time steps to run when stepping the internal components
|
||||
:param should_reset_agent_state_after_time_limit_passes: reset the agent after stepping for steps_limit
|
||||
:param name: the level's name
|
||||
:param spaces_definition: external definition of spaces for when we don't have an environment (e.g. batch-rl)
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
@@ -85,9 +87,11 @@ class LevelManager(EnvironmentInterface):
|
||||
|
||||
if not isinstance(self.steps_limit, EnvironmentSteps):
|
||||
raise ValueError("The num consecutive steps for acting must be defined in terms of environment steps")
|
||||
self.build()
|
||||
self.build(spaces_definition)
|
||||
|
||||
# there are cases where we don't have an environment. e.g. in batch-rl or in imitation learning.
|
||||
self.last_env_response = self.real_environment.last_env_response if self.real_environment else None
|
||||
|
||||
self.last_env_response = self.real_environment.last_env_response
|
||||
self.parent_graph_manager = None
|
||||
|
||||
def handle_episode_ended(self) -> None:
|
||||
@@ -100,13 +104,13 @@ class LevelManager(EnvironmentInterface):
|
||||
def reset_internal_state(self, force_environment_reset: bool = False) -> EnvResponse:
|
||||
"""
|
||||
Reset the environment episode parameters
|
||||
:param force_enviro nment_reset: in some cases, resetting the environment can be suppressed by the environment
|
||||
:param force_environment_reset: in some cases, resetting the environment can be suppressed by the environment
|
||||
itself. This flag allows force the reset.
|
||||
:return: the environment response as returned in get_last_env_response
|
||||
"""
|
||||
[agent.reset_internal_state() for agent in self.agents.values()]
|
||||
self.reset_required = False
|
||||
if self.real_environment.current_episode_steps_counter == 0:
|
||||
if self.real_environment and self.real_environment.current_episode_steps_counter == 0:
|
||||
self.last_env_response = self.real_environment.last_env_response
|
||||
return self.last_env_response
|
||||
|
||||
@@ -136,19 +140,27 @@ class LevelManager(EnvironmentInterface):
|
||||
"""
|
||||
return {k: ActionInfo(v) for k, v in self.get_random_action().items()}
|
||||
|
||||
def build(self) -> None:
|
||||
def build(self, spaces_definition: SpacesDefinition = None) -> None:
|
||||
"""
|
||||
Build all the internal components of the level manager (composite agents and environment).
|
||||
:param spaces_definition: external definition of spaces for when we don't have an environment (e.g. batch-rl)
|
||||
:return: None
|
||||
"""
|
||||
# TODO: move the spaces definition class to the environment?
|
||||
action_space = self.environment.action_space
|
||||
if isinstance(action_space, dict): # TODO: shouldn't be a dict
|
||||
action_space = list(action_space.values())[0]
|
||||
spaces = SpacesDefinition(state=self.real_environment.state_space,
|
||||
goal=self.real_environment.goal_space, # in HRL the agent might want to override this
|
||||
action=action_space,
|
||||
reward=self.real_environment.reward_space)
|
||||
if spaces_definition is None:
|
||||
# normally the spaces are defined by the environment, and we only gather these here
|
||||
action_space = self.environment.action_space
|
||||
|
||||
if isinstance(action_space, dict): # TODO: shouldn't be a dict
|
||||
action_space = list(action_space.values())[0]
|
||||
|
||||
spaces = SpacesDefinition(state=self.real_environment.state_space,
|
||||
goal=self.real_environment.goal_space,
|
||||
# in HRL the agent might want to override this
|
||||
action=action_space,
|
||||
reward=self.real_environment.reward_space)
|
||||
else:
|
||||
spaces = spaces_definition
|
||||
|
||||
[agent.set_environment_parameters(spaces) for agent in self.agents.values()]
|
||||
|
||||
def setup_logger(self) -> None:
|
||||
|
||||
Reference in New Issue
Block a user