1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 19:20:19 +01:00

Moved coach to its top level module.

This commit is contained in:
Roman Dobosz
2018-04-25 12:03:36 +02:00
parent 7e61bb5685
commit 676c69e391
76 changed files with 214 additions and 1437 deletions

View File

@@ -13,28 +13,28 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
from agents.actor_critic_agent import ActorCriticAgent
from agents.agent import Agent
from agents.bc_agent import BCAgent
from agents.bootstrapped_dqn_agent import BootstrappedDQNAgent
from agents.categorical_dqn_agent import CategoricalDQNAgent
from agents.clipped_ppo_agent import ClippedPPOAgent
from agents.ddpg_agent import DDPGAgent
from agents.ddqn_agent import DDQNAgent
from agents.dfp_agent import DFPAgent
from agents.dqn_agent import DQNAgent
from agents.human_agent import HumanAgent
from agents.imitation_agent import ImitationAgent
from agents.mmc_agent import MixedMonteCarloAgent
from agents.n_step_q_agent import NStepQAgent
from agents.naf_agent import NAFAgent
from agents.nec_agent import NECAgent
from agents.pal_agent import PALAgent
from agents.policy_gradients_agent import PolicyGradientsAgent
from agents.policy_optimization_agent import PolicyOptimizationAgent
from agents.ppo_agent import PPOAgent
from agents.qr_dqn_agent import QuantileRegressionDQNAgent
from agents.value_optimization_agent import ValueOptimizationAgent
from coach.agents.actor_critic_agent import ActorCriticAgent
from coach.agents.agent import Agent
from coach.agents.bc_agent import BCAgent
from coach.agents.bootstrapped_dqn_agent import BootstrappedDQNAgent
from coach.agents.categorical_dqn_agent import CategoricalDQNAgent
from coach.agents.clipped_ppo_agent import ClippedPPOAgent
from coach.agents.ddpg_agent import DDPGAgent
from coach.agents.ddqn_agent import DDQNAgent
from coach.agents.dfp_agent import DFPAgent
from coach.agents.dqn_agent import DQNAgent
from coach.agents.human_agent import HumanAgent
from coach.agents.imitation_agent import ImitationAgent
from coach.agents.mmc_agent import MixedMonteCarloAgent
from coach.agents.n_step_q_agent import NStepQAgent
from coach.agents.naf_agent import NAFAgent
from coach.agents.nec_agent import NECAgent
from coach.agents.pal_agent import PALAgent
from coach.agents.policy_gradients_agent import PolicyGradientsAgent
from coach.agents.policy_optimization_agent import PolicyOptimizationAgent
from coach.agents.ppo_agent import PPOAgent
from coach.agents.qr_dqn_agent import QuantileRegressionDQNAgent
from coach.agents.value_optimization_agent import ValueOptimizationAgent
__all__ = [ActorCriticAgent,
Agent,

View File

@@ -16,9 +16,9 @@
import numpy as np
from scipy import signal
from agents import policy_optimization_agent as poa
import utils
import logger
from coach.agents import policy_optimization_agent as poa
from coach import utils
from coach import logger
# Actor Critic - https://arxiv.org/abs/1602.01783

View File

@@ -18,7 +18,7 @@ import copy
import random
import time
import logger
from coach import logger
try:
import matplotlib.pyplot as plt
except ImportError:
@@ -29,13 +29,12 @@ from pandas.io import pickle
from six.moves import range
import scipy
from architectures.tensorflow_components import shared_variables as sv
import configurations
import exploration_policies as ep # noqa, used in eval()
import memories # noqa, used in eval()
from memories import memory
import renderer
import utils
from coach.architectures.tensorflow_components import shared_variables as sv
from coach import configurations
from coach import exploration_policies as ep # noqa, used in eval()
from coach import memories # noqa, used in eval()
from coach.memories import memory
from coach import utils
class Agent(object):
@@ -100,7 +99,6 @@ class Agent(object):
self.main_network = None
self.networks = []
self.last_episode_images = []
self.renderer = renderer.Renderer()
# signals
self.signals = []
@@ -234,13 +232,6 @@ class Agent(object):
r, g, b = observation[:, :, 0], observation[:, :, 1], observation[:, :, 2]
observation = 0.2989 * r + 0.5870 * g + 0.1140 * b
# Render the processed observation which is how the agent will see it
# Warning: this cannot currently be done in parallel to rendering the environment
if self.tp.visualization.render_observation:
if not self.renderer.is_open:
self.renderer.create_screen(observation.shape[0], observation.shape[1])
self.renderer.render_image(observation)
return observation.astype('uint8')
else:
if self.tp.env.normalize_observation:

View File

@@ -15,7 +15,7 @@
#
import numpy as np
from agents import imitation_agent
from coach.agents import imitation_agent
# Behavioral Cloning Agent

View File

@@ -15,8 +15,9 @@
#
import numpy as np
from agents import value_optimization_agent as voa
import utils
from coach.agents import value_optimization_agent as voa
from coach import utils
# Bootstrapped DQN - https://arxiv.org/pdf/1602.04621.pdf
class BootstrappedDQNAgent(voa.ValueOptimizationAgent):

View File

@@ -15,7 +15,7 @@
#
import numpy as np
from agents import value_optimization_agent as voa
from coach.agents import value_optimization_agent as voa
# Categorical Deep Q Network - https://arxiv.org/pdf/1707.06887.pdf

View File

@@ -19,10 +19,10 @@ from random import shuffle
import numpy as np
from agents import actor_critic_agent as aca
from agents import policy_optimization_agent as poa
import logger
import utils
from coach.agents import actor_critic_agent as aca
from coach.agents import policy_optimization_agent as poa
from coach import logger
from coach import utils
# Clipped Proximal Policy Optimization - https://arxiv.org/abs/1707.06347

View File

@@ -17,11 +17,11 @@ import copy
import numpy as np
from agents import actor_critic_agent as aca
from agents import agent
from architectures import network_wrapper as nw
import configurations as conf
import utils
from coach.agents import actor_critic_agent as aca
from coach.agents import agent
from coach.architectures import network_wrapper as nw
from coach import configurations as conf
from coach import utils
# Deep Deterministic Policy Gradients Network - https://arxiv.org/pdf/1509.02971.pdf

View File

@@ -15,7 +15,7 @@
#
import numpy as np
from agents import value_optimization_agent as voa
from coach.agents import value_optimization_agent as voa
# Double DQN - https://arxiv.org/abs/1509.06461

View File

@@ -15,9 +15,9 @@
#
import numpy as np
from agents import agent
from architectures import network_wrapper as nw
import utils
from coach.agents import agent
from coach.architectures import network_wrapper as nw
from coach import utils
# Direct Future Prediction Agent - http://vladlen.info/papers/learning-to-act.pdf

View File

@@ -15,7 +15,7 @@
#
import numpy as np
from agents import value_optimization_agent as voa
from coach.agents import value_optimization_agent as voa
# Distributional Deep Q Network - https://arxiv.org/pdf/1707.06887.pdf

View File

@@ -15,7 +15,7 @@
#
import numpy as np
from agents import value_optimization_agent as voa
from coach.agents import value_optimization_agent as voa
# Deep Q Network - https://www.cs.toronto.edu/~vmnih/docs/dqn.pdf

View File

@@ -19,9 +19,9 @@ import os
import pygame
from pandas.io import pickle
from agents import agent
import logger
import utils
from coach.agents import agent
from coach import logger
from coach import utils
class HumanAgent(agent.Agent):

View File

@@ -15,10 +15,10 @@
#
import collections
from agents import agent
from architectures import network_wrapper as nw
import utils
import logging
from coach.agents import agent
from coach.architectures import network_wrapper as nw
from coach import utils
from coach import logger
# Imitation Agent
@@ -55,7 +55,7 @@ class ImitationAgent(agent.Agent):
# log to screen
if phase == utils.RunPhase.TRAIN:
# for the training phase - we log during the episode to visualize the progress in training
logging.screen.log_dict(
logger.screen.log_dict(
collections.OrderedDict([
("Worker", self.task_id),
("Episode", self.current_episode),
@@ -65,5 +65,5 @@ class ImitationAgent(agent.Agent):
prefix="Training"
)
else:
# for the evaluation phase - logging as in regular RL
# for the evaluation phase - logger as in regular RL
agent.Agent.log_to_screen(self, phase)

View File

@@ -15,7 +15,7 @@
#
import numpy as np
from agents import value_optimization_agent as voa
from coach.agents import value_optimization_agent as voa
class MixedMonteCarloAgent(voa.ValueOptimizationAgent):

View File

@@ -15,10 +15,10 @@
#
import numpy as np
from agents import value_optimization_agent as voa
from agents import policy_optimization_agent as poa
import logger
import utils
from coach.agents import value_optimization_agent as voa
from coach.agents import policy_optimization_agent as poa
from coach import logger
from coach import utils
# N Step Q Learning Agent - https://arxiv.org/abs/1602.01783

View File

@@ -15,8 +15,8 @@
#
import numpy as np
from agents.value_optimization_agent import ValueOptimizationAgent
import utils
from coach.agents.value_optimization_agent import ValueOptimizationAgent
from coach import utils
# Normalized Advantage Functions - https://arxiv.org/pdf/1603.00748.pdf

View File

@@ -13,9 +13,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
from agents import value_optimization_agent as voa
from logger import screen
import utils
from coach.agents import value_optimization_agent as voa
from coach.logger import screen
from coach import utils
# Neural Episodic Control - https://arxiv.org/pdf/1703.01988.pdf

View File

@@ -15,7 +15,7 @@
#
import numpy as np
from agents import value_optimization_agent as voa
from coach.agents import value_optimization_agent as voa
# Persistent Advantage Learning - https://arxiv.org/pdf/1512.04860.pdf

View File

@@ -15,9 +15,9 @@
#
import numpy as np
from agents import policy_optimization_agent as poa
import logger
import utils
from coach.agents import policy_optimization_agent as poa
from coach import logger
from coach import utils
class PolicyGradientsAgent(poa.PolicyOptimizationAgent):

View File

@@ -17,10 +17,10 @@ import collections
import numpy as np
from agents import agent
from architectures import network_wrapper as nw
import logger
import utils
from coach.agents import agent
from coach.architectures import network_wrapper as nw
from coach import logger
from coach import utils
class PolicyGradientRescaler(utils.Enum):

View File

@@ -18,12 +18,12 @@ import copy
import numpy as np
from agents import actor_critic_agent as aca
from agents import policy_optimization_agent as poa
from architectures import network_wrapper as nw
import configurations
import logger
import utils
from coach.agents import actor_critic_agent as aca
from coach.agents import policy_optimization_agent as poa
from coach.architectures import network_wrapper as nw
from coach import configurations
from coach import logger
from coach import utils
# Proximal Policy Optimization - https://arxiv.org/pdf/1707.06347.pdf

View File

@@ -15,7 +15,7 @@
#
import numpy as np
from agents import value_optimization_agent as voa
from coach.agents import value_optimization_agent as voa
# Quantile Regression Deep Q Network - https://arxiv.org/pdf/1710.10044v1.pdf

View File

@@ -15,9 +15,9 @@
#
import numpy as np
from agents import agent
from architectures import network_wrapper as nw
import utils
from coach.agents import agent
from coach.architectures import network_wrapper as nw
from coach import utils
class ValueOptimizationAgent(agent.Agent):

View File

@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logger
from coach import logger
try:
from architectures.tensorflow_components import general_network as ts_gn

View File

@@ -16,8 +16,8 @@
import ngraph as ng
import numpy as np
from architectures import architecture
import utils
from coach.architectures import architecture
from coach import utils
class NeonArchitecture(architecture.Architecture):

View File

@@ -17,11 +17,11 @@ import ngraph as ng
from ngraph.frontends import neon
from ngraph.util import names as ngraph_names
from architectures.neon_components import architecture
from architectures.neon_components import embedders
from architectures.neon_components import middleware
from architectures.neon_components import heads
import configurations as conf
from coach.architectures.neon_components import architecture
from coach.architectures.neon_components import embedders
from coach.architectures.neon_components import middleware
from coach.architectures.neon_components import heads
from coach import configurations as conf
class GeneralNeonNetwork(architecture.NeonArchitecture):

View File

@@ -17,8 +17,8 @@ import ngraph as ng
from ngraph.frontends import neon
from ngraph.util import names as ngraph_names
import utils
from architectures.neon_components import losses
from coach import utils
from coach.architectures.neon_components import losses
class Head(object):

View File

@@ -16,16 +16,16 @@
import os
import collections
import configurations as conf
import logger
from coach import configurations as conf
from coach import logger
try:
import tensorflow as tf
from architectures.tensorflow_components import general_network as tf_net #import GeneralTensorFlowNetwork
from coach.architectures.tensorflow_components import general_network as tf_net
except ImportError:
logger.failed_imports.append("TensorFlow")
try:
from architectures.neon_components import general_network as neon_net
from coach.architectures.neon_components import general_network as neon_net
except ImportError:
logger.failed_imports.append("Neon")

View File

@@ -17,9 +17,10 @@ import time
import tensorflow as tf
from architectures import architecture
import configurations as conf
import utils
from coach.architectures import architecture
from coach import configurations as conf
from coach import utils
def variable_summaries(var):
"""Attach a lot of summaries to a Tensor (for TensorBoard visualization)."""
@@ -36,6 +37,7 @@ def variable_summaries(var):
tf.summary.scalar('min', tf.reduce_min(var))
tf.summary.histogram('histogram', var)
class TensorFlowArchitecture(architecture.Architecture):
def __init__(self, tuning_parameters, name="", global_network=None, network_is_local=True):
"""

View File

@@ -15,7 +15,7 @@
#
import tensorflow as tf
from configurations import EmbedderComplexity
from coach.configurations import EmbedderComplexity
class InputEmbedder(object):

View File

@@ -15,11 +15,11 @@
#
import tensorflow as tf
from architectures.tensorflow_components import architecture
from architectures.tensorflow_components import embedders
from architectures.tensorflow_components import middleware
from architectures.tensorflow_components import heads
import configurations as conf
from coach.architectures.tensorflow_components import architecture
from coach.architectures.tensorflow_components import embedders
from coach.architectures.tensorflow_components import middleware
from coach.architectures.tensorflow_components import heads
from coach import configurations as conf
class GeneralTensorFlowNetwork(architecture.TensorFlowArchitecture):

View File

@@ -16,7 +16,7 @@
import tensorflow as tf
import numpy as np
import utils
from coach import utils
# Used to initialize weights for policy and value output layers

View File

@@ -14,9 +14,9 @@
# limitations under the License.
#
import json
import types
import utils
from coach import utils
class Frameworks(utils.Enum):

View File

@@ -13,10 +13,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
from environments.gym_environment_wrapper import GymEnvironmentWrapper
from environments.doom_environment_wrapper import DoomEnvironmentWrapper
from environments.carla_environment_wrapper import CarlaEnvironmentWrapper
import utils
from coach.environments.gym_environment_wrapper import GymEnvironmentWrapper
from coach.environments.doom_environment_wrapper import DoomEnvironmentWrapper
from coach.environments.carla_environment_wrapper import CarlaEnvironmentWrapper
from coach import utils
class EnvTypes(utils.Enum):

View File

@@ -6,7 +6,7 @@ import sys
import numpy as np
import logger
from coach import logger
try:
if 'CARLA_ROOT' in os.environ:
sys.path.append(os.path.join(os.environ.get('CARLA_ROOT'),
@@ -16,8 +16,8 @@ try:
from carla import sensor as carla_sensor
except ImportError:
logger.failed_imports.append("CARLA")
from environments import environment_wrapper as ew
import utils
from coach.environments import environment_wrapper as ew
from coach import utils
# enum of the available levels and their path

View File

@@ -13,19 +13,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import enum
import os
import numpy as np
import logger
from coach import logger
try:
import vizdoom
except ImportError:
logger.failed_imports.append("ViZDoom")
from environments import environment_wrapper as ew
import utils
from coach.environments import environment_wrapper as ew
from coach import utils
# enum of the available levels and their path

View File

@@ -18,8 +18,7 @@ import time
import numpy as np
import renderer
import utils
from coach import utils
class EnvironmentWrapper(object):
@@ -62,7 +61,6 @@ class EnvironmentWrapper(object):
self.wait_for_explicit_human_action = False
self.is_rendered = self.is_rendered or self.human_control
self.game_is_open = True
self.renderer = renderer.Renderer()
@property
def measurements(self):
@@ -106,26 +104,6 @@ class EnvironmentWrapper(object):
Get an action from the user keyboard
:return: action index
"""
if self.wait_for_explicit_human_action:
while len(self.renderer.pressed_keys) == 0:
self.renderer.get_events()
if self.key_to_action == {}:
# the keys are the numbers on the keyboard corresponding to the action index
if len(self.renderer.pressed_keys) > 0:
action_idx = self.renderer.pressed_keys[0] - ord("1")
if 0 <= action_idx < self.action_space_size:
return action_idx
else:
# the keys are mapped through the environment to more intuitive keyboard keys
# key = tuple(self.renderer.pressed_keys)
# for key in self.renderer.pressed_keys:
for env_keys in self.key_to_action.keys():
if set(env_keys) == set(self.renderer.pressed_keys):
return self.key_to_action[env_keys]
# return the default action 0 so that the environment will continue running
return self.default_action
def step(self, action_idx):
"""

View File

@@ -18,8 +18,8 @@ import random
import gym
import numpy as np
from environments import environment_wrapper as ew
import utils
from coach.environments import environment_wrapper as ew
from coach import utils
class GymEnvironmentWrapper(ew.EnvironmentWrapper):

View File

@@ -13,18 +13,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
from exploration_policies.additive_noise import AdditiveNoise
from exploration_policies.approximated_thompson_sampling_using_dropout import ApproximatedThompsonSamplingUsingDropout
from exploration_policies.bayesian import Bayesian
from exploration_policies.boltzmann import Boltzmann
from exploration_policies.bootstrapped import Bootstrapped
from exploration_policies.categorical import Categorical
from exploration_policies.continuous_entropy import ContinuousEntropy
from exploration_policies.e_greedy import EGreedy
from exploration_policies.exploration_policy import ExplorationPolicy
from exploration_policies.greedy import Greedy
from exploration_policies.ou_process import OUProcess
from exploration_policies.thompson_sampling import ThompsonSampling
from coach.exploration_policies.additive_noise import AdditiveNoise
from coach.exploration_policies.approximated_thompson_sampling_using_dropout import ApproximatedThompsonSamplingUsingDropout
from coach.exploration_policies.bayesian import Bayesian
from coach.exploration_policies.boltzmann import Boltzmann
from coach.exploration_policies.bootstrapped import Bootstrapped
from coach.exploration_policies.categorical import Categorical
from coach.exploration_policies.continuous_entropy import ContinuousEntropy
from coach.exploration_policies.e_greedy import EGreedy
from coach.exploration_policies.exploration_policy import ExplorationPolicy
from coach.exploration_policies.greedy import Greedy
from coach.exploration_policies.ou_process import OUProcess
from coach.exploration_policies.thompson_sampling import ThompsonSampling
__all__ = [AdditiveNoise,

View File

@@ -15,8 +15,8 @@
#
import numpy as np
from exploration_policies import exploration_policy
import utils
from coach.exploration_policies import exploration_policy
from coach import utils
class AdditiveNoise(exploration_policy.ExplorationPolicy):

View File

@@ -15,7 +15,7 @@
#
import numpy as np
from exploration_policies import exploration_policy
from coach.exploration_policies import exploration_policy
class ApproximatedThompsonSamplingUsingDropout(exploration_policy.ExplorationPolicy):

View File

@@ -15,8 +15,8 @@
#
import numpy as np
from exploration_policies import exploration_policy
import utils
from coach.exploration_policies import exploration_policy
from coach import utils
class Bayesian(exploration_policy.ExplorationPolicy):

View File

@@ -15,8 +15,9 @@
#
import numpy as np
from exploration_policies import exploration_policy
import utils
from coach.exploration_policies import exploration_policy
from coach import utils
class Boltzmann(exploration_policy.ExplorationPolicy):
def __init__(self, tuning_parameters):

View File

@@ -15,7 +15,7 @@
#
import numpy as np
from exploration_policies import e_greedy
from coach.exploration_policies import e_greedy
class Bootstrapped(e_greedy.EGreedy):

View File

@@ -15,7 +15,7 @@
#
import numpy as np
from exploration_policies import exploration_policy
from coach.exploration_policies import exploration_policy
class Categorical(exploration_policy.ExplorationPolicy):

View File

@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
from exploration_policies import exploration_policy
from coach.exploration_policies import exploration_policy
class ContinuousEntropy(exploration_policy.ExplorationPolicy):

View File

@@ -15,8 +15,8 @@
#
import numpy as np
from exploration_policies import exploration_policy
import utils
from coach.exploration_policies import exploration_policy
from coach import utils
class EGreedy(exploration_policy.ExplorationPolicy):

View File

@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import utils
from coach import utils
class ExplorationPolicy(object):

View File

@@ -15,7 +15,7 @@
#
import numpy as np
from exploration_policies import exploration_policy
from coach.exploration_policies import exploration_policy
class Greedy(exploration_policy.ExplorationPolicy):

View File

@@ -15,11 +15,12 @@
#
import numpy as np
from exploration_policies import exploration_policy
from coach.exploration_policies import exploration_policy
# Based on on the description in:
# https://math.stackexchange.com/questions/1287634/implementing-ornstein-uhlenbeck-in-matlab
# Ornstein-Uhlenbeck process
class OUProcess(exploration_policy.ExplorationPolicy):
def __init__(self, tuning_parameters):

View File

@@ -15,7 +15,7 @@
#
import numpy as np
from exploration_policies import exploration_policy
from coach.exploration_policies import exploration_policy
class ThompsonSampling(exploration_policy.ExplorationPolicy):

View File

@@ -13,13 +13,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
from memories.differentiable_neural_dictionary import AnnoyDictionary
from memories.differentiable_neural_dictionary import AnnoyIndex
from memories.differentiable_neural_dictionary import QDND
from memories.episodic_experience_replay import EpisodicExperienceReplay
from memories.memory import Episode
from memories.memory import Memory
from memories.memory import Transition
from coach.memories.differentiable_neural_dictionary import AnnoyDictionary
from coach.memories.differentiable_neural_dictionary import AnnoyIndex
from coach.memories.differentiable_neural_dictionary import QDND
from coach.memories.episodic_experience_replay import EpisodicExperienceReplay
from coach.memories.memory import Episode
from coach.memories.memory import Memory
from coach.memories.memory import Transition
__all__ = [AnnoyDictionary,
AnnoyIndex,

View File

@@ -17,7 +17,7 @@ import typing
import numpy as np
from memories import memory
from coach.memories import memory
class EpisodicExperienceReplay(memory.Memory):

View File

@@ -17,11 +17,11 @@ import ast
import json
import sys
import agents
import configurations as conf
import environments as env
import exploration_policies as ep
import presets
from coach import agents
from coach import configurations as conf
from coach import environments as env
from coach import exploration_policies as ep
from coach import presets
def json_to_preset(json_path):

View File

@@ -1,972 +0,0 @@
#
# Copyright (c) 2017 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
To run Coach Dashboard, run the following command:
python3 dashboard.py
"""
import colorsys
import datetime
import enum
import itertools
import os
import random
from bokeh import palettes
from bokeh import layouts as bl
from bokeh import models as bm
from bokeh.models import widgets as bw
from bokeh import plotting as bp
import numpy as np
import pandas as pd
from pandas.io import pandas_common
import wx
import utils
class DialogApp(wx.App):
def getFileDialog(self):
with wx.FileDialog(None, "Open CSV file", wildcard="CSV files (*.csv)|*.csv",
style=wx.FD_OPEN | wx.FD_FILE_MUST_EXIST | wx.FD_CHANGE_DIR | wx.FD_MULTIPLE) as fileDialog:
if fileDialog.ShowModal() == wx.ID_CANCEL:
return None # the user changed their mind
else:
# Proceed loading the file chosen by the user
return fileDialog.GetPaths()
def getDirDialog(self):
with wx.DirDialog (None, "Choose input directory", "",
style=wx.FD_OPEN | wx.FD_FILE_MUST_EXIST | wx.FD_CHANGE_DIR) as dirDialog:
if dirDialog.ShowModal() == wx.ID_CANCEL:
return None # the user changed their mind
else:
# Proceed loading the dir chosen by the user
return dirDialog.GetPath()
class Signal:
def __init__(self, name, parent):
self.name = name
self.full_name = "{}/{}".format(parent.filename, self.name)
self.selected = False
self.color = random.choice(palettes.Dark2[8])
self.line = None
self.bands = None
self.bokeh_source = parent.bokeh_source
self.min_val = 0
self.max_val = 0
self.axis = 'default'
self.sub_signals = []
for name in self.bokeh_source.data.keys():
if (len(name.split('/')) == 1 and name == self.name) or '/'.join(name.split('/')[:-1]) == self.name:
self.sub_signals.append(name)
if len(self.sub_signals) > 1:
self.mean_signal = utils.squeeze_list([name for name in self.sub_signals if 'Mean' in name.split('/')[-1]])
self.stdev_signal = utils.squeeze_list([name for name in self.sub_signals if 'Stdev' in name.split('/')[-1]])
self.min_signal = utils.squeeze_list([name for name in self.sub_signals if 'Min' in name.split('/')[-1]])
self.max_signal = utils.squeeze_list([name for name in self.sub_signals if 'Max' in name.split('/')[-1]])
else:
self.mean_signal = utils.squeeze_list(self.name)
self.stdev_signal = None
self.min_signal = None
self.max_signal = None
self.has_bollinger_bands = False
if self.mean_signal and self.stdev_signal and self.min_signal and self.max_signal:
self.has_bollinger_bands = True
self.show_bollinger_bands = False
self.bollinger_bands_source = None
self.update_range()
def set_color(self, color):
self.color = color
if self.line:
self.line.glyph.line_color = color
if self.bands:
self.bands.glyph.fill_color = color
def set_selected(self, val):
global current_color
if self.selected != val:
self.selected = val
if self.line:
# self.set_color(palettes.Dark2[8][current_color])
# current_color = (current_color + 1) % len(palettes.Dark2[8])
self.line.visible = self.selected
if self.bands:
self.bands.visible = self.selected and self.show_bollinger_bands
elif self.selected:
# lazy plotting - plot only when selected for the first time
show_spinner()
self.set_color(palettes.Dark2[8][current_color])
current_color = (current_color + 1) % len(palettes.Dark2[8])
if self.has_bollinger_bands:
self.set_bands_source()
self.create_bands()
self.line = plot.line('index', self.mean_signal, source=self.bokeh_source,
line_color=self.color, line_width=2)
self.line.visible = True
hide_spinner()
def set_dash(self, dash):
self.line.glyph.line_dash = dash
def create_bands(self):
self.bands = plot.patch(x='band_x', y='band_y', source=self.bollinger_bands_source,
color=self.color, fill_alpha=0.4, alpha=0.1, line_width=0)
self.bands.visible = self.show_bollinger_bands
# self.min_line = plot.line('index', self.min_signal, source=self.bokeh_source,
# line_color=self.color, line_width=3, line_dash="4 4")
# self.max_line = plot.line('index', self.max_signal, source=self.bokeh_source,
# line_color=self.color, line_width=3, line_dash="4 4")
# self.min_line.visible = self.show_bollinger_bands
# self.max_line.visible = self.show_bollinger_bands
def set_bands_source(self):
x_ticks = self.bokeh_source.data['index']
mean_values = self.bokeh_source.data[self.mean_signal]
stdev_values = self.bokeh_source.data[self.stdev_signal]
band_x = np.append(x_ticks, x_ticks[::-1])
band_y = np.append(mean_values - stdev_values, mean_values[::-1] + stdev_values[::-1])
source_data = {'band_x': band_x, 'band_y': band_y}
if self.bollinger_bands_source:
self.bollinger_bands_source.data = source_data
else:
self.bollinger_bands_source = bm.ColumnDataSource(source_data)
def change_bollinger_bands_state(self, new_state):
self.show_bollinger_bands = new_state
if self.bands and self.selected:
self.bands.visible = new_state
# self.min_line.visible = new_state
# self.max_line.visible = new_state
def update_range(self):
self.min_val = np.min(self.bokeh_source.data[self.mean_signal])
self.max_val = np.max(self.bokeh_source.data[self.mean_signal])
def set_axis(self, axis):
self.axis = axis
self.line.y_range_name = axis
def toggle_axis(self):
if self.axis == 'default':
self.set_axis('secondary')
else:
self.set_axis('default')
class SignalsFileBase:
def __init__(self):
self.full_csv_path = ""
self.dir = ""
self.filename = ""
self.signals_averaging_window = 1
self.show_bollinger_bands = False
self.csv = None
self.bokeh_source = None
self.bokeh_source_orig = None
self.last_modified = None
self.signals = {}
self.separate_files = False
def load_csv(self):
pass
def update_source_and_signals(self):
# create bokeh data sources
self.bokeh_source_orig = bm.ColumnDataSource(self.csv)
self.bokeh_source_orig.data['index'] = self.bokeh_source_orig.data[x_axis]
if self.bokeh_source is None:
self.bokeh_source = bm.ColumnDataSource(self.csv)
else:
# self.bokeh_source.data = self.bokeh_source_orig.data
# smooth the data if necessary
self.change_averaging_window(self.signals_averaging_window, force=True)
# create all the signals
if len(self.signals.keys()) == 0:
self.signals = {}
unique_signal_names = []
for name in self.csv.columns:
if len(name.split('/')) == 1:
unique_signal_names.append(name)
else:
unique_signal_names.append('/'.join(name.split('/')[:-1]))
unique_signal_names = list(set(unique_signal_names))
for signal_name in unique_signal_names:
self.signals[signal_name] = Signal(signal_name, self)
def load(self):
self.load_csv()
self.update_source_and_signals()
def reload_data(self, signals):
# this function is a workaround to reload the data of all the signals
# if the data doesn't change, bokeh does not refreshes the line
self.change_averaging_window(self.signals_averaging_window + 1, force=True)
self.change_averaging_window(self.signals_averaging_window - 1, force=True)
def change_averaging_window(self, new_size, force=False, signals=None):
if force or self.signals_averaging_window != new_size:
self.signals_averaging_window = new_size
win = np.ones(new_size) / new_size
temp_data = self.bokeh_source_orig.data.copy()
for col in self.bokeh_source.data.keys():
if col == 'index' or col in x_axis_options \
or (signals and not any(col in signal for signal in signals)):
temp_data[col] = temp_data[col][:-new_size]
continue
temp_data[col] = np.convolve(self.bokeh_source_orig.data[col], win, mode='same')[:-new_size]
self.bokeh_source.data = temp_data
# smooth bollinger bands
for signal in self.signals.values():
if signal.has_bollinger_bands:
signal.set_bands_source()
def hide_all_signals(self):
for signal_name in self.signals.keys():
self.set_signal_selection(signal_name, False)
def set_signal_selection(self, signal_name, val):
self.signals[signal_name].set_selected(val)
def change_bollinger_bands_state(self, new_state):
self.show_bollinger_bands = new_state
for signal in self.signals.values():
signal.change_bollinger_bands_state(new_state)
def file_was_modified_on_disk(self):
pass
def get_range_of_selected_signals_on_axis(self, axis, selected_signal=None):
max_val = -float('inf')
min_val = float('inf')
for signal in self.signals.values():
if (selected_signal and signal.name == selected_signal) or (signal.selected and signal.axis == axis):
max_val = max(max_val, signal.max_val)
min_val = min(min_val, signal.min_val)
return min_val, max_val
def get_selected_signals(self):
signals = []
for signal in self.signals.values():
if signal.selected:
signals.append(signal)
return signals
def show_files_separately(self, val):
pass
class SignalsFile(SignalsFileBase):
def __init__(self, csv_path, load=True):
SignalsFileBase.__init__(self)
self.full_csv_path = csv_path
self.dir, self.filename, _ = utils.break_file_path(csv_path)
if load:
self.load()
# this helps set the correct x axis
self.change_averaging_window(1, force=True)
def load_csv(self):
# load csv and fix sparse data.
# csv can be in the middle of being written so we use try - except
self.csv = None
while self.csv is None:
try:
self.csv = pd.read_csv(self.full_csv_path)
break
except pandas_common.EmptyDataError:
self.csv = None
continue
self.csv = self.csv.interpolate()
self.csv.fillna(value=0, inplace=True)
self.last_modified = os.path.getmtime(self.full_csv_path)
def file_was_modified_on_disk(self):
return self.last_modified != os.path.getmtime(self.full_csv_path)
class SignalsFilesGroup(SignalsFileBase):
def __init__(self, csv_paths):
SignalsFileBase.__init__(self)
self.full_csv_paths = csv_paths
self.signals_files = []
if len(csv_paths) == 1 and os.path.isdir(csv_paths[0]):
self.signals_files = [SignalsFile(str(file), load=False) for file in add_directory_csv_files(csv_paths[0])]
else:
for csv_path in csv_paths:
if os.path.isdir(csv_path):
self.signals_files.append(SignalsFilesGroup(add_directory_csv_files(csv_path)))
else:
self.signals_files.append(SignalsFile(str(csv_path), load=False))
if len(csv_paths) == 1:
# get the parent directory name (since the current directory is the timestamp directory)
self.dir = os.path.abspath(os.path.join(os.path.dirname(csv_paths[0]), '..'))
else:
# get the common directory for all the experiments
self.dir = os.path.dirname(os.path.commonprefix(csv_paths))
self.filename = '{} - Group({})'.format(os.path.basename(self.dir), len(self.signals_files))
self.load()
# this helps set the correct x axis
self.change_averaging_window(1, force=True)
def load_csv(self):
corrupted_files_idx = []
for idx, signal_file in enumerate(self.signals_files):
signal_file.load_csv()
if not all(option in signal_file.csv.keys() for option in x_axis_options):
print("Warning: {} file seems to be corrupted and does contain the necessary columns "
"and will not be rendered".format(signal_file.filename))
corrupted_files_idx.append(idx)
for file_idx in corrupted_files_idx:
del self.signals_files[file_idx]
# get the stats of all the columns
csv_group = pd.concat([signals_file.csv for signals_file in self.signals_files])
columns_to_remove = [s for s in csv_group.columns if '/Stdev' in s] + \
[s for s in csv_group.columns if '/Min' in s] + \
[s for s in csv_group.columns if '/Max' in s]
for col in columns_to_remove:
del csv_group[col]
csv_group = csv_group.groupby(csv_group.index)
self.csv_mean = csv_group.mean()
self.csv_mean.columns = [s + '/Mean' for s in self.csv_mean.columns]
self.csv_stdev = csv_group.std()
self.csv_stdev.columns = [s + '/Stdev' for s in self.csv_stdev.columns]
self.csv_min = csv_group.min()
self.csv_min.columns = [s + '/Min' for s in self.csv_min.columns]
self.csv_max = csv_group.max()
self.csv_max.columns = [s + '/Max' for s in self.csv_max.columns]
# get the indices from the file with the least number of indices and which is not an evaluation worker
file_with_min_indices = self.signals_files[0]
for signals_file in self.signals_files:
if signals_file.csv.shape[0] < file_with_min_indices.csv.shape[0] and \
'Training reward' in signals_file.csv.keys():
file_with_min_indices = signals_file
self.index_columns = file_with_min_indices.csv[x_axis_options]
# concat the stats and the indices columns
num_rows = file_with_min_indices.csv.shape[0]
self.csv = pd.concat([self.index_columns, self.csv_mean.head(num_rows), self.csv_stdev.head(num_rows),
self.csv_min.head(num_rows), self.csv_max.head(num_rows)], axis=1)
# remove the stat columns for the indices columns
columns_to_remove = [s + '/Mean' for s in x_axis_options] + \
[s + '/Stdev' for s in x_axis_options] + \
[s + '/Min' for s in x_axis_options] + \
[s + '/Max' for s in x_axis_options]
for col in columns_to_remove:
del self.csv[col]
# remove NaNs
self.csv.fillna(value=0, inplace=True) # removing this line will make bollinger bands fail
for key in self.csv.keys():
if 'Stdev' in key and 'Evaluation' not in key:
self.csv[key] = self.csv[key].fillna(value=0)
for signal_file in self.signals_files:
signal_file.update_source_and_signals()
def change_averaging_window(self, new_size, force=False, signals=None):
for signal_file in self.signals_files:
signal_file.change_averaging_window(new_size, force, signals)
SignalsFileBase.change_averaging_window(self, new_size, force, signals)
def set_signal_selection(self, signal_name, val):
self.show_files_separately(self.separate_files)
SignalsFileBase.set_signal_selection(self, signal_name, val)
def file_was_modified_on_disk(self):
for signal_file in self.signals_files:
if signal_file.file_was_modified_on_disk():
return True
return False
def show_files_separately(self, val):
self.separate_files = val
for signal in self.signals.values():
if signal.selected:
if val:
signal.set_dash("4 4")
else:
signal.set_dash("")
for signal_file in self.signals_files:
try:
if val:
signal_file.set_signal_selection(signal.name, signal.selected)
else:
signal_file.set_signal_selection(signal.name, False)
except:
pass
class RunType(enum.Enum):
SINGLE_FOLDER_SINGLE_FILE = 1
SINGLE_FOLDER_MULTIPLE_FILES = 2
MULTIPLE_FOLDERS_SINGLE_FILES = 3
MULTIPLE_FOLDERS_MULTIPLE_FILES = 4
UNKNOWN = 0
class FolderType(enum.Enum):
SINGLE_FILE = 1
MULTIPLE_FILES = 2
MULTIPLE_FOLDERS = 3
EMPTY = 4
dialog = DialogApp()
# read data
patches = {}
signals_files = {}
selected_file = None
x_axis = 'Episode #'
x_axis_options = ['Episode #', 'Total steps', 'Wall-Clock Time']
current_color = 0
# spinner
root_dir = os.path.dirname(os.path.abspath(__file__))
with open(os.path.join(root_dir, 'spinner.css'), 'r') as f:
spinner_style = """<style>{}</style>""".format(f.read())
spinner_html = """<ul class="spinner"><li></li><li></li><li></li><li></li></ul>"""
spinner = bw.Div(text="""""")
# file refresh time placeholder
refresh_info = bw.Div(text="""""", width=210)
# create figures
plot = bp.figure(plot_width=1200, plot_height=800,
tools='pan,box_zoom,wheel_zoom,crosshair,undo,redo,reset,save',
toolbar_location='above', x_axis_label='Episodes',
x_range=bm.Range1d(0, 10000), y_range=bm.Range1d(0, 100000))
plot.extra_y_ranges = {"secondary": bm.Range1d(start=-100, end=200)}
plot.add_layout(bm.LinearAxis(y_range_name="secondary"), 'right')
# legend
div = bw.Div(text="""""")
legend = bl.widgetbox([div])
bokeh_legend = bm.Legend(
items=[("12345678901234567890123456789012345678901234567890", [])], # 50 letters
# items=[(" ", [])], # 50 letters
location=(-20, 0), orientation="vertical",
border_line_color="black",
label_text_font_size={'value': '9pt'},
margin=30
)
plot.add_layout(bokeh_legend, "right")
def update_axis_range(name, range_placeholder):
max_val = -float('inf')
min_val = float('inf')
selected_signal = None
if name in x_axis_options:
selected_signal = name
for signals_file in signals_files.values():
curr_min_val, curr_max_val = signals_file.get_range_of_selected_signals_on_axis(name, selected_signal)
max_val = max(max_val, curr_max_val)
min_val = min(min_val, curr_min_val)
if min_val != float('inf'):
range = max_val - min_val
range_placeholder.start = min_val - 0.1 * range
range_placeholder.end = max_val + 0.1 * range
# update axes ranges
def update_ranges():
update_axis_range('default', plot.y_range)
update_axis_range('secondary', plot.extra_y_ranges['secondary'])
def get_all_selected_signals():
signals = []
for signals_file in signals_files.values():
signals += signals_file.get_selected_signals()
return signals
# update legend using the legend text dictionary
def update_legend():
legend_text = """<div></div>"""
selected_signals = get_all_selected_signals()
items = []
for signal in selected_signals:
side_sign = "<" if signal.axis == 'default' else ">"
legend_text += """<div style='color: {}'><b>{} {}</b></div>"""\
.format(signal.color, side_sign, signal.full_name)
items.append((signal.full_name, [signal.line]))
div.text = legend_text
# the visible=false => visible=true is a hack to make the legend render again
bokeh_legend.visible = False
bokeh_legend.items = items
bokeh_legend.visible = True
# select lines to display
def select_data(args, old, new):
if selected_file is None:
return
show_spinner()
selected_signals = new
for signal_name in selected_file.signals.keys():
is_selected = signal_name in selected_signals
selected_file.set_signal_selection(signal_name, is_selected)
# update axes ranges
update_ranges()
update_axis_range(x_axis, plot.x_range)
# update the legend
update_legend()
hide_spinner()
# add new lines to the plot
def plot_signals(signals_file, signals):
for idx, signal in enumerate(signals):
signal.line = plot.line('index', signal.name, source=signals_file.bokeh_source,
line_color=signal.color, line_width=2)
def open_file_dialog():
return dialog.getFileDialog()
def open_directory_dialog():
return dialog.getDirDialog()
def show_spinner():
spinner.text = spinner_style + spinner_html
def hide_spinner():
spinner.text = ""
# will create a group from the files
def create_files_group_signal(files):
global selected_file
signals_file = SignalsFilesGroup(files)
signals_files[signals_file.filename] = signals_file
filenames = [signals_file.filename]
files_selector.options += filenames
files_selector.value = filenames[0]
selected_file = signals_file
# load files from disk as a group
def load_files_group():
show_spinner()
files = open_file_dialog()
# no files selected
if not files or not files[0]:
hide_spinner()
return
change_displayed_doc()
if len(files) == 1:
create_files_signal(files)
else:
create_files_group_signal(files)
change_selected_signals_in_data_selector([""])
hide_spinner()
# classify the folder as containing a single file, multiple files or only folders
def classify_folder(dir_path):
files = [f for f in os.listdir(dir_path) if os.path.isfile(os.path.join(dir_path, f)) and f.endswith('.csv')]
folders = [d for d in os.listdir(dir_path) if os.path.isdir(os.path.join(dir_path, d))]
if len(files) == 1:
return FolderType.SINGLE_FILE
elif len(files) > 1:
return FolderType.MULTIPLE_FILES
elif len(folders) >= 1:
return FolderType.MULTIPLE_FOLDERS
else:
return FolderType.EMPTY
# finds if this is single-threaded or multi-threaded
def get_run_type(dir_path):
folder_type = classify_folder(dir_path)
if folder_type == FolderType.SINGLE_FILE:
return RunType.SINGLE_FOLDER_SINGLE_FILE
elif folder_type == FolderType.MULTIPLE_FILES:
return RunType.SINGLE_FOLDER_MULTIPLE_FILES
elif folder_type == FolderType.MULTIPLE_FOLDERS:
# folder contains sub dirs -> we assume we can classify the folder using only the first sub dir
sub_dirs = [d for d in os.listdir(dir_path) if os.path.isdir(os.path.join(dir_path, d))]
# checking only the first folder in the root dir for its type, since we assume that all sub dirs will share the
# same structure (i.e. if one is a result of multi-threaded run, so will all the other).
folder_type = classify_folder(os.path.join(dir_path, sub_dirs[0]))
if folder_type == FolderType.SINGLE_FILE:
folder_type = RunType.MULTIPLE_FOLDERS_SINGLE_FILES
elif folder_type == FolderType.MULTIPLE_FILES:
folder_type = RunType.MULTIPLE_FOLDERS_MULTIPLE_FILES
return folder_type
# takes path to dir and recursively adds all it's files to paths
def add_directory_csv_files(dir_path, paths=None):
if not paths:
paths = []
for p in os.listdir(dir_path):
path = os.path.join(dir_path, p)
if os.path.isdir(path):
# call recursively for each dir
paths = add_directory_csv_files(path, paths)
elif os.path.isfile(path) and path.endswith('.csv'):
# add every file to the list
paths.append(path)
return paths
# create a signal file from the directory path according to the directory underlying structure
def handle_dir(dir_path, run_type):
paths = add_directory_csv_files(dir_path)
if run_type == RunType.SINGLE_FOLDER_SINGLE_FILE:
create_files_signal(paths)
elif run_type == RunType.SINGLE_FOLDER_MULTIPLE_FILES:
create_files_group_signal(paths)
elif run_type == RunType.MULTIPLE_FOLDERS_SINGLE_FILES:
create_files_group_signal(paths)
elif run_type == RunType.MULTIPLE_FOLDERS_MULTIPLE_FILES:
sub_dirs = [d for d in os.listdir(dir_path) if os.path.isdir(os.path.join(dir_path, d))]
# for d in sub_dirs:
# paths = add_directory_csv_files(os.path.join(dir_path, d))
# create_files_group_signal(paths)
create_files_group_signal([os.path.join(dir_path, d) for d in sub_dirs])
# load directory from disk as a group
def load_directory_group():
global selected_file
show_spinner()
directory = open_directory_dialog()
# no files selected
if not directory:
hide_spinner()
return
change_displayed_doc()
handle_dir(directory, get_run_type(directory))
change_selected_signals_in_data_selector([""])
hide_spinner()
def create_files_signal(files):
global selected_file
new_signal_files = []
for idx, file_path in enumerate(files):
signals_file = SignalsFile(str(file_path))
signals_files[signals_file.filename] = signals_file
new_signal_files.append(signals_file)
filenames = [f.filename for f in new_signal_files]
files_selector.options += filenames
files_selector.value = filenames[0]
selected_file = new_signal_files[0]
# load files from disk
def load_files():
show_spinner()
files = open_file_dialog()
# no files selected
if not files or not files[0]:
hide_spinner()
return
create_files_signal(files)
hide_spinner()
change_selected_signals_in_data_selector([""])
def unload_file():
global selected_file
global signals_files
if selected_file is None:
return
selected_file.hide_all_signals()
del signals_files[selected_file.filename]
data_selector.options = [""]
filenames = itertools.cycle(files_selector.options)
files_selector.options.remove(selected_file.filename)
if len(files_selector.options) > 0:
files_selector.value = next(filenames)
else:
files_selector.value = None
update_legend()
refresh_info.text = ""
# reload the selected csv file
def reload_all_files(force=False):
for file_to_load in signals_files.values():
if force or file_to_load.file_was_modified_on_disk():
file_to_load.load()
refresh_info.text = "last update: " + str(datetime.datetime.now()).split(".")[0]
# unselect the currently selected signals and then select the requested signals in the data selector
def change_selected_signals_in_data_selector(selected_signals):
# the default bokeh way is not working due to a bug since Bokeh 0.12.6 (https://github.com/bokeh/bokeh/issues/6501)
# this will currently cause the signals to change color
for value in list(data_selector.value):
if value in data_selector.options:
index = data_selector.options.index(value)
data_selector.options.remove(value)
data_selector.value.remove(value)
data_selector.options.insert(index, value)
data_selector.value = selected_signals
# change data options according to the selected file
def change_data_selector(args, old, new):
global selected_file
if new is None:
selected_file = None
return
show_spinner()
selected_file = signals_files[new]
data_selector.options = sorted(list(selected_file.signals.keys()))
selected_signal_names = [s.name for s in selected_file.signals.values() if s.selected]
if not selected_signal_names:
selected_signal_names = [""]
change_selected_signals_in_data_selector(selected_signal_names)
averaging_slider.value = selected_file.signals_averaging_window
group_cb.active = [0 if selected_file.show_bollinger_bands else None]
group_cb.active += [1 if selected_file.separate_files else None]
hide_spinner()
# smooth all the signals of the selected file
def update_averaging(args, old, new):
show_spinner()
selected_file.change_averaging_window(new)
hide_spinner()
def change_x_axis(val):
global x_axis
show_spinner()
x_axis = x_axis_options[val]
plot.xaxis.axis_label = x_axis
reload_all_files(force=True)
update_axis_range(x_axis, plot.x_range)
hide_spinner()
# move the signal between the main and secondary Y axes
def toggle_second_axis():
show_spinner()
signals = selected_file.get_selected_signals()
for signal in signals:
signal.toggle_axis()
update_ranges()
update_legend()
# this is just for redrawing the signals
selected_file.reload_data([signal.name for signal in signals])
hide_spinner()
def toggle_group_property(new):
# toggle show / hide Bollinger bands
selected_file.change_bollinger_bands_state(0 in new)
# show a separate signal for each file in a group
selected_file.show_files_separately(1 in new)
def change_displayed_doc():
if doc.roots[0] == landing_page:
doc.remove_root(landing_page)
doc.add_root(layout)
# Color selection - most of these functions are taken from bokeh examples (plotting/color_sliders.py)
def select_color(attr, old, new):
show_spinner()
signals = selected_file.get_selected_signals()
for signal in signals:
signal.set_color(rgb_to_hex(crRGBs[new['1d']['indices'][0]]))
hide_spinner()
def generate_color_range(N, I):
HSV_tuples = [(x*1.0/N, 0.5, I) for x in range(N)]
RGB_tuples = map(lambda x: colorsys.hsv_to_rgb(*x), HSV_tuples)
for_conversion = []
for RGB_tuple in RGB_tuples:
for_conversion.append((int(RGB_tuple[0]*255), int(RGB_tuple[1]*255), int(RGB_tuple[2]*255)))
hex_colors = [rgb_to_hex(RGB_tuple) for RGB_tuple in for_conversion]
return hex_colors, for_conversion
# convert RGB tuple to hexadecimal code
def rgb_to_hex(rgb):
return '#%02x%02x%02x' % rgb
# convert hexadecimal to RGB tuple
def hex_to_dec(hex):
red = ''.join(hex.strip('#')[0:2])
green = ''.join(hex.strip('#')[2:4])
blue = ''.join(hex.strip('#')[4:6])
return int(red, 16), int(green, 16), int(blue,16)
color_resolution = 1000
brightness = 0.75 # change to have brighter/darker colors
crx = list(range(1, color_resolution+1)) # the resolution is 1000 colors
cry = [5 for i in range(len(crx))]
crcolor, crRGBs = generate_color_range(color_resolution, brightness) # produce spectrum
# ---------------- Build Website Layout -------------------
# select file
file_selection_button = bw.Button(label="Select Files", button_type="success", width=120)
file_selection_button.on_click(load_files_group)
files_selector_spacer = bl.Spacer(width=10)
group_selection_button = bw.Button(label="Select Directory", button_type="primary", width=140)
group_selection_button.on_click(load_directory_group)
unload_file_button = bw.Button(label="Unload", button_type="danger", width=50)
unload_file_button.on_click(unload_file)
# files selection box
files_selector = bw.Select(title="Files:", options=[], width=200)
files_selector.on_change('value', change_data_selector)
# data selection box
data_selector = bw.MultiSelect(title="Data:", options=[], size=12)
data_selector.on_change('value', select_data)
# x axis selection box
x_axis_selector_title = bw.Div(text="""X Axis:""")
x_axis_selector = bw.RadioButtonGroup(labels=x_axis_options, active=0)
x_axis_selector.on_click(change_x_axis)
# toggle second axis bw.button
toggle_second_axis_button = bw.Button(label="Toggle Second Axis", button_type="success")
toggle_second_axis_button.on_click(toggle_second_axis)
# averaging slider
averaging_slider = bw.Slider(title="Averaging window", start=1, end=101, step=10)
averaging_slider.on_change('value', update_averaging)
# group properties checkbox
group_cb = bw.CheckboxGroup(labels=["Show statistics bands", "Ungroup signals"], active=[])
group_cb.on_click(toggle_group_property)
# color selector
color_selector_title = bw.Div(text="""Select Color:""")
crsource = bm.ColumnDataSource(data=dict(x=crx, y=cry, crcolor=crcolor, RGBs=crRGBs))
color_selector = bp.figure(x_range=(0, color_resolution), y_range=(0, 10),
plot_width=300, plot_height=40,
tools='tap')
color_selector.axis.visible = False
color_range = color_selector.rect(x='x', y='y', width=1, height=10,
color='crcolor', source=crsource)
crsource.on_change('selected', select_color)
color_range.nonselection_glyph = color_range.glyph
color_selector.toolbar.logo = None
color_selector.toolbar_location = None
# title
title = bw.Div(text="""<h1>Coach Dashboard</h1>""")
# landing page
landing_page_description = bw.Div(text="""<h3>Start by selecting an experiment file or directory to open:</h3>""")
center = bw.Div(text="""<style>html { text-align: center; } </style>""")
center_buttons = bw.Div(text="""<style>.bk-grid-row .bk-layout-fixed { margin: 0 auto; }</style>""", width=0)
landing_page = bl.column(center,
title,
landing_page_description,
bl.row(center_buttons),
bl.row(file_selection_button, sizing_mode='scale_width'),
bl.row(group_selection_button, sizing_mode='scale_width'),
sizing_mode='scale_width')
# main layout of the document
layout = bl.row(file_selection_button, files_selector_spacer, group_selection_button, width=300)
layout = bl.column(layout, files_selector)
layout = bl.column(layout, bl.row(refresh_info, unload_file_button))
layout = bl.column(layout, data_selector)
layout = bl.column(layout, color_selector_title)
layout = bl.column(layout, color_selector)
layout = bl.column(layout, x_axis_selector_title)
layout = bl.column(layout, x_axis_selector)
layout = bl.column(layout, group_cb)
layout = bl.column(layout, toggle_second_axis_button)
layout = bl.column(layout, averaging_slider)
# layout = bl.column(layout, legend)
layout = bl.row(layout, plot)
layout = bl.column(title, layout)
layout = bl.column(layout, spinner)
doc = bp.curdoc()
doc.add_root(landing_page)
doc.add_periodic_callback(reload_all_files, 20000)
plot.y_range = bm.Range1d(0, 100)
plot.extra_y_ranges['secondary'] = bm.Range1d(0, 100)
# show load file dialog immediately on start
#doc.add_timeout_callback(load_files, 1000)
if __name__ == "__main__":
# find an open port and run the server
import socket
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
port = 12345
while True:
try:
s.bind(("127.0.0.1", port))
break
except socket.error as e:
if e.errno == 98:
port += 1
s.close()
os.system('bokeh serve --show dashboard.py --port {}'.format(port))

View File

@@ -1,49 +0,0 @@
#
# Copyright (c) 2017 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import matplotlib.pyplot as plt
import numpy as np
def show_observation_stack(stack, channels_last=False):
if isinstance(stack, list): # is list
stack_size = len(stack)
elif len(stack.shape) == 3:
stack_size = stack.shape[0] # is numpy array
elif len(stack.shape) == 4:
stack_size = stack.shape[1] # ignore batch dimension
stack = stack[0]
else:
assert False, ""
if channels_last:
stack = np.transpose(stack, (2, 0, 1))
stack_size = stack.shape[0]
for i in range(stack_size):
plt.subplot(1, stack_size, i + 1)
plt.imshow(stack[i], cmap='gray')
plt.show()
def show_diff_between_two_observations(observation1, observation2):
plt.imshow(observation1 - observation2, cmap='gray')
plt.show()
def plot_grayscale_observation(observation):
plt.imshow(observation, cmap='gray')
plt.show()

View File

@@ -1,205 +0,0 @@
#!/bin/bash -e
prompt () {
# prints a yes / no question to the user and returns the answer
# first argument is the prompt question
# second argument is the default answer - Y / N
local default_answer
# set the default value
case "${2}" in
y|Y ) default_answer=1; options="[Y/n]";;
n|N ) default_answer=0; options="[y/N]";;
"" ) default_answer=; options="[y/n]";;
* ) echo "invalid default value"; exit;;
esac
while true; do
# read the user choice
read -p "${1} ${options} " choice
# return the choice or the default value if an enter was pressed
case "${choice}" in
y|Y ) retval=1; return;;
n|N ) retval=0; return;;
"" ) if [ ! -z "${default_answer}" ]; then retval=${default_answer}; return; fi;;
esac
done
}
add_to_bashrc () {
# adds an env variable to the bashrc
# first argument is the variable name
# second argument is the variable value
EXISTS_IN_BASHRC=`awk '/${2}/{print $1}' ~/.bashrc`
if [ "${EXISTS_IN_BASHRC}" == "" ]; then
echo "export ${1}=${2}" >> ~/.bashrc
fi
}
GET_PREFERENCES_MANUALLY=1
INSTALL_COACH=0
INSTALL_DASHBOARD=0
INSTALL_GYM=0
INSTALL_NEON=0
INSTALL_VIRTUAL_ENVIRONMENT=1
# Get user preferences
TEMP=`getopt -o cpgvrmeNndh \
--long coach,dashboard,gym,no_virtual_environment,neon,debug,help \
-- "$@"`
eval set -- "$TEMP"
while true; do
#for i in "$@"
case ${1} in
-c|--coach)
INSTALL_COACH=1
GET_PREFERENCES_MANUALLY=0
shift;;
-p|--dashboard)
INSTALL_DASHBOARD=1
GET_PREFERENCES_MANUALLY=0;
shift;;
-g|--gym)
INSTALL_GYM=1
GET_PREFERENCES_MANUALLY=0;
shift;;
-N|--no_virtual_environment)
INSTALL_VIRTUAL_ENVIRONMENT=0
GET_PREFERENCES_MANUALLY=0;
shift;;
-ne|--neon)
INSTALL_NEON=1
GET_PREFERENCES_MANUALLY=0;
shift;;
-d|--debug) set -x; shift;;
-h|--help)
echo "Available command line arguments:"
echo ""
echo " -c | --coach - Install Coach requirements"
echo " -p | --dashboard - Install Dashboard requirements"
echo " -g | --gym - Install Gym support"
echo " -N | --no_virtual_environment - Do not install inside of a virtual environment"
echo " -d | --debug - Run in debug mode"
echo " -h | --help - Display this help message"
echo ""
exit;;
--) shift; break;;
*) break;; # unknown option;;
esac
done
if [ ${GET_PREFERENCES_MANUALLY} -eq 1 ]; then
prompt "Install Coach requirements?" Y
INSTALL_COACH=${retval}
prompt "Install Dashboard requirements?" Y
INSTALL_DASHBOARD=${retval}
prompt "Install Gym support?" Y
INSTALL_GYM=${retval}
prompt "Install neon support?" Y
INSTALL_NEON=${retval}
fi
IN_VIRTUAL_ENV=`python3 -c 'import sys; print("%i" % hasattr(sys, "real_prefix"))'`
# basic installations
sudo -E apt-get install python3-pip cmake zlib1g-dev python3-tk python-opencv -y
pip3 install --upgrade pip
# if we are not in a virtual environment, we will create one with the appropriate python version and then activate it
# if we are already in a virtual environment,
if [ ${INSTALL_VIRTUAL_ENVIRONMENT} -eq 1 ]; then
if [ ${IN_VIRTUAL_ENV} -eq 0 ]; then
sudo -E pip3 install virtualenv
virtualenv -p python3 coach_env
. coach_env/bin/activate
fi
fi
#------------------------------------------------
# From now on we are in a virtual environment
#------------------------------------------------
# get python local and global paths
python_version=python$(python -c "import sys; print (str(sys.version_info[0])+'.'+str(sys.version_info[1]))")
var=( $(which -a $python_version) )
get_python_lib_cmd="from distutils.sysconfig import get_python_lib; print (get_python_lib())"
lib_virtualenv_path=$(python -c "$get_python_lib_cmd")
lib_system_path=$(${var[-1]} -c "$get_python_lib_cmd")
# Boost libraries
sudo -E apt-get install libboost-all-dev -y
# Coach
if [ ${INSTALL_COACH} -eq 1 ]; then
echo "Installing Coach requirements"
pip3 install -r ./requirements_coach.txt
fi
# Dashboard
if [ ${INSTALL_DASHBOARD} -eq 1 ]; then
echo "Installing Dashboard requirements"
pip3 install -r ./requirements_dashboard.txt
sudo -E apt-get install dpkg-dev build-essential python3.5-dev libjpeg-dev libtiff-dev libsdl1.2-dev libnotify-dev \
freeglut3 freeglut3-dev libsm-dev libgtk2.0-dev libgtk-3-dev libwebkitgtk-dev libgtk-3-dev libwebkitgtk-3.0-dev libgstreamer-plugins-base1.0-dev -y
sudo -E -H pip3 install -U --pre -f \
https://wxpython.org/Phoenix/snapshot-builds/linux/gtk3/ubuntu-16.04/wxPython-4.0.0a3.dev3059+4a5c5d9-cp35-cp35m-linux_x86_64.whl wxPython
# link wxPython Phoenix library into the virtualenv since it is installed with apt-get and not accessible
libs=( wx )
for lib in ${libs[@]}
do
ln -sf $lib_system_path/$lib $lib_virtualenv_path/$lib
done
fi
# Gym
if [ ${INSTALL_GYM} -eq 1 ]; then
echo "Installing Gym support"
sudo -E apt-get install libav-tools libsdl2-dev swig cmake -y
pip3 install box2d # for bipedal walker etc.
pip3 install gym
fi
# NGraph and Neon
if [ ${INSTALL_NEON} -eq 1 ]; then
echo "Installing neon requirements"
# MKL
git clone https://github.com/01org/mkl-dnn.git
cd mkl-dnn
cd scripts && ./prepare_mkl.sh && cd ..
mkdir -p build && cd build && cmake .. && make -j
sudo make install -j
cd ../..
export MKLDNN_ROOT=/usr/local/
add_to_bashrc MKLDNN_ROOT ${MKLDNN_ROOT}
export LD_LIBRARY_PATH=$MKLDNN_ROOT/lib:$LD_LIBRARY_PATH
add_to_bashrc LD_LIBRARY_PATH ${MKLDNN_ROOT}/lib:$LD_LIBRARY_PATH
# NGraph
git clone https://github.com/NervanaSystems/ngraph.git
cd ngraph
make install -j
cd ..
# Neon
sudo -E apt-get install libhdf5-dev libyaml-dev pkg-config clang virtualenv libcurl4-openssl-dev libopencv-dev libsox-dev -y
pip3 install nervananeon
fi
if ! [ -x "$(command -v nvidia-smi)" ]; then
# Intel Optimized TensorFlow
#pip3 install https://anaconda.org/intel/tensorflow/1.3.0/download/tensorflow-1.3.0-cp35-cp35m-linux_x86_64.whl
pip3 install https://anaconda.org/intel/tensorflow/1.4.0/download/tensorflow-1.4.0-cp35-cp35m-linux_x86_64.whl
else
# GPU supported TensorFlow
pip3 install tensorflow-gpu==1.4.1
fi

View File

@@ -7,3 +7,5 @@ pygame==1.9.3
PyOpenGL==3.1.0
scipy==0.19.0
scikit-image==0.13.0
gym
tensorflow

View File

@@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import argparse
import atexit
import json
import os
@@ -22,13 +23,12 @@ import subprocess
import sys
import time
import agents
import argparse
import configurations as conf
import environments
import logger
import presets
import utils
from coach import agents # noqa
from coach import configurations as conf
from coach import environments
from coach import logger
from coach import presets
from coach import utils
if len(set(logger.failed_imports)) > 0:

View File

@@ -20,10 +20,10 @@ import time
import tensorflow as tf
import agents
import environments
import logger
import presets
from coach import agents
from coach import environments
from coach import logger
from coach import presets
start_time = time.time()

28
setup.py Executable file
View File

@@ -0,0 +1,28 @@
#!/usr/bin/env python3
"""
Setup for the coach project
"""
from setuptools import setup, find_packages
setup(name="coach",
version="0.7",
description="Reinforcement Learning Coach",
author="Caspi, Itai and Leibovich, Gal and Novik, Gal",
author_email="gal.novik@intel.com",
url="https://github.com/NervanaSystems/coach",
packages=find_packages(),
download_url="https://github.com/NervanaSystems/coach",
keywords=["reinforcement", "machine", "learning"],
install_requires=["annoy", "Pillow", "matplotlib", "numpy", "pandas",
"pygame", "PyOpenGL", "scipy", "scikit-image",
"tensorflow", "gym"],
scripts=["scripts/coach"],
classifiers=["Programming Language :: Python :: 3",
"Development Status :: 4 - Beta",
"Environment :: Console",
"Intended Audience :: End Users/Desktop",
"License :: OSI Approved :: Apache Software License",
"Operating System :: OS Independent",
"Topic :: Scientific/Engineering :: Artificial "
"Intelligence"])