mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 11:10:20 +01:00
applying filters for a csv loaded dataset + some bug-fixes in data loading (#319)
This commit is contained in:
@@ -95,19 +95,6 @@ class Agent(AgentInterface):
|
|||||||
if self.ap.memory.memory_backend_params.run_type != 'trainer':
|
if self.ap.memory.memory_backend_params.run_type != 'trainer':
|
||||||
self.memory.set_memory_backend(self.memory_backend)
|
self.memory.set_memory_backend(self.memory_backend)
|
||||||
|
|
||||||
if agent_parameters.memory.load_memory_from_file_path:
|
|
||||||
if isinstance(agent_parameters.memory.load_memory_from_file_path, PickledReplayBuffer):
|
|
||||||
screen.log_title("Loading a pickled replay buffer. Pickled file path: {}"
|
|
||||||
.format(agent_parameters.memory.load_memory_from_file_path.filepath))
|
|
||||||
self.memory.load_pickled(agent_parameters.memory.load_memory_from_file_path.filepath)
|
|
||||||
elif isinstance(agent_parameters.memory.load_memory_from_file_path, CsvDataset):
|
|
||||||
screen.log_title("Loading a replay buffer from a CSV file. CSV file path: {}"
|
|
||||||
.format(agent_parameters.memory.load_memory_from_file_path.filepath))
|
|
||||||
self.memory.load_csv(agent_parameters.memory.load_memory_from_file_path)
|
|
||||||
else:
|
|
||||||
raise ValueError('Trying to load a replay buffer using an unsupported method - {}. '
|
|
||||||
.format(agent_parameters.memory.load_memory_from_file_path))
|
|
||||||
|
|
||||||
if self.shared_memory and self.is_chief:
|
if self.shared_memory and self.is_chief:
|
||||||
self.shared_memory_scratchpad.add(self.memory_lookup_name, self.memory)
|
self.shared_memory_scratchpad.add(self.memory_lookup_name, self.memory)
|
||||||
|
|
||||||
@@ -262,6 +249,38 @@ class Agent(AgentInterface):
|
|||||||
self.output_filter.set_session(sess)
|
self.output_filter.set_session(sess)
|
||||||
self.pre_network_filter.set_session(sess)
|
self.pre_network_filter.set_session(sess)
|
||||||
[network.set_session(sess) for network in self.networks.values()]
|
[network.set_session(sess) for network in self.networks.values()]
|
||||||
|
self.initialize_session_dependent_components()
|
||||||
|
|
||||||
|
def initialize_session_dependent_components(self):
|
||||||
|
"""
|
||||||
|
Initialize components which require a session as part of their initialization.
|
||||||
|
|
||||||
|
:return: None
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Loading a memory from a CSV file, requires an input filter to filter through the data.
|
||||||
|
# The filter needs a session before it can be used.
|
||||||
|
if self.ap.memory.load_memory_from_file_path:
|
||||||
|
self.load_memory_from_file()
|
||||||
|
|
||||||
|
def load_memory_from_file(self):
|
||||||
|
"""
|
||||||
|
Load memory transitions from a file.
|
||||||
|
|
||||||
|
:return: None
|
||||||
|
"""
|
||||||
|
|
||||||
|
if isinstance(self.ap.memory.load_memory_from_file_path, PickledReplayBuffer):
|
||||||
|
screen.log_title("Loading a pickled replay buffer. Pickled file path: {}"
|
||||||
|
.format(self.ap.memory.load_memory_from_file_path.filepath))
|
||||||
|
self.memory.load_pickled(self.ap.memory.load_memory_from_file_path.filepath)
|
||||||
|
elif isinstance(self.ap.memory.load_memory_from_file_path, CsvDataset):
|
||||||
|
screen.log_title("Loading a replay buffer from a CSV file. CSV file path: {}"
|
||||||
|
.format(self.ap.memory.load_memory_from_file_path.filepath))
|
||||||
|
self.memory.load_csv(self.ap.memory.load_memory_from_file_path, self.input_filter)
|
||||||
|
else:
|
||||||
|
raise ValueError('Trying to load a replay buffer using an unsupported method - {}. '
|
||||||
|
.format(self.ap.memory.load_memory_from_file_path))
|
||||||
|
|
||||||
def register_signal(self, signal_name: str, dump_one_value_per_episode: bool=True,
|
def register_signal(self, signal_name: str, dump_one_value_per_episode: bool=True,
|
||||||
dump_one_value_per_step: bool=False) -> Signal:
|
dump_one_value_per_step: bool=False) -> Signal:
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ import numpy as np
|
|||||||
|
|
||||||
from rl_coach.core_types import ObservationType
|
from rl_coach.core_types import ObservationType
|
||||||
from rl_coach.filters.observation.observation_filter import ObservationFilter
|
from rl_coach.filters.observation.observation_filter import ObservationFilter
|
||||||
from rl_coach.spaces import ObservationSpace
|
from rl_coach.spaces import ObservationSpace, VectorObservationSpace
|
||||||
|
|
||||||
|
|
||||||
class LazyStack(object):
|
class LazyStack(object):
|
||||||
@@ -63,6 +63,7 @@ class ObservationStackingFilter(ObservationFilter):
|
|||||||
self.stack_size = stack_size
|
self.stack_size = stack_size
|
||||||
self.stacking_axis = stacking_axis
|
self.stacking_axis = stacking_axis
|
||||||
self.stack = []
|
self.stack = []
|
||||||
|
self.input_observation_space = None
|
||||||
|
|
||||||
if stack_size <= 0:
|
if stack_size <= 0:
|
||||||
raise ValueError("The stack shape must be a positive number")
|
raise ValueError("The stack shape must be a positive number")
|
||||||
@@ -86,7 +87,6 @@ class ObservationStackingFilter(ObservationFilter):
|
|||||||
raise ValueError("The stacking axis is larger than the number of dimensions in the observation space")
|
raise ValueError("The stacking axis is larger than the number of dimensions in the observation space")
|
||||||
|
|
||||||
def filter(self, observation: ObservationType, update_internal_state: bool=True) -> ObservationType:
|
def filter(self, observation: ObservationType, update_internal_state: bool=True) -> ObservationType:
|
||||||
|
|
||||||
if len(self.stack) == 0:
|
if len(self.stack) == 0:
|
||||||
self.stack = deque([observation] * self.stack_size, maxlen=self.stack_size)
|
self.stack = deque([observation] * self.stack_size, maxlen=self.stack_size)
|
||||||
else:
|
else:
|
||||||
@@ -94,14 +94,21 @@ class ObservationStackingFilter(ObservationFilter):
|
|||||||
self.stack.append(observation)
|
self.stack.append(observation)
|
||||||
observation = LazyStack(self.stack, self.stacking_axis)
|
observation = LazyStack(self.stack, self.stacking_axis)
|
||||||
|
|
||||||
|
if isinstance(self.input_observation_space, VectorObservationSpace):
|
||||||
|
# when stacking vectors, we cannot avoid copying the memory as we're flattening it all
|
||||||
|
observation = np.array(observation).flatten()
|
||||||
|
|
||||||
return observation
|
return observation
|
||||||
|
|
||||||
def get_filtered_observation_space(self, input_observation_space: ObservationSpace) -> ObservationSpace:
|
def get_filtered_observation_space(self, input_observation_space: ObservationSpace) -> ObservationSpace:
|
||||||
if self.stacking_axis == -1:
|
if isinstance(input_observation_space, VectorObservationSpace):
|
||||||
input_observation_space.shape = np.append(input_observation_space.shape, values=[self.stack_size], axis=0)
|
self.input_observation_space = input_observation_space = VectorObservationSpace(input_observation_space.shape * self.stack_size)
|
||||||
else:
|
else:
|
||||||
input_observation_space.shape = np.insert(input_observation_space.shape, obj=self.stacking_axis,
|
if self.stacking_axis == -1:
|
||||||
values=[self.stack_size], axis=0)
|
input_observation_space.shape = np.append(input_observation_space.shape, values=[self.stack_size], axis=0)
|
||||||
|
else:
|
||||||
|
input_observation_space.shape = np.insert(input_observation_space.shape, obj=self.stacking_axis,
|
||||||
|
values=[self.stack_size], axis=0)
|
||||||
return input_observation_space
|
return input_observation_space
|
||||||
|
|
||||||
def reset(self) -> None:
|
def reset(self) -> None:
|
||||||
|
|||||||
@@ -25,6 +25,7 @@ import numpy as np
|
|||||||
import random
|
import random
|
||||||
|
|
||||||
from rl_coach.core_types import Transition, Episode
|
from rl_coach.core_types import Transition, Episode
|
||||||
|
from rl_coach.filters.filter import InputFilter
|
||||||
from rl_coach.logger import screen
|
from rl_coach.logger import screen
|
||||||
from rl_coach.memories.memory import Memory, MemoryGranularity, MemoryParameters
|
from rl_coach.memories.memory import Memory, MemoryGranularity, MemoryParameters
|
||||||
from rl_coach.utils import ReaderWriterLock, ProgressBar
|
from rl_coach.utils import ReaderWriterLock, ProgressBar
|
||||||
@@ -408,11 +409,12 @@ class EpisodicExperienceReplay(Memory):
|
|||||||
self.reader_writer_lock.release_writing()
|
self.reader_writer_lock.release_writing()
|
||||||
return mean
|
return mean
|
||||||
|
|
||||||
def load_csv(self, csv_dataset: CsvDataset) -> None:
|
def load_csv(self, csv_dataset: CsvDataset, input_filter: InputFilter) -> None:
|
||||||
"""
|
"""
|
||||||
Restore the replay buffer contents from a csv file.
|
Restore the replay buffer contents from a csv file.
|
||||||
The csv file is assumed to include a list of transitions.
|
The csv file is assumed to include a list of transitions.
|
||||||
:param csv_dataset: A construct which holds the dataset parameters
|
:param csv_dataset: A construct which holds the dataset parameters
|
||||||
|
:param input_filter: A filter used to filter the CSV data before feeding it to the memory.
|
||||||
"""
|
"""
|
||||||
self.assert_not_frozen()
|
self.assert_not_frozen()
|
||||||
|
|
||||||
@@ -429,18 +431,30 @@ class EpisodicExperienceReplay(Memory):
|
|||||||
for e_id in episode_ids:
|
for e_id in episode_ids:
|
||||||
progress_bar.update(e_id)
|
progress_bar.update(e_id)
|
||||||
df_episode_transitions = df[df['episode_id'] == e_id]
|
df_episode_transitions = df[df['episode_id'] == e_id]
|
||||||
|
input_filter.reset()
|
||||||
|
|
||||||
|
if len(df_episode_transitions) < 2:
|
||||||
|
# we have to have at least 2 rows in each episode for creating a transition
|
||||||
|
continue
|
||||||
|
|
||||||
episode = Episode()
|
episode = Episode()
|
||||||
|
transitions = []
|
||||||
for (_, current_transition), (_, next_transition) in zip(df_episode_transitions[:-1].iterrows(),
|
for (_, current_transition), (_, next_transition) in zip(df_episode_transitions[:-1].iterrows(),
|
||||||
df_episode_transitions[1:].iterrows()):
|
df_episode_transitions[1:].iterrows()):
|
||||||
state = np.array([current_transition[col] for col in state_columns])
|
state = np.array([current_transition[col] for col in state_columns])
|
||||||
next_state = np.array([next_transition[col] for col in state_columns])
|
next_state = np.array([next_transition[col] for col in state_columns])
|
||||||
|
|
||||||
episode.insert(
|
transitions.append(
|
||||||
Transition(state={'observation': state},
|
Transition(state={'observation': state},
|
||||||
action=current_transition['action'], reward=current_transition['reward'],
|
action=current_transition['action'], reward=current_transition['reward'],
|
||||||
next_state={'observation': next_state}, game_over=False,
|
next_state={'observation': next_state}, game_over=False,
|
||||||
info={'all_action_probabilities':
|
info={'all_action_probabilities':
|
||||||
ast.literal_eval(current_transition['all_action_probabilities'])}))
|
ast.literal_eval(current_transition['all_action_probabilities'])}),
|
||||||
|
)
|
||||||
|
|
||||||
|
transitions = input_filter.filter(transitions, deep_copy=False)
|
||||||
|
for t in transitions:
|
||||||
|
episode.insert(t)
|
||||||
|
|
||||||
# Set the last transition to end the episode
|
# Set the last transition to end the episode
|
||||||
if csv_dataset.is_episodic:
|
if csv_dataset.is_episodic:
|
||||||
|
|||||||
Reference in New Issue
Block a user