mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 11:10:20 +01:00
Release 0.9
Main changes are detailed below: New features - * CARLA 0.7 simulator integration * Human control of the game play * Recording of human game play and storing / loading the replay buffer * Behavioral cloning agent and presets * Golden tests for several presets * Selecting between deep / shallow image embedders * Rendering through pygame (with some boost in performance) API changes - * Improved environment wrapper API * Added an evaluate flag to allow convenient evaluation of existing checkpoints * Improve frameskip definition in Gym Bug fixes - * Fixed loading of checkpoints for agents with more than one network * Fixed the N Step Q learning agent python3 compatibility
This commit is contained in:
63
utils.py
63
utils.py
@@ -20,6 +20,7 @@ import os
|
||||
import numpy as np
|
||||
import threading
|
||||
from subprocess import call, Popen
|
||||
import signal
|
||||
|
||||
killed_processes = []
|
||||
|
||||
@@ -54,9 +55,9 @@ class Enum(object):
|
||||
|
||||
|
||||
class RunPhase(Enum):
|
||||
HEATUP = 0
|
||||
TRAIN = 1
|
||||
TEST = 2
|
||||
HEATUP = "Heatup"
|
||||
TRAIN = "Training"
|
||||
TEST = "Testing"
|
||||
|
||||
|
||||
def list_all_classes_in_module(module):
|
||||
@@ -292,3 +293,59 @@ def get_open_port():
|
||||
s.close()
|
||||
return port
|
||||
|
||||
|
||||
class timeout:
|
||||
def __init__(self, seconds=1, error_message='Timeout'):
|
||||
self.seconds = seconds
|
||||
self.error_message = error_message
|
||||
|
||||
def _handle_timeout(self, signum, frame):
|
||||
raise TimeoutError(self.error_message)
|
||||
|
||||
def __enter__(self):
|
||||
signal.signal(signal.SIGALRM, self._handle_timeout)
|
||||
signal.alarm(self.seconds)
|
||||
|
||||
def __exit__(self, type, value, traceback):
|
||||
signal.alarm(0)
|
||||
|
||||
|
||||
def switch_axes_order(observation, from_type='channels_first', to_type='channels_last'):
|
||||
"""
|
||||
transpose an observation axes from channels_first to channels_last or vice versa
|
||||
:param observation: a numpy array
|
||||
:param from_type: can be 'channels_first' or 'channels_last'
|
||||
:param to_type: can be 'channels_first' or 'channels_last'
|
||||
:return: a new observation with the requested axes order
|
||||
"""
|
||||
if from_type == to_type or len(observation.shape) == 1:
|
||||
return observation
|
||||
assert 2 <= len(observation.shape) <= 3, 'num axes of an observation must be 2 for a vector or 3 for an image'
|
||||
assert type(observation) == np.ndarray, 'observation must be a numpy array'
|
||||
if len(observation.shape) == 3:
|
||||
if from_type == 'channels_first' and to_type == 'channels_last':
|
||||
return np.transpose(observation, (1, 2, 0))
|
||||
elif from_type == 'channels_last' and to_type == 'channels_first':
|
||||
return np.transpose(observation, (2, 0, 1))
|
||||
else:
|
||||
return np.transpose(observation, (1, 0))
|
||||
|
||||
|
||||
def stack_observation(curr_stack, observation, stack_size):
|
||||
"""
|
||||
Adds a new observation to an existing stack of observations from previous time-steps.
|
||||
:param curr_stack: The current observations stack.
|
||||
:param observation: The new observation
|
||||
:param stack_size: The required stack size
|
||||
:return: The updated observation stack
|
||||
"""
|
||||
|
||||
if curr_stack == []:
|
||||
# starting an episode
|
||||
curr_stack = np.vstack(np.expand_dims([observation] * stack_size, 0))
|
||||
curr_stack = switch_axes_order(curr_stack, from_type='channels_first', to_type='channels_last')
|
||||
else:
|
||||
curr_stack = np.append(curr_stack, np.expand_dims(np.squeeze(observation), axis=-1), axis=-1)
|
||||
curr_stack = np.delete(curr_stack, 0, -1)
|
||||
|
||||
return curr_stack
|
||||
|
||||
Reference in New Issue
Block a user