mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
Simulating the act on the trainer. (#65)
* Remove the use of daemon threads for Redis subscribe. * Emulate act and observe on trainer side to update internal vars.
This commit is contained in:
committed by
Balaji Subramaniam
parent
fe6857eabd
commit
fde73ced13
@@ -27,13 +27,14 @@ from rl_coach.base_parameters import iterable_to_items, TaskParameters, Distribu
|
||||
Parameters, PresetValidationParameters
|
||||
from rl_coach.core_types import TotalStepsCounter, RunPhase, PlayingStepsType, TrainingSteps, EnvironmentEpisodes, \
|
||||
EnvironmentSteps, \
|
||||
StepMethod
|
||||
StepMethod, Transition
|
||||
from rl_coach.environments.environment import Environment
|
||||
from rl_coach.level_manager import LevelManager
|
||||
from rl_coach.logger import screen, Logger
|
||||
from rl_coach.utils import set_cpu, start_shell_command_and_wait
|
||||
from rl_coach.data_stores.data_store_impl import get_data_store
|
||||
from rl_coach.orchestrators.kubernetes_orchestrator import RunType
|
||||
from rl_coach.memories.backend.memory_impl import get_memory_backend
|
||||
from rl_coach.data_stores.data_store import SyncFiles
|
||||
|
||||
|
||||
@@ -398,9 +399,10 @@ class GraphManager(object):
|
||||
[environment.reset_internal_state(force_environment_reset) for environment in self.environments]
|
||||
[manager.reset_internal_state() for manager in self.level_managers]
|
||||
|
||||
def act(self, steps: PlayingStepsType) -> None:
|
||||
def act(self, steps: PlayingStepsType, wait_for_full_episodes=False) -> None:
|
||||
"""
|
||||
Do several steps of acting on the environment
|
||||
:param wait_for_full_episodes: if set, act for at least `steps`, but make sure that the last episode is complete
|
||||
:param steps: the number of steps as a tuple of steps time and steps count
|
||||
"""
|
||||
self.verify_graph_was_created()
|
||||
@@ -412,7 +414,8 @@ class GraphManager(object):
|
||||
|
||||
# perform several steps of playing
|
||||
count_end = self.current_step_counter + steps
|
||||
while self.current_step_counter < count_end:
|
||||
result = None
|
||||
while self.current_step_counter < count_end or (wait_for_full_episodes and result is not None and not result.game_over):
|
||||
# reset the environment if the previous episode was terminated
|
||||
if self.reset_required:
|
||||
self.reset_internal_state()
|
||||
@@ -624,5 +627,46 @@ class GraphManager(object):
|
||||
def should_train(self) -> bool:
|
||||
return any([manager.should_train() for manager in self.level_managers])
|
||||
|
||||
# TODO-remove - this is a temporary flow, used by the trainer worker, duplicated from observe() - need to create
|
||||
# an external trainer flow reusing the existing flow and methods [e.g. observe(), step(), act()]
|
||||
def emulate_act_on_trainer(self, steps: PlayingStepsType, transition: Transition) -> None:
|
||||
"""
|
||||
This emulates the act using the transition obtained from the rollout worker on the training worker
|
||||
in case of distributed training.
|
||||
Do several steps of acting on the environment
|
||||
:param steps: the number of steps as a tuple of steps time and steps count
|
||||
"""
|
||||
self.verify_graph_was_created()
|
||||
|
||||
# perform several steps of playing
|
||||
count_end = self.current_step_counter + steps
|
||||
while self.current_step_counter < count_end:
|
||||
# reset the environment if the previous episode was terminated
|
||||
if self.reset_required:
|
||||
self.reset_internal_state()
|
||||
|
||||
steps_begin = self.environments[0].total_steps_counter
|
||||
self.top_level_manager.emulate_step_on_trainer(transition)
|
||||
steps_end = self.environments[0].total_steps_counter
|
||||
|
||||
# add the diff between the total steps before and after stepping, such that environment initialization steps
|
||||
# (like in Atari) will not be counted.
|
||||
# We add at least one step so that even if no steps were made (in case no actions are taken in the training
|
||||
# phase), the loop will end eventually.
|
||||
self.current_step_counter[EnvironmentSteps] += max(1, steps_end - steps_begin)
|
||||
|
||||
if transition.game_over:
|
||||
self.handle_episode_ended()
|
||||
self.reset_required = True
|
||||
|
||||
def fetch_from_worker(self, num_steps=0):
|
||||
if hasattr(self, 'memory_backend'):
|
||||
for transition in self.memory_backend.fetch(num_steps):
|
||||
self.emulate_act_on_trainer(EnvironmentSteps(1), transition)
|
||||
|
||||
def setup_memory_backend(self) -> None:
|
||||
if hasattr(self.agent_params.memory, 'memory_backend_params'):
|
||||
self.memory_backend = get_memory_backend(self.agent_params.memory.memory_backend_params)
|
||||
|
||||
def should_stop(self) -> bool:
|
||||
return all([manager.should_stop() for manager in self.level_managers])
|
||||
|
||||
Reference in New Issue
Block a user