mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
Cleanup and refactoring (#171)
This commit is contained in:
committed by
Gal Leibovich
parent
cd812b0d25
commit
fedb4cbd7c
@@ -609,35 +609,35 @@ class Agent(AgentInterface):
|
||||
:return: boolean: True if we should start a training phase
|
||||
"""
|
||||
|
||||
should_update = self._should_train_helper()
|
||||
should_update = self._should_update()
|
||||
|
||||
step_method = self.ap.algorithm.num_consecutive_playing_steps
|
||||
steps = self.ap.algorithm.num_consecutive_playing_steps
|
||||
|
||||
if should_update:
|
||||
if step_method.__class__ == EnvironmentEpisodes:
|
||||
if steps.__class__ == EnvironmentEpisodes:
|
||||
self.last_training_phase_step = self.current_episode
|
||||
if step_method.__class__ == EnvironmentSteps:
|
||||
if steps.__class__ == EnvironmentSteps:
|
||||
self.last_training_phase_step = self.total_steps_counter
|
||||
|
||||
return should_update
|
||||
|
||||
def _should_train_helper(self):
|
||||
def _should_update(self):
|
||||
wait_for_full_episode = self.ap.algorithm.act_for_full_episodes
|
||||
step_method = self.ap.algorithm.num_consecutive_playing_steps
|
||||
steps = self.ap.algorithm.num_consecutive_playing_steps
|
||||
|
||||
if step_method.__class__ == EnvironmentEpisodes:
|
||||
should_update = (self.current_episode - self.last_training_phase_step) >= step_method.num_steps
|
||||
if steps.__class__ == EnvironmentEpisodes:
|
||||
should_update = (self.current_episode - self.last_training_phase_step) >= steps.num_steps
|
||||
should_update = should_update and self.call_memory('length') > 0
|
||||
|
||||
elif step_method.__class__ == EnvironmentSteps:
|
||||
should_update = (self.total_steps_counter - self.last_training_phase_step) >= step_method.num_steps
|
||||
elif steps.__class__ == EnvironmentSteps:
|
||||
should_update = (self.total_steps_counter - self.last_training_phase_step) >= steps.num_steps
|
||||
should_update = should_update and self.call_memory('num_transitions') > 0
|
||||
|
||||
if wait_for_full_episode:
|
||||
should_update = should_update and self.current_episode_buffer.is_complete
|
||||
else:
|
||||
raise ValueError("The num_consecutive_playing_steps parameter should be either "
|
||||
"EnvironmentSteps or Episodes. Instead it is {}".format(step_method.__class__))
|
||||
"EnvironmentSteps or Episodes. Instead it is {}".format(steps.__class__))
|
||||
|
||||
return should_update
|
||||
|
||||
|
||||
@@ -84,7 +84,8 @@ class DFPAlgorithmParameters(AlgorithmParameters):
|
||||
"""
|
||||
:param num_predicted_steps_ahead: (int)
|
||||
Number of future steps to predict measurements for. The future steps won't be sequential, but rather jump
|
||||
in multiples of 2. For example, if num_predicted_steps_ahead = 3, then the steps will be: t+1, t+2, t+4
|
||||
in multiples of 2. For example, if num_predicted_steps_ahead = 3, then the steps will be: t+1, t+2, t+4.
|
||||
The predicted steps will be [t + 2**i for i in range(num_predicted_steps_ahead)]
|
||||
|
||||
:param goal_vector: (List[float])
|
||||
The goal vector will weight each of the measurements to form an optimization goal. The vector should have
|
||||
|
||||
@@ -125,22 +125,25 @@ class GeneralTensorFlowNetwork(TensorFlowArchitecture):
|
||||
super().__init__(agent_parameters, spaces, name, global_network,
|
||||
network_is_local, network_is_trainable)
|
||||
|
||||
def fill_return_types():
|
||||
ret_dict = {}
|
||||
for cls in get_all_subclasses(PredictionType):
|
||||
ret_dict[cls] = []
|
||||
components = self.input_embedders + [self.middleware] + self.output_heads
|
||||
for component in components:
|
||||
if not hasattr(component, 'return_type'):
|
||||
raise ValueError("{} has no return_type attribute. This should not happen.")
|
||||
if component.return_type is not None:
|
||||
ret_dict[component.return_type].append(component)
|
||||
|
||||
return ret_dict
|
||||
|
||||
self.available_return_types = fill_return_types()
|
||||
self.available_return_types = self._available_return_types()
|
||||
self.is_training = None
|
||||
|
||||
def _available_return_types(self):
|
||||
ret_dict = {cls: [] for cls in get_all_subclasses(PredictionType)}
|
||||
|
||||
components = self.input_embedders + [self.middleware] + self.output_heads
|
||||
for component in components:
|
||||
if not hasattr(component, 'return_type'):
|
||||
raise ValueError((
|
||||
"{} has no return_type attribute. Without this, it is "
|
||||
"unclear how this component should be used."
|
||||
).format(component))
|
||||
|
||||
if component.return_type is not None:
|
||||
ret_dict[component.return_type].append(component)
|
||||
|
||||
return ret_dict
|
||||
|
||||
def predict_with_prediction_type(self, states: Dict[str, np.ndarray],
|
||||
prediction_type: PredictionType) -> Dict[str, np.ndarray]:
|
||||
"""
|
||||
|
||||
@@ -98,6 +98,10 @@ class Embedding(PredictionType):
|
||||
pass
|
||||
|
||||
|
||||
class Measurements(PredictionType):
|
||||
pass
|
||||
|
||||
|
||||
class InputEmbedding(Embedding):
|
||||
pass
|
||||
|
||||
@@ -126,10 +130,6 @@ class Middleware_LSTM_Embedding(MiddlewareEmbedding):
|
||||
pass
|
||||
|
||||
|
||||
class Measurements(PredictionType):
|
||||
pass
|
||||
|
||||
|
||||
PlayingStepsType = Union[EnvironmentSteps, EnvironmentEpisodes, Frames]
|
||||
|
||||
|
||||
|
||||
@@ -346,6 +346,10 @@ class GraphManager(object):
|
||||
|
||||
@contextlib.contextmanager
|
||||
def phase_context(self, phase):
|
||||
"""
|
||||
Create a context which temporarily sets the phase to the provided phase.
|
||||
The previous phase is restored afterwards.
|
||||
"""
|
||||
old_phase = self.phase
|
||||
self.phase = phase
|
||||
yield
|
||||
@@ -464,8 +468,11 @@ class GraphManager(object):
|
||||
|
||||
count_end = self.current_step_counter + steps
|
||||
while self.current_step_counter < count_end:
|
||||
# The actual steps being done on the environment are decided by the agents themselves.
|
||||
# This is just an high-level controller.
|
||||
# The actual number of steps being done on the environment
|
||||
# is decided by the agent, though this inner loop always
|
||||
# takes at least one step in the environment. Depending on
|
||||
# internal counters and parameters, it doesn't always train
|
||||
# or save checkpoints.
|
||||
self.act(EnvironmentSteps(1))
|
||||
self.train()
|
||||
self.occasionally_save_checkpoint()
|
||||
|
||||
@@ -271,7 +271,7 @@ class LevelManager(EnvironmentInterface):
|
||||
[agent.sync() for agent in self.agents.values()]
|
||||
|
||||
def should_train(self) -> bool:
|
||||
return any([agent._should_train_helper() for agent in self.agents.values()])
|
||||
return any([agent._should_update() for agent in self.agents.values()])
|
||||
|
||||
# TODO-remove - this is a temporary flow, used by the trainer worker, duplicated from observe() - need to create
|
||||
# an external trainer flow reusing the existing flow and methods [e.g. observe(), step(), act()]
|
||||
|
||||
@@ -76,6 +76,7 @@ def rollout_worker(graph_manager, data_store, num_workers, task_parameters):
|
||||
chkpt_state_reader = CheckpointStateReader(checkpoint_dir, checkpoint_state_optional=False)
|
||||
last_checkpoint = 0
|
||||
|
||||
# this worker should play a fraction of the total playing steps per rollout
|
||||
act_steps = math.ceil((graph_manager.agent_params.algorithm.num_consecutive_playing_steps.num_steps)/num_workers)
|
||||
|
||||
for i in range(int(graph_manager.improve_steps.num_steps/act_steps)):
|
||||
|
||||
Reference in New Issue
Block a user