mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
Moved coach to its top level module.
This commit is contained in:
@@ -13,28 +13,28 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
from agents.actor_critic_agent import ActorCriticAgent
|
from coach.agents.actor_critic_agent import ActorCriticAgent
|
||||||
from agents.agent import Agent
|
from coach.agents.agent import Agent
|
||||||
from agents.bc_agent import BCAgent
|
from coach.agents.bc_agent import BCAgent
|
||||||
from agents.bootstrapped_dqn_agent import BootstrappedDQNAgent
|
from coach.agents.bootstrapped_dqn_agent import BootstrappedDQNAgent
|
||||||
from agents.categorical_dqn_agent import CategoricalDQNAgent
|
from coach.agents.categorical_dqn_agent import CategoricalDQNAgent
|
||||||
from agents.clipped_ppo_agent import ClippedPPOAgent
|
from coach.agents.clipped_ppo_agent import ClippedPPOAgent
|
||||||
from agents.ddpg_agent import DDPGAgent
|
from coach.agents.ddpg_agent import DDPGAgent
|
||||||
from agents.ddqn_agent import DDQNAgent
|
from coach.agents.ddqn_agent import DDQNAgent
|
||||||
from agents.dfp_agent import DFPAgent
|
from coach.agents.dfp_agent import DFPAgent
|
||||||
from agents.dqn_agent import DQNAgent
|
from coach.agents.dqn_agent import DQNAgent
|
||||||
from agents.human_agent import HumanAgent
|
from coach.agents.human_agent import HumanAgent
|
||||||
from agents.imitation_agent import ImitationAgent
|
from coach.agents.imitation_agent import ImitationAgent
|
||||||
from agents.mmc_agent import MixedMonteCarloAgent
|
from coach.agents.mmc_agent import MixedMonteCarloAgent
|
||||||
from agents.n_step_q_agent import NStepQAgent
|
from coach.agents.n_step_q_agent import NStepQAgent
|
||||||
from agents.naf_agent import NAFAgent
|
from coach.agents.naf_agent import NAFAgent
|
||||||
from agents.nec_agent import NECAgent
|
from coach.agents.nec_agent import NECAgent
|
||||||
from agents.pal_agent import PALAgent
|
from coach.agents.pal_agent import PALAgent
|
||||||
from agents.policy_gradients_agent import PolicyGradientsAgent
|
from coach.agents.policy_gradients_agent import PolicyGradientsAgent
|
||||||
from agents.policy_optimization_agent import PolicyOptimizationAgent
|
from coach.agents.policy_optimization_agent import PolicyOptimizationAgent
|
||||||
from agents.ppo_agent import PPOAgent
|
from coach.agents.ppo_agent import PPOAgent
|
||||||
from agents.qr_dqn_agent import QuantileRegressionDQNAgent
|
from coach.agents.qr_dqn_agent import QuantileRegressionDQNAgent
|
||||||
from agents.value_optimization_agent import ValueOptimizationAgent
|
from coach.agents.value_optimization_agent import ValueOptimizationAgent
|
||||||
|
|
||||||
__all__ = [ActorCriticAgent,
|
__all__ = [ActorCriticAgent,
|
||||||
Agent,
|
Agent,
|
||||||
@@ -16,9 +16,9 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from scipy import signal
|
from scipy import signal
|
||||||
|
|
||||||
from agents import policy_optimization_agent as poa
|
from coach.agents import policy_optimization_agent as poa
|
||||||
import utils
|
from coach import utils
|
||||||
import logger
|
from coach import logger
|
||||||
|
|
||||||
|
|
||||||
# Actor Critic - https://arxiv.org/abs/1602.01783
|
# Actor Critic - https://arxiv.org/abs/1602.01783
|
||||||
@@ -18,7 +18,7 @@ import copy
|
|||||||
import random
|
import random
|
||||||
import time
|
import time
|
||||||
|
|
||||||
import logger
|
from coach import logger
|
||||||
try:
|
try:
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
except ImportError:
|
except ImportError:
|
||||||
@@ -29,13 +29,12 @@ from pandas.io import pickle
|
|||||||
from six.moves import range
|
from six.moves import range
|
||||||
import scipy
|
import scipy
|
||||||
|
|
||||||
from architectures.tensorflow_components import shared_variables as sv
|
from coach.architectures.tensorflow_components import shared_variables as sv
|
||||||
import configurations
|
from coach import configurations
|
||||||
import exploration_policies as ep # noqa, used in eval()
|
from coach import exploration_policies as ep # noqa, used in eval()
|
||||||
import memories # noqa, used in eval()
|
from coach import memories # noqa, used in eval()
|
||||||
from memories import memory
|
from coach.memories import memory
|
||||||
import renderer
|
from coach import utils
|
||||||
import utils
|
|
||||||
|
|
||||||
|
|
||||||
class Agent(object):
|
class Agent(object):
|
||||||
@@ -100,7 +99,6 @@ class Agent(object):
|
|||||||
self.main_network = None
|
self.main_network = None
|
||||||
self.networks = []
|
self.networks = []
|
||||||
self.last_episode_images = []
|
self.last_episode_images = []
|
||||||
self.renderer = renderer.Renderer()
|
|
||||||
|
|
||||||
# signals
|
# signals
|
||||||
self.signals = []
|
self.signals = []
|
||||||
@@ -234,13 +232,6 @@ class Agent(object):
|
|||||||
r, g, b = observation[:, :, 0], observation[:, :, 1], observation[:, :, 2]
|
r, g, b = observation[:, :, 0], observation[:, :, 1], observation[:, :, 2]
|
||||||
observation = 0.2989 * r + 0.5870 * g + 0.1140 * b
|
observation = 0.2989 * r + 0.5870 * g + 0.1140 * b
|
||||||
|
|
||||||
# Render the processed observation which is how the agent will see it
|
|
||||||
# Warning: this cannot currently be done in parallel to rendering the environment
|
|
||||||
if self.tp.visualization.render_observation:
|
|
||||||
if not self.renderer.is_open:
|
|
||||||
self.renderer.create_screen(observation.shape[0], observation.shape[1])
|
|
||||||
self.renderer.render_image(observation)
|
|
||||||
|
|
||||||
return observation.astype('uint8')
|
return observation.astype('uint8')
|
||||||
else:
|
else:
|
||||||
if self.tp.env.normalize_observation:
|
if self.tp.env.normalize_observation:
|
||||||
@@ -15,7 +15,7 @@
|
|||||||
#
|
#
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from agents import imitation_agent
|
from coach.agents import imitation_agent
|
||||||
|
|
||||||
|
|
||||||
# Behavioral Cloning Agent
|
# Behavioral Cloning Agent
|
||||||
@@ -15,8 +15,9 @@
|
|||||||
#
|
#
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from agents import value_optimization_agent as voa
|
from coach.agents import value_optimization_agent as voa
|
||||||
import utils
|
from coach import utils
|
||||||
|
|
||||||
|
|
||||||
# Bootstrapped DQN - https://arxiv.org/pdf/1602.04621.pdf
|
# Bootstrapped DQN - https://arxiv.org/pdf/1602.04621.pdf
|
||||||
class BootstrappedDQNAgent(voa.ValueOptimizationAgent):
|
class BootstrappedDQNAgent(voa.ValueOptimizationAgent):
|
||||||
@@ -15,7 +15,7 @@
|
|||||||
#
|
#
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from agents import value_optimization_agent as voa
|
from coach.agents import value_optimization_agent as voa
|
||||||
|
|
||||||
|
|
||||||
# Categorical Deep Q Network - https://arxiv.org/pdf/1707.06887.pdf
|
# Categorical Deep Q Network - https://arxiv.org/pdf/1707.06887.pdf
|
||||||
@@ -19,10 +19,10 @@ from random import shuffle
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from agents import actor_critic_agent as aca
|
from coach.agents import actor_critic_agent as aca
|
||||||
from agents import policy_optimization_agent as poa
|
from coach.agents import policy_optimization_agent as poa
|
||||||
import logger
|
from coach import logger
|
||||||
import utils
|
from coach import utils
|
||||||
|
|
||||||
|
|
||||||
# Clipped Proximal Policy Optimization - https://arxiv.org/abs/1707.06347
|
# Clipped Proximal Policy Optimization - https://arxiv.org/abs/1707.06347
|
||||||
@@ -17,11 +17,11 @@ import copy
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from agents import actor_critic_agent as aca
|
from coach.agents import actor_critic_agent as aca
|
||||||
from agents import agent
|
from coach.agents import agent
|
||||||
from architectures import network_wrapper as nw
|
from coach.architectures import network_wrapper as nw
|
||||||
import configurations as conf
|
from coach import configurations as conf
|
||||||
import utils
|
from coach import utils
|
||||||
|
|
||||||
|
|
||||||
# Deep Deterministic Policy Gradients Network - https://arxiv.org/pdf/1509.02971.pdf
|
# Deep Deterministic Policy Gradients Network - https://arxiv.org/pdf/1509.02971.pdf
|
||||||
@@ -15,7 +15,7 @@
|
|||||||
#
|
#
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from agents import value_optimization_agent as voa
|
from coach.agents import value_optimization_agent as voa
|
||||||
|
|
||||||
|
|
||||||
# Double DQN - https://arxiv.org/abs/1509.06461
|
# Double DQN - https://arxiv.org/abs/1509.06461
|
||||||
@@ -15,9 +15,9 @@
|
|||||||
#
|
#
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from agents import agent
|
from coach.agents import agent
|
||||||
from architectures import network_wrapper as nw
|
from coach.architectures import network_wrapper as nw
|
||||||
import utils
|
from coach import utils
|
||||||
|
|
||||||
|
|
||||||
# Direct Future Prediction Agent - http://vladlen.info/papers/learning-to-act.pdf
|
# Direct Future Prediction Agent - http://vladlen.info/papers/learning-to-act.pdf
|
||||||
@@ -15,7 +15,7 @@
|
|||||||
#
|
#
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from agents import value_optimization_agent as voa
|
from coach.agents import value_optimization_agent as voa
|
||||||
|
|
||||||
|
|
||||||
# Distributional Deep Q Network - https://arxiv.org/pdf/1707.06887.pdf
|
# Distributional Deep Q Network - https://arxiv.org/pdf/1707.06887.pdf
|
||||||
@@ -15,7 +15,7 @@
|
|||||||
#
|
#
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from agents import value_optimization_agent as voa
|
from coach.agents import value_optimization_agent as voa
|
||||||
|
|
||||||
|
|
||||||
# Deep Q Network - https://www.cs.toronto.edu/~vmnih/docs/dqn.pdf
|
# Deep Q Network - https://www.cs.toronto.edu/~vmnih/docs/dqn.pdf
|
||||||
@@ -19,9 +19,9 @@ import os
|
|||||||
import pygame
|
import pygame
|
||||||
from pandas.io import pickle
|
from pandas.io import pickle
|
||||||
|
|
||||||
from agents import agent
|
from coach.agents import agent
|
||||||
import logger
|
from coach import logger
|
||||||
import utils
|
from coach import utils
|
||||||
|
|
||||||
|
|
||||||
class HumanAgent(agent.Agent):
|
class HumanAgent(agent.Agent):
|
||||||
@@ -15,10 +15,10 @@
|
|||||||
#
|
#
|
||||||
import collections
|
import collections
|
||||||
|
|
||||||
from agents import agent
|
from coach.agents import agent
|
||||||
from architectures import network_wrapper as nw
|
from coach.architectures import network_wrapper as nw
|
||||||
import utils
|
from coach import utils
|
||||||
import logging
|
from coach import logger
|
||||||
|
|
||||||
|
|
||||||
# Imitation Agent
|
# Imitation Agent
|
||||||
@@ -55,7 +55,7 @@ class ImitationAgent(agent.Agent):
|
|||||||
# log to screen
|
# log to screen
|
||||||
if phase == utils.RunPhase.TRAIN:
|
if phase == utils.RunPhase.TRAIN:
|
||||||
# for the training phase - we log during the episode to visualize the progress in training
|
# for the training phase - we log during the episode to visualize the progress in training
|
||||||
logging.screen.log_dict(
|
logger.screen.log_dict(
|
||||||
collections.OrderedDict([
|
collections.OrderedDict([
|
||||||
("Worker", self.task_id),
|
("Worker", self.task_id),
|
||||||
("Episode", self.current_episode),
|
("Episode", self.current_episode),
|
||||||
@@ -65,5 +65,5 @@ class ImitationAgent(agent.Agent):
|
|||||||
prefix="Training"
|
prefix="Training"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# for the evaluation phase - logging as in regular RL
|
# for the evaluation phase - logger as in regular RL
|
||||||
agent.Agent.log_to_screen(self, phase)
|
agent.Agent.log_to_screen(self, phase)
|
||||||
@@ -15,7 +15,7 @@
|
|||||||
#
|
#
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from agents import value_optimization_agent as voa
|
from coach.agents import value_optimization_agent as voa
|
||||||
|
|
||||||
|
|
||||||
class MixedMonteCarloAgent(voa.ValueOptimizationAgent):
|
class MixedMonteCarloAgent(voa.ValueOptimizationAgent):
|
||||||
@@ -15,10 +15,10 @@
|
|||||||
#
|
#
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from agents import value_optimization_agent as voa
|
from coach.agents import value_optimization_agent as voa
|
||||||
from agents import policy_optimization_agent as poa
|
from coach.agents import policy_optimization_agent as poa
|
||||||
import logger
|
from coach import logger
|
||||||
import utils
|
from coach import utils
|
||||||
|
|
||||||
|
|
||||||
# N Step Q Learning Agent - https://arxiv.org/abs/1602.01783
|
# N Step Q Learning Agent - https://arxiv.org/abs/1602.01783
|
||||||
@@ -15,8 +15,8 @@
|
|||||||
#
|
#
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from agents.value_optimization_agent import ValueOptimizationAgent
|
from coach.agents.value_optimization_agent import ValueOptimizationAgent
|
||||||
import utils
|
from coach import utils
|
||||||
|
|
||||||
|
|
||||||
# Normalized Advantage Functions - https://arxiv.org/pdf/1603.00748.pdf
|
# Normalized Advantage Functions - https://arxiv.org/pdf/1603.00748.pdf
|
||||||
@@ -13,9 +13,9 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
from agents import value_optimization_agent as voa
|
from coach.agents import value_optimization_agent as voa
|
||||||
from logger import screen
|
from coach.logger import screen
|
||||||
import utils
|
from coach import utils
|
||||||
|
|
||||||
|
|
||||||
# Neural Episodic Control - https://arxiv.org/pdf/1703.01988.pdf
|
# Neural Episodic Control - https://arxiv.org/pdf/1703.01988.pdf
|
||||||
@@ -15,7 +15,7 @@
|
|||||||
#
|
#
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from agents import value_optimization_agent as voa
|
from coach.agents import value_optimization_agent as voa
|
||||||
|
|
||||||
|
|
||||||
# Persistent Advantage Learning - https://arxiv.org/pdf/1512.04860.pdf
|
# Persistent Advantage Learning - https://arxiv.org/pdf/1512.04860.pdf
|
||||||
@@ -15,9 +15,9 @@
|
|||||||
#
|
#
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from agents import policy_optimization_agent as poa
|
from coach.agents import policy_optimization_agent as poa
|
||||||
import logger
|
from coach import logger
|
||||||
import utils
|
from coach import utils
|
||||||
|
|
||||||
|
|
||||||
class PolicyGradientsAgent(poa.PolicyOptimizationAgent):
|
class PolicyGradientsAgent(poa.PolicyOptimizationAgent):
|
||||||
@@ -17,10 +17,10 @@ import collections
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from agents import agent
|
from coach.agents import agent
|
||||||
from architectures import network_wrapper as nw
|
from coach.architectures import network_wrapper as nw
|
||||||
import logger
|
from coach import logger
|
||||||
import utils
|
from coach import utils
|
||||||
|
|
||||||
|
|
||||||
class PolicyGradientRescaler(utils.Enum):
|
class PolicyGradientRescaler(utils.Enum):
|
||||||
@@ -18,12 +18,12 @@ import copy
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from agents import actor_critic_agent as aca
|
from coach.agents import actor_critic_agent as aca
|
||||||
from agents import policy_optimization_agent as poa
|
from coach.agents import policy_optimization_agent as poa
|
||||||
from architectures import network_wrapper as nw
|
from coach.architectures import network_wrapper as nw
|
||||||
import configurations
|
from coach import configurations
|
||||||
import logger
|
from coach import logger
|
||||||
import utils
|
from coach import utils
|
||||||
|
|
||||||
|
|
||||||
# Proximal Policy Optimization - https://arxiv.org/pdf/1707.06347.pdf
|
# Proximal Policy Optimization - https://arxiv.org/pdf/1707.06347.pdf
|
||||||
@@ -15,7 +15,7 @@
|
|||||||
#
|
#
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from agents import value_optimization_agent as voa
|
from coach.agents import value_optimization_agent as voa
|
||||||
|
|
||||||
|
|
||||||
# Quantile Regression Deep Q Network - https://arxiv.org/pdf/1710.10044v1.pdf
|
# Quantile Regression Deep Q Network - https://arxiv.org/pdf/1710.10044v1.pdf
|
||||||
@@ -15,9 +15,9 @@
|
|||||||
#
|
#
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from agents import agent
|
from coach.agents import agent
|
||||||
from architectures import network_wrapper as nw
|
from coach.architectures import network_wrapper as nw
|
||||||
import utils
|
from coach import utils
|
||||||
|
|
||||||
|
|
||||||
class ValueOptimizationAgent(agent.Agent):
|
class ValueOptimizationAgent(agent.Agent):
|
||||||
@@ -13,7 +13,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
import logger
|
from coach import logger
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from architectures.tensorflow_components import general_network as ts_gn
|
from architectures.tensorflow_components import general_network as ts_gn
|
||||||
@@ -16,8 +16,8 @@
|
|||||||
import ngraph as ng
|
import ngraph as ng
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from architectures import architecture
|
from coach.architectures import architecture
|
||||||
import utils
|
from coach import utils
|
||||||
|
|
||||||
|
|
||||||
class NeonArchitecture(architecture.Architecture):
|
class NeonArchitecture(architecture.Architecture):
|
||||||
@@ -17,11 +17,11 @@ import ngraph as ng
|
|||||||
from ngraph.frontends import neon
|
from ngraph.frontends import neon
|
||||||
from ngraph.util import names as ngraph_names
|
from ngraph.util import names as ngraph_names
|
||||||
|
|
||||||
from architectures.neon_components import architecture
|
from coach.architectures.neon_components import architecture
|
||||||
from architectures.neon_components import embedders
|
from coach.architectures.neon_components import embedders
|
||||||
from architectures.neon_components import middleware
|
from coach.architectures.neon_components import middleware
|
||||||
from architectures.neon_components import heads
|
from coach.architectures.neon_components import heads
|
||||||
import configurations as conf
|
from coach import configurations as conf
|
||||||
|
|
||||||
|
|
||||||
class GeneralNeonNetwork(architecture.NeonArchitecture):
|
class GeneralNeonNetwork(architecture.NeonArchitecture):
|
||||||
@@ -17,8 +17,8 @@ import ngraph as ng
|
|||||||
from ngraph.frontends import neon
|
from ngraph.frontends import neon
|
||||||
from ngraph.util import names as ngraph_names
|
from ngraph.util import names as ngraph_names
|
||||||
|
|
||||||
import utils
|
from coach import utils
|
||||||
from architectures.neon_components import losses
|
from coach.architectures.neon_components import losses
|
||||||
|
|
||||||
|
|
||||||
class Head(object):
|
class Head(object):
|
||||||
@@ -16,16 +16,16 @@
|
|||||||
import os
|
import os
|
||||||
import collections
|
import collections
|
||||||
|
|
||||||
import configurations as conf
|
from coach import configurations as conf
|
||||||
import logger
|
from coach import logger
|
||||||
try:
|
try:
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
from architectures.tensorflow_components import general_network as tf_net #import GeneralTensorFlowNetwork
|
from coach.architectures.tensorflow_components import general_network as tf_net
|
||||||
except ImportError:
|
except ImportError:
|
||||||
logger.failed_imports.append("TensorFlow")
|
logger.failed_imports.append("TensorFlow")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from architectures.neon_components import general_network as neon_net
|
from coach.architectures.neon_components import general_network as neon_net
|
||||||
except ImportError:
|
except ImportError:
|
||||||
logger.failed_imports.append("Neon")
|
logger.failed_imports.append("Neon")
|
||||||
|
|
||||||
@@ -17,9 +17,10 @@ import time
|
|||||||
|
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
from architectures import architecture
|
from coach.architectures import architecture
|
||||||
import configurations as conf
|
from coach import configurations as conf
|
||||||
import utils
|
from coach import utils
|
||||||
|
|
||||||
|
|
||||||
def variable_summaries(var):
|
def variable_summaries(var):
|
||||||
"""Attach a lot of summaries to a Tensor (for TensorBoard visualization)."""
|
"""Attach a lot of summaries to a Tensor (for TensorBoard visualization)."""
|
||||||
@@ -36,6 +37,7 @@ def variable_summaries(var):
|
|||||||
tf.summary.scalar('min', tf.reduce_min(var))
|
tf.summary.scalar('min', tf.reduce_min(var))
|
||||||
tf.summary.histogram('histogram', var)
|
tf.summary.histogram('histogram', var)
|
||||||
|
|
||||||
|
|
||||||
class TensorFlowArchitecture(architecture.Architecture):
|
class TensorFlowArchitecture(architecture.Architecture):
|
||||||
def __init__(self, tuning_parameters, name="", global_network=None, network_is_local=True):
|
def __init__(self, tuning_parameters, name="", global_network=None, network_is_local=True):
|
||||||
"""
|
"""
|
||||||
@@ -15,7 +15,7 @@
|
|||||||
#
|
#
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
from configurations import EmbedderComplexity
|
from coach.configurations import EmbedderComplexity
|
||||||
|
|
||||||
|
|
||||||
class InputEmbedder(object):
|
class InputEmbedder(object):
|
||||||
@@ -15,11 +15,11 @@
|
|||||||
#
|
#
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
from architectures.tensorflow_components import architecture
|
from coach.architectures.tensorflow_components import architecture
|
||||||
from architectures.tensorflow_components import embedders
|
from coach.architectures.tensorflow_components import embedders
|
||||||
from architectures.tensorflow_components import middleware
|
from coach.architectures.tensorflow_components import middleware
|
||||||
from architectures.tensorflow_components import heads
|
from coach.architectures.tensorflow_components import heads
|
||||||
import configurations as conf
|
from coach import configurations as conf
|
||||||
|
|
||||||
|
|
||||||
class GeneralTensorFlowNetwork(architecture.TensorFlowArchitecture):
|
class GeneralTensorFlowNetwork(architecture.TensorFlowArchitecture):
|
||||||
@@ -16,7 +16,7 @@
|
|||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
import utils
|
from coach import utils
|
||||||
|
|
||||||
|
|
||||||
# Used to initialize weights for policy and value output layers
|
# Used to initialize weights for policy and value output layers
|
||||||
@@ -14,9 +14,9 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
import json
|
import json
|
||||||
|
|
||||||
import types
|
import types
|
||||||
import utils
|
|
||||||
|
from coach import utils
|
||||||
|
|
||||||
|
|
||||||
class Frameworks(utils.Enum):
|
class Frameworks(utils.Enum):
|
||||||
@@ -13,10 +13,10 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
from environments.gym_environment_wrapper import GymEnvironmentWrapper
|
from coach.environments.gym_environment_wrapper import GymEnvironmentWrapper
|
||||||
from environments.doom_environment_wrapper import DoomEnvironmentWrapper
|
from coach.environments.doom_environment_wrapper import DoomEnvironmentWrapper
|
||||||
from environments.carla_environment_wrapper import CarlaEnvironmentWrapper
|
from coach.environments.carla_environment_wrapper import CarlaEnvironmentWrapper
|
||||||
import utils
|
from coach import utils
|
||||||
|
|
||||||
|
|
||||||
class EnvTypes(utils.Enum):
|
class EnvTypes(utils.Enum):
|
||||||
@@ -6,7 +6,7 @@ import sys
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
import logger
|
from coach import logger
|
||||||
try:
|
try:
|
||||||
if 'CARLA_ROOT' in os.environ:
|
if 'CARLA_ROOT' in os.environ:
|
||||||
sys.path.append(os.path.join(os.environ.get('CARLA_ROOT'),
|
sys.path.append(os.path.join(os.environ.get('CARLA_ROOT'),
|
||||||
@@ -16,8 +16,8 @@ try:
|
|||||||
from carla import sensor as carla_sensor
|
from carla import sensor as carla_sensor
|
||||||
except ImportError:
|
except ImportError:
|
||||||
logger.failed_imports.append("CARLA")
|
logger.failed_imports.append("CARLA")
|
||||||
from environments import environment_wrapper as ew
|
from coach.environments import environment_wrapper as ew
|
||||||
import utils
|
from coach import utils
|
||||||
|
|
||||||
|
|
||||||
# enum of the available levels and their path
|
# enum of the available levels and their path
|
||||||
@@ -13,19 +13,18 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
import enum
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
import logger
|
from coach import logger
|
||||||
try:
|
try:
|
||||||
import vizdoom
|
import vizdoom
|
||||||
except ImportError:
|
except ImportError:
|
||||||
logger.failed_imports.append("ViZDoom")
|
logger.failed_imports.append("ViZDoom")
|
||||||
|
|
||||||
from environments import environment_wrapper as ew
|
from coach.environments import environment_wrapper as ew
|
||||||
import utils
|
from coach import utils
|
||||||
|
|
||||||
|
|
||||||
# enum of the available levels and their path
|
# enum of the available levels and their path
|
||||||
@@ -18,8 +18,7 @@ import time
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
import renderer
|
from coach import utils
|
||||||
import utils
|
|
||||||
|
|
||||||
|
|
||||||
class EnvironmentWrapper(object):
|
class EnvironmentWrapper(object):
|
||||||
@@ -62,7 +61,6 @@ class EnvironmentWrapper(object):
|
|||||||
self.wait_for_explicit_human_action = False
|
self.wait_for_explicit_human_action = False
|
||||||
self.is_rendered = self.is_rendered or self.human_control
|
self.is_rendered = self.is_rendered or self.human_control
|
||||||
self.game_is_open = True
|
self.game_is_open = True
|
||||||
self.renderer = renderer.Renderer()
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def measurements(self):
|
def measurements(self):
|
||||||
@@ -106,26 +104,6 @@ class EnvironmentWrapper(object):
|
|||||||
Get an action from the user keyboard
|
Get an action from the user keyboard
|
||||||
:return: action index
|
:return: action index
|
||||||
"""
|
"""
|
||||||
if self.wait_for_explicit_human_action:
|
|
||||||
while len(self.renderer.pressed_keys) == 0:
|
|
||||||
self.renderer.get_events()
|
|
||||||
|
|
||||||
if self.key_to_action == {}:
|
|
||||||
# the keys are the numbers on the keyboard corresponding to the action index
|
|
||||||
if len(self.renderer.pressed_keys) > 0:
|
|
||||||
action_idx = self.renderer.pressed_keys[0] - ord("1")
|
|
||||||
if 0 <= action_idx < self.action_space_size:
|
|
||||||
return action_idx
|
|
||||||
else:
|
|
||||||
# the keys are mapped through the environment to more intuitive keyboard keys
|
|
||||||
# key = tuple(self.renderer.pressed_keys)
|
|
||||||
# for key in self.renderer.pressed_keys:
|
|
||||||
for env_keys in self.key_to_action.keys():
|
|
||||||
if set(env_keys) == set(self.renderer.pressed_keys):
|
|
||||||
return self.key_to_action[env_keys]
|
|
||||||
|
|
||||||
# return the default action 0 so that the environment will continue running
|
|
||||||
return self.default_action
|
|
||||||
|
|
||||||
def step(self, action_idx):
|
def step(self, action_idx):
|
||||||
"""
|
"""
|
||||||
@@ -18,8 +18,8 @@ import random
|
|||||||
import gym
|
import gym
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from environments import environment_wrapper as ew
|
from coach.environments import environment_wrapper as ew
|
||||||
import utils
|
from coach import utils
|
||||||
|
|
||||||
|
|
||||||
class GymEnvironmentWrapper(ew.EnvironmentWrapper):
|
class GymEnvironmentWrapper(ew.EnvironmentWrapper):
|
||||||
@@ -13,18 +13,18 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
from exploration_policies.additive_noise import AdditiveNoise
|
from coach.exploration_policies.additive_noise import AdditiveNoise
|
||||||
from exploration_policies.approximated_thompson_sampling_using_dropout import ApproximatedThompsonSamplingUsingDropout
|
from coach.exploration_policies.approximated_thompson_sampling_using_dropout import ApproximatedThompsonSamplingUsingDropout
|
||||||
from exploration_policies.bayesian import Bayesian
|
from coach.exploration_policies.bayesian import Bayesian
|
||||||
from exploration_policies.boltzmann import Boltzmann
|
from coach.exploration_policies.boltzmann import Boltzmann
|
||||||
from exploration_policies.bootstrapped import Bootstrapped
|
from coach.exploration_policies.bootstrapped import Bootstrapped
|
||||||
from exploration_policies.categorical import Categorical
|
from coach.exploration_policies.categorical import Categorical
|
||||||
from exploration_policies.continuous_entropy import ContinuousEntropy
|
from coach.exploration_policies.continuous_entropy import ContinuousEntropy
|
||||||
from exploration_policies.e_greedy import EGreedy
|
from coach.exploration_policies.e_greedy import EGreedy
|
||||||
from exploration_policies.exploration_policy import ExplorationPolicy
|
from coach.exploration_policies.exploration_policy import ExplorationPolicy
|
||||||
from exploration_policies.greedy import Greedy
|
from coach.exploration_policies.greedy import Greedy
|
||||||
from exploration_policies.ou_process import OUProcess
|
from coach.exploration_policies.ou_process import OUProcess
|
||||||
from exploration_policies.thompson_sampling import ThompsonSampling
|
from coach.exploration_policies.thompson_sampling import ThompsonSampling
|
||||||
|
|
||||||
|
|
||||||
__all__ = [AdditiveNoise,
|
__all__ = [AdditiveNoise,
|
||||||
@@ -15,8 +15,8 @@
|
|||||||
#
|
#
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from exploration_policies import exploration_policy
|
from coach.exploration_policies import exploration_policy
|
||||||
import utils
|
from coach import utils
|
||||||
|
|
||||||
|
|
||||||
class AdditiveNoise(exploration_policy.ExplorationPolicy):
|
class AdditiveNoise(exploration_policy.ExplorationPolicy):
|
||||||
@@ -15,7 +15,7 @@
|
|||||||
#
|
#
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from exploration_policies import exploration_policy
|
from coach.exploration_policies import exploration_policy
|
||||||
|
|
||||||
|
|
||||||
class ApproximatedThompsonSamplingUsingDropout(exploration_policy.ExplorationPolicy):
|
class ApproximatedThompsonSamplingUsingDropout(exploration_policy.ExplorationPolicy):
|
||||||
@@ -15,8 +15,8 @@
|
|||||||
#
|
#
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from exploration_policies import exploration_policy
|
from coach.exploration_policies import exploration_policy
|
||||||
import utils
|
from coach import utils
|
||||||
|
|
||||||
|
|
||||||
class Bayesian(exploration_policy.ExplorationPolicy):
|
class Bayesian(exploration_policy.ExplorationPolicy):
|
||||||
@@ -15,8 +15,9 @@
|
|||||||
#
|
#
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from exploration_policies import exploration_policy
|
from coach.exploration_policies import exploration_policy
|
||||||
import utils
|
from coach import utils
|
||||||
|
|
||||||
|
|
||||||
class Boltzmann(exploration_policy.ExplorationPolicy):
|
class Boltzmann(exploration_policy.ExplorationPolicy):
|
||||||
def __init__(self, tuning_parameters):
|
def __init__(self, tuning_parameters):
|
||||||
@@ -15,7 +15,7 @@
|
|||||||
#
|
#
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from exploration_policies import e_greedy
|
from coach.exploration_policies import e_greedy
|
||||||
|
|
||||||
|
|
||||||
class Bootstrapped(e_greedy.EGreedy):
|
class Bootstrapped(e_greedy.EGreedy):
|
||||||
@@ -15,7 +15,7 @@
|
|||||||
#
|
#
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from exploration_policies import exploration_policy
|
from coach.exploration_policies import exploration_policy
|
||||||
|
|
||||||
|
|
||||||
class Categorical(exploration_policy.ExplorationPolicy):
|
class Categorical(exploration_policy.ExplorationPolicy):
|
||||||
@@ -13,7 +13,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
from exploration_policies import exploration_policy
|
from coach.exploration_policies import exploration_policy
|
||||||
|
|
||||||
|
|
||||||
class ContinuousEntropy(exploration_policy.ExplorationPolicy):
|
class ContinuousEntropy(exploration_policy.ExplorationPolicy):
|
||||||
@@ -15,8 +15,8 @@
|
|||||||
#
|
#
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from exploration_policies import exploration_policy
|
from coach.exploration_policies import exploration_policy
|
||||||
import utils
|
from coach import utils
|
||||||
|
|
||||||
|
|
||||||
class EGreedy(exploration_policy.ExplorationPolicy):
|
class EGreedy(exploration_policy.ExplorationPolicy):
|
||||||
@@ -13,7 +13,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
import utils
|
from coach import utils
|
||||||
|
|
||||||
|
|
||||||
class ExplorationPolicy(object):
|
class ExplorationPolicy(object):
|
||||||
@@ -15,7 +15,7 @@
|
|||||||
#
|
#
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from exploration_policies import exploration_policy
|
from coach.exploration_policies import exploration_policy
|
||||||
|
|
||||||
|
|
||||||
class Greedy(exploration_policy.ExplorationPolicy):
|
class Greedy(exploration_policy.ExplorationPolicy):
|
||||||
@@ -15,11 +15,12 @@
|
|||||||
#
|
#
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from exploration_policies import exploration_policy
|
from coach.exploration_policies import exploration_policy
|
||||||
|
|
||||||
# Based on on the description in:
|
# Based on on the description in:
|
||||||
# https://math.stackexchange.com/questions/1287634/implementing-ornstein-uhlenbeck-in-matlab
|
# https://math.stackexchange.com/questions/1287634/implementing-ornstein-uhlenbeck-in-matlab
|
||||||
|
|
||||||
|
|
||||||
# Ornstein-Uhlenbeck process
|
# Ornstein-Uhlenbeck process
|
||||||
class OUProcess(exploration_policy.ExplorationPolicy):
|
class OUProcess(exploration_policy.ExplorationPolicy):
|
||||||
def __init__(self, tuning_parameters):
|
def __init__(self, tuning_parameters):
|
||||||
@@ -15,7 +15,7 @@
|
|||||||
#
|
#
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from exploration_policies import exploration_policy
|
from coach.exploration_policies import exploration_policy
|
||||||
|
|
||||||
|
|
||||||
class ThompsonSampling(exploration_policy.ExplorationPolicy):
|
class ThompsonSampling(exploration_policy.ExplorationPolicy):
|
||||||
@@ -13,13 +13,13 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
from memories.differentiable_neural_dictionary import AnnoyDictionary
|
from coach.memories.differentiable_neural_dictionary import AnnoyDictionary
|
||||||
from memories.differentiable_neural_dictionary import AnnoyIndex
|
from coach.memories.differentiable_neural_dictionary import AnnoyIndex
|
||||||
from memories.differentiable_neural_dictionary import QDND
|
from coach.memories.differentiable_neural_dictionary import QDND
|
||||||
from memories.episodic_experience_replay import EpisodicExperienceReplay
|
from coach.memories.episodic_experience_replay import EpisodicExperienceReplay
|
||||||
from memories.memory import Episode
|
from coach.memories.memory import Episode
|
||||||
from memories.memory import Memory
|
from coach.memories.memory import Memory
|
||||||
from memories.memory import Transition
|
from coach.memories.memory import Transition
|
||||||
|
|
||||||
__all__ = [AnnoyDictionary,
|
__all__ = [AnnoyDictionary,
|
||||||
AnnoyIndex,
|
AnnoyIndex,
|
||||||
@@ -17,7 +17,7 @@ import typing
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from memories import memory
|
from coach.memories import memory
|
||||||
|
|
||||||
|
|
||||||
class EpisodicExperienceReplay(memory.Memory):
|
class EpisodicExperienceReplay(memory.Memory):
|
||||||
@@ -17,11 +17,11 @@ import ast
|
|||||||
import json
|
import json
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
import agents
|
from coach import agents
|
||||||
import configurations as conf
|
from coach import configurations as conf
|
||||||
import environments as env
|
from coach import environments as env
|
||||||
import exploration_policies as ep
|
from coach import exploration_policies as ep
|
||||||
import presets
|
from coach import presets
|
||||||
|
|
||||||
|
|
||||||
def json_to_preset(json_path):
|
def json_to_preset(json_path):
|
||||||
972
dashboard.py
972
dashboard.py
@@ -1,972 +0,0 @@
|
|||||||
#
|
|
||||||
# Copyright (c) 2017 Intel Corporation
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
#
|
|
||||||
|
|
||||||
"""
|
|
||||||
To run Coach Dashboard, run the following command:
|
|
||||||
python3 dashboard.py
|
|
||||||
"""
|
|
||||||
|
|
||||||
import colorsys
|
|
||||||
import datetime
|
|
||||||
import enum
|
|
||||||
import itertools
|
|
||||||
import os
|
|
||||||
import random
|
|
||||||
|
|
||||||
from bokeh import palettes
|
|
||||||
from bokeh import layouts as bl
|
|
||||||
from bokeh import models as bm
|
|
||||||
from bokeh.models import widgets as bw
|
|
||||||
from bokeh import plotting as bp
|
|
||||||
import numpy as np
|
|
||||||
import pandas as pd
|
|
||||||
from pandas.io import pandas_common
|
|
||||||
import wx
|
|
||||||
|
|
||||||
import utils
|
|
||||||
|
|
||||||
|
|
||||||
class DialogApp(wx.App):
|
|
||||||
def getFileDialog(self):
|
|
||||||
with wx.FileDialog(None, "Open CSV file", wildcard="CSV files (*.csv)|*.csv",
|
|
||||||
style=wx.FD_OPEN | wx.FD_FILE_MUST_EXIST | wx.FD_CHANGE_DIR | wx.FD_MULTIPLE) as fileDialog:
|
|
||||||
if fileDialog.ShowModal() == wx.ID_CANCEL:
|
|
||||||
return None # the user changed their mind
|
|
||||||
else:
|
|
||||||
# Proceed loading the file chosen by the user
|
|
||||||
return fileDialog.GetPaths()
|
|
||||||
|
|
||||||
def getDirDialog(self):
|
|
||||||
with wx.DirDialog (None, "Choose input directory", "",
|
|
||||||
style=wx.FD_OPEN | wx.FD_FILE_MUST_EXIST | wx.FD_CHANGE_DIR) as dirDialog:
|
|
||||||
if dirDialog.ShowModal() == wx.ID_CANCEL:
|
|
||||||
return None # the user changed their mind
|
|
||||||
else:
|
|
||||||
# Proceed loading the dir chosen by the user
|
|
||||||
return dirDialog.GetPath()
|
|
||||||
class Signal:
|
|
||||||
def __init__(self, name, parent):
|
|
||||||
self.name = name
|
|
||||||
self.full_name = "{}/{}".format(parent.filename, self.name)
|
|
||||||
self.selected = False
|
|
||||||
self.color = random.choice(palettes.Dark2[8])
|
|
||||||
self.line = None
|
|
||||||
self.bands = None
|
|
||||||
self.bokeh_source = parent.bokeh_source
|
|
||||||
self.min_val = 0
|
|
||||||
self.max_val = 0
|
|
||||||
self.axis = 'default'
|
|
||||||
self.sub_signals = []
|
|
||||||
for name in self.bokeh_source.data.keys():
|
|
||||||
if (len(name.split('/')) == 1 and name == self.name) or '/'.join(name.split('/')[:-1]) == self.name:
|
|
||||||
self.sub_signals.append(name)
|
|
||||||
if len(self.sub_signals) > 1:
|
|
||||||
self.mean_signal = utils.squeeze_list([name for name in self.sub_signals if 'Mean' in name.split('/')[-1]])
|
|
||||||
self.stdev_signal = utils.squeeze_list([name for name in self.sub_signals if 'Stdev' in name.split('/')[-1]])
|
|
||||||
self.min_signal = utils.squeeze_list([name for name in self.sub_signals if 'Min' in name.split('/')[-1]])
|
|
||||||
self.max_signal = utils.squeeze_list([name for name in self.sub_signals if 'Max' in name.split('/')[-1]])
|
|
||||||
else:
|
|
||||||
self.mean_signal = utils.squeeze_list(self.name)
|
|
||||||
self.stdev_signal = None
|
|
||||||
self.min_signal = None
|
|
||||||
self.max_signal = None
|
|
||||||
self.has_bollinger_bands = False
|
|
||||||
if self.mean_signal and self.stdev_signal and self.min_signal and self.max_signal:
|
|
||||||
self.has_bollinger_bands = True
|
|
||||||
self.show_bollinger_bands = False
|
|
||||||
self.bollinger_bands_source = None
|
|
||||||
self.update_range()
|
|
||||||
|
|
||||||
def set_color(self, color):
|
|
||||||
self.color = color
|
|
||||||
if self.line:
|
|
||||||
self.line.glyph.line_color = color
|
|
||||||
if self.bands:
|
|
||||||
self.bands.glyph.fill_color = color
|
|
||||||
|
|
||||||
def set_selected(self, val):
|
|
||||||
global current_color
|
|
||||||
if self.selected != val:
|
|
||||||
self.selected = val
|
|
||||||
if self.line:
|
|
||||||
# self.set_color(palettes.Dark2[8][current_color])
|
|
||||||
# current_color = (current_color + 1) % len(palettes.Dark2[8])
|
|
||||||
self.line.visible = self.selected
|
|
||||||
if self.bands:
|
|
||||||
self.bands.visible = self.selected and self.show_bollinger_bands
|
|
||||||
elif self.selected:
|
|
||||||
# lazy plotting - plot only when selected for the first time
|
|
||||||
show_spinner()
|
|
||||||
self.set_color(palettes.Dark2[8][current_color])
|
|
||||||
current_color = (current_color + 1) % len(palettes.Dark2[8])
|
|
||||||
if self.has_bollinger_bands:
|
|
||||||
self.set_bands_source()
|
|
||||||
self.create_bands()
|
|
||||||
self.line = plot.line('index', self.mean_signal, source=self.bokeh_source,
|
|
||||||
line_color=self.color, line_width=2)
|
|
||||||
self.line.visible = True
|
|
||||||
hide_spinner()
|
|
||||||
|
|
||||||
def set_dash(self, dash):
|
|
||||||
self.line.glyph.line_dash = dash
|
|
||||||
|
|
||||||
def create_bands(self):
|
|
||||||
self.bands = plot.patch(x='band_x', y='band_y', source=self.bollinger_bands_source,
|
|
||||||
color=self.color, fill_alpha=0.4, alpha=0.1, line_width=0)
|
|
||||||
self.bands.visible = self.show_bollinger_bands
|
|
||||||
# self.min_line = plot.line('index', self.min_signal, source=self.bokeh_source,
|
|
||||||
# line_color=self.color, line_width=3, line_dash="4 4")
|
|
||||||
# self.max_line = plot.line('index', self.max_signal, source=self.bokeh_source,
|
|
||||||
# line_color=self.color, line_width=3, line_dash="4 4")
|
|
||||||
# self.min_line.visible = self.show_bollinger_bands
|
|
||||||
# self.max_line.visible = self.show_bollinger_bands
|
|
||||||
|
|
||||||
def set_bands_source(self):
|
|
||||||
x_ticks = self.bokeh_source.data['index']
|
|
||||||
mean_values = self.bokeh_source.data[self.mean_signal]
|
|
||||||
stdev_values = self.bokeh_source.data[self.stdev_signal]
|
|
||||||
band_x = np.append(x_ticks, x_ticks[::-1])
|
|
||||||
band_y = np.append(mean_values - stdev_values, mean_values[::-1] + stdev_values[::-1])
|
|
||||||
source_data = {'band_x': band_x, 'band_y': band_y}
|
|
||||||
if self.bollinger_bands_source:
|
|
||||||
self.bollinger_bands_source.data = source_data
|
|
||||||
else:
|
|
||||||
self.bollinger_bands_source = bm.ColumnDataSource(source_data)
|
|
||||||
|
|
||||||
def change_bollinger_bands_state(self, new_state):
|
|
||||||
self.show_bollinger_bands = new_state
|
|
||||||
if self.bands and self.selected:
|
|
||||||
self.bands.visible = new_state
|
|
||||||
# self.min_line.visible = new_state
|
|
||||||
# self.max_line.visible = new_state
|
|
||||||
|
|
||||||
def update_range(self):
|
|
||||||
self.min_val = np.min(self.bokeh_source.data[self.mean_signal])
|
|
||||||
self.max_val = np.max(self.bokeh_source.data[self.mean_signal])
|
|
||||||
|
|
||||||
def set_axis(self, axis):
|
|
||||||
self.axis = axis
|
|
||||||
self.line.y_range_name = axis
|
|
||||||
|
|
||||||
def toggle_axis(self):
|
|
||||||
if self.axis == 'default':
|
|
||||||
self.set_axis('secondary')
|
|
||||||
else:
|
|
||||||
self.set_axis('default')
|
|
||||||
|
|
||||||
|
|
||||||
class SignalsFileBase:
|
|
||||||
def __init__(self):
|
|
||||||
self.full_csv_path = ""
|
|
||||||
self.dir = ""
|
|
||||||
self.filename = ""
|
|
||||||
self.signals_averaging_window = 1
|
|
||||||
self.show_bollinger_bands = False
|
|
||||||
self.csv = None
|
|
||||||
self.bokeh_source = None
|
|
||||||
self.bokeh_source_orig = None
|
|
||||||
self.last_modified = None
|
|
||||||
self.signals = {}
|
|
||||||
self.separate_files = False
|
|
||||||
|
|
||||||
def load_csv(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def update_source_and_signals(self):
|
|
||||||
# create bokeh data sources
|
|
||||||
self.bokeh_source_orig = bm.ColumnDataSource(self.csv)
|
|
||||||
self.bokeh_source_orig.data['index'] = self.bokeh_source_orig.data[x_axis]
|
|
||||||
|
|
||||||
if self.bokeh_source is None:
|
|
||||||
self.bokeh_source = bm.ColumnDataSource(self.csv)
|
|
||||||
else:
|
|
||||||
# self.bokeh_source.data = self.bokeh_source_orig.data
|
|
||||||
# smooth the data if necessary
|
|
||||||
self.change_averaging_window(self.signals_averaging_window, force=True)
|
|
||||||
|
|
||||||
# create all the signals
|
|
||||||
if len(self.signals.keys()) == 0:
|
|
||||||
self.signals = {}
|
|
||||||
unique_signal_names = []
|
|
||||||
for name in self.csv.columns:
|
|
||||||
if len(name.split('/')) == 1:
|
|
||||||
unique_signal_names.append(name)
|
|
||||||
else:
|
|
||||||
unique_signal_names.append('/'.join(name.split('/')[:-1]))
|
|
||||||
unique_signal_names = list(set(unique_signal_names))
|
|
||||||
for signal_name in unique_signal_names:
|
|
||||||
self.signals[signal_name] = Signal(signal_name, self)
|
|
||||||
|
|
||||||
def load(self):
|
|
||||||
self.load_csv()
|
|
||||||
self.update_source_and_signals()
|
|
||||||
|
|
||||||
def reload_data(self, signals):
|
|
||||||
# this function is a workaround to reload the data of all the signals
|
|
||||||
# if the data doesn't change, bokeh does not refreshes the line
|
|
||||||
self.change_averaging_window(self.signals_averaging_window + 1, force=True)
|
|
||||||
self.change_averaging_window(self.signals_averaging_window - 1, force=True)
|
|
||||||
|
|
||||||
def change_averaging_window(self, new_size, force=False, signals=None):
|
|
||||||
if force or self.signals_averaging_window != new_size:
|
|
||||||
self.signals_averaging_window = new_size
|
|
||||||
win = np.ones(new_size) / new_size
|
|
||||||
temp_data = self.bokeh_source_orig.data.copy()
|
|
||||||
for col in self.bokeh_source.data.keys():
|
|
||||||
if col == 'index' or col in x_axis_options \
|
|
||||||
or (signals and not any(col in signal for signal in signals)):
|
|
||||||
temp_data[col] = temp_data[col][:-new_size]
|
|
||||||
continue
|
|
||||||
temp_data[col] = np.convolve(self.bokeh_source_orig.data[col], win, mode='same')[:-new_size]
|
|
||||||
self.bokeh_source.data = temp_data
|
|
||||||
|
|
||||||
# smooth bollinger bands
|
|
||||||
for signal in self.signals.values():
|
|
||||||
if signal.has_bollinger_bands:
|
|
||||||
signal.set_bands_source()
|
|
||||||
|
|
||||||
def hide_all_signals(self):
|
|
||||||
for signal_name in self.signals.keys():
|
|
||||||
self.set_signal_selection(signal_name, False)
|
|
||||||
|
|
||||||
def set_signal_selection(self, signal_name, val):
|
|
||||||
self.signals[signal_name].set_selected(val)
|
|
||||||
|
|
||||||
def change_bollinger_bands_state(self, new_state):
|
|
||||||
self.show_bollinger_bands = new_state
|
|
||||||
for signal in self.signals.values():
|
|
||||||
signal.change_bollinger_bands_state(new_state)
|
|
||||||
|
|
||||||
def file_was_modified_on_disk(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def get_range_of_selected_signals_on_axis(self, axis, selected_signal=None):
|
|
||||||
max_val = -float('inf')
|
|
||||||
min_val = float('inf')
|
|
||||||
for signal in self.signals.values():
|
|
||||||
if (selected_signal and signal.name == selected_signal) or (signal.selected and signal.axis == axis):
|
|
||||||
max_val = max(max_val, signal.max_val)
|
|
||||||
min_val = min(min_val, signal.min_val)
|
|
||||||
return min_val, max_val
|
|
||||||
|
|
||||||
def get_selected_signals(self):
|
|
||||||
signals = []
|
|
||||||
for signal in self.signals.values():
|
|
||||||
if signal.selected:
|
|
||||||
signals.append(signal)
|
|
||||||
return signals
|
|
||||||
|
|
||||||
def show_files_separately(self, val):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class SignalsFile(SignalsFileBase):
|
|
||||||
def __init__(self, csv_path, load=True):
|
|
||||||
SignalsFileBase.__init__(self)
|
|
||||||
self.full_csv_path = csv_path
|
|
||||||
self.dir, self.filename, _ = utils.break_file_path(csv_path)
|
|
||||||
if load:
|
|
||||||
self.load()
|
|
||||||
# this helps set the correct x axis
|
|
||||||
self.change_averaging_window(1, force=True)
|
|
||||||
|
|
||||||
def load_csv(self):
|
|
||||||
# load csv and fix sparse data.
|
|
||||||
# csv can be in the middle of being written so we use try - except
|
|
||||||
self.csv = None
|
|
||||||
while self.csv is None:
|
|
||||||
try:
|
|
||||||
self.csv = pd.read_csv(self.full_csv_path)
|
|
||||||
break
|
|
||||||
except pandas_common.EmptyDataError:
|
|
||||||
self.csv = None
|
|
||||||
continue
|
|
||||||
self.csv = self.csv.interpolate()
|
|
||||||
self.csv.fillna(value=0, inplace=True)
|
|
||||||
|
|
||||||
self.last_modified = os.path.getmtime(self.full_csv_path)
|
|
||||||
|
|
||||||
def file_was_modified_on_disk(self):
|
|
||||||
return self.last_modified != os.path.getmtime(self.full_csv_path)
|
|
||||||
|
|
||||||
|
|
||||||
class SignalsFilesGroup(SignalsFileBase):
|
|
||||||
def __init__(self, csv_paths):
|
|
||||||
SignalsFileBase.__init__(self)
|
|
||||||
self.full_csv_paths = csv_paths
|
|
||||||
self.signals_files = []
|
|
||||||
if len(csv_paths) == 1 and os.path.isdir(csv_paths[0]):
|
|
||||||
self.signals_files = [SignalsFile(str(file), load=False) for file in add_directory_csv_files(csv_paths[0])]
|
|
||||||
else:
|
|
||||||
for csv_path in csv_paths:
|
|
||||||
if os.path.isdir(csv_path):
|
|
||||||
self.signals_files.append(SignalsFilesGroup(add_directory_csv_files(csv_path)))
|
|
||||||
else:
|
|
||||||
self.signals_files.append(SignalsFile(str(csv_path), load=False))
|
|
||||||
if len(csv_paths) == 1:
|
|
||||||
# get the parent directory name (since the current directory is the timestamp directory)
|
|
||||||
self.dir = os.path.abspath(os.path.join(os.path.dirname(csv_paths[0]), '..'))
|
|
||||||
else:
|
|
||||||
# get the common directory for all the experiments
|
|
||||||
self.dir = os.path.dirname(os.path.commonprefix(csv_paths))
|
|
||||||
self.filename = '{} - Group({})'.format(os.path.basename(self.dir), len(self.signals_files))
|
|
||||||
self.load()
|
|
||||||
|
|
||||||
# this helps set the correct x axis
|
|
||||||
self.change_averaging_window(1, force=True)
|
|
||||||
|
|
||||||
def load_csv(self):
|
|
||||||
corrupted_files_idx = []
|
|
||||||
for idx, signal_file in enumerate(self.signals_files):
|
|
||||||
signal_file.load_csv()
|
|
||||||
if not all(option in signal_file.csv.keys() for option in x_axis_options):
|
|
||||||
print("Warning: {} file seems to be corrupted and does contain the necessary columns "
|
|
||||||
"and will not be rendered".format(signal_file.filename))
|
|
||||||
corrupted_files_idx.append(idx)
|
|
||||||
|
|
||||||
for file_idx in corrupted_files_idx:
|
|
||||||
del self.signals_files[file_idx]
|
|
||||||
|
|
||||||
# get the stats of all the columns
|
|
||||||
csv_group = pd.concat([signals_file.csv for signals_file in self.signals_files])
|
|
||||||
columns_to_remove = [s for s in csv_group.columns if '/Stdev' in s] + \
|
|
||||||
[s for s in csv_group.columns if '/Min' in s] + \
|
|
||||||
[s for s in csv_group.columns if '/Max' in s]
|
|
||||||
for col in columns_to_remove:
|
|
||||||
del csv_group[col]
|
|
||||||
csv_group = csv_group.groupby(csv_group.index)
|
|
||||||
self.csv_mean = csv_group.mean()
|
|
||||||
self.csv_mean.columns = [s + '/Mean' for s in self.csv_mean.columns]
|
|
||||||
self.csv_stdev = csv_group.std()
|
|
||||||
self.csv_stdev.columns = [s + '/Stdev' for s in self.csv_stdev.columns]
|
|
||||||
self.csv_min = csv_group.min()
|
|
||||||
self.csv_min.columns = [s + '/Min' for s in self.csv_min.columns]
|
|
||||||
self.csv_max = csv_group.max()
|
|
||||||
self.csv_max.columns = [s + '/Max' for s in self.csv_max.columns]
|
|
||||||
|
|
||||||
# get the indices from the file with the least number of indices and which is not an evaluation worker
|
|
||||||
file_with_min_indices = self.signals_files[0]
|
|
||||||
for signals_file in self.signals_files:
|
|
||||||
if signals_file.csv.shape[0] < file_with_min_indices.csv.shape[0] and \
|
|
||||||
'Training reward' in signals_file.csv.keys():
|
|
||||||
file_with_min_indices = signals_file
|
|
||||||
self.index_columns = file_with_min_indices.csv[x_axis_options]
|
|
||||||
|
|
||||||
# concat the stats and the indices columns
|
|
||||||
num_rows = file_with_min_indices.csv.shape[0]
|
|
||||||
self.csv = pd.concat([self.index_columns, self.csv_mean.head(num_rows), self.csv_stdev.head(num_rows),
|
|
||||||
self.csv_min.head(num_rows), self.csv_max.head(num_rows)], axis=1)
|
|
||||||
|
|
||||||
# remove the stat columns for the indices columns
|
|
||||||
columns_to_remove = [s + '/Mean' for s in x_axis_options] + \
|
|
||||||
[s + '/Stdev' for s in x_axis_options] + \
|
|
||||||
[s + '/Min' for s in x_axis_options] + \
|
|
||||||
[s + '/Max' for s in x_axis_options]
|
|
||||||
for col in columns_to_remove:
|
|
||||||
del self.csv[col]
|
|
||||||
|
|
||||||
# remove NaNs
|
|
||||||
self.csv.fillna(value=0, inplace=True) # removing this line will make bollinger bands fail
|
|
||||||
for key in self.csv.keys():
|
|
||||||
if 'Stdev' in key and 'Evaluation' not in key:
|
|
||||||
self.csv[key] = self.csv[key].fillna(value=0)
|
|
||||||
|
|
||||||
for signal_file in self.signals_files:
|
|
||||||
signal_file.update_source_and_signals()
|
|
||||||
|
|
||||||
def change_averaging_window(self, new_size, force=False, signals=None):
|
|
||||||
for signal_file in self.signals_files:
|
|
||||||
signal_file.change_averaging_window(new_size, force, signals)
|
|
||||||
SignalsFileBase.change_averaging_window(self, new_size, force, signals)
|
|
||||||
|
|
||||||
def set_signal_selection(self, signal_name, val):
|
|
||||||
self.show_files_separately(self.separate_files)
|
|
||||||
SignalsFileBase.set_signal_selection(self, signal_name, val)
|
|
||||||
|
|
||||||
def file_was_modified_on_disk(self):
|
|
||||||
for signal_file in self.signals_files:
|
|
||||||
if signal_file.file_was_modified_on_disk():
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
def show_files_separately(self, val):
|
|
||||||
self.separate_files = val
|
|
||||||
for signal in self.signals.values():
|
|
||||||
if signal.selected:
|
|
||||||
if val:
|
|
||||||
signal.set_dash("4 4")
|
|
||||||
else:
|
|
||||||
signal.set_dash("")
|
|
||||||
for signal_file in self.signals_files:
|
|
||||||
try:
|
|
||||||
if val:
|
|
||||||
signal_file.set_signal_selection(signal.name, signal.selected)
|
|
||||||
else:
|
|
||||||
signal_file.set_signal_selection(signal.name, False)
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class RunType(enum.Enum):
|
|
||||||
SINGLE_FOLDER_SINGLE_FILE = 1
|
|
||||||
SINGLE_FOLDER_MULTIPLE_FILES = 2
|
|
||||||
MULTIPLE_FOLDERS_SINGLE_FILES = 3
|
|
||||||
MULTIPLE_FOLDERS_MULTIPLE_FILES = 4
|
|
||||||
UNKNOWN = 0
|
|
||||||
|
|
||||||
|
|
||||||
class FolderType(enum.Enum):
|
|
||||||
SINGLE_FILE = 1
|
|
||||||
MULTIPLE_FILES = 2
|
|
||||||
MULTIPLE_FOLDERS = 3
|
|
||||||
EMPTY = 4
|
|
||||||
|
|
||||||
dialog = DialogApp()
|
|
||||||
|
|
||||||
# read data
|
|
||||||
patches = {}
|
|
||||||
signals_files = {}
|
|
||||||
selected_file = None
|
|
||||||
x_axis = 'Episode #'
|
|
||||||
x_axis_options = ['Episode #', 'Total steps', 'Wall-Clock Time']
|
|
||||||
current_color = 0
|
|
||||||
|
|
||||||
# spinner
|
|
||||||
root_dir = os.path.dirname(os.path.abspath(__file__))
|
|
||||||
with open(os.path.join(root_dir, 'spinner.css'), 'r') as f:
|
|
||||||
spinner_style = """<style>{}</style>""".format(f.read())
|
|
||||||
spinner_html = """<ul class="spinner"><li></li><li></li><li></li><li></li></ul>"""
|
|
||||||
spinner = bw.Div(text="""""")
|
|
||||||
|
|
||||||
# file refresh time placeholder
|
|
||||||
refresh_info = bw.Div(text="""""", width=210)
|
|
||||||
|
|
||||||
# create figures
|
|
||||||
plot = bp.figure(plot_width=1200, plot_height=800,
|
|
||||||
tools='pan,box_zoom,wheel_zoom,crosshair,undo,redo,reset,save',
|
|
||||||
toolbar_location='above', x_axis_label='Episodes',
|
|
||||||
x_range=bm.Range1d(0, 10000), y_range=bm.Range1d(0, 100000))
|
|
||||||
plot.extra_y_ranges = {"secondary": bm.Range1d(start=-100, end=200)}
|
|
||||||
plot.add_layout(bm.LinearAxis(y_range_name="secondary"), 'right')
|
|
||||||
|
|
||||||
# legend
|
|
||||||
div = bw.Div(text="""""")
|
|
||||||
legend = bl.widgetbox([div])
|
|
||||||
|
|
||||||
bokeh_legend = bm.Legend(
|
|
||||||
items=[("12345678901234567890123456789012345678901234567890", [])], # 50 letters
|
|
||||||
# items=[(" ", [])], # 50 letters
|
|
||||||
location=(-20, 0), orientation="vertical",
|
|
||||||
border_line_color="black",
|
|
||||||
label_text_font_size={'value': '9pt'},
|
|
||||||
margin=30
|
|
||||||
)
|
|
||||||
plot.add_layout(bokeh_legend, "right")
|
|
||||||
|
|
||||||
|
|
||||||
def update_axis_range(name, range_placeholder):
|
|
||||||
max_val = -float('inf')
|
|
||||||
min_val = float('inf')
|
|
||||||
selected_signal = None
|
|
||||||
if name in x_axis_options:
|
|
||||||
selected_signal = name
|
|
||||||
for signals_file in signals_files.values():
|
|
||||||
curr_min_val, curr_max_val = signals_file.get_range_of_selected_signals_on_axis(name, selected_signal)
|
|
||||||
max_val = max(max_val, curr_max_val)
|
|
||||||
min_val = min(min_val, curr_min_val)
|
|
||||||
if min_val != float('inf'):
|
|
||||||
range = max_val - min_val
|
|
||||||
range_placeholder.start = min_val - 0.1 * range
|
|
||||||
range_placeholder.end = max_val + 0.1 * range
|
|
||||||
|
|
||||||
|
|
||||||
# update axes ranges
|
|
||||||
def update_ranges():
|
|
||||||
update_axis_range('default', plot.y_range)
|
|
||||||
update_axis_range('secondary', plot.extra_y_ranges['secondary'])
|
|
||||||
|
|
||||||
|
|
||||||
def get_all_selected_signals():
|
|
||||||
signals = []
|
|
||||||
for signals_file in signals_files.values():
|
|
||||||
signals += signals_file.get_selected_signals()
|
|
||||||
return signals
|
|
||||||
|
|
||||||
|
|
||||||
# update legend using the legend text dictionary
|
|
||||||
def update_legend():
|
|
||||||
legend_text = """<div></div>"""
|
|
||||||
selected_signals = get_all_selected_signals()
|
|
||||||
items = []
|
|
||||||
for signal in selected_signals:
|
|
||||||
side_sign = "<" if signal.axis == 'default' else ">"
|
|
||||||
legend_text += """<div style='color: {}'><b>{} {}</b></div>"""\
|
|
||||||
.format(signal.color, side_sign, signal.full_name)
|
|
||||||
items.append((signal.full_name, [signal.line]))
|
|
||||||
div.text = legend_text
|
|
||||||
# the visible=false => visible=true is a hack to make the legend render again
|
|
||||||
bokeh_legend.visible = False
|
|
||||||
bokeh_legend.items = items
|
|
||||||
bokeh_legend.visible = True
|
|
||||||
|
|
||||||
|
|
||||||
# select lines to display
|
|
||||||
def select_data(args, old, new):
|
|
||||||
if selected_file is None:
|
|
||||||
return
|
|
||||||
show_spinner()
|
|
||||||
selected_signals = new
|
|
||||||
for signal_name in selected_file.signals.keys():
|
|
||||||
is_selected = signal_name in selected_signals
|
|
||||||
selected_file.set_signal_selection(signal_name, is_selected)
|
|
||||||
|
|
||||||
# update axes ranges
|
|
||||||
update_ranges()
|
|
||||||
update_axis_range(x_axis, plot.x_range)
|
|
||||||
|
|
||||||
# update the legend
|
|
||||||
update_legend()
|
|
||||||
|
|
||||||
hide_spinner()
|
|
||||||
|
|
||||||
|
|
||||||
# add new lines to the plot
|
|
||||||
def plot_signals(signals_file, signals):
|
|
||||||
for idx, signal in enumerate(signals):
|
|
||||||
signal.line = plot.line('index', signal.name, source=signals_file.bokeh_source,
|
|
||||||
line_color=signal.color, line_width=2)
|
|
||||||
|
|
||||||
|
|
||||||
def open_file_dialog():
|
|
||||||
return dialog.getFileDialog()
|
|
||||||
|
|
||||||
|
|
||||||
def open_directory_dialog():
|
|
||||||
return dialog.getDirDialog()
|
|
||||||
|
|
||||||
|
|
||||||
def show_spinner():
|
|
||||||
spinner.text = spinner_style + spinner_html
|
|
||||||
|
|
||||||
|
|
||||||
def hide_spinner():
|
|
||||||
spinner.text = ""
|
|
||||||
|
|
||||||
|
|
||||||
# will create a group from the files
|
|
||||||
def create_files_group_signal(files):
|
|
||||||
global selected_file
|
|
||||||
signals_file = SignalsFilesGroup(files)
|
|
||||||
signals_files[signals_file.filename] = signals_file
|
|
||||||
|
|
||||||
filenames = [signals_file.filename]
|
|
||||||
files_selector.options += filenames
|
|
||||||
files_selector.value = filenames[0]
|
|
||||||
selected_file = signals_file
|
|
||||||
|
|
||||||
|
|
||||||
# load files from disk as a group
|
|
||||||
def load_files_group():
|
|
||||||
show_spinner()
|
|
||||||
files = open_file_dialog()
|
|
||||||
# no files selected
|
|
||||||
if not files or not files[0]:
|
|
||||||
hide_spinner()
|
|
||||||
return
|
|
||||||
|
|
||||||
change_displayed_doc()
|
|
||||||
|
|
||||||
if len(files) == 1:
|
|
||||||
create_files_signal(files)
|
|
||||||
else:
|
|
||||||
create_files_group_signal(files)
|
|
||||||
|
|
||||||
change_selected_signals_in_data_selector([""])
|
|
||||||
hide_spinner()
|
|
||||||
|
|
||||||
|
|
||||||
# classify the folder as containing a single file, multiple files or only folders
|
|
||||||
def classify_folder(dir_path):
|
|
||||||
files = [f for f in os.listdir(dir_path) if os.path.isfile(os.path.join(dir_path, f)) and f.endswith('.csv')]
|
|
||||||
folders = [d for d in os.listdir(dir_path) if os.path.isdir(os.path.join(dir_path, d))]
|
|
||||||
if len(files) == 1:
|
|
||||||
return FolderType.SINGLE_FILE
|
|
||||||
elif len(files) > 1:
|
|
||||||
return FolderType.MULTIPLE_FILES
|
|
||||||
elif len(folders) >= 1:
|
|
||||||
return FolderType.MULTIPLE_FOLDERS
|
|
||||||
else:
|
|
||||||
return FolderType.EMPTY
|
|
||||||
|
|
||||||
|
|
||||||
# finds if this is single-threaded or multi-threaded
|
|
||||||
def get_run_type(dir_path):
|
|
||||||
folder_type = classify_folder(dir_path)
|
|
||||||
if folder_type == FolderType.SINGLE_FILE:
|
|
||||||
return RunType.SINGLE_FOLDER_SINGLE_FILE
|
|
||||||
|
|
||||||
elif folder_type == FolderType.MULTIPLE_FILES:
|
|
||||||
return RunType.SINGLE_FOLDER_MULTIPLE_FILES
|
|
||||||
|
|
||||||
elif folder_type == FolderType.MULTIPLE_FOLDERS:
|
|
||||||
# folder contains sub dirs -> we assume we can classify the folder using only the first sub dir
|
|
||||||
sub_dirs = [d for d in os.listdir(dir_path) if os.path.isdir(os.path.join(dir_path, d))]
|
|
||||||
|
|
||||||
# checking only the first folder in the root dir for its type, since we assume that all sub dirs will share the
|
|
||||||
# same structure (i.e. if one is a result of multi-threaded run, so will all the other).
|
|
||||||
folder_type = classify_folder(os.path.join(dir_path, sub_dirs[0]))
|
|
||||||
if folder_type == FolderType.SINGLE_FILE:
|
|
||||||
folder_type = RunType.MULTIPLE_FOLDERS_SINGLE_FILES
|
|
||||||
elif folder_type == FolderType.MULTIPLE_FILES:
|
|
||||||
folder_type = RunType.MULTIPLE_FOLDERS_MULTIPLE_FILES
|
|
||||||
return folder_type
|
|
||||||
|
|
||||||
|
|
||||||
# takes path to dir and recursively adds all it's files to paths
|
|
||||||
def add_directory_csv_files(dir_path, paths=None):
|
|
||||||
if not paths:
|
|
||||||
paths = []
|
|
||||||
|
|
||||||
for p in os.listdir(dir_path):
|
|
||||||
path = os.path.join(dir_path, p)
|
|
||||||
if os.path.isdir(path):
|
|
||||||
# call recursively for each dir
|
|
||||||
paths = add_directory_csv_files(path, paths)
|
|
||||||
elif os.path.isfile(path) and path.endswith('.csv'):
|
|
||||||
# add every file to the list
|
|
||||||
paths.append(path)
|
|
||||||
|
|
||||||
return paths
|
|
||||||
|
|
||||||
|
|
||||||
# create a signal file from the directory path according to the directory underlying structure
|
|
||||||
def handle_dir(dir_path, run_type):
|
|
||||||
paths = add_directory_csv_files(dir_path)
|
|
||||||
if run_type == RunType.SINGLE_FOLDER_SINGLE_FILE:
|
|
||||||
create_files_signal(paths)
|
|
||||||
elif run_type == RunType.SINGLE_FOLDER_MULTIPLE_FILES:
|
|
||||||
create_files_group_signal(paths)
|
|
||||||
elif run_type == RunType.MULTIPLE_FOLDERS_SINGLE_FILES:
|
|
||||||
create_files_group_signal(paths)
|
|
||||||
elif run_type == RunType.MULTIPLE_FOLDERS_MULTIPLE_FILES:
|
|
||||||
sub_dirs = [d for d in os.listdir(dir_path) if os.path.isdir(os.path.join(dir_path, d))]
|
|
||||||
# for d in sub_dirs:
|
|
||||||
# paths = add_directory_csv_files(os.path.join(dir_path, d))
|
|
||||||
# create_files_group_signal(paths)
|
|
||||||
create_files_group_signal([os.path.join(dir_path, d) for d in sub_dirs])
|
|
||||||
|
|
||||||
|
|
||||||
# load directory from disk as a group
|
|
||||||
def load_directory_group():
|
|
||||||
global selected_file
|
|
||||||
show_spinner()
|
|
||||||
directory = open_directory_dialog()
|
|
||||||
# no files selected
|
|
||||||
if not directory:
|
|
||||||
hide_spinner()
|
|
||||||
return
|
|
||||||
|
|
||||||
change_displayed_doc()
|
|
||||||
|
|
||||||
handle_dir(directory, get_run_type(directory))
|
|
||||||
|
|
||||||
change_selected_signals_in_data_selector([""])
|
|
||||||
hide_spinner()
|
|
||||||
|
|
||||||
|
|
||||||
def create_files_signal(files):
|
|
||||||
global selected_file
|
|
||||||
new_signal_files = []
|
|
||||||
for idx, file_path in enumerate(files):
|
|
||||||
signals_file = SignalsFile(str(file_path))
|
|
||||||
signals_files[signals_file.filename] = signals_file
|
|
||||||
new_signal_files.append(signals_file)
|
|
||||||
|
|
||||||
filenames = [f.filename for f in new_signal_files]
|
|
||||||
|
|
||||||
files_selector.options += filenames
|
|
||||||
files_selector.value = filenames[0]
|
|
||||||
selected_file = new_signal_files[0]
|
|
||||||
|
|
||||||
|
|
||||||
# load files from disk
|
|
||||||
def load_files():
|
|
||||||
show_spinner()
|
|
||||||
files = open_file_dialog()
|
|
||||||
|
|
||||||
# no files selected
|
|
||||||
if not files or not files[0]:
|
|
||||||
hide_spinner()
|
|
||||||
return
|
|
||||||
|
|
||||||
create_files_signal(files)
|
|
||||||
hide_spinner()
|
|
||||||
|
|
||||||
change_selected_signals_in_data_selector([""])
|
|
||||||
|
|
||||||
|
|
||||||
def unload_file():
|
|
||||||
global selected_file
|
|
||||||
global signals_files
|
|
||||||
if selected_file is None:
|
|
||||||
return
|
|
||||||
selected_file.hide_all_signals()
|
|
||||||
del signals_files[selected_file.filename]
|
|
||||||
data_selector.options = [""]
|
|
||||||
filenames = itertools.cycle(files_selector.options)
|
|
||||||
files_selector.options.remove(selected_file.filename)
|
|
||||||
if len(files_selector.options) > 0:
|
|
||||||
files_selector.value = next(filenames)
|
|
||||||
else:
|
|
||||||
files_selector.value = None
|
|
||||||
update_legend()
|
|
||||||
refresh_info.text = ""
|
|
||||||
|
|
||||||
|
|
||||||
# reload the selected csv file
|
|
||||||
def reload_all_files(force=False):
|
|
||||||
for file_to_load in signals_files.values():
|
|
||||||
if force or file_to_load.file_was_modified_on_disk():
|
|
||||||
file_to_load.load()
|
|
||||||
refresh_info.text = "last update: " + str(datetime.datetime.now()).split(".")[0]
|
|
||||||
|
|
||||||
|
|
||||||
# unselect the currently selected signals and then select the requested signals in the data selector
|
|
||||||
def change_selected_signals_in_data_selector(selected_signals):
|
|
||||||
# the default bokeh way is not working due to a bug since Bokeh 0.12.6 (https://github.com/bokeh/bokeh/issues/6501)
|
|
||||||
# this will currently cause the signals to change color
|
|
||||||
for value in list(data_selector.value):
|
|
||||||
if value in data_selector.options:
|
|
||||||
index = data_selector.options.index(value)
|
|
||||||
data_selector.options.remove(value)
|
|
||||||
data_selector.value.remove(value)
|
|
||||||
data_selector.options.insert(index, value)
|
|
||||||
data_selector.value = selected_signals
|
|
||||||
|
|
||||||
|
|
||||||
# change data options according to the selected file
|
|
||||||
def change_data_selector(args, old, new):
|
|
||||||
global selected_file
|
|
||||||
if new is None:
|
|
||||||
selected_file = None
|
|
||||||
return
|
|
||||||
show_spinner()
|
|
||||||
selected_file = signals_files[new]
|
|
||||||
data_selector.options = sorted(list(selected_file.signals.keys()))
|
|
||||||
selected_signal_names = [s.name for s in selected_file.signals.values() if s.selected]
|
|
||||||
if not selected_signal_names:
|
|
||||||
selected_signal_names = [""]
|
|
||||||
change_selected_signals_in_data_selector(selected_signal_names)
|
|
||||||
averaging_slider.value = selected_file.signals_averaging_window
|
|
||||||
group_cb.active = [0 if selected_file.show_bollinger_bands else None]
|
|
||||||
group_cb.active += [1 if selected_file.separate_files else None]
|
|
||||||
hide_spinner()
|
|
||||||
|
|
||||||
|
|
||||||
# smooth all the signals of the selected file
|
|
||||||
def update_averaging(args, old, new):
|
|
||||||
show_spinner()
|
|
||||||
selected_file.change_averaging_window(new)
|
|
||||||
hide_spinner()
|
|
||||||
|
|
||||||
|
|
||||||
def change_x_axis(val):
|
|
||||||
global x_axis
|
|
||||||
show_spinner()
|
|
||||||
x_axis = x_axis_options[val]
|
|
||||||
plot.xaxis.axis_label = x_axis
|
|
||||||
reload_all_files(force=True)
|
|
||||||
update_axis_range(x_axis, plot.x_range)
|
|
||||||
hide_spinner()
|
|
||||||
|
|
||||||
|
|
||||||
# move the signal between the main and secondary Y axes
|
|
||||||
def toggle_second_axis():
|
|
||||||
show_spinner()
|
|
||||||
signals = selected_file.get_selected_signals()
|
|
||||||
for signal in signals:
|
|
||||||
signal.toggle_axis()
|
|
||||||
|
|
||||||
update_ranges()
|
|
||||||
update_legend()
|
|
||||||
|
|
||||||
# this is just for redrawing the signals
|
|
||||||
selected_file.reload_data([signal.name for signal in signals])
|
|
||||||
|
|
||||||
hide_spinner()
|
|
||||||
|
|
||||||
|
|
||||||
def toggle_group_property(new):
|
|
||||||
# toggle show / hide Bollinger bands
|
|
||||||
selected_file.change_bollinger_bands_state(0 in new)
|
|
||||||
|
|
||||||
# show a separate signal for each file in a group
|
|
||||||
selected_file.show_files_separately(1 in new)
|
|
||||||
|
|
||||||
|
|
||||||
def change_displayed_doc():
|
|
||||||
if doc.roots[0] == landing_page:
|
|
||||||
doc.remove_root(landing_page)
|
|
||||||
doc.add_root(layout)
|
|
||||||
|
|
||||||
|
|
||||||
# Color selection - most of these functions are taken from bokeh examples (plotting/color_sliders.py)
|
|
||||||
def select_color(attr, old, new):
|
|
||||||
show_spinner()
|
|
||||||
signals = selected_file.get_selected_signals()
|
|
||||||
for signal in signals:
|
|
||||||
signal.set_color(rgb_to_hex(crRGBs[new['1d']['indices'][0]]))
|
|
||||||
hide_spinner()
|
|
||||||
|
|
||||||
|
|
||||||
def generate_color_range(N, I):
|
|
||||||
HSV_tuples = [(x*1.0/N, 0.5, I) for x in range(N)]
|
|
||||||
RGB_tuples = map(lambda x: colorsys.hsv_to_rgb(*x), HSV_tuples)
|
|
||||||
for_conversion = []
|
|
||||||
for RGB_tuple in RGB_tuples:
|
|
||||||
for_conversion.append((int(RGB_tuple[0]*255), int(RGB_tuple[1]*255), int(RGB_tuple[2]*255)))
|
|
||||||
hex_colors = [rgb_to_hex(RGB_tuple) for RGB_tuple in for_conversion]
|
|
||||||
return hex_colors, for_conversion
|
|
||||||
|
|
||||||
|
|
||||||
# convert RGB tuple to hexadecimal code
|
|
||||||
def rgb_to_hex(rgb):
|
|
||||||
return '#%02x%02x%02x' % rgb
|
|
||||||
|
|
||||||
|
|
||||||
# convert hexadecimal to RGB tuple
|
|
||||||
def hex_to_dec(hex):
|
|
||||||
red = ''.join(hex.strip('#')[0:2])
|
|
||||||
green = ''.join(hex.strip('#')[2:4])
|
|
||||||
blue = ''.join(hex.strip('#')[4:6])
|
|
||||||
return int(red, 16), int(green, 16), int(blue,16)
|
|
||||||
|
|
||||||
color_resolution = 1000
|
|
||||||
brightness = 0.75 # change to have brighter/darker colors
|
|
||||||
crx = list(range(1, color_resolution+1)) # the resolution is 1000 colors
|
|
||||||
cry = [5 for i in range(len(crx))]
|
|
||||||
crcolor, crRGBs = generate_color_range(color_resolution, brightness) # produce spectrum
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------- Build Website Layout -------------------
|
|
||||||
|
|
||||||
# select file
|
|
||||||
file_selection_button = bw.Button(label="Select Files", button_type="success", width=120)
|
|
||||||
file_selection_button.on_click(load_files_group)
|
|
||||||
|
|
||||||
files_selector_spacer = bl.Spacer(width=10)
|
|
||||||
|
|
||||||
group_selection_button = bw.Button(label="Select Directory", button_type="primary", width=140)
|
|
||||||
group_selection_button.on_click(load_directory_group)
|
|
||||||
|
|
||||||
unload_file_button = bw.Button(label="Unload", button_type="danger", width=50)
|
|
||||||
unload_file_button.on_click(unload_file)
|
|
||||||
|
|
||||||
# files selection box
|
|
||||||
files_selector = bw.Select(title="Files:", options=[], width=200)
|
|
||||||
files_selector.on_change('value', change_data_selector)
|
|
||||||
|
|
||||||
# data selection box
|
|
||||||
data_selector = bw.MultiSelect(title="Data:", options=[], size=12)
|
|
||||||
data_selector.on_change('value', select_data)
|
|
||||||
|
|
||||||
# x axis selection box
|
|
||||||
x_axis_selector_title = bw.Div(text="""X Axis:""")
|
|
||||||
x_axis_selector = bw.RadioButtonGroup(labels=x_axis_options, active=0)
|
|
||||||
x_axis_selector.on_click(change_x_axis)
|
|
||||||
|
|
||||||
# toggle second axis bw.button
|
|
||||||
toggle_second_axis_button = bw.Button(label="Toggle Second Axis", button_type="success")
|
|
||||||
toggle_second_axis_button.on_click(toggle_second_axis)
|
|
||||||
|
|
||||||
# averaging slider
|
|
||||||
averaging_slider = bw.Slider(title="Averaging window", start=1, end=101, step=10)
|
|
||||||
averaging_slider.on_change('value', update_averaging)
|
|
||||||
|
|
||||||
# group properties checkbox
|
|
||||||
group_cb = bw.CheckboxGroup(labels=["Show statistics bands", "Ungroup signals"], active=[])
|
|
||||||
group_cb.on_click(toggle_group_property)
|
|
||||||
|
|
||||||
# color selector
|
|
||||||
color_selector_title = bw.Div(text="""Select Color:""")
|
|
||||||
crsource = bm.ColumnDataSource(data=dict(x=crx, y=cry, crcolor=crcolor, RGBs=crRGBs))
|
|
||||||
color_selector = bp.figure(x_range=(0, color_resolution), y_range=(0, 10),
|
|
||||||
plot_width=300, plot_height=40,
|
|
||||||
tools='tap')
|
|
||||||
color_selector.axis.visible = False
|
|
||||||
color_range = color_selector.rect(x='x', y='y', width=1, height=10,
|
|
||||||
color='crcolor', source=crsource)
|
|
||||||
crsource.on_change('selected', select_color)
|
|
||||||
color_range.nonselection_glyph = color_range.glyph
|
|
||||||
color_selector.toolbar.logo = None
|
|
||||||
color_selector.toolbar_location = None
|
|
||||||
|
|
||||||
# title
|
|
||||||
title = bw.Div(text="""<h1>Coach Dashboard</h1>""")
|
|
||||||
|
|
||||||
# landing page
|
|
||||||
landing_page_description = bw.Div(text="""<h3>Start by selecting an experiment file or directory to open:</h3>""")
|
|
||||||
center = bw.Div(text="""<style>html { text-align: center; } </style>""")
|
|
||||||
center_buttons = bw.Div(text="""<style>.bk-grid-row .bk-layout-fixed { margin: 0 auto; }</style>""", width=0)
|
|
||||||
landing_page = bl.column(center,
|
|
||||||
title,
|
|
||||||
landing_page_description,
|
|
||||||
bl.row(center_buttons),
|
|
||||||
bl.row(file_selection_button, sizing_mode='scale_width'),
|
|
||||||
bl.row(group_selection_button, sizing_mode='scale_width'),
|
|
||||||
sizing_mode='scale_width')
|
|
||||||
|
|
||||||
# main layout of the document
|
|
||||||
layout = bl.row(file_selection_button, files_selector_spacer, group_selection_button, width=300)
|
|
||||||
layout = bl.column(layout, files_selector)
|
|
||||||
layout = bl.column(layout, bl.row(refresh_info, unload_file_button))
|
|
||||||
layout = bl.column(layout, data_selector)
|
|
||||||
layout = bl.column(layout, color_selector_title)
|
|
||||||
layout = bl.column(layout, color_selector)
|
|
||||||
layout = bl.column(layout, x_axis_selector_title)
|
|
||||||
layout = bl.column(layout, x_axis_selector)
|
|
||||||
layout = bl.column(layout, group_cb)
|
|
||||||
layout = bl.column(layout, toggle_second_axis_button)
|
|
||||||
layout = bl.column(layout, averaging_slider)
|
|
||||||
# layout = bl.column(layout, legend)
|
|
||||||
layout = bl.row(layout, plot)
|
|
||||||
layout = bl.column(title, layout)
|
|
||||||
layout = bl.column(layout, spinner)
|
|
||||||
|
|
||||||
doc = bp.curdoc()
|
|
||||||
doc.add_root(landing_page)
|
|
||||||
|
|
||||||
doc.add_periodic_callback(reload_all_files, 20000)
|
|
||||||
plot.y_range = bm.Range1d(0, 100)
|
|
||||||
plot.extra_y_ranges['secondary'] = bm.Range1d(0, 100)
|
|
||||||
|
|
||||||
# show load file dialog immediately on start
|
|
||||||
#doc.add_timeout_callback(load_files, 1000)
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
# find an open port and run the server
|
|
||||||
import socket
|
|
||||||
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
|
||||||
port = 12345
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
s.bind(("127.0.0.1", port))
|
|
||||||
break
|
|
||||||
except socket.error as e:
|
|
||||||
if e.errno == 98:
|
|
||||||
port += 1
|
|
||||||
s.close()
|
|
||||||
os.system('bokeh serve --show dashboard.py --port {}'.format(port))
|
|
||||||
@@ -1,49 +0,0 @@
|
|||||||
#
|
|
||||||
# Copyright (c) 2017 Intel Corporation
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
#
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
|
|
||||||
def show_observation_stack(stack, channels_last=False):
|
|
||||||
if isinstance(stack, list): # is list
|
|
||||||
stack_size = len(stack)
|
|
||||||
elif len(stack.shape) == 3:
|
|
||||||
stack_size = stack.shape[0] # is numpy array
|
|
||||||
elif len(stack.shape) == 4:
|
|
||||||
stack_size = stack.shape[1] # ignore batch dimension
|
|
||||||
stack = stack[0]
|
|
||||||
else:
|
|
||||||
assert False, ""
|
|
||||||
|
|
||||||
if channels_last:
|
|
||||||
stack = np.transpose(stack, (2, 0, 1))
|
|
||||||
stack_size = stack.shape[0]
|
|
||||||
|
|
||||||
for i in range(stack_size):
|
|
||||||
plt.subplot(1, stack_size, i + 1)
|
|
||||||
plt.imshow(stack[i], cmap='gray')
|
|
||||||
|
|
||||||
plt.show()
|
|
||||||
|
|
||||||
|
|
||||||
def show_diff_between_two_observations(observation1, observation2):
|
|
||||||
plt.imshow(observation1 - observation2, cmap='gray')
|
|
||||||
plt.show()
|
|
||||||
|
|
||||||
|
|
||||||
def plot_grayscale_observation(observation):
|
|
||||||
plt.imshow(observation, cmap='gray')
|
|
||||||
plt.show()
|
|
||||||
205
install.sh
205
install.sh
@@ -1,205 +0,0 @@
|
|||||||
#!/bin/bash -e
|
|
||||||
|
|
||||||
prompt () {
|
|
||||||
# prints a yes / no question to the user and returns the answer
|
|
||||||
# first argument is the prompt question
|
|
||||||
# second argument is the default answer - Y / N
|
|
||||||
local default_answer
|
|
||||||
|
|
||||||
# set the default value
|
|
||||||
case "${2}" in
|
|
||||||
y|Y ) default_answer=1; options="[Y/n]";;
|
|
||||||
n|N ) default_answer=0; options="[y/N]";;
|
|
||||||
"" ) default_answer=; options="[y/n]";;
|
|
||||||
* ) echo "invalid default value"; exit;;
|
|
||||||
esac
|
|
||||||
|
|
||||||
while true; do
|
|
||||||
# read the user choice
|
|
||||||
read -p "${1} ${options} " choice
|
|
||||||
|
|
||||||
# return the choice or the default value if an enter was pressed
|
|
||||||
case "${choice}" in
|
|
||||||
y|Y ) retval=1; return;;
|
|
||||||
n|N ) retval=0; return;;
|
|
||||||
"" ) if [ ! -z "${default_answer}" ]; then retval=${default_answer}; return; fi;;
|
|
||||||
esac
|
|
||||||
done
|
|
||||||
}
|
|
||||||
|
|
||||||
add_to_bashrc () {
|
|
||||||
# adds an env variable to the bashrc
|
|
||||||
# first argument is the variable name
|
|
||||||
# second argument is the variable value
|
|
||||||
|
|
||||||
EXISTS_IN_BASHRC=`awk '/${2}/{print $1}' ~/.bashrc`
|
|
||||||
if [ "${EXISTS_IN_BASHRC}" == "" ]; then
|
|
||||||
echo "export ${1}=${2}" >> ~/.bashrc
|
|
||||||
fi
|
|
||||||
}
|
|
||||||
|
|
||||||
GET_PREFERENCES_MANUALLY=1
|
|
||||||
|
|
||||||
INSTALL_COACH=0
|
|
||||||
INSTALL_DASHBOARD=0
|
|
||||||
INSTALL_GYM=0
|
|
||||||
INSTALL_NEON=0
|
|
||||||
INSTALL_VIRTUAL_ENVIRONMENT=1
|
|
||||||
|
|
||||||
# Get user preferences
|
|
||||||
TEMP=`getopt -o cpgvrmeNndh \
|
|
||||||
--long coach,dashboard,gym,no_virtual_environment,neon,debug,help \
|
|
||||||
-- "$@"`
|
|
||||||
eval set -- "$TEMP"
|
|
||||||
while true; do
|
|
||||||
#for i in "$@"
|
|
||||||
case ${1} in
|
|
||||||
-c|--coach)
|
|
||||||
INSTALL_COACH=1
|
|
||||||
GET_PREFERENCES_MANUALLY=0
|
|
||||||
shift;;
|
|
||||||
-p|--dashboard)
|
|
||||||
INSTALL_DASHBOARD=1
|
|
||||||
GET_PREFERENCES_MANUALLY=0;
|
|
||||||
shift;;
|
|
||||||
-g|--gym)
|
|
||||||
INSTALL_GYM=1
|
|
||||||
GET_PREFERENCES_MANUALLY=0;
|
|
||||||
shift;;
|
|
||||||
-N|--no_virtual_environment)
|
|
||||||
INSTALL_VIRTUAL_ENVIRONMENT=0
|
|
||||||
GET_PREFERENCES_MANUALLY=0;
|
|
||||||
shift;;
|
|
||||||
-ne|--neon)
|
|
||||||
INSTALL_NEON=1
|
|
||||||
GET_PREFERENCES_MANUALLY=0;
|
|
||||||
shift;;
|
|
||||||
-d|--debug) set -x; shift;;
|
|
||||||
-h|--help)
|
|
||||||
echo "Available command line arguments:"
|
|
||||||
echo ""
|
|
||||||
echo " -c | --coach - Install Coach requirements"
|
|
||||||
echo " -p | --dashboard - Install Dashboard requirements"
|
|
||||||
echo " -g | --gym - Install Gym support"
|
|
||||||
echo " -N | --no_virtual_environment - Do not install inside of a virtual environment"
|
|
||||||
echo " -d | --debug - Run in debug mode"
|
|
||||||
echo " -h | --help - Display this help message"
|
|
||||||
echo ""
|
|
||||||
exit;;
|
|
||||||
--) shift; break;;
|
|
||||||
*) break;; # unknown option;;
|
|
||||||
esac
|
|
||||||
done
|
|
||||||
|
|
||||||
if [ ${GET_PREFERENCES_MANUALLY} -eq 1 ]; then
|
|
||||||
prompt "Install Coach requirements?" Y
|
|
||||||
INSTALL_COACH=${retval}
|
|
||||||
|
|
||||||
prompt "Install Dashboard requirements?" Y
|
|
||||||
INSTALL_DASHBOARD=${retval}
|
|
||||||
|
|
||||||
prompt "Install Gym support?" Y
|
|
||||||
INSTALL_GYM=${retval}
|
|
||||||
|
|
||||||
prompt "Install neon support?" Y
|
|
||||||
INSTALL_NEON=${retval}
|
|
||||||
fi
|
|
||||||
|
|
||||||
IN_VIRTUAL_ENV=`python3 -c 'import sys; print("%i" % hasattr(sys, "real_prefix"))'`
|
|
||||||
|
|
||||||
# basic installations
|
|
||||||
sudo -E apt-get install python3-pip cmake zlib1g-dev python3-tk python-opencv -y
|
|
||||||
pip3 install --upgrade pip
|
|
||||||
|
|
||||||
# if we are not in a virtual environment, we will create one with the appropriate python version and then activate it
|
|
||||||
# if we are already in a virtual environment,
|
|
||||||
|
|
||||||
if [ ${INSTALL_VIRTUAL_ENVIRONMENT} -eq 1 ]; then
|
|
||||||
if [ ${IN_VIRTUAL_ENV} -eq 0 ]; then
|
|
||||||
sudo -E pip3 install virtualenv
|
|
||||||
virtualenv -p python3 coach_env
|
|
||||||
. coach_env/bin/activate
|
|
||||||
fi
|
|
||||||
fi
|
|
||||||
|
|
||||||
#------------------------------------------------
|
|
||||||
# From now on we are in a virtual environment
|
|
||||||
#------------------------------------------------
|
|
||||||
|
|
||||||
# get python local and global paths
|
|
||||||
python_version=python$(python -c "import sys; print (str(sys.version_info[0])+'.'+str(sys.version_info[1]))")
|
|
||||||
var=( $(which -a $python_version) )
|
|
||||||
get_python_lib_cmd="from distutils.sysconfig import get_python_lib; print (get_python_lib())"
|
|
||||||
lib_virtualenv_path=$(python -c "$get_python_lib_cmd")
|
|
||||||
lib_system_path=$(${var[-1]} -c "$get_python_lib_cmd")
|
|
||||||
|
|
||||||
# Boost libraries
|
|
||||||
sudo -E apt-get install libboost-all-dev -y
|
|
||||||
|
|
||||||
# Coach
|
|
||||||
if [ ${INSTALL_COACH} -eq 1 ]; then
|
|
||||||
echo "Installing Coach requirements"
|
|
||||||
pip3 install -r ./requirements_coach.txt
|
|
||||||
fi
|
|
||||||
|
|
||||||
# Dashboard
|
|
||||||
if [ ${INSTALL_DASHBOARD} -eq 1 ]; then
|
|
||||||
echo "Installing Dashboard requirements"
|
|
||||||
pip3 install -r ./requirements_dashboard.txt
|
|
||||||
sudo -E apt-get install dpkg-dev build-essential python3.5-dev libjpeg-dev libtiff-dev libsdl1.2-dev libnotify-dev \
|
|
||||||
freeglut3 freeglut3-dev libsm-dev libgtk2.0-dev libgtk-3-dev libwebkitgtk-dev libgtk-3-dev libwebkitgtk-3.0-dev libgstreamer-plugins-base1.0-dev -y
|
|
||||||
|
|
||||||
sudo -E -H pip3 install -U --pre -f \
|
|
||||||
https://wxpython.org/Phoenix/snapshot-builds/linux/gtk3/ubuntu-16.04/wxPython-4.0.0a3.dev3059+4a5c5d9-cp35-cp35m-linux_x86_64.whl wxPython
|
|
||||||
|
|
||||||
# link wxPython Phoenix library into the virtualenv since it is installed with apt-get and not accessible
|
|
||||||
libs=( wx )
|
|
||||||
for lib in ${libs[@]}
|
|
||||||
do
|
|
||||||
ln -sf $lib_system_path/$lib $lib_virtualenv_path/$lib
|
|
||||||
done
|
|
||||||
fi
|
|
||||||
|
|
||||||
# Gym
|
|
||||||
if [ ${INSTALL_GYM} -eq 1 ]; then
|
|
||||||
echo "Installing Gym support"
|
|
||||||
sudo -E apt-get install libav-tools libsdl2-dev swig cmake -y
|
|
||||||
pip3 install box2d # for bipedal walker etc.
|
|
||||||
pip3 install gym
|
|
||||||
fi
|
|
||||||
|
|
||||||
# NGraph and Neon
|
|
||||||
if [ ${INSTALL_NEON} -eq 1 ]; then
|
|
||||||
echo "Installing neon requirements"
|
|
||||||
|
|
||||||
# MKL
|
|
||||||
git clone https://github.com/01org/mkl-dnn.git
|
|
||||||
cd mkl-dnn
|
|
||||||
cd scripts && ./prepare_mkl.sh && cd ..
|
|
||||||
mkdir -p build && cd build && cmake .. && make -j
|
|
||||||
sudo make install -j
|
|
||||||
cd ../..
|
|
||||||
export MKLDNN_ROOT=/usr/local/
|
|
||||||
add_to_bashrc MKLDNN_ROOT ${MKLDNN_ROOT}
|
|
||||||
export LD_LIBRARY_PATH=$MKLDNN_ROOT/lib:$LD_LIBRARY_PATH
|
|
||||||
add_to_bashrc LD_LIBRARY_PATH ${MKLDNN_ROOT}/lib:$LD_LIBRARY_PATH
|
|
||||||
|
|
||||||
# NGraph
|
|
||||||
git clone https://github.com/NervanaSystems/ngraph.git
|
|
||||||
cd ngraph
|
|
||||||
make install -j
|
|
||||||
cd ..
|
|
||||||
|
|
||||||
# Neon
|
|
||||||
sudo -E apt-get install libhdf5-dev libyaml-dev pkg-config clang virtualenv libcurl4-openssl-dev libopencv-dev libsox-dev -y
|
|
||||||
pip3 install nervananeon
|
|
||||||
fi
|
|
||||||
|
|
||||||
if ! [ -x "$(command -v nvidia-smi)" ]; then
|
|
||||||
# Intel Optimized TensorFlow
|
|
||||||
#pip3 install https://anaconda.org/intel/tensorflow/1.3.0/download/tensorflow-1.3.0-cp35-cp35m-linux_x86_64.whl
|
|
||||||
pip3 install https://anaconda.org/intel/tensorflow/1.4.0/download/tensorflow-1.4.0-cp35-cp35m-linux_x86_64.whl
|
|
||||||
else
|
|
||||||
# GPU supported TensorFlow
|
|
||||||
pip3 install tensorflow-gpu==1.4.1
|
|
||||||
fi
|
|
||||||
@@ -7,3 +7,5 @@ pygame==1.9.3
|
|||||||
PyOpenGL==3.1.0
|
PyOpenGL==3.1.0
|
||||||
scipy==0.19.0
|
scipy==0.19.0
|
||||||
scikit-image==0.13.0
|
scikit-image==0.13.0
|
||||||
|
gym
|
||||||
|
tensorflow
|
||||||
@@ -14,6 +14,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
|
import argparse
|
||||||
import atexit
|
import atexit
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
@@ -22,13 +23,12 @@ import subprocess
|
|||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
|
|
||||||
import agents
|
from coach import agents # noqa
|
||||||
import argparse
|
from coach import configurations as conf
|
||||||
import configurations as conf
|
from coach import environments
|
||||||
import environments
|
from coach import logger
|
||||||
import logger
|
from coach import presets
|
||||||
import presets
|
from coach import utils
|
||||||
import utils
|
|
||||||
|
|
||||||
|
|
||||||
if len(set(logger.failed_imports)) > 0:
|
if len(set(logger.failed_imports)) > 0:
|
||||||
@@ -20,10 +20,10 @@ import time
|
|||||||
|
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
import agents
|
from coach import agents
|
||||||
import environments
|
from coach import environments
|
||||||
import logger
|
from coach import logger
|
||||||
import presets
|
from coach import presets
|
||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
28
setup.py
Executable file
28
setup.py
Executable file
@@ -0,0 +1,28 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Setup for the coach project
|
||||||
|
"""
|
||||||
|
from setuptools import setup, find_packages
|
||||||
|
|
||||||
|
|
||||||
|
setup(name="coach",
|
||||||
|
version="0.7",
|
||||||
|
description="Reinforcement Learning Coach",
|
||||||
|
author="Caspi, Itai and Leibovich, Gal and Novik, Gal",
|
||||||
|
author_email="gal.novik@intel.com",
|
||||||
|
url="https://github.com/NervanaSystems/coach",
|
||||||
|
packages=find_packages(),
|
||||||
|
download_url="https://github.com/NervanaSystems/coach",
|
||||||
|
keywords=["reinforcement", "machine", "learning"],
|
||||||
|
install_requires=["annoy", "Pillow", "matplotlib", "numpy", "pandas",
|
||||||
|
"pygame", "PyOpenGL", "scipy", "scikit-image",
|
||||||
|
"tensorflow", "gym"],
|
||||||
|
scripts=["scripts/coach"],
|
||||||
|
classifiers=["Programming Language :: Python :: 3",
|
||||||
|
"Development Status :: 4 - Beta",
|
||||||
|
"Environment :: Console",
|
||||||
|
"Intended Audience :: End Users/Desktop",
|
||||||
|
"License :: OSI Approved :: Apache Software License",
|
||||||
|
"Operating System :: OS Independent",
|
||||||
|
"Topic :: Scientific/Engineering :: Artificial "
|
||||||
|
"Intelligence"])
|
||||||
Reference in New Issue
Block a user