1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-18 11:40:18 +01:00

tests: Removed mxnet from functional tests + minor fix on rewards (#362)

* ci: change workflow

* changed timeout

* fix function reach reward

* print logs

* removing mxnet

* res'
This commit is contained in:
anabwan
2019-06-27 18:52:29 +03:00
committed by GitHub
parent 30c64d0656
commit a576ab5659
4 changed files with 28 additions and 24 deletions

View File

@@ -117,32 +117,33 @@ def is_reward_reached(csv_path, p_valid_params, start_time, time_limit):
last_num_episodes = 0
csv = None
reward_reached = False
reward_str = 'Evaluation Reward'
while csv is None or (csv['Episode #'].values[-1]
< p_valid_params.max_episodes_to_achieve_reward and
time.time() - start_time < time_limit):
csv = pd.read_csv(csv_path)
if 'Evaluation Reward' not in csv.keys():
while csv is None or (csv[csv.columns[0]].values[
-1] < p_valid_params.max_episodes_to_achieve_reward and time.time() - start_time < time_limit):
try:
csv = pd.read_csv(csv_path)
except:
# sometimes the csv is being written at the same time we are
# trying to read it. no problem -> try again
continue
rewards = csv['Evaluation Reward'].values
if reward_str not in csv.keys():
continue
rewards = csv[reward_str].values
rewards = rewards[~np.isnan(rewards)]
if len(rewards) >= 1:
averaged_rewards = np.convolve(rewards, np.ones(
min(len(rewards), win_size)) / win_size, mode='valid')
if len(rewards) >= 1:
averaged_rewards = np.convolve(rewards, np.ones(min(len(rewards), win_size)) / win_size, mode='valid')
else:
# May be in heat-up steps
time.sleep(1)
continue
if csv['Episode #'].shape[0] - last_num_episodes <= 0:
if csv[csv.columns[0]].shape[0] - last_num_episodes <= 0:
continue
last_num_episodes = csv['Episode #'].values[-1]
last_num_episodes = csv[csv.columns[0]].values[-1]
# check if reward is enough
if np.any(averaged_rewards >= p_valid_params.min_reward_threshold):
@@ -408,6 +409,7 @@ def validate_arg_result(flag, p_valid_params, clres=None, process=None,
csv_path = get_csv_path(clres=clres)
assert len(csv_path) > 0, \
Def.Consts.ASSERT_MSG.format("path not found", csv_path)
time.sleep(5)
get_reward = is_reward_reached(csv_path=csv_path[0],
p_valid_params=p_valid_params,