mirror of
https://github.com/gryf/coach.git
synced 2025-12-18 03:30:19 +01:00
Adding parameteres, checking transitions before training
This commit is contained in:
committed by
zach dwiel
parent
0f46877d7e
commit
b285a02023
@@ -563,9 +563,12 @@ class Agent(AgentInterface):
|
|||||||
|
|
||||||
if step_method.__class__ == EnvironmentEpisodes:
|
if step_method.__class__ == EnvironmentEpisodes:
|
||||||
should_update = (self.current_episode - self.last_training_phase_step) >= step_method.num_steps
|
should_update = (self.current_episode - self.last_training_phase_step) >= step_method.num_steps
|
||||||
|
should_update = should_update and self.call_memory('length') > 0
|
||||||
|
|
||||||
elif step_method.__class__ == EnvironmentSteps:
|
elif step_method.__class__ == EnvironmentSteps:
|
||||||
should_update = (self.total_steps_counter - self.last_training_phase_step) >= step_method.num_steps
|
should_update = (self.total_steps_counter - self.last_training_phase_step) >= step_method.num_steps
|
||||||
|
should_update = should_update and self.call_memory('num_transitions') > 0
|
||||||
|
|
||||||
if wait_for_full_episode:
|
if wait_for_full_episode:
|
||||||
should_update = should_update and self.current_episode_buffer.is_complete
|
should_update = should_update and self.current_episode_buffer.is_complete
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ import uuid
|
|||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
from typing import List
|
from typing import List
|
||||||
|
from configparser import ConfigParser, Error
|
||||||
from rl_coach.orchestrators.deploy import Deploy, DeployParameters
|
from rl_coach.orchestrators.deploy import Deploy, DeployParameters
|
||||||
from kubernetes import client as k8sclient, config as k8sconfig
|
from kubernetes import client as k8sclient, config as k8sconfig
|
||||||
from rl_coach.memories.backend.memory import MemoryBackendParameters
|
from rl_coach.memories.backend.memory import MemoryBackendParameters
|
||||||
@@ -100,8 +101,8 @@ class Kubernetes(Deploy):
|
|||||||
if not trainer_params:
|
if not trainer_params:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
trainer_params.command += ['--memory_backend_params', json.dumps(self.params.memory_backend_parameters.__dict__)]
|
trainer_params.command += ['--memory-backend-params', json.dumps(self.params.memory_backend_parameters.__dict__)]
|
||||||
trainer_params.command += ['--data_store_params', json.dumps(self.params.data_store_params.__dict__)]
|
trainer_params.command += ['--data-store-params', json.dumps(self.params.data_store_params.__dict__)]
|
||||||
|
|
||||||
name = "{}-{}".format(trainer_params.run_type, uuid.uuid4())
|
name = "{}-{}".format(trainer_params.run_type, uuid.uuid4())
|
||||||
|
|
||||||
@@ -176,8 +177,8 @@ class Kubernetes(Deploy):
|
|||||||
if not worker_params:
|
if not worker_params:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
worker_params.command += ['--memory_backend_params', json.dumps(self.params.memory_backend_parameters.__dict__)]
|
worker_params.command += ['--memory-backend-params', json.dumps(self.params.memory_backend_parameters.__dict__)]
|
||||||
worker_params.command += ['--data_store_params', json.dumps(self.params.data_store_params.__dict__)]
|
worker_params.command += ['--data-store-params', json.dumps(self.params.data_store_params.__dict__)]
|
||||||
|
|
||||||
name = "{}-{}".format(worker_params.run_type, uuid.uuid4())
|
name = "{}-{}".format(worker_params.run_type, uuid.uuid4())
|
||||||
|
|
||||||
|
|||||||
@@ -77,19 +77,19 @@ if __name__ == '__main__':
|
|||||||
parser.add_argument('-ns', '--nfs-server',
|
parser.add_argument('-ns', '--nfs-server',
|
||||||
help="(string) Addresss of the nfs server.",
|
help="(string) Addresss of the nfs server.",
|
||||||
type=str,
|
type=str,
|
||||||
required=True)
|
required=False)
|
||||||
parser.add_argument('-np', '--nfs-path',
|
parser.add_argument('-np', '--nfs-path',
|
||||||
help="(string) Exported path for the nfs server.",
|
help="(string) Exported path for the nfs server.",
|
||||||
type=str,
|
type=str,
|
||||||
required=True)
|
required=False)
|
||||||
parser.add_argument('--s3-end-point',
|
parser.add_argument('--s3-end-point',
|
||||||
help="(string) S3 endpoint to use when S3 data store is used.",
|
help="(string) S3 endpoint to use when S3 data store is used.",
|
||||||
type=str,
|
type=str,
|
||||||
required=True)
|
required=False)
|
||||||
parser.add_argument('--s3-bucket-name',
|
parser.add_argument('--s3-bucket-name',
|
||||||
help="(string) S3 bucket name to use when S3 data store is used.",
|
help="(string) S3 bucket name to use when S3 data store is used.",
|
||||||
type=str,
|
type=str,
|
||||||
required=True)
|
required=False)
|
||||||
parser.add_argument('--num-workers',
|
parser.add_argument('--num-workers',
|
||||||
help="(string) Number of rollout workers",
|
help="(string) Number of rollout workers",
|
||||||
type=int,
|
type=int,
|
||||||
|
|||||||
@@ -124,22 +124,14 @@ def main():
|
|||||||
help="(string) Name of a preset to run (class name from the 'presets' directory.)",
|
help="(string) Name of a preset to run (class name from the 'presets' directory.)",
|
||||||
type=str,
|
type=str,
|
||||||
required=True)
|
required=True)
|
||||||
parser.add_argument('--checkpoint_dir',
|
parser.add_argument('--checkpoint-dir',
|
||||||
help='(string) Path to a folder containing a checkpoint to restore the model from.',
|
help='(string) Path to a folder containing a checkpoint to restore the model from.',
|
||||||
type=str,
|
type=str,
|
||||||
default='/checkpoint')
|
default='/checkpoint')
|
||||||
parser.add_argument('-r', '--redis_ip',
|
parser.add_argument('--memory-backend-params',
|
||||||
help="(string) IP or host for the redis server",
|
|
||||||
default='localhost',
|
|
||||||
type=str)
|
|
||||||
parser.add_argument('-rp', '--redis_port',
|
|
||||||
help="(int) Port of the redis server",
|
|
||||||
default=6379,
|
|
||||||
type=int)
|
|
||||||
parser.add_argument('--memory_backend_params',
|
|
||||||
help="(string) JSON string of the memory backend params",
|
help="(string) JSON string of the memory backend params",
|
||||||
type=str)
|
type=str)
|
||||||
parser.add_argument('--data_store_params',
|
parser.add_argument('--data-store-params',
|
||||||
help="(string) JSON string of the data store params",
|
help="(string) JSON string of the data store params",
|
||||||
type=str)
|
type=str)
|
||||||
|
|
||||||
|
|||||||
@@ -48,22 +48,14 @@ def main():
|
|||||||
help="(string) Name of a preset to run (class name from the 'presets' directory.)",
|
help="(string) Name of a preset to run (class name from the 'presets' directory.)",
|
||||||
type=str,
|
type=str,
|
||||||
required=True)
|
required=True)
|
||||||
parser.add_argument('--checkpoint_dir',
|
parser.add_argument('--checkpoint-dir',
|
||||||
help='(string) Path to a folder containing a checkpoint to write the model to.',
|
help='(string) Path to a folder containing a checkpoint to write the model to.',
|
||||||
type=str,
|
type=str,
|
||||||
default='/checkpoint')
|
default='/checkpoint')
|
||||||
parser.add_argument('-r', '--redis_ip',
|
parser.add_argument('--memory-backend-params',
|
||||||
help="(string) IP or host for the redis server",
|
|
||||||
default='localhost',
|
|
||||||
type=str)
|
|
||||||
parser.add_argument('-rp', '--redis_port',
|
|
||||||
help="(int) Port of the redis server",
|
|
||||||
default=6379,
|
|
||||||
type=int)
|
|
||||||
parser.add_argument('--memory_backend_params',
|
|
||||||
help="(string) JSON string of the memory backend params",
|
help="(string) JSON string of the memory backend params",
|
||||||
type=str)
|
type=str)
|
||||||
parser.add_argument('--data_store_params',
|
parser.add_argument('--data-store-params',
|
||||||
help="(string) JSON string of the data store params",
|
help="(string) JSON string of the data store params",
|
||||||
type=str)
|
type=str)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|||||||
Reference in New Issue
Block a user