mirror of
https://github.com/gryf/coach.git
synced 2025-12-18 11:40:18 +01:00
Cleanup imports.
Till now, most of the modules were importing all of the module objects (variables, classes, functions, other imports) into module namespace, which potentially could (and was) cause of unintentional use of class or methods, which was indirect imported. With this patch, all the star imports were substituted with top-level module, which provides desired class or function. Besides, all imports where sorted (where possible) in a way pep8[1] suggests - first are imports from standard library, than goes third party imports (like numpy, tensorflow etc) and finally coach modules. All of those sections are separated by one empty line. [1] https://www.python.org/dev/peps/pep-0008/#imports
This commit is contained in:
181
agents/agent.py
181
agents/agent.py
@@ -13,32 +13,28 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import scipy.ndimage
|
||||
try:
|
||||
import matplotlib.pyplot as plt
|
||||
except:
|
||||
from logger import failed_imports
|
||||
failed_imports.append("matplotlib")
|
||||
|
||||
import copy
|
||||
from renderer import Renderer
|
||||
from configurations import Preset
|
||||
from collections import deque
|
||||
from utils import LazyStack
|
||||
from collections import OrderedDict
|
||||
from utils import RunPhase, Signal, is_empty, RunningStat
|
||||
from architectures import *
|
||||
from exploration_policies import *
|
||||
from memories import *
|
||||
from memories.memory import *
|
||||
from logger import logger, screen
|
||||
import collections
|
||||
import random
|
||||
import time
|
||||
import os
|
||||
import itertools
|
||||
from architectures.tensorflow_components.shared_variables import SharedRunningStats
|
||||
|
||||
import logger
|
||||
try:
|
||||
import matplotlib.pyplot as plt
|
||||
except ImportError:
|
||||
logger.failed_imports.append("matplotlib")
|
||||
|
||||
import numpy as np
|
||||
from pandas.io import pickle
|
||||
from six.moves import range
|
||||
import scipy
|
||||
|
||||
from architectures.tensorflow_components import shared_variables as sv
|
||||
import configurations
|
||||
import exploration_policies as ep
|
||||
import memories
|
||||
from memories import memory
|
||||
import renderer
|
||||
import utils
|
||||
|
||||
|
||||
class Agent(object):
|
||||
@@ -54,7 +50,7 @@ class Agent(object):
|
||||
:param thread_id: int
|
||||
"""
|
||||
|
||||
screen.log_title("Creating agent {}".format(task_id))
|
||||
logger.screen.log_title("Creating agent {}".format(task_id))
|
||||
self.task_id = task_id
|
||||
self.sess = tuning_parameters.sess
|
||||
self.env = tuning_parameters.env_instance = env
|
||||
@@ -71,21 +67,20 @@ class Agent(object):
|
||||
|
||||
# modules
|
||||
if tuning_parameters.agent.load_memory_from_file_path:
|
||||
screen.log_title("Loading replay buffer from pickle. Pickle path: {}"
|
||||
logger.screen.log_title("Loading replay buffer from pickle. Pickle path: {}"
|
||||
.format(tuning_parameters.agent.load_memory_from_file_path))
|
||||
self.memory = read_pickle(tuning_parameters.agent.load_memory_from_file_path)
|
||||
self.memory = pickle.read_pickle(tuning_parameters.agent.load_memory_from_file_path)
|
||||
else:
|
||||
self.memory = eval(tuning_parameters.memory + '(tuning_parameters)')
|
||||
# self.architecture = eval(tuning_parameters.architecture)
|
||||
self.memory = eval('memories.' + tuning_parameters.memory + '(tuning_parameters)')
|
||||
|
||||
self.has_global = replicated_device is not None
|
||||
self.replicated_device = replicated_device
|
||||
self.worker_device = "/job:worker/task:{}/cpu:0".format(task_id) if replicated_device is not None else "/gpu:0"
|
||||
|
||||
self.exploration_policy = eval(tuning_parameters.exploration.policy + '(tuning_parameters)')
|
||||
self.evaluation_exploration_policy = eval(tuning_parameters.exploration.evaluation_policy
|
||||
self.exploration_policy = eval('ep.' + tuning_parameters.exploration.policy + '(tuning_parameters)')
|
||||
self.evaluation_exploration_policy = eval('ep.' + tuning_parameters.exploration.evaluation_policy
|
||||
+ '(tuning_parameters)')
|
||||
self.evaluation_exploration_policy.change_phase(RunPhase.TEST)
|
||||
self.evaluation_exploration_policy.change_phase(utils.RunPhase.TEST)
|
||||
|
||||
# initialize all internal variables
|
||||
self.tp = tuning_parameters
|
||||
@@ -100,30 +95,30 @@ class Agent(object):
|
||||
self.episode_running_info = {}
|
||||
self.last_episode_evaluation_ran = 0
|
||||
self.running_observations = []
|
||||
logger.set_current_time(self.current_episode)
|
||||
logger.logger.set_current_time(self.current_episode)
|
||||
self.main_network = None
|
||||
self.networks = []
|
||||
self.last_episode_images = []
|
||||
self.renderer = Renderer()
|
||||
self.renderer = renderer.Renderer()
|
||||
|
||||
# signals
|
||||
self.signals = []
|
||||
self.loss = Signal('Loss')
|
||||
self.loss = utils.Signal('Loss')
|
||||
self.signals.append(self.loss)
|
||||
self.curr_learning_rate = Signal('Learning Rate')
|
||||
self.curr_learning_rate = utils.Signal('Learning Rate')
|
||||
self.signals.append(self.curr_learning_rate)
|
||||
|
||||
if self.tp.env.normalize_observation and not self.env.is_state_type_image:
|
||||
if not self.tp.distributed or not self.tp.agent.share_statistics_between_workers:
|
||||
self.running_observation_stats = RunningStat((self.tp.env.desired_observation_width,))
|
||||
self.running_reward_stats = RunningStat(())
|
||||
self.running_observation_stats = utils.RunningStat((self.tp.env.desired_observation_width,))
|
||||
self.running_reward_stats = utils.RunningStat(())
|
||||
else:
|
||||
self.running_observation_stats = SharedRunningStats(self.tp, replicated_device,
|
||||
shape=(self.tp.env.desired_observation_width,),
|
||||
name='observation_stats')
|
||||
self.running_reward_stats = SharedRunningStats(self.tp, replicated_device,
|
||||
shape=(),
|
||||
name='reward_stats')
|
||||
self.running_observation_stats = sv.SharedRunningStats(self.tp, replicated_device,
|
||||
shape=(self.tp.env.desired_observation_width,),
|
||||
name='observation_stats')
|
||||
self.running_reward_stats = sv.SharedRunningStats(self.tp, replicated_device,
|
||||
shape=(),
|
||||
name='reward_stats')
|
||||
|
||||
# env is already reset at this point. Otherwise we're getting an error where you cannot
|
||||
# reset an env which is not done
|
||||
@@ -137,13 +132,13 @@ class Agent(object):
|
||||
def log_to_screen(self, phase):
|
||||
# log to screen
|
||||
if self.current_episode >= 0:
|
||||
if phase == RunPhase.TRAIN:
|
||||
if phase == utils.RunPhase.TRAIN:
|
||||
exploration = self.exploration_policy.get_control_param()
|
||||
else:
|
||||
exploration = self.evaluation_exploration_policy.get_control_param()
|
||||
|
||||
screen.log_dict(
|
||||
OrderedDict([
|
||||
logger.screen.log_dict(
|
||||
collections.OrderedDict([
|
||||
("Worker", self.task_id),
|
||||
("Episode", self.current_episode),
|
||||
("total reward", self.total_reward_in_current_episode),
|
||||
@@ -154,37 +149,37 @@ class Agent(object):
|
||||
prefix=phase
|
||||
)
|
||||
|
||||
def update_log(self, phase=RunPhase.TRAIN):
|
||||
def update_log(self, phase=utils.RunPhase.TRAIN):
|
||||
"""
|
||||
Writes logging messages to screen and updates the log file with all the signal values.
|
||||
:return: None
|
||||
"""
|
||||
# log all the signals to file
|
||||
logger.set_current_time(self.current_episode)
|
||||
logger.create_signal_value('Training Iter', self.training_iteration)
|
||||
logger.create_signal_value('In Heatup', int(phase == RunPhase.HEATUP))
|
||||
logger.create_signal_value('ER #Transitions', self.memory.num_transitions())
|
||||
logger.create_signal_value('ER #Episodes', self.memory.length())
|
||||
logger.create_signal_value('Episode Length', self.current_episode_steps_counter)
|
||||
logger.create_signal_value('Total steps', self.total_steps_counter)
|
||||
logger.create_signal_value("Epsilon", self.exploration_policy.get_control_param())
|
||||
logger.create_signal_value("Training Reward", self.total_reward_in_current_episode
|
||||
if phase == RunPhase.TRAIN else np.nan)
|
||||
logger.create_signal_value('Evaluation Reward', self.total_reward_in_current_episode
|
||||
if phase == RunPhase.TEST else np.nan)
|
||||
logger.create_signal_value('Update Target Network', 0, overwrite=False)
|
||||
logger.update_wall_clock_time(self.current_episode)
|
||||
logger.logger.set_current_time(self.current_episode)
|
||||
logger.logger.create_signal_value('Training Iter', self.training_iteration)
|
||||
logger.logger.create_signal_value('In Heatup', int(phase == utils.RunPhase.HEATUP))
|
||||
logger.logger.create_signal_value('ER #Transitions', self.memory.num_transitions())
|
||||
logger.logger.create_signal_value('ER #Episodes', self.memory.length())
|
||||
logger.logger.create_signal_value('Episode Length', self.current_episode_steps_counter)
|
||||
logger.logger.create_signal_value('Total steps', self.total_steps_counter)
|
||||
logger.logger.create_signal_value("Epsilon", self.exploration_policy.get_control_param())
|
||||
logger.logger.create_signal_value("Training Reward", self.total_reward_in_current_episode
|
||||
if phase == utils.RunPhase.TRAIN else np.nan)
|
||||
logger.logger.create_signal_value('Evaluation Reward', self.total_reward_in_current_episode
|
||||
if phase == utils.RunPhase.TEST else np.nan)
|
||||
logger.logger.create_signal_value('Update Target Network', 0, overwrite=False)
|
||||
logger.logger.update_wall_clock_time(self.current_episode)
|
||||
|
||||
for signal in self.signals:
|
||||
logger.create_signal_value("{}/Mean".format(signal.name), signal.get_mean())
|
||||
logger.create_signal_value("{}/Stdev".format(signal.name), signal.get_stdev())
|
||||
logger.create_signal_value("{}/Max".format(signal.name), signal.get_max())
|
||||
logger.create_signal_value("{}/Min".format(signal.name), signal.get_min())
|
||||
logger.logger.create_signal_value("{}/Mean".format(signal.name), signal.get_mean())
|
||||
logger.logger.create_signal_value("{}/Stdev".format(signal.name), signal.get_stdev())
|
||||
logger.logger.create_signal_value("{}/Max".format(signal.name), signal.get_max())
|
||||
logger.logger.create_signal_value("{}/Min".format(signal.name), signal.get_min())
|
||||
|
||||
# dump
|
||||
if self.current_episode % self.tp.visualization.dump_signals_to_csv_every_x_episodes == 0 \
|
||||
and self.current_episode > 0:
|
||||
logger.dump_output_csv()
|
||||
logger.logger.dump_output_csv()
|
||||
|
||||
def reset_game(self, do_not_reset_env=False):
|
||||
"""
|
||||
@@ -211,7 +206,7 @@ class Agent(object):
|
||||
self.episode_running_info[action] = []
|
||||
plt.clf()
|
||||
|
||||
if self.tp.agent.middleware_type == MiddlewareTypes.LSTM:
|
||||
if self.tp.agent.middleware_type == configurations.MiddlewareTypes.LSTM:
|
||||
for network in self.networks:
|
||||
network.online_network.curr_rnn_c_in = network.online_network.middleware_embedder.c_init
|
||||
network.online_network.curr_rnn_h_in = network.online_network.middleware_embedder.h_init
|
||||
@@ -281,9 +276,9 @@ class Agent(object):
|
||||
if self.total_steps_counter % self.tp.agent.num_steps_between_copying_online_weights_to_target == 0:
|
||||
for network in self.networks:
|
||||
network.update_target_network(self.tp.agent.rate_for_copying_weights_to_target)
|
||||
logger.create_signal_value('Update Target Network', 1)
|
||||
logger.logger.create_signal_value('Update Target Network', 1)
|
||||
else:
|
||||
logger.create_signal_value('Update Target Network', 0, overwrite=False)
|
||||
logger.logger.create_signal_value('Update Target Network', 0, overwrite=False)
|
||||
|
||||
return loss
|
||||
|
||||
@@ -321,7 +316,7 @@ class Agent(object):
|
||||
plt.legend()
|
||||
plt.pause(0.00000001)
|
||||
|
||||
def choose_action(self, curr_state, phase=RunPhase.TRAIN):
|
||||
def choose_action(self, curr_state, phase=utils.RunPhase.TRAIN):
|
||||
"""
|
||||
choose an action to act with in the current episode being played. Different behavior might be exhibited when training
|
||||
or testing.
|
||||
@@ -351,15 +346,15 @@ class Agent(object):
|
||||
for input_name in self.tp.agent.input_types.keys():
|
||||
input_state[input_name] = np.expand_dims(np.array(curr_state[input_name]), 0)
|
||||
return input_state
|
||||
|
||||
|
||||
def prepare_initial_state(self):
|
||||
"""
|
||||
Create an initial state when starting a new episode
|
||||
:return: None
|
||||
"""
|
||||
observation = self.preprocess_observation(self.env.state['observation'])
|
||||
self.curr_stack = deque([observation]*self.tp.env.observation_stack_size, maxlen=self.tp.env.observation_stack_size)
|
||||
observation = LazyStack(self.curr_stack, -1)
|
||||
self.curr_stack = collections.deque([observation]*self.tp.env.observation_stack_size, maxlen=self.tp.env.observation_stack_size)
|
||||
observation = utils.LazyStack(self.curr_stack, -1)
|
||||
|
||||
self.curr_state = {
|
||||
'observation': observation
|
||||
@@ -369,21 +364,21 @@ class Agent(object):
|
||||
if self.tp.agent.use_accumulated_reward_as_measurement:
|
||||
self.curr_state['measurements'] = np.append(self.curr_state['measurements'], 0)
|
||||
|
||||
def act(self, phase=RunPhase.TRAIN):
|
||||
def act(self, phase=utils.RunPhase.TRAIN):
|
||||
"""
|
||||
Take one step in the environment according to the network prediction and store the transition in memory
|
||||
:param phase: Either Train or Test to specify if greedy actions should be used and if transitions should be stored
|
||||
:return: A boolean value that signals an episode termination
|
||||
"""
|
||||
|
||||
if phase != RunPhase.TEST:
|
||||
if phase != utils.RunPhase.TEST:
|
||||
self.total_steps_counter += 1
|
||||
self.current_episode_steps_counter += 1
|
||||
|
||||
# get new action
|
||||
action_info = {"action_probability": 1.0 / self.env.action_space_size, "action_value": 0, "max_action_value": 0}
|
||||
|
||||
if phase == RunPhase.HEATUP and not self.tp.heatup_using_network_decisions:
|
||||
if phase == utils.RunPhase.HEATUP and not self.tp.heatup_using_network_decisions:
|
||||
action = self.env.get_random_action()
|
||||
else:
|
||||
action, action_info = self.choose_action(self.curr_state, phase=phase)
|
||||
@@ -402,13 +397,13 @@ class Agent(object):
|
||||
next_state['observation'] = self.preprocess_observation(next_state['observation'])
|
||||
|
||||
# plot action values online
|
||||
if self.tp.visualization.plot_action_values_online and phase != RunPhase.HEATUP:
|
||||
if self.tp.visualization.plot_action_values_online and phase != utils.RunPhase.HEATUP:
|
||||
self.plot_action_values_online()
|
||||
|
||||
# initialize the next state
|
||||
# TODO: provide option to stack more than just the observation
|
||||
self.curr_stack.append(next_state['observation'])
|
||||
observation = LazyStack(self.curr_stack, -1)
|
||||
observation = utils.LazyStack(self.curr_stack, -1)
|
||||
|
||||
next_state['observation'] = observation
|
||||
if self.tp.agent.use_measurements and 'measurements' in result.keys():
|
||||
@@ -417,14 +412,14 @@ class Agent(object):
|
||||
next_state['measurements'] = np.append(next_state['measurements'], self.total_reward_in_current_episode)
|
||||
|
||||
# store the transition only if we are training
|
||||
if phase == RunPhase.TRAIN or phase == RunPhase.HEATUP:
|
||||
transition = Transition(self.curr_state, result['action'], shaped_reward, next_state, result['done'])
|
||||
if phase == utils.RunPhase.TRAIN or phase == utils.RunPhase.HEATUP:
|
||||
transition = memory.Transition(self.curr_state, result['action'], shaped_reward, next_state, result['done'])
|
||||
for key in action_info.keys():
|
||||
transition.info[key] = action_info[key]
|
||||
if self.tp.agent.add_a_normalized_timestep_to_the_observation:
|
||||
transition.info['timestep'] = float(self.current_episode_steps_counter) / self.env.timestep_limit
|
||||
self.memory.store(transition)
|
||||
elif phase == RunPhase.TEST and self.tp.visualization.dump_gifs:
|
||||
elif phase == utils.RunPhase.TEST and self.tp.visualization.dump_gifs:
|
||||
# we store the transitions only for saving gifs
|
||||
self.last_episode_images.append(self.env.get_rendered_image())
|
||||
|
||||
@@ -437,7 +432,7 @@ class Agent(object):
|
||||
self.update_log(phase=phase)
|
||||
self.log_to_screen(phase=phase)
|
||||
|
||||
if phase == RunPhase.TRAIN or phase == RunPhase.HEATUP:
|
||||
if phase == utils.RunPhase.TRAIN or phase == utils.RunPhase.HEATUP:
|
||||
self.reset_game()
|
||||
|
||||
self.current_episode += 1
|
||||
@@ -456,8 +451,8 @@ class Agent(object):
|
||||
|
||||
max_reward_achieved = -float('inf')
|
||||
average_evaluation_reward = 0
|
||||
screen.log_title("Running evaluation")
|
||||
self.env.change_phase(RunPhase.TEST)
|
||||
logger.screen.log_title("Running evaluation")
|
||||
self.env.change_phase(utils.RunPhase.TEST)
|
||||
for i in range(num_episodes):
|
||||
# keep the online network in sync with the global network
|
||||
if keep_networks_synced:
|
||||
@@ -466,7 +461,7 @@ class Agent(object):
|
||||
|
||||
episode_ended = False
|
||||
while not episode_ended:
|
||||
episode_ended = self.act(phase=RunPhase.TEST)
|
||||
episode_ended = self.act(phase=utils.RunPhase.TEST)
|
||||
|
||||
if keep_networks_synced \
|
||||
and self.total_steps_counter % self.tp.agent.update_evaluation_agent_network_after_every_num_steps:
|
||||
@@ -477,7 +472,7 @@ class Agent(object):
|
||||
max_reward_achieved = self.total_reward_in_current_episode
|
||||
frame_skipping = int(5/self.tp.env.frame_skip)
|
||||
if self.tp.visualization.dump_gifs:
|
||||
logger.create_gif(self.last_episode_images[::frame_skipping],
|
||||
logger.logger.create_gif(self.last_episode_images[::frame_skipping],
|
||||
name='score-{}'.format(max_reward_achieved), fps=10)
|
||||
|
||||
average_evaluation_reward += self.total_reward_in_current_episode
|
||||
@@ -485,8 +480,8 @@ class Agent(object):
|
||||
|
||||
average_evaluation_reward /= float(num_episodes)
|
||||
|
||||
self.env.change_phase(RunPhase.TRAIN)
|
||||
screen.log_title("Evaluation done. Average reward = {}.".format(average_evaluation_reward))
|
||||
self.env.change_phase(utils.RunPhase.TRAIN)
|
||||
logger.screen.log_title("Evaluation done. Average reward = {}.".format(average_evaluation_reward))
|
||||
|
||||
def post_training_commands(self):
|
||||
pass
|
||||
@@ -505,15 +500,15 @@ class Agent(object):
|
||||
# heatup phase
|
||||
if self.tp.num_heatup_steps != 0:
|
||||
self.in_heatup = True
|
||||
screen.log_title("Starting heatup {}".format(self.task_id))
|
||||
logger.screen.log_title("Starting heatup {}".format(self.task_id))
|
||||
num_steps_required_for_one_training_batch = self.tp.batch_size * self.tp.env.observation_stack_size
|
||||
for step in range(max(self.tp.num_heatup_steps, num_steps_required_for_one_training_batch)):
|
||||
self.act(phase=RunPhase.HEATUP)
|
||||
self.act(phase=utils.RunPhase.HEATUP)
|
||||
|
||||
# training phase
|
||||
self.in_heatup = False
|
||||
screen.log_title("Starting training {}".format(self.task_id))
|
||||
self.exploration_policy.change_phase(RunPhase.TRAIN)
|
||||
logger.screen.log_title("Starting training {}".format(self.task_id))
|
||||
self.exploration_policy.change_phase(utils.RunPhase.TRAIN)
|
||||
training_start_time = time.time()
|
||||
model_snapshots_periods_passed = -1
|
||||
self.reset_game()
|
||||
@@ -557,7 +552,7 @@ class Agent(object):
|
||||
self.loss.add_sample(loss)
|
||||
self.training_iteration += 1
|
||||
if self.imitation:
|
||||
self.log_to_screen(RunPhase.TRAIN)
|
||||
self.log_to_screen(utils.RunPhase.TRAIN)
|
||||
self.post_training_commands()
|
||||
|
||||
def save_model(self, model_id):
|
||||
|
||||
Reference in New Issue
Block a user