mirror of
https://github.com/gryf/coach.git
synced 2025-12-18 11:40:18 +01:00
Added missing imports, correct usages
This commit is contained in:
@@ -14,6 +14,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
import collections
|
import collections
|
||||||
|
import copy
|
||||||
import random
|
import random
|
||||||
import time
|
import time
|
||||||
|
|
||||||
@@ -30,8 +31,8 @@ import scipy
|
|||||||
|
|
||||||
from architectures.tensorflow_components import shared_variables as sv
|
from architectures.tensorflow_components import shared_variables as sv
|
||||||
import configurations
|
import configurations
|
||||||
import exploration_policies as ep
|
import exploration_policies as ep # noqa, used in eval()
|
||||||
import memories
|
import memories # noqa, used in eval()
|
||||||
from memories import memory
|
from memories import memory
|
||||||
import renderer
|
import renderer
|
||||||
import utils
|
import utils
|
||||||
|
|||||||
@@ -31,14 +31,14 @@ class HumanAgent(agent.Agent):
|
|||||||
self.clock = pygame.time.Clock()
|
self.clock = pygame.time.Clock()
|
||||||
self.max_fps = int(self.tp.visualization.max_fps_for_human_control)
|
self.max_fps = int(self.tp.visualization.max_fps_for_human_control)
|
||||||
|
|
||||||
utils.screen.log_title("Human Control Mode")
|
logger.screen.log_title("Human Control Mode")
|
||||||
available_keys = self.env.get_available_keys()
|
available_keys = self.env.get_available_keys()
|
||||||
if available_keys:
|
if available_keys:
|
||||||
utils.screen.log("Use keyboard keys to move. Press escape to quit. Available keys:")
|
logger.screen.log("Use keyboard keys to move. Press escape to quit. Available keys:")
|
||||||
utils.screen.log("")
|
logger.screen.log("")
|
||||||
for action, key in self.env.get_available_keys():
|
for action, key in self.env.get_available_keys():
|
||||||
utils.screen.log("\t- {}: {}".format(action, key))
|
logger.screen.log("\t- {}: {}".format(action, key))
|
||||||
utils.screen.separator()
|
logger.screen.separator()
|
||||||
|
|
||||||
def train(self):
|
def train(self):
|
||||||
return 0
|
return 0
|
||||||
@@ -58,12 +58,12 @@ class HumanAgent(agent.Agent):
|
|||||||
replay_buffer_path = os.path.join(logger.logger.experiments_path, 'replay_buffer.p')
|
replay_buffer_path = os.path.join(logger.logger.experiments_path, 'replay_buffer.p')
|
||||||
self.memory.tp = None
|
self.memory.tp = None
|
||||||
pickle.to_pickle(self.memory, replay_buffer_path)
|
pickle.to_pickle(self.memory, replay_buffer_path)
|
||||||
utils.screen.log_title("Replay buffer was stored in {}".format(replay_buffer_path))
|
logger.screen.log_title("Replay buffer was stored in {}".format(replay_buffer_path))
|
||||||
exit()
|
exit()
|
||||||
|
|
||||||
def log_to_screen(self, phase):
|
def log_to_screen(self, phase):
|
||||||
# log to utils.screen
|
# log to logger.screen
|
||||||
utils.screen.log_dict(
|
logger.screen.log_dict(
|
||||||
collections.OrderedDict([
|
collections.OrderedDict([
|
||||||
("Episode", self.current_episode),
|
("Episode", self.current_episode),
|
||||||
("total reward", self.total_reward_in_current_episode),
|
("total reward", self.total_reward_in_current_episode),
|
||||||
|
|||||||
1
coach.py
1
coach.py
@@ -1,3 +1,4 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
#
|
#
|
||||||
# Copyright (c) 2017 Intel Corporation
|
# Copyright (c) 2017 Intel Corporation
|
||||||
#
|
#
|
||||||
|
|||||||
@@ -15,6 +15,7 @@
|
|||||||
#
|
#
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
|
import sys
|
||||||
import time
|
import time
|
||||||
|
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|||||||
Reference in New Issue
Block a user