1
0
mirror of https://github.com/gryf/coach.git synced 2026-03-05 00:15:50 +01:00

applying filters for a csv loaded dataset + some bug-fixes in data loading (#319)

This commit is contained in:
Gal Leibovich
2019-05-28 15:44:55 +03:00
committed by GitHub
parent 6319387357
commit 4c996e147e
3 changed files with 62 additions and 22 deletions

View File

@@ -25,6 +25,7 @@ import numpy as np
import random
from rl_coach.core_types import Transition, Episode
from rl_coach.filters.filter import InputFilter
from rl_coach.logger import screen
from rl_coach.memories.memory import Memory, MemoryGranularity, MemoryParameters
from rl_coach.utils import ReaderWriterLock, ProgressBar
@@ -408,11 +409,12 @@ class EpisodicExperienceReplay(Memory):
self.reader_writer_lock.release_writing()
return mean
def load_csv(self, csv_dataset: CsvDataset) -> None:
def load_csv(self, csv_dataset: CsvDataset, input_filter: InputFilter) -> None:
"""
Restore the replay buffer contents from a csv file.
The csv file is assumed to include a list of transitions.
:param csv_dataset: A construct which holds the dataset parameters
:param input_filter: A filter used to filter the CSV data before feeding it to the memory.
"""
self.assert_not_frozen()
@@ -429,18 +431,30 @@ class EpisodicExperienceReplay(Memory):
for e_id in episode_ids:
progress_bar.update(e_id)
df_episode_transitions = df[df['episode_id'] == e_id]
input_filter.reset()
if len(df_episode_transitions) < 2:
# we have to have at least 2 rows in each episode for creating a transition
continue
episode = Episode()
transitions = []
for (_, current_transition), (_, next_transition) in zip(df_episode_transitions[:-1].iterrows(),
df_episode_transitions[1:].iterrows()):
state = np.array([current_transition[col] for col in state_columns])
next_state = np.array([next_transition[col] for col in state_columns])
episode.insert(
transitions.append(
Transition(state={'observation': state},
action=current_transition['action'], reward=current_transition['reward'],
next_state={'observation': next_state}, game_over=False,
info={'all_action_probabilities':
ast.literal_eval(current_transition['all_action_probabilities'])}))
ast.literal_eval(current_transition['all_action_probabilities'])}),
)
transitions = input_filter.filter(transitions, deep_copy=False)
for t in transitions:
episode.insert(t)
# Set the last transition to end the episode
if csv_dataset.is_episodic: