mirror of
https://github.com/gryf/coach.git
synced 2026-03-03 23:35:51 +01:00
network_imporvements branch merge
This commit is contained in:
@@ -26,6 +26,7 @@ from rl_coach.base_parameters import AgentParameters, AlgorithmParameters, Netwo
|
||||
from rl_coach.exploration_policies.e_greedy import EGreedyParameters
|
||||
from rl_coach.memories.episodic.episodic_experience_replay import EpisodicExperienceReplayParameters
|
||||
from rl_coach.architectures.tensorflow_components.embedders.embedder import InputEmbedderParameters
|
||||
from rl_coach.memories.non_episodic.experience_replay import ExperienceReplayParameters
|
||||
|
||||
|
||||
class BCAlgorithmParameters(AlgorithmParameters):
|
||||
@@ -40,7 +41,6 @@ class BCNetworkParameters(NetworkParameters):
|
||||
self.input_embedders_parameters = {'observation': InputEmbedderParameters()}
|
||||
self.middleware_parameters = FCMiddlewareParameters(scheme=MiddlewareScheme.Medium)
|
||||
self.heads_parameters = [PolicyHeadParameters()]
|
||||
self.loss_weights = [1.0]
|
||||
self.optimizer_type = 'Adam'
|
||||
self.batch_size = 32
|
||||
self.replace_mse_with_huber_loss = False
|
||||
@@ -51,7 +51,7 @@ class BCAgentParameters(AgentParameters):
|
||||
def __init__(self):
|
||||
super().__init__(algorithm=BCAlgorithmParameters(),
|
||||
exploration=EGreedyParameters(),
|
||||
memory=EpisodicExperienceReplayParameters(),
|
||||
memory=ExperienceReplayParameters(),
|
||||
networks={"main": BCNetworkParameters()})
|
||||
|
||||
@property
|
||||
|
||||
Reference in New Issue
Block a user