mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 11:10:20 +01:00
Create a dataset using an agent (#306)
Generate a dataset using an agent (allowing to select between this and a random dataset)
This commit is contained in:
@@ -685,7 +685,10 @@ class Agent(AgentInterface):
|
||||
"""
|
||||
loss = 0
|
||||
if self._should_train():
|
||||
self.training_epoch += 1
|
||||
if self.ap.is_batch_rl_training:
|
||||
# when training an agent for generating a dataset in batch-rl, we don't want it to be counted as part of
|
||||
# the training epochs. we only care for training epochs in batch-rl anyway.
|
||||
self.training_epoch += 1
|
||||
for network in self.networks.values():
|
||||
network.set_is_training(True)
|
||||
|
||||
@@ -1047,3 +1050,11 @@ class Agent(AgentInterface):
|
||||
TimeTypes.EnvironmentSteps: self.total_steps_counter,
|
||||
TimeTypes.WallClockTime: self.agent_logger.get_current_wall_clock_time(),
|
||||
TimeTypes.Epoch: self.training_epoch}[self.parent_level_manager.parent_graph_manager.time_metric]
|
||||
|
||||
def freeze_memory(self):
|
||||
"""
|
||||
Shuffle episodes in the memory and freeze it to make sure that no extra data is being pushed anymore.
|
||||
:return: None
|
||||
"""
|
||||
self.call_memory('shuffle_episodes')
|
||||
self.call_memory('freeze')
|
||||
|
||||
@@ -13,7 +13,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from typing import Union
|
||||
|
||||
import numpy as np
|
||||
@@ -83,13 +82,22 @@ class CategoricalDQNAgent(ValueOptimizationAgent):
|
||||
|
||||
# prediction's format is (batch,actions,atoms)
|
||||
def get_all_q_values_for_states(self, states: StateType):
|
||||
q_values = None
|
||||
if self.exploration_policy.requires_action_values():
|
||||
q_values = self.get_prediction(states,
|
||||
outputs=[self.networks['main'].online_network.output_heads[0].q_values])
|
||||
else:
|
||||
q_values = None
|
||||
|
||||
return q_values
|
||||
|
||||
def get_all_q_values_for_states_and_softmax_probabilities(self, states: StateType):
|
||||
actions_q_values, softmax_probabilities = None, None
|
||||
if self.exploration_policy.requires_action_values():
|
||||
outputs = [self.networks['main'].online_network.output_heads[0].q_values,
|
||||
self.networks['main'].online_network.output_heads[0].softmax]
|
||||
actions_q_values, softmax_probabilities = self.get_prediction(states, outputs=outputs)
|
||||
|
||||
return actions_q_values, softmax_probabilities
|
||||
|
||||
def learn_from_batch(self, batch):
|
||||
network_keys = self.ap.network_wrappers['main'].input_embedders_parameters.keys()
|
||||
|
||||
|
||||
@@ -182,7 +182,7 @@ class DFPAgent(Agent):
|
||||
action_values = None
|
||||
|
||||
# choose action according to the exploration policy and the current phase (evaluating or training the agent)
|
||||
action = self.exploration_policy.get_action(action_values)
|
||||
action, _ = self.exploration_policy.get_action(action_values)
|
||||
|
||||
if action_values is not None:
|
||||
action_values = action_values.squeeze()
|
||||
|
||||
@@ -49,6 +49,7 @@ class DQNNetworkParameters(NetworkParameters):
|
||||
self.batch_size = 32
|
||||
self.replace_mse_with_huber_loss = True
|
||||
self.create_target_network = True
|
||||
self.should_get_softmax_probabilities = False
|
||||
|
||||
|
||||
class DQNAgentParameters(AgentParameters):
|
||||
|
||||
@@ -16,7 +16,7 @@
|
||||
|
||||
import os
|
||||
import pickle
|
||||
from typing import Union
|
||||
from typing import Union, List
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -40,6 +40,7 @@ class NECNetworkParameters(NetworkParameters):
|
||||
self.middleware_parameters = FCMiddlewareParameters()
|
||||
self.heads_parameters = [DNDQHeadParameters()]
|
||||
self.optimizer_type = 'Adam'
|
||||
self.should_get_softmax_probabilities = False
|
||||
|
||||
|
||||
class NECAlgorithmParameters(AlgorithmParameters):
|
||||
@@ -166,11 +167,25 @@ class NECAgent(ValueOptimizationAgent):
|
||||
|
||||
return super().act()
|
||||
|
||||
def get_all_q_values_for_states(self, states: StateType):
|
||||
def get_all_q_values_for_states(self, states: StateType, additional_outputs: List = None):
|
||||
# we need to store the state embeddings regardless if the action is random or not
|
||||
return self.get_prediction(states)
|
||||
return self.get_prediction_and_update_embeddings(states)
|
||||
|
||||
def get_prediction(self, states):
|
||||
def get_all_q_values_for_states_and_softmax_probabilities(self, states: StateType):
|
||||
# get the actions q values and the state embedding
|
||||
embedding, actions_q_values, softmax_probabilities = self.networks['main'].online_network.predict(
|
||||
self.prepare_batch_for_inference(states, 'main'),
|
||||
outputs=[self.networks['main'].online_network.state_embedding,
|
||||
self.networks['main'].online_network.output_heads[0].output,
|
||||
self.networks['main'].online_network.output_heads[0].softmax]
|
||||
)
|
||||
if self.phase != RunPhase.TEST:
|
||||
# store the state embedding for inserting it to the DND later
|
||||
self.current_episode_state_embeddings.append(embedding.squeeze())
|
||||
actions_q_values = actions_q_values[0][0]
|
||||
return actions_q_values, softmax_probabilities
|
||||
|
||||
def get_prediction_and_update_embeddings(self, states):
|
||||
# get the actions q values and the state embedding
|
||||
embedding, actions_q_values = self.networks['main'].online_network.predict(
|
||||
self.prepare_batch_for_inference(states, 'main'),
|
||||
|
||||
@@ -147,7 +147,7 @@ class PolicyOptimizationAgent(Agent):
|
||||
if isinstance(self.spaces.action, DiscreteActionSpace):
|
||||
# DISCRETE
|
||||
action_probabilities = np.array(action_values).squeeze()
|
||||
action = self.exploration_policy.get_action(action_probabilities)
|
||||
action, _ = self.exploration_policy.get_action(action_probabilities)
|
||||
action_info = ActionInfo(action=action,
|
||||
all_action_probabilities=action_probabilities)
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from copy import copy
|
||||
from typing import Union
|
||||
|
||||
import numpy as np
|
||||
@@ -79,6 +79,17 @@ class QuantileRegressionDQNAgent(ValueOptimizationAgent):
|
||||
actions_q_values = None
|
||||
return actions_q_values
|
||||
|
||||
# prediction's format is (batch,actions,atoms)
|
||||
def get_all_q_values_for_states_and_softmax_probabilities(self, states: StateType):
|
||||
actions_q_values, softmax_probabilities = None, None
|
||||
if self.exploration_policy.requires_action_values():
|
||||
outputs = copy(self.networks['main'].online_network.outputs)
|
||||
outputs.append(self.networks['main'].online_network.output_heads[0].softmax)
|
||||
quantile_values, softmax_probabilities = self.get_prediction(states, outputs)
|
||||
actions_q_values = self.get_q_values(quantile_values)
|
||||
|
||||
return actions_q_values, softmax_probabilities
|
||||
|
||||
def learn_from_batch(self, batch):
|
||||
network_keys = self.ap.network_wrappers['main'].input_embedders_parameters.keys()
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
from collections import OrderedDict
|
||||
from typing import Union
|
||||
from typing import Union, List
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -24,7 +24,8 @@ from rl_coach.filters.filter import NoInputFilter
|
||||
from rl_coach.logger import screen
|
||||
from rl_coach.memories.non_episodic.prioritized_experience_replay import PrioritizedExperienceReplay
|
||||
from rl_coach.spaces import DiscreteActionSpace
|
||||
from copy import deepcopy
|
||||
from copy import deepcopy, copy
|
||||
|
||||
|
||||
## This is an abstract agent - there is no learn_from_batch method ##
|
||||
|
||||
@@ -35,6 +36,12 @@ class ValueOptimizationAgent(Agent):
|
||||
self.q_values = self.register_signal("Q")
|
||||
self.q_value_for_action = {}
|
||||
|
||||
# currently we use softmax action probabilities only in batch-rl,
|
||||
# but we might want to extend this later at some point.
|
||||
self.should_get_softmax_probabilities = \
|
||||
hasattr(self.ap.network_wrappers['main'], 'should_get_softmax_probabilities') and \
|
||||
self.ap.network_wrappers['main'].should_get_softmax_probabilities
|
||||
|
||||
def init_environment_dependent_modules(self):
|
||||
super().init_environment_dependent_modules()
|
||||
if isinstance(self.spaces.action, DiscreteActionSpace):
|
||||
@@ -45,12 +52,21 @@ class ValueOptimizationAgent(Agent):
|
||||
|
||||
# Algorithms for which q_values are calculated from predictions will override this function
|
||||
def get_all_q_values_for_states(self, states: StateType):
|
||||
actions_q_values = None
|
||||
if self.exploration_policy.requires_action_values():
|
||||
actions_q_values = self.get_prediction(states)
|
||||
else:
|
||||
actions_q_values = None
|
||||
|
||||
return actions_q_values
|
||||
|
||||
def get_all_q_values_for_states_and_softmax_probabilities(self, states: StateType):
|
||||
actions_q_values, softmax_probabilities = None, None
|
||||
if self.exploration_policy.requires_action_values():
|
||||
outputs = copy(self.networks['main'].online_network.outputs)
|
||||
outputs.append(self.networks['main'].online_network.output_heads[0].softmax)
|
||||
|
||||
actions_q_values, softmax_probabilities = self.get_prediction(states, outputs=outputs)
|
||||
return actions_q_values, softmax_probabilities
|
||||
|
||||
def get_prediction(self, states, outputs=None):
|
||||
return self.networks['main'].online_network.predict(self.prepare_batch_for_inference(states, 'main'),
|
||||
outputs=outputs)
|
||||
@@ -72,10 +88,19 @@ class ValueOptimizationAgent(Agent):
|
||||
).format(policy.__class__.__name__))
|
||||
|
||||
def choose_action(self, curr_state):
|
||||
actions_q_values = self.get_all_q_values_for_states(curr_state)
|
||||
if self.should_get_softmax_probabilities:
|
||||
actions_q_values, softmax_probabilities = \
|
||||
self.get_all_q_values_for_states_and_softmax_probabilities(curr_state)
|
||||
else:
|
||||
actions_q_values = self.get_all_q_values_for_states(curr_state)
|
||||
|
||||
# choose action according to the exploration policy and the current phase (evaluating or training the agent)
|
||||
action = self.exploration_policy.get_action(actions_q_values)
|
||||
action, action_probabilities = self.exploration_policy.get_action(actions_q_values)
|
||||
if self.should_get_softmax_probabilities and softmax_probabilities is not None:
|
||||
# override the exploration policy's generated probabilities when an action was taken
|
||||
# with the agent's actual policy
|
||||
action_probabilities = softmax_probabilities
|
||||
|
||||
self._validate_action(self.exploration_policy, action)
|
||||
|
||||
if actions_q_values is not None:
|
||||
@@ -87,15 +112,18 @@ class ValueOptimizationAgent(Agent):
|
||||
self.q_values.add_sample(actions_q_values)
|
||||
|
||||
actions_q_values = actions_q_values.squeeze()
|
||||
action_probabilities = action_probabilities.squeeze()
|
||||
|
||||
for i, q_value in enumerate(actions_q_values):
|
||||
self.q_value_for_action[i].add_sample(q_value)
|
||||
|
||||
action_info = ActionInfo(action=action,
|
||||
action_value=actions_q_values[action],
|
||||
max_action_value=np.max(actions_q_values))
|
||||
max_action_value=np.max(actions_q_values),
|
||||
all_action_probabilities=action_probabilities)
|
||||
|
||||
else:
|
||||
action_info = ActionInfo(action=action)
|
||||
action_info = ActionInfo(action=action, all_action_probabilities=action_probabilities)
|
||||
|
||||
return action_info
|
||||
|
||||
|
||||
@@ -17,14 +17,17 @@
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
import scipy.stats
|
||||
|
||||
from rl_coach.core_types import RunPhase, ActionType
|
||||
from rl_coach.exploration_policies.exploration_policy import ExplorationPolicy, ExplorationParameters
|
||||
from rl_coach.exploration_policies.exploration_policy import ContinuousActionExplorationPolicy, ExplorationParameters
|
||||
from rl_coach.schedules import Schedule, LinearSchedule
|
||||
from rl_coach.spaces import ActionSpace, BoxActionSpace
|
||||
|
||||
|
||||
# TODO: consider renaming to gaussian sampling
|
||||
|
||||
|
||||
class AdditiveNoiseParameters(ExplorationParameters):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
@@ -36,7 +39,7 @@ class AdditiveNoiseParameters(ExplorationParameters):
|
||||
return 'rl_coach.exploration_policies.additive_noise:AdditiveNoise'
|
||||
|
||||
|
||||
class AdditiveNoise(ExplorationPolicy):
|
||||
class AdditiveNoise(ContinuousActionExplorationPolicy):
|
||||
"""
|
||||
AdditiveNoise is an exploration policy intended for continuous action spaces. It takes the action from the agent
|
||||
and adds a Gaussian distributed noise to it. The amount of noise added to the action follows the noise amount that
|
||||
|
||||
@@ -19,7 +19,7 @@ from typing import List
|
||||
import numpy as np
|
||||
|
||||
from rl_coach.core_types import RunPhase, ActionType
|
||||
from rl_coach.exploration_policies.exploration_policy import ExplorationPolicy, ExplorationParameters
|
||||
from rl_coach.exploration_policies.exploration_policy import DiscreteActionExplorationPolicy, ExplorationParameters
|
||||
from rl_coach.schedules import Schedule
|
||||
from rl_coach.spaces import ActionSpace
|
||||
|
||||
@@ -34,8 +34,7 @@ class BoltzmannParameters(ExplorationParameters):
|
||||
return 'rl_coach.exploration_policies.boltzmann:Boltzmann'
|
||||
|
||||
|
||||
|
||||
class Boltzmann(ExplorationPolicy):
|
||||
class Boltzmann(DiscreteActionExplorationPolicy):
|
||||
"""
|
||||
The Boltzmann exploration policy is intended for discrete action spaces. It assumes that each of the possible
|
||||
actions has some value assigned to it (such as the Q value), and uses a softmax function to convert these values
|
||||
@@ -50,7 +49,7 @@ class Boltzmann(ExplorationPolicy):
|
||||
super().__init__(action_space)
|
||||
self.temperature_schedule = temperature_schedule
|
||||
|
||||
def get_action(self, action_values: List[ActionType]) -> ActionType:
|
||||
def get_action(self, action_values: List[ActionType]) -> (ActionType, List[float]):
|
||||
if self.phase == RunPhase.TRAIN:
|
||||
self.temperature_schedule.step()
|
||||
# softmax calculation
|
||||
@@ -59,7 +58,8 @@ class Boltzmann(ExplorationPolicy):
|
||||
# make sure probs sum to 1
|
||||
probabilities[-1] = 1 - np.sum(probabilities[:-1])
|
||||
# choose actions according to the probabilities
|
||||
return np.random.choice(range(self.action_space.shape), p=probabilities)
|
||||
action = np.random.choice(range(self.action_space.shape), p=probabilities)
|
||||
return action, probabilities
|
||||
|
||||
def get_control_param(self):
|
||||
return self.temperature_schedule.current_value
|
||||
|
||||
@@ -19,7 +19,7 @@ from typing import List
|
||||
import numpy as np
|
||||
|
||||
from rl_coach.core_types import RunPhase, ActionType
|
||||
from rl_coach.exploration_policies.exploration_policy import ExplorationPolicy, ExplorationParameters
|
||||
from rl_coach.exploration_policies.exploration_policy import DiscreteActionExplorationPolicy, ExplorationParameters
|
||||
from rl_coach.spaces import ActionSpace
|
||||
|
||||
|
||||
@@ -29,7 +29,7 @@ class CategoricalParameters(ExplorationParameters):
|
||||
return 'rl_coach.exploration_policies.categorical:Categorical'
|
||||
|
||||
|
||||
class Categorical(ExplorationPolicy):
|
||||
class Categorical(DiscreteActionExplorationPolicy):
|
||||
"""
|
||||
Categorical exploration policy is intended for discrete action spaces. It expects the action values to
|
||||
represent a probability distribution over the action, from which a single action will be sampled.
|
||||
@@ -42,13 +42,18 @@ class Categorical(ExplorationPolicy):
|
||||
"""
|
||||
super().__init__(action_space)
|
||||
|
||||
def get_action(self, action_values: List[ActionType]) -> ActionType:
|
||||
def get_action(self, action_values: List[ActionType]) -> (ActionType, List[float]):
|
||||
if self.phase == RunPhase.TRAIN:
|
||||
# choose actions according to the probabilities
|
||||
return np.random.choice(self.action_space.actions, p=action_values)
|
||||
action = np.random.choice(self.action_space.actions, p=action_values)
|
||||
return action, action_values
|
||||
else:
|
||||
# take the action with the highest probability
|
||||
return np.argmax(action_values)
|
||||
action = np.argmax(action_values)
|
||||
one_hot_action_probabilities = np.zeros(len(self.action_space.actions))
|
||||
one_hot_action_probabilities[action] = 1
|
||||
|
||||
return action, one_hot_action_probabilities
|
||||
|
||||
def get_control_param(self):
|
||||
return 0
|
||||
|
||||
@@ -20,8 +20,7 @@ import numpy as np
|
||||
|
||||
from rl_coach.core_types import RunPhase, ActionType
|
||||
from rl_coach.exploration_policies.additive_noise import AdditiveNoiseParameters
|
||||
from rl_coach.exploration_policies.exploration_policy import ExplorationParameters
|
||||
from rl_coach.exploration_policies.exploration_policy import ExplorationPolicy
|
||||
from rl_coach.exploration_policies.exploration_policy import ExplorationParameters, ExplorationPolicy
|
||||
from rl_coach.schedules import Schedule, LinearSchedule
|
||||
from rl_coach.spaces import ActionSpace, DiscreteActionSpace, BoxActionSpace
|
||||
from rl_coach.utils import dynamic_import_and_instantiate_module_from_params
|
||||
@@ -82,26 +81,32 @@ class EGreedy(ExplorationPolicy):
|
||||
epsilon = self.evaluation_epsilon if self.phase == RunPhase.TEST else self.epsilon_schedule.current_value
|
||||
return self.current_random_value >= epsilon
|
||||
|
||||
def get_action(self, action_values: List[ActionType]) -> ActionType:
|
||||
def get_action(self, action_values: List[ActionType]) -> (ActionType, List[float]):
|
||||
epsilon = self.evaluation_epsilon if self.phase == RunPhase.TEST else self.epsilon_schedule.current_value
|
||||
|
||||
if isinstance(self.action_space, DiscreteActionSpace):
|
||||
top_action = np.argmax(action_values)
|
||||
if self.current_random_value < epsilon:
|
||||
chosen_action = self.action_space.sample()
|
||||
probabilities = np.full(len(self.action_space.actions),
|
||||
1. / (self.action_space.high[0] - self.action_space.low[0] + 1))
|
||||
else:
|
||||
chosen_action = top_action
|
||||
chosen_action = np.argmax(action_values)
|
||||
|
||||
# one-hot probabilities vector
|
||||
probabilities = np.zeros(len(self.action_space.actions))
|
||||
probabilities[chosen_action] = 1
|
||||
|
||||
self.step_epsilon()
|
||||
return chosen_action, probabilities
|
||||
|
||||
else:
|
||||
if self.current_random_value < epsilon and self.phase == RunPhase.TRAIN:
|
||||
chosen_action = self.action_space.sample()
|
||||
else:
|
||||
chosen_action = self.continuous_exploration_policy.get_action(action_values)
|
||||
|
||||
# step the epsilon schedule and generate a new random value for next time
|
||||
if self.phase == RunPhase.TRAIN:
|
||||
self.epsilon_schedule.step()
|
||||
self.current_random_value = np.random.rand()
|
||||
return chosen_action
|
||||
self.step_epsilon()
|
||||
return chosen_action
|
||||
|
||||
def get_control_param(self):
|
||||
if isinstance(self.action_space, DiscreteActionSpace):
|
||||
@@ -113,3 +118,9 @@ class EGreedy(ExplorationPolicy):
|
||||
super().change_phase(phase)
|
||||
if isinstance(self.action_space, BoxActionSpace):
|
||||
self.continuous_exploration_policy.change_phase(phase)
|
||||
|
||||
def step_epsilon(self):
|
||||
# step the epsilon schedule and generate a new random value for next time
|
||||
if self.phase == RunPhase.TRAIN:
|
||||
self.epsilon_schedule.step()
|
||||
self.current_random_value = np.random.rand()
|
||||
|
||||
@@ -18,7 +18,7 @@ from typing import List
|
||||
|
||||
from rl_coach.base_parameters import Parameters
|
||||
from rl_coach.core_types import RunPhase, ActionType
|
||||
from rl_coach.spaces import ActionSpace
|
||||
from rl_coach.spaces import ActionSpace, DiscreteActionSpace, BoxActionSpace, GoalsSpace
|
||||
|
||||
|
||||
class ExplorationParameters(Parameters):
|
||||
@@ -54,14 +54,10 @@ class ExplorationPolicy(object):
|
||||
Given a list of values corresponding to each action,
|
||||
choose one actions according to the exploration policy
|
||||
:param action_values: A list of action values
|
||||
:return: The chosen action
|
||||
:return: The chosen action,
|
||||
The probability of the action (if available, otherwise 1 for absolute certainty in the action)
|
||||
"""
|
||||
if self.__class__ == ExplorationPolicy:
|
||||
raise ValueError("The ExplorationPolicy class is an abstract class and should not be used directly. "
|
||||
"Please set the exploration parameters to point to an inheriting class like EGreedy or "
|
||||
"AdditiveNoise")
|
||||
else:
|
||||
raise ValueError("The get_action function should be overridden in the inheriting exploration class")
|
||||
raise NotImplementedError()
|
||||
|
||||
def change_phase(self, phase):
|
||||
"""
|
||||
@@ -82,3 +78,42 @@ class ExplorationPolicy(object):
|
||||
|
||||
def get_control_param(self):
|
||||
return 0
|
||||
|
||||
|
||||
class DiscreteActionExplorationPolicy(ExplorationPolicy):
|
||||
"""
|
||||
A discrete action exploration policy.
|
||||
"""
|
||||
def __init__(self, action_space: ActionSpace):
|
||||
"""
|
||||
:param action_space: the action space used by the environment
|
||||
"""
|
||||
assert isinstance(action_space, DiscreteActionSpace)
|
||||
super().__init__(action_space)
|
||||
|
||||
def get_action(self, action_values: List[ActionType]) -> (ActionType, List):
|
||||
"""
|
||||
Given a list of values corresponding to each action,
|
||||
choose one actions according to the exploration policy
|
||||
:param action_values: A list of action values
|
||||
:return: The chosen action,
|
||||
The probabilities of actions to select from (if not available a one-hot vector)
|
||||
"""
|
||||
if self.__class__ == ExplorationPolicy:
|
||||
raise ValueError("The ExplorationPolicy class is an abstract class and should not be used directly. "
|
||||
"Please set the exploration parameters to point to an inheriting class like EGreedy or "
|
||||
"AdditiveNoise")
|
||||
else:
|
||||
raise ValueError("The get_action function should be overridden in the inheriting exploration class")
|
||||
|
||||
|
||||
class ContinuousActionExplorationPolicy(ExplorationPolicy):
|
||||
"""
|
||||
A continuous action exploration policy.
|
||||
"""
|
||||
def __init__(self, action_space: ActionSpace):
|
||||
"""
|
||||
:param action_space: the action space used by the environment
|
||||
"""
|
||||
assert isinstance(action_space, BoxActionSpace) or isinstance(action_space, GoalsSpace)
|
||||
super().__init__(action_space)
|
||||
|
||||
@@ -19,7 +19,7 @@ from typing import List
|
||||
import numpy as np
|
||||
|
||||
from rl_coach.core_types import ActionType
|
||||
from rl_coach.exploration_policies.exploration_policy import ExplorationPolicy, ExplorationParameters
|
||||
from rl_coach.exploration_policies.exploration_policy import ExplorationParameters, ExplorationPolicy
|
||||
from rl_coach.spaces import ActionSpace, DiscreteActionSpace, BoxActionSpace
|
||||
|
||||
|
||||
@@ -41,9 +41,12 @@ class Greedy(ExplorationPolicy):
|
||||
"""
|
||||
super().__init__(action_space)
|
||||
|
||||
def get_action(self, action_values: List[ActionType]) -> ActionType:
|
||||
def get_action(self, action_values: List[ActionType]):
|
||||
if type(self.action_space) == DiscreteActionSpace:
|
||||
return np.argmax(action_values)
|
||||
action = np.argmax(action_values)
|
||||
one_hot_action_probabilities = np.zeros(len(self.action_space.actions))
|
||||
one_hot_action_probabilities[action] = 1
|
||||
return action, one_hot_action_probabilities
|
||||
if type(self.action_space) == BoxActionSpace:
|
||||
return action_values
|
||||
|
||||
|
||||
@@ -19,12 +19,13 @@ from typing import List
|
||||
import numpy as np
|
||||
|
||||
from rl_coach.core_types import RunPhase, ActionType
|
||||
from rl_coach.exploration_policies.exploration_policy import ExplorationPolicy, ExplorationParameters
|
||||
from rl_coach.exploration_policies.exploration_policy import ContinuousActionExplorationPolicy, ExplorationParameters
|
||||
from rl_coach.spaces import ActionSpace, BoxActionSpace, GoalsSpace
|
||||
|
||||
|
||||
# Based on on the description in:
|
||||
# https://math.stackexchange.com/questions/1287634/implementing-ornstein-uhlenbeck-in-matlab
|
||||
|
||||
class OUProcessParameters(ExplorationParameters):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
@@ -39,7 +40,7 @@ class OUProcessParameters(ExplorationParameters):
|
||||
|
||||
|
||||
# Ornstein-Uhlenbeck process
|
||||
class OUProcess(ExplorationPolicy):
|
||||
class OUProcess(ContinuousActionExplorationPolicy):
|
||||
"""
|
||||
OUProcess exploration policy is intended for continuous action spaces, and selects the action according to
|
||||
an Ornstein-Uhlenbeck process. The Ornstein-Uhlenbeck process implements the action as a Gaussian process, where
|
||||
@@ -56,10 +57,6 @@ class OUProcess(ExplorationPolicy):
|
||||
self.state = np.zeros(self.action_space.shape)
|
||||
self.dt = dt
|
||||
|
||||
if not (isinstance(action_space, BoxActionSpace) or isinstance(action_space, GoalsSpace)):
|
||||
raise ValueError("OU process exploration works only for continuous controls."
|
||||
"The given action space is of type: {}".format(action_space.__class__.__name__))
|
||||
|
||||
def reset(self):
|
||||
self.state = np.zeros(self.action_space.shape)
|
||||
|
||||
|
||||
@@ -59,9 +59,13 @@ class ParameterNoise(ExplorationPolicy):
|
||||
self.network_params = network_params
|
||||
self._replace_network_dense_layers()
|
||||
|
||||
def get_action(self, action_values: List[ActionType]) -> ActionType:
|
||||
def get_action(self, action_values: List[ActionType]):
|
||||
if type(self.action_space) == DiscreteActionSpace:
|
||||
return np.argmax(action_values)
|
||||
action = np.argmax(action_values)
|
||||
one_hot_action_probabilities = np.zeros(len(self.action_space.actions))
|
||||
one_hot_action_probabilities[action] = 1
|
||||
|
||||
return action, one_hot_action_probabilities
|
||||
elif type(self.action_space) == BoxActionSpace:
|
||||
action_values_mean = action_values[0].squeeze()
|
||||
action_values_std = action_values[1].squeeze()
|
||||
|
||||
@@ -20,7 +20,7 @@ import numpy as np
|
||||
from scipy.stats import truncnorm
|
||||
|
||||
from rl_coach.core_types import RunPhase, ActionType
|
||||
from rl_coach.exploration_policies.exploration_policy import ExplorationPolicy, ExplorationParameters
|
||||
from rl_coach.exploration_policies.exploration_policy import ExplorationParameters, ContinuousActionExplorationPolicy
|
||||
from rl_coach.schedules import Schedule, LinearSchedule
|
||||
from rl_coach.spaces import ActionSpace, BoxActionSpace
|
||||
|
||||
@@ -38,7 +38,7 @@ class TruncatedNormalParameters(ExplorationParameters):
|
||||
return 'rl_coach.exploration_policies.truncated_normal:TruncatedNormal'
|
||||
|
||||
|
||||
class TruncatedNormal(ExplorationPolicy):
|
||||
class TruncatedNormal(ContinuousActionExplorationPolicy):
|
||||
"""
|
||||
The TruncatedNormal exploration policy is intended for continuous action spaces. It samples the action from a
|
||||
normal distribution, where the mean action is given by the agent, and the standard deviation can be given in t
|
||||
|
||||
@@ -18,15 +18,17 @@ from typing import Tuple, List, Union
|
||||
|
||||
from rl_coach.agents.dqn_agent import DQNAgentParameters
|
||||
from rl_coach.agents.nec_agent import NECAgentParameters
|
||||
from rl_coach.architectures.network_wrapper import NetworkWrapper
|
||||
from rl_coach.base_parameters import AgentParameters, VisualizationParameters, TaskParameters, \
|
||||
PresetValidationParameters
|
||||
from rl_coach.core_types import RunPhase
|
||||
from rl_coach.core_types import RunPhase, TotalStepsCounter, TrainingSteps, EnvironmentEpisodes, EnvironmentSteps
|
||||
from rl_coach.environments.environment import EnvironmentParameters, Environment
|
||||
from rl_coach.graph_managers.graph_manager import ScheduleParameters
|
||||
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager
|
||||
|
||||
from rl_coach.level_manager import LevelManager
|
||||
from rl_coach.logger import screen
|
||||
from rl_coach.schedules import LinearSchedule
|
||||
from rl_coach.spaces import SpacesDefinition
|
||||
from rl_coach.utils import short_dynamic_import
|
||||
|
||||
@@ -35,26 +37,62 @@ from rl_coach.memories.episodic import EpisodicExperienceReplayParameters
|
||||
from rl_coach.core_types import TimeTypes
|
||||
|
||||
|
||||
# TODO build a tutorial for batch RL
|
||||
class BatchRLGraphManager(BasicRLGraphManager):
|
||||
"""
|
||||
A batch RL graph manager creates scenario of learning from a dataset without a simulator.
|
||||
A batch RL graph manager creates a scenario of learning from a dataset without a simulator.
|
||||
|
||||
If an environment is given (useful either for research purposes, or for experimenting with a toy problem before
|
||||
actually working with a real dataset), we can use it in order to collect a dataset to later be used to train the
|
||||
actual agent. The collected dataset, in this case, can be collected either by randomly acting in the environment
|
||||
(only running in heatup), or alternatively by training a different agent in the environment and using its collected
|
||||
data as a dataset. If an experience generating agent parameters are given, we will instantiate this agent and use it
|
||||
in order to train on the environment and then use this dataset to actually train an agent. Otherwise, we will
|
||||
collect a random dataset.
|
||||
:param agent_params: the parameters of the agent to train using batch RL
|
||||
:param env_params: [optional] environment parameters, for cases where we want to first collect a dataset
|
||||
:param vis_params: visualization parameters
|
||||
:param preset_validation_params: preset validation parameters, to be used for testing purposes
|
||||
:param name: graph name
|
||||
:param spaces_definition: when working with a dataset, we need to get a description of the actual state and action
|
||||
spaces of the problem
|
||||
:param reward_model_num_epochs: the number of epochs to go over the dataset for training a reward model for the
|
||||
'direct method' and 'doubly robust' OPE methods.
|
||||
:param train_to_eval_ratio: percentage of the data transitions to be used for training vs. evaluation. i.e. a value
|
||||
of 0.8 means ~80% of the transitions will be used for training and ~20% will be used for
|
||||
evaluation using OPE.
|
||||
:param experience_generating_agent_params: [optional] parameters of an agent to be trained vs. an environment, whose
|
||||
his collected experience will be used to train the acutal (another) agent
|
||||
:param experience_generating_schedule_params: [optional] graph scheduling parameters for training the experience
|
||||
generating agent
|
||||
"""
|
||||
def __init__(self, agent_params: AgentParameters, env_params: Union[EnvironmentParameters, None],
|
||||
def __init__(self, agent_params: AgentParameters,
|
||||
env_params: Union[EnvironmentParameters, None],
|
||||
schedule_params: ScheduleParameters,
|
||||
vis_params: VisualizationParameters = VisualizationParameters(),
|
||||
preset_validation_params: PresetValidationParameters = PresetValidationParameters(),
|
||||
name='batch_rl_graph', spaces_definition: SpacesDefinition = None, reward_model_num_epochs: int = 100,
|
||||
train_to_eval_ratio: float = 0.8):
|
||||
train_to_eval_ratio: float = 0.8, experience_generating_agent_params: AgentParameters = None,
|
||||
experience_generating_schedule_params: ScheduleParameters = None):
|
||||
|
||||
super().__init__(agent_params, env_params, schedule_params, vis_params, preset_validation_params, name)
|
||||
self.is_batch_rl = True
|
||||
self.time_metric = TimeTypes.Epoch
|
||||
self.reward_model_num_epochs = reward_model_num_epochs
|
||||
self.spaces_definition = spaces_definition
|
||||
self.is_collecting_random_dataset = experience_generating_agent_params is None
|
||||
|
||||
# setting this here to make sure that, by default, train_to_eval_ratio gets a value < 1
|
||||
# (its default value in the memory is 1)
|
||||
self.agent_params.memory.train_to_eval_ratio = train_to_eval_ratio
|
||||
# (its default value in the memory is 1, so not to affect other non-batch-rl scenarios)
|
||||
if self.is_collecting_random_dataset:
|
||||
self.agent_params.memory.train_to_eval_ratio = train_to_eval_ratio
|
||||
else:
|
||||
experience_generating_agent_params.memory.train_to_eval_ratio = train_to_eval_ratio
|
||||
self.experience_generating_agent_params = experience_generating_agent_params
|
||||
self.experience_generating_agent = None
|
||||
|
||||
self.set_schedule_params(experience_generating_schedule_params)
|
||||
self.schedule_params = schedule_params
|
||||
|
||||
def _create_graph(self, task_parameters: TaskParameters) -> Tuple[List[LevelManager], List[Environment]]:
|
||||
if self.env_params:
|
||||
@@ -76,22 +114,41 @@ class BatchRLGraphManager(BasicRLGraphManager):
|
||||
self.agent_params.task_parameters = task_parameters # TODO: this should probably be passed in a different way
|
||||
self.agent_params.name = "agent"
|
||||
self.agent_params.is_batch_rl_training = True
|
||||
self.agent_params.network_wrappers['main'].should_get_softmax_probabilities = True
|
||||
|
||||
if 'reward_model' not in self.agent_params.network_wrappers:
|
||||
# user hasn't defined params for the reward model. we will use the same params as used for the 'main'
|
||||
# network.
|
||||
self.agent_params.network_wrappers['reward_model'] = deepcopy(self.agent_params.network_wrappers['main'])
|
||||
|
||||
agent = short_dynamic_import(self.agent_params.path)(self.agent_params)
|
||||
self.agent = short_dynamic_import(self.agent_params.path)(self.agent_params)
|
||||
agents = {'agent': self.agent}
|
||||
|
||||
if not self.is_collecting_random_dataset:
|
||||
self.experience_generating_agent_params.visualization.dump_csv = False
|
||||
self.experience_generating_agent_params.task_parameters = task_parameters
|
||||
self.experience_generating_agent_params.name = "experience_gen_agent"
|
||||
self.experience_generating_agent_params.network_wrappers['main'].should_get_softmax_probabilities = True
|
||||
|
||||
# we need to set these manually as these are usually being set for us only for the default agent
|
||||
self.experience_generating_agent_params.input_filter = self.agent_params.input_filter
|
||||
self.experience_generating_agent_params.output_filter = self.agent_params.output_filter
|
||||
|
||||
self.experience_generating_agent = short_dynamic_import(
|
||||
self.experience_generating_agent_params.path)(self.experience_generating_agent_params)
|
||||
|
||||
agents['experience_generating_agent'] = self.experience_generating_agent
|
||||
|
||||
if not env and not self.agent_params.memory.load_memory_from_file_path:
|
||||
screen.warning("A BatchRLGraph requires setting a dataset to load into the agent's memory or alternatively "
|
||||
"using an environment to create a (random) dataset from. This agent should only be used for "
|
||||
"inference. ")
|
||||
# set level manager
|
||||
level_manager = LevelManager(agents=agent, environment=env, name="main_level",
|
||||
# - although we will be using each agent separately, we have to have both agents initialized together with the
|
||||
# LevelManager, so to have them both properly initialized
|
||||
level_manager = LevelManager(agents=agents,
|
||||
environment=env, name="main_level",
|
||||
spaces_definition=self.spaces_definition)
|
||||
|
||||
if env:
|
||||
return [level_manager], [env]
|
||||
else:
|
||||
@@ -123,12 +180,34 @@ class BatchRLGraphManager(BasicRLGraphManager):
|
||||
# an environment and a dataset to load from, we will use the environment only for evaluating the policy,
|
||||
# and will not run heatup.
|
||||
|
||||
# heatup
|
||||
if self.env_params is not None and not self.agent_params.memory.load_memory_from_file_path:
|
||||
self.heatup(self.heatup_steps)
|
||||
screen.log_title("Starting to improve an agent collecting experience to use for training the actual agent in a "
|
||||
"Batch RL fashion")
|
||||
|
||||
if self.is_collecting_random_dataset:
|
||||
# heatup
|
||||
if self.env_params is not None and not self.agent_params.memory.load_memory_from_file_path:
|
||||
self.heatup(self.heatup_steps)
|
||||
else:
|
||||
# set the experience generating agent to train
|
||||
self.level_managers[0].agents = {'experience_generating_agent': self.experience_generating_agent}
|
||||
|
||||
# collect a dataset using the experience generating agent
|
||||
super().improve()
|
||||
|
||||
# set the acquired experience to the actual agent that we're going to train
|
||||
self.agent.memory = self.experience_generating_agent.memory
|
||||
|
||||
# switch the graph scheduling parameters
|
||||
self.set_schedule_params(self.schedule_params)
|
||||
|
||||
# set the actual agent to train
|
||||
self.level_managers[0].agents = {'agent': self.agent}
|
||||
|
||||
# this agent never actually plays
|
||||
self.level_managers[0].agents['agent'].ap.algorithm.num_consecutive_playing_steps = EnvironmentSteps(0)
|
||||
|
||||
# from this point onwards, the dataset cannot be changed anymore. Allows for performance improvements.
|
||||
self.level_managers[0].agents['agent'].memory.freeze()
|
||||
self.level_managers[0].agents['agent'].freeze_memory()
|
||||
|
||||
self.initialize_ope_models_and_stats()
|
||||
|
||||
@@ -141,15 +220,13 @@ class BatchRLGraphManager(BasicRLGraphManager):
|
||||
# the outer most training loop
|
||||
improve_steps_end = self.total_steps_counters[RunPhase.TRAIN] + self.improve_steps
|
||||
while self.total_steps_counters[RunPhase.TRAIN] < improve_steps_end:
|
||||
# TODO if we have an environment, do we want to use it to have the agent train against, and use the
|
||||
# collected replay buffer as a dataset? (as oppose to what we currently have, where the dataset is built
|
||||
# during heatup, and is composed on random actions)
|
||||
# perform several steps of training
|
||||
if self.steps_between_evaluation_periods.num_steps > 0:
|
||||
with self.phase_context(RunPhase.TRAIN):
|
||||
self.reset_internal_state(force_environment_reset=True)
|
||||
|
||||
steps_between_evaluation_periods_end = self.current_step_counter + self.steps_between_evaluation_periods
|
||||
steps_between_evaluation_periods_end = self.current_step_counter + \
|
||||
self.steps_between_evaluation_periods
|
||||
while self.current_step_counter < steps_between_evaluation_periods_end:
|
||||
self.train()
|
||||
|
||||
@@ -168,8 +245,8 @@ class BatchRLGraphManager(BasicRLGraphManager):
|
||||
|
||||
def initialize_ope_models_and_stats(self):
|
||||
"""
|
||||
|
||||
:return:
|
||||
Improve a reward model of the MDP, to be used for some of the off-policy evaluation (OPE) methods.
|
||||
e.g. 'direct method' and 'doubly robust'.
|
||||
"""
|
||||
agent = self.level_managers[0].agents['agent']
|
||||
|
||||
@@ -193,6 +270,3 @@ class BatchRLGraphManager(BasicRLGraphManager):
|
||||
:return:
|
||||
"""
|
||||
self.level_managers[0].agents['agent'].run_off_policy_evaluation()
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -95,10 +95,7 @@ class GraphManager(object):
|
||||
self.level_managers = [] # type: List[LevelManager]
|
||||
self.top_level_manager = None
|
||||
self.environments = []
|
||||
self.heatup_steps = schedule_params.heatup_steps
|
||||
self.evaluation_steps = schedule_params.evaluation_steps
|
||||
self.steps_between_evaluation_periods = schedule_params.steps_between_evaluation_periods
|
||||
self.improve_steps = schedule_params.improve_steps
|
||||
self.set_schedule_params(schedule_params)
|
||||
self.visualization_parameters = vis_params
|
||||
self.name = name
|
||||
self.task_parameters = None
|
||||
@@ -759,3 +756,14 @@ class GraphManager(object):
|
||||
if hasattr(self, 'data_store_params'):
|
||||
data_store = self.get_data_store(self.data_store_params)
|
||||
data_store.save_to_store()
|
||||
|
||||
def set_schedule_params(self, schedule_params: ScheduleParameters):
|
||||
"""
|
||||
Set schedule parameters for the graph.
|
||||
|
||||
:param schedule_params: the schedule params to set.
|
||||
"""
|
||||
self.heatup_steps = schedule_params.heatup_steps
|
||||
self.evaluation_steps = schedule_params.evaluation_steps
|
||||
self.steps_between_evaluation_periods = schedule_params.steps_between_evaluation_periods
|
||||
self.improve_steps = schedule_params.improve_steps
|
||||
|
||||
@@ -452,8 +452,6 @@ class EpisodicExperienceReplay(Memory):
|
||||
progress_bar.update(len(episode_ids))
|
||||
progress_bar.close()
|
||||
|
||||
self.shuffle_episodes()
|
||||
|
||||
def freeze(self):
|
||||
"""
|
||||
Freezing the replay buffer does not allow any new transitions to be added to the memory.
|
||||
|
||||
@@ -20,7 +20,7 @@ from rl_coach.core_types import Episode
|
||||
|
||||
|
||||
class WeightedImportanceSampling(object):
|
||||
# TODO rename and add PDIS
|
||||
# TODO add PDIS
|
||||
@staticmethod
|
||||
def evaluate(evaluation_dataset_as_episodes: List[Episode]) -> float:
|
||||
"""
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from copy import deepcopy
|
||||
from rl_coach.agents.dqn_agent import DQNAgentParameters
|
||||
|
||||
from rl_coach.architectures.tensorflow_components.layers import Dense
|
||||
from rl_coach.base_parameters import VisualizationParameters, PresetValidationParameters
|
||||
@@ -45,10 +46,10 @@ agent_params.network_wrappers['main'].batch_size = 128
|
||||
agent_params.algorithm.num_steps_between_copying_online_weights_to_target = TrainingSteps(
|
||||
DATASET_SIZE / agent_params.network_wrappers['main'].batch_size)
|
||||
#
|
||||
# agent_params.algorithm.num_steps_between_copying_online_weights_to_target = TrainingSteps(
|
||||
# 3)
|
||||
agent_params.algorithm.num_steps_between_copying_online_weights_to_target = TrainingSteps(
|
||||
100)
|
||||
# agent_params.algorithm.num_steps_between_copying_online_weights_to_target = EnvironmentSteps(100)
|
||||
|
||||
agent_params.algorithm.num_consecutive_playing_steps = EnvironmentSteps(0)
|
||||
agent_params.algorithm.discount = 0.98
|
||||
|
||||
# can use either a kNN or a NN based model for predicting which actions not to max over in the bellman equation
|
||||
@@ -79,17 +80,49 @@ agent_params.network_wrappers['imitation_model'].middleware_parameters.scheme =
|
||||
|
||||
# ER size
|
||||
agent_params.memory = EpisodicExperienceReplayParameters()
|
||||
agent_params.memory.max_size = (MemoryGranularity.Transitions, DATASET_SIZE)
|
||||
|
||||
|
||||
# E-Greedy schedule
|
||||
agent_params.exploration.epsilon_schedule = LinearSchedule(0, 0, 10000)
|
||||
agent_params.exploration.evaluation_epsilon = 0
|
||||
|
||||
|
||||
# Input filtering
|
||||
agent_params.input_filter = InputFilter()
|
||||
agent_params.input_filter.add_reward_filter('rescale', RewardRescaleFilter(1/200.))
|
||||
|
||||
|
||||
|
||||
|
||||
# Experience Generating Agent parameters
|
||||
experience_generating_agent_params = DQNAgentParameters()
|
||||
|
||||
# schedule parameters
|
||||
experience_generating_schedule_params = ScheduleParameters()
|
||||
experience_generating_schedule_params.heatup_steps = EnvironmentSteps(1000)
|
||||
experience_generating_schedule_params.improve_steps = TrainingSteps(
|
||||
DATASET_SIZE - experience_generating_schedule_params.heatup_steps.num_steps)
|
||||
experience_generating_schedule_params.steps_between_evaluation_periods = EnvironmentEpisodes(10)
|
||||
experience_generating_schedule_params.evaluation_steps = EnvironmentEpisodes(1)
|
||||
|
||||
# DQN params
|
||||
experience_generating_agent_params.algorithm.num_steps_between_copying_online_weights_to_target = EnvironmentSteps(100)
|
||||
experience_generating_agent_params.algorithm.discount = 0.99
|
||||
experience_generating_agent_params.algorithm.num_consecutive_playing_steps = EnvironmentSteps(1)
|
||||
|
||||
# NN configuration
|
||||
experience_generating_agent_params.network_wrappers['main'].learning_rate = 0.00025
|
||||
experience_generating_agent_params.network_wrappers['main'].replace_mse_with_huber_loss = False
|
||||
|
||||
# ER size
|
||||
experience_generating_agent_params.memory = EpisodicExperienceReplayParameters()
|
||||
experience_generating_agent_params.memory.max_size = \
|
||||
(MemoryGranularity.Transitions,
|
||||
experience_generating_schedule_params.heatup_steps.num_steps +
|
||||
experience_generating_schedule_params.improve_steps.num_steps)
|
||||
|
||||
# E-Greedy schedule
|
||||
experience_generating_agent_params.exploration.epsilon_schedule = LinearSchedule(1.0, 0.01, 10000)
|
||||
|
||||
|
||||
################
|
||||
# Environment #
|
||||
################
|
||||
@@ -101,11 +134,14 @@ env_params = GymVectorEnvironment(level='CartPole-v0')
|
||||
preset_validation_params = PresetValidationParameters()
|
||||
preset_validation_params.test = True
|
||||
preset_validation_params.min_reward_threshold = 150
|
||||
preset_validation_params.max_episodes_to_achieve_reward = 2000
|
||||
preset_validation_params.max_episodes_to_achieve_reward = 50
|
||||
|
||||
graph_manager = BatchRLGraphManager(agent_params=agent_params, env_params=env_params,
|
||||
graph_manager = BatchRLGraphManager(agent_params=agent_params,
|
||||
experience_generating_agent_params=experience_generating_agent_params,
|
||||
experience_generating_schedule_params=experience_generating_schedule_params,
|
||||
env_params=env_params,
|
||||
schedule_params=schedule_params,
|
||||
vis_params=VisualizationParameters(dump_signals_to_csv_every_x_episodes=1),
|
||||
preset_validation_params=preset_validation_params,
|
||||
reward_model_num_epochs=30,
|
||||
train_to_eval_ratio=0.8)
|
||||
train_to_eval_ratio=0.4)
|
||||
|
||||
@@ -16,10 +16,6 @@ def test_init():
|
||||
action_space = DiscreteActionSpace(3)
|
||||
noise_schedule = LinearSchedule(1.0, 1.0, 1000)
|
||||
|
||||
# additive noise doesn't work for discrete controls
|
||||
with pytest.raises(ValueError):
|
||||
policy = AdditiveNoise(action_space, noise_schedule, 0)
|
||||
|
||||
# additive noise requires a bounded range for the actions
|
||||
action_space = BoxActionSpace(np.array([10]))
|
||||
with pytest.raises(ValueError):
|
||||
|
||||
@@ -21,14 +21,14 @@ def test_get_action():
|
||||
# verify that test phase gives greedy actions (evaluation_epsilon = 0)
|
||||
policy.change_phase(RunPhase.TEST)
|
||||
for i in range(100):
|
||||
best_action = policy.get_action(np.array([10, 20, 30]))
|
||||
best_action, _ = policy.get_action(np.array([10, 20, 30]))
|
||||
assert best_action == 2
|
||||
|
||||
# verify that train phase gives uniform actions (exploration = 1)
|
||||
policy.change_phase(RunPhase.TRAIN)
|
||||
counters = np.array([0, 0, 0])
|
||||
for i in range(30000):
|
||||
best_action = policy.get_action(np.array([10, 20, 30]))
|
||||
best_action, _ = policy.get_action(np.array([10, 20, 30]))
|
||||
counters[best_action] += 1
|
||||
assert np.all(counters > 9500) # this is noisy so we allow 5% error
|
||||
|
||||
|
||||
@@ -15,7 +15,7 @@ def test_get_action():
|
||||
action_space = DiscreteActionSpace(3)
|
||||
policy = Greedy(action_space)
|
||||
|
||||
best_action = policy.get_action(np.array([10, 20, 30]))
|
||||
best_action, _ = policy.get_action(np.array([10, 20, 30]))
|
||||
assert best_action == 2
|
||||
|
||||
# continuous control
|
||||
|
||||
@@ -16,10 +16,6 @@ def test_init():
|
||||
# discrete control
|
||||
action_space = DiscreteActionSpace(3)
|
||||
|
||||
# OU process doesn't work for discrete controls
|
||||
with pytest.raises(ValueError):
|
||||
policy = OUProcess(action_space, mu=0, theta=0.1, sigma=0.2, dt=0.01)
|
||||
|
||||
|
||||
@pytest.mark.unit_test
|
||||
def test_get_action():
|
||||
|
||||
Reference in New Issue
Block a user