1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 11:10:20 +01:00

update of api docstrings across coach and tutorials [WIP] (#91)

* updating the documentation website
* adding the built docs
* update of api docstrings across coach and tutorials 0-2
* added some missing api documentation
* New Sphinx based documentation
This commit is contained in:
Itai Caspi
2018-11-15 15:00:13 +02:00
committed by Gal Novik
parent 524f8436a2
commit 6d40ad1650
517 changed files with 71034 additions and 12834 deletions

View File

@@ -39,7 +39,7 @@ from rl_coach.memories.backend.memory_impl import get_memory_backend
class Agent(AgentInterface):
def __init__(self, agent_parameters: AgentParameters, parent: Union['LevelManager', 'CompositeAgent']=None):
"""
:param agent_parameters: A Preset class instance with all the running paramaters
:param agent_parameters: A AgentParameters class instance with all the agent parameters
"""
super().__init__()
self.ap = agent_parameters
@@ -175,18 +175,20 @@ class Agent(AgentInterface):
np.random.seed()
@property
def parent(self):
def parent(self) -> 'LevelManager':
"""
Get the parent class of the agent
:return: the current phase
"""
return self._parent
@parent.setter
def parent(self, val):
def parent(self, val) -> None:
"""
Change the parent class of the agent.
Additionally, updates the full name of the agent
:param val: the new parent
:return: None
"""
@@ -196,7 +198,12 @@ class Agent(AgentInterface):
raise ValueError("The parent of an agent must have a name")
self.full_name_id = self.ap.full_name_id = "{}/{}".format(self._parent.name, self.name)
def setup_logger(self):
def setup_logger(self) -> None:
"""
Setup the logger for the agent
:return: None
"""
# dump documentation
logger_prefix = "{graph_name}.{level_name}.{agent_full_id}".\
format(graph_name=self.parent_level_manager.parent_graph_manager.name,
@@ -212,6 +219,7 @@ class Agent(AgentInterface):
def set_session(self, sess) -> None:
"""
Set the deep learning framework session for all the agents in the composite agent
:return: None
"""
self.input_filter.set_session(sess)
@@ -223,6 +231,7 @@ class Agent(AgentInterface):
dump_one_value_per_step: bool=False) -> Signal:
"""
Register a signal such that its statistics will be dumped and be viewable through dashboard
:param signal_name: the name of the signal as it will appear in dashboard
:param dump_one_value_per_episode: should the signal value be written for each episode?
:param dump_one_value_per_step: should the signal value be written for each step?
@@ -239,6 +248,7 @@ class Agent(AgentInterface):
"""
Sets the parameters that are environment dependent. As a side effect, initializes all the components that are
dependent on those values, by calling init_environment_dependent_modules
:param spaces: the environment spaces definition
:return: None
"""
@@ -274,6 +284,7 @@ class Agent(AgentInterface):
Create all the networks of the agent.
The network creation will be done after setting the environment parameters for the agent, since they are needed
for creating the network.
:return: A list containing all the networks
"""
networks = {}
@@ -295,6 +306,7 @@ class Agent(AgentInterface):
"""
Initialize any modules that depend on knowing information about the environment such as the action space or
the observation space
:return: None
"""
# initialize exploration policy
@@ -314,13 +326,19 @@ class Agent(AgentInterface):
@property
def phase(self) -> RunPhase:
"""
The current running phase of the agent
:return: RunPhase
"""
return self._phase
@phase.setter
def phase(self, val: RunPhase) -> None:
"""
Change the phase of the run for the agent and all the sub components
:param phase: the new run phase (TRAIN, TEST, etc.)
:param val: the new run phase (TRAIN, TEST, etc.)
:return: None
"""
self.reset_evaluation_state(val)
@@ -328,6 +346,14 @@ class Agent(AgentInterface):
self.exploration_policy.change_phase(val)
def reset_evaluation_state(self, val: RunPhase) -> None:
"""
Perform accumulators initialization when entering an evaluation phase, and signal dumping when exiting an
evaluation phase. Entering or exiting the evaluation phase is determined according to the new phase given
by val, and by the current phase set in self.phase.
:param val: The new phase to change to
:return: None
"""
starting_evaluation = (val == RunPhase.TEST)
ending_evaluation = (self.phase == RunPhase.TEST)
@@ -363,6 +389,7 @@ class Agent(AgentInterface):
This function is a wrapper to allow having the same calls for shared or unshared memories.
It should be used instead of calling the memory directly in order to allow different algorithms to work
both with a shared and a local memory.
:param func: the name of the memory function to call
:param args: the arguments to supply to the function
:return: the return value of the function
@@ -375,7 +402,12 @@ class Agent(AgentInterface):
result = getattr(self.memory, func)(*args)
return result
def log_to_screen(self):
def log_to_screen(self) -> None:
"""
Write an episode summary line to the terminal
:return: None
"""
# log to screen
log = OrderedDict()
log["Name"] = self.full_name_id
@@ -388,9 +420,10 @@ class Agent(AgentInterface):
log["Training iteration"] = self.training_iteration
screen.log_dict(log, prefix=self.phase.value)
def update_step_in_episode_log(self):
def update_step_in_episode_log(self) -> None:
"""
Writes logging messages to screen and updates the log file with all the signal values.
Updates the in-episode log file with all the signal values from the most recent step.
:return: None
"""
# log all the signals to file
@@ -411,9 +444,12 @@ class Agent(AgentInterface):
# dump
self.agent_episode_logger.dump_output_csv()
def update_log(self):
def update_log(self) -> None:
"""
Writes logging messages to screen and updates the log file with all the signal values.
Updates the episodic log file with all the signal values from the most recent episode.
Additional signals for logging can be set by the creating a new signal using self.register_signal,
and then updating it with some internal agent values.
:return: None
"""
# log all the signals to file
@@ -438,7 +474,6 @@ class Agent(AgentInterface):
self.agent_logger.create_signal_value('Shaped Evaluation Reward', np.nan, overwrite=False)
self.agent_logger.create_signal_value('Success Rate', np.nan, overwrite=False)
for signal in self.episode_signals:
self.agent_logger.create_signal_value("{}/Mean".format(signal.name), signal.get_mean())
self.agent_logger.create_signal_value("{}/Stdev".format(signal.name), signal.get_stdev())
@@ -452,7 +487,10 @@ class Agent(AgentInterface):
def handle_episode_ended(self) -> None:
"""
End an episode
Make any changes needed when each episode is ended.
This includes incrementing counters, updating full episode dependent values, updating logs, etc.
This function is called right after each episode is ended.
:return: None
"""
self.current_episode_buffer.is_complete = True
@@ -486,9 +524,10 @@ class Agent(AgentInterface):
if self.ap.is_a_highest_level_agent or self.ap.task_parameters.verbosity == "high":
self.log_to_screen()
def reset_internal_state(self):
def reset_internal_state(self) -> None:
"""
Reset all the episodic parameters
Reset all the episodic parameters. This function is called right before each episode starts.
:return: None
"""
for signal in self.episode_signals:
@@ -516,6 +555,7 @@ class Agent(AgentInterface):
def learn_from_batch(self, batch) -> Tuple[float, List, List]:
"""
Given a batch of transitions, calculates their target values and updates the network.
:param batch: A list of transitions
:return: The total loss of the training, the loss per head and the unclipped gradients
"""
@@ -524,6 +564,7 @@ class Agent(AgentInterface):
def _should_update_online_weights_to_target(self):
"""
Determine if online weights should be copied to the target.
:return: boolean: True if the online weights should be copied to the target.
"""
@@ -542,9 +583,10 @@ class Agent(AgentInterface):
"EnvironmentSteps or TrainingSteps. Instead it is {}".format(step_method.__class__))
return should_update
def _should_train(self, wait_for_full_episode=False):
def _should_train(self, wait_for_full_episode=False) -> bool:
"""
Determine if we should start a training phase according to the number of steps passed since the last training
:return: boolean: True if we should start a training phase
"""
@@ -580,11 +622,12 @@ class Agent(AgentInterface):
return should_update
def train(self):
def train(self) -> float:
"""
Check if a training phase should be done as configured by num_consecutive_playing_steps.
If it should, then do several training steps as configured by num_consecutive_training_steps.
A single training iteration: Sample a batch, train on it and update target networks.
:return: The total training loss during the training iterations.
"""
loss = 0
@@ -641,14 +684,12 @@ class Agent(AgentInterface):
# run additional commands after the training is done
self.post_training_commands()
return loss
def choose_action(self, curr_state):
"""
choose an action to act with in the current episode being played. Different behavior might be exhibited when training
or testing.
choose an action to act with in the current episode being played. Different behavior might be exhibited when
training or testing.
:param curr_state: the current state to act upon.
:return: chosen action, some action value describing the action (q-value, probability, etc)
@@ -656,10 +697,16 @@ class Agent(AgentInterface):
pass
def prepare_batch_for_inference(self, states: Union[Dict[str, np.ndarray], List[Dict[str, np.ndarray]]],
network_name: str):
network_name: str) -> Dict[str, np.array]:
"""
convert curr_state into input tensors tensorflow is expecting. i.e. if we have several inputs states, stack all
Convert curr_state into input tensors tensorflow is expecting. i.e. if we have several inputs states, stack all
observations together, measurements together, etc.
:param states: A list of environment states, where each one is a dict mapping from an observation name to its
corresponding observation
:param network_name: The agent network name to prepare the batch for. this is needed in order to extract only
the observation relevant for the network from the states.
:return: A dictionary containing a list of values from all the given states for each of the observations
"""
# convert to batch so we can run it through the network
states = force_list(states)
@@ -676,7 +723,8 @@ class Agent(AgentInterface):
def act(self) -> ActionInfo:
"""
Given the agents current knowledge, decide on the next action to apply to the environment
:return: an action and a dictionary containing any additional info from the action decision process
:return: An ActionInfo object, which contains the action and any additional info from the action decision process
"""
if self.phase == RunPhase.TRAIN and self.ap.algorithm.num_consecutive_playing_steps.num_steps == 0:
# This agent never plays while training (e.g. behavioral cloning)
@@ -705,13 +753,20 @@ class Agent(AgentInterface):
return filtered_action_info
def run_pre_network_filter_for_inference(self, state: StateType):
def run_pre_network_filter_for_inference(self, state: StateType) -> StateType:
"""
Run filters which where defined for being applied right before using the state for inference.
:param state: The state to run the filters on
:return: The filtered state
"""
dummy_env_response = EnvResponse(next_state=state, reward=0, game_over=False)
return self.pre_network_filter.filter(dummy_env_response)[0].next_state
def get_state_embedding(self, state: dict) -> np.ndarray:
"""
Given a state, get the corresponding state embedding from the main network
:param state: a state dict
:return: a numpy embedding vector
"""
@@ -726,6 +781,7 @@ class Agent(AgentInterface):
"""
Allows agents to update the transition just before adding it to the replay buffer.
Can be useful for agents that want to tweak the reward, termination signal, etc.
:param transition: the transition to update
:return: the updated transition
"""
@@ -736,8 +792,10 @@ class Agent(AgentInterface):
Given a response from the environment, distill the observation from it and store it for later use.
The response should be a dictionary containing the performed action, the new observation and measurements,
the reward, a game over flag and any additional information necessary.
:param env_response: result of call from environment.step(action)
:return:
:return: a boolean value which determines if the agent has decided to terminate the episode after seeing the
given observation
"""
# filter the env_response
@@ -801,7 +859,12 @@ class Agent(AgentInterface):
return transition.game_over
def post_training_commands(self):
def post_training_commands(self) -> None:
"""
A function which allows adding any functionality that is required to run right after the training phase ends.
:return: None
"""
pass
def get_predictions(self, states: List[Dict[str, np.ndarray]], prediction_type: PredictionType):
@@ -809,9 +872,10 @@ class Agent(AgentInterface):
Get a prediction from the agent with regard to the requested prediction_type.
If the agent cannot predict this type of prediction_type, or if there is more than possible way to do so,
raise a ValueException.
:param states:
:param prediction_type:
:return:
:param states: The states to get a prediction for
:param prediction_type: The type of prediction to get for the states. For example, the state-value prediction.
:return: the predicted values
"""
predictions = self.networks['main'].online_network.predict_with_prediction_type(
@@ -824,6 +888,15 @@ class Agent(AgentInterface):
return list(predictions.values())[0]
def set_incoming_directive(self, action: ActionType) -> None:
"""
Allows setting a directive for the agent to follow. This is useful in hierarchy structures, where the agent
has another master agent that is controlling it. In such cases, the master agent can define the goals for the
slave agent, define it's observation, possible actions, etc. The directive type is defined by the agent
in-action-space.
:param action: The action that should be set as the directive
:return:
"""
if isinstance(self.in_action_space, GoalsSpace):
self.current_hrl_goal = action
elif isinstance(self.in_action_space, AttentionActionSpace):
@@ -834,6 +907,7 @@ class Agent(AgentInterface):
def save_checkpoint(self, checkpoint_id: int) -> None:
"""
Allows agents to store additional information when saving checkpoints.
:param checkpoint_id: the id of the checkpoint
:return: None
"""
@@ -842,6 +916,7 @@ class Agent(AgentInterface):
def sync(self) -> None:
"""
Sync the global network parameters to local networks
:return: None
"""
for network in self.networks.values():