mirror of
https://github.com/gryf/coach.git
synced 2026-02-12 20:05:45 +01:00
waiting for a new checkpoint if it's available
This commit is contained in:
committed by
zach dwiel
parent
5eac0102de
commit
7f00235ed5
@@ -81,7 +81,7 @@ class Agent(AgentInterface):
|
||||
self.memory_backend = get_memory_backend(self.ap.memory.memory_backend_params)
|
||||
|
||||
if self.ap.memory.memory_backend_params.run_type == 'trainer':
|
||||
self.memory_backend.subscribe(self.memory)
|
||||
self.memory_backend.subscribe(self)
|
||||
else:
|
||||
self.memory.set_memory_backend(self.memory_backend)
|
||||
|
||||
@@ -523,8 +523,7 @@ class Agent(AgentInterface):
|
||||
Determine if online weights should be copied to the target.
|
||||
:return: boolean: True if the online weights should be copied to the target.
|
||||
"""
|
||||
if hasattr(self.ap.memory, 'memory_backend_params'):
|
||||
self.total_steps_counter = self.call_memory('num_transitions')
|
||||
|
||||
# update the target network of every network that has a target network
|
||||
step_method = self.ap.algorithm.num_steps_between_copying_online_weights_to_target
|
||||
if step_method.__class__ == TrainingSteps:
|
||||
@@ -546,7 +545,7 @@ class Agent(AgentInterface):
|
||||
:return: boolean: True if we should start a training phase
|
||||
"""
|
||||
|
||||
should_update = self._should_train_helper(wait_for_full_episode)
|
||||
should_update = self._should_train_helper(wait_for_full_episode=wait_for_full_episode)
|
||||
|
||||
step_method = self.ap.algorithm.num_consecutive_playing_steps
|
||||
|
||||
@@ -560,10 +559,8 @@ class Agent(AgentInterface):
|
||||
|
||||
def _should_train_helper(self, wait_for_full_episode=False):
|
||||
|
||||
if hasattr(self.ap.memory, 'memory_backend_params'):
|
||||
self.total_steps_counter = self.call_memory('num_transitions')
|
||||
|
||||
step_method = self.ap.algorithm.num_consecutive_playing_steps
|
||||
|
||||
if step_method.__class__ == EnvironmentEpisodes:
|
||||
should_update = (self.current_episode - self.last_training_phase_step) >= step_method.num_steps
|
||||
|
||||
|
||||
@@ -251,6 +251,9 @@ class ClippedPPOAgent(ActorCriticAgent):
|
||||
# clean memory
|
||||
self.call_memory('clean')
|
||||
|
||||
def _should_train_helper(self, wait_for_full_episode=True):
|
||||
return super()._should_train_helper(True)
|
||||
|
||||
def train(self):
|
||||
if self._should_train(wait_for_full_episode=True):
|
||||
for network in self.networks.values():
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
#
|
||||
# Copyright (c) 2017 Intel Corporation
|
||||
# Copyright (c) 2017 Intel Corporation
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -310,6 +310,9 @@ class PPOAgent(ActorCriticAgent):
|
||||
# clean memory
|
||||
self.call_memory('clean')
|
||||
|
||||
def _should_train_helper(self):
|
||||
return super()._should_train_helper(True)
|
||||
|
||||
def train(self):
|
||||
loss = 0
|
||||
if self._should_train(wait_for_full_episode=True):
|
||||
|
||||
Reference in New Issue
Block a user