1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 19:20:19 +01:00

Cleanup and refactoring (#171)

This commit is contained in:
Zach Dwiel
2019-01-15 03:04:53 -05:00
committed by Gal Leibovich
parent cd812b0d25
commit fedb4cbd7c
7 changed files with 45 additions and 33 deletions

View File

@@ -609,35 +609,35 @@ class Agent(AgentInterface):
:return: boolean: True if we should start a training phase :return: boolean: True if we should start a training phase
""" """
should_update = self._should_train_helper() should_update = self._should_update()
step_method = self.ap.algorithm.num_consecutive_playing_steps steps = self.ap.algorithm.num_consecutive_playing_steps
if should_update: if should_update:
if step_method.__class__ == EnvironmentEpisodes: if steps.__class__ == EnvironmentEpisodes:
self.last_training_phase_step = self.current_episode self.last_training_phase_step = self.current_episode
if step_method.__class__ == EnvironmentSteps: if steps.__class__ == EnvironmentSteps:
self.last_training_phase_step = self.total_steps_counter self.last_training_phase_step = self.total_steps_counter
return should_update return should_update
def _should_train_helper(self): def _should_update(self):
wait_for_full_episode = self.ap.algorithm.act_for_full_episodes wait_for_full_episode = self.ap.algorithm.act_for_full_episodes
step_method = self.ap.algorithm.num_consecutive_playing_steps steps = self.ap.algorithm.num_consecutive_playing_steps
if step_method.__class__ == EnvironmentEpisodes: if steps.__class__ == EnvironmentEpisodes:
should_update = (self.current_episode - self.last_training_phase_step) >= step_method.num_steps should_update = (self.current_episode - self.last_training_phase_step) >= steps.num_steps
should_update = should_update and self.call_memory('length') > 0 should_update = should_update and self.call_memory('length') > 0
elif step_method.__class__ == EnvironmentSteps: elif steps.__class__ == EnvironmentSteps:
should_update = (self.total_steps_counter - self.last_training_phase_step) >= step_method.num_steps should_update = (self.total_steps_counter - self.last_training_phase_step) >= steps.num_steps
should_update = should_update and self.call_memory('num_transitions') > 0 should_update = should_update and self.call_memory('num_transitions') > 0
if wait_for_full_episode: if wait_for_full_episode:
should_update = should_update and self.current_episode_buffer.is_complete should_update = should_update and self.current_episode_buffer.is_complete
else: else:
raise ValueError("The num_consecutive_playing_steps parameter should be either " raise ValueError("The num_consecutive_playing_steps parameter should be either "
"EnvironmentSteps or Episodes. Instead it is {}".format(step_method.__class__)) "EnvironmentSteps or Episodes. Instead it is {}".format(steps.__class__))
return should_update return should_update

View File

@@ -84,7 +84,8 @@ class DFPAlgorithmParameters(AlgorithmParameters):
""" """
:param num_predicted_steps_ahead: (int) :param num_predicted_steps_ahead: (int)
Number of future steps to predict measurements for. The future steps won't be sequential, but rather jump Number of future steps to predict measurements for. The future steps won't be sequential, but rather jump
in multiples of 2. For example, if num_predicted_steps_ahead = 3, then the steps will be: t+1, t+2, t+4 in multiples of 2. For example, if num_predicted_steps_ahead = 3, then the steps will be: t+1, t+2, t+4.
The predicted steps will be [t + 2**i for i in range(num_predicted_steps_ahead)]
:param goal_vector: (List[float]) :param goal_vector: (List[float])
The goal vector will weight each of the measurements to form an optimization goal. The vector should have The goal vector will weight each of the measurements to form an optimization goal. The vector should have

View File

@@ -125,22 +125,25 @@ class GeneralTensorFlowNetwork(TensorFlowArchitecture):
super().__init__(agent_parameters, spaces, name, global_network, super().__init__(agent_parameters, spaces, name, global_network,
network_is_local, network_is_trainable) network_is_local, network_is_trainable)
def fill_return_types(): self.available_return_types = self._available_return_types()
ret_dict = {} self.is_training = None
for cls in get_all_subclasses(PredictionType):
ret_dict[cls] = [] def _available_return_types(self):
ret_dict = {cls: [] for cls in get_all_subclasses(PredictionType)}
components = self.input_embedders + [self.middleware] + self.output_heads components = self.input_embedders + [self.middleware] + self.output_heads
for component in components: for component in components:
if not hasattr(component, 'return_type'): if not hasattr(component, 'return_type'):
raise ValueError("{} has no return_type attribute. This should not happen.") raise ValueError((
"{} has no return_type attribute. Without this, it is "
"unclear how this component should be used."
).format(component))
if component.return_type is not None: if component.return_type is not None:
ret_dict[component.return_type].append(component) ret_dict[component.return_type].append(component)
return ret_dict return ret_dict
self.available_return_types = fill_return_types()
self.is_training = None
def predict_with_prediction_type(self, states: Dict[str, np.ndarray], def predict_with_prediction_type(self, states: Dict[str, np.ndarray],
prediction_type: PredictionType) -> Dict[str, np.ndarray]: prediction_type: PredictionType) -> Dict[str, np.ndarray]:
""" """

View File

@@ -98,6 +98,10 @@ class Embedding(PredictionType):
pass pass
class Measurements(PredictionType):
pass
class InputEmbedding(Embedding): class InputEmbedding(Embedding):
pass pass
@@ -126,10 +130,6 @@ class Middleware_LSTM_Embedding(MiddlewareEmbedding):
pass pass
class Measurements(PredictionType):
pass
PlayingStepsType = Union[EnvironmentSteps, EnvironmentEpisodes, Frames] PlayingStepsType = Union[EnvironmentSteps, EnvironmentEpisodes, Frames]

View File

@@ -346,6 +346,10 @@ class GraphManager(object):
@contextlib.contextmanager @contextlib.contextmanager
def phase_context(self, phase): def phase_context(self, phase):
"""
Create a context which temporarily sets the phase to the provided phase.
The previous phase is restored afterwards.
"""
old_phase = self.phase old_phase = self.phase
self.phase = phase self.phase = phase
yield yield
@@ -464,8 +468,11 @@ class GraphManager(object):
count_end = self.current_step_counter + steps count_end = self.current_step_counter + steps
while self.current_step_counter < count_end: while self.current_step_counter < count_end:
# The actual steps being done on the environment are decided by the agents themselves. # The actual number of steps being done on the environment
# This is just an high-level controller. # is decided by the agent, though this inner loop always
# takes at least one step in the environment. Depending on
# internal counters and parameters, it doesn't always train
# or save checkpoints.
self.act(EnvironmentSteps(1)) self.act(EnvironmentSteps(1))
self.train() self.train()
self.occasionally_save_checkpoint() self.occasionally_save_checkpoint()

View File

@@ -271,7 +271,7 @@ class LevelManager(EnvironmentInterface):
[agent.sync() for agent in self.agents.values()] [agent.sync() for agent in self.agents.values()]
def should_train(self) -> bool: def should_train(self) -> bool:
return any([agent._should_train_helper() for agent in self.agents.values()]) return any([agent._should_update() for agent in self.agents.values()])
# TODO-remove - this is a temporary flow, used by the trainer worker, duplicated from observe() - need to create # TODO-remove - this is a temporary flow, used by the trainer worker, duplicated from observe() - need to create
# an external trainer flow reusing the existing flow and methods [e.g. observe(), step(), act()] # an external trainer flow reusing the existing flow and methods [e.g. observe(), step(), act()]

View File

@@ -76,6 +76,7 @@ def rollout_worker(graph_manager, data_store, num_workers, task_parameters):
chkpt_state_reader = CheckpointStateReader(checkpoint_dir, checkpoint_state_optional=False) chkpt_state_reader = CheckpointStateReader(checkpoint_dir, checkpoint_state_optional=False)
last_checkpoint = 0 last_checkpoint = 0
# this worker should play a fraction of the total playing steps per rollout
act_steps = math.ceil((graph_manager.agent_params.algorithm.num_consecutive_playing_steps.num_steps)/num_workers) act_steps = math.ceil((graph_manager.agent_params.algorithm.num_consecutive_playing_steps.num_steps)/num_workers)
for i in range(int(graph_manager.improve_steps.num_steps/act_steps)): for i in range(int(graph_manager.improve_steps.num_steps/act_steps)):