diff --git a/docs/_modules/rl_coach/agents/agent.html b/docs/_modules/rl_coach/agents/agent.html
index 49d4a8a..0f566aa 100644
--- a/docs/_modules/rl_coach/agents/agent.html
+++ b/docs/_modules/rl_coach/agents/agent.html
@@ -756,6 +756,9 @@
ifself.phase!=RunPhase.TEST:ifisinstance(self.memory,EpisodicExperienceReplay):
+ ifself.ap.algorithm.override_episode_rewards_with_the_last_transition_reward:
+ fortinself.current_episode_buffer.transitions:
+ t.reward=self.current_episode_buffer.transitions[-1].rewardself.call_memory('store_episode',self.current_episode_buffer)elifself.ap.algorithm.store_transitions_only_when_episodes_are_terminated:fortransitioninself.current_episode_buffer.transitions:
@@ -910,7 +913,8 @@
# update countersself.training_iteration+=1ifself.pre_network_filterisnotNone:
- batch=self.pre_network_filter.filter(batch,update_internal_state=False,deep_copy=False)
+ update_internal_state=self.ap.algorithm.update_pre_network_filters_state_on_train
+ batch=self.pre_network_filter.filter(batch,update_internal_state=update_internal_state,deep_copy=False)# if the batch returned empty then there are not enough samples in the replay buffer -> skip# training step
@@ -1020,7 +1024,8 @@
# informed actionifself.pre_network_filterisnotNone:# before choosing an action, first use the pre_network_filter to filter out the current state
- update_filter_internal_state=self.phaseisnotRunPhase.TEST
+ update_filter_internal_state=self.ap.algorithm.update_pre_network_filters_state_on_inferenceand \
+ self.phaseisnotRunPhase.TESTcurr_state=self.run_pre_network_filter_for_inference(self.curr_state,update_filter_internal_state)else:
@@ -1048,6 +1053,10 @@
:return: The filtered state """dummy_env_response=EnvResponse(next_state=state,reward=0,game_over=False)
+
+ # TODO actually we only want to run the observation filters. No point in running the reward filters as the
+ # filtered reward is being ignored anyway (and it might unncecessarily affect the reward filters' internal
+ # state).returnself.pre_network_filter.filter(dummy_env_response,update_internal_state=update_filter_internal_state)[0].next_state
@@ -1177,7 +1186,7 @@
""" Allows setting a directive for the agent to follow. This is useful in hierarchy structures, where the agent has another master agent that is controlling it. In such cases, the master agent can define the goals for the
- slave agent, define it's observation, possible actions, etc. The directive type is defined by the agent
+ slave agent, define its observation, possible actions, etc. The directive type is defined by the agent in-action-space. :param action: The action that should be set as the directive
diff --git a/docs/_modules/rl_coach/agents/clipped_ppo_agent.html b/docs/_modules/rl_coach/agents/clipped_ppo_agent.html
index 9a3f3b1..ba808c2 100644
--- a/docs/_modules/rl_coach/agents/clipped_ppo_agent.html
+++ b/docs/_modules/rl_coach/agents/clipped_ppo_agent.html
@@ -295,7 +295,9 @@
self.optimization_epochs=10self.normalization_stats=Noneself.clipping_decay_schedule=ConstantSchedule(1)
- self.act_for_full_episodes=True
+ self.act_for_full_episodes=True
+ self.update_pre_network_filters_state_on_train=True
+ self.update_pre_network_filters_state_on_inference=FalseclassClippedPPOAgentParameters(AgentParameters):
@@ -486,7 +488,9 @@
network.set_is_training(True)dataset=self.memory.transitions
- dataset=self.pre_network_filter.filter(dataset,deep_copy=False)
+ update_internal_state=self.ap.algorithm.update_pre_network_filters_state_on_train
+ dataset=self.pre_network_filter.filter(dataset,deep_copy=False,
+ update_internal_state=update_internal_state)batch=Batch(dataset)fortraining_stepinrange(self.ap.algorithm.num_consecutive_training_steps):
@@ -512,7 +516,9 @@
defrun_pre_network_filter_for_inference(self,state:StateType,update_internal_state:bool=False):dummy_env_response=EnvResponse(next_state=state,reward=0,game_over=False)
- returnself.pre_network_filter.filter(dummy_env_response,update_internal_state=False)[0].next_state
+ update_internal_state=self.ap.algorithm.update_pre_network_filters_state_on_inference
+ returnself.pre_network_filter.filter(
+ dummy_env_response,update_internal_state=update_internal_state)[0].next_statedefchoose_action(self,curr_state):self.ap.algorithm.clipping_decay_schedule.step()
diff --git a/docs/_modules/rl_coach/agents/wolpertinger_agent.html b/docs/_modules/rl_coach/agents/wolpertinger_agent.html
new file mode 100644
index 0000000..67dd6cd
--- /dev/null
+++ b/docs/_modules/rl_coach/agents/wolpertinger_agent.html
@@ -0,0 +1,356 @@
+
+
+
+
+
+
+
+
+
+
+ rl_coach.agents.wolpertinger_agent — Reinforcement Learning Coach 0.12.0 documentation
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
Source code for rl_coach.agents.wolpertinger_agent
+#
+# Copyright (c) 2019 Intel Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+importcopy
+fromtypingimportUnion
+fromcollectionsimportOrderedDict
+importnumpyasnp
+
+fromrl_coach.agents.ddpg_agentimportDDPGAlgorithmParameters,DDPGActorNetworkParameters, \
+ DDPGCriticNetworkParameters,DDPGAgent
+fromrl_coach.base_parametersimportAgentParameters
+fromrl_coach.core_typesimportActionInfo
+fromrl_coach.exploration_policies.additive_noiseimportAdditiveNoiseParameters
+fromrl_coach.memories.episodic.episodic_experience_replayimportEpisodicExperienceReplayParameters
+fromrl_coach.memories.non_episodic.differentiable_neural_dictionaryimportAnnoyDictionary
+fromrl_coach.spacesimportDiscreteActionSpace,BoxActionSpace
+fromrl_coach.architectures.head_parametersimportWolpertingerActorHeadParameters
+
+
+classWolpertingerCriticNetworkParameters(DDPGCriticNetworkParameters):
+ def__init__(self,use_batchnorm=False):
+ super().__init__(use_batchnorm=use_batchnorm)
+
+
+classWolpertingerActorNetworkParameters(DDPGActorNetworkParameters):
+ def__init__(self,use_batchnorm=False):
+ super().__init__()
+ self.heads_parameters=[WolpertingerActorHeadParameters(batchnorm=use_batchnorm)]
+
+
+
+
+
+classWolpertingerAgentParameters(AgentParameters):
+ def__init__(self,use_batchnorm=False):
+ exploration_params=AdditiveNoiseParameters()
+ exploration_params.noise_as_percentage_from_action_space=False
+
+ super().__init__(algorithm=WolpertingerAlgorithmParameters(),
+ exploration=exploration_params,
+ memory=EpisodicExperienceReplayParameters(),
+ networks=OrderedDict(
+ [("actor",WolpertingerActorNetworkParameters(use_batchnorm=use_batchnorm)),
+ ("critic",WolpertingerCriticNetworkParameters(use_batchnorm=use_batchnorm))]))
+
+ @property
+ defpath(self):
+ return'rl_coach.agents.wolpertinger_agent:WolpertingerAgent'
+
+
+# Deep Reinforcement Learning in Large Discrete Action Spaces - https://arxiv.org/pdf/1512.07679.pdf
+classWolpertingerAgent(DDPGAgent):
+ def__init__(self,agent_parameters,parent:Union['LevelManager','CompositeAgent']=None):
+ super().__init__(agent_parameters,parent)
+
+ deflearn_from_batch(self,batch):
+ # replay buffer holds the actions in the discrete manner, as the agent is expected to act with discrete actions
+ # with the BoxDiscretization output filter. But DDPG needs to work on continuous actions, thus converting to
+ # continuous actions. This is actually a duplicate since this filtering is also done before applying actions on
+ # the environment. So might want to somehow reuse that conversion. Maybe can hold this information in the info
+ # dictionary of the transition.
+
+ output_action_filter= \
+ list(self.output_filter.action_filters.values())[0]
+ continuous_actions=[]
+ foractioninbatch.actions():
+ continuous_actions.append(output_action_filter.filter(action))
+ batch._actions=np.array(continuous_actions).squeeze()
+
+ returnsuper().learn_from_batch(batch)
+
+ deftrain(self):
+ returnsuper().train()
+
+ defchoose_action(self,curr_state):
+ ifnotisinstance(self.spaces.action,DiscreteActionSpace):
+ raiseValueError("WolpertingerAgent works only for discrete control problems")
+
+ # convert to batch so we can run it through the network
+ tf_input_state=self.prepare_batch_for_inference(curr_state,'actor')
+ actor_network=self.networks['actor'].online_network
+ critic_network=self.networks['critic'].online_network
+ proto_action=actor_network.predict(tf_input_state)
+ proto_action=np.expand_dims(self.exploration_policy.get_action(proto_action),0)
+
+ nn_action_embeddings,indices,_,_=self.knn_tree.query(keys=proto_action,k=self.ap.algorithm.k)
+
+ # now move the actions through the critic and choose the one with the highest q value
+ critic_inputs=copy.copy(tf_input_state)
+ critic_inputs['observation']=np.tile(critic_inputs['observation'],(self.ap.algorithm.k,1))
+ critic_inputs['action']=nn_action_embeddings[0]
+ q_values=critic_network.predict(critic_inputs)[0]
+ action=int(indices[0][np.argmax(q_values)])
+ self.action_signal.add_sample(action)
+ returnActionInfo(action=action,action_value=0)
+
+ definit_environment_dependent_modules(self):
+ super().init_environment_dependent_modules()
+ self.knn_tree=self.get_initialized_knn()
+
+ # TODO - ideally the knn should not be defined here, but somehow be defined by the user in the preset
+ defget_initialized_knn(self):
+ num_actions=len(self.spaces.action.actions)
+ action_max_abs_range=self.spaces.action.filtered_action_space.max_abs_rangeif \
+ (hasattr(self.spaces.action,'filtered_action_space')and
+ isinstance(self.spaces.action.filtered_action_space,BoxActionSpace)) \
+ else1.0
+ keys=np.expand_dims((np.arange(num_actions)/(num_actions-1)-0.5)*2,1)*action_max_abs_range
+ values=np.expand_dims(np.arange(num_actions),1)
+ knn_tree=AnnoyDictionary(dict_size=num_actions,key_width=self.ap.algorithm.action_embedding_width)
+ knn_tree.add(keys,values,force_rebuild_tree=True)
+
+ returnknn_tree
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/docs/_modules/rl_coach/base_parameters.html b/docs/_modules/rl_coach/base_parameters.html
index 60aac7f..045363d 100644
--- a/docs/_modules/rl_coach/base_parameters.html
+++ b/docs/_modules/rl_coach/base_parameters.html
@@ -396,6 +396,14 @@
# Support for parameter noiseself.supports_parameter_noise=False
+ # Override, in retrospective, all the episode rewards with the last reward in the episode
+ # (sometimes useful for sparse, end of the episode, rewards problems)
+ self.override_episode_rewards_with_the_last_transition_reward=False
+
+ # Filters - TODO consider creating a FilterParameters class and initialize the filters with it
+ self.update_pre_network_filters_state_on_train=False
+ self.update_pre_network_filters_state_on_inference=True
+
[docs]classNFSDataStore(CheckpointDataStore):""" An implementation of data store which uses NFS for storing policy checkpoints when using Coach in distributed mode. The policy checkpoints are written by the trainer and read by the rollout worker.
diff --git a/docs/_modules/rl_coach/data_stores/s3_data_store.html b/docs/_modules/rl_coach/data_stores/s3_data_store.html
index 64fc0b9..c2e4b9b 100644
--- a/docs/_modules/rl_coach/data_stores/s3_data_store.html
+++ b/docs/_modules/rl_coach/data_stores/s3_data_store.html
@@ -198,7 +198,8 @@
#
-fromrl_coach.data_stores.data_storeimportDataStore,DataStoreParameters
+fromrl_coach.data_stores.data_storeimportDataStoreParameters
+fromrl_coach.data_stores.checkpoint_data_storeimportCheckpointDataStorefromminioimportMiniofromminio.errorimportResponseErrorfromconfigparserimportConfigParser,Error
@@ -222,7 +223,7 @@
self.expt_dir=expt_dir
-
[docs]classS3DataStore(CheckpointDataStore):""" An implementation of the data store using S3 for storing policy checkpoints when using Coach in distributed mode. The policy checkpoints are written by the trainer and read by the rollout worker.
diff --git a/docs/_modules/rl_coach/exploration_policies/additive_noise.html b/docs/_modules/rl_coach/exploration_policies/additive_noise.html
index 44e2dc3..92eb352 100644
--- a/docs/_modules/rl_coach/exploration_policies/additive_noise.html
+++ b/docs/_modules/rl_coach/exploration_policies/additive_noise.html
@@ -245,7 +245,9 @@
self.evaluation_noise=evaluation_noiseself.noise_as_percentage_from_action_space=noise_as_percentage_from_action_space
- ifnotisinstance(action_space,BoxActionSpace):
+ ifnotisinstance(action_space,BoxActionSpace)and \
+ (hasattr(action_space,'filtered_action_space')andnot
+ isinstance(action_space.filtered_action_space,BoxActionSpace)):raiseValueError("Additive noise exploration works only for continuous controls.""The given action space is of type: {}".format(action_space.__class__.__name__))
diff --git a/docs/_modules/rl_coach/exploration_policies/exploration_policy.html b/docs/_modules/rl_coach/exploration_policies/exploration_policy.html
index f5ffd7d..faa4583 100644
--- a/docs/_modules/rl_coach/exploration_policies/exploration_policy.html
+++ b/docs/_modules/rl_coach/exploration_policies/exploration_policy.html
@@ -298,7 +298,10 @@
""" :param action_space: the action space used by the environment """
- assertisinstance(action_space,BoxActionSpace)orisinstance(action_space,GoalsSpace)
+ assertisinstance(action_space,BoxActionSpace)or \
+ (hasattr(action_space,'filtered_action_space')and
+ isinstance(action_space.filtered_action_space,BoxActionSpace))or \
+ isinstance(action_space,GoalsSpace)super().__init__(action_space)
diff --git a/docs/_modules/rl_coach/exploration_policies/truncated_normal.html b/docs/_modules/rl_coach/exploration_policies/truncated_normal.html
index 7d33198..56d84e7 100644
--- a/docs/_modules/rl_coach/exploration_policies/truncated_normal.html
+++ b/docs/_modules/rl_coach/exploration_policies/truncated_normal.html
@@ -271,9 +271,6 @@
else:action_values_std=current_noise
- # scale the noise to the action space range
- action_values_std=current_noise*(self.action_space.high-self.action_space.low)
-
# extract the mean valuesifisinstance(action_values,list):# the action values are expected to be a list with the action mean and optionally the action stdev
diff --git a/docs/_modules/rl_coach/filters/action/partial_discrete_action_space_map.html b/docs/_modules/rl_coach/filters/action/partial_discrete_action_space_map.html
index 0fae9ce..4e4837b 100644
--- a/docs/_modules/rl_coach/filters/action/partial_discrete_action_space_map.html
+++ b/docs/_modules/rl_coach/filters/action/partial_discrete_action_space_map.html
@@ -231,7 +231,8 @@
defget_unfiltered_action_space(self,output_action_space:ActionSpace)->DiscreteActionSpace:self.output_action_space=output_action_space
- self.input_action_space=DiscreteActionSpace(len(self.target_actions),self.descriptions)
+ self.input_action_space=DiscreteActionSpace(len(self.target_actions),self.descriptions,
+ filtered_action_space=output_action_space)returnself.input_action_spacedeffilter(self,action:ActionType)->ActionType:
diff --git a/docs/_modules/rl_coach/memories/backend/redis.html b/docs/_modules/rl_coach/memories/backend/redis.html
index c904f4d..7567160 100644
--- a/docs/_modules/rl_coach/memories/backend/redis.html
+++ b/docs/_modules/rl_coach/memories/backend/redis.html
@@ -261,11 +261,18 @@
"""if'namespace'notinself.params.orchestrator_params:self.params.orchestrator_params['namespace']="default"
- fromkubernetesimportclient
+ fromkubernetesimportclient,configcontainer=client.V1Container(name=self.redis_server_name,image='redis:4-alpine',
+ resources=client.V1ResourceRequirements(
+ limits={
+ "cpu":"8",
+ "memory":"4Gi"
+ # "nvidia.com/gpu": "0",
+ }
+ ),)template=client.V1PodTemplateSpec(metadata=client.V1ObjectMeta(labels={'app':self.redis_server_name}),
@@ -288,8 +295,10 @@
spec=deployment_spec)
+ config.load_kube_config()api_client=client.AppsV1Api()try:
+ print(self.params.orchestrator_params)api_client.create_namespaced_deployment(self.params.orchestrator_params['namespace'],deployment)exceptclient.rest.ApiExceptionase:print("Got exception: %s\n while creating redis-server",e)
diff --git a/docs/_modules/rl_coach/memories/non_episodic/differentiable_neural_dictionary.html b/docs/_modules/rl_coach/memories/non_episodic/differentiable_neural_dictionary.html
index 2a5d1f9..063e0e4 100644
--- a/docs/_modules/rl_coach/memories/non_episodic/differentiable_neural_dictionary.html
+++ b/docs/_modules/rl_coach/memories/non_episodic/differentiable_neural_dictionary.html
@@ -240,7 +240,7 @@
self.built_capacity=0
- defadd(self,keys,values,additional_data=None):
+ defadd(self,keys,values,additional_data=None,force_rebuild_tree=False):ifnotadditional_data:additional_data=[None]*len(keys)
@@ -279,7 +279,7 @@
iflen(self.buffered_indices)>=self.min_update_size:self.min_update_size=max(self.initial_update_size,int(self.curr_size*0.02))self._rebuild_index()
- elifself.rebuild_on_every_update:
+ elifforce_rebuild_treeorself.rebuild_on_every_update:self._rebuild_index()self.current_timestamp+=1
diff --git a/docs/_modules/rl_coach/orchestrators/kubernetes_orchestrator.html b/docs/_modules/rl_coach/orchestrators/kubernetes_orchestrator.html
index b8c99db..71d144f 100644
--- a/docs/_modules/rl_coach/orchestrators/kubernetes_orchestrator.html
+++ b/docs/_modules/rl_coach/orchestrators/kubernetes_orchestrator.html
@@ -307,6 +307,11 @@
"""self.memory_backend.deploy()
+
+ ifself.params.data_store_params.store_type=="redis":
+ self.data_store.params.redis_address=self.memory_backend.params.redis_address
+ self.data_store.params.redis_port=self.memory_backend.params.redis_port
+
ifnotself.data_store.deploy():returnFalseifself.params.data_store_params.store_type=="nfs":
@@ -329,6 +334,8 @@
trainer_params.command+=['--data_store_params',json.dumps(self.params.data_store_params.__dict__)]name="{}-{}".format(trainer_params.run_type,uuid.uuid4())
+ # TODO: instead of defining each container and template spec from scratch, loaded default
+ # configuration and modify them as necessary depending on the store typeifself.params.data_store_params.store_type=="nfs":container=k8sclient.V1Container(name=name,
@@ -354,7 +361,7 @@
restart_policy='Never'),)
- else:
+ elifself.params.data_store_params.store_type=="s3":container=k8sclient.V1Container(name=name,image=trainer_params.image,
@@ -373,6 +380,34 @@
restart_policy='Never'),)
+ elifself.params.data_store_params.store_type=="redis":
+ container=k8sclient.V1Container(
+ name=name,
+ image=trainer_params.image,
+ command=trainer_params.command,
+ args=trainer_params.arguments,
+ image_pull_policy='Always',
+ stdin=True,
+ tty=True,
+ resources=k8sclient.V1ResourceRequirements(
+ limits={
+ "cpu":"40",
+ "memory":"4Gi",
+ "nvidia.com/gpu":"1",
+ }
+ ),
+ )
+ template=k8sclient.V1PodTemplateSpec(
+ metadata=k8sclient.V1ObjectMeta(labels={'app':name}),
+ spec=k8sclient.V1PodSpec(
+ containers=[container],
+ restart_policy='Never'
+ ),
+ )
+ else:
+ raiseValueError("unexpected store_type {}. expected 's3', 'nfs', 'redis'".format(
+ self.params.data_store_params.store_type
+ ))job_spec=k8sclient.V1JobSpec(completions=1,
@@ -404,12 +439,17 @@
ifnotworker_params:returnFalse
+ # At this point, the memory backend and data store have been deployed and in the process,
+ # these parameters have been updated to include things like the hostname and port the
+ # service can be found at.worker_params.command+=['--memory_backend_params',json.dumps(self.params.memory_backend_parameters.__dict__)]worker_params.command+=['--data_store_params',json.dumps(self.params.data_store_params.__dict__)]worker_params.command+=['--num_workers','{}'.format(worker_params.num_replicas)]name="{}-{}".format(worker_params.run_type,uuid.uuid4())
+ # TODO: instead of defining each container and template spec from scratch, loaded default
+ # configuration and modify them as necessary depending on the store typeifself.params.data_store_params.store_type=="nfs":container=k8sclient.V1Container(name=name,
@@ -435,7 +475,7 @@
restart_policy='Never'),)
- else:
+ elifself.params.data_store_params.store_type=="s3":container=k8sclient.V1Container(name=name,image=worker_params.image,
@@ -454,6 +494,32 @@
restart_policy='Never'))
+ elifself.params.data_store_params.store_type=="redis":
+ container=k8sclient.V1Container(
+ name=name,
+ image=worker_params.image,
+ command=worker_params.command,
+ args=worker_params.arguments,
+ image_pull_policy='Always',
+ stdin=True,
+ tty=True,
+ resources=k8sclient.V1ResourceRequirements(
+ limits={
+ "cpu":"8",
+ "memory":"4Gi",
+ # "nvidia.com/gpu": "0",
+ }
+ ),
+ )
+ template=k8sclient.V1PodTemplateSpec(
+ metadata=k8sclient.V1ObjectMeta(labels={'app':name}),
+ spec=k8sclient.V1PodSpec(
+ containers=[container],
+ restart_policy='Never'
+ )
+ )
+ else:
+ raiseValueError('unexpected store type {}'.format(self.params.data_store_params.store_type))job_spec=k8sclient.V1JobSpec(completions=worker_params.num_replicas,
diff --git a/docs/_modules/rl_coach/spaces.html b/docs/_modules/rl_coach/spaces.html
index 2e890d9..30472c3 100644
--- a/docs/_modules/rl_coach/spaces.html
+++ b/docs/_modules/rl_coach/spaces.html
@@ -568,7 +568,8 @@
""" A discrete action space with action indices as actions """
- def__init__(self,num_actions:int,descriptions:Union[None,List,Dict]=None,default_action:np.ndarray=None):
+ def__init__(self,num_actions:int,descriptions:Union[None,List,Dict]=None,default_action:np.ndarray=None,
+ filtered_action_space=None):super().__init__(1,low=0,high=num_actions-1,descriptions=descriptions)# the number of actions is mapped to high
@@ -578,6 +579,9 @@
else:self.default_action=default_action
+ iffiltered_action_spaceisnotNone:
+ self.filtered_action_space=filtered_action_space
+
@propertydefactions(self)->List[ActionType]:returnlist(range(0,int(self.high[0])+1))
diff --git a/docs/_sources/components/agents/index.rst.txt b/docs/_sources/components/agents/index.rst.txt
index ca21713..c958768 100644
--- a/docs/_sources/components/agents/index.rst.txt
+++ b/docs/_sources/components/agents/index.rst.txt
@@ -21,8 +21,6 @@ A detailed description of those algorithms can be found by navigating to each of
imitation/cil
policy_optimization/cppo
policy_optimization/ddpg
- policy_optimization/td3
- policy_optimization/sac
other/dfp
value_optimization/double_dqn
value_optimization/dqn
@@ -36,6 +34,10 @@ A detailed description of those algorithms can be found by navigating to each of
policy_optimization/ppo
value_optimization/rainbow
value_optimization/qr_dqn
+ policy_optimization/sac
+ policy_optimization/td3
+ policy_optimization/wolpertinger
+
.. autoclass:: rl_coach.base_parameters.AgentParameters
diff --git a/docs/_sources/components/agents/policy_optimization/wolpertinger.rst.txt b/docs/_sources/components/agents/policy_optimization/wolpertinger.rst.txt
new file mode 100644
index 0000000..5aa57d2
--- /dev/null
+++ b/docs/_sources/components/agents/policy_optimization/wolpertinger.rst.txt
@@ -0,0 +1,56 @@
+Wolpertinger
+=============
+
+**Actions space:** Discrete
+
+**References:** `Deep Reinforcement Learning in Large Discrete Action Spaces `_
+
+Network Structure
+-----------------
+
+.. image:: /_static/img/design_imgs/wolpertinger.png
+ :align: center
+
+Algorithm Description
+---------------------
+Choosing an action
+++++++++++++++++++
+
+Pass the current states through the actor network, and get a proto action :math:`\mu`.
+While in training phase, use a continuous exploration policy, such as the a gaussian noise,
+to add exploration noise to the proto action. Then, pass the proto action to a k-NN tree to find actual valid
+action candidates, which are in the surrounding neighborhood of the proto action. Those actions are then passed to the
+critic to evaluate their goodness, and eventually the discrete index of the action with the highest Q value is chosen.
+When testing, the same flow is used, but no exploration noise is added.
+
+Training the network
+++++++++++++++++++++
+
+Training the network is exactly the same as in DDPG. Unlike when choosing the action, the proto action is not passed
+through the k-NN tree. It is being passed directly to the critic.
+
+Start by sampling a batch of transitions from the experience replay.
+
+* To train the **critic network**, use the following targets:
+
+ :math:`y_t=r(s_t,a_t )+\gamma \cdot Q(s_{t+1},\mu(s_{t+1} ))`
+
+ First run the actor target network, using the next states as the inputs, and get :math:`\mu (s_{t+1} )`.
+ Next, run the critic target network using the next states and :math:`\mu (s_{t+1} )`, and use the output to
+ calculate :math:`y_t` according to the equation above. To train the network, use the current states and actions
+ as the inputs, and :math:`y_t` as the targets.
+
+* To train the **actor network**, use the following equation:
+
+ :math:`\nabla_{\theta^\mu } J \approx E_{s_t \tilde{} \rho^\beta } [\nabla_a Q(s,a)|_{s=s_t,a=\mu (s_t ) } \cdot \nabla_{\theta^\mu} \mu(s)|_{s=s_t} ]`
+
+ Use the actor's online network to get the action mean values using the current states as the inputs.
+ Then, use the critic online network in order to get the gradients of the critic output with respect to the
+ action mean values :math:`\nabla _a Q(s,a)|_{s=s_t,a=\mu(s_t ) }`.
+ Using the chain rule, calculate the gradients of the actor's output, with respect to the actor weights,
+ given :math:`\nabla_a Q(s,a)`. Finally, apply those gradients to the actor network.
+
+After every training step, do a soft update of the critic and actor target networks' weights from the online networks.
+
+
+.. autoclass:: rl_coach.agents.wolpertinger_agent.WolpertingerAlgorithmParameters
\ No newline at end of file
diff --git a/docs/components/agents/index.html b/docs/components/agents/index.html
index 357caad..7c90e3f 100644
--- a/docs/components/agents/index.html
+++ b/docs/components/agents/index.html
@@ -117,8 +117,6 @@
Convert curr_state into input tensors tensorflow is expecting. i.e. if we have several inputs states, stack all
observations together, measurements together, etc.
@@ -652,7 +654,7 @@ dependent on those values, by calling init_environment_dependent_modules
set_incoming_directive(action: Union[int, float, numpy.ndarray, List]) → None[source]¶
Allows setting a directive for the agent to follow. This is useful in hierarchy structures, where the agent
has another master agent that is controlling it. In such cases, the master agent can define the goals for the
-slave agent, define it’s observation, possible actions, etc. The directive type is defined by the agent
+slave agent, define its observation, possible actions, etc. The directive type is defined by the agent
in-action-space.
Pass the current states through the actor network, and get a proto action \(\mu\).
+While in training phase, use a continuous exploration policy, such as the a gaussian noise,
+to add exploration noise to the proto action. Then, pass the proto action to a k-NN tree to find actual valid
+action candidates, which are in the surrounding neighborhood of the proto action. Those actions are then passed to the
+critic to evaluate their goodness, and eventually the discrete index of the action with the highest Q value is chosen.
+When testing, the same flow is used, but no exploration noise is added.
Training the network is exactly the same as in DDPG. Unlike when choosing the action, the proto action is not passed
+through the k-NN tree. It is being passed directly to the critic.
+
Start by sampling a batch of transitions from the experience replay.
+
+
To train the critic network, use the following targets:
First run the actor target network, using the next states as the inputs, and get \(\mu (s_{t+1} )\).
+Next, run the critic target network using the next states and \(\mu (s_{t+1} )\), and use the output to
+calculate \(y_t\) according to the equation above. To train the network, use the current states and actions
+as the inputs, and \(y_t\) as the targets.
+
+
To train the actor network, use the following equation:
Use the actor’s online network to get the action mean values using the current states as the inputs.
+Then, use the critic online network in order to get the gradients of the critic output with respect to the
+action mean values \(\nabla _a Q(s,a)|_{s=s_t,a=\mu(s_t ) }\).
+Using the chain rule, calculate the gradients of the actor’s output, with respect to the actor weights,
+given \(\nabla_a Q(s,a)\). Finally, apply those gradients to the actor network.
+
+
+
After every training step, do a soft update of the critic and actor target networks’ weights from the online networks.
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/docs/components/spaces.html b/docs/components/spaces.html
index a62653c..753e334 100644
--- a/docs/components/spaces.html
+++ b/docs/components/spaces.html
@@ -442,7 +442,7 @@ The actions will be in the form:
Convert curr_state into input tensors tensorflow is expecting. i.e. if we have several inputs states, stack all
observations together, measurements together, etc.
diff --git a/docs_raw/source/_static/img/algorithms.png b/docs_raw/source/_static/img/algorithms.png
index 6c00f21..0849ad7 100644
Binary files a/docs_raw/source/_static/img/algorithms.png and b/docs_raw/source/_static/img/algorithms.png differ
diff --git a/docs_raw/source/_static/img/design_imgs/wolpertinger.png b/docs_raw/source/_static/img/design_imgs/wolpertinger.png
new file mode 100644
index 0000000..e7f9b37
Binary files /dev/null and b/docs_raw/source/_static/img/design_imgs/wolpertinger.png differ
diff --git a/docs_raw/source/algorithms.xml b/docs_raw/source/algorithms.xml
index e6f68c9..377709c 100644
--- a/docs_raw/source/algorithms.xml
+++ b/docs_raw/source/algorithms.xml
@@ -1 +1 @@
-7V1bk5s2FP41O9M+JIMQ18e9NduZbCdp0jR5ysgg26SAXMC73v76CnMVEphdkLGzbB5sDhJgne/7dCQdkQt4HezeRWizvicu9i9Uxd1dwJsLVQVAt+hHannKLIZuZ4ZV5Ll5ocrwyfsP50Ylt249F8dMwYQQP/E2rNEhYYidhLGhKCKPbLEl8dm7btAKc4ZPDvJ569+em6wzq6Urlf0Oe6t1cWeg5GcCVBTODfEaueSxZoK3F/A6IiTJvgW7a+ynjVe0S1bvt5az5YNFOEz6VPh+fW+gL8GXH+p7HP5156+j3dc3mp4/XPJU/GLs0gbID0mUrMmKhMi/raxXEdmGLk4vq9Cjqsx7QjbUCKjxB06Sp9ybaJsQalongZ+fxTsv+ZpWf6vnR99qZ252+ZX3B0/FQZhET7VK6eG3+rmq2v6oqBcnEfmn9F16j+wXpz+ztSVzU0y2kYM7mi8ttYckilY46SgItdLjlCqYBJg+I60YYR8l3gP7JCjH7Kosl1e9jCL0VCuwIV6YxLUrf0gNtEBOP6jl2MvJpxsNiDTLm1ZXefole4LiqPZTKtMeds+AYP6jH5C/zZuBg2QFuNR/j2svwZ82aO+ZRyo7LLiWnu9fE59E+7pwqaf/SiDUzhj7v7QGCZOaPftLaxSMTXGEIifHs8qhSi1R9YCjBO+6ccWjoKigs+4ChXQ8Vsqj5i5R1jXVgYXoiJBT894LnKNwzpj14Rn60FceBqrDIB+rHAHJIsbRA701CYdxkWEQT0DHwoulkE0NErsIW0uHYSEYiXRU9N7qDO0MnnW2zpNOlcY5wDX6zLn+nIM9OTe0Rx7kYyjo9AyfPu7Vgn5ZJfsmyQxp78Sgwfh3S4oTb+K9Py9pAaBvdtXJ4iq/h5ttWv02WGDXxVFxVfrU2YXZm1Fz7QHkUd9SF3Df9x6mvo4tV2Oob4zEfMVqMB9AnvpAhzz3LVnc146Fi3vPdX38SBtrRkQNETaDBxoBC/Cg8njQZOFBHzk4PgMfADYILrnG+EAQBJfkHd0JhgRSQhEpb24+vKMnL52EtrOq3GHkDmEndULScD/j8ZCEuOHa3IR8bxXSQ4c6jfYZ8Cp1qecg/zI/EezlowVvbCzCQG4EgABDacg2FNBUEyBEXsSm/ewR24jxmdkzPrOmjM/MVsr3ZbgtInirbiyiUgNwGGMxvZtF/9gG9EZkm9AQL87uGm6D78hJB24xX/6AltQLyutjXNNeKPwgQNDHLJfYcCQN+lRWQVRL0NGropkWXZaEzBOxQ0TF6ikq9pSiYo0fRwhV5jMK754VN7w+/utqM4aA8C2vAYZo3gfKkgB7loABEmD3lAAwaWBhcxrwlfe673ubuC26rvkOxZtszXPp7VIMXLE8HZ2IYzDPZId30OSHd1C4xCGLdXBm3QDWaT1ZZ05JOmBO6WNQ83Dl70M+ZjxcOXwKHwPjHKIr0DlNcziUskSh1GXbWOpPFK56DtM47L2aeRlz4nkZYxjvwSvXdtB3VJXNSh87u6UctOdw06wD2S2ws7yk7BZ+0JfNz1BbgNFrXF/XtIkX2IsLzylHKrDYcFyzBSlHotUWaEjzjtraj/deXVFEfblDQgclrRMjLx2EncGYy26k9QHeyaZglVsaAeHYOX9nsKxpaKwPBLGRcGkZyPOCiGg/uRcaCbGmIMMSiDIsAZCW8WHNMeqQbK++6V5Z8DdZvld7wtcJLSjOC4Rt05S6LVgcMARhkaz1QW0eyg6Sib7zlNCYVCba8/9mmThBmYAaO7SdXCaKCHKWiZfJhN5TJrRJp7qhKC10lomTlYlmNGEIZsCOKxPDtpi8epnouyCmK2JcHEkm2lfEZpk4QZloRhOTy4Q2RxPH2R0+dP/ni9bPnr07XIFd5eWsn2nGDMH5BQWTvqBAH7a7YxoIqqeDQa1vZqY+iQyqjSVQ07QPyGBneUkYPJtVgiNAaShCDkLAMhurPRl081qNoGqMTo7PCp4kUBcG5WXlkN9uBOZwvmux0WSFQrC7/Lij/mF7jKbpyMzT6ciKPJgT7chgMatTdkwHginV7CovqSMT5V7JkLrxZiAGPVeXaM4SCUEzOXB6kTybASevdSewu1vvm49hqFJCOQ2wotY3lHu22AJ2MsTSDoitAbrKSxJbCa8ZEqrcvRceELXXvuHUMliZsy3BdtNipuEo20314YMOAERg+DjOZYqrxNsSOPWes2amnyhIMREu4vQj+0UH8diZ/vvTpeFr7MyZJdh3KdztLG1vjioJfoEX/jIBBGkrKB8FFVRxhV9nfLKRoAobkaAIoeZREdqeZVG68fMNlWfly0X6EM1OrfH6pkWrL/n9gCM0KLAbgXUxyVTfjCdI+i/W/MZvzh6v0RK+le7YDWexEaQhyBHXRB21tA0xZ7P+cJIjkhzPR9rEKNh12Nhh1Xx35EgDkuZuyOJF9G0DEq3xXM8tf+TdlvSwelV+Vrz6Dwfg7f8=
\ No newline at end of file
+7V1bk5s2FP41O9M+ZAdJXB/31t3OZDtJk6bJUwaDbJMCcgHvevvrK4ywEYiL18jgLJsHm4MEWOf7Ph1JR+QC3QSb+8heLR+Ji/0LqLibC3R7ASFQoU4/UstLZjGgmhkWkeeyQnvDJ+8/zIwKs649F8dcwYQQP/FWvNEhYYidhLPZUUSe+WJz4vN3XdkLXDF8cmy/av3bc5NlZjU1ZW9/wN5imd8ZKOxMYOeFmSFe2i55LpjQ3QW6iQhJsm/B5gb7aePl7ZLV+63m7O7BIhwmXSp8v3nU7S/Blx/wPQ7/evCX0ebrO1VjD5e85L8Yu7QB2CGJkiVZkND27/bW64isQxenl1Xo0b7Me0JW1Aio8QdOkhfmTXudEGpaJoHPzuKNl3xNq19q7Ohb4czthl15e/CSH4RJ9FKolB5+K57bV9se5fXiJCL/7HyX3qPacKwtY7KOHNzQWmmpLQLtaIGThoKIgTxty8ItmGPuMQkwfUZaIMK+nXhPPNhshtnFrhyrehVF9kuhwIp4YRIXrvwhNdACjH5IZdhj5NP0EkTK5Q2zqTz9kj1BflT4KXvTFnYHQJD96CfbX7NmqEByD7jUf89LL8GfVvbWVc9UdnhwzT3fvyE+ibZ10VxL/+2AUDijb//SGiRMCvbsL62RMzbFkR05DM+wgirYhKonHCV404iC/KzGuwvk0vG8Vx7IXKIsC6qDctERIafgvVc4R6k4Y9KHen3oKg99q8NRPoYVApJZjKMnemsSHsdFjkFVAjomns2FbCqR2LWxOXc4FoKeSEdF71LjaKdXWWdpVdJBaZwDlUafONfGpVbO9d4jH+VjJOj0dJ8+//WMflkk2ybJDGnvxKFB/3dN8hPv4q0/r2gBoK02+5P5VX4PV+u0+l0ww66Lo/yq9KmzC/M3o+bCA8ijvglnaNv3tlNfw6arctTXe2K+YpaYD1CV+kBDVe6bsrivngoXj57r+viZtt+EiAIiLA4PNAIW4AFW8aDKwoPWc3B8Bj4AfBC84xrnA0EQvCNv707QJZASiUh5e/vhnp68chLazlB5wLZ7DDtpiycl93MeD0mIS65lJtv3FiE9dKgfaZ+BrlP/eY7tX7ETwVY+avDGxyIc5HoACNCVkmwjAU1VAULkRWzqzx6xvT4+MzrGZ+ao4jOjlvJdGW6JCF6rG7NopwE4jLGY3uWif6wDeiOyTmiIF2d3DdfBd9tJB25xtXyLlhQLyutjXMOaKdVBgKCPmc+x7kga9EFeQaAp6OihaKZFkyUh00TsAaJidhQVa1SiYvYfRwhV5rMdPhwUN7w9/muwHEMgdFnVAF0074NkSYA1SUB3CbA6SgAYV2BhVTTga9Xrvu+t4rrouuA7O15la55zb5Ni4Jrnae9E7IN5Bj+8Q0Z1eIeESxyyWIcm1nVnndqRdcaoSAeMIX0MCh7e+7vNx5yH9w4/gY+BfpbRFWicpmkPpUxRKHVVN5b60w4XHYdpFey9mXkZY+B5Gf043oO3pe2g66gqn4QeOLtlN2hncFPNluwW1FheUnZLddCXzc9QW4Dtt7i+rqoDL7DnF55SjiAw+XBctQQpR6LVFqRL8w6s7cc7r64oor7cIaFjJ7UTI68dhJ3BmMsqpfWBqpMNwSq3NAKivnP+zmBZU1d5HwhiI+HSMpDnBRHRfnIvlBJiDUGGJRBlWAIgLePDnGLUA7K9uqZ75bHeSAanqD7ha0QLitMCYd00pWYJFgd0QVgka31QnYayh8hE13lKpI9LJurz/yaZGKFMIJUf2g4uE3kEOclEJ5nQOsqEOq6pbiRKC51kYrQyUY4mdMEM2Gll4rgtJm9NJrouiOUTZWORifoVsUkmRigT5WhicJlQp2jiAJnIp8/aBx297/981frZwbvDFdRUXs76mapPEJxeUDDoCwq043Z3DANBOFyKWNfMTG0cMghLS6CGYbXIYGN5SRg8m1UCGVDqGSGtEDCN0mpPhmVWqxRU9dHJVbOCBwnUhUH5rnJY3W4EpnC+abHR4IVCsLv8tKP+4/YYDdORGYN1ZPlo/lw6MpTP6uw6ppZgChpN5SV1ZKLcKxlS198MxFHP1SSak0QiUE4OHF4kz2bAWdW60+/u1rrmY+jwNKGcCnhR6xrKHSy2gJ8MMdUWsdVBU3lJYivhNUNClXv0whZRe+sbTk2dlznLFGw3zWcaTrLdVDt+0AGACAwf+7lMfpV4vQNOsecsmOmnHaSYCGdx+pH9olY8Nqb//nRp+Co/c2YK9l0KdztL25sDJcEv8MJfBoAgbQXlo6ACFFf4dcInHwlCVIoERQg1TorQ+iyLnRs/31J5Vr5cpA9R7tRKr2+a1fqyuh+whwYFVimwzieZipvxBEn/+Zpf/83Z4TVawrfSnbrhTD6C1AU54qqoo5a2IeZs1h/GMCJh8B1qE6Ng12Fph1X53ZE9DUjKuyHzF9HXDUjU0nMdWv7Euy3p4f5V+Vnx/X84gO7+Bw==7Vxbt5o4FP415/F0kQQCPOq5tJ1e1plpV6edl1kRojIHiYNYtb++QYkQEi+nEPGcqg/KzgXI/vaXnZ0NV+hmsnydkun4AwtpfAWtcHmFbq8gBDbE/CeXrDYSF9obwSiNwqJSKfgU/aCF0Cqk8yikM6lixlicRVNZGLAkoUEmyUiasoVcbchi+axTMqKK4FNAYlX6dxRm443Uc6xS/oZGo7E4M7CKkgkRlQvBbExCtqiI0N0VukkZyzb/JssbGueDJ8Ylip0vDPcH797f/fEwsd8FD9eL601n909psr2FlCbZL3f9o2e/S1b35Os/N29Xt70VfvsQi66/k3hejFdxr9lKDGDK5klI807AFeovxlFGP01JkJcuOGS4bJxN4qJ4GMXxDYtZum6Lhk7+5fJZlrJHWinB60/egiVZRb755C3EaFv8gKRBgSy47UxoM5cUd0HTjC5raj8wZmCrSG4BlE1olq54O9GLgEkBfiCwsCihBHEhG1dghASKSAHf0bbvUkX8T6GlJ2hMVRANOeCLQ5ZmYzZiCYnvSmm/VGE+mGWd94xNC8X9R7NsVYwxmWdMVitdRtnXvPkrpzj6Vim5XRY9rw9W4iDht1tplB9+q5aVzdZHop2sXbDVbn6bv6BbPlRsngZ0z5CigpZIOqLZnnqOHispjUkWfZcvrnW9I8VS2WBG0+/81CxpZrSSqamWGnh0MNSaXc3aQ0K9YSCZKzBpnRC5rxzJPrFqnr6jWic0ZZz2xTjbNk7nSOPEXRqno5lGccwvtz/gf0bZepw2gny+kyCC/58zUXA9Wyu5xysAZ7osC0Uvb5PpPG9+NxnQMKSp6JVf9aZj+WRcXLkAcxzhwQFaz+aHOcKhXmhLHIFNUoTl1SgCIJUjgINUkvBMkQQ+FVg+RGEY0wUfwQtMDsHEl0ACoaMBCVRBYpsCiduyY/5cFQNkB3xrlZJiNA741sxb14xnwHyRznxvbx9e88JekPHBh9YbSsImdsw1k9UwIcEgYQmt6bsQkTgaJfww4JrkUw7q53qO+BK7VxRM1kSzA4SyfyPh0BRqALZqrI80Bm1rYGPMM/Rfumdo2g8URHDQEQSwS09QXKaGHo5lA19HBjs5ZpBu+YImM6qngnrVj/MJPxGbZ9ybnG3Omswn/5IgX0zO1PoHeKda0dwkFbr+wFIXIZpJajikODjlQhTKbAN9zSwFtWEiUzAEiiZeGN+cfiUK4LEMZHfKQLB9B0VLSZ9J8uZJDsmFLNbk4MC6c4LQK9U9wbrAFTIFGnThi9b5wj6WL/xO+cJW+OKrCoY4jqazXS5+RaVkNt3s4g2jZQ6NvmzTrRutMSt15YUnctUpHZ10RncuFtq2heJjDdTSY+VEBoq71Dyo6L1EwSHNS3ovYXA2mgfu8/DldGHArS932G/zdH5bb9cq7y+SjI5cQCqI/L2jS27H0SXcLLy0ZoP5ZCrqF5o5+/mCpFkvTxEqL3ktu4/y0VvXqWWX+L0bt39/Cn7xj+QX22vIL+umfBTIqlJhyqIkm1V6fsgFFe/GkwFse7XcoVp94FuN6mNrf/369dTq8z+bOywtZDtUDbZYmq15ujWaFwF+3K1f5SuT6zRlXDn8XMFvnVUj27rdcVIN1IXUm2jmhSUpQqvuj2AANAElBHQuiXBd2s+JcZ8xvZr0SUzT7zbN+mA+Y6c5U6LjBjtlwNKtbwKWBCTbGZn+1cjWMw1kIVAPN7sW1rCDq0mBMkfpajarycn2fNNbalFGz3I0qgEYqLrBxnj7spbshreRdyRv2ztAdaJE9N0JUO1mODwhn+HANmSj60rUXAlwyYfYT2ueX5tzPKjzSPU5EaaIzW7B4QA6hPzZTjf8l0xyBCSDWf6zuc6DG+x7vZvfYsWKgLxi9Xxft5+OT+jg2DoHp8ZvX66UpM6nZG8aM15QX0564pHPanjbVkfThoZG02229/lyXZKuInz2sa4KbOqqtBLexv7+8DOy/X31zYSfbROJ5JeHhtpgQNGPgAP0Na6LpxKgY2o6cZs9WHghwNYJ8Ngtjm4IcPvWhrMmQHX/5UKAZ0GACIP6jgJ0teu3U5Kg03y/9kSMp+HaswgwHclZdtN81FacNnAgZ8CGe+sbyhm4bGqd10TsHJvDiZom2jTjrstD32c3y0HhVYt9MREg7Oyhb+xd2OWs2EW8T8Z8Gl8z3KgZMo/XHz9yyeeUch6wuOIf59NmVt5R1kzeiQC0COAbf4DT0b0kRLdZYczZxeqE8TklURIlI0WNamy4zZxqWQG7bTEmAxr3SfA4WreXtp/yj0EmB9iurVecrp/3d9XNpjzv/qI9jfHhejKya0p3/LB8X+NmTVC+9RLd/QQ=
\ No newline at end of file
diff --git a/docs_raw/source/components/agents/index.rst b/docs_raw/source/components/agents/index.rst
index ca21713..c958768 100644
--- a/docs_raw/source/components/agents/index.rst
+++ b/docs_raw/source/components/agents/index.rst
@@ -21,8 +21,6 @@ A detailed description of those algorithms can be found by navigating to each of
imitation/cil
policy_optimization/cppo
policy_optimization/ddpg
- policy_optimization/td3
- policy_optimization/sac
other/dfp
value_optimization/double_dqn
value_optimization/dqn
@@ -36,6 +34,10 @@ A detailed description of those algorithms can be found by navigating to each of
policy_optimization/ppo
value_optimization/rainbow
value_optimization/qr_dqn
+ policy_optimization/sac
+ policy_optimization/td3
+ policy_optimization/wolpertinger
+
.. autoclass:: rl_coach.base_parameters.AgentParameters
diff --git a/docs_raw/source/components/agents/policy_optimization/wolpertinger.rst b/docs_raw/source/components/agents/policy_optimization/wolpertinger.rst
new file mode 100644
index 0000000..5aa57d2
--- /dev/null
+++ b/docs_raw/source/components/agents/policy_optimization/wolpertinger.rst
@@ -0,0 +1,56 @@
+Wolpertinger
+=============
+
+**Actions space:** Discrete
+
+**References:** `Deep Reinforcement Learning in Large Discrete Action Spaces `_
+
+Network Structure
+-----------------
+
+.. image:: /_static/img/design_imgs/wolpertinger.png
+ :align: center
+
+Algorithm Description
+---------------------
+Choosing an action
+++++++++++++++++++
+
+Pass the current states through the actor network, and get a proto action :math:`\mu`.
+While in training phase, use a continuous exploration policy, such as the a gaussian noise,
+to add exploration noise to the proto action. Then, pass the proto action to a k-NN tree to find actual valid
+action candidates, which are in the surrounding neighborhood of the proto action. Those actions are then passed to the
+critic to evaluate their goodness, and eventually the discrete index of the action with the highest Q value is chosen.
+When testing, the same flow is used, but no exploration noise is added.
+
+Training the network
+++++++++++++++++++++
+
+Training the network is exactly the same as in DDPG. Unlike when choosing the action, the proto action is not passed
+through the k-NN tree. It is being passed directly to the critic.
+
+Start by sampling a batch of transitions from the experience replay.
+
+* To train the **critic network**, use the following targets:
+
+ :math:`y_t=r(s_t,a_t )+\gamma \cdot Q(s_{t+1},\mu(s_{t+1} ))`
+
+ First run the actor target network, using the next states as the inputs, and get :math:`\mu (s_{t+1} )`.
+ Next, run the critic target network using the next states and :math:`\mu (s_{t+1} )`, and use the output to
+ calculate :math:`y_t` according to the equation above. To train the network, use the current states and actions
+ as the inputs, and :math:`y_t` as the targets.
+
+* To train the **actor network**, use the following equation:
+
+ :math:`\nabla_{\theta^\mu } J \approx E_{s_t \tilde{} \rho^\beta } [\nabla_a Q(s,a)|_{s=s_t,a=\mu (s_t ) } \cdot \nabla_{\theta^\mu} \mu(s)|_{s=s_t} ]`
+
+ Use the actor's online network to get the action mean values using the current states as the inputs.
+ Then, use the critic online network in order to get the gradients of the critic output with respect to the
+ action mean values :math:`\nabla _a Q(s,a)|_{s=s_t,a=\mu(s_t ) }`.
+ Using the chain rule, calculate the gradients of the actor's output, with respect to the actor weights,
+ given :math:`\nabla_a Q(s,a)`. Finally, apply those gradients to the actor network.
+
+After every training step, do a soft update of the critic and actor target networks' weights from the online networks.
+
+
+.. autoclass:: rl_coach.agents.wolpertinger_agent.WolpertingerAlgorithmParameters
\ No newline at end of file
diff --git a/docs_raw/source/diagrams.xml b/docs_raw/source/diagrams.xml
index 15f067f..9b5e64a 100644
--- a/docs_raw/source/diagrams.xml
+++ b/docs_raw/source/diagrams.xml
@@ -1 +1 @@
-7V1bd6M6sv41eUwWQlwfc+l07zPdPdm7Z87MnJeziE0cZmPwYJxL//qRMMKgkgHbEoa0MrN627K5mK+qVPWpqnSBb5dvn7Ng9fwtnYfxhWnM3y7w3YVpIux75D905H074tj+dmCRRfPyS7uBH9HPsBw0ytFNNA/XjS/maRrn0ao5OEuTJJzljbEgy9LX5tee0rh51VWwCMHAj1kQw9F/RPP8eTvq2cZu/EsYLZ7ZlZFRfvIYzP5cZOkmKa93YeKn4m/78TJg5yq/v34O5ulrbQh/usC3WZrm21fLt9swps+WPbbtcfd7Pq3uOwuTvNcBGG8PeQniTcjuubiz/J09jeL3hPQIdIFvXp+jPPyxCmb001eCPxl7zpdx+fEiDtb06Rvk9SxdRrPy9TrP0j/D2zROs+Ks2Jl54eNT9Ql7zpiMPEVxXPvmPAi9pxkdT5O8FBbTKN/XvmcUf2Q8iKNFQsbi8Cmnb7NZeZRD3sFHVD61lzDLw7faUPnIPofpMsyzd/IV9qnllfi9M7llEv+6ExeLCcVzTVSwUz7foBTRRXX2HUzkRYmUGDXHUwjaPFg/V8e1IvjJof/rg2CpA+RCWTCPwh1qSZqEMoB15QBL8GngahqOAXG1BLg6zDacgqtvaFxzSUi6TQ2lNhogiUVIIksGkgggGZB5ahjL6rt3hutKxg+ZLfjJAMy3rmwOMse/8vDuzwYAVhpTB1AKfma3JobJ/Jr6GuTdjEJDAanjVc3t9AnFwWMY31TeQe0h3hT/E6PV8/GTp569/5NeiTzC8u2/9s67hnt/fdMGWh5kizBvCnI4b3hMEMYaKrYAFDaWhXGQRy9NP0uEVHmFhzQid1dJSWV7mYxULiM7xzrdZLOwPKzu+YAz2by8YYM71/ZBgHMR1IP32tdW9AvrQ24aG633ZvvuiQcgu/0AcEvcAeTF9lfutKXCupcCuT0USE9lPacyzinBfacykxfoY0whUulsDg9eqyFtCJb0sMFs6lwVTtZgdBxRzODLmNE0ikpQRH5PFC1+pjpKGZ1fHkZJoR7mZ0CBNmIsMqoy/Es4O27WIRlIwvw1zf4s+C76C7NwHtG4wSDRQ5Qm5MU8IgBEj5vt27HGEwMGDthpAimaHZXFCZA9+3uBY/5M/w3fVuQnByVyqzSOZu8M2nWwXMX0S09ZuqyOKAEnEGqku5F2B0TaAkjPsjDIKWhBUlfQKHmi+KaP/w63mptQMLfqHeUV/Hma7QSlOnj75SApwov5PKLDQbw7a/CYbnJwmJYNIBsWHlA24KRcaXdAY16CaqHjFcybJHpKs2X8zut/9Y11AZh5S8fiNFkUv4T8EPIJuT9qHcKnYBM3poZSRorbWGupqIF7DoMBw6adwaB4E6lYR0we1lEB8VYIyLPPmYmYFy+YnZlX30nCVzrBJC9UvsL1Kk2ogRkr5COgF40m2eNbA4oC80g62GDyU3KOQmw8e/bUag+5HGKkw4w8oJCM39AHE82C+Lr8YBnN5/QyQiHYiYnRl2jpy4YU3yt/oWBx8lSFNgW+npgIkQCipoQbs91IGGFsN0UCORzWvQlhgzsRH69LYoP5G+7idpEHaGpObE9jalGPDIBfQq7NjynXyOfkh2cTzyTXtqFYrmHElm2KOPw5KIKxYoz6YVFSeF8O8dmJrCSP61UlMpOfnVuVSfbsjPlFd8MF0zMScmoy5me7F97PxAffrDTaElxq1ANtUxXaLkB79pym6xZa5HRIyzW9j+pc8xyKKQiXK69IOqBwgSp9XIfZCwxrp6mb54yTPAGSniIg2Tkm4k++RXnNnSTv/lXeg2RPc+uuXdQjqLrzyfJwR+J98kGH5drHeZ+8V+jwIibJ++Rv2Ebt3id/X9z3T/Y+K5XVWrBPC9AE1ICLnSyeDTxWDVysRg34oOrsagBJM2654kPM7cNGWbbPZS46AhqUGWzp07umixgMozJVvCZ7xybGYq62xT85K7YhWgfaD0jidKx5aQPSw4BwomJC8yFaCpNiPiBL00yF2C5ssiXPKKk+mG2y4jfvPnwJsih4jPW65wHrnsgwINjKFj5NyNIE83kFKTEQWZTQ6YEq9GuQFR+l1ef1NXINcV/WdVeNWodYtCwqBeIe6eEf1h+ohzPMsjWWlNxR+wjIOD6esa6c2p9rNX2GgdZOkeF11Ln47QecHODgDxLnSxL+0Qq6e7Sgc2caavm0U7T55ZfDDwDNATp/uylZeWAsSaZZ6nyDfEg+q7aZMZdu8lWRF8sOH6u3cEbPzx4y4w3DuO1XNIsjozNtxM3aTQGRZSOdk8nNExgDDANMzqbwBAJnS6JkGqbk/IEHb2E8b0gLI0gAqDEE4eqiZA/CVbRO50UCCEv6WGwrc1mYpNmjw9gjnlGwHQA7UkU/YxhuNvmjRl59jT5ahsu0vJRW5548gqD5jTIeAfcoUP4VPAZvVB6DxcuEeWwa6mCBFAj5zx7mMPWoifa80O57Pf9ImH9ElduuQCVkTD/W2drljab+XlEbBVHBpriPgpRqXhifclHCzo9YgxBhlTVr9XWocPAylI3P1tvLghFiUQTwQYzxOfNMkWV1waoq7dTq0RllRP6jurwU2xyVBwm8K8sQisipaSrI8lvPOyQHZQlqv9OEatWe6YVcs+jucEE7COczWubPmG49p/SfU8QSMMScwpRO1AIijopy/vQJwN5o/aAJioMBt9zzAQ7px4GgG00gIKkRF5+5jxwB7STsxCWjLZ4Naac4DLJkp52lQZbuG37wYrKKXa9gFawOIFMRTexAymWzmm/TBktnqRbAaZPb2+S65+vJ7Ewq1ebUErHygY3Fi7f8phwcn17On8nnZWNIT53txjANmTpPsWIJz6hjTGKUHMdDlmMb2OMdCg9JiTj5BjG7zSFUp4J5VvsKhu2ceoDV0S8cRvGS+4U7kJSL05Qu4j+llJRNNsuLerOOYomfxlB5tBQ0u5smdTfoOgqQAUfAwKsqA3FE4ZMTU8zm0UsDS+c/G7ob0g2F9LJE55qG1dR/rj4lr8qsju1Z1ivalXM7FqcEhXKc3Fj9o9pwcV02qsXpUHHim0cgT9QYW5E4udDd19uttMJl+i7Y/uJ8rr0LXXvxkoxGsA1ByzwfgioXxn/x/T5swcSsbusqFy6Na1t6qCba3vk0EXrSIX1dkF8fwq8ZlscUgNuZu6BqkduFi9zlPMk6x+u2wseBzIX2WFASr6xzoesCzCixwn5gCUbT4NafWg3ZHUPk1ikiLuXAhRzRwRxWkOWMFStvsBi7j+jPLC8o4M3IYPkVerv/DvP8vQQ82OQpGUqz/Dld0F0PvtL4vy6weyizVkk5kfxiU3Gd/HL7FvX0ZrV6t5SHYQ6UnPGwlepyXJiPMhL+0bI4jw1UQPXfldDjZxtXUcc0eNMdHKDD858HH8AqzPb/et7zbR5wMsvoTat50TnoflekbuNajLJsTkyq7bgPVjf+TI6ixgVAslmVQm/dOfgA3KGd4JawZErfg85qtFxlqe7oenGMi8pvfWF00nWqohDvg5RqqbSiIr/RHJcV5QXKObb+2+U7/4MzKXNZuowi+I2HHmB2NZjhb8mU3GDGF7ksqlbKqub4Z10r++iZcYKd7oekG/wevTkmlLp6CDu/S1712iz5AcmrBl/HZmBYDW+JdkGwWI+Uk7CEZC95YitaQLAMkmAhqEobn681BqXctUJrbWsgmo+l6OTZ9nV2Zl74+NRHJ+dB6D3NZCgfkrTwQjDiQEOCTshYlHWAeIbmKNRgnU8cvoSxUu2TOzuOQ/dA5wnBiqc63esRzqjRPc98xE6v+XBuh97cOl735KDE5VUiQUa/J0joZ609TttCU5T9rcoPJqpLnfyP5AaPQ9H53hyIcQkDJImRKV4reh+UEERpSEWHLq1W9KkpOqgbGlbRz+ZNT0vRMdjy1oUoqVP0IXPAtaKrUXSMgKJ7Ayp6j+bzOmzuDptNUcm1srCZuIEANh03y4ibTVGNtaq4ueJa9DTbBRNnI7GA3lA2zSIdOH8ATedDMowH9KeRqRX9qMB5WEUfcqVYK/pAgfOwiq4Zsl4ogcAZmwMGzkgzZNNXdBA4Yzxg4IzOxpD5jouDXooeIqLq7plndJvPx0eWYIkZsTxi+bo+JEkWJi9RliZLrfEqNB65GMoSnDYQm0rkK71my47BzeYne9OCqInZMhkdChHSbJkU/bP5VD3TEWifMrbM1GxZT5h4x0hgJJX51qZmyz6ApvO5nNgc0Ldmeq0VvR0lvtZiYEXXbNn0Fd3h9wIaVtE1W9YLJb4+jCg6REmdomu2bPqK7jpA0d0BFf1sgfO02DLHAAwHFmYkqGLLTBgqa7ZsmhrvIMC8YkEOqTq2DMN4HdAuxksUvmo8+9kGzgmwBP1hke+qQhPG9aBQUqPZn8vGCHDZw+IJfW9NisogRZHAN1dHimKRc65DqG5SFDkDhlBYtFqsQ6iJaTpPtyF3wBAKn22/tGkpOiBFh1V0UaCrFX1aig5I0WEVXVfT90IJkKLIgSgpU3RLtBitFX1aig5IUcT6gw+h6JZopXwQRa+eaP82Un2bBe7ShxwlYTO/abVrwKAZC6Mtfq+x4zAzAWaf6izmIPANv4GtYlSBHrqsLVeDChHYcjnllBZcyr5eiPDUrb/2pGIDqtkVJPm6glULOZYUkiBRstrQLpMEilxAZU3WgarkR4o19TnUfFFqtmiLYRn7rCALhrTpJte4deubD8jjgZHrsZ47aAPjVgz2tSmG9vvgfRQq0zOSnsQEcq8hGHCP1b5NiRHY08Xo2dv9iCa9yIJRb/q4DrMX8jTShAU5j5kwsqLAX64LSaCBFXJXb9vIiIutsvA1yOaSTnawc0CE6774a/HTBnYTkNuiNk2FM5Ek42ViXkIFLrxoczYproItitvHa7n2tl4/sb16NfeOxW5ZtiXLbnWfSqLdsmEQH8yONFnaxEgyMVVn6POYGBtaFLm7kXXuvTDhzch6yMvxVs82RS5cX1MofTuy6n4agSvRSvK86Km3HIQRpwvytIcJhU4rNGwFT1EI2/QmdssCdUKQ9V9pUEdylB0rUPaWnQe1rvf2cAQ7yFSD59D1HsUOI/I8dyJngqB5v6SeDJotAI1ZybH4qtg3wP4jx+5XKDqZwVsmmf5qD+aGSNmKvoyWAX3glYP4lQrcQ7qOCv8W3z2meZ4uWyWx8jp5JzOnFuImWK/CGUXnKXqj5vGmuOQ1GzXYCHn9nOerdeEs35P/L6L8efN4RSwpefOd0gRJ8ON9nYfLNRmYpcHsmfw3C17pAwnWBZV4Hy0X5N8gXqQZOXy5vlolCznTkMf1EdxtIt1YixbQcXL4ONtXPAt9KI8TcbNQ28rXqQaNVcg3DJqHZM9CxaGH7gYG2l+aruO0WiuPWzDgD7g4dbMuxEJ1LcnTkGRjopLs4/YDJEgy0pI8JUnu62SOTpI95TZZEKhoSR6vJEuPcQeSZGSwJHd1oqyartGiLFWU+66uy6drHEtLypQkxZqq0at2AVBn9FQsSGlRVibKzmRF2VI+fztalKckytKXVgcTZddTLcquFuUpibI/VVE2OxyMHkd4ZtcRrDBmzxES1MXT6jIldfEmqy6Wcj4N5jr/b/HSNP66yqNl9LNMeubkfYzFD62SJX8XJ4NHE8GCTSTK7JNSJ+bqxamR2iDXEtggZyRTtsVPp04nEcpbLUe2DXL16tSURNk9X46qq1d/piQp/khW5A83ejbLYlJn9PTqz5RE2Z3s/G07hx7hsp0D1Qk/zD8WNznQYQa2ePjMQcMMvV4zVjPlimbckWRzAjPVSXVg3kxJpzpcvV4zKVGWnm8xlChbLKBVJ8qwW8lDGkezd03bHTyfDkzb6aWDsRohNR4+NA+O7V35NtEaY/svn6fg2VeeTXxwc/uv27zAnvqrg80acPLdkufba9YcPibmjjjdrLEM/ppZu/v9+0BG7MRq4yGNmG/wy0IetGGiqkMpJoxVDNVQ+v2PSw2UqBaPd2o9wb7grjKgTADUTZrm5LEFqxWBpdYVZNvjo8CQH9SoAvXDdjeqyBDBasmAFXaIvNPaJ1yiRRZfzzyspYQ017dvtxooCBSoGyJOEKxaVgcUbNr5/ZMGSgCU5XSbPnUwwa1Lvl/+yMMVnLV+v/waBlkSJQs9o/XAFRlmN66+CFdbBq4COuP6qwauD3A2tJwAOEvQ7AG3BK79gYMdPL5f32vgegBX9XdrA06dJYVpf7dBHtJuJDPa70tHBkeqo8+xOYPaUV/An9w/aOB6AGfzbRpFtLAydfQhpVLx+Z/L7RI0jH2sqs+XyrAYrNGLXLRlsRQcTYCJZvdHwe77omZu/kiqG8GCYWeNC2aZkMpqXHxdcz4eUZ6egFqA+W4T0D5XRBx32r8HNyTh/eapJPY09CGTdD3LCdKmcZtFOe2dq6fxzmncwpx5E03jbCld/iwOaaa7u4fPGrkeyGG+7J4EQEMyuT6kkm7jaFUsdxkPD3/VIPYB0eVBdLEAROYiyAcR0koaur6WE65NDqh/Jlv11G7jCNzGegRkGoIIyDTOVstlGrrqb0qSwrJgzx6KoCbDg42ymGR/rGw6rUecHCubbPFMi/IkRNkYC+1zsCibCKkWZV2WOClRHklrk8NF2e44QoIoawZzUqI8mgrbQ0XZ4WIuBaIMWcXPJJ59/hYkwWIiWxMPzUPxKGLBQqCqMp/KxainXgdUWU3jj68aLpg34dhcpi5GBgRMmDkhBS9IG37RQIn0ii1xt8HENsmSDxMkBr9t4jy6DHT3ADEd6PIZ8ELAWAKmfMBgntmPMH66XMXBu4YLZiFx25wPawQRTB/7lKzD5SP5gRorkGrE6nBaFUuUeSsFK5gx9uX6Vi+R9JnCTOhrsPZ8QyyRsA1ga9Ddb3T5lsjX4DVsUJhgmR1trZAmlyyRQwMGFh85wExTMH2p8jVYglDdOfz644vGCSqW7XTjxHbako8TJDO+XW8TbDRSPfx3gRFUp1OQx/j77Q0diIILWsEC6+6Yy7jWzkg3wL5l8hXKng/xFbmRMjbrNZHEjR9+RSZ8/8LAqQnrrHlDcxF+6JXLlrbQNpdV6xvcZfom6Pr8Rr4eKzKWn59rIsgg3W3CuKgTNrZtFJxgSS1P8rheFWjxVuw1ImJiGg+f/tAWrscUZoCMNIGB85UZOF+egeuzCGj5DeN3aVwZ9Ol8YAtIf+pDmEUEmzDbK7AHGj8sMn4jWes2+A5yXtkRab+tZMZ8zxFdxRCGa18h37Bcp/y3eTbXtq9qn5pHm+G2i3hW20UkWujKxPDZ94YmuPqmb3NFkD47SRdvwqF6lMU1Re2JCvzKifOLnjj7wWg4fGwgAtIRUcxSgJTY+lnHBjJjA1ag1Jgex1IrCEpTfdTRWNIycOsRXdMj5htTeu6RoQjCnnlFgnKEXMdyDXBey73afegxL0XFJCixTazWPfW6N5Jt7g/XPYz91iM6dY8qDK993tXRtboImALMzq9AzTD0NW/C5+AlSjPaBce4jVPdQqynu+IgTpJErRaVtd6oGuhoizlYTm/DBLL0SuXts2Hjap4u7G1tHKe9E7c/0k7crsN3uZHdiZs1Q68Zxt+WUa43Fdhj+6we3aPUZRsz3asXvf+m+yf2Qc7l+xWIZi22obR84GDa8R9BlDySuUCD1w0e4rt9+QYS9SlQ1G/PxHozj/F4I80FBF/kokgvdj9+9dT1r5zdH2+DbOcKgU8Pdm+Q7QpOU1tlbbmIvCDr5ufd9/T+t/88/e354eZrtP6U/vN/Lk0Wdmp3XbaCFL//7q26WfruvXzXy5Wveuicm82w+J2nu7uOmVbrEd1dx/grHt91zOLJQzZbDKViFkznv77Vay+9PAsLdkASlDkp8uf3wKn7pQxvS8mj2ZlS+ub9om5XD7SynR1J2xX5/PaYzwnossc22FbgQHvs8Vc82h7bGOxcoawL5Nv/3zyns/zP+PEv69mPhff5xx+Pl1W0UK+F0xkNvcyxjc7a1UyMJ1NLpQZZG13e6O6ZnY4yyBO0qCD57OS+utwqdn+LCm5FnYdrfwujIHh/uV78w/tp/Uz+D8V/EVrUvxED9mtYVGA+BRLe4uBisDWaMEvMEVhU83SLKsZToov7K9rNk6wjZwo7WYI2jTxD3kN7BpGoKfiJOUf8FT2XczQO6E/Oeaa+yZ3qaDtK3mZpmte/XrQNSuch/cZ/AQ==
\ No newline at end of file
+7V1be6M40v41uUwehDhe5tDpmd3u3kz37Dez380+xCYOOxi8GCed/vUrYYRBJQO2JQxpZffpsWVzMG9VqepVVekC3y6/f8yC1fPndB7GF6Yx/36B7y5ME1mmQ/5DR962I66LtwOLLJqXX9oNfIt+hOWgUY5uonm4bnwxT9M4j1bNwVmaJOEsb4wFWZa+Nr/2lMbNq66CRQgGvs2CGI7+Ec3z5+2oZxu78V/CaPHMroyM8pPHYPbXIks3SXm9CxM/FX/bj5cBO1f5/fVzME9fa0P4wwW+zdI0375afr8NY/ps2WPbHne/59PqvrMwyXsdgEtgXoJ4E7J7Lu4sf2NPo/g9IT0CXeCb1+coD7+tghn99JXgT8ae82VcfryIgzV9+gZ5PUuX0ax8vc6z9K/wNo3TrDgrdmZe+PhUfcKeMyYjT1Ec1745D0LvaUbH0yQvhcU0yve17xnFHxkP4miRkLE4fMrp22xWHuWQd/ARlU/tJczy8HttqHxkH8N0GebZG/lKJeBeid8bk1vf2w687sTFYkLxXBMV7JTPNyhFdFGdfQcTeVEiJUbN8RSCNg/Wz9VxrQh+cOj/+iBY6gC5UBbMo3CHWpImoQxgXTnAEnwauJqGY0BcLQGuDrMNp+DqGxrXXBKSblNDqY0GSGIRksiSgSQCSAZknhrGsvruneG6kvFDZgt+MgDzrSubg8zxrzy8+7MBgJXG1AGUgp/ZrYlhMr+mvgZ5N6PQUEDqeFVzO31CcfAYxjeVd1B7iDfF/8Ro9Xz85Klnb3/SK5FHWL79195513Dvr2/aQMuDbBHmTUEO5w2PCcJYQ8UWgMLGsjAO8uil6WeJkCqv8JBG5O4qKalsL5ORymVk51inm2wWlofVPR9wJpuXN2xw59o+CHAugnrwVvvain5hfchNY6P13mzfPfEAZLcfAG6JO4C82P7KnbZUWPdSILeHAumprOdUxjkluO9UZvICfYwpRCqdzeHBazWkDcGSHjaYTZ2rwskajI4jihl8GTOaRlEJisjviaLFz1RHKaPz08MoKdTD/Awo0EaMRUZVhn8JZ8fNOiQDSZi/ptlfBd9Ff2EWziMaNxgkeojShLyYRwSA6HGzfTvWeGLAwAE7TSBFs6OyOAGyZ/8scMyf6b/h9xX5yUGJ3CqNo9kbg3YdLFcx/dJTli6rI0rACYQa6W6k3QGRtgDSsywMcgpakNQVNEqeKL7p43/CreYmFMytekd5BX+eZjtBqQ7efjlIivBiPo/ocBDvzho8ppscHKZlA8iGhQeUDTgpV9od0JiXoFroeAXzJome0mwZv/H6X31jXQBm3tKxOE0WxS8hP4R8Qu6PWofwKdjEjamhlJHiNtZaKmrgnsNgwLBpZzAo3kQq1hGTh3VUQLwVAvLsc2Yi5sULZmfm1XeS8JVOMMkLla9wvUoTamDGCvkI6EWjSfb41oCiwDySDjaY/JScoxAbz549tdpDLocY6TAjDygk4zf0wUSzIL4uP1hG8zm9jFAIdmJi9CVa+rIhxffKXyhYnDxVoU2BrycmQiSAqCnhxmw3EkYY202RQA6HdW9C2OBOxMfrkthg/oa7uF3kAZqaE9vTmFrUIwPgp5Br833KNfI5+eHZxDPJtW0olmsYsWWbIg5/DopgrBijfliUFN6XQ3x2IivJ43pViczkZ+dWZZI9O2N+0d1wwfSMhJyajPnZ7oX3M/HBNyuNtgSXGvVA21SFtgvQnj2n6bqFFjkd0nJN77061zyHYgrC5corkg4oXKBKH9dh9gLD2mnq5jnjJE+ApKcISHaOifiT36O85k6Sd/8q70Gyp7l11y7qEVTd+WR5uCPxPvmgw3Lt47xP3it0eBGT5H3yN2yjdu+Tvy/u+yd7n5XKai3YpwVoAmrAxU4WzwYeqwYuVqMGfFB1djWApBm3XPEu5vZhoyzb5zIXHQENygy29Old00UMhlGZKl6TvWMTYzFX2+KfnBXbEK0D7QckcTrWvLQB6WFAOFExofkQLYVJMR+QpWmmQmwXNtmSZ5RUH8w2WfGbdx++BFkUPMZ63fOAdU9kGBBsZQufJmRpgvm8gpQYiCxK6PRAFfo1yIqP0urz+hq5hrgv67qrRq1DLFoWlQJxj/Twd+sP1MMZZtkaS0ruqH0EZBwfz1hXTu3PtZo+w0Brp8jwOupc/PYDTg5w8DuJ8yUJ/2gF3T1a0LkzDbV82ina/PLL4QeA5gCdv92UrDwwliTTLHW+QT4kn1XbzJhLN/mqyItlh4/VWzij52cPmfGGYdz2M5rFkdGZNuJm7aaAyLKRzsnk5gmMAYYBJmdTeAKBsyVRMg1Tcv7Ag7cwnjekhREkANQYgnB1UbIH4Spap/MiAYQlfSy2lbksTNLs0WHsEc8o2A6AHaminzEMN5v8USOvvkYfLcNlWl5Kq3NPHkHQ/EYZj4B7FCj/DB6DNyqPweJlwjw2DXWwQAqE/GcPc5h61ER7Xmj3vZ5/JMw/osptV6ASMqYf62zt8kZTf6+ojYKoYFPcR0FKNS+MT7koYedHrEGIsMqatfo6VDh4GcrGZ+vtZcEIsSgCeCfG+Jx5psiyumBVlXZq9eiMMiL/UV1eim2OyoME3pVlCEXk1DQVZPmt5x2Sg7IEtd9pQrVqz/RCrll0d7igHYTzGS3zZ0y3nlP6zyliCRhiTmFKJ2oBEUdFOX/6BGBvtH7QBMXBgFvu+QCH9ONA0I0mEJDUiIvP3EeOgHYSduKS0RbPhrRTHAZZstPO0iBL9w3feTFZxa5XsApWB5CpiCZ2IOWyWc23aYOls1QL4LTJ7W1y3fP1ZHYmlWpzaolY+cDG4sVbflMOjk8v58/k87IxpKfOdmOYhkydp1ixhGfUMSYxSo7jIcuxDezxDoWHpEScfIOY3eYQqlPBPKt9BcN2Tj3A6ugXDqN4yf3CHUjKxWlKF/GfUkrKJpvlRb1ZR7HET2OoPFoKmt1Nk7obdB0FyIAjYOBVlYE4ovDJiSlm8+ilgaXz3w3dDemGQnpZonNNw2rqP1efkldlVsf2LOsV7cq5HYtTgkI5Tm6s/lFtuLguG9XidKg48c0jkCdqjK1InFzo7uvtVlrhMn0XbH9xPtfeha69eElGI9iGoGWeD0GVC+M/+X4ftmBiVrd1lQuXxrUtPVQTbe98mgg96ZC+Lsivd+HXDMtjCsDtzF1QtcjtwkXucp5kneN1W+HjQOZCeywoiVfWudB1AWaUWGE/sASjaXDrT62G7I4hcusUEZdy4EKO6GAOK8hyxoqVN1iM3Uf0Z5YXFPBmZLD8Cr3d/4R5/lYCHmzylAylWf6cLuiuB59o/F8X2D2UWauknEh+sam4Tn65fYt6erNavVvKwzAHSs542Ep1OS7MRxkJ/2hZnMcGKqD670ro8bONq6hjGrzpDg7Q4fnPgw9gFWb7fz3v+TYPOJll9KbVvOgcdL8rUrdxLUZZNicm1XbcB6sbfyZHUeMCINmsSqG37hx8AO7QTnBLWDKl70FnNVquslR3dL04xkXlt74wOuk6VVGI905KtVRaUZHfaI7LivIC5Rxb/+3ynf/BmZS5LF1GEfzGQw8wuxrM8LdkSm4w44tcFlUrZVVz/LOulb33zDjBTvdD0g1+j94cE0pdPYSd3yWvem2W/IDkVYOvYzMwrIa3RLsgWKxHyklYQrKXPLEVLSBYBkmwEFSljc/XGoNS7lqhtbY1EM3HUnTybPs6OzMvfHzqo5PzIPSeZjKUD0laeCEYcaAhQSdkLMo6QDxDcxRqsM4nDl/CWKn2yZ0dx6F7oPOEYMVTne71CGfU6J5nPmKn13w4t0Nvbh2ve3JQ4vIqkSCj3xMk9LPWHqdtoSnK/lblBxPVpU7+e3KDx6HofG8OxLiEAZLEyBSvFb0PSgiiNKSiQ5dWK/rUFB3UDQ2r6Gfzpqel6BhseetClNQp+pA54FrR1Sg6RkDRvQEVvUfzeR02d4fNpqjkWlnYTNxAAJuOm2XEzaaoxlpV3FxxLXqa7YKJs5FYQG8om2aRDpzfgabzIRnGA/rTyNSKflTgPKyiD7lSrBV9oMB5WEXXDFkvlEDgjM0BA2ekGbLpKzoInDEeMHBGZ2PIfMfFQS9FDxFRdffMM7rN5+MjS7DEjFgesXxdH5IkC5OXKEuTpdZ4FRqPXAxlCU4biE0l8pVes2XH4Gbzk71pQdTEbJmMDoUIabZMiv7ZfKqe6Qi0TxlbZmq2rCdMvGMkMJLKfGtTs2XvQNP5XE5sDuhbM73Wit6OEl9rMbCia7Zs+oru8HsBDavomi3rhRJfH0YUHaKkTtE1WzZ9RXcdoOjugIp+tsB5WmyZYwCGAwszElSxZSYMlTVbNk2NdxBgXrEgh1QdW4ZhvA5oF+MlCl81nv1sA+cEWIL+sMh3VaEJ43pQKKnR7M9lYwS47GHxhL63JkVlkKJI4JurI0WxyDnXIVQ3KYqcAUMoLFot1iHUxDSdp9uQO2AIhc+2X9q0FB2QosMquijQ1Yo+LUUHpOiwiq6r6XuhBEhR5ECUlCm6JVqM1oo+LUUHpChi/cGHUHRLtFI+iKJXT7R/G6m+zQJ36UOOkrCZ37TaNWDQjIXRFr/X2HGYmQCzD3UWcxD4ht/AVjGqQA9d1parQYUIbLmcckoLLmVfL0R46tZfe1KxAdXsCpJ8XcGqhRxLCkmQKFltaJdJAkUuoLIm60BV8iPFmvocar4oNVu0xbCMfVaQBUPadJNr3Lr1zQfk8cDI9VjPHbSBcSsG+9oUQ/t98D4KlekZSU9iArnXEAy4x2rfpsQI7Oli9OztfkSTXmTBqDd9XIfZC3kaacKCnMdMGFlR4C/XhSTQwAq5q+/byIiLrbLwNcjmkk52sHNAhOu++Gvx0wZ2E5DbojZNhTORJONlYl5CBS68aHM2Ka6CLYrbx2u59rZeP7G9ejX3jsVuWbYly251n0qi3bJhEB/MjjRZ2sRIMjFVZ+jzmBgbWhS5u5F17r0w4c3IesjL8VbPNkUuXF9TKH07sup+GoEr0UryvOiptxyEEacL8rSHCYVOKzRsBU9RCNv0JnbLAnVCkPVfaVBHcpQdK1D2lp0Hta739nAEO8hUg+fQ9R7FDiPyPHciZ4Kgeb+kngyaLQCNWcmx+KrYN8D+I8fuVyg6mcFbJpn+ag/mhkjZir6MlgF94JWD+IkK3EO6jgr/Ft89pnmeLlslsfI6eSczpxbiJlivwhlF5yn6Ts3jTXHJazZqsBHy+jnPV+vCWb4n/19E+fPm8YpYUvLmC6UJkuDb2zoPl2syMEuD2TP5bxa80gcSrAsq8T5aLsi/QbxIM3L4cn21ShZypiGP6yO420S6sRYtoOPk8HG2r3gWelceJ+JmobaVr1MNGquQbxg0D8mehYpDD90NDLS/NF3HabVWHrdgwB9wcepmXYiF6lqSpyHJxkQl2cftB0iQZKQleUqS3NfJHJ0ke8ptsiBQ0ZI8XkmWHuMOJMnIYEnu6kRZNV2jRVmqKPddXZdP1ziWlpQpSYo1VaNX7QKgzuipWJDSoqxMlJ3JirKlfP52tChPSZSlL60OJsqup1qUXS3KUxJlf6qibHY4GD2O8MyuI1hhzJ4jJKiLp9VlSuriTVZdLOV8Gsx1/r/ipWn8Y5VHy+hHmfTMyfsYix9aJUv+Lk4GjyaCBZtIlNknpU7M1YtTI7VBriWwQc5IpmyLn06dTiKUt1qObBvk6tWpKYmye74cVVev/kxJUvyRrMgfbvRslsWkzujp1Z8pibI72fnbdg49wmU7B6oTfph/LG5yoMMMbPHwmYOGGXq9ZqxmyhXNuCPJ5gRmqpPqwLyZkk51uHq9ZlKiLD3fYihRtlhAq06UYbeShzSOZm+atjt4Ph2YttNLB2M1Qmo8fGgeHNu78m2iNcb2Xz5PwbOvPJv44Ob2X7d5gT31VwebNeDkuyXPt9esOXxMzB1xulljGfw1s3b325eBjNiJ1cZDGjHf4JeFPGjDRFWHUkwYqxiqofTb10sNlKgWj3dqPcG+4K4yoEwA1E2a5uSxBasVgaXWFWTb46PAkB/UqAL1w3Y3qsgQwWrJgBV2iLzT2idcokUWX888rKWENNfnz7caKAgUqBsiThCsWlYHFGza+eWDBkoAlOV0mz51MMGtS75cfsvDFZy1frv8FAZZEiULPaP1wBUZZjeuvghXWwauAjrj+pMGrg9wNrScADhL0OwBtwSu/YGDHTy+XN9r4HoAV/V3awNOnSWFaX+3QR7SbiQz2u9LRwZHqqPPsTmD2lFfwJ/cP2jgegBn820aRbSwMnX0IaVS8fkfy+0SNIx9rKrPl8qwGKzRi1y0ZbEUHE2AiWb3R8Hu+6Jmbv5IqhvBgmFnjQtmmZDKalx8XXM+HlGenoBagPluE9A+V0Qcd9q/Bzck4f3mqST2NPQhk3Q9ywnSpnGbRTntnaun8c5p3MKceRNN42wpXf4sDmmmu7uHjxq5HshhvuyeBEBDMrk+pJJu42hVLHcZDw//0CD2AdHlQXSxAETmIsgHEdJKGrq+lhOuTQ6ofyZb9dRu4wjcxnoEZBqCCMg0zlbLZRq66m9KksKyYM8eiqAmw4ONsphkf6xsOq1HnBwrm2zxTIvyJETZGAvtc7AomwipFmVdljgpUR5Ja5PDRdnuOEKCKGsGc1KiPJoK20NF2eFiLgWiDFnFjySeff4cJMFiIlsTD81D8ShiwUKgqjKfysWop14HVFlN4+snDRfMm3BsLlMXIwMCJsyckIIXpA1/0UCJ9IotcbfBxDbJkg8TJAY/b+I8ugx09wAxHejyGfBCwFgCpnzAYJ7ZtzB+ulzFwZuGC2YhcducD2sEEUwf+5Csw+Uj+YEaK5BqxOpwWhVLlHkrBSuYMfbL9a1eIukzhZnQ12Dt+YZYImEbwNagu9/o8i2Rr8Fr2KAwwTI72lohTS5ZIocGDCw+coCZpmD6UuVrsAShunP46dsvGieoWLbTjRPbaUs+TpDM+Hy9TbDRSPXw3wVGUJ1OQR7jn7c3dCAKLmgFC6y7Yy7jWjsj3QD7lslXKHs+xFfkRsrYrNdEEjd++BmZ8P0LA6cmrLPmDc1F+KFXLlvaQttcVq1vcJfpm6Dr8xv5eqzIWH5+rokgg3S3CeOiTtjYtlFwgiW1PMnjelWgxVux14iIiWk8fPiqLVyPKcwAGWkCA+crM3C+PAPXZxHQ8hvG79K4MujTeccWkP7UhzCLCDZhtldgDzR+WGT8RrLWbfAd5LyyI9J+W8mM+Z4juoohDNe+Qr5huU75b/Nsrm1f1T41jzbDbRfxrLaLSLTQlYnhs+8NTXD1Td/miiB9dpIu3oRD9SiLa4raExX4lRPnL3ri7Aej4fCxgQhIR0QxSwFSYutnHRvIjA1YgVJjehxLrSAoTfVRR2NJy8CtR3RNj5hvTOm5R4YiCHvmFQnKEXIdyzXAeS33avehx7wUFZOgxDaxWvfU695Itrk/XPcw9luP6NQ9qjC89nlXR9fqImAKMDu/AjXD0Ne8CZ+DlyjNaBcc4zZOdQuxnu6KgzhJErVaVNZ6o2qgoy3mYDm9DRPI0iuVt8+Gjat5urC3tXGc9k7c/kg7cbsO3+VGdidu1gy9Zhh/XUa53lRgj+2zenSPUpdtzHSvXvT+q+6f2Ac5l+9XIJq12IbS8oGDacdfgyh5JHOBBq8bPMR3+/INJOpToKjfnon1Zh7j8UaaCwi+yEWRXux+/Oqp6185uz/eBtnOFQKfHuzeINsVnKa2ytpyEXlB1s2Puy/p/a//ffr9+eHmU7T+kP75t0uThZ3aXZetIMXvv/te3Sx991a+6+XKVz10zs1mWPzO091dx0yr9YjurmP8FY/vOmbx5CGbLYZSMQum81/f6rWXXp6FBTsgCcqcFPnze+DU/VKGt6Xk0exMKX3zdlG3qwda2c6OpO2KfH57zOcEdNljG2wrcKA99vgrHm2PbQx2rlDWBfL7v2+e01n+V/z49/Xs28L7+O3r42UVLdRr4XRGQy9zbKOzdjUT48nUUqlB1kaXN7p7ZqejDPIELSpIPju5ry63it3fooJbUefh2p/DKAjeXq4Xf3g/rB/J/6P470KL+jsxYNqi9nBwMdgaTZgl5ggsqnm6RRXjKdHF/RntpmzrWPdM2xXw3HYUu+0ZRKKm4CfmHPFX9FzO0TigPznnmfomdyp5dvRu/fX192Q2u3+9fHj6/ePt459f34R29I80XhFrEiXDtRca1nwCWylQgRbzye/zIzKevijXQYLxFIMo0R3VxvNA48nZv05r2qaGZzCeCLgCBguv9lstj6//6jK5Nr9I7pfdYXqbXHCfnlOFgUeYXcT/AJ8/2dGGl7zN0jSvf73o15bOQ/qN/wE=
\ No newline at end of file
diff --git a/rl_coach/agents/agent.py b/rl_coach/agents/agent.py
index 3db0aaf..5d12e0b 100644
--- a/rl_coach/agents/agent.py
+++ b/rl_coach/agents/agent.py
@@ -1003,7 +1003,7 @@ class Agent(AgentInterface):
"""
Allows setting a directive for the agent to follow. This is useful in hierarchy structures, where the agent
has another master agent that is controlling it. In such cases, the master agent can define the goals for the
- slave agent, define it's observation, possible actions, etc. The directive type is defined by the agent
+ slave agent, define its observation, possible actions, etc. The directive type is defined by the agent
in-action-space.
:param action: The action that should be set as the directive
diff --git a/rl_coach/agents/wolpertinger_agent.py b/rl_coach/agents/wolpertinger_agent.py
new file mode 100644
index 0000000..a16b9e9
--- /dev/null
+++ b/rl_coach/agents/wolpertinger_agent.py
@@ -0,0 +1,131 @@
+#
+# Copyright (c) 2019 Intel Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import copy
+from typing import Union
+from collections import OrderedDict
+import numpy as np
+
+from rl_coach.agents.ddpg_agent import DDPGAlgorithmParameters, DDPGActorNetworkParameters, \
+ DDPGCriticNetworkParameters, DDPGAgent
+from rl_coach.base_parameters import AgentParameters
+from rl_coach.core_types import ActionInfo
+from rl_coach.exploration_policies.additive_noise import AdditiveNoiseParameters
+from rl_coach.memories.episodic.episodic_experience_replay import EpisodicExperienceReplayParameters
+from rl_coach.memories.non_episodic.differentiable_neural_dictionary import AnnoyDictionary
+from rl_coach.spaces import DiscreteActionSpace, BoxActionSpace
+from rl_coach.architectures.head_parameters import WolpertingerActorHeadParameters
+
+
+class WolpertingerCriticNetworkParameters(DDPGCriticNetworkParameters):
+ def __init__(self, use_batchnorm=False):
+ super().__init__(use_batchnorm=use_batchnorm)
+
+
+class WolpertingerActorNetworkParameters(DDPGActorNetworkParameters):
+ def __init__(self, use_batchnorm=False):
+ super().__init__()
+ self.heads_parameters = [WolpertingerActorHeadParameters(batchnorm=use_batchnorm)]
+
+
+class WolpertingerAlgorithmParameters(DDPGAlgorithmParameters):
+ def __init__(self):
+ super().__init__()
+ self.action_embedding_width = 1
+ self.k = 1
+
+
+class WolpertingerAgentParameters(AgentParameters):
+ def __init__(self, use_batchnorm=False):
+ exploration_params = AdditiveNoiseParameters()
+ exploration_params.noise_as_percentage_from_action_space = False
+
+ super().__init__(algorithm=WolpertingerAlgorithmParameters(),
+ exploration=exploration_params,
+ memory=EpisodicExperienceReplayParameters(),
+ networks=OrderedDict(
+ [("actor", WolpertingerActorNetworkParameters(use_batchnorm=use_batchnorm)),
+ ("critic", WolpertingerCriticNetworkParameters(use_batchnorm=use_batchnorm))]))
+
+ @property
+ def path(self):
+ return 'rl_coach.agents.wolpertinger_agent:WolpertingerAgent'
+
+
+# Deep Reinforcement Learning in Large Discrete Action Spaces - https://arxiv.org/pdf/1512.07679.pdf
+class WolpertingerAgent(DDPGAgent):
+ def __init__(self, agent_parameters, parent: Union['LevelManager', 'CompositeAgent'] = None):
+ super().__init__(agent_parameters, parent)
+
+ def learn_from_batch(self, batch):
+ # replay buffer holds the actions in the discrete manner, as the agent is expected to act with discrete actions
+ # with the BoxDiscretization output filter. But DDPG needs to work on continuous actions, thus converting to
+ # continuous actions. This is actually a duplicate since this filtering is also done before applying actions on
+ # the environment. So might want to somehow reuse that conversion. Maybe can hold this information in the info
+ # dictionary of the transition.
+
+ output_action_filter = \
+ list(self.output_filter.action_filters.values())[0]
+ continuous_actions = []
+ for action in batch.actions():
+ continuous_actions.append(output_action_filter.filter(action))
+ batch._actions = np.array(continuous_actions).squeeze()
+
+ return super().learn_from_batch(batch)
+
+ def train(self):
+ return super().train()
+
+ def choose_action(self, curr_state):
+ if not isinstance(self.spaces.action, DiscreteActionSpace):
+ raise ValueError("WolpertingerAgent works only for discrete control problems")
+
+ # convert to batch so we can run it through the network
+ tf_input_state = self.prepare_batch_for_inference(curr_state, 'actor')
+ actor_network = self.networks['actor'].online_network
+ critic_network = self.networks['critic'].online_network
+ proto_action = actor_network.predict(tf_input_state)
+ proto_action = np.expand_dims(self.exploration_policy.get_action(proto_action), 0)
+
+ nn_action_embeddings, indices, _, _ = self.knn_tree.query(keys=proto_action, k=self.ap.algorithm.k)
+
+ # now move the actions through the critic and choose the one with the highest q value
+ critic_inputs = copy.copy(tf_input_state)
+ critic_inputs['observation'] = np.tile(critic_inputs['observation'], (self.ap.algorithm.k, 1))
+ critic_inputs['action'] = nn_action_embeddings[0]
+ q_values = critic_network.predict(critic_inputs)[0]
+ action = int(indices[0][np.argmax(q_values)])
+ self.action_signal.add_sample(action)
+ return ActionInfo(action=action, action_value=0)
+
+ def init_environment_dependent_modules(self):
+ super().init_environment_dependent_modules()
+ self.knn_tree = self.get_initialized_knn()
+
+ # TODO - ideally the knn should not be defined here, but somehow be defined by the user in the preset
+ def get_initialized_knn(self):
+ num_actions = len(self.spaces.action.actions)
+ action_max_abs_range = self.spaces.action.filtered_action_space.max_abs_range if \
+ (hasattr(self.spaces.action, 'filtered_action_space') and
+ isinstance(self.spaces.action.filtered_action_space, BoxActionSpace)) \
+ else 1.0
+ keys = np.expand_dims((np.arange(num_actions) / (num_actions - 1) - 0.5) * 2, 1) * action_max_abs_range
+ values = np.expand_dims(np.arange(num_actions), 1)
+ knn_tree = AnnoyDictionary(dict_size=num_actions, key_width=self.ap.algorithm.action_embedding_width)
+ knn_tree.add(keys, values, force_rebuild_tree=True)
+
+ return knn_tree
+
diff --git a/rl_coach/architectures/head_parameters.py b/rl_coach/architectures/head_parameters.py
index ee607dd..207ea3e 100644
--- a/rl_coach/architectures/head_parameters.py
+++ b/rl_coach/architectures/head_parameters.py
@@ -108,6 +108,17 @@ class DDPGActorHeadParameters(HeadParameters):
self.batchnorm = batchnorm
+class WolpertingerActorHeadParameters(HeadParameters):
+ def __init__(self, activation_function: str ='tanh', name: str='policy_head_params', batchnorm: bool=True,
+ num_output_head_copies: int = 1, rescale_gradient_from_head_by_factor: float = 1.0,
+ loss_weight: float = 1.0, dense_layer=None):
+ super().__init__(parameterized_class_name="WolpertingerActorHead", activation_function=activation_function, name=name,
+ dense_layer=dense_layer, num_output_head_copies=num_output_head_copies,
+ rescale_gradient_from_head_by_factor=rescale_gradient_from_head_by_factor,
+ loss_weight=loss_weight)
+ self.batchnorm = batchnorm
+
+
class DNDQHeadParameters(HeadParameters):
def __init__(self, activation_function: str ='relu', name: str='dnd_q_head_params',
num_output_head_copies: int = 1, rescale_gradient_from_head_by_factor: float = 1.0,
diff --git a/rl_coach/architectures/tensorflow_components/heads/__init__.py b/rl_coach/architectures/tensorflow_components/heads/__init__.py
index 03c237a..0a83399 100644
--- a/rl_coach/architectures/tensorflow_components/heads/__init__.py
+++ b/rl_coach/architectures/tensorflow_components/heads/__init__.py
@@ -18,6 +18,7 @@ from .classification_head import ClassificationHead
from .cil_head import RegressionHead
from .td3_v_head import TD3VHead
from .ddpg_v_head import DDPGVHead
+from .wolpertinger_actor_head import WolpertingerActorHead
__all__ = [
'CategoricalQHead',
@@ -38,6 +39,7 @@ __all__ = [
'SACQHead',
'ClassificationHead',
'RegressionHead',
- 'TD3VHead'
- 'DDPGVHead'
+ 'TD3VHead',
+ 'DDPGVHead',
+ 'WolpertingerActorHead'
]
diff --git a/rl_coach/architectures/tensorflow_components/heads/wolpertinger_actor_head.py b/rl_coach/architectures/tensorflow_components/heads/wolpertinger_actor_head.py
new file mode 100644
index 0000000..3521a95
--- /dev/null
+++ b/rl_coach/architectures/tensorflow_components/heads/wolpertinger_actor_head.py
@@ -0,0 +1,59 @@
+#
+# Copyright (c) 2019 Intel Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import tensorflow as tf
+
+from rl_coach.architectures.tensorflow_components.layers import batchnorm_activation_dropout, Dense
+from rl_coach.architectures.tensorflow_components.heads.head import Head
+from rl_coach.base_parameters import AgentParameters
+from rl_coach.core_types import Embedding
+from rl_coach.spaces import SpacesDefinition, BoxActionSpace
+
+
+class WolpertingerActorHead(Head):
+ def __init__(self, agent_parameters: AgentParameters, spaces: SpacesDefinition, network_name: str,
+ head_idx: int = 0, loss_weight: float = 1., is_local: bool = True, activation_function: str='tanh',
+ batchnorm: bool=True, dense_layer=Dense, is_training=False):
+ super().__init__(agent_parameters, spaces, network_name, head_idx, loss_weight, is_local, activation_function,
+ dense_layer=dense_layer, is_training=is_training)
+ self.name = 'wolpertinger_actor_head'
+ self.return_type = Embedding
+ self.action_embedding_width = agent_parameters.algorithm.action_embedding_width
+ self.batchnorm = batchnorm
+ self.output_scale = self.spaces.action.filtered_action_space.max_abs_range if \
+ (hasattr(self.spaces.action, 'filtered_action_space') and
+ isinstance(self.spaces.action.filtered_action_space, BoxActionSpace)) \
+ else None
+
+ def _build_module(self, input_layer):
+ # mean
+ pre_activation_policy_value = self.dense_layer(self.action_embedding_width)(input_layer,
+ name='actor_action_embedding')
+ self.proto_action = batchnorm_activation_dropout(input_layer=pre_activation_policy_value,
+ batchnorm=self.batchnorm,
+ activation_function=self.activation_function,
+ dropout_rate=0,
+ is_training=self.is_training,
+ name="BatchnormActivationDropout_0")[-1]
+ if self.output_scale is not None:
+ self.proto_action = tf.multiply(self.proto_action, self.output_scale, name='proto_action')
+
+ self.output = [self.proto_action]
+
+ def __str__(self):
+ result = [
+ 'Dense (num outputs = {})'.format(self.action_embedding_width)
+ ]
+ return '\n'.join(result)
diff --git a/rl_coach/exploration_policies/additive_noise.py b/rl_coach/exploration_policies/additive_noise.py
index 8194718..8b67c7d 100644
--- a/rl_coach/exploration_policies/additive_noise.py
+++ b/rl_coach/exploration_policies/additive_noise.py
@@ -62,7 +62,9 @@ class AdditiveNoise(ContinuousActionExplorationPolicy):
self.evaluation_noise = evaluation_noise
self.noise_as_percentage_from_action_space = noise_as_percentage_from_action_space
- if not isinstance(action_space, BoxActionSpace):
+ if not isinstance(action_space, BoxActionSpace) and \
+ (hasattr(action_space, 'filtered_action_space') and not
+ isinstance(action_space.filtered_action_space, BoxActionSpace)):
raise ValueError("Additive noise exploration works only for continuous controls."
"The given action space is of type: {}".format(action_space.__class__.__name__))
diff --git a/rl_coach/exploration_policies/exploration_policy.py b/rl_coach/exploration_policies/exploration_policy.py
index a345895..688fcce 100644
--- a/rl_coach/exploration_policies/exploration_policy.py
+++ b/rl_coach/exploration_policies/exploration_policy.py
@@ -115,5 +115,8 @@ class ContinuousActionExplorationPolicy(ExplorationPolicy):
"""
:param action_space: the action space used by the environment
"""
- assert isinstance(action_space, BoxActionSpace) or isinstance(action_space, GoalsSpace)
+ assert isinstance(action_space, BoxActionSpace) or \
+ (hasattr(action_space, 'filtered_action_space') and
+ isinstance(action_space.filtered_action_space, BoxActionSpace)) or \
+ isinstance(action_space, GoalsSpace)
super().__init__(action_space)
diff --git a/rl_coach/filters/action/partial_discrete_action_space_map.py b/rl_coach/filters/action/partial_discrete_action_space_map.py
index 2322698..ad6e105 100644
--- a/rl_coach/filters/action/partial_discrete_action_space_map.py
+++ b/rl_coach/filters/action/partial_discrete_action_space_map.py
@@ -48,7 +48,8 @@ class PartialDiscreteActionSpaceMap(ActionFilter):
def get_unfiltered_action_space(self, output_action_space: ActionSpace) -> DiscreteActionSpace:
self.output_action_space = output_action_space
- self.input_action_space = DiscreteActionSpace(len(self.target_actions), self.descriptions)
+ self.input_action_space = DiscreteActionSpace(len(self.target_actions), self.descriptions,
+ filtered_action_space=output_action_space)
return self.input_action_space
def filter(self, action: ActionType) -> ActionType:
diff --git a/rl_coach/memories/non_episodic/differentiable_neural_dictionary.py b/rl_coach/memories/non_episodic/differentiable_neural_dictionary.py
index 3368ee8..8633118 100644
--- a/rl_coach/memories/non_episodic/differentiable_neural_dictionary.py
+++ b/rl_coach/memories/non_episodic/differentiable_neural_dictionary.py
@@ -57,7 +57,7 @@ class AnnoyDictionary(object):
self.built_capacity = 0
- def add(self, keys, values, additional_data=None):
+ def add(self, keys, values, additional_data=None, force_rebuild_tree=False):
if not additional_data:
additional_data = [None] * len(keys)
@@ -96,7 +96,7 @@ class AnnoyDictionary(object):
if len(self.buffered_indices) >= self.min_update_size:
self.min_update_size = max(self.initial_update_size, int(self.curr_size * 0.02))
self._rebuild_index()
- elif self.rebuild_on_every_update:
+ elif force_rebuild_tree or self.rebuild_on_every_update:
self._rebuild_index()
self.current_timestamp += 1
diff --git a/rl_coach/presets/Mujoco_Wolpertinger.py b/rl_coach/presets/Mujoco_Wolpertinger.py
new file mode 100644
index 0000000..f12e41c
--- /dev/null
+++ b/rl_coach/presets/Mujoco_Wolpertinger.py
@@ -0,0 +1,57 @@
+from collections import OrderedDict
+
+from rl_coach.architectures.layers import Dense
+from rl_coach.base_parameters import VisualizationParameters, PresetValidationParameters, EmbedderScheme
+from rl_coach.core_types import EnvironmentEpisodes, EnvironmentSteps
+from rl_coach.environments.environment import SingleLevelSelection
+from rl_coach.environments.gym_environment import GymVectorEnvironment, mujoco_v2
+from rl_coach.filters.action import BoxDiscretization
+from rl_coach.filters.filter import OutputFilter
+from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager
+from rl_coach.graph_managers.graph_manager import ScheduleParameters
+from rl_coach.agents.wolpertinger_agent import WolpertingerAgentParameters
+
+####################
+# Graph Scheduling #
+####################
+schedule_params = ScheduleParameters()
+schedule_params.improve_steps = EnvironmentSteps(2000000)
+schedule_params.steps_between_evaluation_periods = EnvironmentEpisodes(20)
+schedule_params.evaluation_steps = EnvironmentEpisodes(1)
+schedule_params.heatup_steps = EnvironmentSteps(3000)
+
+#########
+# Agent #
+#########
+agent_params = WolpertingerAgentParameters()
+agent_params.network_wrappers['actor'].input_embedders_parameters['observation'].scheme = [Dense(400)]
+agent_params.network_wrappers['actor'].middleware_parameters.scheme = [Dense(300)]
+agent_params.network_wrappers['critic'].input_embedders_parameters['observation'].scheme = [Dense(400)]
+agent_params.network_wrappers['critic'].middleware_parameters.scheme = [Dense(300)]
+agent_params.network_wrappers['critic'].input_embedders_parameters['action'].scheme = EmbedderScheme.Empty
+agent_params.output_filter = \
+ OutputFilter(
+ action_filters=OrderedDict([
+ ('discretization', BoxDiscretization(num_bins_per_dimension=int(1e6)))
+ ]),
+ is_a_reference_filter=False
+ )
+
+###############
+# Environment #
+###############
+env_params = GymVectorEnvironment(level=SingleLevelSelection(mujoco_v2))
+
+########
+# Test #
+########
+preset_validation_params = PresetValidationParameters()
+preset_validation_params.test = True
+preset_validation_params.min_reward_threshold = 500
+preset_validation_params.max_episodes_to_achieve_reward = 1000
+preset_validation_params.reward_test_level = 'inverted_pendulum'
+preset_validation_params.trace_test_levels = ['inverted_pendulum']
+
+graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params,
+ schedule_params=schedule_params, vis_params=VisualizationParameters(),
+ preset_validation_params=preset_validation_params)
diff --git a/rl_coach/spaces.py b/rl_coach/spaces.py
index 503598c..5dcaa2b 100644
--- a/rl_coach/spaces.py
+++ b/rl_coach/spaces.py
@@ -385,7 +385,8 @@ class DiscreteActionSpace(ActionSpace):
"""
A discrete action space with action indices as actions
"""
- def __init__(self, num_actions: int, descriptions: Union[None, List, Dict]=None, default_action: np.ndarray=None):
+ def __init__(self, num_actions: int, descriptions: Union[None, List, Dict]=None, default_action: np.ndarray=None,
+ filtered_action_space=None):
super().__init__(1, low=0, high=num_actions-1, descriptions=descriptions)
# the number of actions is mapped to high
@@ -395,6 +396,9 @@ class DiscreteActionSpace(ActionSpace):
else:
self.default_action = default_action
+ if filtered_action_space is not None:
+ self.filtered_action_space = filtered_action_space
+
@property
def actions(self) -> List[ActionType]:
return list(range(0, int(self.high[0]) + 1))