mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
added a simple progress bar implementation
This commit is contained in:
@@ -378,7 +378,7 @@ class CarlaEnvironment(Environment):
|
|||||||
self.control.brake = np.abs(np.clip(action[0], -1, 0))
|
self.control.brake = np.abs(np.clip(action[0], -1, 0))
|
||||||
|
|
||||||
# prevent braking
|
# prevent braking
|
||||||
if not self.allow_braking:
|
if not self.allow_braking or self.control.brake < 0.1:
|
||||||
self.control.brake = 0
|
self.control.brake = 0
|
||||||
|
|
||||||
# prevent over speeding
|
# prevent over speeding
|
||||||
|
|||||||
@@ -22,8 +22,9 @@ import time
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from rl_coach.core_types import Transition
|
from rl_coach.core_types import Transition
|
||||||
|
from rl_coach.logger import screen
|
||||||
from rl_coach.memories.memory import Memory, MemoryGranularity, MemoryParameters
|
from rl_coach.memories.memory import Memory, MemoryGranularity, MemoryParameters
|
||||||
from rl_coach.utils import ReaderWriterLock
|
from rl_coach.utils import ReaderWriterLock, ProgressBar
|
||||||
|
|
||||||
|
|
||||||
class ExperienceReplayParameters(MemoryParameters):
|
class ExperienceReplayParameters(MemoryParameters):
|
||||||
@@ -239,15 +240,17 @@ class ExperienceReplay(Memory):
|
|||||||
with open(file_path, 'rb') as file:
|
with open(file_path, 'rb') as file:
|
||||||
transitions = pickle.load(file)
|
transitions = pickle.load(file)
|
||||||
num_transitions = len(transitions)
|
num_transitions = len(transitions)
|
||||||
start_time = time.time()
|
if num_transitions > self.max_size[1]:
|
||||||
|
screen.warning("Warning! The number of transition to load into the replay buffer ({}) is "
|
||||||
|
"bigger than the max size of the replay buffer ({}). The excessive transitions will "
|
||||||
|
"not be stored.".format(num_transitions, self.max_size[1]))
|
||||||
|
|
||||||
|
progress_bar = ProgressBar(num_transitions)
|
||||||
for transition_idx, transition in enumerate(transitions):
|
for transition_idx, transition in enumerate(transitions):
|
||||||
self.store(transition)
|
self.store(transition)
|
||||||
|
|
||||||
# print progress
|
# print progress
|
||||||
if transition_idx % 100 == 0:
|
if transition_idx % 100 == 0:
|
||||||
percentage = int((100 * transition_idx) / num_transitions)
|
progress_bar.update(transition_idx)
|
||||||
sys.stdout.write("\rProgress: ({}/{})".format(transition_idx, num_transitions))
|
|
||||||
sys.stdout.write(' Time (sec): {}'.format(round(time.time() - start_time, 2)))
|
progress_bar.close()
|
||||||
sys.stdout.write(' {}%|{}{}| '.format(percentage, '#' * int(percentage / 10),
|
|
||||||
' ' * (10 - int(percentage / 10))))
|
|
||||||
sys.stdout.flush()
|
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ import inspect
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import signal
|
import signal
|
||||||
|
import sys
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
from multiprocessing import Manager
|
from multiprocessing import Manager
|
||||||
@@ -548,4 +549,24 @@ class ReaderWriterLock(object):
|
|||||||
def release_writing(self):
|
def release_writing(self):
|
||||||
self.num_readers_lock.acquire()
|
self.num_readers_lock.acquire()
|
||||||
self.num_readers -= 1
|
self.num_readers -= 1
|
||||||
self.num_readers_lock.release()
|
self.num_readers_lock.release()
|
||||||
|
|
||||||
|
|
||||||
|
class ProgressBar(object):
|
||||||
|
def __init__(self, max_value):
|
||||||
|
self.start_time = time.time()
|
||||||
|
self.max_value = max_value
|
||||||
|
self.current_value = 0
|
||||||
|
|
||||||
|
def update(self, current_value):
|
||||||
|
self.current_value = current_value
|
||||||
|
percentage = int((100 * current_value) / self.max_value)
|
||||||
|
sys.stdout.write("\rProgress: ({}/{}) Time: {} sec {}%|{}{}| "
|
||||||
|
.format(current_value, self.max_value,
|
||||||
|
round(time.time() - self.start_time, 2),
|
||||||
|
percentage, '#' * int(percentage / 10),
|
||||||
|
' ' * (10 - int(percentage / 10))))
|
||||||
|
sys.stdout.flush()
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
print("")
|
||||||
|
|||||||
Reference in New Issue
Block a user