1
0
mirror of https://github.com/gryf/coach.git synced 2026-03-13 13:15:50 +01:00

imitation related bug fixes

This commit is contained in:
itaicaspi-intel
2018-09-12 14:54:33 +03:00
parent a9bd1047c4
commit 171fe97a3a
7 changed files with 21 additions and 22 deletions

View File

@@ -17,6 +17,8 @@ import copy
from enum import Enum
from typing import List
import numpy as np
from rl_coach.core_types import ObservationType
from rl_coach.filters.observation.observation_filter import ObservationFilter
from rl_coach.spaces import ObservationSpace, VectorObservationSpace
@@ -45,6 +47,8 @@ class ObservationReductionBySubPartsNameFilter(ObservationFilter):
self.indices_to_keep = None
def filter(self, observation: ObservationType, update_internal_state: bool=True) -> ObservationType:
if not isinstance(observation, np.ndarray):
raise ValueError("All the state values are expected to be numpy arrays")
if self.indices_to_keep is None:
raise ValueError("To use ObservationReductionBySubPartsNameFilter, the get_filtered_observation_space "
"function should be called before filtering an observation")