1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 19:20:19 +01:00

Adding target reward and target sucess (#58)

* Adding target reward

* Adding target successs

* Addressing comments

* Using custom_reward_threshold and target_success_rate

* Adding exit message

* Moving success rate to environment

* Making target_success_rate optional
This commit is contained in:
Ajay Deshpande
2018-11-12 15:03:43 -08:00
committed by Balaji Subramaniam
parent 0fe583186e
commit 875d6ef017
17 changed files with 162 additions and 74 deletions

View File

@@ -842,7 +842,5 @@ class Agent(AgentInterface):
for network in self.networks.values():
network.sync()
def get_success_rate(self) -> float:
return self.num_successes_across_evaluation_episodes / self.num_evaluation_episodes_completed

View File

@@ -168,7 +168,7 @@ class AlgorithmParameters(Parameters):
# n-step returns
self.n_step = -1 # calculate the total return (no bootstrap, by default)
# Distributed Coach params
self.distributed_coach_synchronization_type = None

View File

@@ -1,4 +1,6 @@
from enum import Enum
class DataStoreParameters(object):
def __init__(self, store_type, orchestrator_type, orchestrator_params):
@@ -6,6 +8,7 @@ class DataStoreParameters(object):
self.orchestrator_type = orchestrator_type
self.orchestrator_params = orchestrator_params
class DataStore(object):
def __init__(self, params: DataStoreParameters):
pass
@@ -24,3 +27,8 @@ class DataStore(object):
def load_from_store(self):
pass
class SyncFiles(Enum):
FINISHED = ".finished"
LOCKFILE = ".lock"

View File

@@ -58,7 +58,7 @@ class NFSDataStore(DataStore):
pass
def deploy_k8s_nfs(self) -> bool:
name = "nfs-server"
name = "nfs-server-{}".format(uuid.uuid4())
container = k8sclient.V1Container(
name=name,
image="k8s.gcr.io/volume-nfs:0.8",
@@ -83,7 +83,7 @@ class NFSDataStore(DataStore):
security_context=k8sclient.V1SecurityContext(privileged=True)
)
template = k8sclient.V1PodTemplateSpec(
metadata=k8sclient.V1ObjectMeta(labels={'app': 'nfs-server'}),
metadata=k8sclient.V1ObjectMeta(labels={'app': name}),
spec=k8sclient.V1PodSpec(
containers=[container],
volumes=[k8sclient.V1Volume(
@@ -96,14 +96,14 @@ class NFSDataStore(DataStore):
replicas=1,
template=template,
selector=k8sclient.V1LabelSelector(
match_labels={'app': 'nfs-server'}
match_labels={'app': name}
)
)
deployment = k8sclient.V1Deployment(
api_version='apps/v1',
kind='Deployment',
metadata=k8sclient.V1ObjectMeta(name=name, labels={'app': 'nfs-server'}),
metadata=k8sclient.V1ObjectMeta(name=name, labels={'app': name}),
spec=deployment_spec
)
@@ -117,7 +117,7 @@ class NFSDataStore(DataStore):
k8s_core_v1_api_client = k8sclient.CoreV1Api()
svc_name = "nfs-service"
svc_name = "nfs-service-{}".format(uuid.uuid4())
service = k8sclient.V1Service(
api_version='v1',
kind='Service',
@@ -145,7 +145,7 @@ class NFSDataStore(DataStore):
return True
def create_k8s_nfs_resources(self) -> bool:
pv_name = "nfs-ckpt-pv"
pv_name = "nfs-ckpt-pv-{}".format(uuid.uuid4())
persistent_volume = k8sclient.V1PersistentVolume(
api_version="v1",
kind="PersistentVolume",
@@ -171,7 +171,7 @@ class NFSDataStore(DataStore):
print("Got exception: %s\n while creating the NFS PV", e)
return False
pvc_name = "nfs-ckpt-pvc"
pvc_name = "nfs-ckpt-pvc-{}".format(uuid.uuid4())
persistent_volume_claim = k8sclient.V1PersistentVolumeClaim(
api_version="v1",
kind="PersistentVolumeClaim",

View File

@@ -5,6 +5,7 @@ from minio.error import ResponseError
from configparser import ConfigParser, Error
from google.protobuf import text_format
from tensorflow.python.training.checkpoint_state_pb2 import CheckpointState
from rl_coach.data_stores.data_store import SyncFiles
import os
import time
@@ -20,7 +21,6 @@ class S3DataStoreParameters(DataStoreParameters):
self.end_point = end_point
self.bucket_name = bucket_name
self.checkpoint_dir = checkpoint_dir
self.lock_file = ".lock"
class S3DataStore(DataStore):
@@ -52,9 +52,9 @@ class S3DataStore(DataStore):
def save_to_store(self):
try:
self.mc.remove_object(self.params.bucket_name, self.params.lock_file)
self.mc.remove_object(self.params.bucket_name, SyncFiles.LOCKFILE.value)
self.mc.put_object(self.params.bucket_name, self.params.lock_file, io.BytesIO(b''), 0)
self.mc.put_object(self.params.bucket_name, SyncFiles.LOCKFILE.value, io.BytesIO(b''), 0)
checkpoint_file = None
for root, dirs, files in os.walk(self.params.checkpoint_dir):
@@ -70,7 +70,7 @@ class S3DataStore(DataStore):
rel_name = os.path.relpath(abs_name, self.params.checkpoint_dir)
self.mc.fput_object(self.params.bucket_name, rel_name, abs_name)
self.mc.remove_object(self.params.bucket_name, self.params.lock_file)
self.mc.remove_object(self.params.bucket_name, SyncFiles.LOCKFILE.value)
except ResponseError as e:
print("Got exception: %s\n while saving to S3", e)
@@ -80,7 +80,7 @@ class S3DataStore(DataStore):
filename = os.path.abspath(os.path.join(self.params.checkpoint_dir, "checkpoint"))
while True:
objects = self.mc.list_objects_v2(self.params.bucket_name, self.params.lock_file)
objects = self.mc.list_objects_v2(self.params.bucket_name, SyncFiles.LOCKFILE.value)
if next(objects, None) is None:
try:
@@ -90,6 +90,18 @@ class S3DataStore(DataStore):
break
time.sleep(10)
# Check if there's a finished file
objects = self.mc.list_objects_v2(self.params.bucket_name, SyncFiles.FINISHED.value)
if next(objects, None) is not None:
try:
self.mc.fget_object(
self.params.bucket_name, SyncFiles.FINISHED.value,
os.path.abspath(os.path.join(self.params.checkpoint_dir, SyncFiles.FINISHED.value))
)
except Exception as e:
pass
ckpt = CheckpointState()
if os.path.exists(filename):
contents = open(filename, 'r').read()

View File

@@ -133,8 +133,8 @@ class CarlaEnvironment(Environment):
allow_braking: bool, quality: CarlaEnvironmentParameters.Quality,
cameras: List[CameraTypes], weather_id: List[int], experiment_path: str,
separate_actions_for_throttle_and_brake: bool,
num_speedup_steps: int, max_speed: float, **kwargs):
super().__init__(level, seed, frame_skip, human_control, custom_reward_threshold, visualization_parameters)
num_speedup_steps: int, max_speed: float, target_success_rate: float = 1.0, **kwargs):
super().__init__(level, seed, frame_skip, human_control, custom_reward_threshold, visualization_parameters, target_success_rate)
# server configuration
self.server_height = server_height
@@ -261,6 +261,8 @@ class CarlaEnvironment(Environment):
image = self.get_rendered_image()
self.renderer.create_screen(image.shape[1], image.shape[0])
self.target_success_rate = target_success_rate
def _add_cameras(self, settings, cameras, camera_width, camera_height):
# add a front facing camera
if CameraTypes.FRONT in cameras:
@@ -461,3 +463,6 @@ class CarlaEnvironment(Environment):
image = [self.state[camera.name] for camera in self.scene.sensors]
image = np.vstack(image)
return image
def get_target_success_rate(self) -> float:
return self.target_success_rate

View File

@@ -66,10 +66,10 @@ control_suite_envs = {':'.join(env): ':'.join(env) for env in suite.BENCHMARKING
# Environment
class ControlSuiteEnvironment(Environment):
def __init__(self, level: LevelSelection, frame_skip: int, visualization_parameters: VisualizationParameters,
seed: Union[None, int]=None, human_control: bool=False,
target_success_rate: float=1.0, seed: Union[None, int]=None, human_control: bool=False,
observation_type: ObservationType=ObservationType.Measurements,
custom_reward_threshold: Union[int, float]=None, **kwargs):
super().__init__(level, seed, frame_skip, human_control, custom_reward_threshold, visualization_parameters)
super().__init__(level, seed, frame_skip, human_control, custom_reward_threshold, visualization_parameters, target_success_rate)
self.observation_type = observation_type
@@ -126,6 +126,8 @@ class ControlSuiteEnvironment(Environment):
if not self.native_rendering:
self.renderer.create_screen(image.shape[1]*scale, image.shape[0]*scale)
self.target_success_rate = target_success_rate
def _update_state(self):
self.state = {}
@@ -160,3 +162,6 @@ class ControlSuiteEnvironment(Environment):
def get_rendered_image(self):
return self.env.physics.render(camera_id=0)
def get_target_success_rate(self) -> float:
return self.target_success_rate

View File

@@ -1,5 +1,5 @@
#
# Copyright (c) 2017 Intel Corporation
# Copyright (c) 2017 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -124,8 +124,8 @@ class DoomEnvironment(Environment):
def __init__(self, level: LevelSelection, seed: int, frame_skip: int, human_control: bool,
custom_reward_threshold: Union[int, float], visualization_parameters: VisualizationParameters,
cameras: List[CameraTypes], **kwargs):
super().__init__(level, seed, frame_skip, human_control, custom_reward_threshold, visualization_parameters)
cameras: List[CameraTypes], target_success_rate: float=1.0, **kwargs):
super().__init__(level, seed, frame_skip, human_control, custom_reward_threshold, visualization_parameters, target_success_rate)
self.cameras = cameras
@@ -196,6 +196,8 @@ class DoomEnvironment(Environment):
image = self.get_rendered_image()
self.renderer.create_screen(image.shape[1], image.shape[0])
self.target_success_rate = target_success_rate
def _update_state(self):
# extract all data from the current state
state = self.game.get_state()
@@ -227,3 +229,6 @@ class DoomEnvironment(Environment):
image = [self.state[camera.value[0]] for camera in self.cameras]
image = np.vstack(image)
return image
def get_target_success_rate(self) -> float:
return self.target_success_rate

View File

@@ -103,6 +103,9 @@ class EnvironmentParameters(Parameters):
self.default_output_filter = None
self.experiment_path = None
# Set target reward and target_success if present
self.target_success_rate = 1.0
@property
def path(self):
return 'rl_coach.environments.environment:Environment'
@@ -111,7 +114,7 @@ class EnvironmentParameters(Parameters):
class Environment(EnvironmentInterface):
def __init__(self, level: LevelSelection, seed: int, frame_skip: int, human_control: bool,
custom_reward_threshold: Union[int, float], visualization_parameters: VisualizationParameters,
**kwargs):
target_success_rate: float=1.0, **kwargs):
"""
:param level: The environment level. Each environment can have multiple levels
:param seed: a seed for the random number generator of the environment
@@ -166,6 +169,9 @@ class Environment(EnvironmentInterface):
if not self.native_rendering:
self.renderer = Renderer()
# Set target reward and target_success if present
self.target_success_rate = target_success_rate
@property
def action_space(self) -> Union[List[ActionSpace], ActionSpace]:
"""
@@ -469,3 +475,5 @@ class Environment(EnvironmentInterface):
"""
return np.transpose(self.state['observation'], [1, 2, 0])
def get_target_success_rate(self) -> float:
return self.target_success_rate

View File

@@ -178,11 +178,11 @@ class MaxOverFramesAndFrameskipEnvWrapper(gym.Wrapper):
# Environment
class GymEnvironment(Environment):
def __init__(self, level: LevelSelection, frame_skip: int, visualization_parameters: VisualizationParameters,
additional_simulator_parameters: Dict[str, Any] = {}, seed: Union[None, int]=None,
target_success_rate: float=1.0, additional_simulator_parameters: Dict[str, Any] = {}, seed: Union[None, int]=None,
human_control: bool=False, custom_reward_threshold: Union[int, float]=None,
random_initialization_steps: int=1, max_over_num_frames: int=1, **kwargs):
super().__init__(level, seed, frame_skip, human_control, custom_reward_threshold,
visualization_parameters)
visualization_parameters, target_success_rate)
self.random_initialization_steps = random_initialization_steps
self.max_over_num_frames = max_over_num_frames
@@ -221,7 +221,7 @@ class GymEnvironment(Environment):
try:
self.env = env_class(**self.additional_simulator_parameters)
except:
screen.error("Failed to instantiate Gym environment class %s with arguments %s" %
screen.error("Failed to instantiate Gym environment class %s with arguments %s" %
(env_class, self.additional_simulator_parameters), crash=False)
raise
else:
@@ -337,6 +337,8 @@ class GymEnvironment(Environment):
self.reward_success_threshold = self.env.spec.reward_threshold
self.reward_space = RewardSpace(1, reward_success_threshold=self.reward_success_threshold)
self.target_success_rate = target_success_rate
def _wrap_state(self, state):
if not isinstance(self.env.observation_space, gym.spaces.Dict):
return {'observation': state}
@@ -434,3 +436,6 @@ class GymEnvironment(Environment):
if self.is_mujoco_env:
self._set_mujoco_camera(0)
return image
def get_target_success_rate(self) -> float:
return self.target_success_rate

View File

@@ -107,14 +107,14 @@ class StarCraft2EnvironmentParameters(EnvironmentParameters):
# Environment
class StarCraft2Environment(Environment):
def __init__(self, level: LevelSelection, frame_skip: int, visualization_parameters: VisualizationParameters,
seed: Union[None, int]=None, human_control: bool=False,
target_success_rate: float=1.0, seed: Union[None, int]=None, human_control: bool=False,
custom_reward_threshold: Union[int, float]=None,
screen_size: int=84, minimap_size: int=64,
feature_minimap_maps_to_use: List=range(7), feature_screen_maps_to_use: List=range(17),
observation_type: StarcraftObservationType=StarcraftObservationType.Features,
disable_fog: bool=False, auto_select_all_army: bool=True,
use_full_action_space: bool=False, **kwargs):
super().__init__(level, seed, frame_skip, human_control, custom_reward_threshold, visualization_parameters)
super().__init__(level, seed, frame_skip, human_control, custom_reward_threshold, visualization_parameters, target_success_rate)
self.screen_size = screen_size
self.minimap_size = minimap_size
@@ -163,11 +163,11 @@ class StarCraft2Environment(Environment):
"""
feature_screen: [height_map, visibility_map, creep, power, player_id, player_relative, unit_type, selected,
unit_hit_points, unit_hit_points_ratio, unit_energy, unit_energy_ratio, unit_shields,
unit_hit_points, unit_hit_points_ratio, unit_energy, unit_energy_ratio, unit_shields,
unit_shields_ratio, unit_density, unit_density_aa, effects]
feature_minimap: [height_map, visibility_map, creep, camera, player_id, player_relative, selecte
d]
player: [player_id, minerals, vespene, food_cap, food_army, food_workers, idle_worker_dount,
player: [player_id, minerals, vespene, food_cap, food_army, food_workers, idle_worker_dount,
army_count, warp_gate_count, larva_count]
"""
self.screen_shape = np.array(self.env.observation_spec()[0]['feature_screen'])
@@ -192,6 +192,8 @@ class StarCraft2Environment(Environment):
self.action_space = BoxActionSpace(2, 0, self.screen_size - 1, ["X-Axis, Y-Axis"],
default_action=np.array([self.screen_size/2, self.screen_size/2]))
self.target_success_rate = target_success_rate
def _update_state(self):
timestep = 0
self.screen = self.last_result[timestep].observation.feature_screen
@@ -244,3 +246,6 @@ class StarCraft2Environment(Environment):
self.env._run_config.replay_dir = experiment_path
self.env.save_replay('replays')
super().dump_video_of_last_episode()
def get_target_success_rate(self):
return self.target_success_rate

View File

@@ -34,6 +34,7 @@ from rl_coach.logger import screen, Logger
from rl_coach.utils import set_cpu, start_shell_command_and_wait
from rl_coach.data_stores.data_store_impl import get_data_store
from rl_coach.orchestrators.kubernetes_orchestrator import RunType
from rl_coach.data_stores.data_store import SyncFiles
class ScheduleParameters(Parameters):
@@ -458,12 +459,12 @@ class GraphManager(object):
"""
[manager.sync() for manager in self.level_managers]
def evaluate(self, steps: PlayingStepsType, keep_networks_in_sync: bool=False) -> None:
def evaluate(self, steps: PlayingStepsType, keep_networks_in_sync: bool=False) -> bool:
"""
Perform evaluation for several steps
:param steps: the number of steps as a tuple of steps time and steps count
:param keep_networks_in_sync: sync the network parameters with the global network before each episode
:return: None
:return: bool, True if the target reward and target success has been reached
"""
self.verify_graph_was_created()
@@ -478,6 +479,16 @@ class GraphManager(object):
while self.current_step_counter < count_end:
self.act(EnvironmentEpisodes(1))
self.sync()
if self.should_stop():
if self.task_parameters.checkpoint_save_dir:
open(os.path.join(self.task_parameters.checkpoint_save_dir, SyncFiles.FINISHED.value), 'w').close()
if hasattr(self, 'data_store_params'):
data_store = get_data_store(self.data_store_params)
data_store.save_to_store()
screen.success("Reached required success rate. Exiting.")
return True
return False
def improve(self):
"""
@@ -508,7 +519,8 @@ class GraphManager(object):
count_end = self.total_steps_counters[RunPhase.TRAIN] + self.improve_steps
while self.total_steps_counters[RunPhase.TRAIN] < count_end:
self.train_and_act(self.steps_between_evaluation_periods)
self.evaluate(self.evaluation_steps)
if self.evaluate(self.evaluation_steps):
break
def _restore_checkpoint_tf(self, checkpoint_dir: str):
import tensorflow as tf
@@ -609,3 +621,6 @@ class GraphManager(object):
def should_train(self) -> bool:
return any([manager.should_train() for manager in self.level_managers])
def should_stop(self) -> bool:
return all([manager.should_stop() for manager in self.level_managers])

View File

@@ -263,3 +263,6 @@ class LevelManager(EnvironmentInterface):
def should_train(self) -> bool:
return any([agent._should_train_helper() for agent in self.agents.values()])
def should_stop(self) -> bool:
return all([agent.get_success_rate() >= self.environment.get_target_success_rate() for agent in self.agents.values()])

View File

@@ -79,6 +79,7 @@ class Kubernetes(Deploy):
self.memory_backend = get_memory_backend(self.params.memory_backend_parameters)
self.params.data_store_params.orchestrator_params = {'namespace': self.params.namespace}
self.params.data_store_params.namespace = self.params.namespace
self.data_store = get_data_store(self.params.data_store_params)
if self.params.data_store_params.store_type == "s3":
@@ -137,7 +138,8 @@ class Kubernetes(Deploy):
volumes=[k8sclient.V1Volume(
name="nfs-pvc",
persistent_volume_claim=self.nfs_pvc
)]
)],
restart_policy='OnFailure'
),
)
else:
@@ -155,32 +157,30 @@ class Kubernetes(Deploy):
template = k8sclient.V1PodTemplateSpec(
metadata=k8sclient.V1ObjectMeta(labels={'app': name}),
spec=k8sclient.V1PodSpec(
containers=[container]
containers=[container],
restart_policy='OnFailure'
),
)
deployment_spec = k8sclient.V1DeploymentSpec(
replicas=trainer_params.num_replicas,
template=template,
selector=k8sclient.V1LabelSelector(
match_labels={'app': name}
)
job_spec = k8sclient.V1JobSpec(
completions=1,
template=template
)
deployment = k8sclient.V1Deployment(
api_version='apps/v1',
kind='Deployment',
job = k8sclient.V1Job(
api_version="batch/v1",
kind="Job",
metadata=k8sclient.V1ObjectMeta(name=name),
spec=deployment_spec
spec=job_spec
)
api_client = k8sclient.AppsV1Api()
api_client = k8sclient.BatchV1Api()
try:
api_client.create_namespaced_deployment(self.params.namespace, deployment)
trainer_params.orchestration_params['deployment_name'] = name
api_client.create_namespaced_job(self.params.namespace, job)
trainer_params.orchestration_params['job_name'] = name
return True
except k8sclient.rest.ApiException as e:
print("Got exception: %s\n while creating deployment", e)
print("Got exception: %s\n while creating job", e)
return False
def deploy_worker(self):
@@ -217,6 +217,7 @@ class Kubernetes(Deploy):
name="nfs-pvc",
persistent_volume_claim=self.nfs_pvc
)],
restart_policy='OnFailure'
),
)
else:
@@ -234,31 +235,31 @@ class Kubernetes(Deploy):
template = k8sclient.V1PodTemplateSpec(
metadata=k8sclient.V1ObjectMeta(labels={'app': name}),
spec=k8sclient.V1PodSpec(
containers=[container]
containers=[container],
restart_policy='OnFailure'
)
)
deployment_spec = k8sclient.V1DeploymentSpec(
replicas=worker_params.num_replicas,
template=template,
selector=k8sclient.V1LabelSelector(
match_labels={'app': name}
)
)
deployment = k8sclient.V1Deployment(
api_version='apps/v1',
kind="Deployment",
metadata=k8sclient.V1ObjectMeta(name=name),
spec=deployment_spec
job_spec = k8sclient.V1JobSpec(
completions=worker_params.num_replicas,
parallelism=worker_params.num_replicas,
template=template
)
api_client = k8sclient.AppsV1Api()
job = k8sclient.V1Job(
api_version="batch/v1",
kind="Job",
metadata=k8sclient.V1ObjectMeta(name=name),
spec=job_spec
)
api_client = k8sclient.BatchV1Api()
try:
api_client.create_namespaced_deployment(self.params.namespace, deployment)
worker_params.orchestration_params['deployment_name'] = name
api_client.create_namespaced_job(self.params.namespace, job)
worker_params.orchestration_params['job_name'] = name
return True
except k8sclient.rest.ApiException as e:
print("Got exception: %s\n while creating deployment", e)
print("Got exception: %s\n while creating Job", e)
return False
def worker_logs(self):
@@ -273,7 +274,7 @@ class Kubernetes(Deploy):
pod = None
try:
pods = api_client.list_namespaced_pod(self.params.namespace, label_selector='app={}'.format(
trainer_params.orchestration_params['deployment_name']
trainer_params.orchestration_params['job_name']
))
pod = pods.items[0]
@@ -324,17 +325,20 @@ class Kubernetes(Deploy):
def undeploy(self):
trainer_params = self.params.run_type_params.get(str(RunType.TRAINER), None)
api_client = k8sclient.AppsV1Api()
delete_options = k8sclient.V1DeleteOptions()
api_client = k8sclient.BatchV1Api()
delete_options = k8sclient.V1DeleteOptions(
propagation_policy="Foreground"
)
if trainer_params:
try:
api_client.delete_namespaced_deployment(trainer_params.orchestration_params['deployment_name'], self.params.namespace, delete_options)
api_client.delete_namespaced_job(trainer_params.orchestration_params['job_name'], self.params.namespace, delete_options)
except k8sclient.rest.ApiException as e:
print("Got exception: %s\n while deleting trainer", e)
worker_params = self.params.run_type_params.get(str(RunType.ROLLOUT_WORKER), None)
if worker_params:
try:
api_client.delete_namespaced_deployment(worker_params.orchestration_params['deployment_name'], self.params.namespace, delete_options)
api_client.delete_namespaced_job(worker_params.orchestration_params['job_name'], self.params.namespace, delete_options)
except k8sclient.rest.ApiException as e:
print("Got exception: %s\n while deleting workers", e)
self.memory_backend.undeploy()

View File

@@ -56,6 +56,10 @@ agent_params.exploration.epsilon_schedule = LinearSchedule(1.0, 0.01, 10000)
# Environment #
###############
env_params = GymVectorEnvironment(level='CartPole-v0')
env_params.custom_reward_threshold = 200
# Set the target success
env_params.target_success_rate = 1.0
########
# Test #

View File

@@ -15,6 +15,7 @@ from rl_coach.base_parameters import TaskParameters, DistributedCoachSynchroniza
from rl_coach.core_types import EnvironmentSteps, RunPhase
from google.protobuf import text_format
from tensorflow.python.training.checkpoint_state_pb2 import CheckpointState
from rl_coach.data_stores.data_store import SyncFiles
def has_checkpoint(checkpoint_dir):
@@ -68,6 +69,10 @@ def get_latest_checkpoint(checkpoint_dir):
return int(rel_path.split('_Step')[0])
def should_stop(checkpoint_dir):
return os.path.exists(os.path.join(checkpoint_dir, SyncFiles.FINISHED.value))
def rollout_worker(graph_manager, checkpoint_dir, data_store, num_workers):
"""
wait for first checkpoint then perform rollouts using the model
@@ -87,12 +92,17 @@ def rollout_worker(graph_manager, checkpoint_dir, data_store, num_workers):
for i in range(int(graph_manager.improve_steps.num_steps/act_steps)):
if should_stop(checkpoint_dir):
break
graph_manager.act(EnvironmentSteps(num_steps=act_steps))
new_checkpoint = get_latest_checkpoint(checkpoint_dir)
if graph_manager.agent_params.algorithm.distributed_coach_synchronization_type == DistributedCoachSynchronizationType.SYNC:
while new_checkpoint < last_checkpoint + 1:
if should_stop(checkpoint_dir):
break
if data_store:
data_store.load_from_store()
new_checkpoint = get_latest_checkpoint(checkpoint_dir)

View File

@@ -40,8 +40,9 @@ def training_worker(graph_manager, checkpoint_dir):
graph_manager.phase = core_types.RunPhase.UNDEFINED
if steps * graph_manager.agent_params.algorithm.num_consecutive_playing_steps.num_steps > graph_manager.steps_between_evaluation_periods.num_steps * eval_offset:
graph_manager.evaluate(graph_manager.evaluation_steps)
eval_offset += 1
if graph_manager.evaluate(graph_manager.evaluation_steps):
break
if graph_manager.agent_params.algorithm.distributed_coach_synchronization_type == DistributedCoachSynchronizationType.SYNC:
graph_manager.save_checkpoint()