1
0
mirror of https://github.com/gryf/coach.git synced 2026-02-17 14:45:50 +01:00

load and save function for non-episodic replay buffers + carla improvements + network bug fixes

This commit is contained in:
itaicaspi-intel
2018-09-06 16:46:57 +03:00
parent d59a700248
commit a9bd1047c4
8 changed files with 50 additions and 18 deletions

View File

@@ -74,12 +74,12 @@ class Agent(AgentInterface):
self.memory = self.shared_memory_scratchpad.get(self.memory_lookup_name)
else:
# modules
self.memory = dynamic_import_and_instantiate_module_from_params(self.ap.memory)
if agent_parameters.memory.load_memory_from_file_path:
screen.log_title("Loading replay buffer from pickle. Pickle path: {}"
.format(agent_parameters.memory.load_memory_from_file_path))
self.memory = read_pickle(agent_parameters.memory.load_memory_from_file_path)
else:
self.memory = dynamic_import_and_instantiate_module_from_params(self.ap.memory)
self.memory.load(agent_parameters.memory.load_memory_from_file_path)
if self.shared_memory and self.is_chief:
self.shared_memory_scratchpad.add(self.memory_lookup_name, self.memory)
@@ -149,6 +149,7 @@ class Agent(AgentInterface):
self.unclipped_grads = self.register_signal('Grads (unclipped)')
self.reward = self.register_signal('Reward', dump_one_value_per_episode=False, dump_one_value_per_step=True)
self.shaped_reward = self.register_signal('Shaped Reward', dump_one_value_per_episode=False, dump_one_value_per_step=True)
self.discounted_return = self.register_signal('Discounted Return')
if isinstance(self.in_action_space, GoalsSpace):
self.distance_from_goal = self.register_signal('Distance From Goal', dump_one_value_per_step=True)
@@ -427,6 +428,10 @@ class Agent(AgentInterface):
:return: None
"""
self.current_episode_buffer.is_complete = True
self.current_episode_buffer.update_returns()
for transition in self.current_episode_buffer.transitions:
self.discounted_return.add_sample(transition.total_return)
if self.phase != RunPhase.TEST or self.ap.task_parameters.evaluate_only:
self.current_episode += 1
@@ -435,7 +440,6 @@ class Agent(AgentInterface):
if isinstance(self.memory, EpisodicExperienceReplay):
self.call_memory('store_episode', self.current_episode_buffer)
elif self.ap.algorithm.store_transitions_only_when_episodes_are_terminated:
self.current_episode_buffer.update_returns()
for transition in self.current_episode_buffer.transitions:
self.call_memory('store', transition)

View File

@@ -32,6 +32,7 @@ from rl_coach.core_types import ActionInfo
from rl_coach.exploration_policies.e_greedy import EGreedyParameters
from rl_coach.logger import screen
from rl_coach.memories.episodic.episodic_experience_replay import EpisodicExperienceReplayParameters
from rl_coach.memories.non_episodic.experience_replay import ExperienceReplayParameters
class HumanAlgorithmParameters(AlgorithmParameters):
@@ -57,7 +58,7 @@ class HumanAgentParameters(AgentParameters):
def __init__(self):
super().__init__(algorithm=HumanAlgorithmParameters(),
exploration=EGreedyParameters(),
memory=EpisodicExperienceReplayParameters(),
memory=ExperienceReplayParameters(),
networks={"main": BCNetworkParameters()})
@property
@@ -103,7 +104,7 @@ class HumanAgent(Agent):
def save_replay_buffer_and_exit(self):
replay_buffer_path = os.path.join(self.agent_logger.experiments_path, 'replay_buffer.p')
self.memory.tp = None
to_pickle(self.memory, replay_buffer_path)
self.memory.save(replay_buffer_path)
screen.log_title("Replay buffer was stored in {}".format(replay_buffer_path))
exit()