mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
Moved coach to its top level module.
This commit is contained in:
@@ -13,28 +13,28 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from agents.actor_critic_agent import ActorCriticAgent
|
||||
from agents.agent import Agent
|
||||
from agents.bc_agent import BCAgent
|
||||
from agents.bootstrapped_dqn_agent import BootstrappedDQNAgent
|
||||
from agents.categorical_dqn_agent import CategoricalDQNAgent
|
||||
from agents.clipped_ppo_agent import ClippedPPOAgent
|
||||
from agents.ddpg_agent import DDPGAgent
|
||||
from agents.ddqn_agent import DDQNAgent
|
||||
from agents.dfp_agent import DFPAgent
|
||||
from agents.dqn_agent import DQNAgent
|
||||
from agents.human_agent import HumanAgent
|
||||
from agents.imitation_agent import ImitationAgent
|
||||
from agents.mmc_agent import MixedMonteCarloAgent
|
||||
from agents.n_step_q_agent import NStepQAgent
|
||||
from agents.naf_agent import NAFAgent
|
||||
from agents.nec_agent import NECAgent
|
||||
from agents.pal_agent import PALAgent
|
||||
from agents.policy_gradients_agent import PolicyGradientsAgent
|
||||
from agents.policy_optimization_agent import PolicyOptimizationAgent
|
||||
from agents.ppo_agent import PPOAgent
|
||||
from agents.qr_dqn_agent import QuantileRegressionDQNAgent
|
||||
from agents.value_optimization_agent import ValueOptimizationAgent
|
||||
from coach.agents.actor_critic_agent import ActorCriticAgent
|
||||
from coach.agents.agent import Agent
|
||||
from coach.agents.bc_agent import BCAgent
|
||||
from coach.agents.bootstrapped_dqn_agent import BootstrappedDQNAgent
|
||||
from coach.agents.categorical_dqn_agent import CategoricalDQNAgent
|
||||
from coach.agents.clipped_ppo_agent import ClippedPPOAgent
|
||||
from coach.agents.ddpg_agent import DDPGAgent
|
||||
from coach.agents.ddqn_agent import DDQNAgent
|
||||
from coach.agents.dfp_agent import DFPAgent
|
||||
from coach.agents.dqn_agent import DQNAgent
|
||||
from coach.agents.human_agent import HumanAgent
|
||||
from coach.agents.imitation_agent import ImitationAgent
|
||||
from coach.agents.mmc_agent import MixedMonteCarloAgent
|
||||
from coach.agents.n_step_q_agent import NStepQAgent
|
||||
from coach.agents.naf_agent import NAFAgent
|
||||
from coach.agents.nec_agent import NECAgent
|
||||
from coach.agents.pal_agent import PALAgent
|
||||
from coach.agents.policy_gradients_agent import PolicyGradientsAgent
|
||||
from coach.agents.policy_optimization_agent import PolicyOptimizationAgent
|
||||
from coach.agents.ppo_agent import PPOAgent
|
||||
from coach.agents.qr_dqn_agent import QuantileRegressionDQNAgent
|
||||
from coach.agents.value_optimization_agent import ValueOptimizationAgent
|
||||
|
||||
__all__ = [ActorCriticAgent,
|
||||
Agent,
|
||||
@@ -16,9 +16,9 @@
|
||||
import numpy as np
|
||||
from scipy import signal
|
||||
|
||||
from agents import policy_optimization_agent as poa
|
||||
import utils
|
||||
import logger
|
||||
from coach.agents import policy_optimization_agent as poa
|
||||
from coach import utils
|
||||
from coach import logger
|
||||
|
||||
|
||||
# Actor Critic - https://arxiv.org/abs/1602.01783
|
||||
@@ -18,7 +18,7 @@ import copy
|
||||
import random
|
||||
import time
|
||||
|
||||
import logger
|
||||
from coach import logger
|
||||
try:
|
||||
import matplotlib.pyplot as plt
|
||||
except ImportError:
|
||||
@@ -29,13 +29,12 @@ from pandas.io import pickle
|
||||
from six.moves import range
|
||||
import scipy
|
||||
|
||||
from architectures.tensorflow_components import shared_variables as sv
|
||||
import configurations
|
||||
import exploration_policies as ep # noqa, used in eval()
|
||||
import memories # noqa, used in eval()
|
||||
from memories import memory
|
||||
import renderer
|
||||
import utils
|
||||
from coach.architectures.tensorflow_components import shared_variables as sv
|
||||
from coach import configurations
|
||||
from coach import exploration_policies as ep # noqa, used in eval()
|
||||
from coach import memories # noqa, used in eval()
|
||||
from coach.memories import memory
|
||||
from coach import utils
|
||||
|
||||
|
||||
class Agent(object):
|
||||
@@ -100,7 +99,6 @@ class Agent(object):
|
||||
self.main_network = None
|
||||
self.networks = []
|
||||
self.last_episode_images = []
|
||||
self.renderer = renderer.Renderer()
|
||||
|
||||
# signals
|
||||
self.signals = []
|
||||
@@ -234,13 +232,6 @@ class Agent(object):
|
||||
r, g, b = observation[:, :, 0], observation[:, :, 1], observation[:, :, 2]
|
||||
observation = 0.2989 * r + 0.5870 * g + 0.1140 * b
|
||||
|
||||
# Render the processed observation which is how the agent will see it
|
||||
# Warning: this cannot currently be done in parallel to rendering the environment
|
||||
if self.tp.visualization.render_observation:
|
||||
if not self.renderer.is_open:
|
||||
self.renderer.create_screen(observation.shape[0], observation.shape[1])
|
||||
self.renderer.render_image(observation)
|
||||
|
||||
return observation.astype('uint8')
|
||||
else:
|
||||
if self.tp.env.normalize_observation:
|
||||
@@ -15,7 +15,7 @@
|
||||
#
|
||||
import numpy as np
|
||||
|
||||
from agents import imitation_agent
|
||||
from coach.agents import imitation_agent
|
||||
|
||||
|
||||
# Behavioral Cloning Agent
|
||||
@@ -15,8 +15,9 @@
|
||||
#
|
||||
import numpy as np
|
||||
|
||||
from agents import value_optimization_agent as voa
|
||||
import utils
|
||||
from coach.agents import value_optimization_agent as voa
|
||||
from coach import utils
|
||||
|
||||
|
||||
# Bootstrapped DQN - https://arxiv.org/pdf/1602.04621.pdf
|
||||
class BootstrappedDQNAgent(voa.ValueOptimizationAgent):
|
||||
@@ -15,7 +15,7 @@
|
||||
#
|
||||
import numpy as np
|
||||
|
||||
from agents import value_optimization_agent as voa
|
||||
from coach.agents import value_optimization_agent as voa
|
||||
|
||||
|
||||
# Categorical Deep Q Network - https://arxiv.org/pdf/1707.06887.pdf
|
||||
@@ -19,10 +19,10 @@ from random import shuffle
|
||||
|
||||
import numpy as np
|
||||
|
||||
from agents import actor_critic_agent as aca
|
||||
from agents import policy_optimization_agent as poa
|
||||
import logger
|
||||
import utils
|
||||
from coach.agents import actor_critic_agent as aca
|
||||
from coach.agents import policy_optimization_agent as poa
|
||||
from coach import logger
|
||||
from coach import utils
|
||||
|
||||
|
||||
# Clipped Proximal Policy Optimization - https://arxiv.org/abs/1707.06347
|
||||
@@ -17,11 +17,11 @@ import copy
|
||||
|
||||
import numpy as np
|
||||
|
||||
from agents import actor_critic_agent as aca
|
||||
from agents import agent
|
||||
from architectures import network_wrapper as nw
|
||||
import configurations as conf
|
||||
import utils
|
||||
from coach.agents import actor_critic_agent as aca
|
||||
from coach.agents import agent
|
||||
from coach.architectures import network_wrapper as nw
|
||||
from coach import configurations as conf
|
||||
from coach import utils
|
||||
|
||||
|
||||
# Deep Deterministic Policy Gradients Network - https://arxiv.org/pdf/1509.02971.pdf
|
||||
@@ -15,7 +15,7 @@
|
||||
#
|
||||
import numpy as np
|
||||
|
||||
from agents import value_optimization_agent as voa
|
||||
from coach.agents import value_optimization_agent as voa
|
||||
|
||||
|
||||
# Double DQN - https://arxiv.org/abs/1509.06461
|
||||
@@ -15,9 +15,9 @@
|
||||
#
|
||||
import numpy as np
|
||||
|
||||
from agents import agent
|
||||
from architectures import network_wrapper as nw
|
||||
import utils
|
||||
from coach.agents import agent
|
||||
from coach.architectures import network_wrapper as nw
|
||||
from coach import utils
|
||||
|
||||
|
||||
# Direct Future Prediction Agent - http://vladlen.info/papers/learning-to-act.pdf
|
||||
@@ -15,7 +15,7 @@
|
||||
#
|
||||
import numpy as np
|
||||
|
||||
from agents import value_optimization_agent as voa
|
||||
from coach.agents import value_optimization_agent as voa
|
||||
|
||||
|
||||
# Distributional Deep Q Network - https://arxiv.org/pdf/1707.06887.pdf
|
||||
@@ -15,7 +15,7 @@
|
||||
#
|
||||
import numpy as np
|
||||
|
||||
from agents import value_optimization_agent as voa
|
||||
from coach.agents import value_optimization_agent as voa
|
||||
|
||||
|
||||
# Deep Q Network - https://www.cs.toronto.edu/~vmnih/docs/dqn.pdf
|
||||
@@ -19,9 +19,9 @@ import os
|
||||
import pygame
|
||||
from pandas.io import pickle
|
||||
|
||||
from agents import agent
|
||||
import logger
|
||||
import utils
|
||||
from coach.agents import agent
|
||||
from coach import logger
|
||||
from coach import utils
|
||||
|
||||
|
||||
class HumanAgent(agent.Agent):
|
||||
@@ -15,10 +15,10 @@
|
||||
#
|
||||
import collections
|
||||
|
||||
from agents import agent
|
||||
from architectures import network_wrapper as nw
|
||||
import utils
|
||||
import logging
|
||||
from coach.agents import agent
|
||||
from coach.architectures import network_wrapper as nw
|
||||
from coach import utils
|
||||
from coach import logger
|
||||
|
||||
|
||||
# Imitation Agent
|
||||
@@ -55,7 +55,7 @@ class ImitationAgent(agent.Agent):
|
||||
# log to screen
|
||||
if phase == utils.RunPhase.TRAIN:
|
||||
# for the training phase - we log during the episode to visualize the progress in training
|
||||
logging.screen.log_dict(
|
||||
logger.screen.log_dict(
|
||||
collections.OrderedDict([
|
||||
("Worker", self.task_id),
|
||||
("Episode", self.current_episode),
|
||||
@@ -65,5 +65,5 @@ class ImitationAgent(agent.Agent):
|
||||
prefix="Training"
|
||||
)
|
||||
else:
|
||||
# for the evaluation phase - logging as in regular RL
|
||||
# for the evaluation phase - logger as in regular RL
|
||||
agent.Agent.log_to_screen(self, phase)
|
||||
@@ -15,7 +15,7 @@
|
||||
#
|
||||
import numpy as np
|
||||
|
||||
from agents import value_optimization_agent as voa
|
||||
from coach.agents import value_optimization_agent as voa
|
||||
|
||||
|
||||
class MixedMonteCarloAgent(voa.ValueOptimizationAgent):
|
||||
@@ -15,10 +15,10 @@
|
||||
#
|
||||
import numpy as np
|
||||
|
||||
from agents import value_optimization_agent as voa
|
||||
from agents import policy_optimization_agent as poa
|
||||
import logger
|
||||
import utils
|
||||
from coach.agents import value_optimization_agent as voa
|
||||
from coach.agents import policy_optimization_agent as poa
|
||||
from coach import logger
|
||||
from coach import utils
|
||||
|
||||
|
||||
# N Step Q Learning Agent - https://arxiv.org/abs/1602.01783
|
||||
@@ -15,8 +15,8 @@
|
||||
#
|
||||
import numpy as np
|
||||
|
||||
from agents.value_optimization_agent import ValueOptimizationAgent
|
||||
import utils
|
||||
from coach.agents.value_optimization_agent import ValueOptimizationAgent
|
||||
from coach import utils
|
||||
|
||||
|
||||
# Normalized Advantage Functions - https://arxiv.org/pdf/1603.00748.pdf
|
||||
@@ -13,9 +13,9 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from agents import value_optimization_agent as voa
|
||||
from logger import screen
|
||||
import utils
|
||||
from coach.agents import value_optimization_agent as voa
|
||||
from coach.logger import screen
|
||||
from coach import utils
|
||||
|
||||
|
||||
# Neural Episodic Control - https://arxiv.org/pdf/1703.01988.pdf
|
||||
@@ -15,7 +15,7 @@
|
||||
#
|
||||
import numpy as np
|
||||
|
||||
from agents import value_optimization_agent as voa
|
||||
from coach.agents import value_optimization_agent as voa
|
||||
|
||||
|
||||
# Persistent Advantage Learning - https://arxiv.org/pdf/1512.04860.pdf
|
||||
@@ -15,9 +15,9 @@
|
||||
#
|
||||
import numpy as np
|
||||
|
||||
from agents import policy_optimization_agent as poa
|
||||
import logger
|
||||
import utils
|
||||
from coach.agents import policy_optimization_agent as poa
|
||||
from coach import logger
|
||||
from coach import utils
|
||||
|
||||
|
||||
class PolicyGradientsAgent(poa.PolicyOptimizationAgent):
|
||||
@@ -17,10 +17,10 @@ import collections
|
||||
|
||||
import numpy as np
|
||||
|
||||
from agents import agent
|
||||
from architectures import network_wrapper as nw
|
||||
import logger
|
||||
import utils
|
||||
from coach.agents import agent
|
||||
from coach.architectures import network_wrapper as nw
|
||||
from coach import logger
|
||||
from coach import utils
|
||||
|
||||
|
||||
class PolicyGradientRescaler(utils.Enum):
|
||||
@@ -18,12 +18,12 @@ import copy
|
||||
|
||||
import numpy as np
|
||||
|
||||
from agents import actor_critic_agent as aca
|
||||
from agents import policy_optimization_agent as poa
|
||||
from architectures import network_wrapper as nw
|
||||
import configurations
|
||||
import logger
|
||||
import utils
|
||||
from coach.agents import actor_critic_agent as aca
|
||||
from coach.agents import policy_optimization_agent as poa
|
||||
from coach.architectures import network_wrapper as nw
|
||||
from coach import configurations
|
||||
from coach import logger
|
||||
from coach import utils
|
||||
|
||||
|
||||
# Proximal Policy Optimization - https://arxiv.org/pdf/1707.06347.pdf
|
||||
@@ -15,7 +15,7 @@
|
||||
#
|
||||
import numpy as np
|
||||
|
||||
from agents import value_optimization_agent as voa
|
||||
from coach.agents import value_optimization_agent as voa
|
||||
|
||||
|
||||
# Quantile Regression Deep Q Network - https://arxiv.org/pdf/1710.10044v1.pdf
|
||||
@@ -15,9 +15,9 @@
|
||||
#
|
||||
import numpy as np
|
||||
|
||||
from agents import agent
|
||||
from architectures import network_wrapper as nw
|
||||
import utils
|
||||
from coach.agents import agent
|
||||
from coach.architectures import network_wrapper as nw
|
||||
from coach import utils
|
||||
|
||||
|
||||
class ValueOptimizationAgent(agent.Agent):
|
||||
@@ -13,7 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import logger
|
||||
from coach import logger
|
||||
|
||||
try:
|
||||
from architectures.tensorflow_components import general_network as ts_gn
|
||||
@@ -16,8 +16,8 @@
|
||||
import ngraph as ng
|
||||
import numpy as np
|
||||
|
||||
from architectures import architecture
|
||||
import utils
|
||||
from coach.architectures import architecture
|
||||
from coach import utils
|
||||
|
||||
|
||||
class NeonArchitecture(architecture.Architecture):
|
||||
@@ -17,11 +17,11 @@ import ngraph as ng
|
||||
from ngraph.frontends import neon
|
||||
from ngraph.util import names as ngraph_names
|
||||
|
||||
from architectures.neon_components import architecture
|
||||
from architectures.neon_components import embedders
|
||||
from architectures.neon_components import middleware
|
||||
from architectures.neon_components import heads
|
||||
import configurations as conf
|
||||
from coach.architectures.neon_components import architecture
|
||||
from coach.architectures.neon_components import embedders
|
||||
from coach.architectures.neon_components import middleware
|
||||
from coach.architectures.neon_components import heads
|
||||
from coach import configurations as conf
|
||||
|
||||
|
||||
class GeneralNeonNetwork(architecture.NeonArchitecture):
|
||||
@@ -17,8 +17,8 @@ import ngraph as ng
|
||||
from ngraph.frontends import neon
|
||||
from ngraph.util import names as ngraph_names
|
||||
|
||||
import utils
|
||||
from architectures.neon_components import losses
|
||||
from coach import utils
|
||||
from coach.architectures.neon_components import losses
|
||||
|
||||
|
||||
class Head(object):
|
||||
@@ -16,16 +16,16 @@
|
||||
import os
|
||||
import collections
|
||||
|
||||
import configurations as conf
|
||||
import logger
|
||||
from coach import configurations as conf
|
||||
from coach import logger
|
||||
try:
|
||||
import tensorflow as tf
|
||||
from architectures.tensorflow_components import general_network as tf_net #import GeneralTensorFlowNetwork
|
||||
from coach.architectures.tensorflow_components import general_network as tf_net
|
||||
except ImportError:
|
||||
logger.failed_imports.append("TensorFlow")
|
||||
|
||||
try:
|
||||
from architectures.neon_components import general_network as neon_net
|
||||
from coach.architectures.neon_components import general_network as neon_net
|
||||
except ImportError:
|
||||
logger.failed_imports.append("Neon")
|
||||
|
||||
@@ -17,9 +17,10 @@ import time
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from architectures import architecture
|
||||
import configurations as conf
|
||||
import utils
|
||||
from coach.architectures import architecture
|
||||
from coach import configurations as conf
|
||||
from coach import utils
|
||||
|
||||
|
||||
def variable_summaries(var):
|
||||
"""Attach a lot of summaries to a Tensor (for TensorBoard visualization)."""
|
||||
@@ -36,6 +37,7 @@ def variable_summaries(var):
|
||||
tf.summary.scalar('min', tf.reduce_min(var))
|
||||
tf.summary.histogram('histogram', var)
|
||||
|
||||
|
||||
class TensorFlowArchitecture(architecture.Architecture):
|
||||
def __init__(self, tuning_parameters, name="", global_network=None, network_is_local=True):
|
||||
"""
|
||||
@@ -15,7 +15,7 @@
|
||||
#
|
||||
import tensorflow as tf
|
||||
|
||||
from configurations import EmbedderComplexity
|
||||
from coach.configurations import EmbedderComplexity
|
||||
|
||||
|
||||
class InputEmbedder(object):
|
||||
@@ -15,11 +15,11 @@
|
||||
#
|
||||
import tensorflow as tf
|
||||
|
||||
from architectures.tensorflow_components import architecture
|
||||
from architectures.tensorflow_components import embedders
|
||||
from architectures.tensorflow_components import middleware
|
||||
from architectures.tensorflow_components import heads
|
||||
import configurations as conf
|
||||
from coach.architectures.tensorflow_components import architecture
|
||||
from coach.architectures.tensorflow_components import embedders
|
||||
from coach.architectures.tensorflow_components import middleware
|
||||
from coach.architectures.tensorflow_components import heads
|
||||
from coach import configurations as conf
|
||||
|
||||
|
||||
class GeneralTensorFlowNetwork(architecture.TensorFlowArchitecture):
|
||||
@@ -16,7 +16,7 @@
|
||||
import tensorflow as tf
|
||||
import numpy as np
|
||||
|
||||
import utils
|
||||
from coach import utils
|
||||
|
||||
|
||||
# Used to initialize weights for policy and value output layers
|
||||
@@ -14,9 +14,9 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
import json
|
||||
|
||||
import types
|
||||
import utils
|
||||
|
||||
from coach import utils
|
||||
|
||||
|
||||
class Frameworks(utils.Enum):
|
||||
@@ -13,10 +13,10 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from environments.gym_environment_wrapper import GymEnvironmentWrapper
|
||||
from environments.doom_environment_wrapper import DoomEnvironmentWrapper
|
||||
from environments.carla_environment_wrapper import CarlaEnvironmentWrapper
|
||||
import utils
|
||||
from coach.environments.gym_environment_wrapper import GymEnvironmentWrapper
|
||||
from coach.environments.doom_environment_wrapper import DoomEnvironmentWrapper
|
||||
from coach.environments.carla_environment_wrapper import CarlaEnvironmentWrapper
|
||||
from coach import utils
|
||||
|
||||
|
||||
class EnvTypes(utils.Enum):
|
||||
@@ -6,7 +6,7 @@ import sys
|
||||
|
||||
import numpy as np
|
||||
|
||||
import logger
|
||||
from coach import logger
|
||||
try:
|
||||
if 'CARLA_ROOT' in os.environ:
|
||||
sys.path.append(os.path.join(os.environ.get('CARLA_ROOT'),
|
||||
@@ -16,8 +16,8 @@ try:
|
||||
from carla import sensor as carla_sensor
|
||||
except ImportError:
|
||||
logger.failed_imports.append("CARLA")
|
||||
from environments import environment_wrapper as ew
|
||||
import utils
|
||||
from coach.environments import environment_wrapper as ew
|
||||
from coach import utils
|
||||
|
||||
|
||||
# enum of the available levels and their path
|
||||
@@ -13,19 +13,18 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import enum
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
|
||||
import logger
|
||||
from coach import logger
|
||||
try:
|
||||
import vizdoom
|
||||
except ImportError:
|
||||
logger.failed_imports.append("ViZDoom")
|
||||
|
||||
from environments import environment_wrapper as ew
|
||||
import utils
|
||||
from coach.environments import environment_wrapper as ew
|
||||
from coach import utils
|
||||
|
||||
|
||||
# enum of the available levels and their path
|
||||
@@ -18,8 +18,7 @@ import time
|
||||
|
||||
import numpy as np
|
||||
|
||||
import renderer
|
||||
import utils
|
||||
from coach import utils
|
||||
|
||||
|
||||
class EnvironmentWrapper(object):
|
||||
@@ -62,7 +61,6 @@ class EnvironmentWrapper(object):
|
||||
self.wait_for_explicit_human_action = False
|
||||
self.is_rendered = self.is_rendered or self.human_control
|
||||
self.game_is_open = True
|
||||
self.renderer = renderer.Renderer()
|
||||
|
||||
@property
|
||||
def measurements(self):
|
||||
@@ -106,26 +104,6 @@ class EnvironmentWrapper(object):
|
||||
Get an action from the user keyboard
|
||||
:return: action index
|
||||
"""
|
||||
if self.wait_for_explicit_human_action:
|
||||
while len(self.renderer.pressed_keys) == 0:
|
||||
self.renderer.get_events()
|
||||
|
||||
if self.key_to_action == {}:
|
||||
# the keys are the numbers on the keyboard corresponding to the action index
|
||||
if len(self.renderer.pressed_keys) > 0:
|
||||
action_idx = self.renderer.pressed_keys[0] - ord("1")
|
||||
if 0 <= action_idx < self.action_space_size:
|
||||
return action_idx
|
||||
else:
|
||||
# the keys are mapped through the environment to more intuitive keyboard keys
|
||||
# key = tuple(self.renderer.pressed_keys)
|
||||
# for key in self.renderer.pressed_keys:
|
||||
for env_keys in self.key_to_action.keys():
|
||||
if set(env_keys) == set(self.renderer.pressed_keys):
|
||||
return self.key_to_action[env_keys]
|
||||
|
||||
# return the default action 0 so that the environment will continue running
|
||||
return self.default_action
|
||||
|
||||
def step(self, action_idx):
|
||||
"""
|
||||
@@ -18,8 +18,8 @@ import random
|
||||
import gym
|
||||
import numpy as np
|
||||
|
||||
from environments import environment_wrapper as ew
|
||||
import utils
|
||||
from coach.environments import environment_wrapper as ew
|
||||
from coach import utils
|
||||
|
||||
|
||||
class GymEnvironmentWrapper(ew.EnvironmentWrapper):
|
||||
@@ -13,18 +13,18 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from exploration_policies.additive_noise import AdditiveNoise
|
||||
from exploration_policies.approximated_thompson_sampling_using_dropout import ApproximatedThompsonSamplingUsingDropout
|
||||
from exploration_policies.bayesian import Bayesian
|
||||
from exploration_policies.boltzmann import Boltzmann
|
||||
from exploration_policies.bootstrapped import Bootstrapped
|
||||
from exploration_policies.categorical import Categorical
|
||||
from exploration_policies.continuous_entropy import ContinuousEntropy
|
||||
from exploration_policies.e_greedy import EGreedy
|
||||
from exploration_policies.exploration_policy import ExplorationPolicy
|
||||
from exploration_policies.greedy import Greedy
|
||||
from exploration_policies.ou_process import OUProcess
|
||||
from exploration_policies.thompson_sampling import ThompsonSampling
|
||||
from coach.exploration_policies.additive_noise import AdditiveNoise
|
||||
from coach.exploration_policies.approximated_thompson_sampling_using_dropout import ApproximatedThompsonSamplingUsingDropout
|
||||
from coach.exploration_policies.bayesian import Bayesian
|
||||
from coach.exploration_policies.boltzmann import Boltzmann
|
||||
from coach.exploration_policies.bootstrapped import Bootstrapped
|
||||
from coach.exploration_policies.categorical import Categorical
|
||||
from coach.exploration_policies.continuous_entropy import ContinuousEntropy
|
||||
from coach.exploration_policies.e_greedy import EGreedy
|
||||
from coach.exploration_policies.exploration_policy import ExplorationPolicy
|
||||
from coach.exploration_policies.greedy import Greedy
|
||||
from coach.exploration_policies.ou_process import OUProcess
|
||||
from coach.exploration_policies.thompson_sampling import ThompsonSampling
|
||||
|
||||
|
||||
__all__ = [AdditiveNoise,
|
||||
@@ -15,8 +15,8 @@
|
||||
#
|
||||
import numpy as np
|
||||
|
||||
from exploration_policies import exploration_policy
|
||||
import utils
|
||||
from coach.exploration_policies import exploration_policy
|
||||
from coach import utils
|
||||
|
||||
|
||||
class AdditiveNoise(exploration_policy.ExplorationPolicy):
|
||||
@@ -15,7 +15,7 @@
|
||||
#
|
||||
import numpy as np
|
||||
|
||||
from exploration_policies import exploration_policy
|
||||
from coach.exploration_policies import exploration_policy
|
||||
|
||||
|
||||
class ApproximatedThompsonSamplingUsingDropout(exploration_policy.ExplorationPolicy):
|
||||
@@ -15,8 +15,8 @@
|
||||
#
|
||||
import numpy as np
|
||||
|
||||
from exploration_policies import exploration_policy
|
||||
import utils
|
||||
from coach.exploration_policies import exploration_policy
|
||||
from coach import utils
|
||||
|
||||
|
||||
class Bayesian(exploration_policy.ExplorationPolicy):
|
||||
@@ -15,8 +15,9 @@
|
||||
#
|
||||
import numpy as np
|
||||
|
||||
from exploration_policies import exploration_policy
|
||||
import utils
|
||||
from coach.exploration_policies import exploration_policy
|
||||
from coach import utils
|
||||
|
||||
|
||||
class Boltzmann(exploration_policy.ExplorationPolicy):
|
||||
def __init__(self, tuning_parameters):
|
||||
@@ -15,7 +15,7 @@
|
||||
#
|
||||
import numpy as np
|
||||
|
||||
from exploration_policies import e_greedy
|
||||
from coach.exploration_policies import e_greedy
|
||||
|
||||
|
||||
class Bootstrapped(e_greedy.EGreedy):
|
||||
@@ -15,7 +15,7 @@
|
||||
#
|
||||
import numpy as np
|
||||
|
||||
from exploration_policies import exploration_policy
|
||||
from coach.exploration_policies import exploration_policy
|
||||
|
||||
|
||||
class Categorical(exploration_policy.ExplorationPolicy):
|
||||
@@ -13,7 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from exploration_policies import exploration_policy
|
||||
from coach.exploration_policies import exploration_policy
|
||||
|
||||
|
||||
class ContinuousEntropy(exploration_policy.ExplorationPolicy):
|
||||
@@ -15,8 +15,8 @@
|
||||
#
|
||||
import numpy as np
|
||||
|
||||
from exploration_policies import exploration_policy
|
||||
import utils
|
||||
from coach.exploration_policies import exploration_policy
|
||||
from coach import utils
|
||||
|
||||
|
||||
class EGreedy(exploration_policy.ExplorationPolicy):
|
||||
@@ -13,7 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import utils
|
||||
from coach import utils
|
||||
|
||||
|
||||
class ExplorationPolicy(object):
|
||||
@@ -15,7 +15,7 @@
|
||||
#
|
||||
import numpy as np
|
||||
|
||||
from exploration_policies import exploration_policy
|
||||
from coach.exploration_policies import exploration_policy
|
||||
|
||||
|
||||
class Greedy(exploration_policy.ExplorationPolicy):
|
||||
@@ -15,11 +15,12 @@
|
||||
#
|
||||
import numpy as np
|
||||
|
||||
from exploration_policies import exploration_policy
|
||||
from coach.exploration_policies import exploration_policy
|
||||
|
||||
# Based on on the description in:
|
||||
# https://math.stackexchange.com/questions/1287634/implementing-ornstein-uhlenbeck-in-matlab
|
||||
|
||||
|
||||
# Ornstein-Uhlenbeck process
|
||||
class OUProcess(exploration_policy.ExplorationPolicy):
|
||||
def __init__(self, tuning_parameters):
|
||||
@@ -15,7 +15,7 @@
|
||||
#
|
||||
import numpy as np
|
||||
|
||||
from exploration_policies import exploration_policy
|
||||
from coach.exploration_policies import exploration_policy
|
||||
|
||||
|
||||
class ThompsonSampling(exploration_policy.ExplorationPolicy):
|
||||
@@ -13,13 +13,13 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from memories.differentiable_neural_dictionary import AnnoyDictionary
|
||||
from memories.differentiable_neural_dictionary import AnnoyIndex
|
||||
from memories.differentiable_neural_dictionary import QDND
|
||||
from memories.episodic_experience_replay import EpisodicExperienceReplay
|
||||
from memories.memory import Episode
|
||||
from memories.memory import Memory
|
||||
from memories.memory import Transition
|
||||
from coach.memories.differentiable_neural_dictionary import AnnoyDictionary
|
||||
from coach.memories.differentiable_neural_dictionary import AnnoyIndex
|
||||
from coach.memories.differentiable_neural_dictionary import QDND
|
||||
from coach.memories.episodic_experience_replay import EpisodicExperienceReplay
|
||||
from coach.memories.memory import Episode
|
||||
from coach.memories.memory import Memory
|
||||
from coach.memories.memory import Transition
|
||||
|
||||
__all__ = [AnnoyDictionary,
|
||||
AnnoyIndex,
|
||||
@@ -17,7 +17,7 @@ import typing
|
||||
|
||||
import numpy as np
|
||||
|
||||
from memories import memory
|
||||
from coach.memories import memory
|
||||
|
||||
|
||||
class EpisodicExperienceReplay(memory.Memory):
|
||||
@@ -17,11 +17,11 @@ import ast
|
||||
import json
|
||||
import sys
|
||||
|
||||
import agents
|
||||
import configurations as conf
|
||||
import environments as env
|
||||
import exploration_policies as ep
|
||||
import presets
|
||||
from coach import agents
|
||||
from coach import configurations as conf
|
||||
from coach import environments as env
|
||||
from coach import exploration_policies as ep
|
||||
from coach import presets
|
||||
|
||||
|
||||
def json_to_preset(json_path):
|
||||
972
dashboard.py
972
dashboard.py
@@ -1,972 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2017 Intel Corporation
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
"""
|
||||
To run Coach Dashboard, run the following command:
|
||||
python3 dashboard.py
|
||||
"""
|
||||
|
||||
import colorsys
|
||||
import datetime
|
||||
import enum
|
||||
import itertools
|
||||
import os
|
||||
import random
|
||||
|
||||
from bokeh import palettes
|
||||
from bokeh import layouts as bl
|
||||
from bokeh import models as bm
|
||||
from bokeh.models import widgets as bw
|
||||
from bokeh import plotting as bp
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from pandas.io import pandas_common
|
||||
import wx
|
||||
|
||||
import utils
|
||||
|
||||
|
||||
class DialogApp(wx.App):
|
||||
def getFileDialog(self):
|
||||
with wx.FileDialog(None, "Open CSV file", wildcard="CSV files (*.csv)|*.csv",
|
||||
style=wx.FD_OPEN | wx.FD_FILE_MUST_EXIST | wx.FD_CHANGE_DIR | wx.FD_MULTIPLE) as fileDialog:
|
||||
if fileDialog.ShowModal() == wx.ID_CANCEL:
|
||||
return None # the user changed their mind
|
||||
else:
|
||||
# Proceed loading the file chosen by the user
|
||||
return fileDialog.GetPaths()
|
||||
|
||||
def getDirDialog(self):
|
||||
with wx.DirDialog (None, "Choose input directory", "",
|
||||
style=wx.FD_OPEN | wx.FD_FILE_MUST_EXIST | wx.FD_CHANGE_DIR) as dirDialog:
|
||||
if dirDialog.ShowModal() == wx.ID_CANCEL:
|
||||
return None # the user changed their mind
|
||||
else:
|
||||
# Proceed loading the dir chosen by the user
|
||||
return dirDialog.GetPath()
|
||||
class Signal:
|
||||
def __init__(self, name, parent):
|
||||
self.name = name
|
||||
self.full_name = "{}/{}".format(parent.filename, self.name)
|
||||
self.selected = False
|
||||
self.color = random.choice(palettes.Dark2[8])
|
||||
self.line = None
|
||||
self.bands = None
|
||||
self.bokeh_source = parent.bokeh_source
|
||||
self.min_val = 0
|
||||
self.max_val = 0
|
||||
self.axis = 'default'
|
||||
self.sub_signals = []
|
||||
for name in self.bokeh_source.data.keys():
|
||||
if (len(name.split('/')) == 1 and name == self.name) or '/'.join(name.split('/')[:-1]) == self.name:
|
||||
self.sub_signals.append(name)
|
||||
if len(self.sub_signals) > 1:
|
||||
self.mean_signal = utils.squeeze_list([name for name in self.sub_signals if 'Mean' in name.split('/')[-1]])
|
||||
self.stdev_signal = utils.squeeze_list([name for name in self.sub_signals if 'Stdev' in name.split('/')[-1]])
|
||||
self.min_signal = utils.squeeze_list([name for name in self.sub_signals if 'Min' in name.split('/')[-1]])
|
||||
self.max_signal = utils.squeeze_list([name for name in self.sub_signals if 'Max' in name.split('/')[-1]])
|
||||
else:
|
||||
self.mean_signal = utils.squeeze_list(self.name)
|
||||
self.stdev_signal = None
|
||||
self.min_signal = None
|
||||
self.max_signal = None
|
||||
self.has_bollinger_bands = False
|
||||
if self.mean_signal and self.stdev_signal and self.min_signal and self.max_signal:
|
||||
self.has_bollinger_bands = True
|
||||
self.show_bollinger_bands = False
|
||||
self.bollinger_bands_source = None
|
||||
self.update_range()
|
||||
|
||||
def set_color(self, color):
|
||||
self.color = color
|
||||
if self.line:
|
||||
self.line.glyph.line_color = color
|
||||
if self.bands:
|
||||
self.bands.glyph.fill_color = color
|
||||
|
||||
def set_selected(self, val):
|
||||
global current_color
|
||||
if self.selected != val:
|
||||
self.selected = val
|
||||
if self.line:
|
||||
# self.set_color(palettes.Dark2[8][current_color])
|
||||
# current_color = (current_color + 1) % len(palettes.Dark2[8])
|
||||
self.line.visible = self.selected
|
||||
if self.bands:
|
||||
self.bands.visible = self.selected and self.show_bollinger_bands
|
||||
elif self.selected:
|
||||
# lazy plotting - plot only when selected for the first time
|
||||
show_spinner()
|
||||
self.set_color(palettes.Dark2[8][current_color])
|
||||
current_color = (current_color + 1) % len(palettes.Dark2[8])
|
||||
if self.has_bollinger_bands:
|
||||
self.set_bands_source()
|
||||
self.create_bands()
|
||||
self.line = plot.line('index', self.mean_signal, source=self.bokeh_source,
|
||||
line_color=self.color, line_width=2)
|
||||
self.line.visible = True
|
||||
hide_spinner()
|
||||
|
||||
def set_dash(self, dash):
|
||||
self.line.glyph.line_dash = dash
|
||||
|
||||
def create_bands(self):
|
||||
self.bands = plot.patch(x='band_x', y='band_y', source=self.bollinger_bands_source,
|
||||
color=self.color, fill_alpha=0.4, alpha=0.1, line_width=0)
|
||||
self.bands.visible = self.show_bollinger_bands
|
||||
# self.min_line = plot.line('index', self.min_signal, source=self.bokeh_source,
|
||||
# line_color=self.color, line_width=3, line_dash="4 4")
|
||||
# self.max_line = plot.line('index', self.max_signal, source=self.bokeh_source,
|
||||
# line_color=self.color, line_width=3, line_dash="4 4")
|
||||
# self.min_line.visible = self.show_bollinger_bands
|
||||
# self.max_line.visible = self.show_bollinger_bands
|
||||
|
||||
def set_bands_source(self):
|
||||
x_ticks = self.bokeh_source.data['index']
|
||||
mean_values = self.bokeh_source.data[self.mean_signal]
|
||||
stdev_values = self.bokeh_source.data[self.stdev_signal]
|
||||
band_x = np.append(x_ticks, x_ticks[::-1])
|
||||
band_y = np.append(mean_values - stdev_values, mean_values[::-1] + stdev_values[::-1])
|
||||
source_data = {'band_x': band_x, 'band_y': band_y}
|
||||
if self.bollinger_bands_source:
|
||||
self.bollinger_bands_source.data = source_data
|
||||
else:
|
||||
self.bollinger_bands_source = bm.ColumnDataSource(source_data)
|
||||
|
||||
def change_bollinger_bands_state(self, new_state):
|
||||
self.show_bollinger_bands = new_state
|
||||
if self.bands and self.selected:
|
||||
self.bands.visible = new_state
|
||||
# self.min_line.visible = new_state
|
||||
# self.max_line.visible = new_state
|
||||
|
||||
def update_range(self):
|
||||
self.min_val = np.min(self.bokeh_source.data[self.mean_signal])
|
||||
self.max_val = np.max(self.bokeh_source.data[self.mean_signal])
|
||||
|
||||
def set_axis(self, axis):
|
||||
self.axis = axis
|
||||
self.line.y_range_name = axis
|
||||
|
||||
def toggle_axis(self):
|
||||
if self.axis == 'default':
|
||||
self.set_axis('secondary')
|
||||
else:
|
||||
self.set_axis('default')
|
||||
|
||||
|
||||
class SignalsFileBase:
|
||||
def __init__(self):
|
||||
self.full_csv_path = ""
|
||||
self.dir = ""
|
||||
self.filename = ""
|
||||
self.signals_averaging_window = 1
|
||||
self.show_bollinger_bands = False
|
||||
self.csv = None
|
||||
self.bokeh_source = None
|
||||
self.bokeh_source_orig = None
|
||||
self.last_modified = None
|
||||
self.signals = {}
|
||||
self.separate_files = False
|
||||
|
||||
def load_csv(self):
|
||||
pass
|
||||
|
||||
def update_source_and_signals(self):
|
||||
# create bokeh data sources
|
||||
self.bokeh_source_orig = bm.ColumnDataSource(self.csv)
|
||||
self.bokeh_source_orig.data['index'] = self.bokeh_source_orig.data[x_axis]
|
||||
|
||||
if self.bokeh_source is None:
|
||||
self.bokeh_source = bm.ColumnDataSource(self.csv)
|
||||
else:
|
||||
# self.bokeh_source.data = self.bokeh_source_orig.data
|
||||
# smooth the data if necessary
|
||||
self.change_averaging_window(self.signals_averaging_window, force=True)
|
||||
|
||||
# create all the signals
|
||||
if len(self.signals.keys()) == 0:
|
||||
self.signals = {}
|
||||
unique_signal_names = []
|
||||
for name in self.csv.columns:
|
||||
if len(name.split('/')) == 1:
|
||||
unique_signal_names.append(name)
|
||||
else:
|
||||
unique_signal_names.append('/'.join(name.split('/')[:-1]))
|
||||
unique_signal_names = list(set(unique_signal_names))
|
||||
for signal_name in unique_signal_names:
|
||||
self.signals[signal_name] = Signal(signal_name, self)
|
||||
|
||||
def load(self):
|
||||
self.load_csv()
|
||||
self.update_source_and_signals()
|
||||
|
||||
def reload_data(self, signals):
|
||||
# this function is a workaround to reload the data of all the signals
|
||||
# if the data doesn't change, bokeh does not refreshes the line
|
||||
self.change_averaging_window(self.signals_averaging_window + 1, force=True)
|
||||
self.change_averaging_window(self.signals_averaging_window - 1, force=True)
|
||||
|
||||
def change_averaging_window(self, new_size, force=False, signals=None):
|
||||
if force or self.signals_averaging_window != new_size:
|
||||
self.signals_averaging_window = new_size
|
||||
win = np.ones(new_size) / new_size
|
||||
temp_data = self.bokeh_source_orig.data.copy()
|
||||
for col in self.bokeh_source.data.keys():
|
||||
if col == 'index' or col in x_axis_options \
|
||||
or (signals and not any(col in signal for signal in signals)):
|
||||
temp_data[col] = temp_data[col][:-new_size]
|
||||
continue
|
||||
temp_data[col] = np.convolve(self.bokeh_source_orig.data[col], win, mode='same')[:-new_size]
|
||||
self.bokeh_source.data = temp_data
|
||||
|
||||
# smooth bollinger bands
|
||||
for signal in self.signals.values():
|
||||
if signal.has_bollinger_bands:
|
||||
signal.set_bands_source()
|
||||
|
||||
def hide_all_signals(self):
|
||||
for signal_name in self.signals.keys():
|
||||
self.set_signal_selection(signal_name, False)
|
||||
|
||||
def set_signal_selection(self, signal_name, val):
|
||||
self.signals[signal_name].set_selected(val)
|
||||
|
||||
def change_bollinger_bands_state(self, new_state):
|
||||
self.show_bollinger_bands = new_state
|
||||
for signal in self.signals.values():
|
||||
signal.change_bollinger_bands_state(new_state)
|
||||
|
||||
def file_was_modified_on_disk(self):
|
||||
pass
|
||||
|
||||
def get_range_of_selected_signals_on_axis(self, axis, selected_signal=None):
|
||||
max_val = -float('inf')
|
||||
min_val = float('inf')
|
||||
for signal in self.signals.values():
|
||||
if (selected_signal and signal.name == selected_signal) or (signal.selected and signal.axis == axis):
|
||||
max_val = max(max_val, signal.max_val)
|
||||
min_val = min(min_val, signal.min_val)
|
||||
return min_val, max_val
|
||||
|
||||
def get_selected_signals(self):
|
||||
signals = []
|
||||
for signal in self.signals.values():
|
||||
if signal.selected:
|
||||
signals.append(signal)
|
||||
return signals
|
||||
|
||||
def show_files_separately(self, val):
|
||||
pass
|
||||
|
||||
|
||||
class SignalsFile(SignalsFileBase):
|
||||
def __init__(self, csv_path, load=True):
|
||||
SignalsFileBase.__init__(self)
|
||||
self.full_csv_path = csv_path
|
||||
self.dir, self.filename, _ = utils.break_file_path(csv_path)
|
||||
if load:
|
||||
self.load()
|
||||
# this helps set the correct x axis
|
||||
self.change_averaging_window(1, force=True)
|
||||
|
||||
def load_csv(self):
|
||||
# load csv and fix sparse data.
|
||||
# csv can be in the middle of being written so we use try - except
|
||||
self.csv = None
|
||||
while self.csv is None:
|
||||
try:
|
||||
self.csv = pd.read_csv(self.full_csv_path)
|
||||
break
|
||||
except pandas_common.EmptyDataError:
|
||||
self.csv = None
|
||||
continue
|
||||
self.csv = self.csv.interpolate()
|
||||
self.csv.fillna(value=0, inplace=True)
|
||||
|
||||
self.last_modified = os.path.getmtime(self.full_csv_path)
|
||||
|
||||
def file_was_modified_on_disk(self):
|
||||
return self.last_modified != os.path.getmtime(self.full_csv_path)
|
||||
|
||||
|
||||
class SignalsFilesGroup(SignalsFileBase):
|
||||
def __init__(self, csv_paths):
|
||||
SignalsFileBase.__init__(self)
|
||||
self.full_csv_paths = csv_paths
|
||||
self.signals_files = []
|
||||
if len(csv_paths) == 1 and os.path.isdir(csv_paths[0]):
|
||||
self.signals_files = [SignalsFile(str(file), load=False) for file in add_directory_csv_files(csv_paths[0])]
|
||||
else:
|
||||
for csv_path in csv_paths:
|
||||
if os.path.isdir(csv_path):
|
||||
self.signals_files.append(SignalsFilesGroup(add_directory_csv_files(csv_path)))
|
||||
else:
|
||||
self.signals_files.append(SignalsFile(str(csv_path), load=False))
|
||||
if len(csv_paths) == 1:
|
||||
# get the parent directory name (since the current directory is the timestamp directory)
|
||||
self.dir = os.path.abspath(os.path.join(os.path.dirname(csv_paths[0]), '..'))
|
||||
else:
|
||||
# get the common directory for all the experiments
|
||||
self.dir = os.path.dirname(os.path.commonprefix(csv_paths))
|
||||
self.filename = '{} - Group({})'.format(os.path.basename(self.dir), len(self.signals_files))
|
||||
self.load()
|
||||
|
||||
# this helps set the correct x axis
|
||||
self.change_averaging_window(1, force=True)
|
||||
|
||||
def load_csv(self):
|
||||
corrupted_files_idx = []
|
||||
for idx, signal_file in enumerate(self.signals_files):
|
||||
signal_file.load_csv()
|
||||
if not all(option in signal_file.csv.keys() for option in x_axis_options):
|
||||
print("Warning: {} file seems to be corrupted and does contain the necessary columns "
|
||||
"and will not be rendered".format(signal_file.filename))
|
||||
corrupted_files_idx.append(idx)
|
||||
|
||||
for file_idx in corrupted_files_idx:
|
||||
del self.signals_files[file_idx]
|
||||
|
||||
# get the stats of all the columns
|
||||
csv_group = pd.concat([signals_file.csv for signals_file in self.signals_files])
|
||||
columns_to_remove = [s for s in csv_group.columns if '/Stdev' in s] + \
|
||||
[s for s in csv_group.columns if '/Min' in s] + \
|
||||
[s for s in csv_group.columns if '/Max' in s]
|
||||
for col in columns_to_remove:
|
||||
del csv_group[col]
|
||||
csv_group = csv_group.groupby(csv_group.index)
|
||||
self.csv_mean = csv_group.mean()
|
||||
self.csv_mean.columns = [s + '/Mean' for s in self.csv_mean.columns]
|
||||
self.csv_stdev = csv_group.std()
|
||||
self.csv_stdev.columns = [s + '/Stdev' for s in self.csv_stdev.columns]
|
||||
self.csv_min = csv_group.min()
|
||||
self.csv_min.columns = [s + '/Min' for s in self.csv_min.columns]
|
||||
self.csv_max = csv_group.max()
|
||||
self.csv_max.columns = [s + '/Max' for s in self.csv_max.columns]
|
||||
|
||||
# get the indices from the file with the least number of indices and which is not an evaluation worker
|
||||
file_with_min_indices = self.signals_files[0]
|
||||
for signals_file in self.signals_files:
|
||||
if signals_file.csv.shape[0] < file_with_min_indices.csv.shape[0] and \
|
||||
'Training reward' in signals_file.csv.keys():
|
||||
file_with_min_indices = signals_file
|
||||
self.index_columns = file_with_min_indices.csv[x_axis_options]
|
||||
|
||||
# concat the stats and the indices columns
|
||||
num_rows = file_with_min_indices.csv.shape[0]
|
||||
self.csv = pd.concat([self.index_columns, self.csv_mean.head(num_rows), self.csv_stdev.head(num_rows),
|
||||
self.csv_min.head(num_rows), self.csv_max.head(num_rows)], axis=1)
|
||||
|
||||
# remove the stat columns for the indices columns
|
||||
columns_to_remove = [s + '/Mean' for s in x_axis_options] + \
|
||||
[s + '/Stdev' for s in x_axis_options] + \
|
||||
[s + '/Min' for s in x_axis_options] + \
|
||||
[s + '/Max' for s in x_axis_options]
|
||||
for col in columns_to_remove:
|
||||
del self.csv[col]
|
||||
|
||||
# remove NaNs
|
||||
self.csv.fillna(value=0, inplace=True) # removing this line will make bollinger bands fail
|
||||
for key in self.csv.keys():
|
||||
if 'Stdev' in key and 'Evaluation' not in key:
|
||||
self.csv[key] = self.csv[key].fillna(value=0)
|
||||
|
||||
for signal_file in self.signals_files:
|
||||
signal_file.update_source_and_signals()
|
||||
|
||||
def change_averaging_window(self, new_size, force=False, signals=None):
|
||||
for signal_file in self.signals_files:
|
||||
signal_file.change_averaging_window(new_size, force, signals)
|
||||
SignalsFileBase.change_averaging_window(self, new_size, force, signals)
|
||||
|
||||
def set_signal_selection(self, signal_name, val):
|
||||
self.show_files_separately(self.separate_files)
|
||||
SignalsFileBase.set_signal_selection(self, signal_name, val)
|
||||
|
||||
def file_was_modified_on_disk(self):
|
||||
for signal_file in self.signals_files:
|
||||
if signal_file.file_was_modified_on_disk():
|
||||
return True
|
||||
return False
|
||||
|
||||
def show_files_separately(self, val):
|
||||
self.separate_files = val
|
||||
for signal in self.signals.values():
|
||||
if signal.selected:
|
||||
if val:
|
||||
signal.set_dash("4 4")
|
||||
else:
|
||||
signal.set_dash("")
|
||||
for signal_file in self.signals_files:
|
||||
try:
|
||||
if val:
|
||||
signal_file.set_signal_selection(signal.name, signal.selected)
|
||||
else:
|
||||
signal_file.set_signal_selection(signal.name, False)
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
class RunType(enum.Enum):
|
||||
SINGLE_FOLDER_SINGLE_FILE = 1
|
||||
SINGLE_FOLDER_MULTIPLE_FILES = 2
|
||||
MULTIPLE_FOLDERS_SINGLE_FILES = 3
|
||||
MULTIPLE_FOLDERS_MULTIPLE_FILES = 4
|
||||
UNKNOWN = 0
|
||||
|
||||
|
||||
class FolderType(enum.Enum):
|
||||
SINGLE_FILE = 1
|
||||
MULTIPLE_FILES = 2
|
||||
MULTIPLE_FOLDERS = 3
|
||||
EMPTY = 4
|
||||
|
||||
dialog = DialogApp()
|
||||
|
||||
# read data
|
||||
patches = {}
|
||||
signals_files = {}
|
||||
selected_file = None
|
||||
x_axis = 'Episode #'
|
||||
x_axis_options = ['Episode #', 'Total steps', 'Wall-Clock Time']
|
||||
current_color = 0
|
||||
|
||||
# spinner
|
||||
root_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
with open(os.path.join(root_dir, 'spinner.css'), 'r') as f:
|
||||
spinner_style = """<style>{}</style>""".format(f.read())
|
||||
spinner_html = """<ul class="spinner"><li></li><li></li><li></li><li></li></ul>"""
|
||||
spinner = bw.Div(text="""""")
|
||||
|
||||
# file refresh time placeholder
|
||||
refresh_info = bw.Div(text="""""", width=210)
|
||||
|
||||
# create figures
|
||||
plot = bp.figure(plot_width=1200, plot_height=800,
|
||||
tools='pan,box_zoom,wheel_zoom,crosshair,undo,redo,reset,save',
|
||||
toolbar_location='above', x_axis_label='Episodes',
|
||||
x_range=bm.Range1d(0, 10000), y_range=bm.Range1d(0, 100000))
|
||||
plot.extra_y_ranges = {"secondary": bm.Range1d(start=-100, end=200)}
|
||||
plot.add_layout(bm.LinearAxis(y_range_name="secondary"), 'right')
|
||||
|
||||
# legend
|
||||
div = bw.Div(text="""""")
|
||||
legend = bl.widgetbox([div])
|
||||
|
||||
bokeh_legend = bm.Legend(
|
||||
items=[("12345678901234567890123456789012345678901234567890", [])], # 50 letters
|
||||
# items=[(" ", [])], # 50 letters
|
||||
location=(-20, 0), orientation="vertical",
|
||||
border_line_color="black",
|
||||
label_text_font_size={'value': '9pt'},
|
||||
margin=30
|
||||
)
|
||||
plot.add_layout(bokeh_legend, "right")
|
||||
|
||||
|
||||
def update_axis_range(name, range_placeholder):
|
||||
max_val = -float('inf')
|
||||
min_val = float('inf')
|
||||
selected_signal = None
|
||||
if name in x_axis_options:
|
||||
selected_signal = name
|
||||
for signals_file in signals_files.values():
|
||||
curr_min_val, curr_max_val = signals_file.get_range_of_selected_signals_on_axis(name, selected_signal)
|
||||
max_val = max(max_val, curr_max_val)
|
||||
min_val = min(min_val, curr_min_val)
|
||||
if min_val != float('inf'):
|
||||
range = max_val - min_val
|
||||
range_placeholder.start = min_val - 0.1 * range
|
||||
range_placeholder.end = max_val + 0.1 * range
|
||||
|
||||
|
||||
# update axes ranges
|
||||
def update_ranges():
|
||||
update_axis_range('default', plot.y_range)
|
||||
update_axis_range('secondary', plot.extra_y_ranges['secondary'])
|
||||
|
||||
|
||||
def get_all_selected_signals():
|
||||
signals = []
|
||||
for signals_file in signals_files.values():
|
||||
signals += signals_file.get_selected_signals()
|
||||
return signals
|
||||
|
||||
|
||||
# update legend using the legend text dictionary
|
||||
def update_legend():
|
||||
legend_text = """<div></div>"""
|
||||
selected_signals = get_all_selected_signals()
|
||||
items = []
|
||||
for signal in selected_signals:
|
||||
side_sign = "<" if signal.axis == 'default' else ">"
|
||||
legend_text += """<div style='color: {}'><b>{} {}</b></div>"""\
|
||||
.format(signal.color, side_sign, signal.full_name)
|
||||
items.append((signal.full_name, [signal.line]))
|
||||
div.text = legend_text
|
||||
# the visible=false => visible=true is a hack to make the legend render again
|
||||
bokeh_legend.visible = False
|
||||
bokeh_legend.items = items
|
||||
bokeh_legend.visible = True
|
||||
|
||||
|
||||
# select lines to display
|
||||
def select_data(args, old, new):
|
||||
if selected_file is None:
|
||||
return
|
||||
show_spinner()
|
||||
selected_signals = new
|
||||
for signal_name in selected_file.signals.keys():
|
||||
is_selected = signal_name in selected_signals
|
||||
selected_file.set_signal_selection(signal_name, is_selected)
|
||||
|
||||
# update axes ranges
|
||||
update_ranges()
|
||||
update_axis_range(x_axis, plot.x_range)
|
||||
|
||||
# update the legend
|
||||
update_legend()
|
||||
|
||||
hide_spinner()
|
||||
|
||||
|
||||
# add new lines to the plot
|
||||
def plot_signals(signals_file, signals):
|
||||
for idx, signal in enumerate(signals):
|
||||
signal.line = plot.line('index', signal.name, source=signals_file.bokeh_source,
|
||||
line_color=signal.color, line_width=2)
|
||||
|
||||
|
||||
def open_file_dialog():
|
||||
return dialog.getFileDialog()
|
||||
|
||||
|
||||
def open_directory_dialog():
|
||||
return dialog.getDirDialog()
|
||||
|
||||
|
||||
def show_spinner():
|
||||
spinner.text = spinner_style + spinner_html
|
||||
|
||||
|
||||
def hide_spinner():
|
||||
spinner.text = ""
|
||||
|
||||
|
||||
# will create a group from the files
|
||||
def create_files_group_signal(files):
|
||||
global selected_file
|
||||
signals_file = SignalsFilesGroup(files)
|
||||
signals_files[signals_file.filename] = signals_file
|
||||
|
||||
filenames = [signals_file.filename]
|
||||
files_selector.options += filenames
|
||||
files_selector.value = filenames[0]
|
||||
selected_file = signals_file
|
||||
|
||||
|
||||
# load files from disk as a group
|
||||
def load_files_group():
|
||||
show_spinner()
|
||||
files = open_file_dialog()
|
||||
# no files selected
|
||||
if not files or not files[0]:
|
||||
hide_spinner()
|
||||
return
|
||||
|
||||
change_displayed_doc()
|
||||
|
||||
if len(files) == 1:
|
||||
create_files_signal(files)
|
||||
else:
|
||||
create_files_group_signal(files)
|
||||
|
||||
change_selected_signals_in_data_selector([""])
|
||||
hide_spinner()
|
||||
|
||||
|
||||
# classify the folder as containing a single file, multiple files or only folders
|
||||
def classify_folder(dir_path):
|
||||
files = [f for f in os.listdir(dir_path) if os.path.isfile(os.path.join(dir_path, f)) and f.endswith('.csv')]
|
||||
folders = [d for d in os.listdir(dir_path) if os.path.isdir(os.path.join(dir_path, d))]
|
||||
if len(files) == 1:
|
||||
return FolderType.SINGLE_FILE
|
||||
elif len(files) > 1:
|
||||
return FolderType.MULTIPLE_FILES
|
||||
elif len(folders) >= 1:
|
||||
return FolderType.MULTIPLE_FOLDERS
|
||||
else:
|
||||
return FolderType.EMPTY
|
||||
|
||||
|
||||
# finds if this is single-threaded or multi-threaded
|
||||
def get_run_type(dir_path):
|
||||
folder_type = classify_folder(dir_path)
|
||||
if folder_type == FolderType.SINGLE_FILE:
|
||||
return RunType.SINGLE_FOLDER_SINGLE_FILE
|
||||
|
||||
elif folder_type == FolderType.MULTIPLE_FILES:
|
||||
return RunType.SINGLE_FOLDER_MULTIPLE_FILES
|
||||
|
||||
elif folder_type == FolderType.MULTIPLE_FOLDERS:
|
||||
# folder contains sub dirs -> we assume we can classify the folder using only the first sub dir
|
||||
sub_dirs = [d for d in os.listdir(dir_path) if os.path.isdir(os.path.join(dir_path, d))]
|
||||
|
||||
# checking only the first folder in the root dir for its type, since we assume that all sub dirs will share the
|
||||
# same structure (i.e. if one is a result of multi-threaded run, so will all the other).
|
||||
folder_type = classify_folder(os.path.join(dir_path, sub_dirs[0]))
|
||||
if folder_type == FolderType.SINGLE_FILE:
|
||||
folder_type = RunType.MULTIPLE_FOLDERS_SINGLE_FILES
|
||||
elif folder_type == FolderType.MULTIPLE_FILES:
|
||||
folder_type = RunType.MULTIPLE_FOLDERS_MULTIPLE_FILES
|
||||
return folder_type
|
||||
|
||||
|
||||
# takes path to dir and recursively adds all it's files to paths
|
||||
def add_directory_csv_files(dir_path, paths=None):
|
||||
if not paths:
|
||||
paths = []
|
||||
|
||||
for p in os.listdir(dir_path):
|
||||
path = os.path.join(dir_path, p)
|
||||
if os.path.isdir(path):
|
||||
# call recursively for each dir
|
||||
paths = add_directory_csv_files(path, paths)
|
||||
elif os.path.isfile(path) and path.endswith('.csv'):
|
||||
# add every file to the list
|
||||
paths.append(path)
|
||||
|
||||
return paths
|
||||
|
||||
|
||||
# create a signal file from the directory path according to the directory underlying structure
|
||||
def handle_dir(dir_path, run_type):
|
||||
paths = add_directory_csv_files(dir_path)
|
||||
if run_type == RunType.SINGLE_FOLDER_SINGLE_FILE:
|
||||
create_files_signal(paths)
|
||||
elif run_type == RunType.SINGLE_FOLDER_MULTIPLE_FILES:
|
||||
create_files_group_signal(paths)
|
||||
elif run_type == RunType.MULTIPLE_FOLDERS_SINGLE_FILES:
|
||||
create_files_group_signal(paths)
|
||||
elif run_type == RunType.MULTIPLE_FOLDERS_MULTIPLE_FILES:
|
||||
sub_dirs = [d for d in os.listdir(dir_path) if os.path.isdir(os.path.join(dir_path, d))]
|
||||
# for d in sub_dirs:
|
||||
# paths = add_directory_csv_files(os.path.join(dir_path, d))
|
||||
# create_files_group_signal(paths)
|
||||
create_files_group_signal([os.path.join(dir_path, d) for d in sub_dirs])
|
||||
|
||||
|
||||
# load directory from disk as a group
|
||||
def load_directory_group():
|
||||
global selected_file
|
||||
show_spinner()
|
||||
directory = open_directory_dialog()
|
||||
# no files selected
|
||||
if not directory:
|
||||
hide_spinner()
|
||||
return
|
||||
|
||||
change_displayed_doc()
|
||||
|
||||
handle_dir(directory, get_run_type(directory))
|
||||
|
||||
change_selected_signals_in_data_selector([""])
|
||||
hide_spinner()
|
||||
|
||||
|
||||
def create_files_signal(files):
|
||||
global selected_file
|
||||
new_signal_files = []
|
||||
for idx, file_path in enumerate(files):
|
||||
signals_file = SignalsFile(str(file_path))
|
||||
signals_files[signals_file.filename] = signals_file
|
||||
new_signal_files.append(signals_file)
|
||||
|
||||
filenames = [f.filename for f in new_signal_files]
|
||||
|
||||
files_selector.options += filenames
|
||||
files_selector.value = filenames[0]
|
||||
selected_file = new_signal_files[0]
|
||||
|
||||
|
||||
# load files from disk
|
||||
def load_files():
|
||||
show_spinner()
|
||||
files = open_file_dialog()
|
||||
|
||||
# no files selected
|
||||
if not files or not files[0]:
|
||||
hide_spinner()
|
||||
return
|
||||
|
||||
create_files_signal(files)
|
||||
hide_spinner()
|
||||
|
||||
change_selected_signals_in_data_selector([""])
|
||||
|
||||
|
||||
def unload_file():
|
||||
global selected_file
|
||||
global signals_files
|
||||
if selected_file is None:
|
||||
return
|
||||
selected_file.hide_all_signals()
|
||||
del signals_files[selected_file.filename]
|
||||
data_selector.options = [""]
|
||||
filenames = itertools.cycle(files_selector.options)
|
||||
files_selector.options.remove(selected_file.filename)
|
||||
if len(files_selector.options) > 0:
|
||||
files_selector.value = next(filenames)
|
||||
else:
|
||||
files_selector.value = None
|
||||
update_legend()
|
||||
refresh_info.text = ""
|
||||
|
||||
|
||||
# reload the selected csv file
|
||||
def reload_all_files(force=False):
|
||||
for file_to_load in signals_files.values():
|
||||
if force or file_to_load.file_was_modified_on_disk():
|
||||
file_to_load.load()
|
||||
refresh_info.text = "last update: " + str(datetime.datetime.now()).split(".")[0]
|
||||
|
||||
|
||||
# unselect the currently selected signals and then select the requested signals in the data selector
|
||||
def change_selected_signals_in_data_selector(selected_signals):
|
||||
# the default bokeh way is not working due to a bug since Bokeh 0.12.6 (https://github.com/bokeh/bokeh/issues/6501)
|
||||
# this will currently cause the signals to change color
|
||||
for value in list(data_selector.value):
|
||||
if value in data_selector.options:
|
||||
index = data_selector.options.index(value)
|
||||
data_selector.options.remove(value)
|
||||
data_selector.value.remove(value)
|
||||
data_selector.options.insert(index, value)
|
||||
data_selector.value = selected_signals
|
||||
|
||||
|
||||
# change data options according to the selected file
|
||||
def change_data_selector(args, old, new):
|
||||
global selected_file
|
||||
if new is None:
|
||||
selected_file = None
|
||||
return
|
||||
show_spinner()
|
||||
selected_file = signals_files[new]
|
||||
data_selector.options = sorted(list(selected_file.signals.keys()))
|
||||
selected_signal_names = [s.name for s in selected_file.signals.values() if s.selected]
|
||||
if not selected_signal_names:
|
||||
selected_signal_names = [""]
|
||||
change_selected_signals_in_data_selector(selected_signal_names)
|
||||
averaging_slider.value = selected_file.signals_averaging_window
|
||||
group_cb.active = [0 if selected_file.show_bollinger_bands else None]
|
||||
group_cb.active += [1 if selected_file.separate_files else None]
|
||||
hide_spinner()
|
||||
|
||||
|
||||
# smooth all the signals of the selected file
|
||||
def update_averaging(args, old, new):
|
||||
show_spinner()
|
||||
selected_file.change_averaging_window(new)
|
||||
hide_spinner()
|
||||
|
||||
|
||||
def change_x_axis(val):
|
||||
global x_axis
|
||||
show_spinner()
|
||||
x_axis = x_axis_options[val]
|
||||
plot.xaxis.axis_label = x_axis
|
||||
reload_all_files(force=True)
|
||||
update_axis_range(x_axis, plot.x_range)
|
||||
hide_spinner()
|
||||
|
||||
|
||||
# move the signal between the main and secondary Y axes
|
||||
def toggle_second_axis():
|
||||
show_spinner()
|
||||
signals = selected_file.get_selected_signals()
|
||||
for signal in signals:
|
||||
signal.toggle_axis()
|
||||
|
||||
update_ranges()
|
||||
update_legend()
|
||||
|
||||
# this is just for redrawing the signals
|
||||
selected_file.reload_data([signal.name for signal in signals])
|
||||
|
||||
hide_spinner()
|
||||
|
||||
|
||||
def toggle_group_property(new):
|
||||
# toggle show / hide Bollinger bands
|
||||
selected_file.change_bollinger_bands_state(0 in new)
|
||||
|
||||
# show a separate signal for each file in a group
|
||||
selected_file.show_files_separately(1 in new)
|
||||
|
||||
|
||||
def change_displayed_doc():
|
||||
if doc.roots[0] == landing_page:
|
||||
doc.remove_root(landing_page)
|
||||
doc.add_root(layout)
|
||||
|
||||
|
||||
# Color selection - most of these functions are taken from bokeh examples (plotting/color_sliders.py)
|
||||
def select_color(attr, old, new):
|
||||
show_spinner()
|
||||
signals = selected_file.get_selected_signals()
|
||||
for signal in signals:
|
||||
signal.set_color(rgb_to_hex(crRGBs[new['1d']['indices'][0]]))
|
||||
hide_spinner()
|
||||
|
||||
|
||||
def generate_color_range(N, I):
|
||||
HSV_tuples = [(x*1.0/N, 0.5, I) for x in range(N)]
|
||||
RGB_tuples = map(lambda x: colorsys.hsv_to_rgb(*x), HSV_tuples)
|
||||
for_conversion = []
|
||||
for RGB_tuple in RGB_tuples:
|
||||
for_conversion.append((int(RGB_tuple[0]*255), int(RGB_tuple[1]*255), int(RGB_tuple[2]*255)))
|
||||
hex_colors = [rgb_to_hex(RGB_tuple) for RGB_tuple in for_conversion]
|
||||
return hex_colors, for_conversion
|
||||
|
||||
|
||||
# convert RGB tuple to hexadecimal code
|
||||
def rgb_to_hex(rgb):
|
||||
return '#%02x%02x%02x' % rgb
|
||||
|
||||
|
||||
# convert hexadecimal to RGB tuple
|
||||
def hex_to_dec(hex):
|
||||
red = ''.join(hex.strip('#')[0:2])
|
||||
green = ''.join(hex.strip('#')[2:4])
|
||||
blue = ''.join(hex.strip('#')[4:6])
|
||||
return int(red, 16), int(green, 16), int(blue,16)
|
||||
|
||||
color_resolution = 1000
|
||||
brightness = 0.75 # change to have brighter/darker colors
|
||||
crx = list(range(1, color_resolution+1)) # the resolution is 1000 colors
|
||||
cry = [5 for i in range(len(crx))]
|
||||
crcolor, crRGBs = generate_color_range(color_resolution, brightness) # produce spectrum
|
||||
|
||||
|
||||
# ---------------- Build Website Layout -------------------
|
||||
|
||||
# select file
|
||||
file_selection_button = bw.Button(label="Select Files", button_type="success", width=120)
|
||||
file_selection_button.on_click(load_files_group)
|
||||
|
||||
files_selector_spacer = bl.Spacer(width=10)
|
||||
|
||||
group_selection_button = bw.Button(label="Select Directory", button_type="primary", width=140)
|
||||
group_selection_button.on_click(load_directory_group)
|
||||
|
||||
unload_file_button = bw.Button(label="Unload", button_type="danger", width=50)
|
||||
unload_file_button.on_click(unload_file)
|
||||
|
||||
# files selection box
|
||||
files_selector = bw.Select(title="Files:", options=[], width=200)
|
||||
files_selector.on_change('value', change_data_selector)
|
||||
|
||||
# data selection box
|
||||
data_selector = bw.MultiSelect(title="Data:", options=[], size=12)
|
||||
data_selector.on_change('value', select_data)
|
||||
|
||||
# x axis selection box
|
||||
x_axis_selector_title = bw.Div(text="""X Axis:""")
|
||||
x_axis_selector = bw.RadioButtonGroup(labels=x_axis_options, active=0)
|
||||
x_axis_selector.on_click(change_x_axis)
|
||||
|
||||
# toggle second axis bw.button
|
||||
toggle_second_axis_button = bw.Button(label="Toggle Second Axis", button_type="success")
|
||||
toggle_second_axis_button.on_click(toggle_second_axis)
|
||||
|
||||
# averaging slider
|
||||
averaging_slider = bw.Slider(title="Averaging window", start=1, end=101, step=10)
|
||||
averaging_slider.on_change('value', update_averaging)
|
||||
|
||||
# group properties checkbox
|
||||
group_cb = bw.CheckboxGroup(labels=["Show statistics bands", "Ungroup signals"], active=[])
|
||||
group_cb.on_click(toggle_group_property)
|
||||
|
||||
# color selector
|
||||
color_selector_title = bw.Div(text="""Select Color:""")
|
||||
crsource = bm.ColumnDataSource(data=dict(x=crx, y=cry, crcolor=crcolor, RGBs=crRGBs))
|
||||
color_selector = bp.figure(x_range=(0, color_resolution), y_range=(0, 10),
|
||||
plot_width=300, plot_height=40,
|
||||
tools='tap')
|
||||
color_selector.axis.visible = False
|
||||
color_range = color_selector.rect(x='x', y='y', width=1, height=10,
|
||||
color='crcolor', source=crsource)
|
||||
crsource.on_change('selected', select_color)
|
||||
color_range.nonselection_glyph = color_range.glyph
|
||||
color_selector.toolbar.logo = None
|
||||
color_selector.toolbar_location = None
|
||||
|
||||
# title
|
||||
title = bw.Div(text="""<h1>Coach Dashboard</h1>""")
|
||||
|
||||
# landing page
|
||||
landing_page_description = bw.Div(text="""<h3>Start by selecting an experiment file or directory to open:</h3>""")
|
||||
center = bw.Div(text="""<style>html { text-align: center; } </style>""")
|
||||
center_buttons = bw.Div(text="""<style>.bk-grid-row .bk-layout-fixed { margin: 0 auto; }</style>""", width=0)
|
||||
landing_page = bl.column(center,
|
||||
title,
|
||||
landing_page_description,
|
||||
bl.row(center_buttons),
|
||||
bl.row(file_selection_button, sizing_mode='scale_width'),
|
||||
bl.row(group_selection_button, sizing_mode='scale_width'),
|
||||
sizing_mode='scale_width')
|
||||
|
||||
# main layout of the document
|
||||
layout = bl.row(file_selection_button, files_selector_spacer, group_selection_button, width=300)
|
||||
layout = bl.column(layout, files_selector)
|
||||
layout = bl.column(layout, bl.row(refresh_info, unload_file_button))
|
||||
layout = bl.column(layout, data_selector)
|
||||
layout = bl.column(layout, color_selector_title)
|
||||
layout = bl.column(layout, color_selector)
|
||||
layout = bl.column(layout, x_axis_selector_title)
|
||||
layout = bl.column(layout, x_axis_selector)
|
||||
layout = bl.column(layout, group_cb)
|
||||
layout = bl.column(layout, toggle_second_axis_button)
|
||||
layout = bl.column(layout, averaging_slider)
|
||||
# layout = bl.column(layout, legend)
|
||||
layout = bl.row(layout, plot)
|
||||
layout = bl.column(title, layout)
|
||||
layout = bl.column(layout, spinner)
|
||||
|
||||
doc = bp.curdoc()
|
||||
doc.add_root(landing_page)
|
||||
|
||||
doc.add_periodic_callback(reload_all_files, 20000)
|
||||
plot.y_range = bm.Range1d(0, 100)
|
||||
plot.extra_y_ranges['secondary'] = bm.Range1d(0, 100)
|
||||
|
||||
# show load file dialog immediately on start
|
||||
#doc.add_timeout_callback(load_files, 1000)
|
||||
|
||||
if __name__ == "__main__":
|
||||
# find an open port and run the server
|
||||
import socket
|
||||
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
port = 12345
|
||||
while True:
|
||||
try:
|
||||
s.bind(("127.0.0.1", port))
|
||||
break
|
||||
except socket.error as e:
|
||||
if e.errno == 98:
|
||||
port += 1
|
||||
s.close()
|
||||
os.system('bokeh serve --show dashboard.py --port {}'.format(port))
|
||||
@@ -1,49 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2017 Intel Corporation
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
|
||||
|
||||
def show_observation_stack(stack, channels_last=False):
|
||||
if isinstance(stack, list): # is list
|
||||
stack_size = len(stack)
|
||||
elif len(stack.shape) == 3:
|
||||
stack_size = stack.shape[0] # is numpy array
|
||||
elif len(stack.shape) == 4:
|
||||
stack_size = stack.shape[1] # ignore batch dimension
|
||||
stack = stack[0]
|
||||
else:
|
||||
assert False, ""
|
||||
|
||||
if channels_last:
|
||||
stack = np.transpose(stack, (2, 0, 1))
|
||||
stack_size = stack.shape[0]
|
||||
|
||||
for i in range(stack_size):
|
||||
plt.subplot(1, stack_size, i + 1)
|
||||
plt.imshow(stack[i], cmap='gray')
|
||||
|
||||
plt.show()
|
||||
|
||||
|
||||
def show_diff_between_two_observations(observation1, observation2):
|
||||
plt.imshow(observation1 - observation2, cmap='gray')
|
||||
plt.show()
|
||||
|
||||
|
||||
def plot_grayscale_observation(observation):
|
||||
plt.imshow(observation, cmap='gray')
|
||||
plt.show()
|
||||
205
install.sh
205
install.sh
@@ -1,205 +0,0 @@
|
||||
#!/bin/bash -e
|
||||
|
||||
prompt () {
|
||||
# prints a yes / no question to the user and returns the answer
|
||||
# first argument is the prompt question
|
||||
# second argument is the default answer - Y / N
|
||||
local default_answer
|
||||
|
||||
# set the default value
|
||||
case "${2}" in
|
||||
y|Y ) default_answer=1; options="[Y/n]";;
|
||||
n|N ) default_answer=0; options="[y/N]";;
|
||||
"" ) default_answer=; options="[y/n]";;
|
||||
* ) echo "invalid default value"; exit;;
|
||||
esac
|
||||
|
||||
while true; do
|
||||
# read the user choice
|
||||
read -p "${1} ${options} " choice
|
||||
|
||||
# return the choice or the default value if an enter was pressed
|
||||
case "${choice}" in
|
||||
y|Y ) retval=1; return;;
|
||||
n|N ) retval=0; return;;
|
||||
"" ) if [ ! -z "${default_answer}" ]; then retval=${default_answer}; return; fi;;
|
||||
esac
|
||||
done
|
||||
}
|
||||
|
||||
add_to_bashrc () {
|
||||
# adds an env variable to the bashrc
|
||||
# first argument is the variable name
|
||||
# second argument is the variable value
|
||||
|
||||
EXISTS_IN_BASHRC=`awk '/${2}/{print $1}' ~/.bashrc`
|
||||
if [ "${EXISTS_IN_BASHRC}" == "" ]; then
|
||||
echo "export ${1}=${2}" >> ~/.bashrc
|
||||
fi
|
||||
}
|
||||
|
||||
GET_PREFERENCES_MANUALLY=1
|
||||
|
||||
INSTALL_COACH=0
|
||||
INSTALL_DASHBOARD=0
|
||||
INSTALL_GYM=0
|
||||
INSTALL_NEON=0
|
||||
INSTALL_VIRTUAL_ENVIRONMENT=1
|
||||
|
||||
# Get user preferences
|
||||
TEMP=`getopt -o cpgvrmeNndh \
|
||||
--long coach,dashboard,gym,no_virtual_environment,neon,debug,help \
|
||||
-- "$@"`
|
||||
eval set -- "$TEMP"
|
||||
while true; do
|
||||
#for i in "$@"
|
||||
case ${1} in
|
||||
-c|--coach)
|
||||
INSTALL_COACH=1
|
||||
GET_PREFERENCES_MANUALLY=0
|
||||
shift;;
|
||||
-p|--dashboard)
|
||||
INSTALL_DASHBOARD=1
|
||||
GET_PREFERENCES_MANUALLY=0;
|
||||
shift;;
|
||||
-g|--gym)
|
||||
INSTALL_GYM=1
|
||||
GET_PREFERENCES_MANUALLY=0;
|
||||
shift;;
|
||||
-N|--no_virtual_environment)
|
||||
INSTALL_VIRTUAL_ENVIRONMENT=0
|
||||
GET_PREFERENCES_MANUALLY=0;
|
||||
shift;;
|
||||
-ne|--neon)
|
||||
INSTALL_NEON=1
|
||||
GET_PREFERENCES_MANUALLY=0;
|
||||
shift;;
|
||||
-d|--debug) set -x; shift;;
|
||||
-h|--help)
|
||||
echo "Available command line arguments:"
|
||||
echo ""
|
||||
echo " -c | --coach - Install Coach requirements"
|
||||
echo " -p | --dashboard - Install Dashboard requirements"
|
||||
echo " -g | --gym - Install Gym support"
|
||||
echo " -N | --no_virtual_environment - Do not install inside of a virtual environment"
|
||||
echo " -d | --debug - Run in debug mode"
|
||||
echo " -h | --help - Display this help message"
|
||||
echo ""
|
||||
exit;;
|
||||
--) shift; break;;
|
||||
*) break;; # unknown option;;
|
||||
esac
|
||||
done
|
||||
|
||||
if [ ${GET_PREFERENCES_MANUALLY} -eq 1 ]; then
|
||||
prompt "Install Coach requirements?" Y
|
||||
INSTALL_COACH=${retval}
|
||||
|
||||
prompt "Install Dashboard requirements?" Y
|
||||
INSTALL_DASHBOARD=${retval}
|
||||
|
||||
prompt "Install Gym support?" Y
|
||||
INSTALL_GYM=${retval}
|
||||
|
||||
prompt "Install neon support?" Y
|
||||
INSTALL_NEON=${retval}
|
||||
fi
|
||||
|
||||
IN_VIRTUAL_ENV=`python3 -c 'import sys; print("%i" % hasattr(sys, "real_prefix"))'`
|
||||
|
||||
# basic installations
|
||||
sudo -E apt-get install python3-pip cmake zlib1g-dev python3-tk python-opencv -y
|
||||
pip3 install --upgrade pip
|
||||
|
||||
# if we are not in a virtual environment, we will create one with the appropriate python version and then activate it
|
||||
# if we are already in a virtual environment,
|
||||
|
||||
if [ ${INSTALL_VIRTUAL_ENVIRONMENT} -eq 1 ]; then
|
||||
if [ ${IN_VIRTUAL_ENV} -eq 0 ]; then
|
||||
sudo -E pip3 install virtualenv
|
||||
virtualenv -p python3 coach_env
|
||||
. coach_env/bin/activate
|
||||
fi
|
||||
fi
|
||||
|
||||
#------------------------------------------------
|
||||
# From now on we are in a virtual environment
|
||||
#------------------------------------------------
|
||||
|
||||
# get python local and global paths
|
||||
python_version=python$(python -c "import sys; print (str(sys.version_info[0])+'.'+str(sys.version_info[1]))")
|
||||
var=( $(which -a $python_version) )
|
||||
get_python_lib_cmd="from distutils.sysconfig import get_python_lib; print (get_python_lib())"
|
||||
lib_virtualenv_path=$(python -c "$get_python_lib_cmd")
|
||||
lib_system_path=$(${var[-1]} -c "$get_python_lib_cmd")
|
||||
|
||||
# Boost libraries
|
||||
sudo -E apt-get install libboost-all-dev -y
|
||||
|
||||
# Coach
|
||||
if [ ${INSTALL_COACH} -eq 1 ]; then
|
||||
echo "Installing Coach requirements"
|
||||
pip3 install -r ./requirements_coach.txt
|
||||
fi
|
||||
|
||||
# Dashboard
|
||||
if [ ${INSTALL_DASHBOARD} -eq 1 ]; then
|
||||
echo "Installing Dashboard requirements"
|
||||
pip3 install -r ./requirements_dashboard.txt
|
||||
sudo -E apt-get install dpkg-dev build-essential python3.5-dev libjpeg-dev libtiff-dev libsdl1.2-dev libnotify-dev \
|
||||
freeglut3 freeglut3-dev libsm-dev libgtk2.0-dev libgtk-3-dev libwebkitgtk-dev libgtk-3-dev libwebkitgtk-3.0-dev libgstreamer-plugins-base1.0-dev -y
|
||||
|
||||
sudo -E -H pip3 install -U --pre -f \
|
||||
https://wxpython.org/Phoenix/snapshot-builds/linux/gtk3/ubuntu-16.04/wxPython-4.0.0a3.dev3059+4a5c5d9-cp35-cp35m-linux_x86_64.whl wxPython
|
||||
|
||||
# link wxPython Phoenix library into the virtualenv since it is installed with apt-get and not accessible
|
||||
libs=( wx )
|
||||
for lib in ${libs[@]}
|
||||
do
|
||||
ln -sf $lib_system_path/$lib $lib_virtualenv_path/$lib
|
||||
done
|
||||
fi
|
||||
|
||||
# Gym
|
||||
if [ ${INSTALL_GYM} -eq 1 ]; then
|
||||
echo "Installing Gym support"
|
||||
sudo -E apt-get install libav-tools libsdl2-dev swig cmake -y
|
||||
pip3 install box2d # for bipedal walker etc.
|
||||
pip3 install gym
|
||||
fi
|
||||
|
||||
# NGraph and Neon
|
||||
if [ ${INSTALL_NEON} -eq 1 ]; then
|
||||
echo "Installing neon requirements"
|
||||
|
||||
# MKL
|
||||
git clone https://github.com/01org/mkl-dnn.git
|
||||
cd mkl-dnn
|
||||
cd scripts && ./prepare_mkl.sh && cd ..
|
||||
mkdir -p build && cd build && cmake .. && make -j
|
||||
sudo make install -j
|
||||
cd ../..
|
||||
export MKLDNN_ROOT=/usr/local/
|
||||
add_to_bashrc MKLDNN_ROOT ${MKLDNN_ROOT}
|
||||
export LD_LIBRARY_PATH=$MKLDNN_ROOT/lib:$LD_LIBRARY_PATH
|
||||
add_to_bashrc LD_LIBRARY_PATH ${MKLDNN_ROOT}/lib:$LD_LIBRARY_PATH
|
||||
|
||||
# NGraph
|
||||
git clone https://github.com/NervanaSystems/ngraph.git
|
||||
cd ngraph
|
||||
make install -j
|
||||
cd ..
|
||||
|
||||
# Neon
|
||||
sudo -E apt-get install libhdf5-dev libyaml-dev pkg-config clang virtualenv libcurl4-openssl-dev libopencv-dev libsox-dev -y
|
||||
pip3 install nervananeon
|
||||
fi
|
||||
|
||||
if ! [ -x "$(command -v nvidia-smi)" ]; then
|
||||
# Intel Optimized TensorFlow
|
||||
#pip3 install https://anaconda.org/intel/tensorflow/1.3.0/download/tensorflow-1.3.0-cp35-cp35m-linux_x86_64.whl
|
||||
pip3 install https://anaconda.org/intel/tensorflow/1.4.0/download/tensorflow-1.4.0-cp35-cp35m-linux_x86_64.whl
|
||||
else
|
||||
# GPU supported TensorFlow
|
||||
pip3 install tensorflow-gpu==1.4.1
|
||||
fi
|
||||
@@ -7,3 +7,5 @@ pygame==1.9.3
|
||||
PyOpenGL==3.1.0
|
||||
scipy==0.19.0
|
||||
scikit-image==0.13.0
|
||||
gym
|
||||
tensorflow
|
||||
@@ -14,6 +14,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import argparse
|
||||
import atexit
|
||||
import json
|
||||
import os
|
||||
@@ -22,13 +23,12 @@ import subprocess
|
||||
import sys
|
||||
import time
|
||||
|
||||
import agents
|
||||
import argparse
|
||||
import configurations as conf
|
||||
import environments
|
||||
import logger
|
||||
import presets
|
||||
import utils
|
||||
from coach import agents # noqa
|
||||
from coach import configurations as conf
|
||||
from coach import environments
|
||||
from coach import logger
|
||||
from coach import presets
|
||||
from coach import utils
|
||||
|
||||
|
||||
if len(set(logger.failed_imports)) > 0:
|
||||
@@ -20,10 +20,10 @@ import time
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
import agents
|
||||
import environments
|
||||
import logger
|
||||
import presets
|
||||
from coach import agents
|
||||
from coach import environments
|
||||
from coach import logger
|
||||
from coach import presets
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
28
setup.py
Executable file
28
setup.py
Executable file
@@ -0,0 +1,28 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Setup for the coach project
|
||||
"""
|
||||
from setuptools import setup, find_packages
|
||||
|
||||
|
||||
setup(name="coach",
|
||||
version="0.7",
|
||||
description="Reinforcement Learning Coach",
|
||||
author="Caspi, Itai and Leibovich, Gal and Novik, Gal",
|
||||
author_email="gal.novik@intel.com",
|
||||
url="https://github.com/NervanaSystems/coach",
|
||||
packages=find_packages(),
|
||||
download_url="https://github.com/NervanaSystems/coach",
|
||||
keywords=["reinforcement", "machine", "learning"],
|
||||
install_requires=["annoy", "Pillow", "matplotlib", "numpy", "pandas",
|
||||
"pygame", "PyOpenGL", "scipy", "scikit-image",
|
||||
"tensorflow", "gym"],
|
||||
scripts=["scripts/coach"],
|
||||
classifiers=["Programming Language :: Python :: 3",
|
||||
"Development Status :: 4 - Beta",
|
||||
"Environment :: Console",
|
||||
"Intended Audience :: End Users/Desktop",
|
||||
"License :: OSI Approved :: Apache Software License",
|
||||
"Operating System :: OS Independent",
|
||||
"Topic :: Scientific/Engineering :: Artificial "
|
||||
"Intelligence"])
|
||||
Reference in New Issue
Block a user