1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 19:20:19 +01:00

added a simple progress bar implementation

This commit is contained in:
itaicaspi-intel
2018-09-13 14:21:38 +03:00
parent fa79d8d365
commit 607ef17431
3 changed files with 34 additions and 10 deletions

View File

@@ -378,7 +378,7 @@ class CarlaEnvironment(Environment):
self.control.brake = np.abs(np.clip(action[0], -1, 0)) self.control.brake = np.abs(np.clip(action[0], -1, 0))
# prevent braking # prevent braking
if not self.allow_braking: if not self.allow_braking or self.control.brake < 0.1:
self.control.brake = 0 self.control.brake = 0
# prevent over speeding # prevent over speeding

View File

@@ -22,8 +22,9 @@ import time
import numpy as np import numpy as np
from rl_coach.core_types import Transition from rl_coach.core_types import Transition
from rl_coach.logger import screen
from rl_coach.memories.memory import Memory, MemoryGranularity, MemoryParameters from rl_coach.memories.memory import Memory, MemoryGranularity, MemoryParameters
from rl_coach.utils import ReaderWriterLock from rl_coach.utils import ReaderWriterLock, ProgressBar
class ExperienceReplayParameters(MemoryParameters): class ExperienceReplayParameters(MemoryParameters):
@@ -239,15 +240,17 @@ class ExperienceReplay(Memory):
with open(file_path, 'rb') as file: with open(file_path, 'rb') as file:
transitions = pickle.load(file) transitions = pickle.load(file)
num_transitions = len(transitions) num_transitions = len(transitions)
start_time = time.time() if num_transitions > self.max_size[1]:
screen.warning("Warning! The number of transition to load into the replay buffer ({}) is "
"bigger than the max size of the replay buffer ({}). The excessive transitions will "
"not be stored.".format(num_transitions, self.max_size[1]))
progress_bar = ProgressBar(num_transitions)
for transition_idx, transition in enumerate(transitions): for transition_idx, transition in enumerate(transitions):
self.store(transition) self.store(transition)
# print progress # print progress
if transition_idx % 100 == 0: if transition_idx % 100 == 0:
percentage = int((100 * transition_idx) / num_transitions) progress_bar.update(transition_idx)
sys.stdout.write("\rProgress: ({}/{})".format(transition_idx, num_transitions))
sys.stdout.write(' Time (sec): {}'.format(round(time.time() - start_time, 2))) progress_bar.close()
sys.stdout.write(' {}%|{}{}| '.format(percentage, '#' * int(percentage / 10),
' ' * (10 - int(percentage / 10))))
sys.stdout.flush()

View File

@@ -20,6 +20,7 @@ import inspect
import json import json
import os import os
import signal import signal
import sys
import threading import threading
import time import time
from multiprocessing import Manager from multiprocessing import Manager
@@ -548,4 +549,24 @@ class ReaderWriterLock(object):
def release_writing(self): def release_writing(self):
self.num_readers_lock.acquire() self.num_readers_lock.acquire()
self.num_readers -= 1 self.num_readers -= 1
self.num_readers_lock.release() self.num_readers_lock.release()
class ProgressBar(object):
def __init__(self, max_value):
self.start_time = time.time()
self.max_value = max_value
self.current_value = 0
def update(self, current_value):
self.current_value = current_value
percentage = int((100 * current_value) / self.max_value)
sys.stdout.write("\rProgress: ({}/{}) Time: {} sec {}%|{}{}| "
.format(current_value, self.max_value,
round(time.time() - self.start_time, 2),
percentage, '#' * int(percentage / 10),
' ' * (10 - int(percentage / 10))))
sys.stdout.flush()
def close(self):
print("")