mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
Create a dataset using an agent (#306)
Generate a dataset using an agent (allowing to select between this and a random dataset)
This commit is contained in:
@@ -20,8 +20,7 @@ import numpy as np
|
||||
|
||||
from rl_coach.core_types import RunPhase, ActionType
|
||||
from rl_coach.exploration_policies.additive_noise import AdditiveNoiseParameters
|
||||
from rl_coach.exploration_policies.exploration_policy import ExplorationParameters
|
||||
from rl_coach.exploration_policies.exploration_policy import ExplorationPolicy
|
||||
from rl_coach.exploration_policies.exploration_policy import ExplorationParameters, ExplorationPolicy
|
||||
from rl_coach.schedules import Schedule, LinearSchedule
|
||||
from rl_coach.spaces import ActionSpace, DiscreteActionSpace, BoxActionSpace
|
||||
from rl_coach.utils import dynamic_import_and_instantiate_module_from_params
|
||||
@@ -82,26 +81,32 @@ class EGreedy(ExplorationPolicy):
|
||||
epsilon = self.evaluation_epsilon if self.phase == RunPhase.TEST else self.epsilon_schedule.current_value
|
||||
return self.current_random_value >= epsilon
|
||||
|
||||
def get_action(self, action_values: List[ActionType]) -> ActionType:
|
||||
def get_action(self, action_values: List[ActionType]) -> (ActionType, List[float]):
|
||||
epsilon = self.evaluation_epsilon if self.phase == RunPhase.TEST else self.epsilon_schedule.current_value
|
||||
|
||||
if isinstance(self.action_space, DiscreteActionSpace):
|
||||
top_action = np.argmax(action_values)
|
||||
if self.current_random_value < epsilon:
|
||||
chosen_action = self.action_space.sample()
|
||||
probabilities = np.full(len(self.action_space.actions),
|
||||
1. / (self.action_space.high[0] - self.action_space.low[0] + 1))
|
||||
else:
|
||||
chosen_action = top_action
|
||||
chosen_action = np.argmax(action_values)
|
||||
|
||||
# one-hot probabilities vector
|
||||
probabilities = np.zeros(len(self.action_space.actions))
|
||||
probabilities[chosen_action] = 1
|
||||
|
||||
self.step_epsilon()
|
||||
return chosen_action, probabilities
|
||||
|
||||
else:
|
||||
if self.current_random_value < epsilon and self.phase == RunPhase.TRAIN:
|
||||
chosen_action = self.action_space.sample()
|
||||
else:
|
||||
chosen_action = self.continuous_exploration_policy.get_action(action_values)
|
||||
|
||||
# step the epsilon schedule and generate a new random value for next time
|
||||
if self.phase == RunPhase.TRAIN:
|
||||
self.epsilon_schedule.step()
|
||||
self.current_random_value = np.random.rand()
|
||||
return chosen_action
|
||||
self.step_epsilon()
|
||||
return chosen_action
|
||||
|
||||
def get_control_param(self):
|
||||
if isinstance(self.action_space, DiscreteActionSpace):
|
||||
@@ -113,3 +118,9 @@ class EGreedy(ExplorationPolicy):
|
||||
super().change_phase(phase)
|
||||
if isinstance(self.action_space, BoxActionSpace):
|
||||
self.continuous_exploration_policy.change_phase(phase)
|
||||
|
||||
def step_epsilon(self):
|
||||
# step the epsilon schedule and generate a new random value for next time
|
||||
if self.phase == RunPhase.TRAIN:
|
||||
self.epsilon_schedule.step()
|
||||
self.current_random_value = np.random.rand()
|
||||
|
||||
Reference in New Issue
Block a user