1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 19:20:19 +01:00

Enabling-more-agents-for-Batch-RL-and-cleanup (#258)

allowing for the last training batch drawn to be smaller than batch_size + adding support for more agents in BatchRL by adding softmax with temperature to the corresponding heads + adding a CartPole_QR_DQN preset with a golden test + cleanups
This commit is contained in:
Gal Leibovich
2019-03-21 16:10:29 +02:00
committed by GitHub
parent abec59f367
commit 6e08c55ad5
24 changed files with 152 additions and 69 deletions

View File

@@ -18,11 +18,9 @@ from typing import Union
import numpy as np
from rl_coach.agents.dqn_agent import DQNNetworkParameters, DQNAlgorithmParameters
from rl_coach.agents.dqn_agent import DQNNetworkParameters, DQNAgentParameters
from rl_coach.agents.value_optimization_agent import ValueOptimizationAgent
from rl_coach.base_parameters import AgentParameters
from rl_coach.exploration_policies.bootstrapped import BootstrappedParameters
from rl_coach.memories.non_episodic.experience_replay import ExperienceReplayParameters
class BootstrappedDQNNetworkParameters(DQNNetworkParameters):
@@ -32,12 +30,11 @@ class BootstrappedDQNNetworkParameters(DQNNetworkParameters):
self.heads_parameters[0].rescale_gradient_from_head_by_factor = 1.0/self.heads_parameters[0].num_output_head_copies
class BootstrappedDQNAgentParameters(AgentParameters):
class BootstrappedDQNAgentParameters(DQNAgentParameters):
def __init__(self):
super().__init__(algorithm=DQNAlgorithmParameters(),
exploration=BootstrappedParameters(),
memory=ExperienceReplayParameters(),
networks={"main": BootstrappedDQNNetworkParameters()})
super().__init__()
self.exploration = BootstrappedParameters()
self.network_wrappers = {"main": BootstrappedDQNNetworkParameters()}
@property
def path(self):
@@ -65,13 +62,14 @@ class BootstrappedDQNAgent(ValueOptimizationAgent):
TD_targets = result[self.ap.exploration.architecture_num_q_heads:]
# add Q value samples for logging
self.q_values.add_sample(TD_targets)
# initialize with the current prediction so that we will
# only update the action that we have actually done in this transition
for i in range(self.ap.network_wrappers['main'].batch_size):
for i in range(batch.size):
mask = batch[i].info['mask']
for head_idx in range(self.ap.exploration.architecture_num_q_heads):
self.q_values.add_sample(TD_targets[head_idx])
if mask[head_idx] == 1:
selected_action = np.argmax(next_states_online_values[head_idx][i], 0)
TD_targets[head_idx][i, batch.actions()[i]] = \

View File

@@ -84,8 +84,8 @@ class CategoricalDQNAgent(ValueOptimizationAgent):
# prediction's format is (batch,actions,atoms)
def get_all_q_values_for_states(self, states: StateType):
if self.exploration_policy.requires_action_values():
prediction = self.get_prediction(states)
q_values = self.distribution_prediction_to_q_values(prediction)
q_values = self.get_prediction(states,
outputs=self.networks['main'].online_network.output_heads[0].q_values)
else:
q_values = None
return q_values
@@ -105,9 +105,9 @@ class CategoricalDQNAgent(ValueOptimizationAgent):
# select the optimal actions for the next state
target_actions = np.argmax(self.distribution_prediction_to_q_values(distributional_q_st_plus_1), axis=1)
m = np.zeros((self.ap.network_wrappers['main'].batch_size, self.z_values.size))
m = np.zeros((batch.size, self.z_values.size))
batches = np.arange(self.ap.network_wrappers['main'].batch_size)
batches = np.arange(batch.size)
# an alternative to the for loop. 3.7x perf improvement vs. the same code done with for looping.
# only 10% speedup overall - leaving commented out as the code is not as clear.
@@ -120,7 +120,7 @@ class CategoricalDQNAgent(ValueOptimizationAgent):
# bj_ = (tzj_ - self.z_values[0]) / (self.z_values[1] - self.z_values[0])
# u_ = (np.ceil(bj_)).astype(int)
# l_ = (np.floor(bj_)).astype(int)
# m_ = np.zeros((self.ap.network_wrappers['main'].batch_size, self.z_values.size))
# m_ = np.zeros((batch.size, self.z_values.size))
# np.add.at(m_, [batches, l_],
# np.transpose(distributional_q_st_plus_1[batches, target_actions], (1, 0)) * (u_ - bj_))
# np.add.at(m_, [batches, u_],

View File

@@ -207,6 +207,8 @@ class ClippedPPOAgent(ActorCriticAgent):
self.networks['main'].online_network.output_heads[1].likelihood_ratio,
self.networks['main'].online_network.output_heads[1].clipped_likelihood_ratio]
# TODO-fixme if batch.size / self.ap.network_wrappers['main'].batch_size is not an integer, we do not train on
# some of the data
for i in range(int(batch.size / self.ap.network_wrappers['main'].batch_size)):
start = i * self.ap.network_wrappers['main'].batch_size
end = (i + 1) * self.ap.network_wrappers['main'].batch_size

View File

@@ -56,7 +56,7 @@ class DDQNAgent(ValueOptimizationAgent):
# initialize with the current prediction so that we will
# only update the action that we have actually done in this transition
TD_errors = []
for i in range(self.ap.network_wrappers['main'].batch_size):
for i in range(batch.size):
new_target = batch.rewards()[i] + \
(1.0 - batch.game_overs()[i]) * self.ap.algorithm.discount * q_st_plus_1[i][selected_actions[i]]
TD_errors.append(np.abs(new_target - TD_targets[i, batch.actions()[i]]))

View File

@@ -146,13 +146,13 @@ class DFPAgent(Agent):
network_inputs = batch.states(network_keys)
network_inputs['goal'] = np.repeat(np.expand_dims(self.current_goal, 0),
self.ap.network_wrappers['main'].batch_size, axis=0)
batch.size, axis=0)
# get the current outputs of the network
targets = self.networks['main'].online_network.predict(network_inputs)
# change the targets for the taken actions
for i in range(self.ap.network_wrappers['main'].batch_size):
for i in range(batch.size):
targets[i, batch.actions()[i]] = batch[i].info['future_measurements'].flatten()
result = self.networks['main'].train_and_sync_networks(network_inputs, targets)

View File

@@ -86,7 +86,7 @@ class DQNAgent(ValueOptimizationAgent):
# only update the action that we have actually done in this transition
TD_errors = []
for i in range(self.ap.network_wrappers['main'].batch_size):
for i in range(batch.size):
new_target = batch.rewards()[i] +\
(1.0 - batch.game_overs()[i]) * self.ap.algorithm.discount * np.max(q_st_plus_1[i], 0)
TD_errors.append(np.abs(new_target - TD_targets[i, batch.actions()[i]]))

View File

@@ -65,7 +65,7 @@ class MixedMonteCarloAgent(ValueOptimizationAgent):
total_returns = batch.n_step_discounted_rewards()
for i in range(self.ap.network_wrappers['main'].batch_size):
for i in range(batch.size):
one_step_target = batch.rewards()[i] + \
(1.0 - batch.game_overs()[i]) * self.ap.algorithm.discount * \
q_st_plus_1[i][selected_actions[i]]

View File

@@ -133,7 +133,7 @@ class NECAgent(ValueOptimizationAgent):
TD_targets = self.networks['main'].online_network.predict(batch.states(network_keys))
bootstrapped_return_from_old_policy = batch.n_step_discounted_rewards()
# only update the action that we have actually done in this transition
for i in range(self.ap.network_wrappers['main'].batch_size):
for i in range(batch.size):
TD_targets[i, batch.actions()[i]] = bootstrapped_return_from_old_policy[i]
# set the gradients to fetch for the DND update

View File

@@ -84,7 +84,7 @@ class PALAgent(ValueOptimizationAgent):
# calculate TD error
TD_targets = np.copy(q_st_online)
total_returns = batch.n_step_discounted_rewards()
for i in range(self.ap.network_wrappers['main'].batch_size):
for i in range(batch.size):
TD_targets[i, batch.actions()[i]] = batch.rewards()[i] + \
(1.0 - batch.game_overs()[i]) * self.ap.algorithm.discount * \
q_st_plus_1_target[i][selected_actions[i]]

View File

@@ -95,7 +95,7 @@ class QuantileRegressionDQNAgent(ValueOptimizationAgent):
target_actions = np.argmax(self.get_q_values(next_state_quantiles), axis=1)
# calculate the Bellman update
batch_idx = list(range(self.ap.network_wrappers['main'].batch_size))
batch_idx = list(range(batch.size))
TD_targets = batch.rewards(True) + (1.0 - batch.game_overs(True)) * self.ap.algorithm.discount \
* next_state_quantiles[batch_idx, target_actions]
@@ -106,9 +106,9 @@ class QuantileRegressionDQNAgent(ValueOptimizationAgent):
# calculate the cumulative quantile probabilities and reorder them to fit the sorted quantiles order
cumulative_probabilities = np.array(range(self.ap.algorithm.atoms + 1)) / float(self.ap.algorithm.atoms) # tau_i
quantile_midpoints = 0.5*(cumulative_probabilities[1:] + cumulative_probabilities[:-1]) # tau^hat_i
quantile_midpoints = np.tile(quantile_midpoints, (self.ap.network_wrappers['main'].batch_size, 1))
quantile_midpoints = np.tile(quantile_midpoints, (batch.size, 1))
sorted_quantiles = np.argsort(current_quantiles[batch_idx, batch.actions()])
for idx in range(self.ap.network_wrappers['main'].batch_size):
for idx in range(batch.size):
quantile_midpoints[idx, :] = quantile_midpoints[idx, sorted_quantiles[idx]]
# train

View File

@@ -103,9 +103,9 @@ class RainbowDQNAgent(CategoricalDQNAgent):
# only update the action that we have actually done in this transition (using the Double-DQN selected actions)
target_actions = ddqn_selected_actions
m = np.zeros((self.ap.network_wrappers['main'].batch_size, self.z_values.size))
m = np.zeros((batch.size, self.z_values.size))
batches = np.arange(self.ap.network_wrappers['main'].batch_size)
batches = np.arange(batch.size)
for j in range(self.z_values.size):
# we use batch.info('should_bootstrap_next_state') instead of (1 - batch.game_overs()) since with n-step,
# we will not bootstrap for the last n-step transitions in the episode

View File

@@ -50,8 +50,9 @@ class ValueOptimizationAgent(Agent):
actions_q_values = None
return actions_q_values
def get_prediction(self, states):
return self.networks['main'].online_network.predict(self.prepare_batch_for_inference(states, 'main'))
def get_prediction(self, states, outputs=None):
return self.networks['main'].online_network.predict(self.prepare_batch_for_inference(states, 'main'),
outputs=outputs)
def update_transition_priorities_and_get_weights(self, TD_errors, batch):
# update errors in prioritized replay buffer
@@ -151,17 +152,18 @@ class ValueOptimizationAgent(Agent):
# this is fitted from the training dataset
for epoch in range(epochs):
loss = 0
total_transitions_processed = 0
for i, batch in enumerate(self.call_memory('get_shuffled_data_generator', batch_size)):
batch = Batch(batch)
current_rewards_prediction_for_all_actions = self.networks['reward_model'].online_network.predict(batch.states(network_keys))
current_rewards_prediction_for_all_actions[range(batch_size), batch.actions()] = batch.rewards()
current_rewards_prediction_for_all_actions[range(batch.size), batch.actions()] = batch.rewards()
loss += self.networks['reward_model'].train_and_sync_networks(
batch.states(network_keys), current_rewards_prediction_for_all_actions)[0]
# print(self.networks['reward_model'].online_network.predict(batch.states(network_keys))[0])
total_transitions_processed += batch.size
log = OrderedDict()
log['Epoch'] = epoch
log['loss'] = loss / int(self.call_memory('num_transitions_in_complete_episodes') / batch_size)
log['loss'] = loss / total_transitions_processed
screen.log_dict(log, prefix='Training Reward Model')

View File

@@ -1,3 +1,4 @@
from .q_head import QHead
from .categorical_q_head import CategoricalQHead
from .ddpg_actor_head import DDPGActor
from .dnd_q_head import DNDQHead
@@ -7,7 +8,6 @@ from .naf_head import NAFHead
from .policy_head import PolicyHead
from .ppo_head import PPOHead
from .ppo_v_head import PPOVHead
from .q_head import QHead
from .quantile_regression_q_head import QuantileRegressionQHead
from .rainbow_q_head import RainbowQHead
from .v_head import VHead

View File

@@ -15,16 +15,15 @@
#
import tensorflow as tf
import numpy as np
from rl_coach.architectures.tensorflow_components.heads import QHead
from rl_coach.architectures.tensorflow_components.layers import Dense
from rl_coach.architectures.tensorflow_components.heads.head import Head
from rl_coach.base_parameters import AgentParameters
from rl_coach.core_types import QActionStateValue
from rl_coach.spaces import SpacesDefinition
class CategoricalQHead(Head):
class CategoricalQHead(QHead):
def __init__(self, agent_parameters: AgentParameters, spaces: SpacesDefinition, network_name: str,
head_idx: int = 0, loss_weight: float = 1., is_local: bool = True, activation_function: str ='relu',
dense_layer=Dense):
@@ -33,7 +32,9 @@ class CategoricalQHead(Head):
self.name = 'categorical_dqn_head'
self.num_actions = len(self.spaces.action.actions)
self.num_atoms = agent_parameters.algorithm.atoms
self.return_type = QActionStateValue
self.z_values = tf.cast(tf.constant(np.linspace(self.ap.algorithm.v_min, self.ap.algorithm.v_max,
self.ap.algorithm.atoms), dtype=tf.float32), dtype=tf.float64)
self.loss_type = []
def _build_module(self, input_layer):
values_distribution = self.dense_layer(self.num_actions * self.num_atoms)(input_layer, name='output')
@@ -49,6 +50,11 @@ class CategoricalQHead(Head):
self.loss = tf.nn.softmax_cross_entropy_with_logits(labels=self.target, logits=values_distribution)
tf.losses.add_loss(self.loss)
self.q_values = tf.tensordot(tf.cast(self.output, tf.float64), self.z_values, 1)
# used in batch-rl to estimate a probablity distribution over actions
self.softmax = self.add_softmax_with_temperature()
def __str__(self):
result = [
"Dense (num outputs = {})".format(self.num_actions * self.num_atoms),

View File

@@ -56,11 +56,14 @@ class DNDQHead(QHead):
# Retrieve info from DND dictionary
# We assume that all actions have enough entries in the DND
self.output = tf.transpose([
self.q_values = self.output = tf.transpose([
self._q_value(input_layer, action)
for action in range(self.num_actions)
])
# used in batch-rl to estimate a probablity distribution over actions
self.softmax = self.add_softmax_with_temperature()
def _q_value(self, input_layer, action):
result = tf.py_func(self.DND.query,
[input_layer, action, self.number_of_nn],

View File

@@ -44,7 +44,10 @@ class DuelingQHead(QHead):
self.action_advantage = self.action_advantage - self.action_mean
# merge to state-action value function Q
self.output = tf.add(self.state_value, self.action_advantage, name='output')
self.q_values = self.output = tf.add(self.state_value, self.action_advantage, name='output')
# used in batch-rl to estimate a probablity distribution over actions
self.softmax = self.add_softmax_with_temperature()
def __str__(self):
result = [

View File

@@ -48,15 +48,19 @@ class QHead(Head):
def _build_module(self, input_layer):
# Standard Q Network
self.output = self.dense_layer(self.num_actions)(input_layer, name='output')
self.q_values = self.output = self.dense_layer(self.num_actions)(input_layer, name='output')
# TODO add this to other Q heads. e.g. dueling.
temperature = self.ap.network_wrappers[self.network_name].softmax_temperature
temperature_scaled_outputs = self.output / temperature
self.softmax = tf.nn.softmax(temperature_scaled_outputs, name="softmax")
# used in batch-rl to estimate a probablity distribution over actions
self.softmax = self.add_softmax_with_temperature()
def __str__(self):
result = [
"Dense (num outputs = {})".format(self.num_actions)
]
return '\n'.join(result)
def add_softmax_with_temperature(self):
temperature = self.ap.network_wrappers[self.network_name].softmax_temperature
temperature_scaled_outputs = self.q_values / temperature
return tf.nn.softmax(temperature_scaled_outputs, name="softmax")

View File

@@ -15,15 +15,14 @@
#
import tensorflow as tf
import numpy as np
from rl_coach.architectures.tensorflow_components.heads import QHead
from rl_coach.architectures.tensorflow_components.layers import Dense
from rl_coach.architectures.tensorflow_components.heads.head import Head
from rl_coach.base_parameters import AgentParameters
from rl_coach.core_types import QActionStateValue
from rl_coach.spaces import SpacesDefinition
class QuantileRegressionQHead(Head):
class QuantileRegressionQHead(QHead):
def __init__(self, agent_parameters: AgentParameters, spaces: SpacesDefinition, network_name: str,
head_idx: int = 0, loss_weight: float = 1., is_local: bool = True, activation_function: str='relu',
dense_layer=Dense):
@@ -33,7 +32,10 @@ class QuantileRegressionQHead(Head):
self.num_actions = len(self.spaces.action.actions)
self.num_atoms = agent_parameters.algorithm.atoms # we use atom / quantile interchangeably
self.huber_loss_interval = agent_parameters.algorithm.huber_loss_interval # k
self.return_type = QActionStateValue
self.quantile_probabilities = tf.cast(
tf.constant(np.ones(self.ap.algorithm.atoms) / float(self.ap.algorithm.atoms), dtype=tf.float32),
dtype=tf.float64)
self.loss_type = []
def _build_module(self, input_layer):
self.actions = tf.placeholder(tf.int32, [None, 2], name="actions")
@@ -72,6 +74,11 @@ class QuantileRegressionQHead(Head):
self.loss = quantile_regression_loss
tf.losses.add_loss(self.loss)
self.q_values = tf.tensordot(tf.cast(self.output, tf.float64), self.quantile_probabilities, 1)
# used in batch-rl to estimate a probablity distribution over actions
self.softmax = self.add_softmax_with_temperature()
def __str__(self):
result = [
"Dense (num outputs = {})".format(self.num_actions * self.num_atoms),

View File

@@ -15,15 +15,14 @@
#
import tensorflow as tf
import numpy as np
from rl_coach.architectures.tensorflow_components.heads import QHead
from rl_coach.architectures.tensorflow_components.layers import Dense
from rl_coach.architectures.tensorflow_components.heads.head import Head
from rl_coach.base_parameters import AgentParameters
from rl_coach.core_types import QActionStateValue
from rl_coach.spaces import SpacesDefinition
class RainbowQHead(Head):
class RainbowQHead(QHead):
def __init__(self, agent_parameters: AgentParameters, spaces: SpacesDefinition, network_name: str,
head_idx: int = 0, loss_weight: float = 1., is_local: bool = True, activation_function: str='relu',
dense_layer=Dense):
@@ -31,8 +30,10 @@ class RainbowQHead(Head):
dense_layer=dense_layer)
self.num_actions = len(self.spaces.action.actions)
self.num_atoms = agent_parameters.algorithm.atoms
self.return_type = QActionStateValue
self.name = 'rainbow_q_values_head'
self.z_values = tf.cast(tf.constant(np.linspace(self.ap.algorithm.v_min, self.ap.algorithm.v_max,
self.ap.algorithm.atoms), dtype=tf.float32), dtype=tf.float64)
self.loss_type = []
def _build_module(self, input_layer):
# state value tower - V
@@ -63,6 +64,11 @@ class RainbowQHead(Head):
self.loss = tf.nn.softmax_cross_entropy_with_logits(labels=self.target, logits=values_distribution)
tf.losses.add_loss(self.loss)
self.q_values = tf.tensordot(tf.cast(self.output, tf.float64), self.z_values, 1)
# used in batch-rl to estimate a probablity distribution over actions
self.softmax = self.add_softmax_with_temperature()
def __str__(self):
result = [
"State Value Stream - V",

View File

@@ -17,6 +17,7 @@ from copy import deepcopy
from typing import Tuple, List, Union
from rl_coach.agents.dqn_agent import DQNAgentParameters
from rl_coach.agents.nec_agent import NECAgentParameters
from rl_coach.base_parameters import AgentParameters, VisualizationParameters, TaskParameters, \
PresetValidationParameters
from rl_coach.core_types import RunPhase
@@ -65,8 +66,8 @@ class BatchRLGraphManager(BasicRLGraphManager):
else:
env = None
# Only DQN variants are supported at this point.
assert(isinstance(self.agent_params, DQNAgentParameters))
# Only DQN variants and NEC are supported at this point.
assert(isinstance(self.agent_params, DQNAgentParameters) or isinstance(self.agent_params, NECAgentParameters))
# Only Episodic memories are supported,
# for evaluating the sequential doubly robust estimator
assert(isinstance(self.agent_params.memory, EpisodicExperienceReplayParameters))

View File

@@ -15,6 +15,7 @@
# limitations under the License.
#
import ast
import math
import pandas as pd
from typing import List, Tuple, Union
@@ -163,9 +164,8 @@ class EpisodicExperienceReplay(Memory):
shuffled_transition_indices = list(range(self.last_training_set_transition_id))
random.shuffle(shuffled_transition_indices)
# we deliberately drop some of the ending data which is left after dividing to batches of size `size`
# for i in range(math.ceil(len(shuffled_transition_indices) / size)):
for i in range(int(len(shuffled_transition_indices) / size)):
# The last batch drawn will usually be < batch_size (=the size variable)
for i in range(math.ceil(len(shuffled_transition_indices) / size)):
sample_data = [self.transitions[j] for j in shuffled_transition_indices[i * size: (i + 1) * size]]
self.reader_writer_lock.release_writing()

View File

@@ -113,10 +113,6 @@ class ExperienceReplay(Memory):
yield sample_data
## usage example
# for o in random_seq_generator(list(range(10)), 4):
# print(o)
def _enforce_max_length(self) -> None:
"""
Make sure that the size of the replay buffer does not pass the maximum size allowed.

View File

@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import math
from collections import namedtuple
import numpy as np
@@ -55,19 +56,18 @@ class OpeManager(object):
all_reward_model_rewards, all_policy_probs, all_old_policy_probs = [], [], []
all_v_values_reward_model_based, all_v_values_q_model_based, all_rewards, all_actions = [], [], [], []
for i in range(int(len(dataset_as_transitions) / batch_size) + 1):
for i in range(math.ceil(len(dataset_as_transitions) / batch_size)):
batch = dataset_as_transitions[i * batch_size: (i + 1) * batch_size]
batch_for_inference = Batch(batch)
all_reward_model_rewards.append(reward_model.predict(
batch_for_inference.states(network_keys)))
# TODO can we get rid of the 'output_heads[0]', and have some way of a cleaner API?
# we always use the first Q head to calculate OPEs. might want to change this in the future.
# for instance, this means that for bootstrapped we always use the first QHead to calculate the OPEs.
q_values, sm_values = q_network.predict(batch_for_inference.states(network_keys),
outputs=[q_network.output_heads[0].output,
outputs=[q_network.output_heads[0].q_values,
q_network.output_heads[0].softmax])
# TODO why is this needed?
q_values = q_values[0]
all_policy_probs.append(sm_values)
all_v_values_reward_model_based.append(np.sum(all_policy_probs[-1] * all_reward_model_rewards[-1], axis=1))

View File

@@ -0,0 +1,55 @@
from rl_coach.agents.qr_dqn_agent import QuantileRegressionDQNAgentParameters
from rl_coach.agents.rainbow_dqn_agent import RainbowDQNAgentParameters
from rl_coach.base_parameters import VisualizationParameters, PresetValidationParameters
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps
from rl_coach.environments.gym_environment import GymVectorEnvironment
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager
from rl_coach.graph_managers.graph_manager import ScheduleParameters
from rl_coach.memories.memory import MemoryGranularity
from rl_coach.schedules import LinearSchedule
####################
# Graph Scheduling #
####################
schedule_params = ScheduleParameters()
schedule_params.improve_steps = TrainingSteps(10000000000)
schedule_params.steps_between_evaluation_periods = EnvironmentEpisodes(10)
schedule_params.evaluation_steps = EnvironmentEpisodes(1)
schedule_params.heatup_steps = EnvironmentSteps(1000)
#########
# Agent #
#########
agent_params = QuantileRegressionDQNAgentParameters()
# DQN params
agent_params.algorithm.num_steps_between_copying_online_weights_to_target = EnvironmentSteps(100)
agent_params.algorithm.discount = 0.99
agent_params.algorithm.num_consecutive_playing_steps = EnvironmentSteps(1)
agent_params.algorithm.atoms = 50
# NN configuration
agent_params.network_wrappers['main'].learning_rate = 0.0005
# ER size
agent_params.memory.max_size = (MemoryGranularity.Transitions, 40000)
# E-Greedy schedule
agent_params.exploration.epsilon_schedule = LinearSchedule(1.0, 0.01, 10000)
################
# Environment #
################
env_params = GymVectorEnvironment(level='CartPole-v0')
########
# Test #
########
preset_validation_params = PresetValidationParameters()
preset_validation_params.test = True
preset_validation_params.min_reward_threshold = 150
preset_validation_params.max_episodes_to_achieve_reward = 250
graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params,
schedule_params=schedule_params, vis_params=VisualizationParameters(),
preset_validation_params=preset_validation_params)