1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 19:20:19 +01:00

load and save function for non-episodic replay buffers + carla improvements + network bug fixes

This commit is contained in:
itaicaspi-intel
2018-09-06 16:46:57 +03:00
parent d59a700248
commit a9bd1047c4
8 changed files with 50 additions and 18 deletions

View File

@@ -74,12 +74,12 @@ class Agent(AgentInterface):
self.memory = self.shared_memory_scratchpad.get(self.memory_lookup_name)
else:
# modules
self.memory = dynamic_import_and_instantiate_module_from_params(self.ap.memory)
if agent_parameters.memory.load_memory_from_file_path:
screen.log_title("Loading replay buffer from pickle. Pickle path: {}"
.format(agent_parameters.memory.load_memory_from_file_path))
self.memory = read_pickle(agent_parameters.memory.load_memory_from_file_path)
else:
self.memory = dynamic_import_and_instantiate_module_from_params(self.ap.memory)
self.memory.load(agent_parameters.memory.load_memory_from_file_path)
if self.shared_memory and self.is_chief:
self.shared_memory_scratchpad.add(self.memory_lookup_name, self.memory)
@@ -149,6 +149,7 @@ class Agent(AgentInterface):
self.unclipped_grads = self.register_signal('Grads (unclipped)')
self.reward = self.register_signal('Reward', dump_one_value_per_episode=False, dump_one_value_per_step=True)
self.shaped_reward = self.register_signal('Shaped Reward', dump_one_value_per_episode=False, dump_one_value_per_step=True)
self.discounted_return = self.register_signal('Discounted Return')
if isinstance(self.in_action_space, GoalsSpace):
self.distance_from_goal = self.register_signal('Distance From Goal', dump_one_value_per_step=True)
@@ -427,6 +428,10 @@ class Agent(AgentInterface):
:return: None
"""
self.current_episode_buffer.is_complete = True
self.current_episode_buffer.update_returns()
for transition in self.current_episode_buffer.transitions:
self.discounted_return.add_sample(transition.total_return)
if self.phase != RunPhase.TEST or self.ap.task_parameters.evaluate_only:
self.current_episode += 1
@@ -435,7 +440,6 @@ class Agent(AgentInterface):
if isinstance(self.memory, EpisodicExperienceReplay):
self.call_memory('store_episode', self.current_episode_buffer)
elif self.ap.algorithm.store_transitions_only_when_episodes_are_terminated:
self.current_episode_buffer.update_returns()
for transition in self.current_episode_buffer.transitions:
self.call_memory('store', transition)