1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-18 03:30:19 +01:00

temp commit

This commit is contained in:
Zach Dwiel
2018-02-16 09:35:58 -05:00
parent 16c5032735
commit 85afb86893
14 changed files with 244 additions and 127 deletions

View File

@@ -69,7 +69,7 @@ class Parameters(object):
parameters[k] = dict(v.items())
else:
parameters[k] = v
return json.dumps(parameters, indent=4, default=repr)
@@ -77,7 +77,7 @@ class AgentParameters(Parameters):
agent = ''
# Architecture parameters
input_types = [InputTypes.Observation]
input_types = {'observation': InputTypes.Observation}
output_types = [OutputTypes.Q]
middleware_type = MiddlewareTypes.FC
loss_weights = [1.0]
@@ -327,7 +327,7 @@ class Human(AgentParameters):
class NStepQ(AgentParameters):
type = 'NStepQAgent'
input_types = [InputTypes.Observation]
input_types = {'observation': InputTypes.Observation}
output_types = [OutputTypes.Q]
loss_weights = [1.0]
optimizer_type = 'Adam'
@@ -343,7 +343,7 @@ class NStepQ(AgentParameters):
class DQN(AgentParameters):
type = 'DQNAgent'
input_types = [InputTypes.Observation]
input_types = {'observation': InputTypes.Observation}
output_types = [OutputTypes.Q]
loss_weights = [1.0]
optimizer_type = 'Adam'
@@ -385,7 +385,7 @@ class QuantileRegressionDQN(DQN):
class NEC(AgentParameters):
type = 'NECAgent'
optimizer_type = 'RMSProp'
input_types = [InputTypes.Observation]
input_types = {'observation': InputTypes.Observation}
output_types = [OutputTypes.DNDQ]
loss_weights = [1.0]
dnd_size = 500000
@@ -399,7 +399,7 @@ class NEC(AgentParameters):
class ActorCritic(AgentParameters):
type = 'ActorCriticAgent'
input_types = [InputTypes.Observation]
input_types = {'observation': InputTypes.Observation}
output_types = [OutputTypes.V, OutputTypes.Pi]
loss_weights = [0.5, 1.0]
stop_gradients_from_head = [False, False]
@@ -417,7 +417,7 @@ class ActorCritic(AgentParameters):
class PolicyGradient(AgentParameters):
type = 'PolicyGradientsAgent'
input_types = [InputTypes.Observation]
input_types = {'observation': InputTypes.Observation}
output_types = [OutputTypes.Pi]
loss_weights = [1.0]
num_episodes_in_experience_replay = 2
@@ -430,7 +430,7 @@ class PolicyGradient(AgentParameters):
class DDPG(AgentParameters):
type = 'DDPGAgent'
input_types = [InputTypes.Observation, InputTypes.Action]
input_types = {'observation': InputTypes.Observation, 'action': InputTypes.Action}
output_types = [OutputTypes.V] # V is used because we only want a single Q value
loss_weights = [1.0]
hidden_layers_activation_function = 'relu'
@@ -443,7 +443,7 @@ class DDPG(AgentParameters):
class DDDPG(AgentParameters):
type = 'DDPGAgent'
input_types = [InputTypes.Observation, InputTypes.Action]
input_types = {'observation': InputTypes.Observation, 'action': InputTypes.Action}
output_types = [OutputTypes.V] # V is used because we only want a single Q value
loss_weights = [1.0]
hidden_layers_activation_function = 'relu'
@@ -456,7 +456,7 @@ class DDDPG(AgentParameters):
class NAF(AgentParameters):
type = 'NAFAgent'
input_types = [InputTypes.Observation]
input_types = {'observation': InputTypes.Observation}
output_types = [OutputTypes.NAF]
loss_weights = [1.0]
hidden_layers_activation_function = 'tanh'
@@ -469,7 +469,7 @@ class NAF(AgentParameters):
class PPO(AgentParameters):
type = 'PPOAgent'
input_types = [InputTypes.Observation]
input_types = {'observation': InputTypes.Observation}
output_types = [OutputTypes.V]
loss_weights = [1.0]
hidden_layers_activation_function = 'tanh'
@@ -489,7 +489,7 @@ class PPO(AgentParameters):
class ClippedPPO(AgentParameters):
type = 'ClippedPPOAgent'
input_types = [InputTypes.Observation]
input_types = {'observation': InputTypes.Observation}
output_types = [OutputTypes.V, OutputTypes.PPO]
loss_weights = [0.5, 1.0]
stop_gradients_from_head = [False, False]
@@ -515,7 +515,11 @@ class ClippedPPO(AgentParameters):
class DFP(AgentParameters):
type = 'DFPAgent'
input_types = [InputTypes.Observation, InputTypes.Measurements, InputTypes.GoalVector]
input_types = {
'observation': InputTypes.Observation,
'measurements': InputTypes.Measurements,
'goal': InputTypes.GoalVector
}
output_types = [OutputTypes.MeasurementsPrediction]
loss_weights = [1.0]
use_measurements = True
@@ -527,7 +531,7 @@ class DFP(AgentParameters):
class MMC(AgentParameters):
type = 'MixedMonteCarloAgent'
input_types = [InputTypes.Observation]
input_types = {'observation': InputTypes.Observation}
output_types = [OutputTypes.Q]
loss_weights = [1.0]
num_steps_between_copying_online_weights_to_target = 1000
@@ -537,7 +541,7 @@ class MMC(AgentParameters):
class PAL(AgentParameters):
type = 'PALAgent'
input_types = [InputTypes.Observation]
input_types = {'observation': InputTypes.Observation}
output_types = [OutputTypes.Q]
loss_weights = [1.0]
pal_alpha = 0.9
@@ -548,7 +552,7 @@ class PAL(AgentParameters):
class BC(AgentParameters):
type = 'BCAgent'
input_types = [InputTypes.Observation]
input_types = {'observation': InputTypes.Observation}
output_types = [OutputTypes.Q]
loss_weights = [1.0]
collect_new_data = False