mirror of
https://github.com/gryf/coach.git
synced 2025-12-18 03:30:19 +01:00
network_imporvements branch merge
This commit is contained in:
@@ -25,14 +25,38 @@ from rl_coach.architectures.tensorflow_components.heads.v_head import VHeadParam
|
||||
from rl_coach.architectures.tensorflow_components.middlewares.fc_middleware import FCMiddlewareParameters
|
||||
from rl_coach.base_parameters import AlgorithmParameters, NetworkParameters, \
|
||||
AgentParameters
|
||||
from rl_coach.exploration_policies.categorical import CategoricalParameters
|
||||
from rl_coach.exploration_policies.continuous_entropy import ContinuousEntropyParameters
|
||||
from rl_coach.logger import screen
|
||||
from rl_coach.memories.episodic.single_episode_buffer import SingleEpisodeBufferParameters
|
||||
from rl_coach.spaces import DiscreteActionSpace
|
||||
from rl_coach.spaces import DiscreteActionSpace, BoxActionSpace
|
||||
from rl_coach.utils import last_sample
|
||||
from rl_coach.architectures.tensorflow_components.embedders.embedder import InputEmbedderParameters
|
||||
|
||||
|
||||
class ActorCriticAlgorithmParameters(AlgorithmParameters):
|
||||
"""
|
||||
:param policy_gradient_rescaler: (PolicyGradientRescaler)
|
||||
The value that will be used to rescale the policy gradient
|
||||
|
||||
:param apply_gradients_every_x_episodes: (int)
|
||||
The number of episodes to wait before applying the accumulated gradients to the network.
|
||||
The training iterations only accumulate gradients without actually applying them.
|
||||
|
||||
:param beta_entropy: (float)
|
||||
The weight that will be given to the entropy regularization which is used in order to improve exploration.
|
||||
|
||||
:param num_steps_between_gradient_updates: (int)
|
||||
Every num_steps_between_gradient_updates transitions will be considered as a single batch and use for
|
||||
accumulating gradients. This is also the number of steps used for bootstrapping according to the n-step formulation.
|
||||
|
||||
:param gae_lambda: (float)
|
||||
If the policy gradient rescaler was defined as PolicyGradientRescaler.GAE, the generalized advantage estimation
|
||||
scheme will be used, in which case the lambda value controls the decay for the different n-step lengths.
|
||||
|
||||
:param estimate_state_value_using_gae: (bool)
|
||||
If set to True, the state value targets for the V head will be estimated using the GAE scheme.
|
||||
"""
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.policy_gradient_rescaler = PolicyGradientRescaler.A_VALUE
|
||||
@@ -48,9 +72,7 @@ class ActorCriticNetworkParameters(NetworkParameters):
|
||||
super().__init__()
|
||||
self.input_embedders_parameters = {'observation': InputEmbedderParameters()}
|
||||
self.middleware_parameters = FCMiddlewareParameters()
|
||||
self.heads_parameters = [VHeadParameters(), PolicyHeadParameters()]
|
||||
self.loss_weights = [0.5, 1.0]
|
||||
self.rescale_gradient_from_head_by_factor = [1, 1]
|
||||
self.heads_parameters = [VHeadParameters(loss_weight=0.5), PolicyHeadParameters(loss_weight=1.0)]
|
||||
self.optimizer_type = 'Adam'
|
||||
self.clip_gradients = 40.0
|
||||
self.async_training = True
|
||||
@@ -59,8 +81,8 @@ class ActorCriticNetworkParameters(NetworkParameters):
|
||||
class ActorCriticAgentParameters(AgentParameters):
|
||||
def __init__(self):
|
||||
super().__init__(algorithm=ActorCriticAlgorithmParameters(),
|
||||
exploration=None, #TODO this should be different for continuous (ContinuousEntropyExploration)
|
||||
# and discrete (CategoricalExploration) action spaces.
|
||||
exploration={DiscreteActionSpace: CategoricalParameters(),
|
||||
BoxActionSpace: ContinuousEntropyParameters()},
|
||||
memory=SingleEpisodeBufferParameters(),
|
||||
networks={"main": ActorCriticNetworkParameters()})
|
||||
|
||||
|
||||
@@ -157,6 +157,10 @@ class Agent(AgentInterface):
|
||||
if self.ap.task_parameters.seed is not None:
|
||||
random.seed(self.ap.task_parameters.seed)
|
||||
np.random.seed(self.ap.task_parameters.seed)
|
||||
else:
|
||||
# we need to seed the RNG since the different processes are initialized with the same parent seed
|
||||
random.seed()
|
||||
np.random.seed()
|
||||
|
||||
@property
|
||||
def parent(self):
|
||||
@@ -269,6 +273,10 @@ class Agent(AgentInterface):
|
||||
spaces=self.spaces,
|
||||
replicated_device=self.replicated_device,
|
||||
worker_device=self.worker_device)
|
||||
|
||||
if self.ap.visualization.print_networks_summary:
|
||||
print(networks[network_name])
|
||||
|
||||
return networks
|
||||
|
||||
def init_environment_dependent_modules(self) -> None:
|
||||
@@ -278,6 +286,14 @@ class Agent(AgentInterface):
|
||||
:return: None
|
||||
"""
|
||||
# initialize exploration policy
|
||||
if isinstance(self.ap.exploration, dict):
|
||||
if self.spaces.action.__class__ in self.ap.exploration.keys():
|
||||
self.ap.exploration = self.ap.exploration[self.spaces.action.__class__]
|
||||
else:
|
||||
raise ValueError("The exploration parameters were defined as a mapping between action space types and "
|
||||
"exploration types, but the action space used by the environment ({}) was not part of "
|
||||
"the exploration parameters dictionary keys ({})"
|
||||
.format(self.spaces.action.__class__, list(self.ap.exploration.keys())))
|
||||
self.ap.exploration.action_space = self.spaces.action
|
||||
self.exploration_policy = dynamic_import_and_instantiate_module_from_params(self.ap.exploration)
|
||||
|
||||
@@ -543,6 +559,9 @@ class Agent(AgentInterface):
|
||||
"""
|
||||
loss = 0
|
||||
if self._should_train():
|
||||
for network in self.networks.values():
|
||||
network.set_is_training(True)
|
||||
|
||||
for training_step in range(self.ap.algorithm.num_consecutive_training_steps):
|
||||
# TODO: this should be network dependent
|
||||
network_parameters = list(self.ap.network_wrappers.values())[0]
|
||||
@@ -586,9 +605,14 @@ class Agent(AgentInterface):
|
||||
if self.imitation:
|
||||
self.log_to_screen()
|
||||
|
||||
for network in self.networks.values():
|
||||
network.set_is_training(False)
|
||||
|
||||
# run additional commands after the training is done
|
||||
self.post_training_commands()
|
||||
|
||||
|
||||
|
||||
return loss
|
||||
|
||||
def choose_action(self, curr_state):
|
||||
|
||||
@@ -26,6 +26,7 @@ from rl_coach.base_parameters import AgentParameters, AlgorithmParameters, Netwo
|
||||
from rl_coach.exploration_policies.e_greedy import EGreedyParameters
|
||||
from rl_coach.memories.episodic.episodic_experience_replay import EpisodicExperienceReplayParameters
|
||||
from rl_coach.architectures.tensorflow_components.embedders.embedder import InputEmbedderParameters
|
||||
from rl_coach.memories.non_episodic.experience_replay import ExperienceReplayParameters
|
||||
|
||||
|
||||
class BCAlgorithmParameters(AlgorithmParameters):
|
||||
@@ -40,7 +41,6 @@ class BCNetworkParameters(NetworkParameters):
|
||||
self.input_embedders_parameters = {'observation': InputEmbedderParameters()}
|
||||
self.middleware_parameters = FCMiddlewareParameters(scheme=MiddlewareScheme.Medium)
|
||||
self.heads_parameters = [PolicyHeadParameters()]
|
||||
self.loss_weights = [1.0]
|
||||
self.optimizer_type = 'Adam'
|
||||
self.batch_size = 32
|
||||
self.replace_mse_with_huber_loss = False
|
||||
@@ -51,7 +51,7 @@ class BCAgentParameters(AgentParameters):
|
||||
def __init__(self):
|
||||
super().__init__(algorithm=BCAlgorithmParameters(),
|
||||
exploration=EGreedyParameters(),
|
||||
memory=EpisodicExperienceReplayParameters(),
|
||||
memory=ExperienceReplayParameters(),
|
||||
networks={"main": BCNetworkParameters()})
|
||||
|
||||
@property
|
||||
|
||||
@@ -28,8 +28,8 @@ from rl_coach.memories.non_episodic.experience_replay import ExperienceReplayPar
|
||||
class BootstrappedDQNNetworkParameters(DQNNetworkParameters):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.num_output_head_copies = 10
|
||||
self.rescale_gradient_from_head_by_factor = [1.0/self.num_output_head_copies]*self.num_output_head_copies
|
||||
self.heads_parameters[0].num_output_head_copies = 10
|
||||
self.heads_parameters[0].rescale_gradient_from_head_by_factor = 1.0/self.heads_parameters[0].num_output_head_copies
|
||||
|
||||
|
||||
class BootstrappedDQNAgentParameters(AgentParameters):
|
||||
|
||||
@@ -38,7 +38,6 @@ class CILNetworkParameters(NetworkParameters):
|
||||
self.input_embedders_parameters = {'observation': InputEmbedderParameters()}
|
||||
self.middleware_parameters = FCMiddlewareParameters(scheme=MiddlewareScheme.Medium)
|
||||
self.heads_parameters = [RegressionHeadParameters()]
|
||||
self.loss_weights = [1.0]
|
||||
self.optimizer_type = 'Adam'
|
||||
self.batch_size = 32
|
||||
self.replace_mse_with_huber_loss = False
|
||||
|
||||
@@ -31,10 +31,11 @@ from rl_coach.base_parameters import AlgorithmParameters, NetworkParameters, \
|
||||
from rl_coach.architectures.tensorflow_components.embedders.embedder import InputEmbedderParameters
|
||||
from rl_coach.core_types import EnvironmentSteps, Batch, EnvResponse, StateType
|
||||
from rl_coach.exploration_policies.additive_noise import AdditiveNoiseParameters
|
||||
from rl_coach.exploration_policies.categorical import CategoricalParameters
|
||||
from rl_coach.logger import screen
|
||||
from rl_coach.memories.episodic.episodic_experience_replay import EpisodicExperienceReplayParameters
|
||||
from rl_coach.schedules import ConstantSchedule
|
||||
from rl_coach.spaces import DiscreteActionSpace
|
||||
from rl_coach.spaces import DiscreteActionSpace, BoxActionSpace
|
||||
|
||||
|
||||
class ClippedPPONetworkParameters(NetworkParameters):
|
||||
@@ -43,8 +44,6 @@ class ClippedPPONetworkParameters(NetworkParameters):
|
||||
self.input_embedders_parameters = {'observation': InputEmbedderParameters(activation_function='tanh')}
|
||||
self.middleware_parameters = FCMiddlewareParameters(activation_function='tanh')
|
||||
self.heads_parameters = [VHeadParameters(), PPOHeadParameters()]
|
||||
self.loss_weights = [1.0, 1.0]
|
||||
self.rescale_gradient_from_head_by_factor = [1, 1]
|
||||
self.batch_size = 64
|
||||
self.optimizer_type = 'Adam'
|
||||
self.clip_gradients = None
|
||||
@@ -79,7 +78,8 @@ class ClippedPPOAlgorithmParameters(AlgorithmParameters):
|
||||
class ClippedPPOAgentParameters(AgentParameters):
|
||||
def __init__(self):
|
||||
super().__init__(algorithm=ClippedPPOAlgorithmParameters(),
|
||||
exploration=AdditiveNoiseParameters(),
|
||||
exploration={DiscreteActionSpace: CategoricalParameters(),
|
||||
BoxActionSpace: AdditiveNoiseParameters()},
|
||||
memory=EpisodicExperienceReplayParameters(),
|
||||
networks={"main": ClippedPPONetworkParameters()})
|
||||
|
||||
@@ -253,6 +253,9 @@ class ClippedPPOAgent(ActorCriticAgent):
|
||||
|
||||
def train(self):
|
||||
if self._should_train(wait_for_full_episode=True):
|
||||
for network in self.networks.values():
|
||||
network.set_is_training(True)
|
||||
|
||||
dataset = self.memory.transitions
|
||||
dataset = self.pre_network_filter.filter(dataset, deep_copy=False)
|
||||
batch = Batch(dataset)
|
||||
@@ -269,6 +272,9 @@ class ClippedPPOAgent(ActorCriticAgent):
|
||||
|
||||
self.train_network(batch, self.ap.algorithm.optimization_epochs)
|
||||
|
||||
for network in self.networks.values():
|
||||
network.set_is_training(False)
|
||||
|
||||
self.post_training_commands()
|
||||
self.training_iteration += 1
|
||||
# should be done in order to update the data that has been accumulated * while not playing *
|
||||
|
||||
@@ -41,8 +41,6 @@ class DDPGCriticNetworkParameters(NetworkParameters):
|
||||
'action': InputEmbedderParameters(scheme=EmbedderScheme.Shallow)}
|
||||
self.middleware_parameters = FCMiddlewareParameters()
|
||||
self.heads_parameters = [VHeadParameters()]
|
||||
self.loss_weights = [1.0]
|
||||
self.rescale_gradient_from_head_by_factor = [1]
|
||||
self.optimizer_type = 'Adam'
|
||||
self.batch_size = 64
|
||||
self.async_training = False
|
||||
@@ -58,8 +56,6 @@ class DDPGActorNetworkParameters(NetworkParameters):
|
||||
self.input_embedders_parameters = {'observation': InputEmbedderParameters(batchnorm=True)}
|
||||
self.middleware_parameters = FCMiddlewareParameters(batchnorm=True)
|
||||
self.heads_parameters = [DDPGActorHeadParameters()]
|
||||
self.loss_weights = [1.0]
|
||||
self.rescale_gradient_from_head_by_factor = [1]
|
||||
self.optimizer_type = 'Adam'
|
||||
self.batch_size = 64
|
||||
self.async_training = False
|
||||
|
||||
@@ -21,7 +21,7 @@ from typing import Union
|
||||
import numpy as np
|
||||
|
||||
from rl_coach.agents.agent import Agent
|
||||
from rl_coach.architectures.tensorflow_components.architecture import Conv2d, Dense
|
||||
from rl_coach.architectures.tensorflow_components.layers import Conv2d, Dense
|
||||
from rl_coach.architectures.tensorflow_components.heads.measurements_prediction_head import \
|
||||
MeasurementsPredictionHeadParameters
|
||||
from rl_coach.architectures.tensorflow_components.middlewares.fc_middleware import FCMiddlewareParameters
|
||||
@@ -48,28 +48,27 @@ class DFPNetworkParameters(NetworkParameters):
|
||||
'goal': InputEmbedderParameters(activation_function='leaky_relu')}
|
||||
|
||||
self.input_embedders_parameters['observation'].scheme = [
|
||||
Conv2d([32, 8, 4]),
|
||||
Conv2d([64, 4, 2]),
|
||||
Conv2d([64, 3, 1]),
|
||||
Dense([512]),
|
||||
Conv2d(32, 8, 4),
|
||||
Conv2d(64, 4, 2),
|
||||
Conv2d(64, 3, 1),
|
||||
Dense(512),
|
||||
]
|
||||
|
||||
self.input_embedders_parameters['measurements'].scheme = [
|
||||
Dense([128]),
|
||||
Dense([128]),
|
||||
Dense([128]),
|
||||
Dense(128),
|
||||
Dense(128),
|
||||
Dense(128),
|
||||
]
|
||||
|
||||
self.input_embedders_parameters['goal'].scheme = [
|
||||
Dense([128]),
|
||||
Dense([128]),
|
||||
Dense([128]),
|
||||
Dense(128),
|
||||
Dense(128),
|
||||
Dense(128),
|
||||
]
|
||||
|
||||
self.middleware_parameters = FCMiddlewareParameters(activation_function='leaky_relu',
|
||||
scheme=MiddlewareScheme.Empty)
|
||||
self.heads_parameters = [MeasurementsPredictionHeadParameters(activation_function='leaky_relu')]
|
||||
self.loss_weights = [1.0]
|
||||
self.async_training = False
|
||||
self.batch_size = 64
|
||||
self.adam_optimizer_beta1 = 0.95
|
||||
|
||||
@@ -44,7 +44,6 @@ class DQNNetworkParameters(NetworkParameters):
|
||||
self.input_embedders_parameters = {'observation': InputEmbedderParameters()}
|
||||
self.middleware_parameters = FCMiddlewareParameters(scheme=MiddlewareScheme.Medium)
|
||||
self.heads_parameters = [QHeadParameters()]
|
||||
self.loss_weights = [1.0]
|
||||
self.optimizer_type = 'Adam'
|
||||
self.batch_size = 32
|
||||
self.replace_mse_with_huber_loss = True
|
||||
|
||||
@@ -46,8 +46,6 @@ class HumanNetworkParameters(NetworkParameters):
|
||||
self.input_embedders_parameters = {'observation': InputEmbedderParameters()}
|
||||
self.input_embedders_parameters['observation'].scheme = EmbedderScheme.Medium
|
||||
self.middleware_parameters = FCMiddlewareParameters()
|
||||
self.heads_parameters = [PolicyHeadParameters()]
|
||||
self.loss_weights = [1.0]
|
||||
self.optimizer_type = 'Adam'
|
||||
self.batch_size = 32
|
||||
self.replace_mse_with_huber_loss = False
|
||||
|
||||
@@ -37,7 +37,6 @@ class NStepQNetworkParameters(NetworkParameters):
|
||||
self.input_embedders_parameters = {'observation': InputEmbedderParameters()}
|
||||
self.middleware_parameters = FCMiddlewareParameters()
|
||||
self.heads_parameters = [QHeadParameters()]
|
||||
self.loss_weights = [1.0]
|
||||
self.optimizer_type = 'Adam'
|
||||
self.async_training = True
|
||||
self.shared_optimizer = True
|
||||
|
||||
@@ -37,7 +37,6 @@ class NAFNetworkParameters(NetworkParameters):
|
||||
self.input_embedders_parameters = {'observation': InputEmbedderParameters()}
|
||||
self.middleware_parameters = FCMiddlewareParameters()
|
||||
self.heads_parameters = [NAFHeadParameters()]
|
||||
self.loss_weights = [1.0]
|
||||
self.optimizer_type = 'Adam'
|
||||
self.learning_rate = 0.001
|
||||
self.async_training = True
|
||||
|
||||
@@ -39,8 +39,6 @@ class NECNetworkParameters(NetworkParameters):
|
||||
self.input_embedders_parameters = {'observation': InputEmbedderParameters()}
|
||||
self.middleware_parameters = FCMiddlewareParameters()
|
||||
self.heads_parameters = [DNDQHeadParameters()]
|
||||
self.loss_weights = [1.0]
|
||||
self.rescale_gradient_from_head_by_factor = [1]
|
||||
self.optimizer_type = 'Adam'
|
||||
|
||||
|
||||
|
||||
@@ -26,9 +26,10 @@ from rl_coach.base_parameters import NetworkParameters, AlgorithmParameters, \
|
||||
from rl_coach.architectures.tensorflow_components.embedders.embedder import InputEmbedderParameters
|
||||
|
||||
from rl_coach.exploration_policies.additive_noise import AdditiveNoiseParameters
|
||||
from rl_coach.exploration_policies.categorical import CategoricalParameters
|
||||
from rl_coach.logger import screen
|
||||
from rl_coach.memories.episodic.single_episode_buffer import SingleEpisodeBufferParameters
|
||||
from rl_coach.spaces import DiscreteActionSpace
|
||||
from rl_coach.spaces import DiscreteActionSpace, BoxActionSpace
|
||||
|
||||
|
||||
class PolicyGradientNetworkParameters(NetworkParameters):
|
||||
@@ -37,7 +38,6 @@ class PolicyGradientNetworkParameters(NetworkParameters):
|
||||
self.input_embedders_parameters = {'observation': InputEmbedderParameters()}
|
||||
self.middleware_parameters = FCMiddlewareParameters()
|
||||
self.heads_parameters = [PolicyHeadParameters()]
|
||||
self.loss_weights = [1.0]
|
||||
self.async_training = True
|
||||
|
||||
|
||||
@@ -53,7 +53,8 @@ class PolicyGradientAlgorithmParameters(AlgorithmParameters):
|
||||
class PolicyGradientsAgentParameters(AgentParameters):
|
||||
def __init__(self):
|
||||
super().__init__(algorithm=PolicyGradientAlgorithmParameters(),
|
||||
exploration=AdditiveNoiseParameters(),
|
||||
exploration={DiscreteActionSpace: CategoricalParameters(),
|
||||
BoxActionSpace: AdditiveNoiseParameters()},
|
||||
memory=SingleEpisodeBufferParameters(),
|
||||
networks={"main": PolicyGradientNetworkParameters()})
|
||||
|
||||
|
||||
@@ -93,6 +93,8 @@ class PolicyOptimizationAgent(Agent):
|
||||
|
||||
total_loss = 0
|
||||
if num_steps_passed_since_last_update > 0:
|
||||
for network in self.networks.values():
|
||||
network.set_is_training(True)
|
||||
|
||||
# we need to update the returns of the episode until now
|
||||
episode.update_returns()
|
||||
@@ -124,6 +126,9 @@ class PolicyOptimizationAgent(Agent):
|
||||
network.apply_gradients_and_sync_networks()
|
||||
self.training_iteration += 1
|
||||
|
||||
for network in self.networks.values():
|
||||
network.set_is_training(False)
|
||||
|
||||
# run additional commands after the training is done
|
||||
self.post_training_commands()
|
||||
|
||||
|
||||
@@ -31,9 +31,10 @@ from rl_coach.architectures.tensorflow_components.embedders.embedder import Inpu
|
||||
|
||||
from rl_coach.core_types import EnvironmentSteps, Batch
|
||||
from rl_coach.exploration_policies.additive_noise import AdditiveNoiseParameters
|
||||
from rl_coach.exploration_policies.categorical import CategoricalParameters
|
||||
from rl_coach.logger import screen
|
||||
from rl_coach.memories.episodic.episodic_experience_replay import EpisodicExperienceReplayParameters
|
||||
from rl_coach.spaces import DiscreteActionSpace
|
||||
from rl_coach.spaces import DiscreteActionSpace, BoxActionSpace
|
||||
from rl_coach.utils import force_list
|
||||
|
||||
|
||||
@@ -43,7 +44,6 @@ class PPOCriticNetworkParameters(NetworkParameters):
|
||||
self.input_embedders_parameters = {'observation': InputEmbedderParameters(activation_function='tanh')}
|
||||
self.middleware_parameters = FCMiddlewareParameters(activation_function='tanh')
|
||||
self.heads_parameters = [VHeadParameters()]
|
||||
self.loss_weights = [1.0]
|
||||
self.async_training = True
|
||||
self.l2_regularization = 0
|
||||
self.create_target_network = True
|
||||
@@ -57,7 +57,6 @@ class PPOActorNetworkParameters(NetworkParameters):
|
||||
self.middleware_parameters = FCMiddlewareParameters(activation_function='tanh')
|
||||
self.heads_parameters = [PPOHeadParameters()]
|
||||
self.optimizer_type = 'Adam'
|
||||
self.loss_weights = [1.0]
|
||||
self.async_training = True
|
||||
self.l2_regularization = 0
|
||||
self.create_target_network = True
|
||||
@@ -84,7 +83,8 @@ class PPOAlgorithmParameters(AlgorithmParameters):
|
||||
class PPOAgentParameters(AgentParameters):
|
||||
def __init__(self):
|
||||
super().__init__(algorithm=PPOAlgorithmParameters(),
|
||||
exploration=AdditiveNoiseParameters(),
|
||||
exploration={DiscreteActionSpace: CategoricalParameters(),
|
||||
BoxActionSpace: AdditiveNoiseParameters()},
|
||||
memory=EpisodicExperienceReplayParameters(),
|
||||
networks={"critic": PPOCriticNetworkParameters(), "actor": PPOActorNetworkParameters()})
|
||||
|
||||
@@ -313,6 +313,9 @@ class PPOAgent(ActorCriticAgent):
|
||||
def train(self):
|
||||
loss = 0
|
||||
if self._should_train(wait_for_full_episode=True):
|
||||
for network in self.networks.values():
|
||||
network.set_is_training(True)
|
||||
|
||||
for training_step in range(self.ap.algorithm.num_consecutive_training_steps):
|
||||
self.networks['actor'].sync()
|
||||
self.networks['critic'].sync()
|
||||
@@ -330,6 +333,9 @@ class PPOAgent(ActorCriticAgent):
|
||||
self.value_loss.add_sample(value_loss)
|
||||
self.policy_loss.add_sample(policy_loss)
|
||||
|
||||
for network in self.networks.values():
|
||||
network.set_is_training(False)
|
||||
|
||||
self.post_training_commands()
|
||||
self.training_iteration += 1
|
||||
self.update_log() # should be done in order to update the data that has been accumulated * while not playing *
|
||||
|
||||
Reference in New Issue
Block a user