1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 19:20:19 +01:00

Create a dataset using an agent (#306)

Generate a dataset using an agent (allowing to select between this and a random dataset)
This commit is contained in:
Gal Leibovich
2019-05-28 09:34:49 +03:00
committed by GitHub
parent 342b7184bc
commit 9e9c4fd332
26 changed files with 351 additions and 111 deletions

View File

@@ -20,8 +20,7 @@ import numpy as np
from rl_coach.core_types import RunPhase, ActionType
from rl_coach.exploration_policies.additive_noise import AdditiveNoiseParameters
from rl_coach.exploration_policies.exploration_policy import ExplorationParameters
from rl_coach.exploration_policies.exploration_policy import ExplorationPolicy
from rl_coach.exploration_policies.exploration_policy import ExplorationParameters, ExplorationPolicy
from rl_coach.schedules import Schedule, LinearSchedule
from rl_coach.spaces import ActionSpace, DiscreteActionSpace, BoxActionSpace
from rl_coach.utils import dynamic_import_and_instantiate_module_from_params
@@ -82,26 +81,32 @@ class EGreedy(ExplorationPolicy):
epsilon = self.evaluation_epsilon if self.phase == RunPhase.TEST else self.epsilon_schedule.current_value
return self.current_random_value >= epsilon
def get_action(self, action_values: List[ActionType]) -> ActionType:
def get_action(self, action_values: List[ActionType]) -> (ActionType, List[float]):
epsilon = self.evaluation_epsilon if self.phase == RunPhase.TEST else self.epsilon_schedule.current_value
if isinstance(self.action_space, DiscreteActionSpace):
top_action = np.argmax(action_values)
if self.current_random_value < epsilon:
chosen_action = self.action_space.sample()
probabilities = np.full(len(self.action_space.actions),
1. / (self.action_space.high[0] - self.action_space.low[0] + 1))
else:
chosen_action = top_action
chosen_action = np.argmax(action_values)
# one-hot probabilities vector
probabilities = np.zeros(len(self.action_space.actions))
probabilities[chosen_action] = 1
self.step_epsilon()
return chosen_action, probabilities
else:
if self.current_random_value < epsilon and self.phase == RunPhase.TRAIN:
chosen_action = self.action_space.sample()
else:
chosen_action = self.continuous_exploration_policy.get_action(action_values)
# step the epsilon schedule and generate a new random value for next time
if self.phase == RunPhase.TRAIN:
self.epsilon_schedule.step()
self.current_random_value = np.random.rand()
return chosen_action
self.step_epsilon()
return chosen_action
def get_control_param(self):
if isinstance(self.action_space, DiscreteActionSpace):
@@ -113,3 +118,9 @@ class EGreedy(ExplorationPolicy):
super().change_phase(phase)
if isinstance(self.action_space, BoxActionSpace):
self.continuous_exploration_policy.change_phase(phase)
def step_epsilon(self):
# step the epsilon schedule and generate a new random value for next time
if self.phase == RunPhase.TRAIN:
self.epsilon_schedule.step()
self.current_random_value = np.random.rand()