mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
add reset internal state to rollout worker (#421)
This commit is contained in:
@@ -208,7 +208,7 @@ class Kubernetes(Deploy):
|
||||
tty=True,
|
||||
resources=k8sclient.V1ResourceRequirements(
|
||||
limits={
|
||||
"cpu": "40",
|
||||
"cpu": "24",
|
||||
"memory": "4Gi",
|
||||
"nvidia.com/gpu": "1",
|
||||
}
|
||||
@@ -322,7 +322,7 @@ class Kubernetes(Deploy):
|
||||
tty=True,
|
||||
resources=k8sclient.V1ResourceRequirements(
|
||||
limits={
|
||||
"cpu": "8",
|
||||
"cpu": "4",
|
||||
"memory": "4Gi",
|
||||
# "nvidia.com/gpu": "0",
|
||||
}
|
||||
|
||||
@@ -31,6 +31,7 @@ import os
|
||||
from rl_coach.base_parameters import TaskParameters, DistributedCoachSynchronizationType
|
||||
from rl_coach.checkpoint import CheckpointStateFile, CheckpointStateReader
|
||||
from rl_coach.data_stores.data_store import SyncFiles
|
||||
from rl_coach.core_types import RunPhase
|
||||
|
||||
|
||||
def wait_for(wait_func, data_store=None, timeout=10):
|
||||
@@ -71,7 +72,6 @@ def rollout_worker(graph_manager, data_store, num_workers, task_parameters):
|
||||
"""
|
||||
wait for first checkpoint then perform rollouts using the model
|
||||
"""
|
||||
wait_for_trainer_ready(checkpoint_dir, data_store)
|
||||
if (
|
||||
graph_manager.agent_params.algorithm.distributed_coach_synchronization_type
|
||||
== DistributedCoachSynchronizationType.SYNC
|
||||
@@ -87,6 +87,7 @@ def rollout_worker(graph_manager, data_store, num_workers, task_parameters):
|
||||
|
||||
with graph_manager.phase_context(RunPhase.TRAIN):
|
||||
# this worker should play a fraction of the total playing steps per rollout
|
||||
graph_manager.reset_internal_state(force_environment_reset=True)
|
||||
|
||||
act_steps = (
|
||||
graph_manager.agent_params.algorithm.num_consecutive_playing_steps
|
||||
|
||||
Reference in New Issue
Block a user