1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-18 03:30:19 +01:00

network_imporvements branch merge

This commit is contained in:
Shadi Endrawis
2018-10-02 13:41:46 +03:00
parent 72ea933384
commit 51726a5b80
110 changed files with 1639 additions and 1161 deletions

View File

@@ -157,6 +157,10 @@ class Agent(AgentInterface):
if self.ap.task_parameters.seed is not None:
random.seed(self.ap.task_parameters.seed)
np.random.seed(self.ap.task_parameters.seed)
else:
# we need to seed the RNG since the different processes are initialized with the same parent seed
random.seed()
np.random.seed()
@property
def parent(self):
@@ -269,6 +273,10 @@ class Agent(AgentInterface):
spaces=self.spaces,
replicated_device=self.replicated_device,
worker_device=self.worker_device)
if self.ap.visualization.print_networks_summary:
print(networks[network_name])
return networks
def init_environment_dependent_modules(self) -> None:
@@ -278,6 +286,14 @@ class Agent(AgentInterface):
:return: None
"""
# initialize exploration policy
if isinstance(self.ap.exploration, dict):
if self.spaces.action.__class__ in self.ap.exploration.keys():
self.ap.exploration = self.ap.exploration[self.spaces.action.__class__]
else:
raise ValueError("The exploration parameters were defined as a mapping between action space types and "
"exploration types, but the action space used by the environment ({}) was not part of "
"the exploration parameters dictionary keys ({})"
.format(self.spaces.action.__class__, list(self.ap.exploration.keys())))
self.ap.exploration.action_space = self.spaces.action
self.exploration_policy = dynamic_import_and_instantiate_module_from_params(self.ap.exploration)
@@ -543,6 +559,9 @@ class Agent(AgentInterface):
"""
loss = 0
if self._should_train():
for network in self.networks.values():
network.set_is_training(True)
for training_step in range(self.ap.algorithm.num_consecutive_training_steps):
# TODO: this should be network dependent
network_parameters = list(self.ap.network_wrappers.values())[0]
@@ -586,9 +605,14 @@ class Agent(AgentInterface):
if self.imitation:
self.log_to_screen()
for network in self.networks.values():
network.set_is_training(False)
# run additional commands after the training is done
self.post_training_commands()
return loss
def choose_action(self, curr_state):