mirror of
https://github.com/gryf/coach.git
synced 2025-12-18 11:40:18 +01:00
network_imporvements branch merge
This commit is contained in:
@@ -18,6 +18,7 @@ import gym
|
||||
import numpy as np
|
||||
import scipy.ndimage
|
||||
|
||||
from rl_coach.graph_managers.graph_manager import ScheduleParameters
|
||||
from rl_coach.utils import lower_under_to_upper, short_dynamic_import
|
||||
|
||||
try:
|
||||
@@ -40,7 +41,7 @@ except ImportError:
|
||||
failed_imports.append("PyBullet")
|
||||
|
||||
from typing import Dict, Any, Union
|
||||
from rl_coach.core_types import RunPhase
|
||||
from rl_coach.core_types import RunPhase, EnvironmentSteps
|
||||
from rl_coach.environments.environment import Environment, EnvironmentParameters, LevelSelection
|
||||
from rl_coach.spaces import DiscreteActionSpace, BoxActionSpace, ImageObservationSpace, VectorObservationSpace, \
|
||||
StateSpace, RewardSpace
|
||||
@@ -57,10 +58,9 @@ from rl_coach.logger import screen
|
||||
|
||||
|
||||
# Parameters
|
||||
|
||||
class GymEnvironmentParameters(EnvironmentParameters):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
def __init__(self, level=None):
|
||||
super().__init__(level=level)
|
||||
self.random_initialization_steps = 0
|
||||
self.max_over_num_frames = 1
|
||||
self.additional_simulator_parameters = None
|
||||
@@ -70,64 +70,32 @@ class GymEnvironmentParameters(EnvironmentParameters):
|
||||
return 'rl_coach.environments.gym_environment:GymEnvironment'
|
||||
|
||||
|
||||
"""
|
||||
Roboschool Environment Components
|
||||
"""
|
||||
RoboSchoolInputFilters = NoInputFilter()
|
||||
RoboSchoolOutputFilters = NoOutputFilter()
|
||||
|
||||
|
||||
class Roboschool(GymEnvironmentParameters):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
# Generic parameters for vector environments such as mujoco, roboschool, bullet, etc.
|
||||
class GymVectorEnvironment(GymEnvironmentParameters):
|
||||
def __init__(self, level=None):
|
||||
super().__init__(level=level)
|
||||
self.frame_skip = 1
|
||||
self.default_input_filter = RoboSchoolInputFilters
|
||||
self.default_output_filter = RoboSchoolOutputFilters
|
||||
self.default_input_filter = NoInputFilter()
|
||||
self.default_output_filter = NoOutputFilter()
|
||||
|
||||
|
||||
# Roboschool
|
||||
gym_roboschool_envs = ['inverted_pendulum', 'inverted_pendulum_swingup', 'inverted_double_pendulum', 'reacher',
|
||||
'hopper', 'walker2d', 'half_cheetah', 'ant', 'humanoid', 'humanoid_flagrun',
|
||||
'humanoid_flagrun_harder', 'pong']
|
||||
roboschool_v0 = {e: "{}".format(lower_under_to_upper(e) + '-v0') for e in gym_roboschool_envs}
|
||||
|
||||
"""
|
||||
Mujoco Environment Components
|
||||
"""
|
||||
MujocoInputFilter = NoInputFilter()
|
||||
MujocoOutputFilter = NoOutputFilter()
|
||||
|
||||
|
||||
class Mujoco(GymEnvironmentParameters):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.frame_skip = 1
|
||||
self.default_input_filter = MujocoInputFilter
|
||||
self.default_output_filter = MujocoOutputFilter
|
||||
|
||||
|
||||
# Mujoco
|
||||
gym_mujoco_envs = ['inverted_pendulum', 'inverted_double_pendulum', 'reacher', 'hopper', 'walker2d', 'half_cheetah',
|
||||
'ant', 'swimmer', 'humanoid', 'humanoid_standup', 'pusher', 'thrower', 'striker']
|
||||
|
||||
mujoco_v2 = {e: "{}".format(lower_under_to_upper(e) + '-v2') for e in gym_mujoco_envs}
|
||||
mujoco_v2['walker2d'] = 'Walker2d-v2'
|
||||
|
||||
# Fetch
|
||||
gym_fetch_envs = ['reach', 'slide', 'push', 'pick_and_place']
|
||||
fetch_v1 = {e: "{}".format('Fetch' + lower_under_to_upper(e) + '-v1') for e in gym_fetch_envs}
|
||||
|
||||
"""
|
||||
Bullet Environment Components
|
||||
"""
|
||||
BulletInputFilter = NoInputFilter()
|
||||
BulletOutputFilter = NoOutputFilter()
|
||||
|
||||
|
||||
class Bullet(GymEnvironmentParameters):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.frame_skip = 1
|
||||
self.default_input_filter = BulletInputFilter
|
||||
self.default_output_filter = BulletOutputFilter
|
||||
|
||||
|
||||
"""
|
||||
Atari Environment Components
|
||||
@@ -145,8 +113,8 @@ AtariOutputFilter = NoOutputFilter()
|
||||
|
||||
|
||||
class Atari(GymEnvironmentParameters):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
def __init__(self, level=None):
|
||||
super().__init__(level=level)
|
||||
self.frame_skip = 4
|
||||
self.max_over_num_frames = 2
|
||||
self.random_initialization_steps = 30
|
||||
@@ -167,6 +135,14 @@ atari_deterministic_v4 = {e: "{}".format(lower_under_to_upper(e) + 'Deterministi
|
||||
atari_no_frameskip_v4 = {e: "{}".format(lower_under_to_upper(e) + 'NoFrameskip-v4') for e in gym_atari_envs}
|
||||
|
||||
|
||||
# default atari schedule used in the DeepMind papers
|
||||
atari_schedule = ScheduleParameters()
|
||||
atari_schedule.improve_steps = EnvironmentSteps(50000000)
|
||||
atari_schedule.steps_between_evaluation_periods = EnvironmentSteps(250000)
|
||||
atari_schedule.evaluation_steps = EnvironmentSteps(135000)
|
||||
atari_schedule.heatup_steps = EnvironmentSteps(50000)
|
||||
|
||||
|
||||
class MaxOverFramesAndFrameskipEnvWrapper(gym.Wrapper):
|
||||
def __init__(self, env, frameskip=4, max_over_num_frames=2):
|
||||
super().__init__(env)
|
||||
|
||||
Reference in New Issue
Block a user