mirror of
https://github.com/gryf/coach.git
synced 2026-03-05 00:15:50 +01:00
applying filters for a csv loaded dataset + some bug-fixes in data loading (#319)
This commit is contained in:
@@ -25,6 +25,7 @@ import numpy as np
|
||||
import random
|
||||
|
||||
from rl_coach.core_types import Transition, Episode
|
||||
from rl_coach.filters.filter import InputFilter
|
||||
from rl_coach.logger import screen
|
||||
from rl_coach.memories.memory import Memory, MemoryGranularity, MemoryParameters
|
||||
from rl_coach.utils import ReaderWriterLock, ProgressBar
|
||||
@@ -408,11 +409,12 @@ class EpisodicExperienceReplay(Memory):
|
||||
self.reader_writer_lock.release_writing()
|
||||
return mean
|
||||
|
||||
def load_csv(self, csv_dataset: CsvDataset) -> None:
|
||||
def load_csv(self, csv_dataset: CsvDataset, input_filter: InputFilter) -> None:
|
||||
"""
|
||||
Restore the replay buffer contents from a csv file.
|
||||
The csv file is assumed to include a list of transitions.
|
||||
:param csv_dataset: A construct which holds the dataset parameters
|
||||
:param input_filter: A filter used to filter the CSV data before feeding it to the memory.
|
||||
"""
|
||||
self.assert_not_frozen()
|
||||
|
||||
@@ -429,18 +431,30 @@ class EpisodicExperienceReplay(Memory):
|
||||
for e_id in episode_ids:
|
||||
progress_bar.update(e_id)
|
||||
df_episode_transitions = df[df['episode_id'] == e_id]
|
||||
input_filter.reset()
|
||||
|
||||
if len(df_episode_transitions) < 2:
|
||||
# we have to have at least 2 rows in each episode for creating a transition
|
||||
continue
|
||||
|
||||
episode = Episode()
|
||||
transitions = []
|
||||
for (_, current_transition), (_, next_transition) in zip(df_episode_transitions[:-1].iterrows(),
|
||||
df_episode_transitions[1:].iterrows()):
|
||||
state = np.array([current_transition[col] for col in state_columns])
|
||||
next_state = np.array([next_transition[col] for col in state_columns])
|
||||
|
||||
episode.insert(
|
||||
transitions.append(
|
||||
Transition(state={'observation': state},
|
||||
action=current_transition['action'], reward=current_transition['reward'],
|
||||
next_state={'observation': next_state}, game_over=False,
|
||||
info={'all_action_probabilities':
|
||||
ast.literal_eval(current_transition['all_action_probabilities'])}))
|
||||
ast.literal_eval(current_transition['all_action_probabilities'])}),
|
||||
)
|
||||
|
||||
transitions = input_filter.filter(transitions, deep_copy=False)
|
||||
for t in transitions:
|
||||
episode.insert(t)
|
||||
|
||||
# Set the last transition to end the episode
|
||||
if csv_dataset.is_episodic:
|
||||
|
||||
Reference in New Issue
Block a user