1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 19:20:19 +01:00

Fix for issue #128 - circular DQN import (#130)

This commit is contained in:
Gal Leibovich
2018-12-16 16:06:44 +02:00
committed by GitHub
parent e08accdc22
commit f9ee526536
4 changed files with 7 additions and 2 deletions

View File

@@ -36,6 +36,7 @@ class DQNAlgorithmParameters(AlgorithmParameters):
self.num_steps_between_copying_online_weights_to_target = EnvironmentSteps(10000)
self.num_consecutive_playing_steps = EnvironmentSteps(4)
self.discount = 0.99
self.supports_parameter_noise = True
class DQNNetworkParameters(NetworkParameters):

View File

@@ -211,6 +211,9 @@ class AlgorithmParameters(Parameters):
# Should the workers wait for full episode
self.act_for_full_episodes = False
# Support for parameter noise
self.supports_parameter_noise = False
class PresetValidationParameters(Parameters):
def __init__(self,

View File

@@ -18,7 +18,6 @@ from typing import List, Dict
import numpy as np
from rl_coach.agents.dqn_agent import DQNAgentParameters
from rl_coach.architectures.layers import NoisyNetDense
from rl_coach.base_parameters import AgentParameters, NetworkParameters
from rl_coach.spaces import ActionSpace, BoxActionSpace, DiscreteActionSpace
@@ -30,7 +29,8 @@ from rl_coach.exploration_policies.exploration_policy import ExplorationPolicy,
class ParameterNoiseParameters(ExplorationParameters):
def __init__(self, agent_params: AgentParameters):
super().__init__()
if not isinstance(agent_params, DQNAgentParameters):
if not agent_params.algorithm.supports_parameter_noise:
raise ValueError("Currently only DQN variants are supported for using an exploration type of "
"ParameterNoise.")

View File

@@ -87,3 +87,4 @@ class ObservationNormalizationFilter(ObservationFilter):
def restore_state_from_checkpoint(self, checkpoint_dir: str, checkpoint_prefix: str):
self.running_observation_stats.restore_state_from_checkpoint(checkpoint_dir, checkpoint_prefix)