mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
add reset internal state to rollout worker (#421)
This commit is contained in:
@@ -208,7 +208,7 @@ class Kubernetes(Deploy):
|
|||||||
tty=True,
|
tty=True,
|
||||||
resources=k8sclient.V1ResourceRequirements(
|
resources=k8sclient.V1ResourceRequirements(
|
||||||
limits={
|
limits={
|
||||||
"cpu": "40",
|
"cpu": "24",
|
||||||
"memory": "4Gi",
|
"memory": "4Gi",
|
||||||
"nvidia.com/gpu": "1",
|
"nvidia.com/gpu": "1",
|
||||||
}
|
}
|
||||||
@@ -322,7 +322,7 @@ class Kubernetes(Deploy):
|
|||||||
tty=True,
|
tty=True,
|
||||||
resources=k8sclient.V1ResourceRequirements(
|
resources=k8sclient.V1ResourceRequirements(
|
||||||
limits={
|
limits={
|
||||||
"cpu": "8",
|
"cpu": "4",
|
||||||
"memory": "4Gi",
|
"memory": "4Gi",
|
||||||
# "nvidia.com/gpu": "0",
|
# "nvidia.com/gpu": "0",
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -31,6 +31,7 @@ import os
|
|||||||
from rl_coach.base_parameters import TaskParameters, DistributedCoachSynchronizationType
|
from rl_coach.base_parameters import TaskParameters, DistributedCoachSynchronizationType
|
||||||
from rl_coach.checkpoint import CheckpointStateFile, CheckpointStateReader
|
from rl_coach.checkpoint import CheckpointStateFile, CheckpointStateReader
|
||||||
from rl_coach.data_stores.data_store import SyncFiles
|
from rl_coach.data_stores.data_store import SyncFiles
|
||||||
|
from rl_coach.core_types import RunPhase
|
||||||
|
|
||||||
|
|
||||||
def wait_for(wait_func, data_store=None, timeout=10):
|
def wait_for(wait_func, data_store=None, timeout=10):
|
||||||
@@ -71,7 +72,6 @@ def rollout_worker(graph_manager, data_store, num_workers, task_parameters):
|
|||||||
"""
|
"""
|
||||||
wait for first checkpoint then perform rollouts using the model
|
wait for first checkpoint then perform rollouts using the model
|
||||||
"""
|
"""
|
||||||
wait_for_trainer_ready(checkpoint_dir, data_store)
|
|
||||||
if (
|
if (
|
||||||
graph_manager.agent_params.algorithm.distributed_coach_synchronization_type
|
graph_manager.agent_params.algorithm.distributed_coach_synchronization_type
|
||||||
== DistributedCoachSynchronizationType.SYNC
|
== DistributedCoachSynchronizationType.SYNC
|
||||||
@@ -87,6 +87,7 @@ def rollout_worker(graph_manager, data_store, num_workers, task_parameters):
|
|||||||
|
|
||||||
with graph_manager.phase_context(RunPhase.TRAIN):
|
with graph_manager.phase_context(RunPhase.TRAIN):
|
||||||
# this worker should play a fraction of the total playing steps per rollout
|
# this worker should play a fraction of the total playing steps per rollout
|
||||||
|
graph_manager.reset_internal_state(force_environment_reset=True)
|
||||||
|
|
||||||
act_steps = (
|
act_steps = (
|
||||||
graph_manager.agent_params.algorithm.num_consecutive_playing_steps
|
graph_manager.agent_params.algorithm.num_consecutive_playing_steps
|
||||||
|
|||||||
Reference in New Issue
Block a user