mirror of
https://github.com/gryf/coach.git
synced 2025-12-18 19:50:17 +01:00
use phase context in GraphManager.evaluate
This commit is contained in:
@@ -446,19 +446,17 @@ class GraphManager(object):
|
|||||||
self.verify_graph_was_created()
|
self.verify_graph_was_created()
|
||||||
|
|
||||||
if steps.num_steps > 0:
|
if steps.num_steps > 0:
|
||||||
self.phase = RunPhase.TEST
|
with self.phase_context(RunPhase.TEST):
|
||||||
self.last_evaluation_start_time = time.time()
|
self.last_evaluation_start_time = time.time()
|
||||||
|
|
||||||
# reset all the levels before starting to evaluate
|
# reset all the levels before starting to evaluate
|
||||||
self.reset_internal_state(force_environment_reset=True)
|
self.reset_internal_state(force_environment_reset=True)
|
||||||
self.sync_graph()
|
self.sync_graph()
|
||||||
|
|
||||||
# act for at least `steps`, though don't interrupt an episode
|
# act for at least `steps`, though don't interrupt an episode
|
||||||
count_end = self.total_steps_counters[self.phase][steps.__class__] + steps.num_steps
|
count_end = self.total_steps_counters[self.phase][steps.__class__] + steps.num_steps
|
||||||
while self.total_steps_counters[self.phase][steps.__class__] < count_end:
|
while self.total_steps_counters[self.phase][steps.__class__] < count_end:
|
||||||
self.act(EnvironmentEpisodes(1), keep_networks_in_sync=keep_networks_in_sync)
|
self.act(EnvironmentEpisodes(1), keep_networks_in_sync=keep_networks_in_sync)
|
||||||
|
|
||||||
self.phase = RunPhase.UNDEFINED
|
|
||||||
|
|
||||||
def restore_checkpoint(self):
|
def restore_checkpoint(self):
|
||||||
self.verify_graph_was_created()
|
self.verify_graph_was_created()
|
||||||
|
|||||||
Reference in New Issue
Block a user