1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 19:20:19 +01:00

load and save function for non-episodic replay buffers + carla improvements + network bug fixes

This commit is contained in:
itaicaspi-intel
2018-09-06 16:46:57 +03:00
parent d59a700248
commit a9bd1047c4
8 changed files with 50 additions and 18 deletions

View File

@@ -74,12 +74,12 @@ class Agent(AgentInterface):
self.memory = self.shared_memory_scratchpad.get(self.memory_lookup_name)
else:
# modules
self.memory = dynamic_import_and_instantiate_module_from_params(self.ap.memory)
if agent_parameters.memory.load_memory_from_file_path:
screen.log_title("Loading replay buffer from pickle. Pickle path: {}"
.format(agent_parameters.memory.load_memory_from_file_path))
self.memory = read_pickle(agent_parameters.memory.load_memory_from_file_path)
else:
self.memory = dynamic_import_and_instantiate_module_from_params(self.ap.memory)
self.memory.load(agent_parameters.memory.load_memory_from_file_path)
if self.shared_memory and self.is_chief:
self.shared_memory_scratchpad.add(self.memory_lookup_name, self.memory)
@@ -149,6 +149,7 @@ class Agent(AgentInterface):
self.unclipped_grads = self.register_signal('Grads (unclipped)')
self.reward = self.register_signal('Reward', dump_one_value_per_episode=False, dump_one_value_per_step=True)
self.shaped_reward = self.register_signal('Shaped Reward', dump_one_value_per_episode=False, dump_one_value_per_step=True)
self.discounted_return = self.register_signal('Discounted Return')
if isinstance(self.in_action_space, GoalsSpace):
self.distance_from_goal = self.register_signal('Distance From Goal', dump_one_value_per_step=True)
@@ -427,6 +428,10 @@ class Agent(AgentInterface):
:return: None
"""
self.current_episode_buffer.is_complete = True
self.current_episode_buffer.update_returns()
for transition in self.current_episode_buffer.transitions:
self.discounted_return.add_sample(transition.total_return)
if self.phase != RunPhase.TEST or self.ap.task_parameters.evaluate_only:
self.current_episode += 1
@@ -435,7 +440,6 @@ class Agent(AgentInterface):
if isinstance(self.memory, EpisodicExperienceReplay):
self.call_memory('store_episode', self.current_episode_buffer)
elif self.ap.algorithm.store_transitions_only_when_episodes_are_terminated:
self.current_episode_buffer.update_returns()
for transition in self.current_episode_buffer.transitions:
self.call_memory('store', transition)

View File

@@ -32,6 +32,7 @@ from rl_coach.core_types import ActionInfo
from rl_coach.exploration_policies.e_greedy import EGreedyParameters
from rl_coach.logger import screen
from rl_coach.memories.episodic.episodic_experience_replay import EpisodicExperienceReplayParameters
from rl_coach.memories.non_episodic.experience_replay import ExperienceReplayParameters
class HumanAlgorithmParameters(AlgorithmParameters):
@@ -57,7 +58,7 @@ class HumanAgentParameters(AgentParameters):
def __init__(self):
super().__init__(algorithm=HumanAlgorithmParameters(),
exploration=EGreedyParameters(),
memory=EpisodicExperienceReplayParameters(),
memory=ExperienceReplayParameters(),
networks={"main": BCNetworkParameters()})
@property
@@ -103,7 +104,7 @@ class HumanAgent(Agent):
def save_replay_buffer_and_exit(self):
replay_buffer_path = os.path.join(self.agent_logger.experiments_path, 'replay_buffer.p')
self.memory.tp = None
to_pickle(self.memory, replay_buffer_path)
self.memory.save(replay_buffer_path)
screen.log_title("Replay buffer was stored in {}".format(replay_buffer_path))
exit()

View File

@@ -61,7 +61,7 @@ class Conv2d(object):
"""
self.params = params
def __call__(self, input_layer, name: str):
def __call__(self, input_layer, name: str=None):
"""
returns a tensorflow conv2d layer
:param input_layer: previous layer
@@ -79,7 +79,7 @@ class Dense(object):
"""
self.params = force_list(params)
def __call__(self, input_layer, name: str, kernel_initializer=None, activation=None):
def __call__(self, input_layer, name: str=None, kernel_initializer=None, activation=None):
"""
returns a tensorflow dense layer
:param input_layer: previous layer

View File

@@ -253,9 +253,11 @@ class GeneralTensorFlowNetwork(TensorFlowArchitecture):
else:
# if we use a single network with multiple embedders, then the head type is the current head idx
head_type_idx = head_idx
# create output head and add it to the output heads list
self.output_heads.append(
self.get_output_head(self.network_parameters.heads_parameters[head_type_idx],
head_copy_idx,
head_idx*self.network_parameters.num_output_head_copies + head_copy_idx,
self.network_parameters.loss_weights[head_type_idx])
)

View File

@@ -59,8 +59,8 @@ class Head(object):
self.loss = []
self.loss_type = []
self.regularizations = []
# self.loss_weight = force_list(loss_weight)
self.loss_weight = tf.Variable(force_list(loss_weight), trainable=False, collections=[tf.GraphKeys.LOCAL_VARIABLES])
self.loss_weight = tf.Variable([float(w) for w in force_list(loss_weight)],
trainable=False, collections=[tf.GraphKeys.LOCAL_VARIABLES])
self.loss_weight_placeholder = tf.placeholder("float")
self.set_loss_weight = tf.assign(self.loss_weight, self.loss_weight_placeholder)
self.target = []

View File

@@ -49,7 +49,7 @@ class PolicyHead(Head):
# a scalar weight that penalizes low entropy values to encourage exploration
if hasattr(agent_parameters.algorithm, 'beta_entropy'):
# we set the beta value as a tf variable so it can be updated later if needed
self.beta = tf.Variable(agent_parameters.algorithm.beta_entropy,
self.beta = tf.Variable(float(agent_parameters.algorithm.beta_entropy),
trainable=False, collections=[tf.GraphKeys.LOCAL_VARIABLES])
self.beta_placeholder = tf.placeholder('float')
self.set_beta = tf.assign(self.beta, self.beta_placeholder)

View File

@@ -18,6 +18,7 @@ import random
import sys
from os import path, environ
from rl_coach.logger import screen
from rl_coach.filters.action.partial_discrete_action_space_map import PartialDiscreteActionSpaceMap
from rl_coach.filters.observation.observation_rgb_to_y_filter import ObservationRGBToYFilter
from rl_coach.filters.observation.observation_to_uint8_filter import ObservationToUInt8Filter
@@ -25,6 +26,8 @@ from rl_coach.filters.observation.observation_to_uint8_filter import Observation
try:
if 'CARLA_ROOT' in environ:
sys.path.append(path.join(environ.get('CARLA_ROOT'), 'PythonClient'))
else:
screen.error("CARLA_ROOT was not defined. Please set it to point to the CARLA root directory and try again.")
from carla.client import CarlaClient
from carla.settings import CarlaSettings
from carla.tcp import TCPConnectionError
@@ -237,25 +240,28 @@ class CarlaEnvironment(Environment):
# add a front facing camera
if CameraTypes.FRONT in cameras:
camera = Camera(CameraTypes.FRONT.value)
camera.set(FOV=100)
camera.set_image_size(camera_width, camera_height)
camera.set_position(0.2, 0, 1.3)
camera.set_rotation(8, 0, 0)
camera.set_position(2.0, 0, 1.4)
camera.set_rotation(-15.0, 0, 0)
settings.add_sensor(camera)
# add a left facing camera
if CameraTypes.LEFT in cameras:
camera = Camera(CameraTypes.LEFT.value)
camera.set(FOV=100)
camera.set_image_size(camera_width, camera_height)
camera.set_position(0.2, 0, 1.3)
camera.set_rotation(8, -30, 0)
camera.set_position(2.0, 0, 1.4)
camera.set_rotation(-15.0, -30, 0)
settings.add_sensor(camera)
# add a right facing camera
if CameraTypes.RIGHT in cameras:
camera = Camera(CameraTypes.RIGHT.value)
camera.set(FOV=100)
camera.set_image_size(camera_width, camera_height)
camera.set_position(0.2, 0, 1.3)
camera.set_rotation(8, 30, 0)
camera.set_position(2.0, 0, 1.4)
camera.set_rotation(-15.0, 30, 0)
settings.add_sensor(camera)
# add a front facing depth camera

View File

@@ -15,6 +15,7 @@
#
from typing import List, Tuple, Union, Dict, Any
import pickle
import numpy as np
@@ -218,3 +219,21 @@ class ExperienceReplay(Memory):
self.reader_writer_lock.release_writing()
return mean
def save(self, file_path: str) -> None:
"""
Save the replay buffer contents to a pickle file
:param file_path: the path to the file that will be used to store the pickled transitions
"""
with open(file_path, 'wb') as file:
pickle.dump(self.transitions, file)
def load(self, file_path: str) -> None:
"""
Restore the replay buffer contents from a pickle file.
The pickle file is assumed to include a list of transitions.
:param file_path: The path to a pickle file to restore
"""
with open(file_path, 'rb') as file:
self.transitions = pickle.load(file)
self._num_transitions = len(self.transitions)