mirror of
https://github.com/gryf/coach.git
synced 2025-12-18 03:30:19 +01:00
network_imporvements branch merge
This commit is contained in:
@@ -157,6 +157,10 @@ class Agent(AgentInterface):
|
||||
if self.ap.task_parameters.seed is not None:
|
||||
random.seed(self.ap.task_parameters.seed)
|
||||
np.random.seed(self.ap.task_parameters.seed)
|
||||
else:
|
||||
# we need to seed the RNG since the different processes are initialized with the same parent seed
|
||||
random.seed()
|
||||
np.random.seed()
|
||||
|
||||
@property
|
||||
def parent(self):
|
||||
@@ -269,6 +273,10 @@ class Agent(AgentInterface):
|
||||
spaces=self.spaces,
|
||||
replicated_device=self.replicated_device,
|
||||
worker_device=self.worker_device)
|
||||
|
||||
if self.ap.visualization.print_networks_summary:
|
||||
print(networks[network_name])
|
||||
|
||||
return networks
|
||||
|
||||
def init_environment_dependent_modules(self) -> None:
|
||||
@@ -278,6 +286,14 @@ class Agent(AgentInterface):
|
||||
:return: None
|
||||
"""
|
||||
# initialize exploration policy
|
||||
if isinstance(self.ap.exploration, dict):
|
||||
if self.spaces.action.__class__ in self.ap.exploration.keys():
|
||||
self.ap.exploration = self.ap.exploration[self.spaces.action.__class__]
|
||||
else:
|
||||
raise ValueError("The exploration parameters were defined as a mapping between action space types and "
|
||||
"exploration types, but the action space used by the environment ({}) was not part of "
|
||||
"the exploration parameters dictionary keys ({})"
|
||||
.format(self.spaces.action.__class__, list(self.ap.exploration.keys())))
|
||||
self.ap.exploration.action_space = self.spaces.action
|
||||
self.exploration_policy = dynamic_import_and_instantiate_module_from_params(self.ap.exploration)
|
||||
|
||||
@@ -543,6 +559,9 @@ class Agent(AgentInterface):
|
||||
"""
|
||||
loss = 0
|
||||
if self._should_train():
|
||||
for network in self.networks.values():
|
||||
network.set_is_training(True)
|
||||
|
||||
for training_step in range(self.ap.algorithm.num_consecutive_training_steps):
|
||||
# TODO: this should be network dependent
|
||||
network_parameters = list(self.ap.network_wrappers.values())[0]
|
||||
@@ -586,9 +605,14 @@ class Agent(AgentInterface):
|
||||
if self.imitation:
|
||||
self.log_to_screen()
|
||||
|
||||
for network in self.networks.values():
|
||||
network.set_is_training(False)
|
||||
|
||||
# run additional commands after the training is done
|
||||
self.post_training_commands()
|
||||
|
||||
|
||||
|
||||
return loss
|
||||
|
||||
def choose_action(self, curr_state):
|
||||
|
||||
Reference in New Issue
Block a user