1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-18 03:30:19 +01:00

RL in Large Discrete Action Spaces - Wolpertinger Agent (#394)

* Currently this is specific to the case of discretizing a continuous action space. Can easily be adapted to other case by feeding the kNN otherwise, and removing the usage of a discretizing output action filter
This commit is contained in:
Gal Leibovich
2019-09-08 12:53:49 +03:00
committed by GitHub
parent fc50398544
commit 138ced23ba
46 changed files with 1193 additions and 51 deletions

View File

@@ -1003,7 +1003,7 @@ class Agent(AgentInterface):
"""
Allows setting a directive for the agent to follow. This is useful in hierarchy structures, where the agent
has another master agent that is controlling it. In such cases, the master agent can define the goals for the
slave agent, define it's observation, possible actions, etc. The directive type is defined by the agent
slave agent, define its observation, possible actions, etc. The directive type is defined by the agent
in-action-space.
:param action: The action that should be set as the directive

View File

@@ -0,0 +1,131 @@
#
# Copyright (c) 2019 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import copy
from typing import Union
from collections import OrderedDict
import numpy as np
from rl_coach.agents.ddpg_agent import DDPGAlgorithmParameters, DDPGActorNetworkParameters, \
DDPGCriticNetworkParameters, DDPGAgent
from rl_coach.base_parameters import AgentParameters
from rl_coach.core_types import ActionInfo
from rl_coach.exploration_policies.additive_noise import AdditiveNoiseParameters
from rl_coach.memories.episodic.episodic_experience_replay import EpisodicExperienceReplayParameters
from rl_coach.memories.non_episodic.differentiable_neural_dictionary import AnnoyDictionary
from rl_coach.spaces import DiscreteActionSpace, BoxActionSpace
from rl_coach.architectures.head_parameters import WolpertingerActorHeadParameters
class WolpertingerCriticNetworkParameters(DDPGCriticNetworkParameters):
def __init__(self, use_batchnorm=False):
super().__init__(use_batchnorm=use_batchnorm)
class WolpertingerActorNetworkParameters(DDPGActorNetworkParameters):
def __init__(self, use_batchnorm=False):
super().__init__()
self.heads_parameters = [WolpertingerActorHeadParameters(batchnorm=use_batchnorm)]
class WolpertingerAlgorithmParameters(DDPGAlgorithmParameters):
def __init__(self):
super().__init__()
self.action_embedding_width = 1
self.k = 1
class WolpertingerAgentParameters(AgentParameters):
def __init__(self, use_batchnorm=False):
exploration_params = AdditiveNoiseParameters()
exploration_params.noise_as_percentage_from_action_space = False
super().__init__(algorithm=WolpertingerAlgorithmParameters(),
exploration=exploration_params,
memory=EpisodicExperienceReplayParameters(),
networks=OrderedDict(
[("actor", WolpertingerActorNetworkParameters(use_batchnorm=use_batchnorm)),
("critic", WolpertingerCriticNetworkParameters(use_batchnorm=use_batchnorm))]))
@property
def path(self):
return 'rl_coach.agents.wolpertinger_agent:WolpertingerAgent'
# Deep Reinforcement Learning in Large Discrete Action Spaces - https://arxiv.org/pdf/1512.07679.pdf
class WolpertingerAgent(DDPGAgent):
def __init__(self, agent_parameters, parent: Union['LevelManager', 'CompositeAgent'] = None):
super().__init__(agent_parameters, parent)
def learn_from_batch(self, batch):
# replay buffer holds the actions in the discrete manner, as the agent is expected to act with discrete actions
# with the BoxDiscretization output filter. But DDPG needs to work on continuous actions, thus converting to
# continuous actions. This is actually a duplicate since this filtering is also done before applying actions on
# the environment. So might want to somehow reuse that conversion. Maybe can hold this information in the info
# dictionary of the transition.
output_action_filter = \
list(self.output_filter.action_filters.values())[0]
continuous_actions = []
for action in batch.actions():
continuous_actions.append(output_action_filter.filter(action))
batch._actions = np.array(continuous_actions).squeeze()
return super().learn_from_batch(batch)
def train(self):
return super().train()
def choose_action(self, curr_state):
if not isinstance(self.spaces.action, DiscreteActionSpace):
raise ValueError("WolpertingerAgent works only for discrete control problems")
# convert to batch so we can run it through the network
tf_input_state = self.prepare_batch_for_inference(curr_state, 'actor')
actor_network = self.networks['actor'].online_network
critic_network = self.networks['critic'].online_network
proto_action = actor_network.predict(tf_input_state)
proto_action = np.expand_dims(self.exploration_policy.get_action(proto_action), 0)
nn_action_embeddings, indices, _, _ = self.knn_tree.query(keys=proto_action, k=self.ap.algorithm.k)
# now move the actions through the critic and choose the one with the highest q value
critic_inputs = copy.copy(tf_input_state)
critic_inputs['observation'] = np.tile(critic_inputs['observation'], (self.ap.algorithm.k, 1))
critic_inputs['action'] = nn_action_embeddings[0]
q_values = critic_network.predict(critic_inputs)[0]
action = int(indices[0][np.argmax(q_values)])
self.action_signal.add_sample(action)
return ActionInfo(action=action, action_value=0)
def init_environment_dependent_modules(self):
super().init_environment_dependent_modules()
self.knn_tree = self.get_initialized_knn()
# TODO - ideally the knn should not be defined here, but somehow be defined by the user in the preset
def get_initialized_knn(self):
num_actions = len(self.spaces.action.actions)
action_max_abs_range = self.spaces.action.filtered_action_space.max_abs_range if \
(hasattr(self.spaces.action, 'filtered_action_space') and
isinstance(self.spaces.action.filtered_action_space, BoxActionSpace)) \
else 1.0
keys = np.expand_dims((np.arange(num_actions) / (num_actions - 1) - 0.5) * 2, 1) * action_max_abs_range
values = np.expand_dims(np.arange(num_actions), 1)
knn_tree = AnnoyDictionary(dict_size=num_actions, key_width=self.ap.algorithm.action_embedding_width)
knn_tree.add(keys, values, force_rebuild_tree=True)
return knn_tree

View File

@@ -108,6 +108,17 @@ class DDPGActorHeadParameters(HeadParameters):
self.batchnorm = batchnorm
class WolpertingerActorHeadParameters(HeadParameters):
def __init__(self, activation_function: str ='tanh', name: str='policy_head_params', batchnorm: bool=True,
num_output_head_copies: int = 1, rescale_gradient_from_head_by_factor: float = 1.0,
loss_weight: float = 1.0, dense_layer=None):
super().__init__(parameterized_class_name="WolpertingerActorHead", activation_function=activation_function, name=name,
dense_layer=dense_layer, num_output_head_copies=num_output_head_copies,
rescale_gradient_from_head_by_factor=rescale_gradient_from_head_by_factor,
loss_weight=loss_weight)
self.batchnorm = batchnorm
class DNDQHeadParameters(HeadParameters):
def __init__(self, activation_function: str ='relu', name: str='dnd_q_head_params',
num_output_head_copies: int = 1, rescale_gradient_from_head_by_factor: float = 1.0,

View File

@@ -18,6 +18,7 @@ from .classification_head import ClassificationHead
from .cil_head import RegressionHead
from .td3_v_head import TD3VHead
from .ddpg_v_head import DDPGVHead
from .wolpertinger_actor_head import WolpertingerActorHead
__all__ = [
'CategoricalQHead',
@@ -38,6 +39,7 @@ __all__ = [
'SACQHead',
'ClassificationHead',
'RegressionHead',
'TD3VHead'
'DDPGVHead'
'TD3VHead',
'DDPGVHead',
'WolpertingerActorHead'
]

View File

@@ -0,0 +1,59 @@
#
# Copyright (c) 2019 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import tensorflow as tf
from rl_coach.architectures.tensorflow_components.layers import batchnorm_activation_dropout, Dense
from rl_coach.architectures.tensorflow_components.heads.head import Head
from rl_coach.base_parameters import AgentParameters
from rl_coach.core_types import Embedding
from rl_coach.spaces import SpacesDefinition, BoxActionSpace
class WolpertingerActorHead(Head):
def __init__(self, agent_parameters: AgentParameters, spaces: SpacesDefinition, network_name: str,
head_idx: int = 0, loss_weight: float = 1., is_local: bool = True, activation_function: str='tanh',
batchnorm: bool=True, dense_layer=Dense, is_training=False):
super().__init__(agent_parameters, spaces, network_name, head_idx, loss_weight, is_local, activation_function,
dense_layer=dense_layer, is_training=is_training)
self.name = 'wolpertinger_actor_head'
self.return_type = Embedding
self.action_embedding_width = agent_parameters.algorithm.action_embedding_width
self.batchnorm = batchnorm
self.output_scale = self.spaces.action.filtered_action_space.max_abs_range if \
(hasattr(self.spaces.action, 'filtered_action_space') and
isinstance(self.spaces.action.filtered_action_space, BoxActionSpace)) \
else None
def _build_module(self, input_layer):
# mean
pre_activation_policy_value = self.dense_layer(self.action_embedding_width)(input_layer,
name='actor_action_embedding')
self.proto_action = batchnorm_activation_dropout(input_layer=pre_activation_policy_value,
batchnorm=self.batchnorm,
activation_function=self.activation_function,
dropout_rate=0,
is_training=self.is_training,
name="BatchnormActivationDropout_0")[-1]
if self.output_scale is not None:
self.proto_action = tf.multiply(self.proto_action, self.output_scale, name='proto_action')
self.output = [self.proto_action]
def __str__(self):
result = [
'Dense (num outputs = {})'.format(self.action_embedding_width)
]
return '\n'.join(result)

View File

@@ -62,7 +62,9 @@ class AdditiveNoise(ContinuousActionExplorationPolicy):
self.evaluation_noise = evaluation_noise
self.noise_as_percentage_from_action_space = noise_as_percentage_from_action_space
if not isinstance(action_space, BoxActionSpace):
if not isinstance(action_space, BoxActionSpace) and \
(hasattr(action_space, 'filtered_action_space') and not
isinstance(action_space.filtered_action_space, BoxActionSpace)):
raise ValueError("Additive noise exploration works only for continuous controls."
"The given action space is of type: {}".format(action_space.__class__.__name__))

View File

@@ -115,5 +115,8 @@ class ContinuousActionExplorationPolicy(ExplorationPolicy):
"""
:param action_space: the action space used by the environment
"""
assert isinstance(action_space, BoxActionSpace) or isinstance(action_space, GoalsSpace)
assert isinstance(action_space, BoxActionSpace) or \
(hasattr(action_space, 'filtered_action_space') and
isinstance(action_space.filtered_action_space, BoxActionSpace)) or \
isinstance(action_space, GoalsSpace)
super().__init__(action_space)

View File

@@ -48,7 +48,8 @@ class PartialDiscreteActionSpaceMap(ActionFilter):
def get_unfiltered_action_space(self, output_action_space: ActionSpace) -> DiscreteActionSpace:
self.output_action_space = output_action_space
self.input_action_space = DiscreteActionSpace(len(self.target_actions), self.descriptions)
self.input_action_space = DiscreteActionSpace(len(self.target_actions), self.descriptions,
filtered_action_space=output_action_space)
return self.input_action_space
def filter(self, action: ActionType) -> ActionType:

View File

@@ -57,7 +57,7 @@ class AnnoyDictionary(object):
self.built_capacity = 0
def add(self, keys, values, additional_data=None):
def add(self, keys, values, additional_data=None, force_rebuild_tree=False):
if not additional_data:
additional_data = [None] * len(keys)
@@ -96,7 +96,7 @@ class AnnoyDictionary(object):
if len(self.buffered_indices) >= self.min_update_size:
self.min_update_size = max(self.initial_update_size, int(self.curr_size * 0.02))
self._rebuild_index()
elif self.rebuild_on_every_update:
elif force_rebuild_tree or self.rebuild_on_every_update:
self._rebuild_index()
self.current_timestamp += 1

View File

@@ -0,0 +1,57 @@
from collections import OrderedDict
from rl_coach.architectures.layers import Dense
from rl_coach.base_parameters import VisualizationParameters, PresetValidationParameters, EmbedderScheme
from rl_coach.core_types import EnvironmentEpisodes, EnvironmentSteps
from rl_coach.environments.environment import SingleLevelSelection
from rl_coach.environments.gym_environment import GymVectorEnvironment, mujoco_v2
from rl_coach.filters.action import BoxDiscretization
from rl_coach.filters.filter import OutputFilter
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager
from rl_coach.graph_managers.graph_manager import ScheduleParameters
from rl_coach.agents.wolpertinger_agent import WolpertingerAgentParameters
####################
# Graph Scheduling #
####################
schedule_params = ScheduleParameters()
schedule_params.improve_steps = EnvironmentSteps(2000000)
schedule_params.steps_between_evaluation_periods = EnvironmentEpisodes(20)
schedule_params.evaluation_steps = EnvironmentEpisodes(1)
schedule_params.heatup_steps = EnvironmentSteps(3000)
#########
# Agent #
#########
agent_params = WolpertingerAgentParameters()
agent_params.network_wrappers['actor'].input_embedders_parameters['observation'].scheme = [Dense(400)]
agent_params.network_wrappers['actor'].middleware_parameters.scheme = [Dense(300)]
agent_params.network_wrappers['critic'].input_embedders_parameters['observation'].scheme = [Dense(400)]
agent_params.network_wrappers['critic'].middleware_parameters.scheme = [Dense(300)]
agent_params.network_wrappers['critic'].input_embedders_parameters['action'].scheme = EmbedderScheme.Empty
agent_params.output_filter = \
OutputFilter(
action_filters=OrderedDict([
('discretization', BoxDiscretization(num_bins_per_dimension=int(1e6)))
]),
is_a_reference_filter=False
)
###############
# Environment #
###############
env_params = GymVectorEnvironment(level=SingleLevelSelection(mujoco_v2))
########
# Test #
########
preset_validation_params = PresetValidationParameters()
preset_validation_params.test = True
preset_validation_params.min_reward_threshold = 500
preset_validation_params.max_episodes_to_achieve_reward = 1000
preset_validation_params.reward_test_level = 'inverted_pendulum'
preset_validation_params.trace_test_levels = ['inverted_pendulum']
graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params,
schedule_params=schedule_params, vis_params=VisualizationParameters(),
preset_validation_params=preset_validation_params)

View File

@@ -385,7 +385,8 @@ class DiscreteActionSpace(ActionSpace):
"""
A discrete action space with action indices as actions
"""
def __init__(self, num_actions: int, descriptions: Union[None, List, Dict]=None, default_action: np.ndarray=None):
def __init__(self, num_actions: int, descriptions: Union[None, List, Dict]=None, default_action: np.ndarray=None,
filtered_action_space=None):
super().__init__(1, low=0, high=num_actions-1, descriptions=descriptions)
# the number of actions is mapped to high
@@ -395,6 +396,9 @@ class DiscreteActionSpace(ActionSpace):
else:
self.default_action = default_action
if filtered_action_space is not None:
self.filtered_action_space = filtered_action_space
@property
def actions(self) -> List[ActionType]:
return list(range(0, int(self.high[0]) + 1))