1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 19:20:19 +01:00

use phase context in GraphManager.evaluate

This commit is contained in:
Zach Dwiel
2018-10-04 11:38:45 -04:00
committed by zach dwiel
parent d3c341147e
commit 35d67cbd9b

View File

@@ -446,19 +446,17 @@ class GraphManager(object):
self.verify_graph_was_created()
if steps.num_steps > 0:
self.phase = RunPhase.TEST
self.last_evaluation_start_time = time.time()
with self.phase_context(RunPhase.TEST):
self.last_evaluation_start_time = time.time()
# reset all the levels before starting to evaluate
self.reset_internal_state(force_environment_reset=True)
self.sync_graph()
# reset all the levels before starting to evaluate
self.reset_internal_state(force_environment_reset=True)
self.sync_graph()
# act for at least `steps`, though don't interrupt an episode
count_end = self.total_steps_counters[self.phase][steps.__class__] + steps.num_steps
while self.total_steps_counters[self.phase][steps.__class__] < count_end:
self.act(EnvironmentEpisodes(1), keep_networks_in_sync=keep_networks_in_sync)
self.phase = RunPhase.UNDEFINED
# act for at least `steps`, though don't interrupt an episode
count_end = self.total_steps_counters[self.phase][steps.__class__] + steps.num_steps
while self.total_steps_counters[self.phase][steps.__class__] < count_end:
self.act(EnvironmentEpisodes(1), keep_networks_in_sync=keep_networks_in_sync)
def restore_checkpoint(self):
self.verify_graph_was_created()