diff --git a/rl_coach/environments/carla_environment.py b/rl_coach/environments/carla_environment.py index a2bd57f..2a2fb3d 100644 --- a/rl_coach/environments/carla_environment.py +++ b/rl_coach/environments/carla_environment.py @@ -378,7 +378,7 @@ class CarlaEnvironment(Environment): self.control.brake = np.abs(np.clip(action[0], -1, 0)) # prevent braking - if not self.allow_braking: + if not self.allow_braking or self.control.brake < 0.1: self.control.brake = 0 # prevent over speeding diff --git a/rl_coach/memories/non_episodic/experience_replay.py b/rl_coach/memories/non_episodic/experience_replay.py index 415706b..a73e44a 100644 --- a/rl_coach/memories/non_episodic/experience_replay.py +++ b/rl_coach/memories/non_episodic/experience_replay.py @@ -22,8 +22,9 @@ import time import numpy as np from rl_coach.core_types import Transition +from rl_coach.logger import screen from rl_coach.memories.memory import Memory, MemoryGranularity, MemoryParameters -from rl_coach.utils import ReaderWriterLock +from rl_coach.utils import ReaderWriterLock, ProgressBar class ExperienceReplayParameters(MemoryParameters): @@ -239,15 +240,17 @@ class ExperienceReplay(Memory): with open(file_path, 'rb') as file: transitions = pickle.load(file) num_transitions = len(transitions) - start_time = time.time() + if num_transitions > self.max_size[1]: + screen.warning("Warning! The number of transition to load into the replay buffer ({}) is " + "bigger than the max size of the replay buffer ({}). The excessive transitions will " + "not be stored.".format(num_transitions, self.max_size[1])) + + progress_bar = ProgressBar(num_transitions) for transition_idx, transition in enumerate(transitions): self.store(transition) # print progress if transition_idx % 100 == 0: - percentage = int((100 * transition_idx) / num_transitions) - sys.stdout.write("\rProgress: ({}/{})".format(transition_idx, num_transitions)) - sys.stdout.write(' Time (sec): {}'.format(round(time.time() - start_time, 2))) - sys.stdout.write(' {}%|{}{}| '.format(percentage, '#' * int(percentage / 10), - ' ' * (10 - int(percentage / 10)))) - sys.stdout.flush() + progress_bar.update(transition_idx) + + progress_bar.close() diff --git a/rl_coach/utils.py b/rl_coach/utils.py index dc1c071..88f74c5 100644 --- a/rl_coach/utils.py +++ b/rl_coach/utils.py @@ -20,6 +20,7 @@ import inspect import json import os import signal +import sys import threading import time from multiprocessing import Manager @@ -548,4 +549,24 @@ class ReaderWriterLock(object): def release_writing(self): self.num_readers_lock.acquire() self.num_readers -= 1 - self.num_readers_lock.release() \ No newline at end of file + self.num_readers_lock.release() + + +class ProgressBar(object): + def __init__(self, max_value): + self.start_time = time.time() + self.max_value = max_value + self.current_value = 0 + + def update(self, current_value): + self.current_value = current_value + percentage = int((100 * current_value) / self.max_value) + sys.stdout.write("\rProgress: ({}/{}) Time: {} sec {}%|{}{}| " + .format(current_value, self.max_value, + round(time.time() - self.start_time, 2), + percentage, '#' * int(percentage / 10), + ' ' * (10 - int(percentage / 10)))) + sys.stdout.flush() + + def close(self): + print("")