mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 11:10:20 +01:00
update of api docstrings across coach and tutorials [WIP] (#91)
* updating the documentation website * adding the built docs * update of api docstrings across coach and tutorials 0-2 * added some missing api documentation * New Sphinx based documentation
This commit is contained in:
@@ -39,7 +39,7 @@ from rl_coach.memories.backend.memory_impl import get_memory_backend
|
||||
class Agent(AgentInterface):
|
||||
def __init__(self, agent_parameters: AgentParameters, parent: Union['LevelManager', 'CompositeAgent']=None):
|
||||
"""
|
||||
:param agent_parameters: A Preset class instance with all the running paramaters
|
||||
:param agent_parameters: A AgentParameters class instance with all the agent parameters
|
||||
"""
|
||||
super().__init__()
|
||||
self.ap = agent_parameters
|
||||
@@ -175,18 +175,20 @@ class Agent(AgentInterface):
|
||||
np.random.seed()
|
||||
|
||||
@property
|
||||
def parent(self):
|
||||
def parent(self) -> 'LevelManager':
|
||||
"""
|
||||
Get the parent class of the agent
|
||||
|
||||
:return: the current phase
|
||||
"""
|
||||
return self._parent
|
||||
|
||||
@parent.setter
|
||||
def parent(self, val):
|
||||
def parent(self, val) -> None:
|
||||
"""
|
||||
Change the parent class of the agent.
|
||||
Additionally, updates the full name of the agent
|
||||
|
||||
:param val: the new parent
|
||||
:return: None
|
||||
"""
|
||||
@@ -196,7 +198,12 @@ class Agent(AgentInterface):
|
||||
raise ValueError("The parent of an agent must have a name")
|
||||
self.full_name_id = self.ap.full_name_id = "{}/{}".format(self._parent.name, self.name)
|
||||
|
||||
def setup_logger(self):
|
||||
def setup_logger(self) -> None:
|
||||
"""
|
||||
Setup the logger for the agent
|
||||
|
||||
:return: None
|
||||
"""
|
||||
# dump documentation
|
||||
logger_prefix = "{graph_name}.{level_name}.{agent_full_id}".\
|
||||
format(graph_name=self.parent_level_manager.parent_graph_manager.name,
|
||||
@@ -212,6 +219,7 @@ class Agent(AgentInterface):
|
||||
def set_session(self, sess) -> None:
|
||||
"""
|
||||
Set the deep learning framework session for all the agents in the composite agent
|
||||
|
||||
:return: None
|
||||
"""
|
||||
self.input_filter.set_session(sess)
|
||||
@@ -223,6 +231,7 @@ class Agent(AgentInterface):
|
||||
dump_one_value_per_step: bool=False) -> Signal:
|
||||
"""
|
||||
Register a signal such that its statistics will be dumped and be viewable through dashboard
|
||||
|
||||
:param signal_name: the name of the signal as it will appear in dashboard
|
||||
:param dump_one_value_per_episode: should the signal value be written for each episode?
|
||||
:param dump_one_value_per_step: should the signal value be written for each step?
|
||||
@@ -239,6 +248,7 @@ class Agent(AgentInterface):
|
||||
"""
|
||||
Sets the parameters that are environment dependent. As a side effect, initializes all the components that are
|
||||
dependent on those values, by calling init_environment_dependent_modules
|
||||
|
||||
:param spaces: the environment spaces definition
|
||||
:return: None
|
||||
"""
|
||||
@@ -274,6 +284,7 @@ class Agent(AgentInterface):
|
||||
Create all the networks of the agent.
|
||||
The network creation will be done after setting the environment parameters for the agent, since they are needed
|
||||
for creating the network.
|
||||
|
||||
:return: A list containing all the networks
|
||||
"""
|
||||
networks = {}
|
||||
@@ -295,6 +306,7 @@ class Agent(AgentInterface):
|
||||
"""
|
||||
Initialize any modules that depend on knowing information about the environment such as the action space or
|
||||
the observation space
|
||||
|
||||
:return: None
|
||||
"""
|
||||
# initialize exploration policy
|
||||
@@ -314,13 +326,19 @@ class Agent(AgentInterface):
|
||||
|
||||
@property
|
||||
def phase(self) -> RunPhase:
|
||||
"""
|
||||
The current running phase of the agent
|
||||
|
||||
:return: RunPhase
|
||||
"""
|
||||
return self._phase
|
||||
|
||||
@phase.setter
|
||||
def phase(self, val: RunPhase) -> None:
|
||||
"""
|
||||
Change the phase of the run for the agent and all the sub components
|
||||
:param phase: the new run phase (TRAIN, TEST, etc.)
|
||||
|
||||
:param val: the new run phase (TRAIN, TEST, etc.)
|
||||
:return: None
|
||||
"""
|
||||
self.reset_evaluation_state(val)
|
||||
@@ -328,6 +346,14 @@ class Agent(AgentInterface):
|
||||
self.exploration_policy.change_phase(val)
|
||||
|
||||
def reset_evaluation_state(self, val: RunPhase) -> None:
|
||||
"""
|
||||
Perform accumulators initialization when entering an evaluation phase, and signal dumping when exiting an
|
||||
evaluation phase. Entering or exiting the evaluation phase is determined according to the new phase given
|
||||
by val, and by the current phase set in self.phase.
|
||||
|
||||
:param val: The new phase to change to
|
||||
:return: None
|
||||
"""
|
||||
starting_evaluation = (val == RunPhase.TEST)
|
||||
ending_evaluation = (self.phase == RunPhase.TEST)
|
||||
|
||||
@@ -363,6 +389,7 @@ class Agent(AgentInterface):
|
||||
This function is a wrapper to allow having the same calls for shared or unshared memories.
|
||||
It should be used instead of calling the memory directly in order to allow different algorithms to work
|
||||
both with a shared and a local memory.
|
||||
|
||||
:param func: the name of the memory function to call
|
||||
:param args: the arguments to supply to the function
|
||||
:return: the return value of the function
|
||||
@@ -375,7 +402,12 @@ class Agent(AgentInterface):
|
||||
result = getattr(self.memory, func)(*args)
|
||||
return result
|
||||
|
||||
def log_to_screen(self):
|
||||
def log_to_screen(self) -> None:
|
||||
"""
|
||||
Write an episode summary line to the terminal
|
||||
|
||||
:return: None
|
||||
"""
|
||||
# log to screen
|
||||
log = OrderedDict()
|
||||
log["Name"] = self.full_name_id
|
||||
@@ -388,9 +420,10 @@ class Agent(AgentInterface):
|
||||
log["Training iteration"] = self.training_iteration
|
||||
screen.log_dict(log, prefix=self.phase.value)
|
||||
|
||||
def update_step_in_episode_log(self):
|
||||
def update_step_in_episode_log(self) -> None:
|
||||
"""
|
||||
Writes logging messages to screen and updates the log file with all the signal values.
|
||||
Updates the in-episode log file with all the signal values from the most recent step.
|
||||
|
||||
:return: None
|
||||
"""
|
||||
# log all the signals to file
|
||||
@@ -411,9 +444,12 @@ class Agent(AgentInterface):
|
||||
# dump
|
||||
self.agent_episode_logger.dump_output_csv()
|
||||
|
||||
def update_log(self):
|
||||
def update_log(self) -> None:
|
||||
"""
|
||||
Writes logging messages to screen and updates the log file with all the signal values.
|
||||
Updates the episodic log file with all the signal values from the most recent episode.
|
||||
Additional signals for logging can be set by the creating a new signal using self.register_signal,
|
||||
and then updating it with some internal agent values.
|
||||
|
||||
:return: None
|
||||
"""
|
||||
# log all the signals to file
|
||||
@@ -438,7 +474,6 @@ class Agent(AgentInterface):
|
||||
self.agent_logger.create_signal_value('Shaped Evaluation Reward', np.nan, overwrite=False)
|
||||
self.agent_logger.create_signal_value('Success Rate', np.nan, overwrite=False)
|
||||
|
||||
|
||||
for signal in self.episode_signals:
|
||||
self.agent_logger.create_signal_value("{}/Mean".format(signal.name), signal.get_mean())
|
||||
self.agent_logger.create_signal_value("{}/Stdev".format(signal.name), signal.get_stdev())
|
||||
@@ -452,7 +487,10 @@ class Agent(AgentInterface):
|
||||
|
||||
def handle_episode_ended(self) -> None:
|
||||
"""
|
||||
End an episode
|
||||
Make any changes needed when each episode is ended.
|
||||
This includes incrementing counters, updating full episode dependent values, updating logs, etc.
|
||||
This function is called right after each episode is ended.
|
||||
|
||||
:return: None
|
||||
"""
|
||||
self.current_episode_buffer.is_complete = True
|
||||
@@ -486,9 +524,10 @@ class Agent(AgentInterface):
|
||||
if self.ap.is_a_highest_level_agent or self.ap.task_parameters.verbosity == "high":
|
||||
self.log_to_screen()
|
||||
|
||||
def reset_internal_state(self):
|
||||
def reset_internal_state(self) -> None:
|
||||
"""
|
||||
Reset all the episodic parameters
|
||||
Reset all the episodic parameters. This function is called right before each episode starts.
|
||||
|
||||
:return: None
|
||||
"""
|
||||
for signal in self.episode_signals:
|
||||
@@ -516,6 +555,7 @@ class Agent(AgentInterface):
|
||||
def learn_from_batch(self, batch) -> Tuple[float, List, List]:
|
||||
"""
|
||||
Given a batch of transitions, calculates their target values and updates the network.
|
||||
|
||||
:param batch: A list of transitions
|
||||
:return: The total loss of the training, the loss per head and the unclipped gradients
|
||||
"""
|
||||
@@ -524,6 +564,7 @@ class Agent(AgentInterface):
|
||||
def _should_update_online_weights_to_target(self):
|
||||
"""
|
||||
Determine if online weights should be copied to the target.
|
||||
|
||||
:return: boolean: True if the online weights should be copied to the target.
|
||||
"""
|
||||
|
||||
@@ -542,9 +583,10 @@ class Agent(AgentInterface):
|
||||
"EnvironmentSteps or TrainingSteps. Instead it is {}".format(step_method.__class__))
|
||||
return should_update
|
||||
|
||||
def _should_train(self, wait_for_full_episode=False):
|
||||
def _should_train(self, wait_for_full_episode=False) -> bool:
|
||||
"""
|
||||
Determine if we should start a training phase according to the number of steps passed since the last training
|
||||
|
||||
:return: boolean: True if we should start a training phase
|
||||
"""
|
||||
|
||||
@@ -580,11 +622,12 @@ class Agent(AgentInterface):
|
||||
|
||||
return should_update
|
||||
|
||||
def train(self):
|
||||
def train(self) -> float:
|
||||
"""
|
||||
Check if a training phase should be done as configured by num_consecutive_playing_steps.
|
||||
If it should, then do several training steps as configured by num_consecutive_training_steps.
|
||||
A single training iteration: Sample a batch, train on it and update target networks.
|
||||
|
||||
:return: The total training loss during the training iterations.
|
||||
"""
|
||||
loss = 0
|
||||
@@ -641,14 +684,12 @@ class Agent(AgentInterface):
|
||||
# run additional commands after the training is done
|
||||
self.post_training_commands()
|
||||
|
||||
|
||||
|
||||
return loss
|
||||
|
||||
def choose_action(self, curr_state):
|
||||
"""
|
||||
choose an action to act with in the current episode being played. Different behavior might be exhibited when training
|
||||
or testing.
|
||||
choose an action to act with in the current episode being played. Different behavior might be exhibited when
|
||||
training or testing.
|
||||
|
||||
:param curr_state: the current state to act upon.
|
||||
:return: chosen action, some action value describing the action (q-value, probability, etc)
|
||||
@@ -656,10 +697,16 @@ class Agent(AgentInterface):
|
||||
pass
|
||||
|
||||
def prepare_batch_for_inference(self, states: Union[Dict[str, np.ndarray], List[Dict[str, np.ndarray]]],
|
||||
network_name: str):
|
||||
network_name: str) -> Dict[str, np.array]:
|
||||
"""
|
||||
convert curr_state into input tensors tensorflow is expecting. i.e. if we have several inputs states, stack all
|
||||
Convert curr_state into input tensors tensorflow is expecting. i.e. if we have several inputs states, stack all
|
||||
observations together, measurements together, etc.
|
||||
|
||||
:param states: A list of environment states, where each one is a dict mapping from an observation name to its
|
||||
corresponding observation
|
||||
:param network_name: The agent network name to prepare the batch for. this is needed in order to extract only
|
||||
the observation relevant for the network from the states.
|
||||
:return: A dictionary containing a list of values from all the given states for each of the observations
|
||||
"""
|
||||
# convert to batch so we can run it through the network
|
||||
states = force_list(states)
|
||||
@@ -676,7 +723,8 @@ class Agent(AgentInterface):
|
||||
def act(self) -> ActionInfo:
|
||||
"""
|
||||
Given the agents current knowledge, decide on the next action to apply to the environment
|
||||
:return: an action and a dictionary containing any additional info from the action decision process
|
||||
|
||||
:return: An ActionInfo object, which contains the action and any additional info from the action decision process
|
||||
"""
|
||||
if self.phase == RunPhase.TRAIN and self.ap.algorithm.num_consecutive_playing_steps.num_steps == 0:
|
||||
# This agent never plays while training (e.g. behavioral cloning)
|
||||
@@ -705,13 +753,20 @@ class Agent(AgentInterface):
|
||||
|
||||
return filtered_action_info
|
||||
|
||||
def run_pre_network_filter_for_inference(self, state: StateType):
|
||||
def run_pre_network_filter_for_inference(self, state: StateType) -> StateType:
|
||||
"""
|
||||
Run filters which where defined for being applied right before using the state for inference.
|
||||
|
||||
:param state: The state to run the filters on
|
||||
:return: The filtered state
|
||||
"""
|
||||
dummy_env_response = EnvResponse(next_state=state, reward=0, game_over=False)
|
||||
return self.pre_network_filter.filter(dummy_env_response)[0].next_state
|
||||
|
||||
def get_state_embedding(self, state: dict) -> np.ndarray:
|
||||
"""
|
||||
Given a state, get the corresponding state embedding from the main network
|
||||
|
||||
:param state: a state dict
|
||||
:return: a numpy embedding vector
|
||||
"""
|
||||
@@ -726,6 +781,7 @@ class Agent(AgentInterface):
|
||||
"""
|
||||
Allows agents to update the transition just before adding it to the replay buffer.
|
||||
Can be useful for agents that want to tweak the reward, termination signal, etc.
|
||||
|
||||
:param transition: the transition to update
|
||||
:return: the updated transition
|
||||
"""
|
||||
@@ -736,8 +792,10 @@ class Agent(AgentInterface):
|
||||
Given a response from the environment, distill the observation from it and store it for later use.
|
||||
The response should be a dictionary containing the performed action, the new observation and measurements,
|
||||
the reward, a game over flag and any additional information necessary.
|
||||
|
||||
:param env_response: result of call from environment.step(action)
|
||||
:return:
|
||||
:return: a boolean value which determines if the agent has decided to terminate the episode after seeing the
|
||||
given observation
|
||||
"""
|
||||
|
||||
# filter the env_response
|
||||
@@ -801,7 +859,12 @@ class Agent(AgentInterface):
|
||||
|
||||
return transition.game_over
|
||||
|
||||
def post_training_commands(self):
|
||||
def post_training_commands(self) -> None:
|
||||
"""
|
||||
A function which allows adding any functionality that is required to run right after the training phase ends.
|
||||
|
||||
:return: None
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_predictions(self, states: List[Dict[str, np.ndarray]], prediction_type: PredictionType):
|
||||
@@ -809,9 +872,10 @@ class Agent(AgentInterface):
|
||||
Get a prediction from the agent with regard to the requested prediction_type.
|
||||
If the agent cannot predict this type of prediction_type, or if there is more than possible way to do so,
|
||||
raise a ValueException.
|
||||
:param states:
|
||||
:param prediction_type:
|
||||
:return:
|
||||
|
||||
:param states: The states to get a prediction for
|
||||
:param prediction_type: The type of prediction to get for the states. For example, the state-value prediction.
|
||||
:return: the predicted values
|
||||
"""
|
||||
|
||||
predictions = self.networks['main'].online_network.predict_with_prediction_type(
|
||||
@@ -824,6 +888,15 @@ class Agent(AgentInterface):
|
||||
return list(predictions.values())[0]
|
||||
|
||||
def set_incoming_directive(self, action: ActionType) -> None:
|
||||
"""
|
||||
Allows setting a directive for the agent to follow. This is useful in hierarchy structures, where the agent
|
||||
has another master agent that is controlling it. In such cases, the master agent can define the goals for the
|
||||
slave agent, define it's observation, possible actions, etc. The directive type is defined by the agent
|
||||
in-action-space.
|
||||
|
||||
:param action: The action that should be set as the directive
|
||||
:return:
|
||||
"""
|
||||
if isinstance(self.in_action_space, GoalsSpace):
|
||||
self.current_hrl_goal = action
|
||||
elif isinstance(self.in_action_space, AttentionActionSpace):
|
||||
@@ -834,6 +907,7 @@ class Agent(AgentInterface):
|
||||
def save_checkpoint(self, checkpoint_id: int) -> None:
|
||||
"""
|
||||
Allows agents to store additional information when saving checkpoints.
|
||||
|
||||
:param checkpoint_id: the id of the checkpoint
|
||||
:return: None
|
||||
"""
|
||||
@@ -842,6 +916,7 @@ class Agent(AgentInterface):
|
||||
def sync(self) -> None:
|
||||
"""
|
||||
Sync the global network parameters to local networks
|
||||
|
||||
:return: None
|
||||
"""
|
||||
for network in self.networks.values():
|
||||
|
||||
Reference in New Issue
Block a user