1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-18 11:40:18 +01:00
Files
coach/rl_coach/filters
Gal Leibovich a1c56edd98 Fixes for having NumpySharedRunningStats syncing on multi-node (#139)
1. Having the standard checkpoint prefix in order for the data store to grab it, and sync it to S3.
2. Removing the reference to Redis so that it won't try to pickle that in.
3. Enable restoring a checkpoint into a single-worker run, which was saved by a single-node-multiple-worker run.
2018-11-23 16:11:47 +02:00
..
2018-08-13 17:11:34 +03:00
2018-08-13 17:11:34 +03:00

A custom observation filter implementation should look like this:

from coach.filters.filter import ObservationFilter

class CustomFilter(ObservationFilter):
  def __init__(self):
    ...
  def filter(self, env_response: EnvResponse) -> EnvResponse:
    ...
  def get_filtered_observation_space(self, input_observation_space: ObservationSpace) -> ObservationSpace:
    ...
  def validate_input_observation_space(self, input_observation_space: ObservationSpace):
    ...
  def reset(self):
    ...

or for reward filters:

from coach.filters.filter import RewardFilter

class CustomFilter(ObservationFilter):
  def __init__(self):
    ...
  def filter(self, env_response: EnvResponse) -> EnvResponse:
    ...
  def get_filtered_reward_space(self, input_reward_space: RewardSpace) -> RewardSpace:
    ...
  def reset(self):
    ...

To create a stack of filters:

from coach.filters.observation.observation_rescale_to_size_filter import ObservationRescaleToSizeFilter, RescaleInterpolationType
from coach.filters.observation.observation_crop_filter import ObservationCropFilter
from coach.filters.reward.reward_clipping_filter import RewardClippingFilter
from environments.environment_interface import ObservationSpace
import numpy as np
from core_types import EnvResponse
from filters.filter import InputFilter
from collections import OrderedDict

env_response = EnvResponse({'observation': np.ones([210, 160])}, reward=100, game_over=False)

rescale = ObservationRescaleToSizeFilter(
    output_observation_space=ObservationSpace(np.array([110, 84])),
    rescaling_interpolation_type=RescaleInterpolationType.BILINEAR
)

crop = ObservationCropFilter(
    crop_low=np.array([16, 0]),
    crop_high=np.array([100, 84])
)

clip = RewardClippingFilter(
    clipping_low=-1,
    clipping_high=1
)

input_filter = InputFilter(
    observation_filters=OrderedDict([('rescale', rescale), ('crop', crop)]),
    reward_filters=OrderedDict([('clip', clip)])
)

result = input_filter.filter(env_response)