mirror of
https://github.com/gryf/coach.git
synced 2025-12-18 03:30:19 +01:00
temp commit
This commit is contained in:
@@ -69,7 +69,7 @@ class Parameters(object):
|
||||
parameters[k] = dict(v.items())
|
||||
else:
|
||||
parameters[k] = v
|
||||
|
||||
|
||||
return json.dumps(parameters, indent=4, default=repr)
|
||||
|
||||
|
||||
@@ -77,7 +77,7 @@ class AgentParameters(Parameters):
|
||||
agent = ''
|
||||
|
||||
# Architecture parameters
|
||||
input_types = [InputTypes.Observation]
|
||||
input_types = {'observation': InputTypes.Observation}
|
||||
output_types = [OutputTypes.Q]
|
||||
middleware_type = MiddlewareTypes.FC
|
||||
loss_weights = [1.0]
|
||||
@@ -327,7 +327,7 @@ class Human(AgentParameters):
|
||||
|
||||
class NStepQ(AgentParameters):
|
||||
type = 'NStepQAgent'
|
||||
input_types = [InputTypes.Observation]
|
||||
input_types = {'observation': InputTypes.Observation}
|
||||
output_types = [OutputTypes.Q]
|
||||
loss_weights = [1.0]
|
||||
optimizer_type = 'Adam'
|
||||
@@ -343,7 +343,7 @@ class NStepQ(AgentParameters):
|
||||
|
||||
class DQN(AgentParameters):
|
||||
type = 'DQNAgent'
|
||||
input_types = [InputTypes.Observation]
|
||||
input_types = {'observation': InputTypes.Observation}
|
||||
output_types = [OutputTypes.Q]
|
||||
loss_weights = [1.0]
|
||||
optimizer_type = 'Adam'
|
||||
@@ -385,7 +385,7 @@ class QuantileRegressionDQN(DQN):
|
||||
class NEC(AgentParameters):
|
||||
type = 'NECAgent'
|
||||
optimizer_type = 'RMSProp'
|
||||
input_types = [InputTypes.Observation]
|
||||
input_types = {'observation': InputTypes.Observation}
|
||||
output_types = [OutputTypes.DNDQ]
|
||||
loss_weights = [1.0]
|
||||
dnd_size = 500000
|
||||
@@ -399,7 +399,7 @@ class NEC(AgentParameters):
|
||||
|
||||
class ActorCritic(AgentParameters):
|
||||
type = 'ActorCriticAgent'
|
||||
input_types = [InputTypes.Observation]
|
||||
input_types = {'observation': InputTypes.Observation}
|
||||
output_types = [OutputTypes.V, OutputTypes.Pi]
|
||||
loss_weights = [0.5, 1.0]
|
||||
stop_gradients_from_head = [False, False]
|
||||
@@ -417,7 +417,7 @@ class ActorCritic(AgentParameters):
|
||||
|
||||
class PolicyGradient(AgentParameters):
|
||||
type = 'PolicyGradientsAgent'
|
||||
input_types = [InputTypes.Observation]
|
||||
input_types = {'observation': InputTypes.Observation}
|
||||
output_types = [OutputTypes.Pi]
|
||||
loss_weights = [1.0]
|
||||
num_episodes_in_experience_replay = 2
|
||||
@@ -430,7 +430,7 @@ class PolicyGradient(AgentParameters):
|
||||
|
||||
class DDPG(AgentParameters):
|
||||
type = 'DDPGAgent'
|
||||
input_types = [InputTypes.Observation, InputTypes.Action]
|
||||
input_types = {'observation': InputTypes.Observation, 'action': InputTypes.Action}
|
||||
output_types = [OutputTypes.V] # V is used because we only want a single Q value
|
||||
loss_weights = [1.0]
|
||||
hidden_layers_activation_function = 'relu'
|
||||
@@ -443,7 +443,7 @@ class DDPG(AgentParameters):
|
||||
|
||||
class DDDPG(AgentParameters):
|
||||
type = 'DDPGAgent'
|
||||
input_types = [InputTypes.Observation, InputTypes.Action]
|
||||
input_types = {'observation': InputTypes.Observation, 'action': InputTypes.Action}
|
||||
output_types = [OutputTypes.V] # V is used because we only want a single Q value
|
||||
loss_weights = [1.0]
|
||||
hidden_layers_activation_function = 'relu'
|
||||
@@ -456,7 +456,7 @@ class DDDPG(AgentParameters):
|
||||
|
||||
class NAF(AgentParameters):
|
||||
type = 'NAFAgent'
|
||||
input_types = [InputTypes.Observation]
|
||||
input_types = {'observation': InputTypes.Observation}
|
||||
output_types = [OutputTypes.NAF]
|
||||
loss_weights = [1.0]
|
||||
hidden_layers_activation_function = 'tanh'
|
||||
@@ -469,7 +469,7 @@ class NAF(AgentParameters):
|
||||
|
||||
class PPO(AgentParameters):
|
||||
type = 'PPOAgent'
|
||||
input_types = [InputTypes.Observation]
|
||||
input_types = {'observation': InputTypes.Observation}
|
||||
output_types = [OutputTypes.V]
|
||||
loss_weights = [1.0]
|
||||
hidden_layers_activation_function = 'tanh'
|
||||
@@ -489,7 +489,7 @@ class PPO(AgentParameters):
|
||||
|
||||
class ClippedPPO(AgentParameters):
|
||||
type = 'ClippedPPOAgent'
|
||||
input_types = [InputTypes.Observation]
|
||||
input_types = {'observation': InputTypes.Observation}
|
||||
output_types = [OutputTypes.V, OutputTypes.PPO]
|
||||
loss_weights = [0.5, 1.0]
|
||||
stop_gradients_from_head = [False, False]
|
||||
@@ -515,7 +515,11 @@ class ClippedPPO(AgentParameters):
|
||||
|
||||
class DFP(AgentParameters):
|
||||
type = 'DFPAgent'
|
||||
input_types = [InputTypes.Observation, InputTypes.Measurements, InputTypes.GoalVector]
|
||||
input_types = {
|
||||
'observation': InputTypes.Observation,
|
||||
'measurements': InputTypes.Measurements,
|
||||
'goal': InputTypes.GoalVector
|
||||
}
|
||||
output_types = [OutputTypes.MeasurementsPrediction]
|
||||
loss_weights = [1.0]
|
||||
use_measurements = True
|
||||
@@ -527,7 +531,7 @@ class DFP(AgentParameters):
|
||||
|
||||
class MMC(AgentParameters):
|
||||
type = 'MixedMonteCarloAgent'
|
||||
input_types = [InputTypes.Observation]
|
||||
input_types = {'observation': InputTypes.Observation}
|
||||
output_types = [OutputTypes.Q]
|
||||
loss_weights = [1.0]
|
||||
num_steps_between_copying_online_weights_to_target = 1000
|
||||
@@ -537,7 +541,7 @@ class MMC(AgentParameters):
|
||||
|
||||
class PAL(AgentParameters):
|
||||
type = 'PALAgent'
|
||||
input_types = [InputTypes.Observation]
|
||||
input_types = {'observation': InputTypes.Observation}
|
||||
output_types = [OutputTypes.Q]
|
||||
loss_weights = [1.0]
|
||||
pal_alpha = 0.9
|
||||
@@ -548,7 +552,7 @@ class PAL(AgentParameters):
|
||||
|
||||
class BC(AgentParameters):
|
||||
type = 'BCAgent'
|
||||
input_types = [InputTypes.Observation]
|
||||
input_types = {'observation': InputTypes.Observation}
|
||||
output_types = [OutputTypes.Q]
|
||||
loss_weights = [1.0]
|
||||
collect_new_data = False
|
||||
|
||||
Reference in New Issue
Block a user