1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-18 11:40:18 +01:00

Simulating the act on the trainer. (#65)

* Remove the use of daemon threads for Redis subscribe.
* Emulate act and observe on trainer side to update internal vars.
This commit is contained in:
Ajay Deshpande
2018-11-15 08:38:58 -08:00
committed by Balaji Subramaniam
parent fe6857eabd
commit fde73ced13
13 changed files with 221 additions and 55 deletions

View File

@@ -27,13 +27,14 @@ from rl_coach.base_parameters import iterable_to_items, TaskParameters, Distribu
Parameters, PresetValidationParameters
from rl_coach.core_types import TotalStepsCounter, RunPhase, PlayingStepsType, TrainingSteps, EnvironmentEpisodes, \
EnvironmentSteps, \
StepMethod
StepMethod, Transition
from rl_coach.environments.environment import Environment
from rl_coach.level_manager import LevelManager
from rl_coach.logger import screen, Logger
from rl_coach.utils import set_cpu, start_shell_command_and_wait
from rl_coach.data_stores.data_store_impl import get_data_store
from rl_coach.orchestrators.kubernetes_orchestrator import RunType
from rl_coach.memories.backend.memory_impl import get_memory_backend
from rl_coach.data_stores.data_store import SyncFiles
@@ -398,9 +399,10 @@ class GraphManager(object):
[environment.reset_internal_state(force_environment_reset) for environment in self.environments]
[manager.reset_internal_state() for manager in self.level_managers]
def act(self, steps: PlayingStepsType) -> None:
def act(self, steps: PlayingStepsType, wait_for_full_episodes=False) -> None:
"""
Do several steps of acting on the environment
:param wait_for_full_episodes: if set, act for at least `steps`, but make sure that the last episode is complete
:param steps: the number of steps as a tuple of steps time and steps count
"""
self.verify_graph_was_created()
@@ -412,7 +414,8 @@ class GraphManager(object):
# perform several steps of playing
count_end = self.current_step_counter + steps
while self.current_step_counter < count_end:
result = None
while self.current_step_counter < count_end or (wait_for_full_episodes and result is not None and not result.game_over):
# reset the environment if the previous episode was terminated
if self.reset_required:
self.reset_internal_state()
@@ -624,5 +627,46 @@ class GraphManager(object):
def should_train(self) -> bool:
return any([manager.should_train() for manager in self.level_managers])
# TODO-remove - this is a temporary flow, used by the trainer worker, duplicated from observe() - need to create
# an external trainer flow reusing the existing flow and methods [e.g. observe(), step(), act()]
def emulate_act_on_trainer(self, steps: PlayingStepsType, transition: Transition) -> None:
"""
This emulates the act using the transition obtained from the rollout worker on the training worker
in case of distributed training.
Do several steps of acting on the environment
:param steps: the number of steps as a tuple of steps time and steps count
"""
self.verify_graph_was_created()
# perform several steps of playing
count_end = self.current_step_counter + steps
while self.current_step_counter < count_end:
# reset the environment if the previous episode was terminated
if self.reset_required:
self.reset_internal_state()
steps_begin = self.environments[0].total_steps_counter
self.top_level_manager.emulate_step_on_trainer(transition)
steps_end = self.environments[0].total_steps_counter
# add the diff between the total steps before and after stepping, such that environment initialization steps
# (like in Atari) will not be counted.
# We add at least one step so that even if no steps were made (in case no actions are taken in the training
# phase), the loop will end eventually.
self.current_step_counter[EnvironmentSteps] += max(1, steps_end - steps_begin)
if transition.game_over:
self.handle_episode_ended()
self.reset_required = True
def fetch_from_worker(self, num_steps=0):
if hasattr(self, 'memory_backend'):
for transition in self.memory_backend.fetch(num_steps):
self.emulate_act_on_trainer(EnvironmentSteps(1), transition)
def setup_memory_backend(self) -> None:
if hasattr(self.agent_params.memory, 'memory_backend_params'):
self.memory_backend = get_memory_backend(self.agent_params.memory.memory_backend_params)
def should_stop(self) -> bool:
return all([manager.should_stop() for manager in self.level_managers])