1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 19:20:19 +01:00

applying filters for a csv loaded dataset + some bug-fixes in data loading (#319)

This commit is contained in:
Gal Leibovich
2019-05-28 15:44:55 +03:00
committed by GitHub
parent 6319387357
commit 4c996e147e
3 changed files with 62 additions and 22 deletions

View File

@@ -95,19 +95,6 @@ class Agent(AgentInterface):
if self.ap.memory.memory_backend_params.run_type != 'trainer': if self.ap.memory.memory_backend_params.run_type != 'trainer':
self.memory.set_memory_backend(self.memory_backend) self.memory.set_memory_backend(self.memory_backend)
if agent_parameters.memory.load_memory_from_file_path:
if isinstance(agent_parameters.memory.load_memory_from_file_path, PickledReplayBuffer):
screen.log_title("Loading a pickled replay buffer. Pickled file path: {}"
.format(agent_parameters.memory.load_memory_from_file_path.filepath))
self.memory.load_pickled(agent_parameters.memory.load_memory_from_file_path.filepath)
elif isinstance(agent_parameters.memory.load_memory_from_file_path, CsvDataset):
screen.log_title("Loading a replay buffer from a CSV file. CSV file path: {}"
.format(agent_parameters.memory.load_memory_from_file_path.filepath))
self.memory.load_csv(agent_parameters.memory.load_memory_from_file_path)
else:
raise ValueError('Trying to load a replay buffer using an unsupported method - {}. '
.format(agent_parameters.memory.load_memory_from_file_path))
if self.shared_memory and self.is_chief: if self.shared_memory and self.is_chief:
self.shared_memory_scratchpad.add(self.memory_lookup_name, self.memory) self.shared_memory_scratchpad.add(self.memory_lookup_name, self.memory)
@@ -262,6 +249,38 @@ class Agent(AgentInterface):
self.output_filter.set_session(sess) self.output_filter.set_session(sess)
self.pre_network_filter.set_session(sess) self.pre_network_filter.set_session(sess)
[network.set_session(sess) for network in self.networks.values()] [network.set_session(sess) for network in self.networks.values()]
self.initialize_session_dependent_components()
def initialize_session_dependent_components(self):
"""
Initialize components which require a session as part of their initialization.
:return: None
"""
# Loading a memory from a CSV file, requires an input filter to filter through the data.
# The filter needs a session before it can be used.
if self.ap.memory.load_memory_from_file_path:
self.load_memory_from_file()
def load_memory_from_file(self):
"""
Load memory transitions from a file.
:return: None
"""
if isinstance(self.ap.memory.load_memory_from_file_path, PickledReplayBuffer):
screen.log_title("Loading a pickled replay buffer. Pickled file path: {}"
.format(self.ap.memory.load_memory_from_file_path.filepath))
self.memory.load_pickled(self.ap.memory.load_memory_from_file_path.filepath)
elif isinstance(self.ap.memory.load_memory_from_file_path, CsvDataset):
screen.log_title("Loading a replay buffer from a CSV file. CSV file path: {}"
.format(self.ap.memory.load_memory_from_file_path.filepath))
self.memory.load_csv(self.ap.memory.load_memory_from_file_path, self.input_filter)
else:
raise ValueError('Trying to load a replay buffer using an unsupported method - {}. '
.format(self.ap.memory.load_memory_from_file_path))
def register_signal(self, signal_name: str, dump_one_value_per_episode: bool=True, def register_signal(self, signal_name: str, dump_one_value_per_episode: bool=True,
dump_one_value_per_step: bool=False) -> Signal: dump_one_value_per_step: bool=False) -> Signal:

View File

@@ -21,7 +21,7 @@ import numpy as np
from rl_coach.core_types import ObservationType from rl_coach.core_types import ObservationType
from rl_coach.filters.observation.observation_filter import ObservationFilter from rl_coach.filters.observation.observation_filter import ObservationFilter
from rl_coach.spaces import ObservationSpace from rl_coach.spaces import ObservationSpace, VectorObservationSpace
class LazyStack(object): class LazyStack(object):
@@ -63,6 +63,7 @@ class ObservationStackingFilter(ObservationFilter):
self.stack_size = stack_size self.stack_size = stack_size
self.stacking_axis = stacking_axis self.stacking_axis = stacking_axis
self.stack = [] self.stack = []
self.input_observation_space = None
if stack_size <= 0: if stack_size <= 0:
raise ValueError("The stack shape must be a positive number") raise ValueError("The stack shape must be a positive number")
@@ -86,7 +87,6 @@ class ObservationStackingFilter(ObservationFilter):
raise ValueError("The stacking axis is larger than the number of dimensions in the observation space") raise ValueError("The stacking axis is larger than the number of dimensions in the observation space")
def filter(self, observation: ObservationType, update_internal_state: bool=True) -> ObservationType: def filter(self, observation: ObservationType, update_internal_state: bool=True) -> ObservationType:
if len(self.stack) == 0: if len(self.stack) == 0:
self.stack = deque([observation] * self.stack_size, maxlen=self.stack_size) self.stack = deque([observation] * self.stack_size, maxlen=self.stack_size)
else: else:
@@ -94,14 +94,21 @@ class ObservationStackingFilter(ObservationFilter):
self.stack.append(observation) self.stack.append(observation)
observation = LazyStack(self.stack, self.stacking_axis) observation = LazyStack(self.stack, self.stacking_axis)
if isinstance(self.input_observation_space, VectorObservationSpace):
# when stacking vectors, we cannot avoid copying the memory as we're flattening it all
observation = np.array(observation).flatten()
return observation return observation
def get_filtered_observation_space(self, input_observation_space: ObservationSpace) -> ObservationSpace: def get_filtered_observation_space(self, input_observation_space: ObservationSpace) -> ObservationSpace:
if self.stacking_axis == -1: if isinstance(input_observation_space, VectorObservationSpace):
input_observation_space.shape = np.append(input_observation_space.shape, values=[self.stack_size], axis=0) self.input_observation_space = input_observation_space = VectorObservationSpace(input_observation_space.shape * self.stack_size)
else: else:
input_observation_space.shape = np.insert(input_observation_space.shape, obj=self.stacking_axis, if self.stacking_axis == -1:
values=[self.stack_size], axis=0) input_observation_space.shape = np.append(input_observation_space.shape, values=[self.stack_size], axis=0)
else:
input_observation_space.shape = np.insert(input_observation_space.shape, obj=self.stacking_axis,
values=[self.stack_size], axis=0)
return input_observation_space return input_observation_space
def reset(self) -> None: def reset(self) -> None:

View File

@@ -25,6 +25,7 @@ import numpy as np
import random import random
from rl_coach.core_types import Transition, Episode from rl_coach.core_types import Transition, Episode
from rl_coach.filters.filter import InputFilter
from rl_coach.logger import screen from rl_coach.logger import screen
from rl_coach.memories.memory import Memory, MemoryGranularity, MemoryParameters from rl_coach.memories.memory import Memory, MemoryGranularity, MemoryParameters
from rl_coach.utils import ReaderWriterLock, ProgressBar from rl_coach.utils import ReaderWriterLock, ProgressBar
@@ -408,11 +409,12 @@ class EpisodicExperienceReplay(Memory):
self.reader_writer_lock.release_writing() self.reader_writer_lock.release_writing()
return mean return mean
def load_csv(self, csv_dataset: CsvDataset) -> None: def load_csv(self, csv_dataset: CsvDataset, input_filter: InputFilter) -> None:
""" """
Restore the replay buffer contents from a csv file. Restore the replay buffer contents from a csv file.
The csv file is assumed to include a list of transitions. The csv file is assumed to include a list of transitions.
:param csv_dataset: A construct which holds the dataset parameters :param csv_dataset: A construct which holds the dataset parameters
:param input_filter: A filter used to filter the CSV data before feeding it to the memory.
""" """
self.assert_not_frozen() self.assert_not_frozen()
@@ -429,18 +431,30 @@ class EpisodicExperienceReplay(Memory):
for e_id in episode_ids: for e_id in episode_ids:
progress_bar.update(e_id) progress_bar.update(e_id)
df_episode_transitions = df[df['episode_id'] == e_id] df_episode_transitions = df[df['episode_id'] == e_id]
input_filter.reset()
if len(df_episode_transitions) < 2:
# we have to have at least 2 rows in each episode for creating a transition
continue
episode = Episode() episode = Episode()
transitions = []
for (_, current_transition), (_, next_transition) in zip(df_episode_transitions[:-1].iterrows(), for (_, current_transition), (_, next_transition) in zip(df_episode_transitions[:-1].iterrows(),
df_episode_transitions[1:].iterrows()): df_episode_transitions[1:].iterrows()):
state = np.array([current_transition[col] for col in state_columns]) state = np.array([current_transition[col] for col in state_columns])
next_state = np.array([next_transition[col] for col in state_columns]) next_state = np.array([next_transition[col] for col in state_columns])
episode.insert( transitions.append(
Transition(state={'observation': state}, Transition(state={'observation': state},
action=current_transition['action'], reward=current_transition['reward'], action=current_transition['action'], reward=current_transition['reward'],
next_state={'observation': next_state}, game_over=False, next_state={'observation': next_state}, game_over=False,
info={'all_action_probabilities': info={'all_action_probabilities':
ast.literal_eval(current_transition['all_action_probabilities'])})) ast.literal_eval(current_transition['all_action_probabilities'])}),
)
transitions = input_filter.filter(transitions, deep_copy=False)
for t in transitions:
episode.insert(t)
# Set the last transition to end the episode # Set the last transition to end the episode
if csv_dataset.is_episodic: if csv_dataset.is_episodic: