mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 11:10:20 +01:00
fix for intel optimized tensorflow on distributed runs + adding coach_env to .gitignore
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -14,3 +14,4 @@ roboschool
|
|||||||
*.doc
|
*.doc
|
||||||
*.orig
|
*.orig
|
||||||
docs/site
|
docs/site
|
||||||
|
coach_env
|
||||||
|
|||||||
@@ -343,8 +343,8 @@ class Agent(object):
|
|||||||
:param to_type: can be 'channels_first' or 'channels_last'
|
:param to_type: can be 'channels_first' or 'channels_last'
|
||||||
:return: a new observation with the requested axes order
|
:return: a new observation with the requested axes order
|
||||||
"""
|
"""
|
||||||
if from_type == to_type:
|
if from_type == to_type or len(observation.shape) == 1:
|
||||||
return
|
return observation
|
||||||
assert 2 <= len(observation.shape) <= 3, 'num axes of an observation must be 2 for a vector or 3 for an image'
|
assert 2 <= len(observation.shape) <= 3, 'num axes of an observation must be 2 for a vector or 3 for an image'
|
||||||
assert type(observation) == np.ndarray, 'observation must be a numpy array'
|
assert type(observation) == np.ndarray, 'observation must be a numpy array'
|
||||||
if len(observation.shape) == 3:
|
if len(observation.shape) == 3:
|
||||||
|
|||||||
3
coach.py
3
coach.py
@@ -37,7 +37,6 @@ time_started = datetime.datetime.now()
|
|||||||
cur_time = time_started.time()
|
cur_time = time_started.time()
|
||||||
cur_date = time_started.date()
|
cur_date = time_started.date()
|
||||||
|
|
||||||
|
|
||||||
def get_experiment_path(general_experiments_path):
|
def get_experiment_path(general_experiments_path):
|
||||||
if not os.path.exists(general_experiments_path):
|
if not os.path.exists(general_experiments_path):
|
||||||
os.makedirs(general_experiments_path)
|
os.makedirs(general_experiments_path)
|
||||||
@@ -265,7 +264,7 @@ if __name__ == "__main__":
|
|||||||
# Multi-threaded runs
|
# Multi-threaded runs
|
||||||
else:
|
else:
|
||||||
assert args.framework.lower() == 'tensorflow', "Distributed training works only with TensorFlow"
|
assert args.framework.lower() == 'tensorflow', "Distributed training works only with TensorFlow"
|
||||||
|
os.environ["OMP_NUM_THREADS"]="1"
|
||||||
# set parameter server and workers addresses
|
# set parameter server and workers addresses
|
||||||
ps_hosts = "localhost:{}".format(get_open_port())
|
ps_hosts = "localhost:{}".format(get_open_port())
|
||||||
worker_hosts = ",".join(["localhost:{}".format(get_open_port()) for i in range(run_dict['num_threads'] + 1)])
|
worker_hosts = ",".join(["localhost:{}".format(get_open_port()) for i in range(run_dict['num_threads'] + 1)])
|
||||||
|
|||||||
Reference in New Issue
Block a user