mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
Cleanup and refactoring (#171)
This commit is contained in:
committed by
Gal Leibovich
parent
cd812b0d25
commit
fedb4cbd7c
@@ -609,35 +609,35 @@ class Agent(AgentInterface):
|
|||||||
:return: boolean: True if we should start a training phase
|
:return: boolean: True if we should start a training phase
|
||||||
"""
|
"""
|
||||||
|
|
||||||
should_update = self._should_train_helper()
|
should_update = self._should_update()
|
||||||
|
|
||||||
step_method = self.ap.algorithm.num_consecutive_playing_steps
|
steps = self.ap.algorithm.num_consecutive_playing_steps
|
||||||
|
|
||||||
if should_update:
|
if should_update:
|
||||||
if step_method.__class__ == EnvironmentEpisodes:
|
if steps.__class__ == EnvironmentEpisodes:
|
||||||
self.last_training_phase_step = self.current_episode
|
self.last_training_phase_step = self.current_episode
|
||||||
if step_method.__class__ == EnvironmentSteps:
|
if steps.__class__ == EnvironmentSteps:
|
||||||
self.last_training_phase_step = self.total_steps_counter
|
self.last_training_phase_step = self.total_steps_counter
|
||||||
|
|
||||||
return should_update
|
return should_update
|
||||||
|
|
||||||
def _should_train_helper(self):
|
def _should_update(self):
|
||||||
wait_for_full_episode = self.ap.algorithm.act_for_full_episodes
|
wait_for_full_episode = self.ap.algorithm.act_for_full_episodes
|
||||||
step_method = self.ap.algorithm.num_consecutive_playing_steps
|
steps = self.ap.algorithm.num_consecutive_playing_steps
|
||||||
|
|
||||||
if step_method.__class__ == EnvironmentEpisodes:
|
if steps.__class__ == EnvironmentEpisodes:
|
||||||
should_update = (self.current_episode - self.last_training_phase_step) >= step_method.num_steps
|
should_update = (self.current_episode - self.last_training_phase_step) >= steps.num_steps
|
||||||
should_update = should_update and self.call_memory('length') > 0
|
should_update = should_update and self.call_memory('length') > 0
|
||||||
|
|
||||||
elif step_method.__class__ == EnvironmentSteps:
|
elif steps.__class__ == EnvironmentSteps:
|
||||||
should_update = (self.total_steps_counter - self.last_training_phase_step) >= step_method.num_steps
|
should_update = (self.total_steps_counter - self.last_training_phase_step) >= steps.num_steps
|
||||||
should_update = should_update and self.call_memory('num_transitions') > 0
|
should_update = should_update and self.call_memory('num_transitions') > 0
|
||||||
|
|
||||||
if wait_for_full_episode:
|
if wait_for_full_episode:
|
||||||
should_update = should_update and self.current_episode_buffer.is_complete
|
should_update = should_update and self.current_episode_buffer.is_complete
|
||||||
else:
|
else:
|
||||||
raise ValueError("The num_consecutive_playing_steps parameter should be either "
|
raise ValueError("The num_consecutive_playing_steps parameter should be either "
|
||||||
"EnvironmentSteps or Episodes. Instead it is {}".format(step_method.__class__))
|
"EnvironmentSteps or Episodes. Instead it is {}".format(steps.__class__))
|
||||||
|
|
||||||
return should_update
|
return should_update
|
||||||
|
|
||||||
|
|||||||
@@ -84,7 +84,8 @@ class DFPAlgorithmParameters(AlgorithmParameters):
|
|||||||
"""
|
"""
|
||||||
:param num_predicted_steps_ahead: (int)
|
:param num_predicted_steps_ahead: (int)
|
||||||
Number of future steps to predict measurements for. The future steps won't be sequential, but rather jump
|
Number of future steps to predict measurements for. The future steps won't be sequential, but rather jump
|
||||||
in multiples of 2. For example, if num_predicted_steps_ahead = 3, then the steps will be: t+1, t+2, t+4
|
in multiples of 2. For example, if num_predicted_steps_ahead = 3, then the steps will be: t+1, t+2, t+4.
|
||||||
|
The predicted steps will be [t + 2**i for i in range(num_predicted_steps_ahead)]
|
||||||
|
|
||||||
:param goal_vector: (List[float])
|
:param goal_vector: (List[float])
|
||||||
The goal vector will weight each of the measurements to form an optimization goal. The vector should have
|
The goal vector will weight each of the measurements to form an optimization goal. The vector should have
|
||||||
|
|||||||
@@ -125,22 +125,25 @@ class GeneralTensorFlowNetwork(TensorFlowArchitecture):
|
|||||||
super().__init__(agent_parameters, spaces, name, global_network,
|
super().__init__(agent_parameters, spaces, name, global_network,
|
||||||
network_is_local, network_is_trainable)
|
network_is_local, network_is_trainable)
|
||||||
|
|
||||||
def fill_return_types():
|
self.available_return_types = self._available_return_types()
|
||||||
ret_dict = {}
|
self.is_training = None
|
||||||
for cls in get_all_subclasses(PredictionType):
|
|
||||||
ret_dict[cls] = []
|
def _available_return_types(self):
|
||||||
|
ret_dict = {cls: [] for cls in get_all_subclasses(PredictionType)}
|
||||||
|
|
||||||
components = self.input_embedders + [self.middleware] + self.output_heads
|
components = self.input_embedders + [self.middleware] + self.output_heads
|
||||||
for component in components:
|
for component in components:
|
||||||
if not hasattr(component, 'return_type'):
|
if not hasattr(component, 'return_type'):
|
||||||
raise ValueError("{} has no return_type attribute. This should not happen.")
|
raise ValueError((
|
||||||
|
"{} has no return_type attribute. Without this, it is "
|
||||||
|
"unclear how this component should be used."
|
||||||
|
).format(component))
|
||||||
|
|
||||||
if component.return_type is not None:
|
if component.return_type is not None:
|
||||||
ret_dict[component.return_type].append(component)
|
ret_dict[component.return_type].append(component)
|
||||||
|
|
||||||
return ret_dict
|
return ret_dict
|
||||||
|
|
||||||
self.available_return_types = fill_return_types()
|
|
||||||
self.is_training = None
|
|
||||||
|
|
||||||
def predict_with_prediction_type(self, states: Dict[str, np.ndarray],
|
def predict_with_prediction_type(self, states: Dict[str, np.ndarray],
|
||||||
prediction_type: PredictionType) -> Dict[str, np.ndarray]:
|
prediction_type: PredictionType) -> Dict[str, np.ndarray]:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -98,6 +98,10 @@ class Embedding(PredictionType):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class Measurements(PredictionType):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class InputEmbedding(Embedding):
|
class InputEmbedding(Embedding):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@@ -126,10 +130,6 @@ class Middleware_LSTM_Embedding(MiddlewareEmbedding):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class Measurements(PredictionType):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
PlayingStepsType = Union[EnvironmentSteps, EnvironmentEpisodes, Frames]
|
PlayingStepsType = Union[EnvironmentSteps, EnvironmentEpisodes, Frames]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -346,6 +346,10 @@ class GraphManager(object):
|
|||||||
|
|
||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
def phase_context(self, phase):
|
def phase_context(self, phase):
|
||||||
|
"""
|
||||||
|
Create a context which temporarily sets the phase to the provided phase.
|
||||||
|
The previous phase is restored afterwards.
|
||||||
|
"""
|
||||||
old_phase = self.phase
|
old_phase = self.phase
|
||||||
self.phase = phase
|
self.phase = phase
|
||||||
yield
|
yield
|
||||||
@@ -464,8 +468,11 @@ class GraphManager(object):
|
|||||||
|
|
||||||
count_end = self.current_step_counter + steps
|
count_end = self.current_step_counter + steps
|
||||||
while self.current_step_counter < count_end:
|
while self.current_step_counter < count_end:
|
||||||
# The actual steps being done on the environment are decided by the agents themselves.
|
# The actual number of steps being done on the environment
|
||||||
# This is just an high-level controller.
|
# is decided by the agent, though this inner loop always
|
||||||
|
# takes at least one step in the environment. Depending on
|
||||||
|
# internal counters and parameters, it doesn't always train
|
||||||
|
# or save checkpoints.
|
||||||
self.act(EnvironmentSteps(1))
|
self.act(EnvironmentSteps(1))
|
||||||
self.train()
|
self.train()
|
||||||
self.occasionally_save_checkpoint()
|
self.occasionally_save_checkpoint()
|
||||||
|
|||||||
@@ -271,7 +271,7 @@ class LevelManager(EnvironmentInterface):
|
|||||||
[agent.sync() for agent in self.agents.values()]
|
[agent.sync() for agent in self.agents.values()]
|
||||||
|
|
||||||
def should_train(self) -> bool:
|
def should_train(self) -> bool:
|
||||||
return any([agent._should_train_helper() for agent in self.agents.values()])
|
return any([agent._should_update() for agent in self.agents.values()])
|
||||||
|
|
||||||
# TODO-remove - this is a temporary flow, used by the trainer worker, duplicated from observe() - need to create
|
# TODO-remove - this is a temporary flow, used by the trainer worker, duplicated from observe() - need to create
|
||||||
# an external trainer flow reusing the existing flow and methods [e.g. observe(), step(), act()]
|
# an external trainer flow reusing the existing flow and methods [e.g. observe(), step(), act()]
|
||||||
|
|||||||
@@ -76,6 +76,7 @@ def rollout_worker(graph_manager, data_store, num_workers, task_parameters):
|
|||||||
chkpt_state_reader = CheckpointStateReader(checkpoint_dir, checkpoint_state_optional=False)
|
chkpt_state_reader = CheckpointStateReader(checkpoint_dir, checkpoint_state_optional=False)
|
||||||
last_checkpoint = 0
|
last_checkpoint = 0
|
||||||
|
|
||||||
|
# this worker should play a fraction of the total playing steps per rollout
|
||||||
act_steps = math.ceil((graph_manager.agent_params.algorithm.num_consecutive_playing_steps.num_steps)/num_workers)
|
act_steps = math.ceil((graph_manager.agent_params.algorithm.num_consecutive_playing_steps.num_steps)/num_workers)
|
||||||
|
|
||||||
for i in range(int(graph_manager.improve_steps.num_steps/act_steps)):
|
for i in range(int(graph_manager.improve_steps.num_steps/act_steps)):
|
||||||
|
|||||||
Reference in New Issue
Block a user