mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
Till now, most of the modules were importing all of the module objects (variables, classes, functions, other imports) into module namespace, which potentially could (and was) cause of unintentional use of class or methods, which was indirect imported. With this patch, all the star imports were substituted with top-level module, which provides desired class or function. Besides, all imports where sorted (where possible) in a way pep8[1] suggests - first are imports from standard library, than goes third party imports (like numpy, tensorflow etc) and finally coach modules. All of those sections are separated by one empty line. [1] https://www.python.org/dev/peps/pep-0008/#imports
560 lines
26 KiB
Python
560 lines
26 KiB
Python
#
|
|
# Copyright (c) 2017 Intel Corporation
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
import collections
|
|
import random
|
|
import time
|
|
|
|
import logger
|
|
try:
|
|
import matplotlib.pyplot as plt
|
|
except ImportError:
|
|
logger.failed_imports.append("matplotlib")
|
|
|
|
import numpy as np
|
|
from pandas.io import pickle
|
|
from six.moves import range
|
|
import scipy
|
|
|
|
from architectures.tensorflow_components import shared_variables as sv
|
|
import configurations
|
|
import exploration_policies as ep
|
|
import memories
|
|
from memories import memory
|
|
import renderer
|
|
import utils
|
|
|
|
|
|
class Agent(object):
|
|
def __init__(self, env, tuning_parameters, replicated_device=None, task_id=0):
|
|
"""
|
|
:param env: An environment instance
|
|
:type env: EnvironmentWrapper
|
|
:param tuning_parameters: A Preset class instance with all the running paramaters
|
|
:type tuning_parameters: Preset
|
|
:param replicated_device: A tensorflow device for distributed training (optional)
|
|
:type replicated_device: instancemethod
|
|
:param thread_id: The current thread id
|
|
:param thread_id: int
|
|
"""
|
|
|
|
logger.screen.log_title("Creating agent {}".format(task_id))
|
|
self.task_id = task_id
|
|
self.sess = tuning_parameters.sess
|
|
self.env = tuning_parameters.env_instance = env
|
|
self.imitation = False
|
|
|
|
# i/o dimensions
|
|
if not tuning_parameters.env.desired_observation_width or not tuning_parameters.env.desired_observation_height:
|
|
tuning_parameters.env.desired_observation_width = self.env.width
|
|
tuning_parameters.env.desired_observation_height = self.env.height
|
|
self.action_space_size = tuning_parameters.env.action_space_size = self.env.action_space_size
|
|
self.measurements_size = tuning_parameters.env.measurements_size = self.env.measurements_size
|
|
if tuning_parameters.agent.use_accumulated_reward_as_measurement:
|
|
self.measurements_size = tuning_parameters.env.measurements_size = (self.measurements_size[0] + 1,)
|
|
|
|
# modules
|
|
if tuning_parameters.agent.load_memory_from_file_path:
|
|
logger.screen.log_title("Loading replay buffer from pickle. Pickle path: {}"
|
|
.format(tuning_parameters.agent.load_memory_from_file_path))
|
|
self.memory = pickle.read_pickle(tuning_parameters.agent.load_memory_from_file_path)
|
|
else:
|
|
self.memory = eval('memories.' + tuning_parameters.memory + '(tuning_parameters)')
|
|
|
|
self.has_global = replicated_device is not None
|
|
self.replicated_device = replicated_device
|
|
self.worker_device = "/job:worker/task:{}/cpu:0".format(task_id) if replicated_device is not None else "/gpu:0"
|
|
|
|
self.exploration_policy = eval('ep.' + tuning_parameters.exploration.policy + '(tuning_parameters)')
|
|
self.evaluation_exploration_policy = eval('ep.' + tuning_parameters.exploration.evaluation_policy
|
|
+ '(tuning_parameters)')
|
|
self.evaluation_exploration_policy.change_phase(utils.RunPhase.TEST)
|
|
|
|
# initialize all internal variables
|
|
self.tp = tuning_parameters
|
|
self.in_heatup = False
|
|
self.total_reward_in_current_episode = 0
|
|
self.total_steps_counter = 0
|
|
self.running_reward = None
|
|
self.training_iteration = 0
|
|
self.current_episode = self.tp.current_episode = 0
|
|
self.curr_state = {}
|
|
self.current_episode_steps_counter = 0
|
|
self.episode_running_info = {}
|
|
self.last_episode_evaluation_ran = 0
|
|
self.running_observations = []
|
|
logger.logger.set_current_time(self.current_episode)
|
|
self.main_network = None
|
|
self.networks = []
|
|
self.last_episode_images = []
|
|
self.renderer = renderer.Renderer()
|
|
|
|
# signals
|
|
self.signals = []
|
|
self.loss = utils.Signal('Loss')
|
|
self.signals.append(self.loss)
|
|
self.curr_learning_rate = utils.Signal('Learning Rate')
|
|
self.signals.append(self.curr_learning_rate)
|
|
|
|
if self.tp.env.normalize_observation and not self.env.is_state_type_image:
|
|
if not self.tp.distributed or not self.tp.agent.share_statistics_between_workers:
|
|
self.running_observation_stats = utils.RunningStat((self.tp.env.desired_observation_width,))
|
|
self.running_reward_stats = utils.RunningStat(())
|
|
else:
|
|
self.running_observation_stats = sv.SharedRunningStats(self.tp, replicated_device,
|
|
shape=(self.tp.env.desired_observation_width,),
|
|
name='observation_stats')
|
|
self.running_reward_stats = sv.SharedRunningStats(self.tp, replicated_device,
|
|
shape=(),
|
|
name='reward_stats')
|
|
|
|
# env is already reset at this point. Otherwise we're getting an error where you cannot
|
|
# reset an env which is not done
|
|
self.reset_game(do_not_reset_env=True)
|
|
|
|
# use seed
|
|
if self.tp.seed is not None:
|
|
random.seed(self.tp.seed)
|
|
np.random.seed(self.tp.seed)
|
|
|
|
def log_to_screen(self, phase):
|
|
# log to screen
|
|
if self.current_episode >= 0:
|
|
if phase == utils.RunPhase.TRAIN:
|
|
exploration = self.exploration_policy.get_control_param()
|
|
else:
|
|
exploration = self.evaluation_exploration_policy.get_control_param()
|
|
|
|
logger.screen.log_dict(
|
|
collections.OrderedDict([
|
|
("Worker", self.task_id),
|
|
("Episode", self.current_episode),
|
|
("total reward", self.total_reward_in_current_episode),
|
|
("exploration", exploration),
|
|
("steps", self.total_steps_counter),
|
|
("training iteration", self.training_iteration)
|
|
]),
|
|
prefix=phase
|
|
)
|
|
|
|
def update_log(self, phase=utils.RunPhase.TRAIN):
|
|
"""
|
|
Writes logging messages to screen and updates the log file with all the signal values.
|
|
:return: None
|
|
"""
|
|
# log all the signals to file
|
|
logger.logger.set_current_time(self.current_episode)
|
|
logger.logger.create_signal_value('Training Iter', self.training_iteration)
|
|
logger.logger.create_signal_value('In Heatup', int(phase == utils.RunPhase.HEATUP))
|
|
logger.logger.create_signal_value('ER #Transitions', self.memory.num_transitions())
|
|
logger.logger.create_signal_value('ER #Episodes', self.memory.length())
|
|
logger.logger.create_signal_value('Episode Length', self.current_episode_steps_counter)
|
|
logger.logger.create_signal_value('Total steps', self.total_steps_counter)
|
|
logger.logger.create_signal_value("Epsilon", self.exploration_policy.get_control_param())
|
|
logger.logger.create_signal_value("Training Reward", self.total_reward_in_current_episode
|
|
if phase == utils.RunPhase.TRAIN else np.nan)
|
|
logger.logger.create_signal_value('Evaluation Reward', self.total_reward_in_current_episode
|
|
if phase == utils.RunPhase.TEST else np.nan)
|
|
logger.logger.create_signal_value('Update Target Network', 0, overwrite=False)
|
|
logger.logger.update_wall_clock_time(self.current_episode)
|
|
|
|
for signal in self.signals:
|
|
logger.logger.create_signal_value("{}/Mean".format(signal.name), signal.get_mean())
|
|
logger.logger.create_signal_value("{}/Stdev".format(signal.name), signal.get_stdev())
|
|
logger.logger.create_signal_value("{}/Max".format(signal.name), signal.get_max())
|
|
logger.logger.create_signal_value("{}/Min".format(signal.name), signal.get_min())
|
|
|
|
# dump
|
|
if self.current_episode % self.tp.visualization.dump_signals_to_csv_every_x_episodes == 0 \
|
|
and self.current_episode > 0:
|
|
logger.logger.dump_output_csv()
|
|
|
|
def reset_game(self, do_not_reset_env=False):
|
|
"""
|
|
Resets all the episodic parameters and start a new environment episode.
|
|
:param do_not_reset_env: A boolean that allows prevention of environment reset
|
|
:return: None
|
|
"""
|
|
|
|
for signal in self.signals:
|
|
signal.reset()
|
|
self.total_reward_in_current_episode = 0
|
|
self.curr_state = {}
|
|
self.last_episode_images = []
|
|
self.current_episode_steps_counter = 0
|
|
self.episode_running_info = {}
|
|
if not do_not_reset_env:
|
|
self.env.reset()
|
|
self.exploration_policy.reset()
|
|
|
|
# required for online plotting
|
|
if self.tp.visualization.plot_action_values_online:
|
|
if hasattr(self, 'episode_running_info') and hasattr(self.env, 'actions_description'):
|
|
for action in self.env.actions_description:
|
|
self.episode_running_info[action] = []
|
|
plt.clf()
|
|
|
|
if self.tp.agent.middleware_type == configurations.MiddlewareTypes.LSTM:
|
|
for network in self.networks:
|
|
network.online_network.curr_rnn_c_in = network.online_network.middleware_embedder.c_init
|
|
network.online_network.curr_rnn_h_in = network.online_network.middleware_embedder.h_init
|
|
|
|
self.prepare_initial_state()
|
|
|
|
def preprocess_observation(self, observation):
|
|
"""
|
|
Preprocesses the given observation.
|
|
For images - convert to grayscale, resize and convert to int.
|
|
For measurements vectors - normalize by a running average and std.
|
|
:param observation: The agents observation
|
|
:return: A processed version of the observation
|
|
"""
|
|
|
|
if self.env.is_state_type_image:
|
|
# rescale
|
|
observation = scipy.misc.imresize(observation,
|
|
(self.tp.env.desired_observation_height,
|
|
self.tp.env.desired_observation_width),
|
|
interp=self.tp.rescaling_interpolation_type)
|
|
# rgb to y
|
|
if len(observation.shape) > 2 and observation.shape[2] > 1:
|
|
r, g, b = observation[:, :, 0], observation[:, :, 1], observation[:, :, 2]
|
|
observation = 0.2989 * r + 0.5870 * g + 0.1140 * b
|
|
|
|
# Render the processed observation which is how the agent will see it
|
|
# Warning: this cannot currently be done in parallel to rendering the environment
|
|
if self.tp.visualization.render_observation:
|
|
if not self.renderer.is_open:
|
|
self.renderer.create_screen(observation.shape[0], observation.shape[1])
|
|
self.renderer.render_image(observation)
|
|
|
|
return observation.astype('uint8')
|
|
else:
|
|
if self.tp.env.normalize_observation:
|
|
# standardize the input observation using a running mean and std
|
|
if not self.tp.distributed or not self.tp.agent.share_statistics_between_workers:
|
|
self.running_observation_stats.push(observation)
|
|
observation = (observation - self.running_observation_stats.mean) / \
|
|
(self.running_observation_stats.std + 1e-15)
|
|
observation = np.clip(observation, -5.0, 5.0)
|
|
return observation
|
|
|
|
def learn_from_batch(self, batch):
|
|
"""
|
|
Given a batch of transitions, calculates their target values and updates the network.
|
|
:param batch: A list of transitions
|
|
:return: The loss of the training
|
|
"""
|
|
pass
|
|
|
|
def train(self):
|
|
"""
|
|
A single training iteration. Sample a batch, train on it and update target networks.
|
|
:return: The training loss.
|
|
"""
|
|
batch = self.memory.sample(self.tp.batch_size)
|
|
loss = self.learn_from_batch(batch)
|
|
|
|
if self.tp.learning_rate_decay_rate != 0:
|
|
self.curr_learning_rate.add_sample(self.tp.sess.run(self.tp.learning_rate))
|
|
else:
|
|
self.curr_learning_rate.add_sample(self.tp.learning_rate)
|
|
|
|
# update the target network of every network that has a target network
|
|
if self.total_steps_counter % self.tp.agent.num_steps_between_copying_online_weights_to_target == 0:
|
|
for network in self.networks:
|
|
network.update_target_network(self.tp.agent.rate_for_copying_weights_to_target)
|
|
logger.logger.create_signal_value('Update Target Network', 1)
|
|
else:
|
|
logger.logger.create_signal_value('Update Target Network', 0, overwrite=False)
|
|
|
|
return loss
|
|
|
|
def extract_batch(self, batch):
|
|
"""
|
|
Extracts a single numpy array for each object in a batch of transitions (state, action, etc.)
|
|
:param batch: An array of transitions
|
|
:return: For each transition element, returns a numpy array of all the transitions in the batch
|
|
"""
|
|
current_states = {}
|
|
next_states = {}
|
|
current_states['observation'] = np.array([np.array(transition.state['observation']) for transition in batch])
|
|
next_states['observation'] = np.array([np.array(transition.next_state['observation']) for transition in batch])
|
|
actions = np.array([transition.action for transition in batch])
|
|
rewards = np.array([transition.reward for transition in batch])
|
|
game_overs = np.array([transition.game_over for transition in batch])
|
|
total_return = np.array([transition.total_return for transition in batch])
|
|
|
|
# get the entire state including measurements if available
|
|
if self.tp.agent.use_measurements:
|
|
current_states['measurements'] = np.array([transition.state['measurements'] for transition in batch])
|
|
next_states['measurements'] = np.array([transition.next_state['measurements'] for transition in batch])
|
|
|
|
return current_states, next_states, actions, rewards, game_overs, total_return
|
|
|
|
def plot_action_values_online(self):
|
|
"""
|
|
Plot an animated graph of the value of each possible action during the episode
|
|
:return: None
|
|
"""
|
|
|
|
plt.clf()
|
|
for key, data_list in self.episode_running_info.items():
|
|
plt.plot(data_list, label=key)
|
|
plt.legend()
|
|
plt.pause(0.00000001)
|
|
|
|
def choose_action(self, curr_state, phase=utils.RunPhase.TRAIN):
|
|
"""
|
|
choose an action to act with in the current episode being played. Different behavior might be exhibited when training
|
|
or testing.
|
|
|
|
:param curr_state: the current state to act upon.
|
|
:param phase: the current phase: training or testing.
|
|
:return: chosen action, some action value describing the action (q-value, probability, etc)
|
|
"""
|
|
pass
|
|
|
|
def preprocess_reward(self, reward):
|
|
if self.tp.env.reward_scaling:
|
|
reward /= float(self.tp.env.reward_scaling)
|
|
if self.tp.env.reward_clipping_max:
|
|
reward = min(reward, self.tp.env.reward_clipping_max)
|
|
if self.tp.env.reward_clipping_min:
|
|
reward = max(reward, self.tp.env.reward_clipping_min)
|
|
return reward
|
|
|
|
def tf_input_state(self, curr_state):
|
|
"""
|
|
convert curr_state into input tensors tensorflow is expecting.
|
|
"""
|
|
# add batch axis with length 1 onto each value
|
|
# extract values from the state based on agent.input_types
|
|
input_state = {}
|
|
for input_name in self.tp.agent.input_types.keys():
|
|
input_state[input_name] = np.expand_dims(np.array(curr_state[input_name]), 0)
|
|
return input_state
|
|
|
|
def prepare_initial_state(self):
|
|
"""
|
|
Create an initial state when starting a new episode
|
|
:return: None
|
|
"""
|
|
observation = self.preprocess_observation(self.env.state['observation'])
|
|
self.curr_stack = collections.deque([observation]*self.tp.env.observation_stack_size, maxlen=self.tp.env.observation_stack_size)
|
|
observation = utils.LazyStack(self.curr_stack, -1)
|
|
|
|
self.curr_state = {
|
|
'observation': observation
|
|
}
|
|
if self.tp.agent.use_measurements:
|
|
self.curr_state['measurements'] = self.env.measurements
|
|
if self.tp.agent.use_accumulated_reward_as_measurement:
|
|
self.curr_state['measurements'] = np.append(self.curr_state['measurements'], 0)
|
|
|
|
def act(self, phase=utils.RunPhase.TRAIN):
|
|
"""
|
|
Take one step in the environment according to the network prediction and store the transition in memory
|
|
:param phase: Either Train or Test to specify if greedy actions should be used and if transitions should be stored
|
|
:return: A boolean value that signals an episode termination
|
|
"""
|
|
|
|
if phase != utils.RunPhase.TEST:
|
|
self.total_steps_counter += 1
|
|
self.current_episode_steps_counter += 1
|
|
|
|
# get new action
|
|
action_info = {"action_probability": 1.0 / self.env.action_space_size, "action_value": 0, "max_action_value": 0}
|
|
|
|
if phase == utils.RunPhase.HEATUP and not self.tp.heatup_using_network_decisions:
|
|
action = self.env.get_random_action()
|
|
else:
|
|
action, action_info = self.choose_action(self.curr_state, phase=phase)
|
|
|
|
# perform action
|
|
if type(action) == np.ndarray:
|
|
action = action.squeeze()
|
|
result = self.env.step(action)
|
|
|
|
shaped_reward = self.preprocess_reward(result['reward'])
|
|
if 'action_intrinsic_reward' in action_info.keys():
|
|
shaped_reward += action_info['action_intrinsic_reward']
|
|
# TODO: should total_reward_in_current_episode include shaped_reward?
|
|
self.total_reward_in_current_episode += result['reward']
|
|
next_state = result['state']
|
|
next_state['observation'] = self.preprocess_observation(next_state['observation'])
|
|
|
|
# plot action values online
|
|
if self.tp.visualization.plot_action_values_online and phase != utils.RunPhase.HEATUP:
|
|
self.plot_action_values_online()
|
|
|
|
# initialize the next state
|
|
# TODO: provide option to stack more than just the observation
|
|
self.curr_stack.append(next_state['observation'])
|
|
observation = utils.LazyStack(self.curr_stack, -1)
|
|
|
|
next_state['observation'] = observation
|
|
if self.tp.agent.use_measurements and 'measurements' in result.keys():
|
|
next_state['measurements'] = result['state']['measurements']
|
|
if self.tp.agent.use_accumulated_reward_as_measurement:
|
|
next_state['measurements'] = np.append(next_state['measurements'], self.total_reward_in_current_episode)
|
|
|
|
# store the transition only if we are training
|
|
if phase == utils.RunPhase.TRAIN or phase == utils.RunPhase.HEATUP:
|
|
transition = memory.Transition(self.curr_state, result['action'], shaped_reward, next_state, result['done'])
|
|
for key in action_info.keys():
|
|
transition.info[key] = action_info[key]
|
|
if self.tp.agent.add_a_normalized_timestep_to_the_observation:
|
|
transition.info['timestep'] = float(self.current_episode_steps_counter) / self.env.timestep_limit
|
|
self.memory.store(transition)
|
|
elif phase == utils.RunPhase.TEST and self.tp.visualization.dump_gifs:
|
|
# we store the transitions only for saving gifs
|
|
self.last_episode_images.append(self.env.get_rendered_image())
|
|
|
|
# update the current state for the next step
|
|
self.curr_state = next_state
|
|
|
|
# deal with episode termination
|
|
if result['done']:
|
|
if self.tp.visualization.dump_csv:
|
|
self.update_log(phase=phase)
|
|
self.log_to_screen(phase=phase)
|
|
|
|
if phase == utils.RunPhase.TRAIN or phase == utils.RunPhase.HEATUP:
|
|
self.reset_game()
|
|
|
|
self.current_episode += 1
|
|
self.tp.current_episode = self.current_episode
|
|
|
|
# return episode really ended
|
|
return result['done']
|
|
|
|
def evaluate(self, num_episodes, keep_networks_synced=False):
|
|
"""
|
|
Run in an evaluation mode for several episodes. Actions will be chosen greedily.
|
|
:param keep_networks_synced: keep the online network in sync with the global network after every episode
|
|
:param num_episodes: The number of episodes to evaluate on
|
|
:return: None
|
|
"""
|
|
|
|
max_reward_achieved = -float('inf')
|
|
average_evaluation_reward = 0
|
|
logger.screen.log_title("Running evaluation")
|
|
self.env.change_phase(utils.RunPhase.TEST)
|
|
for i in range(num_episodes):
|
|
# keep the online network in sync with the global network
|
|
if keep_networks_synced:
|
|
for network in self.networks:
|
|
network.sync()
|
|
|
|
episode_ended = False
|
|
while not episode_ended:
|
|
episode_ended = self.act(phase=utils.RunPhase.TEST)
|
|
|
|
if keep_networks_synced \
|
|
and self.total_steps_counter % self.tp.agent.update_evaluation_agent_network_after_every_num_steps:
|
|
for network in self.networks:
|
|
network.sync()
|
|
|
|
if self.total_reward_in_current_episode > max_reward_achieved:
|
|
max_reward_achieved = self.total_reward_in_current_episode
|
|
frame_skipping = int(5/self.tp.env.frame_skip)
|
|
if self.tp.visualization.dump_gifs:
|
|
logger.logger.create_gif(self.last_episode_images[::frame_skipping],
|
|
name='score-{}'.format(max_reward_achieved), fps=10)
|
|
|
|
average_evaluation_reward += self.total_reward_in_current_episode
|
|
self.reset_game()
|
|
|
|
average_evaluation_reward /= float(num_episodes)
|
|
|
|
self.env.change_phase(utils.RunPhase.TRAIN)
|
|
logger.screen.log_title("Evaluation done. Average reward = {}.".format(average_evaluation_reward))
|
|
|
|
def post_training_commands(self):
|
|
pass
|
|
|
|
def improve(self):
|
|
"""
|
|
Training algorithms wrapper. Heatup >> [ Evaluate >> Play >> Train >> Save checkpoint ]
|
|
|
|
:return: None
|
|
"""
|
|
|
|
# synchronize the online network weights with the global network
|
|
for network in self.networks:
|
|
network.sync()
|
|
|
|
# heatup phase
|
|
if self.tp.num_heatup_steps != 0:
|
|
self.in_heatup = True
|
|
logger.screen.log_title("Starting heatup {}".format(self.task_id))
|
|
num_steps_required_for_one_training_batch = self.tp.batch_size * self.tp.env.observation_stack_size
|
|
for step in range(max(self.tp.num_heatup_steps, num_steps_required_for_one_training_batch)):
|
|
self.act(phase=utils.RunPhase.HEATUP)
|
|
|
|
# training phase
|
|
self.in_heatup = False
|
|
logger.screen.log_title("Starting training {}".format(self.task_id))
|
|
self.exploration_policy.change_phase(utils.RunPhase.TRAIN)
|
|
training_start_time = time.time()
|
|
model_snapshots_periods_passed = -1
|
|
self.reset_game()
|
|
|
|
while self.training_iteration < self.tp.num_training_iterations:
|
|
# evaluate
|
|
evaluate_agent = (self.last_episode_evaluation_ran is not self.current_episode) and \
|
|
(self.current_episode % self.tp.evaluate_every_x_episodes == 0)
|
|
evaluate_agent = evaluate_agent or \
|
|
(self.imitation and self.training_iteration > 0 and
|
|
self.training_iteration % self.tp.evaluate_every_x_training_iterations == 0)
|
|
|
|
if evaluate_agent:
|
|
self.env.reset(force_environment_reset=True)
|
|
self.last_episode_evaluation_ran = self.current_episode
|
|
self.evaluate(self.tp.evaluation_episodes)
|
|
|
|
# snapshot model
|
|
if self.tp.save_model_sec and self.tp.save_model_sec > 0 and not self.tp.distributed:
|
|
total_training_time = time.time() - training_start_time
|
|
current_snapshot_period = (int(total_training_time) // self.tp.save_model_sec)
|
|
if current_snapshot_period > model_snapshots_periods_passed:
|
|
model_snapshots_periods_passed = current_snapshot_period
|
|
self.save_model(model_snapshots_periods_passed)
|
|
|
|
# play and record in replay buffer
|
|
if self.tp.agent.collect_new_data:
|
|
if self.tp.agent.step_until_collecting_full_episodes:
|
|
step = 0
|
|
while step < self.tp.agent.num_consecutive_playing_steps or self.memory.get_episode(-1).length() != 0:
|
|
self.act()
|
|
step += 1
|
|
else:
|
|
for step in range(self.tp.agent.num_consecutive_playing_steps):
|
|
self.act()
|
|
|
|
# train
|
|
if self.tp.train:
|
|
for step in range(self.tp.agent.num_consecutive_training_steps):
|
|
loss = self.train()
|
|
self.loss.add_sample(loss)
|
|
self.training_iteration += 1
|
|
if self.imitation:
|
|
self.log_to_screen(utils.RunPhase.TRAIN)
|
|
self.post_training_commands()
|
|
|
|
def save_model(self, model_id):
|
|
self.main_network.save_model(model_id)
|