mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 11:10:20 +01:00
@@ -36,6 +36,7 @@ class DQNAlgorithmParameters(AlgorithmParameters):
|
|||||||
self.num_steps_between_copying_online_weights_to_target = EnvironmentSteps(10000)
|
self.num_steps_between_copying_online_weights_to_target = EnvironmentSteps(10000)
|
||||||
self.num_consecutive_playing_steps = EnvironmentSteps(4)
|
self.num_consecutive_playing_steps = EnvironmentSteps(4)
|
||||||
self.discount = 0.99
|
self.discount = 0.99
|
||||||
|
self.supports_parameter_noise = True
|
||||||
|
|
||||||
|
|
||||||
class DQNNetworkParameters(NetworkParameters):
|
class DQNNetworkParameters(NetworkParameters):
|
||||||
|
|||||||
@@ -211,6 +211,9 @@ class AlgorithmParameters(Parameters):
|
|||||||
# Should the workers wait for full episode
|
# Should the workers wait for full episode
|
||||||
self.act_for_full_episodes = False
|
self.act_for_full_episodes = False
|
||||||
|
|
||||||
|
# Support for parameter noise
|
||||||
|
self.supports_parameter_noise = False
|
||||||
|
|
||||||
|
|
||||||
class PresetValidationParameters(Parameters):
|
class PresetValidationParameters(Parameters):
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
|
|||||||
@@ -18,7 +18,6 @@ from typing import List, Dict
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from rl_coach.agents.dqn_agent import DQNAgentParameters
|
|
||||||
from rl_coach.architectures.layers import NoisyNetDense
|
from rl_coach.architectures.layers import NoisyNetDense
|
||||||
from rl_coach.base_parameters import AgentParameters, NetworkParameters
|
from rl_coach.base_parameters import AgentParameters, NetworkParameters
|
||||||
from rl_coach.spaces import ActionSpace, BoxActionSpace, DiscreteActionSpace
|
from rl_coach.spaces import ActionSpace, BoxActionSpace, DiscreteActionSpace
|
||||||
@@ -30,7 +29,8 @@ from rl_coach.exploration_policies.exploration_policy import ExplorationPolicy,
|
|||||||
class ParameterNoiseParameters(ExplorationParameters):
|
class ParameterNoiseParameters(ExplorationParameters):
|
||||||
def __init__(self, agent_params: AgentParameters):
|
def __init__(self, agent_params: AgentParameters):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if not isinstance(agent_params, DQNAgentParameters):
|
|
||||||
|
if not agent_params.algorithm.supports_parameter_noise:
|
||||||
raise ValueError("Currently only DQN variants are supported for using an exploration type of "
|
raise ValueError("Currently only DQN variants are supported for using an exploration type of "
|
||||||
"ParameterNoise.")
|
"ParameterNoise.")
|
||||||
|
|
||||||
|
|||||||
@@ -87,3 +87,4 @@ class ObservationNormalizationFilter(ObservationFilter):
|
|||||||
|
|
||||||
def restore_state_from_checkpoint(self, checkpoint_dir: str, checkpoint_prefix: str):
|
def restore_state_from_checkpoint(self, checkpoint_dir: str, checkpoint_prefix: str):
|
||||||
self.running_observation_stats.restore_state_from_checkpoint(checkpoint_dir, checkpoint_prefix)
|
self.running_observation_stats.restore_state_from_checkpoint(checkpoint_dir, checkpoint_prefix)
|
||||||
|
|
||||||
Reference in New Issue
Block a user