diff --git a/rl_coach/agents/agent.py b/rl_coach/agents/agent.py index c7b755e..866fe8a 100644 --- a/rl_coach/agents/agent.py +++ b/rl_coach/agents/agent.py @@ -95,19 +95,6 @@ class Agent(AgentInterface): if self.ap.memory.memory_backend_params.run_type != 'trainer': self.memory.set_memory_backend(self.memory_backend) - if agent_parameters.memory.load_memory_from_file_path: - if isinstance(agent_parameters.memory.load_memory_from_file_path, PickledReplayBuffer): - screen.log_title("Loading a pickled replay buffer. Pickled file path: {}" - .format(agent_parameters.memory.load_memory_from_file_path.filepath)) - self.memory.load_pickled(agent_parameters.memory.load_memory_from_file_path.filepath) - elif isinstance(agent_parameters.memory.load_memory_from_file_path, CsvDataset): - screen.log_title("Loading a replay buffer from a CSV file. CSV file path: {}" - .format(agent_parameters.memory.load_memory_from_file_path.filepath)) - self.memory.load_csv(agent_parameters.memory.load_memory_from_file_path) - else: - raise ValueError('Trying to load a replay buffer using an unsupported method - {}. ' - .format(agent_parameters.memory.load_memory_from_file_path)) - if self.shared_memory and self.is_chief: self.shared_memory_scratchpad.add(self.memory_lookup_name, self.memory) @@ -262,6 +249,38 @@ class Agent(AgentInterface): self.output_filter.set_session(sess) self.pre_network_filter.set_session(sess) [network.set_session(sess) for network in self.networks.values()] + self.initialize_session_dependent_components() + + def initialize_session_dependent_components(self): + """ + Initialize components which require a session as part of their initialization. + + :return: None + """ + + # Loading a memory from a CSV file, requires an input filter to filter through the data. + # The filter needs a session before it can be used. + if self.ap.memory.load_memory_from_file_path: + self.load_memory_from_file() + + def load_memory_from_file(self): + """ + Load memory transitions from a file. + + :return: None + """ + + if isinstance(self.ap.memory.load_memory_from_file_path, PickledReplayBuffer): + screen.log_title("Loading a pickled replay buffer. Pickled file path: {}" + .format(self.ap.memory.load_memory_from_file_path.filepath)) + self.memory.load_pickled(self.ap.memory.load_memory_from_file_path.filepath) + elif isinstance(self.ap.memory.load_memory_from_file_path, CsvDataset): + screen.log_title("Loading a replay buffer from a CSV file. CSV file path: {}" + .format(self.ap.memory.load_memory_from_file_path.filepath)) + self.memory.load_csv(self.ap.memory.load_memory_from_file_path, self.input_filter) + else: + raise ValueError('Trying to load a replay buffer using an unsupported method - {}. ' + .format(self.ap.memory.load_memory_from_file_path)) def register_signal(self, signal_name: str, dump_one_value_per_episode: bool=True, dump_one_value_per_step: bool=False) -> Signal: diff --git a/rl_coach/filters/observation/observation_stacking_filter.py b/rl_coach/filters/observation/observation_stacking_filter.py index 3b000a5..58ed45b 100644 --- a/rl_coach/filters/observation/observation_stacking_filter.py +++ b/rl_coach/filters/observation/observation_stacking_filter.py @@ -21,7 +21,7 @@ import numpy as np from rl_coach.core_types import ObservationType from rl_coach.filters.observation.observation_filter import ObservationFilter -from rl_coach.spaces import ObservationSpace +from rl_coach.spaces import ObservationSpace, VectorObservationSpace class LazyStack(object): @@ -63,6 +63,7 @@ class ObservationStackingFilter(ObservationFilter): self.stack_size = stack_size self.stacking_axis = stacking_axis self.stack = [] + self.input_observation_space = None if stack_size <= 0: raise ValueError("The stack shape must be a positive number") @@ -86,7 +87,6 @@ class ObservationStackingFilter(ObservationFilter): raise ValueError("The stacking axis is larger than the number of dimensions in the observation space") def filter(self, observation: ObservationType, update_internal_state: bool=True) -> ObservationType: - if len(self.stack) == 0: self.stack = deque([observation] * self.stack_size, maxlen=self.stack_size) else: @@ -94,14 +94,21 @@ class ObservationStackingFilter(ObservationFilter): self.stack.append(observation) observation = LazyStack(self.stack, self.stacking_axis) + if isinstance(self.input_observation_space, VectorObservationSpace): + # when stacking vectors, we cannot avoid copying the memory as we're flattening it all + observation = np.array(observation).flatten() + return observation def get_filtered_observation_space(self, input_observation_space: ObservationSpace) -> ObservationSpace: - if self.stacking_axis == -1: - input_observation_space.shape = np.append(input_observation_space.shape, values=[self.stack_size], axis=0) + if isinstance(input_observation_space, VectorObservationSpace): + self.input_observation_space = input_observation_space = VectorObservationSpace(input_observation_space.shape * self.stack_size) else: - input_observation_space.shape = np.insert(input_observation_space.shape, obj=self.stacking_axis, - values=[self.stack_size], axis=0) + if self.stacking_axis == -1: + input_observation_space.shape = np.append(input_observation_space.shape, values=[self.stack_size], axis=0) + else: + input_observation_space.shape = np.insert(input_observation_space.shape, obj=self.stacking_axis, + values=[self.stack_size], axis=0) return input_observation_space def reset(self) -> None: diff --git a/rl_coach/memories/episodic/episodic_experience_replay.py b/rl_coach/memories/episodic/episodic_experience_replay.py index a7dddbd..8a54dad 100644 --- a/rl_coach/memories/episodic/episodic_experience_replay.py +++ b/rl_coach/memories/episodic/episodic_experience_replay.py @@ -25,6 +25,7 @@ import numpy as np import random from rl_coach.core_types import Transition, Episode +from rl_coach.filters.filter import InputFilter from rl_coach.logger import screen from rl_coach.memories.memory import Memory, MemoryGranularity, MemoryParameters from rl_coach.utils import ReaderWriterLock, ProgressBar @@ -408,11 +409,12 @@ class EpisodicExperienceReplay(Memory): self.reader_writer_lock.release_writing() return mean - def load_csv(self, csv_dataset: CsvDataset) -> None: + def load_csv(self, csv_dataset: CsvDataset, input_filter: InputFilter) -> None: """ Restore the replay buffer contents from a csv file. The csv file is assumed to include a list of transitions. :param csv_dataset: A construct which holds the dataset parameters + :param input_filter: A filter used to filter the CSV data before feeding it to the memory. """ self.assert_not_frozen() @@ -429,18 +431,30 @@ class EpisodicExperienceReplay(Memory): for e_id in episode_ids: progress_bar.update(e_id) df_episode_transitions = df[df['episode_id'] == e_id] + input_filter.reset() + + if len(df_episode_transitions) < 2: + # we have to have at least 2 rows in each episode for creating a transition + continue + episode = Episode() + transitions = [] for (_, current_transition), (_, next_transition) in zip(df_episode_transitions[:-1].iterrows(), df_episode_transitions[1:].iterrows()): state = np.array([current_transition[col] for col in state_columns]) next_state = np.array([next_transition[col] for col in state_columns]) - episode.insert( + transitions.append( Transition(state={'observation': state}, action=current_transition['action'], reward=current_transition['reward'], next_state={'observation': next_state}, game_over=False, info={'all_action_probabilities': - ast.literal_eval(current_transition['all_action_probabilities'])})) + ast.literal_eval(current_transition['all_action_probabilities'])}), + ) + + transitions = input_filter.filter(transitions, deep_copy=False) + for t in transitions: + episode.insert(t) # Set the last transition to end the episode if csv_dataset.is_episodic: