1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-18 03:30:19 +01:00

Robosuite exploration (#478)

* Add Robosuite parameters for all env types + initialize env flow

* Init flow done

* Rest of Environment API complete for RobosuiteEnvironment

* RobosuiteEnvironment changes

* Observation stacking filter
* Add proper frame_skip in addition to control_freq
* Hardcode Coach rendering to 'frontview' camera

* Robosuite_Lift_DDPG preset + Robosuite env updates

* Move observation stacking filter from env to preset
* Pre-process observation - concatenate depth map (if exists)
  to image and object state (if exists) to robot state
* Preset parameters based on Surreal DDPG parameters, taken from:
  https://github.com/SurrealAI/surreal/blob/master/surreal/main/ddpg_configs.py

* RobosuiteEnvironment fixes - working now with PyGame rendering

* Preset minor modifications

* ObservationStackingFilter - option to concat non-vector observations

* Consider frame skip when setting horizon in robosuite env

* Robosuite lift preset - update heatup length and training interval

* Robosuite env - change control_freq to 10 to match Surreal usage

* Robosuite clipped PPO preset

* Distribute multiple workers (-n #) over multiple GPUs

* Clipped PPO memory optimization from @shadiendrawis

* Fixes to evaluation only workers

* RoboSuite_ClippedPPO: Update training interval

* Undo last commit (update training interval)

* Fix "doube-negative" if conditions

* multi-agent single-trainer clipped ppo training with cartpole

* cleanups (not done yet) + ~tuned hyper-params for mast

* Switch to Robosuite v1 APIs

* Change presets to IK controller

* more cleanups + enabling evaluation worker + better logging

* RoboSuite_Lift_ClippedPPO updates

* Fix major bug in obs normalization filter setup

* Reduce coupling between Robosuite API and Coach environment

* Now only non task-specific parameters are explicitly defined
  in Coach
* Removed a bunch of enums of Robosuite elements, using simple
  strings instead
* With this change new environments/robots/controllers in Robosuite
  can be used immediately in Coach

* MAST: better logging of actor-trainer interaction + bug fixes + performance improvements.

Still missing: fixed pubsub for obs normalization running stats + logging for trainer signals

* lstm support for ppo

* setting JOINT VELOCITY action space by default + fix for EveryNEpisodes video dump filter + new TaskIDDumpFilter + allowing or between video dump filters

* Separate Robosuite clipped PPO preset for the non-MAST case

* Add flatten layer to architectures and use it in Robosuite presets

This is required for embedders that mix conv and dense

TODO: Add MXNet implementation

* publishing running_stats together with the published policy + hyper-param for when to publish a policy + cleanups

* bug-fix for memory leak in MAST

* Bugfix: Return value in TF BatchnormActivationDropout.to_tf_instance

* Explicit activations in embedder scheme so there's no ReLU after flatten

* Add clipped PPO heads with configurable dense layers at the beginning

* This is a workaround needed to mimic Surreal-PPO, where the CNN and
  LSTM are shared between actor and critic but the FC layers are not
  shared
* Added a "SchemeBuilder" class, currently only used for the new heads
  but we can change Middleware and Embedder implementations to use it
  as well

* Video dump setting fix in basic preset

* logging screen output to file

* coach to start the redis-server for a MAST run

* trainer drops off-policy data + old policy in ClippedPPO updates only after policy was published + logging free memory stats + actors check for a new policy only at the beginning of a new episode + fixed a bug where the trainer was logging "Training Reward = 0", causing dashboard to incorrectly display the signal

* Add missing set_internal_state function in TFSharedRunningStats

* Robosuite preset - use SingleLevelSelect instead of hard-coded level

* policy ID published directly on Redis

* Small fix when writing to log file

* Major bugfix in Robosuite presets - pass dense sizes to heads

* RoboSuite_Lift_ClippedPPO hyper-params update

* add horizon and value bootstrap to GAE calculation, fix A3C with LSTM

* adam hyper-params from mujoco

* updated MAST preset with IK_POSE_POS controller

* configurable initialization for policy stdev + custom extra noise per actor + logging of policy stdev to dashboard

* values loss weighting of 0.5

* minor fixes + presets

* bug-fix for MAST  where the old policy in the trainer had kept updating every training iter while it should only update after every policy publish

* bug-fix: reset_internal_state was not called by the trainer

* bug-fixes in the lstm flow + some hyper-param adjustments for CartPole_ClippedPPO_LSTM -> training and sometimes reaches 200

* adding back the horizon hyper-param - a messy commit

* another bug-fix missing from prev commit

* set control_freq=2 to match action_scale 0.125

* ClippedPPO with MAST cleanups and some preps for TD3 with MAST

* TD3 presets. RoboSuite_Lift_TD3 seems to work well with multi-process runs (-n 8)

* setting termination on collision to be on by default

* bug-fix following prev-prev commit

* initial cube exploration environment with TD3 commit

* bug fix + minor refactoring

* several parameter changes and RND debugging

* Robosuite Gym wrapper + Rename TD3_Random* -> Random*

* algorithm update

* Add RoboSuite v1 env + presets (to eventually replace non-v1 ones)

* Remove grasping presets, keep only V1 exp. presets (w/o V1 tag)

* Keep just robosuite V1 env as the 'robosuite_environment' module

* Exclude Robosuite and MAST presets from integration tests

* Exclude LSTM and MAST presets from golden tests

* Fix mistakenly removed import

* Revert debug changes in ReaderWriterLock

* Try another way to exclude LSTM/MAST golden tests

* Remove debug prints

* Remove PreDense heads, unused in the end

* Missed removing an instance of PreDense head

* Remove MAST, not required for this PR

* Undo unused concat option in ObservationStackingFilter

* Remove LSTM updates, not required in this PR

* Update README.md

* code changes for the exploration flow to work with robosuite master branch

* code cleanup + documentation

* jupyter tutorial for the goal-based exploration + scatter plot

* typo fix

* Update README.md

* seprate parameter for the obs-goal observation + small fixes

* code clarity fixes

* adjustment in tutorial 5

* Update tutorial

* Update tutorial

Co-authored-by: Guy Jacob <guy.jacob@intel.com>
Co-authored-by: Gal Leibovich <gal.leibovich@intel.com>
Co-authored-by: shadi.endrawis <sendrawi@aipg-ra-skx-03.ra.intel.com>
This commit is contained in:
shadiendrawis
2021-06-01 00:34:19 +03:00
committed by GitHub
parent 235a259223
commit 0896f43097
25 changed files with 1905 additions and 46 deletions

View File

@@ -257,7 +257,6 @@ class Agent(AgentInterface):
:return: None
"""
# Loading a memory from a CSV file, requires an input filter to filter through the data.
# The filter needs a session before it can be used.
if self.ap.memory.load_memory_from_file_path:
@@ -418,10 +417,11 @@ class Agent(AgentInterface):
self.num_successes_across_evaluation_episodes = 0
self.num_evaluation_episodes_completed = 0
# TODO verbosity was mistakenly removed from task_parameters on release 0.11.0, need to bring it back
# if self.ap.is_a_highest_level_agent or self.ap.task_parameters.verbosity == "high":
if self.ap.is_a_highest_level_agent:
screen.log_title("{}: Starting evaluation phase".format(self.name))
if self.ap.task_parameters.evaluate_only is None:
# TODO verbosity was mistakenly removed from task_parameters on release 0.11.0, need to bring it back
# if self.ap.is_a_highest_level_agent or self.ap.task_parameters.verbosity == "high":
if self.ap.is_a_highest_level_agent:
screen.log_title("{}: Starting evaluation phase".format(self.name))
elif ending_evaluation:
# we write to the next episode, because it could be that the current episode was already written
@@ -439,11 +439,12 @@ class Agent(AgentInterface):
"Success Rate",
success_rate)
# TODO verbosity was mistakenly removed from task_parameters on release 0.11.0, need to bring it back
# if self.ap.is_a_highest_level_agent or self.ap.task_parameters.verbosity == "high":
if self.ap.is_a_highest_level_agent:
screen.log_title("{}: Finished evaluation phase. Success rate = {}, Avg Total Reward = {}"
.format(self.name, np.round(success_rate, 2), np.round(evaluation_reward, 2)))
if self.ap.task_parameters.evaluate_only is None:
# TODO verbosity was mistakenly removed from task_parameters on release 0.11.0, need to bring it back
# if self.ap.is_a_highest_level_agent or self.ap.task_parameters.verbosity == "high":
if self.ap.is_a_highest_level_agent:
screen.log_title("{}: Finished evaluation phase. Success rate = {}, Avg Total Reward = {}"
.format(self.name, np.round(success_rate, 2), np.round(evaluation_reward, 2)))
def call_memory(self, func, args=()):
"""
@@ -568,7 +569,7 @@ class Agent(AgentInterface):
for transition in self.current_episode_buffer.transitions:
self.discounted_return.add_sample(transition.n_step_discounted_rewards)
if self.phase != RunPhase.TEST or self.ap.task_parameters.evaluate_only:
if self.phase != RunPhase.TEST or self.ap.task_parameters.evaluate_only is not None:
self.current_episode += 1
if self.phase != RunPhase.TEST:
@@ -828,7 +829,7 @@ class Agent(AgentInterface):
return None
# count steps (only when training or if we are in the evaluation worker)
if self.phase != RunPhase.TEST or self.ap.task_parameters.evaluate_only:
if self.phase != RunPhase.TEST or self.ap.task_parameters.evaluate_only is not None:
self.total_steps_counter += 1
self.current_episode_steps_counter += 1

View File

@@ -15,6 +15,7 @@
#
import copy
import math
from collections import OrderedDict
from random import shuffle
from typing import Union
@@ -156,8 +157,17 @@ class ClippedPPOAgent(ActorCriticAgent):
def fill_advantages(self, batch):
network_keys = self.ap.network_wrappers['main'].input_embedders_parameters.keys()
current_state_values = self.networks['main'].online_network.predict(batch.states(network_keys))[0]
current_state_values = current_state_values.squeeze()
state_values = []
for i in range(int(batch.size / self.ap.network_wrappers['main'].batch_size) + 1):
start = i * self.ap.network_wrappers['main'].batch_size
end = (i + 1) * self.ap.network_wrappers['main'].batch_size
if start == batch.size:
break
state_values.append(self.networks['main'].online_network.predict(
{k: v[start:end] for k, v in batch.states(network_keys).items()})[0])
current_state_values = np.concatenate(state_values)
self.state_values.add_sample(current_state_values)
# calculate advantages
@@ -213,9 +223,7 @@ class ClippedPPOAgent(ActorCriticAgent):
self.networks['main'].online_network.output_heads[1].likelihood_ratio,
self.networks['main'].online_network.output_heads[1].clipped_likelihood_ratio]
# TODO-fixme if batch.size / self.ap.network_wrappers['main'].batch_size is not an integer, we do not train on
# some of the data
for i in range(int(batch.size / self.ap.network_wrappers['main'].batch_size)):
for i in range(math.ceil(batch.size / self.ap.network_wrappers['main'].batch_size)):
start = i * self.ap.network_wrappers['main'].batch_size
end = (i + 1) * self.ap.network_wrappers['main'].batch_size

View File

@@ -0,0 +1,410 @@
#
# Copyright (c) 2019 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import copy
from typing import Union
from collections import OrderedDict
from random import shuffle
import os
from PIL import Image
import joblib
import numpy as np
from rl_coach.agents.agent import Agent
from rl_coach.agents.td3_agent import TD3Agent, TD3CriticNetworkParameters, TD3ActorNetworkParameters, \
TD3AlgorithmParameters, TD3AgentExplorationParameters
from rl_coach.architectures.embedder_parameters import InputEmbedderParameters
from rl_coach.base_parameters import NetworkParameters, AgentParameters, MiddlewareScheme
from rl_coach.core_types import Transition, Batch
from rl_coach.memories.episodic.episodic_experience_replay import EpisodicExperienceReplayParameters
from rl_coach.architectures.middleware_parameters import FCMiddlewareParameters
from rl_coach.architectures.head_parameters import RNDHeadParameters
from rl_coach.utilities.shared_running_stats import NumpySharedRunningStats
from rl_coach.logger import screen
from rl_coach.exploration_policies.e_greedy import EGreedyParameters
from rl_coach.schedules import LinearSchedule
class RNDNetworkParameters(NetworkParameters):
def __init__(self):
super().__init__()
self.input_embedders_parameters = {'observation': InputEmbedderParameters(activation_function='leaky_relu',
input_rescaling={'image': 1.0})}
self.middleware_parameters = FCMiddlewareParameters(scheme=MiddlewareScheme.Empty)
self.heads_parameters = [RNDHeadParameters()]
self.create_target_network = False
self.optimizer_type = 'Adam'
self.batch_size = 100
self.learning_rate = 0.0001
self.should_get_softmax_probabilities = False
class TD3ExplorationAlgorithmParameters(TD3AlgorithmParameters):
"""
:param rnd_sample_size: (int)
The number of states in each RND training iteration.
:param rnd_batch_size: (int)
Batch size for the RND optimization cycle.
:param rnd_optimization_epochs: (int)
Number of epochs for the RND optimization cycle.
:param td3_training_ratio: (float)
The ratio between TD3 training steps and the number of steps in each episode (must be a positive number).
:param identity_goal_sample_rate: (float)
For the goal-based agent, this number indicates the probability to sample a goal that is the identity
(must be a number between 0 and 1).
:param env_obs_key: (str)
The name of the state key for the camera observation from the environment.
:param agent_obs_key: (str)
The name of the state key for the camera observation for the agent. This key has to be different
from env_obs_key in case the agent modifies the observation from the environment. For example,
the goal-based agent concatenates a goal image to the image observation from the environment.
:param replay_buffer_save_steps: (int)
The number of steps to periodically save the replay buffer.
:param replay_buffer_save_path: (str or None)
A path to save the replay buffer to. if set to None, the replay buffer will be saved in the
experiment directory.
"""
def __init__(self):
super().__init__()
self.rnd_sample_size = 2000
self.rnd_batch_size = 500
self.rnd_optimization_epochs = 4
self.td3_training_ratio = 1.0
self.identity_goal_sample_rate = 0.0
self.env_obs_key = 'camera'
self.agent_obs_key = 'camera'
self.replay_buffer_save_steps = 25000
self.replay_buffer_save_path = None
class TD3ExplorationAgentParameters(AgentParameters):
def __init__(self):
td3_exp_algorithm_params = TD3ExplorationAlgorithmParameters()
super().__init__(algorithm=td3_exp_algorithm_params,
exploration=TD3AgentExplorationParameters(),
memory=EpisodicExperienceReplayParameters(),
networks=OrderedDict([("actor", TD3ActorNetworkParameters()),
("critic",
TD3CriticNetworkParameters(td3_exp_algorithm_params.num_q_networks)),
("predictor", RNDNetworkParameters()),
("constant", RNDNetworkParameters())]))
@property
def path(self):
return 'rl_coach.agents.td3_exp_agent:TD3ExplorationAgent'
class TD3ExplorationAgent(TD3Agent):
def __init__(self, agent_parameters, parent: Union['LevelManager', 'CompositeAgent']=None):
super().__init__(agent_parameters, parent)
self.rnd_stats = NumpySharedRunningStats(name='RND_normalization', epsilon=1e-8)
self.rnd_stats.set_params()
self.rnd_obs_stats = NumpySharedRunningStats(name='RND_observation_normalization', epsilon=1e-8)
self.intrinsic_returns_estimate = None
def update_intrinsic_returns_estimate(self, rewards):
returns = np.zeros_like(rewards)
for i, r in enumerate(rewards):
if self.intrinsic_returns_estimate is None:
self.intrinsic_returns_estimate = r
else:
self.intrinsic_returns_estimate = \
self.intrinsic_returns_estimate * self.ap.algorithm.discount + r
returns[i] = self.intrinsic_returns_estimate
return returns
def prepare_rnd_inputs(self, batch):
env_obs_key = self.ap.algorithm.env_obs_key
next_states = batch.next_states([env_obs_key])
inputs = {env_obs_key: self.rnd_obs_stats.normalize(next_states[env_obs_key])}
return inputs
def handle_self_supervised_reward(self, batch):
"""
Allows agents to update the batch for self supervised learning
:param batch: original training batch
:return: updated traing batch
"""
return batch
def update_transition_before_adding_to_replay_buffer(self, transition: Transition) -> Transition:
"""
Allows agents to update the transition just before adding it to the replay buffer.
Can be useful for agents that want to tweak the reward, termination signal, etc.
:param transition: the transition to update
:return: the updated transition
"""
transition = super().update_transition_before_adding_to_replay_buffer(transition)
image = np.array(transition.state[self.ap.algorithm.env_obs_key])
if self.rnd_obs_stats.n < 1:
self.rnd_obs_stats.set_params(shape=image.shape, clip_values=[-5, 5])
self.rnd_obs_stats.push_val(np.expand_dims(image, 0))
return transition
def train_rnd(self):
if self.memory.num_transitions() == 0:
return
transitions = self.memory.transitions[-self.ap.algorithm.rnd_sample_size:]
dataset = Batch(transitions)
dataset_order = list(range(dataset.size))
batch_size = self.ap.algorithm.rnd_batch_size
for epoch in range(self.ap.algorithm.rnd_optimization_epochs):
shuffle(dataset_order)
total_loss = 0
total_grads = 0
for i in range(int(dataset.size / batch_size)):
start = i * batch_size
end = (i + 1) * batch_size
batch = Batch(list(np.array(dataset.transitions)[dataset_order[start:end]]))
inputs = self.prepare_rnd_inputs(batch)
const_embedding = self.networks['constant'].online_network.predict(inputs)
res = self.networks['predictor'].train_and_sync_networks(inputs, [const_embedding])
total_loss += res[0]
total_grads += res[2]
screen.log_dict(
OrderedDict([
("training epoch", epoch),
("dataset size", dataset.size),
("mean loss", total_loss / dataset.size),
("mean gradients", total_grads / dataset.size)
]),
prefix="RND Training"
)
def learn_from_batch(self, batch):
batch = self.handle_self_supervised_reward(batch)
return super().learn_from_batch(batch)
def train(self):
self.ap.algorithm.num_consecutive_training_steps = \
int(self.current_episode_steps_counter * self.ap.algorithm.td3_training_ratio)
return Agent.train(self)
def calculate_novelty(self, batch):
inputs = self.prepare_rnd_inputs(batch)
embedding = self.networks['constant'].online_network.predict(inputs)
prediction = self.networks['predictor'].online_network.predict(inputs)
prediction_error = np.mean((embedding - prediction) ** 2, axis=1)
return prediction_error
def save_replay_buffer(self, dir_path=None):
if dir_path is None:
dir_path = os.path.join(self.parent_level_manager.parent_graph_manager.task_parameters.experiment_path,
'replay_buffer')
if not os.path.exists(dir_path):
os.mkdir(dir_path)
path = os.path.join(dir_path, 'RB_{}.joblib.bz2'.format(type(self).__name__))
joblib.dump(self.memory.get_all_complete_episodes(), path, compress=('bz2', 1))
screen.log('Saved replay buffer to: \"{}\" - Number of transitions: {}'.format(path,
self.memory.num_transitions()))
def handle_episode_ended(self) -> None:
super().handle_episode_ended()
if self.total_steps_counter % self.ap.algorithm.rnd_sample_size == 0:
self.train_rnd()
if self.total_steps_counter % self.ap.algorithm.replay_buffer_save_steps == 0:
self.save_replay_buffer(self.ap.algorithm.replay_buffer_save_path)
self.save_rnd_images(self.ap.algorithm.replay_buffer_save_path)
def save_rnd_images(self, dir_path=None):
if dir_path is None:
dir_path = os.path.join(self.parent_level_manager.parent_graph_manager.task_parameters.experiment_path,
'rnd_images')
else:
dir_path = os.path.join(dir_path, 'rnd_images')
if not os.path.exists(dir_path):
os.mkdir(dir_path)
transitions = self.memory.transitions
dataset = Batch(transitions)
batch_size = self.ap.algorithm.rnd_batch_size
novelties = []
for i in range(int(dataset.size / batch_size)):
start = i * batch_size
end = (i + 1) * batch_size
batch = Batch(dataset[start:end])
novelty = self.calculate_novelty(batch)
novelties.append(novelty)
novelties = np.concatenate(novelties)
sorted_indices = np.argsort(novelties)
sample_indices = sorted_indices[np.round(np.linspace(0, len(sorted_indices) - 1, 100)).astype(np.uint32)]
images = []
for si in sample_indices:
images.append(np.flip(transitions[si].next_state[self.ap.algorithm.env_obs_key], 0))
rows = []
for i in range(10):
rows.append(np.hstack(images[(i * 10):((i + 1) * 10)]))
image = np.vstack(rows)
image = Image.fromarray(image)
image.save('{}/{}_{}.jpeg'.format(dir_path, 'rnd_samples', len(transitions)))
class TD3IntrinsicRewardAgentParameters(TD3ExplorationAgentParameters):
@property
def path(self):
return 'rl_coach.agents.td3_exp_agent:TD3IntrinsicRewardAgent'
class TD3IntrinsicRewardAgent(TD3ExplorationAgent):
def __init__(self, agent_parameters, parent: Union['LevelManager', 'CompositeAgent']=None):
super().__init__(agent_parameters, parent)
def handle_self_supervised_reward(self, batch):
novelty = self.calculate_novelty(batch)
for i, t in enumerate(batch.transitions):
t.reward = novelty[i] / self.rnd_stats.std[0]
return batch
def handle_episode_ended(self) -> None:
super().handle_episode_ended()
novelty = self.calculate_novelty(Batch(self.memory.get_last_complete_episode().transitions))
self.rnd_stats.push_val(np.expand_dims(self.update_intrinsic_returns_estimate(novelty), -1))
class RandomAgentParameters(TD3ExplorationAgentParameters):
def __init__(self):
super().__init__()
self.exploration = EGreedyParameters()
self.exploration.epsilon_schedule = LinearSchedule(1.0, 1.0, 500000000)
@property
def path(self):
return 'rl_coach.agents.td3_exp_agent:RandomAgent'
class RandomAgent(TD3ExplorationAgent):
def __init__(self, agent_parameters, parent: Union['LevelManager', 'CompositeAgent']=None):
super().__init__(agent_parameters, parent)
self.ap.algorithm.periodic_exploration_noise = None
self.ap.algorithm.rnd_sample_size = 100000000000
def train(self):
return 0
class TD3GoalBasedAgentParameters(TD3ExplorationAgentParameters):
@property
def path(self):
return 'rl_coach.agents.td3_exp_agent:TD3GoalBasedAgent'
class TD3GoalBasedAgent(TD3ExplorationAgent):
def __init__(self, agent_parameters, parent: Union['LevelManager', 'CompositeAgent']=None):
super().__init__(agent_parameters, parent)
self.goal = None
self.ap.algorithm.use_non_zero_discount_for_terminal_states = False
def concat_goal(self, state, goal_state):
ret = np.concatenate([state[self.ap.algorithm.env_obs_key], goal_state[self.ap.algorithm.env_obs_key]], axis=2)
return ret
def handle_self_supervised_reward(self, batch):
batch_size = self.ap.network_wrappers['actor'].batch_size
episode_indices = np.random.randint(self.memory.num_complete_episodes(), size=batch_size)
transitions = []
for e_idx in episode_indices:
episode = self.memory.get_all_complete_episodes()[e_idx]
transition_idx = np.random.randint(episode.length())
t = copy.copy(episode[transition_idx])
if np.random.rand(1) < self.ap.algorithm.identity_goal_sample_rate:
t.state[self.ap.algorithm.agent_obs_key] = self.concat_goal(t.state, t.state)
# this doesn't matter for learning but is set anyway so that the agent can pass it through the network
t.next_state[self.ap.algorithm.agent_obs_key] = self.concat_goal(t.next_state, t.state)
t.game_over = True
t.reward = 0
t.action = np.zeros_like(t.action)
else:
if transition_idx == episode.length() - 1:
goal = t
t.state[self.ap.algorithm.agent_obs_key] = self.concat_goal(t.state, t.next_state)
t.next_state[self.ap.algorithm.agent_obs_key] = self.concat_goal(t.next_state, t.next_state)
else:
goal_idx = np.random.randint(transition_idx, episode.length())
goal = episode.transitions[goal_idx]
t.state[self.ap.algorithm.agent_obs_key] = self.concat_goal(t.state, episode.transitions[goal_idx].next_state)
t.next_state[self.ap.algorithm.agent_obs_key] = self.concat_goal(t.next_state,
episode.transitions[goal_idx].next_state)
camera_equal = np.alltrue(np.equal(t.next_state[self.ap.algorithm.env_obs_key],
goal.next_state[self.ap.algorithm.env_obs_key]))
measurements_equal = np.alltrue(np.isclose(t.next_state['measurements'],
goal.next_state['measurements']))
t.game_over = camera_equal and measurements_equal
t.reward = -1
transitions.append(t)
return Batch(transitions)
def choose_action(self, curr_state):
if self.goal:
curr_state[self.ap.algorithm.agent_obs_key] = self.concat_goal(curr_state, self.goal.next_state)
else:
curr_state[self.ap.algorithm.agent_obs_key] = self.concat_goal(curr_state, curr_state)
return super().choose_action(curr_state)
def generate_goal(self):
if self.memory.num_transitions() == 0:
return
transitions = list(np.random.choice(self.memory.transitions,
min(self.ap.algorithm.rnd_sample_size,
self.memory.num_transitions()),
replace=False))
dataset = Batch(transitions)
batch_size = self.ap.algorithm.rnd_batch_size
self.goal = dataset[0]
max_novelty = 0
for i in range(int(dataset.size / batch_size)):
start = i * batch_size
end = (i + 1) * batch_size
novelty = self.calculate_novelty(Batch(dataset[start:end]))
curr_max = np.max(novelty)
if curr_max > max_novelty:
max_novelty = curr_max
idx = start + np.argmax(novelty)
self.goal = dataset[idx]
def handle_episode_ended(self) -> None:
super().handle_episode_ended()
self.generate_goal()

View File

@@ -258,3 +258,9 @@ class TD3VHeadParameters(HeadParameters):
loss_weight=loss_weight)
self.initializer = initializer
self.output_bias_initializer = output_bias_initializer
class RNDHeadParameters(HeadParameters):
def __init__(self, name: str = 'rnd_head_params', dense_layer=None, is_predictor=False):
super().__init__(parameterized_class_name="RNDHead", name=name, dense_layer=dense_layer)
self.is_predictor = is_predictor

View File

@@ -56,4 +56,4 @@ class LSTMMiddlewareParameters(MiddlewareParameters):
super().__init__(parameterized_class_name="LSTMMiddleware", activation_function=activation_function,
scheme=scheme, batchnorm=batchnorm, dropout_rate=dropout_rate, name=name, dense_layer=dense_layer,
is_training=is_training)
self.number_of_lstm_cells = number_of_lstm_cells
self.number_of_lstm_cells = number_of_lstm_cells

View File

@@ -0,0 +1,54 @@
#
# Copyright (c) 2019 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import tensorflow as tf
import numpy as np
from rl_coach.architectures.tensorflow_components.layers import Conv2d, BatchnormActivationDropout
from rl_coach.architectures.tensorflow_components.heads.head import Head, Orthogonal
from rl_coach.base_parameters import AgentParameters
from rl_coach.core_types import Embedding
from rl_coach.spaces import SpacesDefinition
class RNDHead(Head):
def __init__(self, agent_parameters: AgentParameters, spaces: SpacesDefinition, network_name: str,
head_idx: int = 0, is_local: bool = True, is_predictor: bool = False):
super().__init__(agent_parameters, spaces, network_name, head_idx, is_local)
self.name = 'rnd_head'
self.return_type = Embedding
self.is_predictor = is_predictor
self.activation_function = tf.nn.leaky_relu
self.loss_type = tf.losses.mean_squared_error
def _build_module(self, input_layer):
weight_init = Orthogonal(gain=np.sqrt(2))
input_layer = Conv2d(num_filters=32, kernel_size=8, strides=4)(input_layer, kernel_initializer=weight_init)
input_layer = BatchnormActivationDropout(activation_function=self.activation_function)(input_layer)[-1]
input_layer = Conv2d(num_filters=64, kernel_size=4, strides=2)(input_layer, kernel_initializer=weight_init)
input_layer = BatchnormActivationDropout(activation_function=self.activation_function)(input_layer)[-1]
input_layer = Conv2d(num_filters=64, kernel_size=3, strides=1)(input_layer, kernel_initializer=weight_init)
input_layer = BatchnormActivationDropout(activation_function=self.activation_function)(input_layer)[-1]
input_layer = tf.contrib.layers.flatten(input_layer)
if self.is_predictor:
input_layer = self.dense_layer(512)(input_layer, kernel_initializer=weight_init)
input_layer = BatchnormActivationDropout(activation_function=tf.nn.relu)(input_layer)[-1]
input_layer = self.dense_layer(512)(input_layer, kernel_initializer=weight_init)
input_layer = BatchnormActivationDropout(activation_function=tf.nn.relu)(input_layer)[-1]
self.output = self.dense_layer(512)(input_layer, name='output', kernel_initializer=weight_init)

View File

@@ -19,6 +19,7 @@ from .cil_head import RegressionHead
from .td3_v_head import TD3VHead
from .ddpg_v_head import DDPGVHead
from .wolpertinger_actor_head import WolpertingerActorHead
from .RND_head import RNDHead
__all__ = [
'CategoricalQHead',
@@ -41,5 +42,6 @@ __all__ = [
'RegressionHead',
'TD3VHead',
'DDPGVHead',
'WolpertingerActorHead'
'WolpertingerActorHead',
'RNDHead'
]

View File

@@ -23,6 +23,7 @@ from rl_coach.spaces import SpacesDefinition
from rl_coach.utils import force_list
from rl_coach.architectures.tensorflow_components.utils import squeeze_tensor
# Used to initialize weights for policy and value output layers
def normalized_columns_initializer(std=1.0):
def _initializer(shape, dtype=None, partition_info=None):
@@ -32,6 +33,29 @@ def normalized_columns_initializer(std=1.0):
return _initializer
# Used to initialize RND network parameters
class Orthogonal(tf.initializers.orthogonal):
def __init__(self, gain=1.0):
super().__init__(gain=gain)
def __call__(self, shape, dtype=None, partition_info=None):
shape = tuple(shape)
if len(shape) == 2:
flat_shape = shape
elif len(shape) == 4: # assumes NHWC
flat_shape = (np.prod(shape[:-1]), shape[-1])
else:
raise NotImplementedError
a = np.random.normal(0.0, 1.0, flat_shape)
u, _, v = np.linalg.svd(a, full_matrices=False)
q = u if u.shape == flat_shape else v # pick the one with the correct shape
q = q.reshape(shape)
return (self.gain * q[:shape[0], :shape[1]]).astype(np.float32)
def get_config(self):
return {"gain": self.gain}
class Head(object):
"""
A head is the final part of the network. It takes the embedding from the middleware embedder and passes it through

View File

@@ -109,7 +109,7 @@ class Conv2d(layers.Conv2d):
def __init__(self, num_filters: int, kernel_size: int, strides: int):
super(Conv2d, self).__init__(num_filters=num_filters, kernel_size=kernel_size, strides=strides)
def __call__(self, input_layer, name: str=None, is_training=None):
def __call__(self, input_layer, name: str=None, is_training=None, kernel_initializer=None):
"""
returns a tensorflow conv2d layer
:param input_layer: previous layer
@@ -117,7 +117,8 @@ class Conv2d(layers.Conv2d):
:return: conv2d layer
"""
return tf.layers.conv2d(input_layer, filters=self.num_filters, kernel_size=self.kernel_size,
strides=self.strides, data_format='channels_last', name=name)
strides=self.strides, data_format='channels_last', name=name,
kernel_initializer=kernel_initializer)
@staticmethod
@reg_to_tf_instance(layers.Conv2d)
@@ -153,7 +154,7 @@ class BatchnormActivationDropout(layers.BatchnormActivationDropout):
@staticmethod
@reg_to_tf_instance(layers.BatchnormActivationDropout)
def to_tf_instance(base: layers.BatchnormActivationDropout):
return BatchnormActivationDropout, BatchnormActivationDropout(
return BatchnormActivationDropout(
batchnorm=base.batchnorm,
activation_function=base.activation_function,
dropout_rate=base.dropout_rate)

View File

@@ -37,7 +37,8 @@ import subprocess
from glob import glob
from rl_coach.graph_managers.graph_manager import HumanPlayScheduleParameters, GraphManager
from rl_coach.utils import list_all_presets, short_dynamic_import, get_open_port, SharedMemoryScratchPad, get_base_dir
from rl_coach.utils import list_all_presets, short_dynamic_import, get_open_port, SharedMemoryScratchPad, \
get_base_dir, set_gpu
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager
from rl_coach.environments.environment import SingleLevelSelection
from rl_coach.memories.backend.redis import RedisPubSubMemoryBackendParameters
@@ -49,12 +50,40 @@ from rl_coach.data_stores.redis_data_store import RedisDataStoreParameters
from rl_coach.data_stores.data_store_impl import get_data_store, construct_data_store_params
from rl_coach.training_worker import training_worker
from rl_coach.rollout_worker import rollout_worker
from rl_coach.schedules import *
from rl_coach.exploration_policies.e_greedy import *
if len(set(failed_imports)) > 0:
screen.warning("Warning: failed to import the following packages - {}".format(', '.join(set(failed_imports))))
def _get_cuda_available_devices():
import ctypes
try:
devices = os.environ['CUDA_VISIBLE_DEVICES'].split(',')
return [] if devices[0] == '' else [int(i) for i in devices]
except KeyError:
pass
try:
cuda_lib = ctypes.CDLL('libcuda.so')
except OSError:
return []
CUDA_SUCCESS = 0
num_gpus = ctypes.c_int()
result = cuda_lib.cuInit(0)
if result != CUDA_SUCCESS:
return []
result = cuda_lib.cuDeviceGetCount(ctypes.byref(num_gpus))
if result != CUDA_SUCCESS:
return []
return list(range(num_gpus.value))
def add_items_to_dict(target_dict, source_dict):
updated_task_parameters = copy.copy(source_dict)
updated_task_parameters.update(target_dict)
@@ -215,6 +244,8 @@ class CoachLauncher(object):
and handle absolutely everything for a job.
"""
gpus = _get_cuda_available_devices()
def launch(self):
"""
Main entry point for the class, and the standard way to run coach from the command line.
@@ -440,6 +471,9 @@ class CoachLauncher(object):
screen.warning("Exporting ONNX graphs requires setting the --checkpoint_save_secs flag. "
"The --export_onnx_graph will have no effect.")
if args.use_cpu or not CoachLauncher.gpus:
CoachLauncher.gpus = [None]
return args
def get_argument_parser(self) -> argparse.ArgumentParser:
@@ -609,9 +643,9 @@ class CoachLauncher(object):
# Single-threaded runs
if args.num_workers == 1:
self.start_single_threaded(task_parameters, graph_manager, args)
self.start_single_process(task_parameters, graph_manager, args)
else:
self.start_multi_threaded(graph_manager, args)
self.start_multi_process(graph_manager, args)
@staticmethod
def create_task_parameters(graph_manager: 'GraphManager', args: argparse.Namespace):
@@ -669,12 +703,12 @@ class CoachLauncher(object):
return task_parameters
@staticmethod
def start_single_threaded(task_parameters, graph_manager: 'GraphManager', args: argparse.Namespace):
def start_single_process(task_parameters, graph_manager: 'GraphManager', args: argparse.Namespace):
# Start the training or evaluation
start_graph(graph_manager=graph_manager, task_parameters=task_parameters)
@staticmethod
def start_multi_threaded(graph_manager: 'GraphManager', args: argparse.Namespace):
def start_multi_process(graph_manager: 'GraphManager', args: argparse.Namespace):
total_tasks = args.num_workers
if args.evaluation_worker:
total_tasks += 1
@@ -695,7 +729,8 @@ class CoachLauncher(object):
"and not from a file. ")
def start_distributed_task(job_type, task_index, evaluation_worker=False,
shared_memory_scratchpad=shared_memory_scratchpad):
shared_memory_scratchpad=shared_memory_scratchpad,
gpu_id=None):
task_parameters = DistributedTaskParameters(
framework_type=args.framework,
parameters_server_hosts=ps_hosts,
@@ -715,6 +750,8 @@ class CoachLauncher(object):
export_onnx_graph=args.export_onnx_graph,
apply_stop_condition=args.apply_stop_condition
)
if gpu_id is not None:
set_gpu(gpu_id)
# we assume that only the evaluation workers are rendering
graph_manager.visualization_parameters.render = args.render and evaluation_worker
p = Process(target=start_graph, args=(graph_manager, task_parameters))
@@ -723,25 +760,30 @@ class CoachLauncher(object):
return p
# parameter server
parameter_server = start_distributed_task("ps", 0)
parameter_server = start_distributed_task("ps", 0, gpu_id=CoachLauncher.gpus[0])
# training workers
# wait a bit before spawning the non chief workers in order to make sure the session is already created
curr_gpu_idx = 0
workers = []
workers.append(start_distributed_task("worker", 0))
workers.append(start_distributed_task("worker", 0, gpu_id=CoachLauncher.gpus[curr_gpu_idx]))
time.sleep(2)
for task_index in range(1, args.num_workers):
workers.append(start_distributed_task("worker", task_index))
curr_gpu_idx = (curr_gpu_idx + 1) % len(CoachLauncher.gpus)
workers.append(start_distributed_task("worker", task_index, gpu_id=CoachLauncher.gpus[curr_gpu_idx]))
# evaluation worker
if args.evaluation_worker or args.render:
evaluation_worker = start_distributed_task("worker", args.num_workers, evaluation_worker=True)
curr_gpu_idx = (curr_gpu_idx + 1) % len(CoachLauncher.gpus)
evaluation_worker = start_distributed_task("worker", args.num_workers, evaluation_worker=True,
gpu_id=CoachLauncher.gpus[curr_gpu_idx])
# wait for all workers
[w.join() for w in workers]
if args.evaluation_worker:
evaluation_worker.terminate()
parameter_server.terminate()
class CoachInterface(CoachLauncher):

View File

@@ -151,8 +151,8 @@ class DoomEnvironment(Environment):
Each camera should be an enum from CameraTypes, and there are several options like an RGB observation,
a depth map, a segmentation map, and a top down map of the enviornment.
:param target_success_rate: (float)
Stop experiment if given target success rate was achieved.
:param target_success_rate: (float)
Stop experiment if given target success rate was achieved.
"""
super().__init__(level, seed, frame_skip, human_control, custom_reward_threshold, visualization_parameters, target_success_rate)

View File

@@ -47,20 +47,21 @@ class LevelSelection(object):
class SingleLevelSelection(LevelSelection):
def __init__(self, levels: Union[str, List[str], Dict[str, str]]):
def __init__(self, levels: Union[str, List[str], Dict[str, str]], force_lower=True):
super().__init__(None)
self.levels = levels
if isinstance(levels, list):
self.levels = {level: level for level in levels}
if isinstance(levels, str):
self.levels = {levels: levels}
self.force_lower = force_lower
def __str__(self):
if self.selected_level is None:
logger.screen.error("No level has been selected. Please select a level using the -lvl command line flag, "
"or change the level in the preset. \nThe available levels are: \n{}"
.format(', '.join(sorted(self.levels.keys()))), crash=True)
selected_level = self.selected_level.lower()
selected_level = self.selected_level.lower() if self.force_lower else self.selected_level
if selected_level not in self.levels.keys():
logger.screen.error("The selected level ({}) is not part of the available levels ({})"
.format(selected_level, ', '.join(self.levels.keys())), crash=True)

View File

@@ -0,0 +1,187 @@
#
# Copyright (c) 2021 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import numpy as np
from robosuite.utils.mjcf_utils import CustomMaterial
from robosuite.environments.manipulation.single_arm_env import SingleArmEnv
from robosuite.environments.manipulation.lift import Lift
from robosuite.models.arenas import TableArena
from robosuite.models.objects import BoxObject
from robosuite.models.tasks import ManipulationTask
from robosuite.utils.placement_samplers import UniformRandomSampler
TABLE_TOP_SIZE = (0.84, 1.25, 0.05)
TABLE_OFFSET = (0, 0, 0.82)
class CubeExp(Lift):
"""
This class corresponds to multi-colored cube exploration for a single robot arm.
"""
def __init__(
self,
robots,
table_full_size=TABLE_TOP_SIZE,
table_offset=TABLE_OFFSET,
placement_initializer=None,
penalize_reward_on_collision=False,
end_episode_on_collision=False,
**kwargs
):
"""
Args:
robots (str or list of str): Specification for specific robot arm(s) to be instantiated within this env
(e.g: "Sawyer" would generate one arm; ["Panda", "Panda", "Sawyer"] would generate three robot arms)
Note: Must be a single single-arm robot!
table_full_size (3-tuple): x, y, and z dimensions of the table.
placement_initializer (ObjectPositionSampler instance): if provided, will
be used to place objects on every reset, else a UniformRandomSampler
is used by default.
Rest of kwargs follow Lift class arguments
"""
if placement_initializer is None:
placement_initializer = UniformRandomSampler(
name="ObjectSampler",
x_range=[0.0, 0.0],
y_range=[0.0, 0.0],
rotation=(0.0, 0.0),
ensure_object_boundary_in_range=False,
ensure_valid_placement=True,
reference_pos=table_offset,
z_offset=0.9,
)
super().__init__(
robots=robots,
table_full_size=table_full_size,
placement_initializer=placement_initializer,
initialization_noise=None,
**kwargs
)
self._max_episode_steps = self.horizon
def _load_model(self):
"""
Loads an xml model, puts it in self.model
"""
SingleArmEnv._load_model(self)
# Adjust base pose accordingly
xpos = self.robots[0].robot_model.base_xpos_offset["table"](self.table_full_size[0])
self.robots[0].robot_model.set_base_xpos(xpos)
# load model for table top workspace
mujoco_arena = TableArena(
table_full_size=self.table_full_size,
table_friction=self.table_friction,
table_offset=self.table_offset,
)
# Arena always gets set to zero origin
mujoco_arena.set_origin([0, 0, 0])
cube_material = self._get_cube_material()
self.cube = BoxObject(
name="cube",
size_min=(0.025, 0.025, 0.025),
size_max=(0.025, 0.025, 0.025),
rgba=[1, 0, 0, 1],
material=cube_material,
)
self.placement_initializer.reset()
self.placement_initializer.add_objects(self.cube)
# task includes arena, robot, and objects of interest
self.model = ManipulationTask(
mujoco_arena=mujoco_arena,
mujoco_robots=[robot.robot_model for robot in self.robots],
mujoco_objects=self.cube,
)
@property
def action_spec(self):
"""
Action space (low, high) for this environment
"""
low, high = super().action_spec
return low[:3], high[:3]
def _get_cube_material(self):
from robosuite.utils.mjcf_utils import array_to_string
rgba = (1, 0, 0, 1)
cube_material = CustomMaterial(
texture=rgba,
tex_name="solid",
mat_name="solid_mat",
)
cube_material.tex_attrib.pop('file')
cube_material.tex_attrib["type"] = "cube"
cube_material.tex_attrib["builtin"] = "flat"
cube_material.tex_attrib["rgb1"] = array_to_string(rgba[:3])
cube_material.tex_attrib["rgb2"] = array_to_string(rgba[:3])
cube_material.tex_attrib["width"] = "100"
cube_material.tex_attrib["height"] = "100"
return cube_material
def _reset_internal(self):
"""
Resets simulation internal configurations.
"""
from robosuite.utils.mjmod import Texture
super()._reset_internal()
self._action_dim = 3
geom_id = self.sim.model.geom_name2id('cube_g0_vis')
mat_id = self.sim.model.geom_matid[geom_id]
tex_id = self.sim.model.mat_texid[mat_id]
texture = Texture(self.sim.model, tex_id)
bitmap_to_set = texture.bitmap
bitmap = np.zeros_like(bitmap_to_set)
bitmap[:100, :, :] = 255
bitmap[100:200, :, 0] = 255
bitmap[200:300, :, 1] = 255
bitmap[300:400, :, 2] = 255
bitmap[400:500, :, :2] = 255
bitmap[500:, :, 1:] = 255
bitmap_to_set[:] = bitmap
for render_context in self.sim.render_contexts:
render_context.upload_texture(texture.id)
def _pre_action(self, action, policy_step=False):
""" explicitly shut the gripper """
joined_action = np.append(action, [0., 0., 0., 1.])
self._action_dim = 7
super()._pre_action(joined_action, policy_step)
def _post_action(self, action):
ret = super()._post_action(action)
self._action_dim = 3
return ret
def reward(self, action=None):
return 0
def _check_success(self):
return False

View File

@@ -0,0 +1,18 @@
{
"type": "OSC_POSE",
"input_max": 1,
"input_min": -1,
"output_max": [0.125, 0.125, 0.125, 0.5, 0.5, 0.5],
"output_min": [-0.125, -0.125, -0.125, -0.5, -0.5, -0.5],
"kp": 150,
"damping_ratio": 1,
"impedance_mode": "fixed",
"kp_limits": [0, 300],
"damping_ratio_limits": [0, 10],
"position_limits": [[-0.22, -0.35, 0.82], [0.22, 0.35, 1.3]],
"orientation_limits": null,
"uncouple_pos_ori": true,
"control_delta": true,
"interpolation": null,
"ramp_ratio": 0.2
}

View File

@@ -0,0 +1,321 @@
#
# Copyright (c) 2020 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import Union ,Dict, Any
from enum import Enum, Flag, auto
from copy import deepcopy
import numpy as np
import random
from collections import namedtuple
try:
import robosuite
from robosuite.wrappers import Wrapper, DomainRandomizationWrapper
except ImportError:
from rl_coach.logger import failed_imports
failed_imports.append("Robosuite")
from rl_coach.base_parameters import Parameters, VisualizationParameters
from rl_coach.environments.environment import Environment, EnvironmentParameters, LevelSelection
from rl_coach.spaces import BoxActionSpace, VectorObservationSpace, StateSpace, PlanarMapsObservationSpace
# Importing our custom Robosuite environments here so that they are properly
# registered in Robosuite, and so recognized by 'robosuite.make()' and included
# in 'robosuite.ALL_ENVIRONMENTS'
import rl_coach.environments.robosuite.cube_exp
robosuite_environments = list(robosuite.ALL_ENVIRONMENTS)
robosuite_robots = list(robosuite.ALL_ROBOTS)
robosuite_controllers = list(robosuite.ALL_CONTROLLERS)
def get_robosuite_env_extra_parameters(env_name: str):
import inspect
assert env_name in robosuite_environments
env_params = inspect.signature(robosuite.environments.REGISTERED_ENVS[env_name]).parameters
base_params = list(RobosuiteBaseParameters().env_kwargs_dict().keys()) + ['robots', 'controller_configs']
return {n: p.default for n, p in env_params.items() if n not in base_params}
class OptionalObservations(Flag):
NONE = 0
CAMERA = auto()
OBJECT = auto()
class RobosuiteBaseParameters(Parameters):
def __init__(self, optional_observations: OptionalObservations = OptionalObservations.NONE):
super(RobosuiteBaseParameters, self).__init__()
# NOTE: Attribute names should exactly match the attribute names in Robosuite
self.horizon = 1000 # Every episode lasts for exactly horizon timesteps
self.ignore_done = True # True if never terminating the environment (ignore horizon)
self.reward_shaping = True # if True, use dense rewards.
# How many control signals to receive in every simulated second. This sets the amount of simulation time
# that passes between every action input (this is NOT the same as frame_skip)
self.control_freq = 10
# Optional observations (robot state is always returned)
# if True, every observation includes a rendered image
self.use_camera_obs = bool(optional_observations & OptionalObservations.CAMERA)
# if True, include object (cube/etc.) information in the observation
self.use_object_obs = bool(optional_observations & OptionalObservations.OBJECT)
# Camera parameters
self.has_renderer = False # Set to true to use Mujoco native viewer for on-screen rendering
self.render_camera = 'frontview' # name of camera to use for on-screen rendering
self.has_offscreen_renderer = self.use_camera_obs
self.render_collision_mesh = False # True if rendering collision meshes in camera. False otherwise
self.render_visual_mesh = True # True if rendering visual meshes in camera. False otherwise
self.camera_names = 'agentview' # name of camera for rendering camera observations
self.camera_heights = 84 # height of camera frame.
self.camera_widths = 84 # width of camera frame.
self.camera_depths = False # True if rendering RGB-D, and RGB otherwise.
# Collision
self.penalize_reward_on_collision = True
self.end_episode_on_collision = False
@property
def optional_observations(self):
flag = OptionalObservations.NONE
if self.use_camera_obs:
flag = OptionalObservations.CAMERA
if self.use_object_obs:
flag |= OptionalObservations.OBJECT
elif self.use_object_obs:
flag = OptionalObservations.OBJECT
return flag
@optional_observations.setter
def optional_observations(self, value):
self.use_camera_obs = bool(value & OptionalObservations.CAMERA)
if self.use_camera_obs:
self.has_offscreen_renderer = True
self.use_object_obs = bool(value & OptionalObservations.OBJECT)
def env_kwargs_dict(self):
res = {k: (v.value if isinstance(v, Enum) else v) for k, v in vars(self).items()}
return res
class RobosuiteEnvironmentParameters(EnvironmentParameters):
def __init__(self, level, robot=None, controller=None, apply_dr: bool = False,
dr_every_n_steps_min: int = 10, dr_every_n_steps_max: int = 20,
use_joint_vel_obs=False):
super().__init__(level=level)
self.base_parameters = RobosuiteBaseParameters()
self.extra_parameters = {}
self.robot = robot
self.controller = controller
self.apply_dr = apply_dr
self.dr_every_n_steps_min = dr_every_n_steps_min
self.dr_every_n_steps_max = dr_every_n_steps_max
self.use_joint_vel_obs = use_joint_vel_obs
self.custom_controller_config_fpath = None
@property
def path(self):
return 'rl_coach.environments.robosuite_environment:RobosuiteEnvironment'
DEFAULT_REWARD_SCALES = {
'Lift': 2.25,
'LiftLab': 2.25,
}
RobosuiteStepResult = namedtuple('RobosuiteStepResult', ['observation', 'reward', 'done', 'info'])
# Environment
class RobosuiteEnvironment(Environment):
def __init__(self, level: LevelSelection,
seed: int, frame_skip: int, human_control: bool, custom_reward_threshold: Union[int, float, None],
visualization_parameters: VisualizationParameters,
base_parameters: RobosuiteBaseParameters,
extra_parameters: Dict[str, Any],
robot: str, controller: str,
target_success_rate: float = 1.0, apply_dr: bool = False,
dr_every_n_steps_min: int = 10, dr_every_n_steps_max: int = 20, use_joint_vel_obs=False,
custom_controller_config_fpath=None, **kwargs):
super(RobosuiteEnvironment, self).__init__(level, seed, frame_skip, human_control, custom_reward_threshold,
visualization_parameters, target_success_rate)
# Validate arguments
self.frame_skip = max(1, self.frame_skip)
def validate_input(input, supported, name):
if input not in supported:
raise ValueError("Unknown Robosuite {0} passed: '{1}' ; Supported {0}s are: {2}".format(
name, input, ' | '.join(supported)
))
validate_input(self.env_id, robosuite_environments, 'environment')
validate_input(robot, robosuite_robots, 'robot')
self.robot = robot
if controller is not None:
validate_input(controller, robosuite_controllers, 'controller')
self.controller = controller
self.base_parameters = base_parameters
self.base_parameters.has_renderer = self.is_rendered and self.native_rendering
self.base_parameters.has_offscreen_renderer = self.base_parameters.use_camera_obs or (self.is_rendered and not
self.native_rendering)
# Seed
if self.seed is not None:
np.random.seed(self.seed)
random.seed(self.seed)
# Load and initialize environment
env_args = self.base_parameters.env_kwargs_dict()
env_args.update(extra_parameters)
if 'reward_scale' not in env_args and self.env_id in DEFAULT_REWARD_SCALES:
env_args['reward_scale'] = DEFAULT_REWARD_SCALES[self.env_id]
env_args['robots'] = self.robot
controller_cfg = None
if self.controller is not None:
controller_cfg = robosuite.controllers.load_controller_config(default_controller=self.controller)
elif custom_controller_config_fpath is not None:
controller_cfg = robosuite.controllers.load_controller_config(custom_fpath=custom_controller_config_fpath)
env_args['controller_configs'] = controller_cfg
self.env = robosuite.make(self.env_id, **env_args)
# TODO: Generalize this to filter any observation by name
if not use_joint_vel_obs:
self.env.modify_observable('robot0_joint_vel', 'active', False)
# Wrap with a dummy wrapper so we get a consistent API (there are subtle changes between
# wrappers and actual environments in Robosuite, for example action_spec as property vs. function)
self.env = Wrapper(self.env)
if apply_dr:
self.env = DomainRandomizationWrapper(self.env, seed=self.seed, randomize_every_n_steps_min=dr_every_n_steps_min,
randomize_every_n_steps_max=dr_every_n_steps_max)
# State space
self.state_space = self._setup_state_space()
# Action space
low, high = self.env.unwrapped.action_spec
self.action_space = BoxActionSpace(low.shape, low=low, high=high)
self.reset_internal_state()
if self.is_rendered:
image = self.get_rendered_image()
self.renderer.create_screen(image.shape[1], image.shape[0])
# TODO: Other environments call rendering here, why? reset_internal_state does it
def _setup_state_space(self):
state_space = StateSpace({})
dummy_obs = self._process_observation(self.env.observation_spec())
state_space['measurements'] = VectorObservationSpace(dummy_obs['measurements'].shape[0])
if self.base_parameters.use_camera_obs:
state_space['camera'] = PlanarMapsObservationSpace(dummy_obs['camera'].shape, 0, 255)
return state_space
def _process_observation(self, raw_obs):
new_obs = {}
# TODO: Support multiple cameras, this assumes a single camera
camera_name = self.base_parameters.camera_names
camera_obs = raw_obs.get(camera_name + '_image', None)
if camera_obs is not None:
depth_obs = raw_obs.get(camera_name + '_depth', None)
if depth_obs is not None:
depth_obs = np.expand_dims(depth_obs, axis=2)
camera_obs = np.concatenate([camera_obs, depth_obs], axis=2)
new_obs['camera'] = camera_obs
measurements = raw_obs['robot0_proprio-state']
object_obs = raw_obs.get('object-state', None)
if object_obs is not None:
measurements = np.concatenate([measurements, object_obs])
new_obs['measurements'] = measurements
return new_obs
def _take_action(self, action):
action = self.action_space.clip_action_to_space(action)
# We mimic the "action_repeat" mechanism of RobosuiteWrapper in Surreal.
# Same concept as frame_skip, only returning the average reward across repeated actions instead
# of the total reward.
rewards = []
for _ in range(self.frame_skip):
obs, reward, done, info = self.env.step(action)
rewards.append(reward)
if done:
break
reward = np.mean(rewards)
self.last_result = RobosuiteStepResult(obs, reward, done, info)
def _update_state(self):
obs = self._process_observation(self.last_result.observation)
self.state = {k: obs[k] for k in self.state_space.sub_spaces}
self.reward = self.last_result.reward or 0
self.done = self.last_result.done
self.info = self.last_result.info
def _restart_environment_episode(self, force_environment_reset=False):
reset_obs = self.env.reset()
self.last_result = RobosuiteStepResult(reset_obs, 0.0, False, {})
def _render(self):
self.env.render()
def get_rendered_image(self):
img: np.ndarray = self.env.sim.render(camera_name=self.base_parameters.render_camera,
height=512, width=512, depth=False)
return np.flip(img, 0)
def close(self):
self.env.close()
class RobosuiteGoalBasedExpEnvironmentParameters(RobosuiteEnvironmentParameters):
@property
def path(self):
return 'rl_coach.environments.robosuite_environment:RobosuiteGoalBasedExpEnvironment'
class RobosuiteGoalBasedExpEnvironment(RobosuiteEnvironment):
def _process_observation(self, raw_obs):
new_obs = super()._process_observation(raw_obs)
new_obs['obs-goal'] = None
return new_obs
def _setup_state_space(self):
state_space = super()._setup_state_space()
goal_based_shape = list(state_space['camera'].shape)
goal_based_shape[2] *= 2
state_space['obs-goal'] = PlanarMapsObservationSpace(tuple(goal_based_shape), 0, 255)
return state_space

View File

@@ -114,7 +114,8 @@ class StarCraft2Environment(Environment):
observation_type: StarcraftObservationType=StarcraftObservationType.Features,
disable_fog: bool=False, auto_select_all_army: bool=True,
use_full_action_space: bool=False, **kwargs):
super().__init__(level, seed, frame_skip, human_control, custom_reward_threshold, visualization_parameters, target_success_rate)
super().__init__(level, seed, frame_skip, human_control, custom_reward_threshold, visualization_parameters,
target_success_rate)
self.screen_size = screen_size
self.minimap_size = minimap_size

View File

@@ -222,7 +222,8 @@ class GraphManager(object):
if isinstance(task_parameters, DistributedTaskParameters):
# the distributed tensorflow setting
from rl_coach.architectures.tensorflow_components.distributed_tf_utils import create_monitored_session
if hasattr(self.task_parameters, 'checkpoint_restore_path') and self.task_parameters.checkpoint_restore_path:
if hasattr(self.task_parameters,
'checkpoint_restore_path') and self.task_parameters.checkpoint_restore_path:
checkpoint_dir = os.path.join(task_parameters.experiment_path, 'checkpoint')
if os.path.exists(checkpoint_dir):
remove_tree(checkpoint_dir)
@@ -438,7 +439,8 @@ class GraphManager(object):
# perform several steps of playing
count_end = self.current_step_counter + steps
result = None
while self.current_step_counter < count_end or (wait_for_full_episodes and result is not None and not result.game_over):
while self.current_step_counter < count_end or (
wait_for_full_episodes and result is not None and not result.game_over):
# reset the environment if the previous episode was terminated
if self.reset_required:
self.reset_internal_state()
@@ -506,8 +508,14 @@ class GraphManager(object):
# act for at least `steps`, though don't interrupt an episode
count_end = self.current_step_counter + steps
while self.current_step_counter < count_end:
# In case of an evaluation-only worker, fake a phase transition before and after every
# episode to make sure results are logged correctly
if self.task_parameters.evaluate_only is not None:
self.phase = RunPhase.TEST
self.act(EnvironmentEpisodes(1))
self.sync()
if self.task_parameters.evaluate_only is not None:
self.phase = RunPhase.TRAIN
if self.should_stop():
self.flush_finished()
screen.success("Reached required success rate. Exiting.")
@@ -555,7 +563,7 @@ class GraphManager(object):
if self.task_parameters.checkpoint_restore_path:
if os.path.isdir(self.task_parameters.checkpoint_restore_path):
# a checkpoint dir
if self.task_parameters.framework_type == Frameworks.tensorflow and\
if self.task_parameters.framework_type == Frameworks.tensorflow and \
'checkpoint' in os.listdir(self.task_parameters.checkpoint_restore_path):
# TODO-fixme checkpointing
# MonitoredTrainingSession manages save/restore checkpoints autonomously. Doing so,
@@ -717,7 +725,8 @@ class GraphManager(object):
self.memory_backend = get_memory_backend(self.agent_params.memory.memory_backend_params)
def should_stop(self) -> bool:
return self.task_parameters.apply_stop_condition and all([manager.should_stop() for manager in self.level_managers])
return self.task_parameters.apply_stop_condition and all(
[manager.should_stop() for manager in self.level_managers])
def get_data_store(self, param):
if self.data_store:
@@ -727,10 +736,10 @@ class GraphManager(object):
def signal_ready(self):
if self.task_parameters.checkpoint_save_dir and os.path.exists(self.task_parameters.checkpoint_save_dir):
open(os.path.join(self.task_parameters.checkpoint_save_dir, SyncFiles.TRAINER_READY.value), 'w').close()
open(os.path.join(self.task_parameters.checkpoint_save_dir, SyncFiles.TRAINER_READY.value), 'w').close()
if hasattr(self, 'data_store_params'):
data_store = self.get_data_store(self.data_store_params)
data_store.save_to_store()
data_store = self.get_data_store(self.data_store_params)
data_store.save_to_store()
def close(self) -> None:
"""

View File

@@ -70,7 +70,7 @@ class ScreenLogger(object):
"""
if not self.log_file:
self.log_file = open(os.path.join(experiment_path, "log.txt"), "a")
self.log_file.write(",".join([t for t in text]))
self.log_file.write(",".join([str(t) for t in text]))
self.log_file.write("\n")
self.log_file.flush()
print(*text, flush=True)

View File

@@ -0,0 +1,98 @@
from rl_coach.agents.td3_exp_agent import RandomAgentParameters
from rl_coach.architectures.embedder_parameters import InputEmbedderParameters
from rl_coach.architectures.layers import Dense, Conv2d, BatchnormActivationDropout, Flatten
from rl_coach.base_parameters import EmbedderScheme
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps
from rl_coach.environments.robosuite_environment import RobosuiteGoalBasedExpEnvironmentParameters, \
OptionalObservations
from rl_coach.filters.filter import NoInputFilter, NoOutputFilter
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager
from rl_coach.graph_managers.graph_manager import ScheduleParameters
from rl_coach.architectures.head_parameters import RNDHeadParameters
####################
# Graph Scheduling #
####################
schedule_params = ScheduleParameters()
schedule_params.improve_steps = TrainingSteps(300000)
schedule_params.steps_between_evaluation_periods = TrainingSteps(300000)
schedule_params.evaluation_steps = EnvironmentEpisodes(0)
schedule_params.heatup_steps = EnvironmentSteps(0)
#########
# Agent #
#########
agent_params = RandomAgentParameters()
agent_params.algorithm.use_non_zero_discount_for_terminal_states = True
agent_params.input_filter = NoInputFilter()
agent_params.output_filter = NoOutputFilter()
# Camera observation pre-processing network scheme
camera_obs_scheme = [
Conv2d(32, 8, 4),
BatchnormActivationDropout(activation_function='relu'),
Conv2d(64, 4, 2),
BatchnormActivationDropout(activation_function='relu'),
Conv2d(64, 3, 1),
BatchnormActivationDropout(activation_function='relu'),
Flatten(),
Dense(256),
BatchnormActivationDropout(activation_function='relu')
]
# Actor
actor_network = agent_params.network_wrappers['actor']
actor_network.input_embedders_parameters = {
'measurements': InputEmbedderParameters(scheme=EmbedderScheme.Empty),
agent_params.algorithm.agent_obs_key: InputEmbedderParameters(scheme=camera_obs_scheme, activation_function='none')
}
actor_network.middleware_parameters.scheme = [Dense(300), Dense(200)]
actor_network.learning_rate = 1e-4
# Critic
critic_network = agent_params.network_wrappers['critic']
critic_network.input_embedders_parameters = {
'action': InputEmbedderParameters(scheme=EmbedderScheme.Empty),
'measurements': InputEmbedderParameters(scheme=EmbedderScheme.Empty),
agent_params.algorithm.agent_obs_key: InputEmbedderParameters(scheme=camera_obs_scheme, activation_function='none')
}
critic_network.middleware_parameters.scheme = [Dense(400), Dense(300)]
critic_network.learning_rate = 1e-4
# RND
agent_params.network_wrappers['predictor'].input_embedders_parameters = \
{agent_params.algorithm.env_obs_key: InputEmbedderParameters(scheme=EmbedderScheme.Empty,
input_rescaling={'image': 1.0},
flatten=False)}
agent_params.network_wrappers['constant'].input_embedders_parameters = \
{agent_params.algorithm.env_obs_key: InputEmbedderParameters(scheme=EmbedderScheme.Empty,
input_rescaling={'image': 1.0},
flatten=False)}
agent_params.network_wrappers['predictor'].heads_parameters = [RNDHeadParameters(is_predictor=True)]
###############
# Environment #
###############
env_params = RobosuiteGoalBasedExpEnvironmentParameters(level='CubeExp')
env_params.robot = 'Panda'
env_params.custom_controller_config_fpath = './rl_coach/environments/robosuite/osc_pose.json'
env_params.base_parameters.optional_observations = OptionalObservations.CAMERA
env_params.base_parameters.render_camera = 'frontview'
env_params.base_parameters.camera_names = 'agentview'
env_params.base_parameters.camera_depths = False
env_params.base_parameters.horizon = 200
env_params.base_parameters.ignore_done = False
env_params.base_parameters.use_object_obs = True
env_params.frame_skip = 1
env_params.base_parameters.control_freq = 2
env_params.base_parameters.camera_heights = 84
env_params.base_parameters.camera_widths = 84
env_params.extra_parameters = {'hard_reset': False}
graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params, schedule_params=schedule_params)

View File

@@ -0,0 +1,111 @@
from rl_coach.agents.td3_exp_agent import TD3GoalBasedAgentParameters
from rl_coach.architectures.embedder_parameters import InputEmbedderParameters
from rl_coach.architectures.layers import Dense, Conv2d, BatchnormActivationDropout, Flatten
from rl_coach.base_parameters import EmbedderScheme
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps
from rl_coach.environments.robosuite_environment import RobosuiteGoalBasedExpEnvironmentParameters, \
OptionalObservations
from rl_coach.filters.filter import NoInputFilter, NoOutputFilter
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager
from rl_coach.graph_managers.graph_manager import ScheduleParameters
from rl_coach.architectures.head_parameters import RNDHeadParameters
from rl_coach.schedules import LinearSchedule
####################
# Graph Scheduling #
####################
schedule_params = ScheduleParameters()
schedule_params.improve_steps = TrainingSteps(300000)
schedule_params.steps_between_evaluation_periods = TrainingSteps(300000)
schedule_params.evaluation_steps = EnvironmentEpisodes(0)
schedule_params.heatup_steps = EnvironmentSteps(1000)
#########
# Agent #
#########
agent_params = TD3GoalBasedAgentParameters()
agent_params.algorithm.use_non_zero_discount_for_terminal_states = False
agent_params.algorithm.identity_goal_sample_rate = 0.04
agent_params.exploration.noise_schedule = LinearSchedule(1.5, 0.5, 300000)
agent_params.algorithm.rnd_sample_size = 2000
agent_params.algorithm.rnd_batch_size = 500
agent_params.algorithm.rnd_optimization_epochs = 4
agent_params.algorithm.td3_training_ratio = 1.0
agent_params.algorithm.identity_goal_sample_rate = 0.0
agent_params.algorithm.env_obs_key = 'camera'
agent_params.algorithm.agent_obs_key = 'obs-goal'
agent_params.algorithm.replay_buffer_save_steps = 25000
agent_params.algorithm.replay_buffer_save_path = './tutorials'
agent_params.input_filter = NoInputFilter()
agent_params.output_filter = NoOutputFilter()
# Camera observation pre-processing network scheme
camera_obs_scheme = [
Conv2d(32, 8, 4),
BatchnormActivationDropout(activation_function='relu'),
Conv2d(64, 4, 2),
BatchnormActivationDropout(activation_function='relu'),
Conv2d(64, 3, 1),
BatchnormActivationDropout(activation_function='relu'),
Flatten(),
Dense(256),
BatchnormActivationDropout(activation_function='relu')
]
# Actor
actor_network = agent_params.network_wrappers['actor']
actor_network.input_embedders_parameters = {
'measurements': InputEmbedderParameters(scheme=EmbedderScheme.Empty),
agent_params.algorithm.agent_obs_key: InputEmbedderParameters(scheme=camera_obs_scheme, activation_function='none')
}
actor_network.middleware_parameters.scheme = [Dense(300), Dense(200)]
actor_network.learning_rate = 1e-4
# Critic
critic_network = agent_params.network_wrappers['critic']
critic_network.input_embedders_parameters = {
'action': InputEmbedderParameters(scheme=EmbedderScheme.Empty),
'measurements': InputEmbedderParameters(scheme=EmbedderScheme.Empty),
agent_params.algorithm.agent_obs_key: InputEmbedderParameters(scheme=camera_obs_scheme, activation_function='none')
}
critic_network.middleware_parameters.scheme = [Dense(400), Dense(300)]
critic_network.learning_rate = 1e-4
# RND
agent_params.network_wrappers['predictor'].input_embedders_parameters = \
{agent_params.algorithm.env_obs_key: InputEmbedderParameters(scheme=EmbedderScheme.Empty,
input_rescaling={'image': 1.0},
flatten=False)}
agent_params.network_wrappers['constant'].input_embedders_parameters = \
{agent_params.algorithm.env_obs_key: InputEmbedderParameters(scheme=EmbedderScheme.Empty,
input_rescaling={'image': 1.0},
flatten=False)}
agent_params.network_wrappers['predictor'].heads_parameters = [RNDHeadParameters(is_predictor=True)]
###############
# Environment #
###############
env_params = RobosuiteGoalBasedExpEnvironmentParameters(level='CubeExp')
env_params.robot = 'Panda'
env_params.custom_controller_config_fpath = './rl_coach/environments/robosuite/osc_pose.json'
env_params.base_parameters.optional_observations = OptionalObservations.CAMERA
env_params.base_parameters.render_camera = 'frontview'
env_params.base_parameters.camera_names = 'agentview'
env_params.base_parameters.camera_depths = False
env_params.base_parameters.horizon = 200
env_params.base_parameters.ignore_done = False
env_params.base_parameters.use_object_obs = True
env_params.frame_skip = 1
env_params.base_parameters.control_freq = 2
env_params.base_parameters.camera_heights = 84
env_params.base_parameters.camera_widths = 84
env_params.extra_parameters = {'hard_reset': False}
graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params, schedule_params=schedule_params)

View File

@@ -0,0 +1,100 @@
from rl_coach.agents.td3_exp_agent import TD3IntrinsicRewardAgentParameters
from rl_coach.architectures.embedder_parameters import InputEmbedderParameters
from rl_coach.architectures.layers import Dense, Conv2d, BatchnormActivationDropout, Flatten
from rl_coach.base_parameters import EmbedderScheme
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps
from rl_coach.environments.robosuite_environment import RobosuiteGoalBasedExpEnvironmentParameters, \
OptionalObservations
from rl_coach.filters.filter import NoInputFilter, NoOutputFilter
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager
from rl_coach.graph_managers.graph_manager import ScheduleParameters
from rl_coach.architectures.head_parameters import RNDHeadParameters
from rl_coach.schedules import LinearSchedule
####################
# Graph Scheduling #
####################
schedule_params = ScheduleParameters()
schedule_params.improve_steps = TrainingSteps(300000)
schedule_params.steps_between_evaluation_periods = TrainingSteps(300000)
schedule_params.evaluation_steps = EnvironmentEpisodes(0)
schedule_params.heatup_steps = EnvironmentSteps(1000)
#########
# Agent #
#########
agent_params = TD3IntrinsicRewardAgentParameters()
agent_params.algorithm.use_non_zero_discount_for_terminal_states = True
agent_params.exploration.noise_schedule = LinearSchedule(1.5, 0.5, 300000)
agent_params.input_filter = NoInputFilter()
agent_params.output_filter = NoOutputFilter()
# Camera observation pre-processing network scheme
camera_obs_scheme = [
Conv2d(32, 8, 4),
BatchnormActivationDropout(activation_function='relu'),
Conv2d(64, 4, 2),
BatchnormActivationDropout(activation_function='relu'),
Conv2d(64, 3, 1),
BatchnormActivationDropout(activation_function='relu'),
Flatten(),
Dense(256),
BatchnormActivationDropout(activation_function='relu')
]
# Actor
actor_network = agent_params.network_wrappers['actor']
actor_network.input_embedders_parameters = {
'measurements': InputEmbedderParameters(scheme=EmbedderScheme.Empty),
agent_params.algorithm.agent_obs_key: InputEmbedderParameters(scheme=camera_obs_scheme, activation_function='none')
}
actor_network.middleware_parameters.scheme = [Dense(300), Dense(200)]
actor_network.learning_rate = 1e-4
# Critic
critic_network = agent_params.network_wrappers['critic']
critic_network.input_embedders_parameters = {
'action': InputEmbedderParameters(scheme=EmbedderScheme.Empty),
'measurements': InputEmbedderParameters(scheme=EmbedderScheme.Empty),
agent_params.algorithm.agent_obs_key: InputEmbedderParameters(scheme=camera_obs_scheme, activation_function='none')
}
critic_network.middleware_parameters.scheme = [Dense(400), Dense(300)]
critic_network.learning_rate = 1e-4
# RND
agent_params.network_wrappers['predictor'].input_embedders_parameters = \
{agent_params.algorithm.env_obs_key: InputEmbedderParameters(scheme=EmbedderScheme.Empty,
input_rescaling={'image': 1.0},
flatten=False)}
agent_params.network_wrappers['constant'].input_embedders_parameters = \
{agent_params.algorithm.env_obs_key: InputEmbedderParameters(scheme=EmbedderScheme.Empty,
input_rescaling={'image': 1.0},
flatten=False)}
agent_params.network_wrappers['predictor'].heads_parameters = [RNDHeadParameters(is_predictor=True)]
###############
# Environment #
###############
env_params = RobosuiteGoalBasedExpEnvironmentParameters(level='CubeExp')
env_params.robot = 'Panda'
env_params.custom_controller_config_fpath = './rl_coach/environments/robosuite/osc_pose.json'
env_params.base_parameters.optional_observations = OptionalObservations.CAMERA
env_params.base_parameters.render_camera = 'frontview'
env_params.base_parameters.camera_names = 'agentview'
env_params.base_parameters.camera_depths = False
env_params.base_parameters.horizon = 200
env_params.base_parameters.ignore_done = False
env_params.base_parameters.use_object_obs = True
env_params.frame_skip = 1
env_params.base_parameters.control_freq = 2
env_params.base_parameters.camera_heights = 84
env_params.base_parameters.camera_widths = 84
env_params.extra_parameters = {'hard_reset': False}
graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params, schedule_params=schedule_params)

View File

@@ -22,6 +22,9 @@ FAILING_PRESETS = [
'CARLA_3_Cameras_DDPG',
'Starcraft_CollectMinerals_A3C',
'Starcraft_CollectMinerals_Dueling_DDQN',
'RoboSuite_CubeExp_Random',
'RoboSuite_CubeExp_TD3_Goal_Based',
'RoboSuite_CubeExp_TD3_Intrinsic_Reward',
]
def all_presets():