mirror of
https://github.com/gryf/coach.git
synced 2026-03-04 07:45:53 +01:00
applying filters for a csv loaded dataset + some bug-fixes in data loading (#319)
This commit is contained in:
@@ -21,7 +21,7 @@ import numpy as np
|
||||
|
||||
from rl_coach.core_types import ObservationType
|
||||
from rl_coach.filters.observation.observation_filter import ObservationFilter
|
||||
from rl_coach.spaces import ObservationSpace
|
||||
from rl_coach.spaces import ObservationSpace, VectorObservationSpace
|
||||
|
||||
|
||||
class LazyStack(object):
|
||||
@@ -63,6 +63,7 @@ class ObservationStackingFilter(ObservationFilter):
|
||||
self.stack_size = stack_size
|
||||
self.stacking_axis = stacking_axis
|
||||
self.stack = []
|
||||
self.input_observation_space = None
|
||||
|
||||
if stack_size <= 0:
|
||||
raise ValueError("The stack shape must be a positive number")
|
||||
@@ -86,7 +87,6 @@ class ObservationStackingFilter(ObservationFilter):
|
||||
raise ValueError("The stacking axis is larger than the number of dimensions in the observation space")
|
||||
|
||||
def filter(self, observation: ObservationType, update_internal_state: bool=True) -> ObservationType:
|
||||
|
||||
if len(self.stack) == 0:
|
||||
self.stack = deque([observation] * self.stack_size, maxlen=self.stack_size)
|
||||
else:
|
||||
@@ -94,14 +94,21 @@ class ObservationStackingFilter(ObservationFilter):
|
||||
self.stack.append(observation)
|
||||
observation = LazyStack(self.stack, self.stacking_axis)
|
||||
|
||||
if isinstance(self.input_observation_space, VectorObservationSpace):
|
||||
# when stacking vectors, we cannot avoid copying the memory as we're flattening it all
|
||||
observation = np.array(observation).flatten()
|
||||
|
||||
return observation
|
||||
|
||||
def get_filtered_observation_space(self, input_observation_space: ObservationSpace) -> ObservationSpace:
|
||||
if self.stacking_axis == -1:
|
||||
input_observation_space.shape = np.append(input_observation_space.shape, values=[self.stack_size], axis=0)
|
||||
if isinstance(input_observation_space, VectorObservationSpace):
|
||||
self.input_observation_space = input_observation_space = VectorObservationSpace(input_observation_space.shape * self.stack_size)
|
||||
else:
|
||||
input_observation_space.shape = np.insert(input_observation_space.shape, obj=self.stacking_axis,
|
||||
values=[self.stack_size], axis=0)
|
||||
if self.stacking_axis == -1:
|
||||
input_observation_space.shape = np.append(input_observation_space.shape, values=[self.stack_size], axis=0)
|
||||
else:
|
||||
input_observation_space.shape = np.insert(input_observation_space.shape, obj=self.stacking_axis,
|
||||
values=[self.stack_size], axis=0)
|
||||
return input_observation_space
|
||||
|
||||
def reset(self) -> None:
|
||||
|
||||
Reference in New Issue
Block a user