1
0
mirror of https://github.com/gryf/coach.git synced 2026-03-04 07:45:53 +01:00

applying filters for a csv loaded dataset + some bug-fixes in data loading (#319)

This commit is contained in:
Gal Leibovich
2019-05-28 15:44:55 +03:00
committed by GitHub
parent 6319387357
commit 4c996e147e
3 changed files with 62 additions and 22 deletions

View File

@@ -21,7 +21,7 @@ import numpy as np
from rl_coach.core_types import ObservationType
from rl_coach.filters.observation.observation_filter import ObservationFilter
from rl_coach.spaces import ObservationSpace
from rl_coach.spaces import ObservationSpace, VectorObservationSpace
class LazyStack(object):
@@ -63,6 +63,7 @@ class ObservationStackingFilter(ObservationFilter):
self.stack_size = stack_size
self.stacking_axis = stacking_axis
self.stack = []
self.input_observation_space = None
if stack_size <= 0:
raise ValueError("The stack shape must be a positive number")
@@ -86,7 +87,6 @@ class ObservationStackingFilter(ObservationFilter):
raise ValueError("The stacking axis is larger than the number of dimensions in the observation space")
def filter(self, observation: ObservationType, update_internal_state: bool=True) -> ObservationType:
if len(self.stack) == 0:
self.stack = deque([observation] * self.stack_size, maxlen=self.stack_size)
else:
@@ -94,14 +94,21 @@ class ObservationStackingFilter(ObservationFilter):
self.stack.append(observation)
observation = LazyStack(self.stack, self.stacking_axis)
if isinstance(self.input_observation_space, VectorObservationSpace):
# when stacking vectors, we cannot avoid copying the memory as we're flattening it all
observation = np.array(observation).flatten()
return observation
def get_filtered_observation_space(self, input_observation_space: ObservationSpace) -> ObservationSpace:
if self.stacking_axis == -1:
input_observation_space.shape = np.append(input_observation_space.shape, values=[self.stack_size], axis=0)
if isinstance(input_observation_space, VectorObservationSpace):
self.input_observation_space = input_observation_space = VectorObservationSpace(input_observation_space.shape * self.stack_size)
else:
input_observation_space.shape = np.insert(input_observation_space.shape, obj=self.stacking_axis,
values=[self.stack_size], axis=0)
if self.stacking_axis == -1:
input_observation_space.shape = np.append(input_observation_space.shape, values=[self.stack_size], axis=0)
else:
input_observation_space.shape = np.insert(input_observation_space.shape, obj=self.stacking_axis,
values=[self.stack_size], axis=0)
return input_observation_space
def reset(self) -> None: