mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
574 lines
18 KiB
Python
574 lines
18 KiB
Python
#
|
|
# Copyright (c) 2017 Intel Corporation
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
|
|
import importlib
|
|
import importlib.util
|
|
import inspect
|
|
import json
|
|
import os
|
|
import signal
|
|
import sys
|
|
import threading
|
|
import time
|
|
from multiprocessing import Manager
|
|
from subprocess import Popen
|
|
from typing import List, Tuple
|
|
|
|
import numpy as np
|
|
|
|
killed_processes = []
|
|
|
|
eps = np.finfo(np.float32).eps
|
|
|
|
|
|
def lower_under_to_upper(s):
|
|
s = s.replace('_', ' ')
|
|
s = s.title()
|
|
return s.replace(' ', '')
|
|
|
|
|
|
def get_base_dir():
|
|
return os.path.dirname(os.path.realpath(__file__))
|
|
|
|
|
|
def list_all_presets():
|
|
presets_path = os.path.join(get_base_dir(), 'presets')
|
|
return [f.split('.')[0] for f in os.listdir(presets_path) if f.endswith('.py') and f != '__init__.py']
|
|
|
|
|
|
def list_all_classes_in_module(module):
|
|
return [k for k, v in inspect.getmembers(module, inspect.isclass) if v.__module__ == module.__name__]
|
|
|
|
|
|
def parse_bool(value):
|
|
return {'true': True, 'false': False}.get(value.strip().lower(), value)
|
|
|
|
|
|
def convert_to_ascii(data):
|
|
import collections
|
|
if isinstance(data, basestring):
|
|
return parse_bool(str(data))
|
|
elif isinstance(data, collections.Mapping):
|
|
return dict(map(convert_to_ascii, data.iteritems()))
|
|
elif isinstance(data, collections.Iterable):
|
|
return type(data)(map(convert_to_ascii, data))
|
|
else:
|
|
return data
|
|
|
|
|
|
def break_file_path(path):
|
|
base = os.path.splitext(os.path.basename(path))[0]
|
|
extension = os.path.splitext(os.path.basename(path))[1]
|
|
dir = os.path.dirname(path)
|
|
return dir, base, extension
|
|
|
|
|
|
def is_empty(str):
|
|
return str == 0 or len(str.replace("'", "").replace("\"", "")) == 0
|
|
|
|
|
|
def read_json(filename):
|
|
# read json file
|
|
with open(filename, 'r') as f:
|
|
dict = json.loads(f.read())
|
|
return dict
|
|
|
|
|
|
def write_json(filename, dict):
|
|
# read json file
|
|
with open(filename, 'w') as f:
|
|
f.write(json.dumps(dict, indent=4))
|
|
|
|
|
|
def path_is_valid_dir(path):
|
|
return os.path.isdir(path)
|
|
|
|
|
|
def remove_suffix(name, suffix_start):
|
|
for s in suffix_start:
|
|
split = name.find(s)
|
|
if split != -1:
|
|
name = name[:split]
|
|
return name
|
|
|
|
|
|
def parse_int(value):
|
|
import ast
|
|
try:
|
|
int_value = int(value)
|
|
return int_value if int_value == value else value
|
|
except:
|
|
pass
|
|
|
|
try:
|
|
return ast.literal_eval(value)
|
|
except:
|
|
return value
|
|
|
|
|
|
def set_gpu(gpu_id):
|
|
os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_id)
|
|
|
|
|
|
def set_cpu():
|
|
set_gpu("")
|
|
|
|
|
|
# dictionary to class
|
|
class DictToClass(object):
|
|
def __init__(self, data):
|
|
for name, value in data.iteritems():
|
|
setattr(self, name, self._wrap(value))
|
|
|
|
def _wrap(self, value):
|
|
if isinstance(value, (tuple, list, set, frozenset)):
|
|
return type(value)([self._wrap(v) for v in value])
|
|
else:
|
|
return DictToClass(value) if isinstance(value, dict) else value
|
|
|
|
|
|
# class to dictionary
|
|
def ClassToDict(x):
|
|
# return dict((key, getattr(x, key)) for key in dir(x) if key not in dir(x.__class__))
|
|
dictionary = x.__dict__
|
|
return {key: dictionary[key] for key in dictionary.keys() if not key.startswith('__')}
|
|
|
|
|
|
def cmd_line_run(result, run_cmd, id=-1):
|
|
p = Popen(run_cmd, shell=True, executable="/bin/bash")
|
|
while result[0] is None or result[0] == [None]:
|
|
if id in killed_processes:
|
|
p.kill()
|
|
result[0] = p.poll()
|
|
|
|
|
|
def threaded_cmd_line_run(run_cmd, id=-1):
|
|
runThread = []
|
|
result = [[None]]
|
|
try:
|
|
params = (result, run_cmd, id)
|
|
runThread = threading.Thread(name='runThread', target=cmd_line_run, args=params)
|
|
runThread.daemon = True
|
|
runThread.start()
|
|
except:
|
|
runThread.join()
|
|
return result
|
|
|
|
|
|
class Signal(object):
|
|
"""
|
|
Stores a stream of values and provides methods like get_mean and get_max
|
|
which returns the statistics about accumulated values.
|
|
"""
|
|
def __init__(self, name):
|
|
self.name = name
|
|
self.sample_count = 0
|
|
self.values = []
|
|
|
|
def reset(self):
|
|
self.sample_count = 0
|
|
self.values = []
|
|
|
|
def add_sample(self, sample):
|
|
"""
|
|
:param sample: either a single value or an array of values
|
|
"""
|
|
self.values.append(sample)
|
|
|
|
def _get_values(self):
|
|
if type(self.values[0]) == np.ndarray:
|
|
return np.concatenate(self.values)
|
|
else:
|
|
return self.values
|
|
|
|
def get_last_value(self):
|
|
if len(self.values) == 0:
|
|
return np.nan
|
|
else:
|
|
return self._get_values()[-1]
|
|
|
|
def get_mean(self):
|
|
if len(self.values) == 0:
|
|
return ''
|
|
return np.mean(self._get_values())
|
|
|
|
def get_max(self):
|
|
if len(self.values) == 0:
|
|
return ''
|
|
return np.max(self._get_values())
|
|
|
|
def get_min(self):
|
|
if len(self.values) == 0:
|
|
return ''
|
|
return np.min(self._get_values())
|
|
|
|
def get_stdev(self):
|
|
if len(self.values) == 0:
|
|
return ''
|
|
return np.std(self._get_values())
|
|
|
|
|
|
def force_list(var):
|
|
if isinstance(var, list):
|
|
return var
|
|
else:
|
|
return [var]
|
|
|
|
|
|
def squeeze_list(var):
|
|
if len(var) == 1:
|
|
return var[0]
|
|
else:
|
|
return var
|
|
|
|
|
|
# http://www.johndcook.com/blog/standard_deviation/
|
|
class RunningStat(object):
|
|
def __init__(self, shape):
|
|
self._shape = shape
|
|
self._num_samples = 0
|
|
self._mean = np.zeros(shape)
|
|
self._std = np.zeros(shape)
|
|
|
|
def reset(self):
|
|
self._num_samples = 0
|
|
self._mean = np.zeros(self._shape)
|
|
self._std = np.zeros(self._shape)
|
|
|
|
def push(self, sample):
|
|
sample = np.asarray(sample)
|
|
assert sample.shape == self._mean.shape, 'RunningStat input shape mismatch'
|
|
self._num_samples += 1
|
|
if self._num_samples == 1:
|
|
self._mean[...] = sample
|
|
else:
|
|
old_mean = self._mean.copy()
|
|
self._mean[...] = old_mean + (sample - old_mean) / self._num_samples
|
|
self._std[...] = self._std + (sample - old_mean) * (sample - self._mean)
|
|
|
|
@property
|
|
def n(self):
|
|
return self._num_samples
|
|
|
|
@property
|
|
def mean(self):
|
|
return self._mean
|
|
|
|
@property
|
|
def var(self):
|
|
return self._std / (self._num_samples - 1) if self._num_samples > 1 else np.square(self._mean)
|
|
|
|
@property
|
|
def std(self):
|
|
return np.sqrt(self.var)
|
|
|
|
@property
|
|
def shape(self):
|
|
return self._mean.shape
|
|
|
|
|
|
def get_open_port():
|
|
import socket
|
|
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
|
s.bind(("", 0))
|
|
s.listen(1)
|
|
port = s.getsockname()[1]
|
|
s.close()
|
|
return port
|
|
|
|
|
|
class timeout:
|
|
def __init__(self, seconds=1, error_message='Timeout'):
|
|
self.seconds = seconds
|
|
self.error_message = error_message
|
|
|
|
def _handle_timeout(self, signum, frame):
|
|
raise TimeoutError(self.error_message)
|
|
|
|
def __enter__(self):
|
|
signal.signal(signal.SIGALRM, self._handle_timeout)
|
|
signal.alarm(self.seconds)
|
|
|
|
def __exit__(self, type, value, traceback):
|
|
signal.alarm(0)
|
|
|
|
|
|
def switch_axes_order(observation, from_type='channels_first', to_type='channels_last'):
|
|
"""
|
|
transpose an observation axes from channels_first to channels_last or vice versa
|
|
:param observation: a numpy array
|
|
:param from_type: can be 'channels_first' or 'channels_last'
|
|
:param to_type: can be 'channels_first' or 'channels_last'
|
|
:return: a new observation with the requested axes order
|
|
"""
|
|
if from_type == to_type or len(observation.shape) == 1:
|
|
return observation
|
|
assert 2 <= len(observation.shape) <= 3, 'num axes of an observation must be 2 for a vector or 3 for an image'
|
|
assert type(observation) == np.ndarray, 'observation must be a numpy array'
|
|
if len(observation.shape) == 3:
|
|
if from_type == 'channels_first' and to_type == 'channels_last':
|
|
return np.transpose(observation, (1, 2, 0))
|
|
elif from_type == 'channels_last' and to_type == 'channels_first':
|
|
return np.transpose(observation, (2, 0, 1))
|
|
else:
|
|
return np.transpose(observation, (1, 0))
|
|
|
|
|
|
def stack_observation(curr_stack, observation, stack_size):
|
|
"""
|
|
Adds a new observation to an existing stack of observations from previous time-steps.
|
|
:param curr_stack: The current observations stack.
|
|
:param observation: The new observation
|
|
:param stack_size: The required stack size
|
|
:return: The updated observation stack
|
|
"""
|
|
|
|
if curr_stack == []:
|
|
# starting an episode
|
|
curr_stack = np.vstack(np.expand_dims([observation] * stack_size, 0))
|
|
curr_stack = switch_axes_order(curr_stack, from_type='channels_first', to_type='channels_last')
|
|
else:
|
|
curr_stack = np.append(curr_stack, np.expand_dims(np.squeeze(observation), axis=-1), axis=-1)
|
|
curr_stack = np.delete(curr_stack, 0, -1)
|
|
|
|
return curr_stack
|
|
|
|
|
|
def call_method_for_all(instances: List, method: str, args=[], kwargs={}) -> List:
|
|
"""
|
|
Calls the same function for all the class instances in the group
|
|
:param instances: a list of class instances to apply the method on
|
|
:param method: the name of the function to be called
|
|
:param args: the positional parameters of the method
|
|
:param kwargs: the named parameters of the method
|
|
:return: a list of the returns values for all the instances
|
|
"""
|
|
result = []
|
|
if not isinstance(args, list):
|
|
args = [args]
|
|
sub_methods = method.split('.') # we allow calling an internal method such as "as_level_manager.train"
|
|
for instance in instances:
|
|
sub_instance = instance
|
|
for sub_method in sub_methods:
|
|
if not hasattr(sub_instance, sub_method):
|
|
raise ValueError("The requested instance method {} does not exist for {}"
|
|
.format(sub_method, '.'.join([str(instance.__class__.__name__)] + sub_methods)))
|
|
sub_instance = getattr(sub_instance, sub_method)
|
|
result.append(sub_instance(*args, **kwargs))
|
|
return result
|
|
|
|
|
|
def set_member_values_for_all(instances: List, member: str, val) -> None:
|
|
"""
|
|
Calls the same function for all the class instances in the group
|
|
:param instances: a list of class instances to apply the method on
|
|
:param member: the name of the member to be changed
|
|
:param val: the new value to assign
|
|
:return: None
|
|
"""
|
|
for instance in instances:
|
|
if not hasattr(instance, member):
|
|
raise ValueError("The requested instance member does not exist")
|
|
setattr(instance, member, val)
|
|
|
|
|
|
def short_dynamic_import(module_path_and_attribute: str, ignore_module_case: bool=False):
|
|
"""
|
|
Import by "path:attribute"
|
|
:param module_path_and_attribute: a path to a python file (using dots to separate dirs), followed by a ":" and
|
|
an attribute name to import from the path
|
|
:return: the requested attribute
|
|
"""
|
|
if '/' in module_path_and_attribute:
|
|
"""
|
|
Imports a class from a module using the full path of the module. The path should be given as:
|
|
<full absolute module path with / including .py>:<class name to import>
|
|
And this will be the same as doing "from <full absolute module path> import <class name to import>"
|
|
"""
|
|
return dynamic_import_from_full_path(*module_path_and_attribute.split(':'),
|
|
ignore_module_case=ignore_module_case)
|
|
else:
|
|
"""
|
|
Imports a class from a module using the relative path of the module. The path should be given as:
|
|
<full absolute module path with . and not including .py>:<class name to import>
|
|
And this will be the same as doing "from <full relative module path> import <class name to import>"
|
|
"""
|
|
return dynamic_import(*module_path_and_attribute.split(':'),
|
|
ignore_module_case=ignore_module_case)
|
|
|
|
|
|
def dynamic_import(module_path: str, class_name: str, ignore_module_case: bool=False):
|
|
if ignore_module_case:
|
|
module_name = module_path.split(".")[-1]
|
|
available_modules = os.listdir(os.path.dirname(module_path.replace('.', '/')))
|
|
for module in available_modules:
|
|
curr_module_ext = module.split('.')[-1].lower()
|
|
curr_module_name = module.split('.')[0]
|
|
if curr_module_ext == "py" and curr_module_name.lower() == module_name.lower():
|
|
module_path = '.'.join(module_path.split(".")[:-1] + [curr_module_name])
|
|
module = importlib.import_module(module_path)
|
|
class_ref = getattr(module, class_name)
|
|
return class_ref
|
|
|
|
|
|
def dynamic_import_from_full_path(module_path: str, class_name: str, ignore_module_case: bool=False):
|
|
if ignore_module_case:
|
|
module_name = module_path.split("/")[-1]
|
|
available_modules = os.listdir(os.path.dirname(module_path))
|
|
for module in available_modules:
|
|
curr_module_ext = module.split('.')[-1].lower()
|
|
curr_module_name = module.split('.')[0]
|
|
if curr_module_ext == "py" and curr_module_name.lower() == module_name.lower():
|
|
module_path = '.'.join(module_path.split("/")[:-1] + [curr_module_name])
|
|
spec = importlib.util.spec_from_file_location("module", module_path)
|
|
module = importlib.util.module_from_spec(spec)
|
|
spec.loader.exec_module(module)
|
|
class_ref = getattr(module, class_name)
|
|
return class_ref
|
|
|
|
|
|
def dynamic_import_and_instantiate_module_from_params(module_parameters, path=None, positional_args=[],
|
|
extra_kwargs={}):
|
|
"""
|
|
A function dedicated for coach modules like memory, exploration policy, etc.
|
|
Given the module parameters, it imports it and instantiates it.
|
|
:param module_parameters:
|
|
:return:
|
|
"""
|
|
import inspect
|
|
if path is None:
|
|
path = module_parameters.path
|
|
module = short_dynamic_import(path)
|
|
args = set(inspect.getfullargspec(module).args).intersection(module_parameters.__dict__)
|
|
args = {k: module_parameters.__dict__[k] for k in args}
|
|
args = {**args, **extra_kwargs}
|
|
return short_dynamic_import(path)(*positional_args, **args)
|
|
|
|
|
|
def last_sample(state):
|
|
"""
|
|
given a batch of states, return the last sample of the batch with length 1
|
|
batch axis.
|
|
"""
|
|
return {
|
|
k: np.expand_dims(v[-1], 0)
|
|
for k, v in state.items()
|
|
}
|
|
|
|
|
|
def get_all_subclasses(cls):
|
|
if len(cls.__subclasses__()) == 0:
|
|
return []
|
|
ret = []
|
|
for drv in cls.__subclasses__():
|
|
ret.append(drv)
|
|
ret.extend(get_all_subclasses(drv))
|
|
|
|
return ret
|
|
|
|
|
|
class SharedMemoryScratchPad(object):
|
|
def __init__(self):
|
|
self.dict = {}
|
|
|
|
def add(self, key, value):
|
|
self.dict[key] = value
|
|
|
|
def get(self, key, timeout=30):
|
|
start_time = time.time()
|
|
timeout_passed = False
|
|
while key not in self.dict and not timeout_passed:
|
|
time.sleep(0.1)
|
|
timeout_passed = (time.time() - start_time) > timeout
|
|
|
|
if timeout_passed:
|
|
return None
|
|
return self.dict[key]
|
|
|
|
def internal_call(self, key, func, args: Tuple):
|
|
if type(args) != tuple:
|
|
args = (args,)
|
|
return getattr(self.dict[key], func)(*args)
|
|
|
|
|
|
class Timer(object):
|
|
def __init__(self, prefix):
|
|
self.prefix = prefix
|
|
|
|
def __enter__(self):
|
|
self.start = time.time()
|
|
|
|
def __exit__(self, type, value, traceback):
|
|
print(self.prefix, time.time() - self.start)
|
|
|
|
|
|
class ReaderWriterLock(object):
|
|
def __init__(self):
|
|
self.num_readers_lock = Manager().Lock()
|
|
self.writers_lock = Manager().Lock()
|
|
self.num_readers = 0
|
|
self.now_writing = False
|
|
|
|
def some_worker_is_reading(self):
|
|
return self.num_readers > 0
|
|
|
|
def some_worker_is_writing(self):
|
|
return self.now_writing is True
|
|
|
|
def lock_writing_and_reading(self):
|
|
self.writers_lock.acquire() # first things first - block all other writers
|
|
self.now_writing = True # block new readers who haven't started reading yet
|
|
while self.some_worker_is_reading(): # let existing readers finish their homework
|
|
time.sleep(0.05)
|
|
|
|
def release_writing_and_reading(self):
|
|
self.now_writing = False # release readers - guarantee no readers starvation
|
|
self.writers_lock.release() # release writers
|
|
|
|
def lock_writing(self):
|
|
while self.now_writing:
|
|
time.sleep(0.05)
|
|
|
|
self.num_readers_lock.acquire()
|
|
self.num_readers += 1
|
|
self.num_readers_lock.release()
|
|
|
|
def release_writing(self):
|
|
self.num_readers_lock.acquire()
|
|
self.num_readers -= 1
|
|
self.num_readers_lock.release()
|
|
|
|
|
|
class ProgressBar(object):
|
|
def __init__(self, max_value):
|
|
self.start_time = time.time()
|
|
self.max_value = max_value
|
|
self.current_value = 0
|
|
|
|
def update(self, current_value, additional_info=""):
|
|
self.current_value = current_value
|
|
percentage = int((100 * current_value) / self.max_value)
|
|
sys.stdout.write("\rProgress: ({}/{}) Time: {} sec {}%|{}{}| {}"
|
|
.format(current_value, self.max_value,
|
|
round(time.time() - self.start_time, 2),
|
|
percentage, '#' * int(percentage / 10),
|
|
' ' * (10 - int(percentage / 10)),
|
|
additional_info))
|
|
sys.stdout.flush()
|
|
|
|
def close(self):
|
|
print("")
|