1
0
mirror of https://github.com/gryf/coach.git synced 2026-03-03 23:35:51 +01:00

network_imporvements branch merge

This commit is contained in:
Shadi Endrawis
2018-10-02 13:41:46 +03:00
parent 72ea933384
commit 51726a5b80
110 changed files with 1639 additions and 1161 deletions

View File

@@ -26,6 +26,7 @@ from rl_coach.base_parameters import AgentParameters, AlgorithmParameters, Netwo
from rl_coach.exploration_policies.e_greedy import EGreedyParameters
from rl_coach.memories.episodic.episodic_experience_replay import EpisodicExperienceReplayParameters
from rl_coach.architectures.tensorflow_components.embedders.embedder import InputEmbedderParameters
from rl_coach.memories.non_episodic.experience_replay import ExperienceReplayParameters
class BCAlgorithmParameters(AlgorithmParameters):
@@ -40,7 +41,6 @@ class BCNetworkParameters(NetworkParameters):
self.input_embedders_parameters = {'observation': InputEmbedderParameters()}
self.middleware_parameters = FCMiddlewareParameters(scheme=MiddlewareScheme.Medium)
self.heads_parameters = [PolicyHeadParameters()]
self.loss_weights = [1.0]
self.optimizer_type = 'Adam'
self.batch_size = 32
self.replace_mse_with_huber_loss = False
@@ -51,7 +51,7 @@ class BCAgentParameters(AgentParameters):
def __init__(self):
super().__init__(algorithm=BCAlgorithmParameters(),
exploration=EGreedyParameters(),
memory=EpisodicExperienceReplayParameters(),
memory=ExperienceReplayParameters(),
networks={"main": BCNetworkParameters()})
@property