mirror of
https://github.com/gryf/coach.git
synced 2026-03-13 13:15:50 +01:00
imitation related bug fixes
This commit is contained in:
@@ -17,6 +17,8 @@ import copy
|
||||
from enum import Enum
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
|
||||
from rl_coach.core_types import ObservationType
|
||||
from rl_coach.filters.observation.observation_filter import ObservationFilter
|
||||
from rl_coach.spaces import ObservationSpace, VectorObservationSpace
|
||||
@@ -45,6 +47,8 @@ class ObservationReductionBySubPartsNameFilter(ObservationFilter):
|
||||
self.indices_to_keep = None
|
||||
|
||||
def filter(self, observation: ObservationType, update_internal_state: bool=True) -> ObservationType:
|
||||
if not isinstance(observation, np.ndarray):
|
||||
raise ValueError("All the state values are expected to be numpy arrays")
|
||||
if self.indices_to_keep is None:
|
||||
raise ValueError("To use ObservationReductionBySubPartsNameFilter, the get_filtered_observation_space "
|
||||
"function should be called before filtering an observation")
|
||||
|
||||
Reference in New Issue
Block a user