mirror of
https://github.com/gryf/coach.git
synced 2025-12-18 19:50:17 +01:00
Ignoring redis sub if testing
This commit is contained in:
committed by
zach dwiel
parent
7f00235ed5
commit
0e121c5762
@@ -435,7 +435,7 @@ class GraphManager(object):
|
|||||||
if steps.num_steps > 0:
|
if steps.num_steps > 0:
|
||||||
self.phase = RunPhase.TRAIN
|
self.phase = RunPhase.TRAIN
|
||||||
self.reset_internal_state(force_environment_reset=True)
|
self.reset_internal_state(force_environment_reset=True)
|
||||||
#TODO - the below while loop should end with full episodes, so to avoid situations where we have partial
|
# TODO - the below while loop should end with full episodes, so to avoid situations where we have partial
|
||||||
# episodes in memory
|
# episodes in memory
|
||||||
count_end = self.total_steps_counters[self.phase][steps.__class__] + steps.num_steps
|
count_end = self.total_steps_counters[self.phase][steps.__class__] + steps.num_steps
|
||||||
while self.total_steps_counters[self.phase][steps.__class__] < count_end:
|
while self.total_steps_counters[self.phase][steps.__class__] < count_end:
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ from kubernetes import client
|
|||||||
|
|
||||||
from rl_coach.memories.backend.memory import MemoryBackend, MemoryBackendParameters
|
from rl_coach.memories.backend.memory import MemoryBackend, MemoryBackendParameters
|
||||||
from rl_coach.core_types import Transition, Episode
|
from rl_coach.core_types import Transition, Episode
|
||||||
|
from rl_coach.core_types import RunPhase
|
||||||
|
|
||||||
|
|
||||||
class RedisPubSubMemoryBackendParameters(MemoryBackendParameters):
|
class RedisPubSubMemoryBackendParameters(MemoryBackendParameters):
|
||||||
@@ -148,7 +149,9 @@ class RedisSub(threading.Thread):
|
|||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
for message in self.pubsub.listen():
|
for message in self.pubsub.listen():
|
||||||
if message and 'data' in message:
|
if message and 'data' in message and self.agent.phase != RunPhase.TEST or self.agent.ap.task_parameters.evaluate_only:
|
||||||
|
if self.agent.phase == RunPhase.TEST:
|
||||||
|
print(self.agent.phase)
|
||||||
try:
|
try:
|
||||||
obj = pickle.loads(message['data'])
|
obj = pickle.loads(message['data'])
|
||||||
if type(obj) == Transition:
|
if type(obj) == Transition:
|
||||||
|
|||||||
@@ -38,7 +38,6 @@ def training_worker(graph_manager, checkpoint_dir):
|
|||||||
graph_manager.phase = core_types.RunPhase.UNDEFINED
|
graph_manager.phase = core_types.RunPhase.UNDEFINED
|
||||||
graph_manager.evaluate(graph_manager.evaluation_steps)
|
graph_manager.evaluate(graph_manager.evaluation_steps)
|
||||||
graph_manager.save_checkpoint()
|
graph_manager.save_checkpoint()
|
||||||
time.sleep(10)
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
|||||||
Reference in New Issue
Block a user