1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-18 03:30:19 +01:00

Merge branch 'master' into tf_version_bump

This commit is contained in:
Scott Leishman
2018-12-21 10:58:41 -05:00
committed by GitHub
11 changed files with 41 additions and 32 deletions

View File

@@ -175,8 +175,8 @@ jobs:
- run:
name: run gym related golden tests
command: |
export PRESETS='CartPole_A3C,CartPole_Dueling_DDQN,CartPole_NStepQ,CartPole_DQN,CartPole_DFP,CartPole_PG,CartPole_NEC,CartPole_ClippedPPO,CartPole_PAL'
python3 rl_coach/tests/test_eks.py -c coach-test -bn ${CIRCLE_BUILD_NUM} -tn golden-test-gym -tc "export PRESETS=${PRESETS} && make golden_tests_without_docker" -i 316971102342.dkr.ecr.us-west-2.amazonaws.com/coach-gym_environment:$(git describe --tags --always --dirty) -cpu 2048 -mem 4096
export GOLDEN_PRESETS='CartPole'
python3 rl_coach/tests/test_eks.py -c coach-test -bn ${CIRCLE_BUILD_NUM} -tn golden-test-gym -tc "export GOLDEN_PRESETS=${GOLDEN_PRESETS} && make golden_tests_without_docker" -i 316971102342.dkr.ecr.us-west-2.amazonaws.com/coach-gym_environment:$(git describe --tags --always --dirty) -cpu 2048 -mem 4096
no_output_timeout: 30m
- run:
name: cleanup
@@ -196,8 +196,8 @@ jobs:
- run:
name: run doom related golden tests
command: |
export PRESETS='Doom_Basic_DQN,Doom_Basic_A3C,Doom_Health_DFP'
python3 rl_coach/tests/test_eks.py -c coach-test -bn ${CIRCLE_BUILD_NUM} -tn golden-test-doom -tc "export PRESETS=${PRESETS} && make golden_tests_without_docker" -i 316971102342.dkr.ecr.us-west-2.amazonaws.com/coach-doom_environment:$(git describe --tags --always --dirty) -cpu 2048 -mem 4096
export GOLDEN_PRESETS='Doom'
python3 rl_coach/tests/test_eks.py -c coach-test -bn ${CIRCLE_BUILD_NUM} -tn golden-test-doom -tc "export GOLDEN_PRESETS=${GOLDEN_PRESETS} && make golden_tests_without_docker" -i 316971102342.dkr.ecr.us-west-2.amazonaws.com/coach-doom_environment:$(git describe --tags --always --dirty) -cpu 2048 -mem 4096
no_output_timeout: 30m
- run:
name: cleanup
@@ -217,8 +217,8 @@ jobs:
- run:
name: run mujoco related golden tests
command: |
export PRESETS='BitFlip_DQN_HER,BitFlip_DQN,Mujoco_A3C,Mujoco_A3C_LSTM,Mujoco_PPO,Mujoco_ClippedPPO,Mujoco_DDPG'
python3 rl_coach/tests/test_eks.py -c coach-test -bn ${CIRCLE_BUILD_NUM} -tn golden-test-mujoco -tc "export PRESETS=${PRESETS} && make golden_tests_without_docker" -i 316971102342.dkr.ecr.us-west-2.amazonaws.com/coach-mujoco_environment:$(git describe --tags --always --dirty) -cpu 2048 -mem 4096
export GOLDEN_PRESETS='BitFlip or Mujoco'
python3 rl_coach/tests/test_eks.py -c coach-test -bn ${CIRCLE_BUILD_NUM} -tn golden-test-mujoco -tc "export GOLDEN_PRESETS=${GOLDEN_PRESETS} && make golden_tests_without_docker" -i 316971102342.dkr.ecr.us-west-2.amazonaws.com/coach-mujoco_environment:$(git describe --tags --always --dirty) -cpu 2048 -mem 4096
no_output_timeout: 30m
- run:
name: cleanup
@@ -238,8 +238,8 @@ jobs:
- run:
name: run gym related trace tests
command: |
export PRESETS='CartPole_A3C,CartPole_Dueling_DDQN,CartPole_NStepQ,CartPole_DQN,CartPole_DFP,CartPole_PG,CartPole_NEC,CartPole_ClippedPPO,CartPole_PAL'
python3 rl_coach/tests/test_eks.py -c coach-test -bn ${CIRCLE_BUILD_NUM} -tn trace-test-gym -tc "export PRESETS=${PRESETS} && make trace_tests_without_docker" -i 316971102342.dkr.ecr.us-west-2.amazonaws.com/coach-gym_environment:$(git describe --tags --always --dirty) -cpu 2048 -mem 4096
export TRACE_PRESETS='CartPole_A3C,CartPole_Dueling_DDQN,CartPole_NStepQ,CartPole_DQN,CartPole_DFP,CartPole_PG,CartPole_NEC,CartPole_ClippedPPO,CartPole_PAL'
python3 rl_coach/tests/test_eks.py -c coach-test -bn ${CIRCLE_BUILD_NUM} -tn trace-test-gym -tc "export TRACE_PRESETS=${TRACE_PRESETS} && make trace_tests_without_docker" -i 316971102342.dkr.ecr.us-west-2.amazonaws.com/coach-gym_environment:$(git describe --tags --always --dirty) -cpu 2048 -mem 4096
no_output_timeout: 30m
- run:
name: cleanup
@@ -259,8 +259,8 @@ jobs:
- run:
name: run doom related trace tests
command: |
export PRESETS='Doom_Basic_DQN,Doom_Basic_A3C,Doom_Health_DFP'
python3 rl_coach/tests/test_eks.py -c coach-test -bn ${CIRCLE_BUILD_NUM} -tn trace-test-doom -tc "export PRESETS=${PRESETS} && make trace_tests_without_docker" -i 316971102342.dkr.ecr.us-west-2.amazonaws.com/coach-doom_environment:$(git describe --tags --always --dirty) -cpu 2048 -mem 4096
export TRACE_PRESETS='Doom_Basic_DQN,Doom_Basic_A3C,Doom_Health_DFP'
python3 rl_coach/tests/test_eks.py -c coach-test -bn ${CIRCLE_BUILD_NUM} -tn trace-test-doom -tc "export TRACE_PRESETS=${TRACE_PRESETS} && make trace_tests_without_docker" -i 316971102342.dkr.ecr.us-west-2.amazonaws.com/coach-doom_environment:$(git describe --tags --always --dirty) -cpu 2048 -mem 4096
no_output_timeout: 30m
- run:
name: cleanup
@@ -280,8 +280,8 @@ jobs:
- run:
name: run mujoco related trace tests
command: |
export PRESETS='BitFlip_DQN_HER,BitFlip_DQN,Mujoco_A3C,Mujoco_A3C_LSTM,Mujoco_PPO,Mujoco_ClippedPPO,Mujoco_DDPG'
python3 rl_coach/tests/test_eks.py -c coach-test -bn ${CIRCLE_BUILD_NUM} -tn trace-test-mujoco -tc "export PRESETS=${PRESETS} && make trace_tests_without_docker" -i 316971102342.dkr.ecr.us-west-2.amazonaws.com/coach-mujoco_environment:$(git describe --tags --always --dirty) -cpu 2048 -mem 4096
export TRACE_PRESETS='BitFlip_DQN_HER,BitFlip_DQN,Mujoco_A3C,Mujoco_A3C_LSTM,Mujoco_PPO,Mujoco_ClippedPPO,Mujoco_DDPG'
python3 rl_coach/tests/test_eks.py -c coach-test -bn ${CIRCLE_BUILD_NUM} -tn trace-test-mujoco -tc "export TRACE_PRESETS=${TRACE_PRESETS} && make trace_tests_without_docker" -i 316971102342.dkr.ecr.us-west-2.amazonaws.com/coach-mujoco_environment:$(git describe --tags --always --dirty) -cpu 2048 -mem 4096
no_output_timeout: 30m
- run:
name: cleanup

View File

@@ -762,7 +762,8 @@ class Agent(AgentInterface):
# informed action
if self.pre_network_filter is not None:
# before choosing an action, first use the pre_network_filter to filter out the current state
curr_state = self.run_pre_network_filter_for_inference(self.curr_state)
update_filter_internal_state = self.phase is not RunPhase.TEST
curr_state = self.run_pre_network_filter_for_inference(self.curr_state, update_filter_internal_state)
else:
curr_state = self.curr_state
@@ -772,15 +773,18 @@ class Agent(AgentInterface):
return filtered_action_info
def run_pre_network_filter_for_inference(self, state: StateType) -> StateType:
def run_pre_network_filter_for_inference(self, state: StateType, update_filter_internal_state: bool=True)\
-> StateType:
"""
Run filters which where defined for being applied right before using the state for inference.
:param state: The state to run the filters on
:param update_filter_internal_state: Should update the filter's internal state - should not update when evaluating
:return: The filtered state
"""
dummy_env_response = EnvResponse(next_state=state, reward=0, game_over=False)
return self.pre_network_filter.filter(dummy_env_response)[0].next_state
return self.pre_network_filter.filter(dummy_env_response,
update_internal_state=update_filter_internal_state)[0].next_state
def get_state_embedding(self, state: dict) -> np.ndarray:
"""

View File

@@ -325,7 +325,7 @@ class ClippedPPOAgent(ActorCriticAgent):
self.update_log()
return None
def run_pre_network_filter_for_inference(self, state: StateType):
def run_pre_network_filter_for_inference(self, state: StateType, update_internal_state: bool=False):
dummy_env_response = EnvResponse(next_state=state, reward=0, game_over=False)
return self.pre_network_filter.filter(dummy_env_response, update_internal_state=False)[0].next_state

View File

@@ -36,6 +36,7 @@ class DQNAlgorithmParameters(AlgorithmParameters):
self.num_steps_between_copying_online_weights_to_target = EnvironmentSteps(10000)
self.num_consecutive_playing_steps = EnvironmentSteps(4)
self.discount = 0.99
self.supports_parameter_noise = True
class DQNNetworkParameters(NetworkParameters):

View File

@@ -211,6 +211,9 @@ class AlgorithmParameters(Parameters):
# Should the workers wait for full episode
self.act_for_full_episodes = False
# Support for parameter noise
self.supports_parameter_noise = False
class PresetValidationParameters(Parameters):
def __init__(self,

View File

@@ -88,7 +88,7 @@ class AdditiveNoise(ExplorationPolicy):
action_values_mean = action_values.squeeze()
# step the noise schedule
if self.phase == RunPhase.TRAIN:
if self.phase is not RunPhase.TEST:
self.noise_percentage_schedule.step()
# the second element of the list is assumed to be the standard deviation
if isinstance(action_values, list) and len(action_values) > 1:

View File

@@ -18,7 +18,6 @@ from typing import List, Dict
import numpy as np
from rl_coach.agents.dqn_agent import DQNAgentParameters
from rl_coach.architectures.layers import NoisyNetDense
from rl_coach.base_parameters import AgentParameters, NetworkParameters
from rl_coach.spaces import ActionSpace, BoxActionSpace, DiscreteActionSpace
@@ -30,7 +29,8 @@ from rl_coach.exploration_policies.exploration_policy import ExplorationPolicy,
class ParameterNoiseParameters(ExplorationParameters):
def __init__(self, agent_params: AgentParameters):
super().__init__()
if not isinstance(agent_params, DQNAgentParameters):
if not agent_params.algorithm.supports_parameter_noise:
raise ValueError("Currently only DQN variants are supported for using an exploration type of "
"ParameterNoise.")

View File

@@ -92,7 +92,7 @@ class TruncatedNormal(ExplorationPolicy):
action_values_mean = action_values.squeeze()
# step the noise schedule
if self.phase == RunPhase.TRAIN:
if self.phase is not RunPhase.TEST:
self.noise_percentage_schedule.step()
# the second element of the list is assumed to be the standard deviation
if isinstance(action_values, list) and len(action_values) > 1:

View File

@@ -466,14 +466,14 @@ class InputFilter(Filter):
if self.name is not None:
checkpoint_prefix = '.'.join([checkpoint_prefix, self.name])
for filter_name, filter in self._reward_filters.items():
checkpoint_prefix = '.'.join([checkpoint_prefix, 'reward_filters', filter_name])
filter.save_state_to_checkpoint(checkpoint_dir, checkpoint_prefix)
curr_reward_filter_ckpt_prefix = '.'.join([checkpoint_prefix, 'reward_filters', filter_name])
filter.save_state_to_checkpoint(checkpoint_dir, curr_reward_filter_ckpt_prefix)
for observation_name, filters_dict in self._observation_filters.items():
for filter_name, filter in filters_dict.items():
checkpoint_prefix = '.'.join([checkpoint_prefix, 'observation_filters', observation_name,
curr_obs_filter_ckpt_prefix = '.'.join([checkpoint_prefix, 'observation_filters', observation_name,
filter_name])
filter.save_state_to_checkpoint(checkpoint_dir, checkpoint_prefix)
filter.save_state_to_checkpoint(checkpoint_dir, curr_obs_filter_ckpt_prefix)
def restore_state_from_checkpoint(self, checkpoint_dir, checkpoint_prefix)->None:
"""
@@ -486,14 +486,14 @@ class InputFilter(Filter):
if self.name is not None:
checkpoint_prefix = '.'.join([checkpoint_prefix, self.name])
for filter_name, filter in self._reward_filters.items():
checkpoint_prefix = '.'.join([checkpoint_prefix, 'reward_filters', filter_name])
filter.restore_state_from_checkpoint(checkpoint_dir, checkpoint_prefix)
curr_reward_filter_ckpt_prefix = '.'.join([checkpoint_prefix, 'reward_filters', filter_name])
filter.restore_state_from_checkpoint(checkpoint_dir, curr_reward_filter_ckpt_prefix)
for observation_name, filters_dict in self._observation_filters.items():
for filter_name, filter in filters_dict.items():
checkpoint_prefix = '.'.join([checkpoint_prefix, 'observation_filters', observation_name,
curr_obs_filter_ckpt_prefix = '.'.join([checkpoint_prefix, 'observation_filters', observation_name,
filter_name])
filter.restore_state_from_checkpoint(checkpoint_dir, checkpoint_prefix)
filter.restore_state_from_checkpoint(checkpoint_dir, curr_obs_filter_ckpt_prefix)
class NoInputFilter(InputFilter):

View File

@@ -87,3 +87,4 @@ class ObservationNormalizationFilter(ObservationFilter):
def restore_state_from_checkpoint(self, checkpoint_dir: str, checkpoint_prefix: str):
self.running_observation_stats.restore_state_from_checkpoint(checkpoint_dir, checkpoint_prefix)

View File

@@ -109,13 +109,13 @@ class SharedRunningStats(ABC):
def restore_state_from_checkpoint(self, checkpoint_dir: str, checkpoint_prefix: str):
pass
def get_latest_checkpoint(self, checkpoint_dir: str) -> str:
def get_latest_checkpoint(self, checkpoint_dir: str, checkpoint_prefix: str) -> str:
latest_checkpoint_id = -1
latest_checkpoint = ''
# get all checkpoint files
for fname in os.listdir(checkpoint_dir):
path = os.path.join(checkpoint_dir, fname)
if os.path.isdir(path) or fname.split('.')[-1] != 'srs':
if os.path.isdir(path) or fname.split('.')[-1] != 'srs' or checkpoint_prefix not in fname:
continue
checkpoint_id = int(fname.split('_')[0])
if checkpoint_id > latest_checkpoint_id:
@@ -189,7 +189,7 @@ class NumpySharedRunningStats(SharedRunningStats):
pickle.dump(dict_to_save, f, pickle.HIGHEST_PROTOCOL)
def restore_state_from_checkpoint(self, checkpoint_dir: str, checkpoint_prefix: str):
latest_checkpoint_filename = self.get_latest_checkpoint(checkpoint_dir)
latest_checkpoint_filename = self.get_latest_checkpoint(checkpoint_dir, checkpoint_prefix)
if latest_checkpoint_filename == '':
raise ValueError("Could not find NumpySharedRunningStats checkpoint file. ")