1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 11:10:20 +01:00
Files
coach/rl_coach/filters
2018-08-27 10:54:11 +03:00
..
2018-08-13 17:11:34 +03:00
2018-08-13 17:11:34 +03:00

A custom observation filter implementation should look like this:

from coach.filters.filter import ObservationFilter

class CustomFilter(ObservationFilter):
  def __init__(self):
    ...
  def filter(self, env_response: EnvResponse) -> EnvResponse:
    ...
  def get_filtered_observation_space(self, input_observation_space: ObservationSpace) -> ObservationSpace:
    ...
  def validate_input_observation_space(self, input_observation_space: ObservationSpace):
    ...
  def reset(self):
    ...

or for reward filters:

from coach.filters.filter import RewardFilter

class CustomFilter(ObservationFilter):
  def __init__(self):
    ...
  def filter(self, env_response: EnvResponse) -> EnvResponse:
    ...
  def get_filtered_reward_space(self, input_reward_space: RewardSpace) -> RewardSpace:
    ...
  def reset(self):
    ...

To create a stack of filters:

from coach.filters.observation.observation_rescale_to_size_filter import ObservationRescaleToSizeFilter, RescaleInterpolationType
from coach.filters.observation.observation_crop_filter import ObservationCropFilter
from coach.filters.reward.reward_clipping_filter import RewardClippingFilter
from environments.environment_interface import ObservationSpace
import numpy as np
from core_types import EnvResponse
from filters.filter import InputFilter
from collections import OrderedDict

env_response = EnvResponse({'observation': np.ones([210, 160])}, reward=100, game_over=False)

rescale = ObservationRescaleToSizeFilter(
    output_observation_space=ObservationSpace(np.array([110, 84])),
    rescaling_interpolation_type=RescaleInterpolationType.BILINEAR
)

crop = ObservationCropFilter(
    crop_low=np.array([16, 0]),
    crop_high=np.array([100, 84])
)

clip = RewardClippingFilter(
    clipping_low=-1,
    clipping_high=1
)

input_filter = InputFilter(
    observation_filters=OrderedDict([('rescale', rescale), ('crop', crop)]),
    reward_filters=OrderedDict([('clip', clip)])
)

result = input_filter.filter(env_response)