1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-18 11:40:18 +01:00

network_imporvements branch merge

This commit is contained in:
Shadi Endrawis
2018-10-02 13:41:46 +03:00
parent 72ea933384
commit 51726a5b80
110 changed files with 1639 additions and 1161 deletions

View File

@@ -18,6 +18,7 @@ import gym
import numpy as np
import scipy.ndimage
from rl_coach.graph_managers.graph_manager import ScheduleParameters
from rl_coach.utils import lower_under_to_upper, short_dynamic_import
try:
@@ -40,7 +41,7 @@ except ImportError:
failed_imports.append("PyBullet")
from typing import Dict, Any, Union
from rl_coach.core_types import RunPhase
from rl_coach.core_types import RunPhase, EnvironmentSteps
from rl_coach.environments.environment import Environment, EnvironmentParameters, LevelSelection
from rl_coach.spaces import DiscreteActionSpace, BoxActionSpace, ImageObservationSpace, VectorObservationSpace, \
StateSpace, RewardSpace
@@ -57,10 +58,9 @@ from rl_coach.logger import screen
# Parameters
class GymEnvironmentParameters(EnvironmentParameters):
def __init__(self):
super().__init__()
def __init__(self, level=None):
super().__init__(level=level)
self.random_initialization_steps = 0
self.max_over_num_frames = 1
self.additional_simulator_parameters = None
@@ -70,64 +70,32 @@ class GymEnvironmentParameters(EnvironmentParameters):
return 'rl_coach.environments.gym_environment:GymEnvironment'
"""
Roboschool Environment Components
"""
RoboSchoolInputFilters = NoInputFilter()
RoboSchoolOutputFilters = NoOutputFilter()
class Roboschool(GymEnvironmentParameters):
def __init__(self):
super().__init__()
# Generic parameters for vector environments such as mujoco, roboschool, bullet, etc.
class GymVectorEnvironment(GymEnvironmentParameters):
def __init__(self, level=None):
super().__init__(level=level)
self.frame_skip = 1
self.default_input_filter = RoboSchoolInputFilters
self.default_output_filter = RoboSchoolOutputFilters
self.default_input_filter = NoInputFilter()
self.default_output_filter = NoOutputFilter()
# Roboschool
gym_roboschool_envs = ['inverted_pendulum', 'inverted_pendulum_swingup', 'inverted_double_pendulum', 'reacher',
'hopper', 'walker2d', 'half_cheetah', 'ant', 'humanoid', 'humanoid_flagrun',
'humanoid_flagrun_harder', 'pong']
roboschool_v0 = {e: "{}".format(lower_under_to_upper(e) + '-v0') for e in gym_roboschool_envs}
"""
Mujoco Environment Components
"""
MujocoInputFilter = NoInputFilter()
MujocoOutputFilter = NoOutputFilter()
class Mujoco(GymEnvironmentParameters):
def __init__(self):
super().__init__()
self.frame_skip = 1
self.default_input_filter = MujocoInputFilter
self.default_output_filter = MujocoOutputFilter
# Mujoco
gym_mujoco_envs = ['inverted_pendulum', 'inverted_double_pendulum', 'reacher', 'hopper', 'walker2d', 'half_cheetah',
'ant', 'swimmer', 'humanoid', 'humanoid_standup', 'pusher', 'thrower', 'striker']
mujoco_v2 = {e: "{}".format(lower_under_to_upper(e) + '-v2') for e in gym_mujoco_envs}
mujoco_v2['walker2d'] = 'Walker2d-v2'
# Fetch
gym_fetch_envs = ['reach', 'slide', 'push', 'pick_and_place']
fetch_v1 = {e: "{}".format('Fetch' + lower_under_to_upper(e) + '-v1') for e in gym_fetch_envs}
"""
Bullet Environment Components
"""
BulletInputFilter = NoInputFilter()
BulletOutputFilter = NoOutputFilter()
class Bullet(GymEnvironmentParameters):
def __init__(self):
super().__init__()
self.frame_skip = 1
self.default_input_filter = BulletInputFilter
self.default_output_filter = BulletOutputFilter
"""
Atari Environment Components
@@ -145,8 +113,8 @@ AtariOutputFilter = NoOutputFilter()
class Atari(GymEnvironmentParameters):
def __init__(self):
super().__init__()
def __init__(self, level=None):
super().__init__(level=level)
self.frame_skip = 4
self.max_over_num_frames = 2
self.random_initialization_steps = 30
@@ -167,6 +135,14 @@ atari_deterministic_v4 = {e: "{}".format(lower_under_to_upper(e) + 'Deterministi
atari_no_frameskip_v4 = {e: "{}".format(lower_under_to_upper(e) + 'NoFrameskip-v4') for e in gym_atari_envs}
# default atari schedule used in the DeepMind papers
atari_schedule = ScheduleParameters()
atari_schedule.improve_steps = EnvironmentSteps(50000000)
atari_schedule.steps_between_evaluation_periods = EnvironmentSteps(250000)
atari_schedule.evaluation_steps = EnvironmentSteps(135000)
atari_schedule.heatup_steps = EnvironmentSteps(50000)
class MaxOverFramesAndFrameskipEnvWrapper(gym.Wrapper):
def __init__(self, env, frameskip=4, max_over_num_frames=2):
super().__init__(env)