1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 11:10:20 +01:00

add reset internal state to rollout worker (#421)

This commit is contained in:
shadiendrawis
2019-11-03 14:42:51 +02:00
committed by GitHub
parent e288a552dd
commit 6ca91b9090
2 changed files with 4 additions and 3 deletions

View File

@@ -208,7 +208,7 @@ class Kubernetes(Deploy):
tty=True,
resources=k8sclient.V1ResourceRequirements(
limits={
"cpu": "40",
"cpu": "24",
"memory": "4Gi",
"nvidia.com/gpu": "1",
}
@@ -322,7 +322,7 @@ class Kubernetes(Deploy):
tty=True,
resources=k8sclient.V1ResourceRequirements(
limits={
"cpu": "8",
"cpu": "4",
"memory": "4Gi",
# "nvidia.com/gpu": "0",
}

View File

@@ -31,6 +31,7 @@ import os
from rl_coach.base_parameters import TaskParameters, DistributedCoachSynchronizationType
from rl_coach.checkpoint import CheckpointStateFile, CheckpointStateReader
from rl_coach.data_stores.data_store import SyncFiles
from rl_coach.core_types import RunPhase
def wait_for(wait_func, data_store=None, timeout=10):
@@ -71,7 +72,6 @@ def rollout_worker(graph_manager, data_store, num_workers, task_parameters):
"""
wait for first checkpoint then perform rollouts using the model
"""
wait_for_trainer_ready(checkpoint_dir, data_store)
if (
graph_manager.agent_params.algorithm.distributed_coach_synchronization_type
== DistributedCoachSynchronizationType.SYNC
@@ -87,6 +87,7 @@ def rollout_worker(graph_manager, data_store, num_workers, task_parameters):
with graph_manager.phase_context(RunPhase.TRAIN):
# this worker should play a fraction of the total playing steps per rollout
graph_manager.reset_internal_state(force_environment_reset=True)
act_steps = (
graph_manager.agent_params.algorithm.num_consecutive_playing_steps