mirror of
https://github.com/gryf/coach.git
synced 2025-12-18 11:40:18 +01:00
graph_manager:heatup uses total_steps_counters looping mechanism like other loops. graph_manager:act no longer needs to return any values
This commit is contained in:
@@ -30,12 +30,8 @@ from rl_coach.core_types import TotalStepsCounter, RunPhase, PlayingStepsType, T
|
|||||||
from rl_coach.environments.environment import Environment
|
from rl_coach.environments.environment import Environment
|
||||||
from rl_coach.level_manager import LevelManager
|
from rl_coach.level_manager import LevelManager
|
||||||
from rl_coach.logger import screen, Logger
|
from rl_coach.logger import screen, Logger
|
||||||
<<<<<<< HEAD
|
|
||||||
from rl_coach.utils import set_cpu, start_shell_command_and_wait
|
from rl_coach.utils import set_cpu, start_shell_command_and_wait
|
||||||
=======
|
|
||||||
from rl_coach.utils import set_cpu
|
|
||||||
from rl_coach.data_stores.data_store_impl import get_data_store
|
from rl_coach.data_stores.data_store_impl import get_data_store
|
||||||
>>>>>>> Make distributed coach work end-to-end.
|
|
||||||
|
|
||||||
|
|
||||||
class ScheduleParameters(Parameters):
|
class ScheduleParameters(Parameters):
|
||||||
@@ -305,9 +301,7 @@ class GraphManager(object):
|
|||||||
"""
|
"""
|
||||||
self.verify_graph_was_created()
|
self.verify_graph_was_created()
|
||||||
|
|
||||||
steps_copy = copy.copy(steps)
|
if steps.num_steps > 0:
|
||||||
|
|
||||||
if steps_copy.num_steps > 0:
|
|
||||||
self.phase = RunPhase.HEATUP
|
self.phase = RunPhase.HEATUP
|
||||||
screen.log_title("{}: Starting heatup".format(self.name))
|
screen.log_title("{}: Starting heatup".format(self.name))
|
||||||
self.heatup_start_time = time.time()
|
self.heatup_start_time = time.time()
|
||||||
@@ -315,11 +309,10 @@ class GraphManager(object):
|
|||||||
# reset all the levels before starting to heatup
|
# reset all the levels before starting to heatup
|
||||||
self.reset_internal_state(force_environment_reset=True)
|
self.reset_internal_state(force_environment_reset=True)
|
||||||
|
|
||||||
# act on the environment
|
|
||||||
# act for at least steps, though don't interrupt an episode
|
# act for at least steps, though don't interrupt an episode
|
||||||
while steps_copy.num_steps > 0:
|
count_end = self.total_steps_counters[self.phase][EnvironmentSteps] + steps.num_steps
|
||||||
steps_done, _ = self.act(steps_copy, continue_until_game_over=True, return_on_game_over=True)
|
while self.total_steps_counters[self.phase][steps.__class__] < count_end:
|
||||||
steps_copy.num_steps -= steps_done
|
self.act(steps, continue_until_game_over=True, return_on_game_over=True)
|
||||||
|
|
||||||
# training phase
|
# training phase
|
||||||
self.phase = RunPhase.UNDEFINED
|
self.phase = RunPhase.UNDEFINED
|
||||||
@@ -380,7 +373,7 @@ class GraphManager(object):
|
|||||||
# perform several steps of playing
|
# perform several steps of playing
|
||||||
result = None
|
result = None
|
||||||
|
|
||||||
initial_count = self.total_steps_counters[self.phase][steps.__class__]
|
initial_count = self.total_steps_counters[self.phase][EnvironmentSteps]
|
||||||
count_end = initial_count + steps.num_steps
|
count_end = initial_count + steps.num_steps
|
||||||
|
|
||||||
# The assumption here is that the total_steps_counters are each updated when an event
|
# The assumption here is that the total_steps_counters are each updated when an event
|
||||||
|
|||||||
Reference in New Issue
Block a user