mirror of
https://github.com/gryf/coach.git
synced 2025-12-18 19:50:17 +01:00
Merge branch 'master' into tf_version_bump
This commit is contained in:
@@ -175,8 +175,8 @@ jobs:
|
|||||||
- run:
|
- run:
|
||||||
name: run gym related golden tests
|
name: run gym related golden tests
|
||||||
command: |
|
command: |
|
||||||
export PRESETS='CartPole_A3C,CartPole_Dueling_DDQN,CartPole_NStepQ,CartPole_DQN,CartPole_DFP,CartPole_PG,CartPole_NEC,CartPole_ClippedPPO,CartPole_PAL'
|
export GOLDEN_PRESETS='CartPole'
|
||||||
python3 rl_coach/tests/test_eks.py -c coach-test -bn ${CIRCLE_BUILD_NUM} -tn golden-test-gym -tc "export PRESETS=${PRESETS} && make golden_tests_without_docker" -i 316971102342.dkr.ecr.us-west-2.amazonaws.com/coach-gym_environment:$(git describe --tags --always --dirty) -cpu 2048 -mem 4096
|
python3 rl_coach/tests/test_eks.py -c coach-test -bn ${CIRCLE_BUILD_NUM} -tn golden-test-gym -tc "export GOLDEN_PRESETS=${GOLDEN_PRESETS} && make golden_tests_without_docker" -i 316971102342.dkr.ecr.us-west-2.amazonaws.com/coach-gym_environment:$(git describe --tags --always --dirty) -cpu 2048 -mem 4096
|
||||||
no_output_timeout: 30m
|
no_output_timeout: 30m
|
||||||
- run:
|
- run:
|
||||||
name: cleanup
|
name: cleanup
|
||||||
@@ -196,8 +196,8 @@ jobs:
|
|||||||
- run:
|
- run:
|
||||||
name: run doom related golden tests
|
name: run doom related golden tests
|
||||||
command: |
|
command: |
|
||||||
export PRESETS='Doom_Basic_DQN,Doom_Basic_A3C,Doom_Health_DFP'
|
export GOLDEN_PRESETS='Doom'
|
||||||
python3 rl_coach/tests/test_eks.py -c coach-test -bn ${CIRCLE_BUILD_NUM} -tn golden-test-doom -tc "export PRESETS=${PRESETS} && make golden_tests_without_docker" -i 316971102342.dkr.ecr.us-west-2.amazonaws.com/coach-doom_environment:$(git describe --tags --always --dirty) -cpu 2048 -mem 4096
|
python3 rl_coach/tests/test_eks.py -c coach-test -bn ${CIRCLE_BUILD_NUM} -tn golden-test-doom -tc "export GOLDEN_PRESETS=${GOLDEN_PRESETS} && make golden_tests_without_docker" -i 316971102342.dkr.ecr.us-west-2.amazonaws.com/coach-doom_environment:$(git describe --tags --always --dirty) -cpu 2048 -mem 4096
|
||||||
no_output_timeout: 30m
|
no_output_timeout: 30m
|
||||||
- run:
|
- run:
|
||||||
name: cleanup
|
name: cleanup
|
||||||
@@ -217,8 +217,8 @@ jobs:
|
|||||||
- run:
|
- run:
|
||||||
name: run mujoco related golden tests
|
name: run mujoco related golden tests
|
||||||
command: |
|
command: |
|
||||||
export PRESETS='BitFlip_DQN_HER,BitFlip_DQN,Mujoco_A3C,Mujoco_A3C_LSTM,Mujoco_PPO,Mujoco_ClippedPPO,Mujoco_DDPG'
|
export GOLDEN_PRESETS='BitFlip or Mujoco'
|
||||||
python3 rl_coach/tests/test_eks.py -c coach-test -bn ${CIRCLE_BUILD_NUM} -tn golden-test-mujoco -tc "export PRESETS=${PRESETS} && make golden_tests_without_docker" -i 316971102342.dkr.ecr.us-west-2.amazonaws.com/coach-mujoco_environment:$(git describe --tags --always --dirty) -cpu 2048 -mem 4096
|
python3 rl_coach/tests/test_eks.py -c coach-test -bn ${CIRCLE_BUILD_NUM} -tn golden-test-mujoco -tc "export GOLDEN_PRESETS=${GOLDEN_PRESETS} && make golden_tests_without_docker" -i 316971102342.dkr.ecr.us-west-2.amazonaws.com/coach-mujoco_environment:$(git describe --tags --always --dirty) -cpu 2048 -mem 4096
|
||||||
no_output_timeout: 30m
|
no_output_timeout: 30m
|
||||||
- run:
|
- run:
|
||||||
name: cleanup
|
name: cleanup
|
||||||
@@ -238,8 +238,8 @@ jobs:
|
|||||||
- run:
|
- run:
|
||||||
name: run gym related trace tests
|
name: run gym related trace tests
|
||||||
command: |
|
command: |
|
||||||
export PRESETS='CartPole_A3C,CartPole_Dueling_DDQN,CartPole_NStepQ,CartPole_DQN,CartPole_DFP,CartPole_PG,CartPole_NEC,CartPole_ClippedPPO,CartPole_PAL'
|
export TRACE_PRESETS='CartPole_A3C,CartPole_Dueling_DDQN,CartPole_NStepQ,CartPole_DQN,CartPole_DFP,CartPole_PG,CartPole_NEC,CartPole_ClippedPPO,CartPole_PAL'
|
||||||
python3 rl_coach/tests/test_eks.py -c coach-test -bn ${CIRCLE_BUILD_NUM} -tn trace-test-gym -tc "export PRESETS=${PRESETS} && make trace_tests_without_docker" -i 316971102342.dkr.ecr.us-west-2.amazonaws.com/coach-gym_environment:$(git describe --tags --always --dirty) -cpu 2048 -mem 4096
|
python3 rl_coach/tests/test_eks.py -c coach-test -bn ${CIRCLE_BUILD_NUM} -tn trace-test-gym -tc "export TRACE_PRESETS=${TRACE_PRESETS} && make trace_tests_without_docker" -i 316971102342.dkr.ecr.us-west-2.amazonaws.com/coach-gym_environment:$(git describe --tags --always --dirty) -cpu 2048 -mem 4096
|
||||||
no_output_timeout: 30m
|
no_output_timeout: 30m
|
||||||
- run:
|
- run:
|
||||||
name: cleanup
|
name: cleanup
|
||||||
@@ -259,8 +259,8 @@ jobs:
|
|||||||
- run:
|
- run:
|
||||||
name: run doom related trace tests
|
name: run doom related trace tests
|
||||||
command: |
|
command: |
|
||||||
export PRESETS='Doom_Basic_DQN,Doom_Basic_A3C,Doom_Health_DFP'
|
export TRACE_PRESETS='Doom_Basic_DQN,Doom_Basic_A3C,Doom_Health_DFP'
|
||||||
python3 rl_coach/tests/test_eks.py -c coach-test -bn ${CIRCLE_BUILD_NUM} -tn trace-test-doom -tc "export PRESETS=${PRESETS} && make trace_tests_without_docker" -i 316971102342.dkr.ecr.us-west-2.amazonaws.com/coach-doom_environment:$(git describe --tags --always --dirty) -cpu 2048 -mem 4096
|
python3 rl_coach/tests/test_eks.py -c coach-test -bn ${CIRCLE_BUILD_NUM} -tn trace-test-doom -tc "export TRACE_PRESETS=${TRACE_PRESETS} && make trace_tests_without_docker" -i 316971102342.dkr.ecr.us-west-2.amazonaws.com/coach-doom_environment:$(git describe --tags --always --dirty) -cpu 2048 -mem 4096
|
||||||
no_output_timeout: 30m
|
no_output_timeout: 30m
|
||||||
- run:
|
- run:
|
||||||
name: cleanup
|
name: cleanup
|
||||||
@@ -280,8 +280,8 @@ jobs:
|
|||||||
- run:
|
- run:
|
||||||
name: run mujoco related trace tests
|
name: run mujoco related trace tests
|
||||||
command: |
|
command: |
|
||||||
export PRESETS='BitFlip_DQN_HER,BitFlip_DQN,Mujoco_A3C,Mujoco_A3C_LSTM,Mujoco_PPO,Mujoco_ClippedPPO,Mujoco_DDPG'
|
export TRACE_PRESETS='BitFlip_DQN_HER,BitFlip_DQN,Mujoco_A3C,Mujoco_A3C_LSTM,Mujoco_PPO,Mujoco_ClippedPPO,Mujoco_DDPG'
|
||||||
python3 rl_coach/tests/test_eks.py -c coach-test -bn ${CIRCLE_BUILD_NUM} -tn trace-test-mujoco -tc "export PRESETS=${PRESETS} && make trace_tests_without_docker" -i 316971102342.dkr.ecr.us-west-2.amazonaws.com/coach-mujoco_environment:$(git describe --tags --always --dirty) -cpu 2048 -mem 4096
|
python3 rl_coach/tests/test_eks.py -c coach-test -bn ${CIRCLE_BUILD_NUM} -tn trace-test-mujoco -tc "export TRACE_PRESETS=${TRACE_PRESETS} && make trace_tests_without_docker" -i 316971102342.dkr.ecr.us-west-2.amazonaws.com/coach-mujoco_environment:$(git describe --tags --always --dirty) -cpu 2048 -mem 4096
|
||||||
no_output_timeout: 30m
|
no_output_timeout: 30m
|
||||||
- run:
|
- run:
|
||||||
name: cleanup
|
name: cleanup
|
||||||
|
|||||||
@@ -762,7 +762,8 @@ class Agent(AgentInterface):
|
|||||||
# informed action
|
# informed action
|
||||||
if self.pre_network_filter is not None:
|
if self.pre_network_filter is not None:
|
||||||
# before choosing an action, first use the pre_network_filter to filter out the current state
|
# before choosing an action, first use the pre_network_filter to filter out the current state
|
||||||
curr_state = self.run_pre_network_filter_for_inference(self.curr_state)
|
update_filter_internal_state = self.phase is not RunPhase.TEST
|
||||||
|
curr_state = self.run_pre_network_filter_for_inference(self.curr_state, update_filter_internal_state)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
curr_state = self.curr_state
|
curr_state = self.curr_state
|
||||||
@@ -772,15 +773,18 @@ class Agent(AgentInterface):
|
|||||||
|
|
||||||
return filtered_action_info
|
return filtered_action_info
|
||||||
|
|
||||||
def run_pre_network_filter_for_inference(self, state: StateType) -> StateType:
|
def run_pre_network_filter_for_inference(self, state: StateType, update_filter_internal_state: bool=True)\
|
||||||
|
-> StateType:
|
||||||
"""
|
"""
|
||||||
Run filters which where defined for being applied right before using the state for inference.
|
Run filters which where defined for being applied right before using the state for inference.
|
||||||
|
|
||||||
:param state: The state to run the filters on
|
:param state: The state to run the filters on
|
||||||
|
:param update_filter_internal_state: Should update the filter's internal state - should not update when evaluating
|
||||||
:return: The filtered state
|
:return: The filtered state
|
||||||
"""
|
"""
|
||||||
dummy_env_response = EnvResponse(next_state=state, reward=0, game_over=False)
|
dummy_env_response = EnvResponse(next_state=state, reward=0, game_over=False)
|
||||||
return self.pre_network_filter.filter(dummy_env_response)[0].next_state
|
return self.pre_network_filter.filter(dummy_env_response,
|
||||||
|
update_internal_state=update_filter_internal_state)[0].next_state
|
||||||
|
|
||||||
def get_state_embedding(self, state: dict) -> np.ndarray:
|
def get_state_embedding(self, state: dict) -> np.ndarray:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -325,7 +325,7 @@ class ClippedPPOAgent(ActorCriticAgent):
|
|||||||
self.update_log()
|
self.update_log()
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def run_pre_network_filter_for_inference(self, state: StateType):
|
def run_pre_network_filter_for_inference(self, state: StateType, update_internal_state: bool=False):
|
||||||
dummy_env_response = EnvResponse(next_state=state, reward=0, game_over=False)
|
dummy_env_response = EnvResponse(next_state=state, reward=0, game_over=False)
|
||||||
return self.pre_network_filter.filter(dummy_env_response, update_internal_state=False)[0].next_state
|
return self.pre_network_filter.filter(dummy_env_response, update_internal_state=False)[0].next_state
|
||||||
|
|
||||||
|
|||||||
@@ -36,6 +36,7 @@ class DQNAlgorithmParameters(AlgorithmParameters):
|
|||||||
self.num_steps_between_copying_online_weights_to_target = EnvironmentSteps(10000)
|
self.num_steps_between_copying_online_weights_to_target = EnvironmentSteps(10000)
|
||||||
self.num_consecutive_playing_steps = EnvironmentSteps(4)
|
self.num_consecutive_playing_steps = EnvironmentSteps(4)
|
||||||
self.discount = 0.99
|
self.discount = 0.99
|
||||||
|
self.supports_parameter_noise = True
|
||||||
|
|
||||||
|
|
||||||
class DQNNetworkParameters(NetworkParameters):
|
class DQNNetworkParameters(NetworkParameters):
|
||||||
|
|||||||
@@ -211,6 +211,9 @@ class AlgorithmParameters(Parameters):
|
|||||||
# Should the workers wait for full episode
|
# Should the workers wait for full episode
|
||||||
self.act_for_full_episodes = False
|
self.act_for_full_episodes = False
|
||||||
|
|
||||||
|
# Support for parameter noise
|
||||||
|
self.supports_parameter_noise = False
|
||||||
|
|
||||||
|
|
||||||
class PresetValidationParameters(Parameters):
|
class PresetValidationParameters(Parameters):
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
|
|||||||
@@ -88,7 +88,7 @@ class AdditiveNoise(ExplorationPolicy):
|
|||||||
action_values_mean = action_values.squeeze()
|
action_values_mean = action_values.squeeze()
|
||||||
|
|
||||||
# step the noise schedule
|
# step the noise schedule
|
||||||
if self.phase == RunPhase.TRAIN:
|
if self.phase is not RunPhase.TEST:
|
||||||
self.noise_percentage_schedule.step()
|
self.noise_percentage_schedule.step()
|
||||||
# the second element of the list is assumed to be the standard deviation
|
# the second element of the list is assumed to be the standard deviation
|
||||||
if isinstance(action_values, list) and len(action_values) > 1:
|
if isinstance(action_values, list) and len(action_values) > 1:
|
||||||
|
|||||||
@@ -18,7 +18,6 @@ from typing import List, Dict
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from rl_coach.agents.dqn_agent import DQNAgentParameters
|
|
||||||
from rl_coach.architectures.layers import NoisyNetDense
|
from rl_coach.architectures.layers import NoisyNetDense
|
||||||
from rl_coach.base_parameters import AgentParameters, NetworkParameters
|
from rl_coach.base_parameters import AgentParameters, NetworkParameters
|
||||||
from rl_coach.spaces import ActionSpace, BoxActionSpace, DiscreteActionSpace
|
from rl_coach.spaces import ActionSpace, BoxActionSpace, DiscreteActionSpace
|
||||||
@@ -30,7 +29,8 @@ from rl_coach.exploration_policies.exploration_policy import ExplorationPolicy,
|
|||||||
class ParameterNoiseParameters(ExplorationParameters):
|
class ParameterNoiseParameters(ExplorationParameters):
|
||||||
def __init__(self, agent_params: AgentParameters):
|
def __init__(self, agent_params: AgentParameters):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if not isinstance(agent_params, DQNAgentParameters):
|
|
||||||
|
if not agent_params.algorithm.supports_parameter_noise:
|
||||||
raise ValueError("Currently only DQN variants are supported for using an exploration type of "
|
raise ValueError("Currently only DQN variants are supported for using an exploration type of "
|
||||||
"ParameterNoise.")
|
"ParameterNoise.")
|
||||||
|
|
||||||
|
|||||||
@@ -92,7 +92,7 @@ class TruncatedNormal(ExplorationPolicy):
|
|||||||
action_values_mean = action_values.squeeze()
|
action_values_mean = action_values.squeeze()
|
||||||
|
|
||||||
# step the noise schedule
|
# step the noise schedule
|
||||||
if self.phase == RunPhase.TRAIN:
|
if self.phase is not RunPhase.TEST:
|
||||||
self.noise_percentage_schedule.step()
|
self.noise_percentage_schedule.step()
|
||||||
# the second element of the list is assumed to be the standard deviation
|
# the second element of the list is assumed to be the standard deviation
|
||||||
if isinstance(action_values, list) and len(action_values) > 1:
|
if isinstance(action_values, list) and len(action_values) > 1:
|
||||||
|
|||||||
@@ -466,14 +466,14 @@ class InputFilter(Filter):
|
|||||||
if self.name is not None:
|
if self.name is not None:
|
||||||
checkpoint_prefix = '.'.join([checkpoint_prefix, self.name])
|
checkpoint_prefix = '.'.join([checkpoint_prefix, self.name])
|
||||||
for filter_name, filter in self._reward_filters.items():
|
for filter_name, filter in self._reward_filters.items():
|
||||||
checkpoint_prefix = '.'.join([checkpoint_prefix, 'reward_filters', filter_name])
|
curr_reward_filter_ckpt_prefix = '.'.join([checkpoint_prefix, 'reward_filters', filter_name])
|
||||||
filter.save_state_to_checkpoint(checkpoint_dir, checkpoint_prefix)
|
filter.save_state_to_checkpoint(checkpoint_dir, curr_reward_filter_ckpt_prefix)
|
||||||
|
|
||||||
for observation_name, filters_dict in self._observation_filters.items():
|
for observation_name, filters_dict in self._observation_filters.items():
|
||||||
for filter_name, filter in filters_dict.items():
|
for filter_name, filter in filters_dict.items():
|
||||||
checkpoint_prefix = '.'.join([checkpoint_prefix, 'observation_filters', observation_name,
|
curr_obs_filter_ckpt_prefix = '.'.join([checkpoint_prefix, 'observation_filters', observation_name,
|
||||||
filter_name])
|
filter_name])
|
||||||
filter.save_state_to_checkpoint(checkpoint_dir, checkpoint_prefix)
|
filter.save_state_to_checkpoint(checkpoint_dir, curr_obs_filter_ckpt_prefix)
|
||||||
|
|
||||||
def restore_state_from_checkpoint(self, checkpoint_dir, checkpoint_prefix)->None:
|
def restore_state_from_checkpoint(self, checkpoint_dir, checkpoint_prefix)->None:
|
||||||
"""
|
"""
|
||||||
@@ -486,14 +486,14 @@ class InputFilter(Filter):
|
|||||||
if self.name is not None:
|
if self.name is not None:
|
||||||
checkpoint_prefix = '.'.join([checkpoint_prefix, self.name])
|
checkpoint_prefix = '.'.join([checkpoint_prefix, self.name])
|
||||||
for filter_name, filter in self._reward_filters.items():
|
for filter_name, filter in self._reward_filters.items():
|
||||||
checkpoint_prefix = '.'.join([checkpoint_prefix, 'reward_filters', filter_name])
|
curr_reward_filter_ckpt_prefix = '.'.join([checkpoint_prefix, 'reward_filters', filter_name])
|
||||||
filter.restore_state_from_checkpoint(checkpoint_dir, checkpoint_prefix)
|
filter.restore_state_from_checkpoint(checkpoint_dir, curr_reward_filter_ckpt_prefix)
|
||||||
|
|
||||||
for observation_name, filters_dict in self._observation_filters.items():
|
for observation_name, filters_dict in self._observation_filters.items():
|
||||||
for filter_name, filter in filters_dict.items():
|
for filter_name, filter in filters_dict.items():
|
||||||
checkpoint_prefix = '.'.join([checkpoint_prefix, 'observation_filters', observation_name,
|
curr_obs_filter_ckpt_prefix = '.'.join([checkpoint_prefix, 'observation_filters', observation_name,
|
||||||
filter_name])
|
filter_name])
|
||||||
filter.restore_state_from_checkpoint(checkpoint_dir, checkpoint_prefix)
|
filter.restore_state_from_checkpoint(checkpoint_dir, curr_obs_filter_ckpt_prefix)
|
||||||
|
|
||||||
|
|
||||||
class NoInputFilter(InputFilter):
|
class NoInputFilter(InputFilter):
|
||||||
|
|||||||
@@ -87,3 +87,4 @@ class ObservationNormalizationFilter(ObservationFilter):
|
|||||||
|
|
||||||
def restore_state_from_checkpoint(self, checkpoint_dir: str, checkpoint_prefix: str):
|
def restore_state_from_checkpoint(self, checkpoint_dir: str, checkpoint_prefix: str):
|
||||||
self.running_observation_stats.restore_state_from_checkpoint(checkpoint_dir, checkpoint_prefix)
|
self.running_observation_stats.restore_state_from_checkpoint(checkpoint_dir, checkpoint_prefix)
|
||||||
|
|
||||||
@@ -109,13 +109,13 @@ class SharedRunningStats(ABC):
|
|||||||
def restore_state_from_checkpoint(self, checkpoint_dir: str, checkpoint_prefix: str):
|
def restore_state_from_checkpoint(self, checkpoint_dir: str, checkpoint_prefix: str):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def get_latest_checkpoint(self, checkpoint_dir: str) -> str:
|
def get_latest_checkpoint(self, checkpoint_dir: str, checkpoint_prefix: str) -> str:
|
||||||
latest_checkpoint_id = -1
|
latest_checkpoint_id = -1
|
||||||
latest_checkpoint = ''
|
latest_checkpoint = ''
|
||||||
# get all checkpoint files
|
# get all checkpoint files
|
||||||
for fname in os.listdir(checkpoint_dir):
|
for fname in os.listdir(checkpoint_dir):
|
||||||
path = os.path.join(checkpoint_dir, fname)
|
path = os.path.join(checkpoint_dir, fname)
|
||||||
if os.path.isdir(path) or fname.split('.')[-1] != 'srs':
|
if os.path.isdir(path) or fname.split('.')[-1] != 'srs' or checkpoint_prefix not in fname:
|
||||||
continue
|
continue
|
||||||
checkpoint_id = int(fname.split('_')[0])
|
checkpoint_id = int(fname.split('_')[0])
|
||||||
if checkpoint_id > latest_checkpoint_id:
|
if checkpoint_id > latest_checkpoint_id:
|
||||||
@@ -189,7 +189,7 @@ class NumpySharedRunningStats(SharedRunningStats):
|
|||||||
pickle.dump(dict_to_save, f, pickle.HIGHEST_PROTOCOL)
|
pickle.dump(dict_to_save, f, pickle.HIGHEST_PROTOCOL)
|
||||||
|
|
||||||
def restore_state_from_checkpoint(self, checkpoint_dir: str, checkpoint_prefix: str):
|
def restore_state_from_checkpoint(self, checkpoint_dir: str, checkpoint_prefix: str):
|
||||||
latest_checkpoint_filename = self.get_latest_checkpoint(checkpoint_dir)
|
latest_checkpoint_filename = self.get_latest_checkpoint(checkpoint_dir, checkpoint_prefix)
|
||||||
|
|
||||||
if latest_checkpoint_filename == '':
|
if latest_checkpoint_filename == '':
|
||||||
raise ValueError("Could not find NumpySharedRunningStats checkpoint file. ")
|
raise ValueError("Could not find NumpySharedRunningStats checkpoint file. ")
|
||||||
|
|||||||
Reference in New Issue
Block a user