mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
Batch RL (#238)
This commit is contained in:
@@ -38,6 +38,8 @@ from rl_coach.data_stores.data_store_impl import get_data_store as data_store_cr
|
||||
from rl_coach.memories.backend.memory_impl import get_memory_backend
|
||||
from rl_coach.data_stores.data_store import SyncFiles
|
||||
|
||||
from rl_coach.core_types import TimeTypes
|
||||
|
||||
|
||||
class ScheduleParameters(Parameters):
|
||||
def __init__(self):
|
||||
@@ -119,6 +121,8 @@ class GraphManager(object):
|
||||
self.checkpoint_state_updater = None
|
||||
self.graph_logger = Logger()
|
||||
self.data_store = None
|
||||
self.is_batch_rl = False
|
||||
self.time_metric = TimeTypes.EpisodeNumber
|
||||
|
||||
def create_graph(self, task_parameters: TaskParameters=TaskParameters()):
|
||||
self.graph_creation_time = time.time()
|
||||
@@ -445,16 +449,17 @@ class GraphManager(object):
|
||||
result = self.top_level_manager.step(None)
|
||||
steps_end = self.environments[0].total_steps_counter
|
||||
|
||||
# add the diff between the total steps before and after stepping, such that environment initialization steps
|
||||
# (like in Atari) will not be counted.
|
||||
# We add at least one step so that even if no steps were made (in case no actions are taken in the training
|
||||
# phase), the loop will end eventually.
|
||||
self.current_step_counter[EnvironmentSteps] += max(1, steps_end - steps_begin)
|
||||
|
||||
if result.game_over:
|
||||
self.handle_episode_ended()
|
||||
self.reset_required = True
|
||||
|
||||
self.current_step_counter[EnvironmentSteps] += (steps_end - steps_begin)
|
||||
|
||||
# if no steps were made (can happen when no actions are taken while in the TRAIN phase, either in batch RL
|
||||
# or in imitation learning), we force end the loop, so that it will not continue forever.
|
||||
if (steps_end - steps_begin) == 0:
|
||||
break
|
||||
|
||||
def train_and_act(self, steps: StepMethod) -> None:
|
||||
"""
|
||||
Train the agent by doing several acting steps followed by several training steps continually
|
||||
@@ -472,9 +477,9 @@ class GraphManager(object):
|
||||
while self.current_step_counter < count_end:
|
||||
# The actual number of steps being done on the environment
|
||||
# is decided by the agent, though this inner loop always
|
||||
# takes at least one step in the environment. Depending on
|
||||
# internal counters and parameters, it doesn't always train
|
||||
# or save checkpoints.
|
||||
# takes at least one step in the environment (at the GraphManager level).
|
||||
# The agent might also decide to skip acting altogether.
|
||||
# Depending on internal counters and parameters, it doesn't always train or save checkpoints.
|
||||
self.act(EnvironmentSteps(1))
|
||||
self.train()
|
||||
self.occasionally_save_checkpoint()
|
||||
|
||||
Reference in New Issue
Block a user