mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
applying filters for a csv loaded dataset + some bug-fixes in data loading (#319)
This commit is contained in:
@@ -95,19 +95,6 @@ class Agent(AgentInterface):
|
||||
if self.ap.memory.memory_backend_params.run_type != 'trainer':
|
||||
self.memory.set_memory_backend(self.memory_backend)
|
||||
|
||||
if agent_parameters.memory.load_memory_from_file_path:
|
||||
if isinstance(agent_parameters.memory.load_memory_from_file_path, PickledReplayBuffer):
|
||||
screen.log_title("Loading a pickled replay buffer. Pickled file path: {}"
|
||||
.format(agent_parameters.memory.load_memory_from_file_path.filepath))
|
||||
self.memory.load_pickled(agent_parameters.memory.load_memory_from_file_path.filepath)
|
||||
elif isinstance(agent_parameters.memory.load_memory_from_file_path, CsvDataset):
|
||||
screen.log_title("Loading a replay buffer from a CSV file. CSV file path: {}"
|
||||
.format(agent_parameters.memory.load_memory_from_file_path.filepath))
|
||||
self.memory.load_csv(agent_parameters.memory.load_memory_from_file_path)
|
||||
else:
|
||||
raise ValueError('Trying to load a replay buffer using an unsupported method - {}. '
|
||||
.format(agent_parameters.memory.load_memory_from_file_path))
|
||||
|
||||
if self.shared_memory and self.is_chief:
|
||||
self.shared_memory_scratchpad.add(self.memory_lookup_name, self.memory)
|
||||
|
||||
@@ -262,6 +249,38 @@ class Agent(AgentInterface):
|
||||
self.output_filter.set_session(sess)
|
||||
self.pre_network_filter.set_session(sess)
|
||||
[network.set_session(sess) for network in self.networks.values()]
|
||||
self.initialize_session_dependent_components()
|
||||
|
||||
def initialize_session_dependent_components(self):
|
||||
"""
|
||||
Initialize components which require a session as part of their initialization.
|
||||
|
||||
:return: None
|
||||
"""
|
||||
|
||||
# Loading a memory from a CSV file, requires an input filter to filter through the data.
|
||||
# The filter needs a session before it can be used.
|
||||
if self.ap.memory.load_memory_from_file_path:
|
||||
self.load_memory_from_file()
|
||||
|
||||
def load_memory_from_file(self):
|
||||
"""
|
||||
Load memory transitions from a file.
|
||||
|
||||
:return: None
|
||||
"""
|
||||
|
||||
if isinstance(self.ap.memory.load_memory_from_file_path, PickledReplayBuffer):
|
||||
screen.log_title("Loading a pickled replay buffer. Pickled file path: {}"
|
||||
.format(self.ap.memory.load_memory_from_file_path.filepath))
|
||||
self.memory.load_pickled(self.ap.memory.load_memory_from_file_path.filepath)
|
||||
elif isinstance(self.ap.memory.load_memory_from_file_path, CsvDataset):
|
||||
screen.log_title("Loading a replay buffer from a CSV file. CSV file path: {}"
|
||||
.format(self.ap.memory.load_memory_from_file_path.filepath))
|
||||
self.memory.load_csv(self.ap.memory.load_memory_from_file_path, self.input_filter)
|
||||
else:
|
||||
raise ValueError('Trying to load a replay buffer using an unsupported method - {}. '
|
||||
.format(self.ap.memory.load_memory_from_file_path))
|
||||
|
||||
def register_signal(self, signal_name: str, dump_one_value_per_episode: bool=True,
|
||||
dump_one_value_per_step: bool=False) -> Signal:
|
||||
|
||||
Reference in New Issue
Block a user