mirror of
https://github.com/gryf/coach.git
synced 2025-12-18 03:30:19 +01:00
RL in Large Discrete Action Spaces - Wolpertinger Agent (#394)
* Currently this is specific to the case of discretizing a continuous action space. Can easily be adapted to other case by feeding the kNN otherwise, and removing the usage of a discretizing output action filter
This commit is contained in:
@@ -1003,7 +1003,7 @@ class Agent(AgentInterface):
|
||||
"""
|
||||
Allows setting a directive for the agent to follow. This is useful in hierarchy structures, where the agent
|
||||
has another master agent that is controlling it. In such cases, the master agent can define the goals for the
|
||||
slave agent, define it's observation, possible actions, etc. The directive type is defined by the agent
|
||||
slave agent, define its observation, possible actions, etc. The directive type is defined by the agent
|
||||
in-action-space.
|
||||
|
||||
:param action: The action that should be set as the directive
|
||||
|
||||
131
rl_coach/agents/wolpertinger_agent.py
Normal file
131
rl_coach/agents/wolpertinger_agent.py
Normal file
@@ -0,0 +1,131 @@
|
||||
#
|
||||
# Copyright (c) 2019 Intel Corporation
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import copy
|
||||
from typing import Union
|
||||
from collections import OrderedDict
|
||||
import numpy as np
|
||||
|
||||
from rl_coach.agents.ddpg_agent import DDPGAlgorithmParameters, DDPGActorNetworkParameters, \
|
||||
DDPGCriticNetworkParameters, DDPGAgent
|
||||
from rl_coach.base_parameters import AgentParameters
|
||||
from rl_coach.core_types import ActionInfo
|
||||
from rl_coach.exploration_policies.additive_noise import AdditiveNoiseParameters
|
||||
from rl_coach.memories.episodic.episodic_experience_replay import EpisodicExperienceReplayParameters
|
||||
from rl_coach.memories.non_episodic.differentiable_neural_dictionary import AnnoyDictionary
|
||||
from rl_coach.spaces import DiscreteActionSpace, BoxActionSpace
|
||||
from rl_coach.architectures.head_parameters import WolpertingerActorHeadParameters
|
||||
|
||||
|
||||
class WolpertingerCriticNetworkParameters(DDPGCriticNetworkParameters):
|
||||
def __init__(self, use_batchnorm=False):
|
||||
super().__init__(use_batchnorm=use_batchnorm)
|
||||
|
||||
|
||||
class WolpertingerActorNetworkParameters(DDPGActorNetworkParameters):
|
||||
def __init__(self, use_batchnorm=False):
|
||||
super().__init__()
|
||||
self.heads_parameters = [WolpertingerActorHeadParameters(batchnorm=use_batchnorm)]
|
||||
|
||||
|
||||
class WolpertingerAlgorithmParameters(DDPGAlgorithmParameters):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.action_embedding_width = 1
|
||||
self.k = 1
|
||||
|
||||
|
||||
class WolpertingerAgentParameters(AgentParameters):
|
||||
def __init__(self, use_batchnorm=False):
|
||||
exploration_params = AdditiveNoiseParameters()
|
||||
exploration_params.noise_as_percentage_from_action_space = False
|
||||
|
||||
super().__init__(algorithm=WolpertingerAlgorithmParameters(),
|
||||
exploration=exploration_params,
|
||||
memory=EpisodicExperienceReplayParameters(),
|
||||
networks=OrderedDict(
|
||||
[("actor", WolpertingerActorNetworkParameters(use_batchnorm=use_batchnorm)),
|
||||
("critic", WolpertingerCriticNetworkParameters(use_batchnorm=use_batchnorm))]))
|
||||
|
||||
@property
|
||||
def path(self):
|
||||
return 'rl_coach.agents.wolpertinger_agent:WolpertingerAgent'
|
||||
|
||||
|
||||
# Deep Reinforcement Learning in Large Discrete Action Spaces - https://arxiv.org/pdf/1512.07679.pdf
|
||||
class WolpertingerAgent(DDPGAgent):
|
||||
def __init__(self, agent_parameters, parent: Union['LevelManager', 'CompositeAgent'] = None):
|
||||
super().__init__(agent_parameters, parent)
|
||||
|
||||
def learn_from_batch(self, batch):
|
||||
# replay buffer holds the actions in the discrete manner, as the agent is expected to act with discrete actions
|
||||
# with the BoxDiscretization output filter. But DDPG needs to work on continuous actions, thus converting to
|
||||
# continuous actions. This is actually a duplicate since this filtering is also done before applying actions on
|
||||
# the environment. So might want to somehow reuse that conversion. Maybe can hold this information in the info
|
||||
# dictionary of the transition.
|
||||
|
||||
output_action_filter = \
|
||||
list(self.output_filter.action_filters.values())[0]
|
||||
continuous_actions = []
|
||||
for action in batch.actions():
|
||||
continuous_actions.append(output_action_filter.filter(action))
|
||||
batch._actions = np.array(continuous_actions).squeeze()
|
||||
|
||||
return super().learn_from_batch(batch)
|
||||
|
||||
def train(self):
|
||||
return super().train()
|
||||
|
||||
def choose_action(self, curr_state):
|
||||
if not isinstance(self.spaces.action, DiscreteActionSpace):
|
||||
raise ValueError("WolpertingerAgent works only for discrete control problems")
|
||||
|
||||
# convert to batch so we can run it through the network
|
||||
tf_input_state = self.prepare_batch_for_inference(curr_state, 'actor')
|
||||
actor_network = self.networks['actor'].online_network
|
||||
critic_network = self.networks['critic'].online_network
|
||||
proto_action = actor_network.predict(tf_input_state)
|
||||
proto_action = np.expand_dims(self.exploration_policy.get_action(proto_action), 0)
|
||||
|
||||
nn_action_embeddings, indices, _, _ = self.knn_tree.query(keys=proto_action, k=self.ap.algorithm.k)
|
||||
|
||||
# now move the actions through the critic and choose the one with the highest q value
|
||||
critic_inputs = copy.copy(tf_input_state)
|
||||
critic_inputs['observation'] = np.tile(critic_inputs['observation'], (self.ap.algorithm.k, 1))
|
||||
critic_inputs['action'] = nn_action_embeddings[0]
|
||||
q_values = critic_network.predict(critic_inputs)[0]
|
||||
action = int(indices[0][np.argmax(q_values)])
|
||||
self.action_signal.add_sample(action)
|
||||
return ActionInfo(action=action, action_value=0)
|
||||
|
||||
def init_environment_dependent_modules(self):
|
||||
super().init_environment_dependent_modules()
|
||||
self.knn_tree = self.get_initialized_knn()
|
||||
|
||||
# TODO - ideally the knn should not be defined here, but somehow be defined by the user in the preset
|
||||
def get_initialized_knn(self):
|
||||
num_actions = len(self.spaces.action.actions)
|
||||
action_max_abs_range = self.spaces.action.filtered_action_space.max_abs_range if \
|
||||
(hasattr(self.spaces.action, 'filtered_action_space') and
|
||||
isinstance(self.spaces.action.filtered_action_space, BoxActionSpace)) \
|
||||
else 1.0
|
||||
keys = np.expand_dims((np.arange(num_actions) / (num_actions - 1) - 0.5) * 2, 1) * action_max_abs_range
|
||||
values = np.expand_dims(np.arange(num_actions), 1)
|
||||
knn_tree = AnnoyDictionary(dict_size=num_actions, key_width=self.ap.algorithm.action_embedding_width)
|
||||
knn_tree.add(keys, values, force_rebuild_tree=True)
|
||||
|
||||
return knn_tree
|
||||
|
||||
@@ -108,6 +108,17 @@ class DDPGActorHeadParameters(HeadParameters):
|
||||
self.batchnorm = batchnorm
|
||||
|
||||
|
||||
class WolpertingerActorHeadParameters(HeadParameters):
|
||||
def __init__(self, activation_function: str ='tanh', name: str='policy_head_params', batchnorm: bool=True,
|
||||
num_output_head_copies: int = 1, rescale_gradient_from_head_by_factor: float = 1.0,
|
||||
loss_weight: float = 1.0, dense_layer=None):
|
||||
super().__init__(parameterized_class_name="WolpertingerActorHead", activation_function=activation_function, name=name,
|
||||
dense_layer=dense_layer, num_output_head_copies=num_output_head_copies,
|
||||
rescale_gradient_from_head_by_factor=rescale_gradient_from_head_by_factor,
|
||||
loss_weight=loss_weight)
|
||||
self.batchnorm = batchnorm
|
||||
|
||||
|
||||
class DNDQHeadParameters(HeadParameters):
|
||||
def __init__(self, activation_function: str ='relu', name: str='dnd_q_head_params',
|
||||
num_output_head_copies: int = 1, rescale_gradient_from_head_by_factor: float = 1.0,
|
||||
|
||||
@@ -18,6 +18,7 @@ from .classification_head import ClassificationHead
|
||||
from .cil_head import RegressionHead
|
||||
from .td3_v_head import TD3VHead
|
||||
from .ddpg_v_head import DDPGVHead
|
||||
from .wolpertinger_actor_head import WolpertingerActorHead
|
||||
|
||||
__all__ = [
|
||||
'CategoricalQHead',
|
||||
@@ -38,6 +39,7 @@ __all__ = [
|
||||
'SACQHead',
|
||||
'ClassificationHead',
|
||||
'RegressionHead',
|
||||
'TD3VHead'
|
||||
'DDPGVHead'
|
||||
'TD3VHead',
|
||||
'DDPGVHead',
|
||||
'WolpertingerActorHead'
|
||||
]
|
||||
|
||||
@@ -0,0 +1,59 @@
|
||||
#
|
||||
# Copyright (c) 2019 Intel Corporation
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import tensorflow as tf
|
||||
|
||||
from rl_coach.architectures.tensorflow_components.layers import batchnorm_activation_dropout, Dense
|
||||
from rl_coach.architectures.tensorflow_components.heads.head import Head
|
||||
from rl_coach.base_parameters import AgentParameters
|
||||
from rl_coach.core_types import Embedding
|
||||
from rl_coach.spaces import SpacesDefinition, BoxActionSpace
|
||||
|
||||
|
||||
class WolpertingerActorHead(Head):
|
||||
def __init__(self, agent_parameters: AgentParameters, spaces: SpacesDefinition, network_name: str,
|
||||
head_idx: int = 0, loss_weight: float = 1., is_local: bool = True, activation_function: str='tanh',
|
||||
batchnorm: bool=True, dense_layer=Dense, is_training=False):
|
||||
super().__init__(agent_parameters, spaces, network_name, head_idx, loss_weight, is_local, activation_function,
|
||||
dense_layer=dense_layer, is_training=is_training)
|
||||
self.name = 'wolpertinger_actor_head'
|
||||
self.return_type = Embedding
|
||||
self.action_embedding_width = agent_parameters.algorithm.action_embedding_width
|
||||
self.batchnorm = batchnorm
|
||||
self.output_scale = self.spaces.action.filtered_action_space.max_abs_range if \
|
||||
(hasattr(self.spaces.action, 'filtered_action_space') and
|
||||
isinstance(self.spaces.action.filtered_action_space, BoxActionSpace)) \
|
||||
else None
|
||||
|
||||
def _build_module(self, input_layer):
|
||||
# mean
|
||||
pre_activation_policy_value = self.dense_layer(self.action_embedding_width)(input_layer,
|
||||
name='actor_action_embedding')
|
||||
self.proto_action = batchnorm_activation_dropout(input_layer=pre_activation_policy_value,
|
||||
batchnorm=self.batchnorm,
|
||||
activation_function=self.activation_function,
|
||||
dropout_rate=0,
|
||||
is_training=self.is_training,
|
||||
name="BatchnormActivationDropout_0")[-1]
|
||||
if self.output_scale is not None:
|
||||
self.proto_action = tf.multiply(self.proto_action, self.output_scale, name='proto_action')
|
||||
|
||||
self.output = [self.proto_action]
|
||||
|
||||
def __str__(self):
|
||||
result = [
|
||||
'Dense (num outputs = {})'.format(self.action_embedding_width)
|
||||
]
|
||||
return '\n'.join(result)
|
||||
@@ -62,7 +62,9 @@ class AdditiveNoise(ContinuousActionExplorationPolicy):
|
||||
self.evaluation_noise = evaluation_noise
|
||||
self.noise_as_percentage_from_action_space = noise_as_percentage_from_action_space
|
||||
|
||||
if not isinstance(action_space, BoxActionSpace):
|
||||
if not isinstance(action_space, BoxActionSpace) and \
|
||||
(hasattr(action_space, 'filtered_action_space') and not
|
||||
isinstance(action_space.filtered_action_space, BoxActionSpace)):
|
||||
raise ValueError("Additive noise exploration works only for continuous controls."
|
||||
"The given action space is of type: {}".format(action_space.__class__.__name__))
|
||||
|
||||
|
||||
@@ -115,5 +115,8 @@ class ContinuousActionExplorationPolicy(ExplorationPolicy):
|
||||
"""
|
||||
:param action_space: the action space used by the environment
|
||||
"""
|
||||
assert isinstance(action_space, BoxActionSpace) or isinstance(action_space, GoalsSpace)
|
||||
assert isinstance(action_space, BoxActionSpace) or \
|
||||
(hasattr(action_space, 'filtered_action_space') and
|
||||
isinstance(action_space.filtered_action_space, BoxActionSpace)) or \
|
||||
isinstance(action_space, GoalsSpace)
|
||||
super().__init__(action_space)
|
||||
|
||||
@@ -48,7 +48,8 @@ class PartialDiscreteActionSpaceMap(ActionFilter):
|
||||
|
||||
def get_unfiltered_action_space(self, output_action_space: ActionSpace) -> DiscreteActionSpace:
|
||||
self.output_action_space = output_action_space
|
||||
self.input_action_space = DiscreteActionSpace(len(self.target_actions), self.descriptions)
|
||||
self.input_action_space = DiscreteActionSpace(len(self.target_actions), self.descriptions,
|
||||
filtered_action_space=output_action_space)
|
||||
return self.input_action_space
|
||||
|
||||
def filter(self, action: ActionType) -> ActionType:
|
||||
|
||||
@@ -57,7 +57,7 @@ class AnnoyDictionary(object):
|
||||
|
||||
self.built_capacity = 0
|
||||
|
||||
def add(self, keys, values, additional_data=None):
|
||||
def add(self, keys, values, additional_data=None, force_rebuild_tree=False):
|
||||
if not additional_data:
|
||||
additional_data = [None] * len(keys)
|
||||
|
||||
@@ -96,7 +96,7 @@ class AnnoyDictionary(object):
|
||||
if len(self.buffered_indices) >= self.min_update_size:
|
||||
self.min_update_size = max(self.initial_update_size, int(self.curr_size * 0.02))
|
||||
self._rebuild_index()
|
||||
elif self.rebuild_on_every_update:
|
||||
elif force_rebuild_tree or self.rebuild_on_every_update:
|
||||
self._rebuild_index()
|
||||
|
||||
self.current_timestamp += 1
|
||||
|
||||
57
rl_coach/presets/Mujoco_Wolpertinger.py
Normal file
57
rl_coach/presets/Mujoco_Wolpertinger.py
Normal file
@@ -0,0 +1,57 @@
|
||||
from collections import OrderedDict
|
||||
|
||||
from rl_coach.architectures.layers import Dense
|
||||
from rl_coach.base_parameters import VisualizationParameters, PresetValidationParameters, EmbedderScheme
|
||||
from rl_coach.core_types import EnvironmentEpisodes, EnvironmentSteps
|
||||
from rl_coach.environments.environment import SingleLevelSelection
|
||||
from rl_coach.environments.gym_environment import GymVectorEnvironment, mujoco_v2
|
||||
from rl_coach.filters.action import BoxDiscretization
|
||||
from rl_coach.filters.filter import OutputFilter
|
||||
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager
|
||||
from rl_coach.graph_managers.graph_manager import ScheduleParameters
|
||||
from rl_coach.agents.wolpertinger_agent import WolpertingerAgentParameters
|
||||
|
||||
####################
|
||||
# Graph Scheduling #
|
||||
####################
|
||||
schedule_params = ScheduleParameters()
|
||||
schedule_params.improve_steps = EnvironmentSteps(2000000)
|
||||
schedule_params.steps_between_evaluation_periods = EnvironmentEpisodes(20)
|
||||
schedule_params.evaluation_steps = EnvironmentEpisodes(1)
|
||||
schedule_params.heatup_steps = EnvironmentSteps(3000)
|
||||
|
||||
#########
|
||||
# Agent #
|
||||
#########
|
||||
agent_params = WolpertingerAgentParameters()
|
||||
agent_params.network_wrappers['actor'].input_embedders_parameters['observation'].scheme = [Dense(400)]
|
||||
agent_params.network_wrappers['actor'].middleware_parameters.scheme = [Dense(300)]
|
||||
agent_params.network_wrappers['critic'].input_embedders_parameters['observation'].scheme = [Dense(400)]
|
||||
agent_params.network_wrappers['critic'].middleware_parameters.scheme = [Dense(300)]
|
||||
agent_params.network_wrappers['critic'].input_embedders_parameters['action'].scheme = EmbedderScheme.Empty
|
||||
agent_params.output_filter = \
|
||||
OutputFilter(
|
||||
action_filters=OrderedDict([
|
||||
('discretization', BoxDiscretization(num_bins_per_dimension=int(1e6)))
|
||||
]),
|
||||
is_a_reference_filter=False
|
||||
)
|
||||
|
||||
###############
|
||||
# Environment #
|
||||
###############
|
||||
env_params = GymVectorEnvironment(level=SingleLevelSelection(mujoco_v2))
|
||||
|
||||
########
|
||||
# Test #
|
||||
########
|
||||
preset_validation_params = PresetValidationParameters()
|
||||
preset_validation_params.test = True
|
||||
preset_validation_params.min_reward_threshold = 500
|
||||
preset_validation_params.max_episodes_to_achieve_reward = 1000
|
||||
preset_validation_params.reward_test_level = 'inverted_pendulum'
|
||||
preset_validation_params.trace_test_levels = ['inverted_pendulum']
|
||||
|
||||
graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params,
|
||||
schedule_params=schedule_params, vis_params=VisualizationParameters(),
|
||||
preset_validation_params=preset_validation_params)
|
||||
@@ -385,7 +385,8 @@ class DiscreteActionSpace(ActionSpace):
|
||||
"""
|
||||
A discrete action space with action indices as actions
|
||||
"""
|
||||
def __init__(self, num_actions: int, descriptions: Union[None, List, Dict]=None, default_action: np.ndarray=None):
|
||||
def __init__(self, num_actions: int, descriptions: Union[None, List, Dict]=None, default_action: np.ndarray=None,
|
||||
filtered_action_space=None):
|
||||
super().__init__(1, low=0, high=num_actions-1, descriptions=descriptions)
|
||||
# the number of actions is mapped to high
|
||||
|
||||
@@ -395,6 +396,9 @@ class DiscreteActionSpace(ActionSpace):
|
||||
else:
|
||||
self.default_action = default_action
|
||||
|
||||
if filtered_action_space is not None:
|
||||
self.filtered_action_space = filtered_action_space
|
||||
|
||||
@property
|
||||
def actions(self) -> List[ActionType]:
|
||||
return list(range(0, int(self.high[0]) + 1))
|
||||
|
||||
Reference in New Issue
Block a user