mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
load and save function for non-episodic replay buffers + carla improvements + network bug fixes
This commit is contained in:
@@ -74,12 +74,12 @@ class Agent(AgentInterface):
|
||||
self.memory = self.shared_memory_scratchpad.get(self.memory_lookup_name)
|
||||
else:
|
||||
# modules
|
||||
self.memory = dynamic_import_and_instantiate_module_from_params(self.ap.memory)
|
||||
|
||||
if agent_parameters.memory.load_memory_from_file_path:
|
||||
screen.log_title("Loading replay buffer from pickle. Pickle path: {}"
|
||||
.format(agent_parameters.memory.load_memory_from_file_path))
|
||||
self.memory = read_pickle(agent_parameters.memory.load_memory_from_file_path)
|
||||
else:
|
||||
self.memory = dynamic_import_and_instantiate_module_from_params(self.ap.memory)
|
||||
self.memory.load(agent_parameters.memory.load_memory_from_file_path)
|
||||
|
||||
if self.shared_memory and self.is_chief:
|
||||
self.shared_memory_scratchpad.add(self.memory_lookup_name, self.memory)
|
||||
@@ -149,6 +149,7 @@ class Agent(AgentInterface):
|
||||
self.unclipped_grads = self.register_signal('Grads (unclipped)')
|
||||
self.reward = self.register_signal('Reward', dump_one_value_per_episode=False, dump_one_value_per_step=True)
|
||||
self.shaped_reward = self.register_signal('Shaped Reward', dump_one_value_per_episode=False, dump_one_value_per_step=True)
|
||||
self.discounted_return = self.register_signal('Discounted Return')
|
||||
if isinstance(self.in_action_space, GoalsSpace):
|
||||
self.distance_from_goal = self.register_signal('Distance From Goal', dump_one_value_per_step=True)
|
||||
|
||||
@@ -427,6 +428,10 @@ class Agent(AgentInterface):
|
||||
:return: None
|
||||
"""
|
||||
self.current_episode_buffer.is_complete = True
|
||||
self.current_episode_buffer.update_returns()
|
||||
|
||||
for transition in self.current_episode_buffer.transitions:
|
||||
self.discounted_return.add_sample(transition.total_return)
|
||||
|
||||
if self.phase != RunPhase.TEST or self.ap.task_parameters.evaluate_only:
|
||||
self.current_episode += 1
|
||||
@@ -435,7 +440,6 @@ class Agent(AgentInterface):
|
||||
if isinstance(self.memory, EpisodicExperienceReplay):
|
||||
self.call_memory('store_episode', self.current_episode_buffer)
|
||||
elif self.ap.algorithm.store_transitions_only_when_episodes_are_terminated:
|
||||
self.current_episode_buffer.update_returns()
|
||||
for transition in self.current_episode_buffer.transitions:
|
||||
self.call_memory('store', transition)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user