diff --git a/rl_coach/agents/agent.py b/rl_coach/agents/agent.py index f8fcab6..5ddc48e 100644 --- a/rl_coach/agents/agent.py +++ b/rl_coach/agents/agent.py @@ -563,9 +563,12 @@ class Agent(AgentInterface): if step_method.__class__ == EnvironmentEpisodes: should_update = (self.current_episode - self.last_training_phase_step) >= step_method.num_steps + should_update = should_update and self.call_memory('length') > 0 elif step_method.__class__ == EnvironmentSteps: should_update = (self.total_steps_counter - self.last_training_phase_step) >= step_method.num_steps + should_update = should_update and self.call_memory('num_transitions') > 0 + if wait_for_full_episode: should_update = should_update and self.current_episode_buffer.is_complete else: diff --git a/rl_coach/orchestrators/kubernetes_orchestrator.py b/rl_coach/orchestrators/kubernetes_orchestrator.py index d690623..229a26e 100644 --- a/rl_coach/orchestrators/kubernetes_orchestrator.py +++ b/rl_coach/orchestrators/kubernetes_orchestrator.py @@ -3,6 +3,7 @@ import uuid import json import time from typing import List +from configparser import ConfigParser, Error from rl_coach.orchestrators.deploy import Deploy, DeployParameters from kubernetes import client as k8sclient, config as k8sconfig from rl_coach.memories.backend.memory import MemoryBackendParameters @@ -100,8 +101,8 @@ class Kubernetes(Deploy): if not trainer_params: return False - trainer_params.command += ['--memory_backend_params', json.dumps(self.params.memory_backend_parameters.__dict__)] - trainer_params.command += ['--data_store_params', json.dumps(self.params.data_store_params.__dict__)] + trainer_params.command += ['--memory-backend-params', json.dumps(self.params.memory_backend_parameters.__dict__)] + trainer_params.command += ['--data-store-params', json.dumps(self.params.data_store_params.__dict__)] name = "{}-{}".format(trainer_params.run_type, uuid.uuid4()) @@ -176,8 +177,8 @@ class Kubernetes(Deploy): if not worker_params: return False - worker_params.command += ['--memory_backend_params', json.dumps(self.params.memory_backend_parameters.__dict__)] - worker_params.command += ['--data_store_params', json.dumps(self.params.data_store_params.__dict__)] + worker_params.command += ['--memory-backend-params', json.dumps(self.params.memory_backend_parameters.__dict__)] + worker_params.command += ['--data-store-params', json.dumps(self.params.data_store_params.__dict__)] name = "{}-{}".format(worker_params.run_type, uuid.uuid4()) diff --git a/rl_coach/orchestrators/start_training.py b/rl_coach/orchestrators/start_training.py index dbf9df0..aea8237 100644 --- a/rl_coach/orchestrators/start_training.py +++ b/rl_coach/orchestrators/start_training.py @@ -77,19 +77,19 @@ if __name__ == '__main__': parser.add_argument('-ns', '--nfs-server', help="(string) Addresss of the nfs server.", type=str, - required=True) + required=False) parser.add_argument('-np', '--nfs-path', help="(string) Exported path for the nfs server.", type=str, - required=True) + required=False) parser.add_argument('--s3-end-point', help="(string) S3 endpoint to use when S3 data store is used.", type=str, - required=True) + required=False) parser.add_argument('--s3-bucket-name', help="(string) S3 bucket name to use when S3 data store is used.", type=str, - required=True) + required=False) parser.add_argument('--num-workers', help="(string) Number of rollout workers", type=int, diff --git a/rl_coach/rollout_worker.py b/rl_coach/rollout_worker.py index 8efa7c2..744c92e 100644 --- a/rl_coach/rollout_worker.py +++ b/rl_coach/rollout_worker.py @@ -124,22 +124,14 @@ def main(): help="(string) Name of a preset to run (class name from the 'presets' directory.)", type=str, required=True) - parser.add_argument('--checkpoint_dir', + parser.add_argument('--checkpoint-dir', help='(string) Path to a folder containing a checkpoint to restore the model from.', type=str, default='/checkpoint') - parser.add_argument('-r', '--redis_ip', - help="(string) IP or host for the redis server", - default='localhost', - type=str) - parser.add_argument('-rp', '--redis_port', - help="(int) Port of the redis server", - default=6379, - type=int) - parser.add_argument('--memory_backend_params', + parser.add_argument('--memory-backend-params', help="(string) JSON string of the memory backend params", type=str) - parser.add_argument('--data_store_params', + parser.add_argument('--data-store-params', help="(string) JSON string of the data store params", type=str) diff --git a/rl_coach/training_worker.py b/rl_coach/training_worker.py index 1c79726..ba15df1 100644 --- a/rl_coach/training_worker.py +++ b/rl_coach/training_worker.py @@ -48,22 +48,14 @@ def main(): help="(string) Name of a preset to run (class name from the 'presets' directory.)", type=str, required=True) - parser.add_argument('--checkpoint_dir', + parser.add_argument('--checkpoint-dir', help='(string) Path to a folder containing a checkpoint to write the model to.', type=str, default='/checkpoint') - parser.add_argument('-r', '--redis_ip', - help="(string) IP or host for the redis server", - default='localhost', - type=str) - parser.add_argument('-rp', '--redis_port', - help="(int) Port of the redis server", - default=6379, - type=int) - parser.add_argument('--memory_backend_params', + parser.add_argument('--memory-backend-params', help="(string) JSON string of the memory backend params", type=str) - parser.add_argument('--data_store_params', + parser.add_argument('--data-store-params', help="(string) JSON string of the data store params", type=str) args = parser.parse_args()