mirror of
https://github.com/gryf/coach.git
synced 2026-02-17 14:45:50 +01:00
load and save function for non-episodic replay buffers + carla improvements + network bug fixes
This commit is contained in:
@@ -74,12 +74,12 @@ class Agent(AgentInterface):
|
||||
self.memory = self.shared_memory_scratchpad.get(self.memory_lookup_name)
|
||||
else:
|
||||
# modules
|
||||
self.memory = dynamic_import_and_instantiate_module_from_params(self.ap.memory)
|
||||
|
||||
if agent_parameters.memory.load_memory_from_file_path:
|
||||
screen.log_title("Loading replay buffer from pickle. Pickle path: {}"
|
||||
.format(agent_parameters.memory.load_memory_from_file_path))
|
||||
self.memory = read_pickle(agent_parameters.memory.load_memory_from_file_path)
|
||||
else:
|
||||
self.memory = dynamic_import_and_instantiate_module_from_params(self.ap.memory)
|
||||
self.memory.load(agent_parameters.memory.load_memory_from_file_path)
|
||||
|
||||
if self.shared_memory and self.is_chief:
|
||||
self.shared_memory_scratchpad.add(self.memory_lookup_name, self.memory)
|
||||
@@ -149,6 +149,7 @@ class Agent(AgentInterface):
|
||||
self.unclipped_grads = self.register_signal('Grads (unclipped)')
|
||||
self.reward = self.register_signal('Reward', dump_one_value_per_episode=False, dump_one_value_per_step=True)
|
||||
self.shaped_reward = self.register_signal('Shaped Reward', dump_one_value_per_episode=False, dump_one_value_per_step=True)
|
||||
self.discounted_return = self.register_signal('Discounted Return')
|
||||
if isinstance(self.in_action_space, GoalsSpace):
|
||||
self.distance_from_goal = self.register_signal('Distance From Goal', dump_one_value_per_step=True)
|
||||
|
||||
@@ -427,6 +428,10 @@ class Agent(AgentInterface):
|
||||
:return: None
|
||||
"""
|
||||
self.current_episode_buffer.is_complete = True
|
||||
self.current_episode_buffer.update_returns()
|
||||
|
||||
for transition in self.current_episode_buffer.transitions:
|
||||
self.discounted_return.add_sample(transition.total_return)
|
||||
|
||||
if self.phase != RunPhase.TEST or self.ap.task_parameters.evaluate_only:
|
||||
self.current_episode += 1
|
||||
@@ -435,7 +440,6 @@ class Agent(AgentInterface):
|
||||
if isinstance(self.memory, EpisodicExperienceReplay):
|
||||
self.call_memory('store_episode', self.current_episode_buffer)
|
||||
elif self.ap.algorithm.store_transitions_only_when_episodes_are_terminated:
|
||||
self.current_episode_buffer.update_returns()
|
||||
for transition in self.current_episode_buffer.transitions:
|
||||
self.call_memory('store', transition)
|
||||
|
||||
|
||||
@@ -32,6 +32,7 @@ from rl_coach.core_types import ActionInfo
|
||||
from rl_coach.exploration_policies.e_greedy import EGreedyParameters
|
||||
from rl_coach.logger import screen
|
||||
from rl_coach.memories.episodic.episodic_experience_replay import EpisodicExperienceReplayParameters
|
||||
from rl_coach.memories.non_episodic.experience_replay import ExperienceReplayParameters
|
||||
|
||||
|
||||
class HumanAlgorithmParameters(AlgorithmParameters):
|
||||
@@ -57,7 +58,7 @@ class HumanAgentParameters(AgentParameters):
|
||||
def __init__(self):
|
||||
super().__init__(algorithm=HumanAlgorithmParameters(),
|
||||
exploration=EGreedyParameters(),
|
||||
memory=EpisodicExperienceReplayParameters(),
|
||||
memory=ExperienceReplayParameters(),
|
||||
networks={"main": BCNetworkParameters()})
|
||||
|
||||
@property
|
||||
@@ -103,7 +104,7 @@ class HumanAgent(Agent):
|
||||
def save_replay_buffer_and_exit(self):
|
||||
replay_buffer_path = os.path.join(self.agent_logger.experiments_path, 'replay_buffer.p')
|
||||
self.memory.tp = None
|
||||
to_pickle(self.memory, replay_buffer_path)
|
||||
self.memory.save(replay_buffer_path)
|
||||
screen.log_title("Replay buffer was stored in {}".format(replay_buffer_path))
|
||||
exit()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user