mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
Create a dataset using an agent (#306)
Generate a dataset using an agent (allowing to select between this and a random dataset)
This commit is contained in:
@@ -18,7 +18,7 @@ from typing import List
|
||||
|
||||
from rl_coach.base_parameters import Parameters
|
||||
from rl_coach.core_types import RunPhase, ActionType
|
||||
from rl_coach.spaces import ActionSpace
|
||||
from rl_coach.spaces import ActionSpace, DiscreteActionSpace, BoxActionSpace, GoalsSpace
|
||||
|
||||
|
||||
class ExplorationParameters(Parameters):
|
||||
@@ -54,14 +54,10 @@ class ExplorationPolicy(object):
|
||||
Given a list of values corresponding to each action,
|
||||
choose one actions according to the exploration policy
|
||||
:param action_values: A list of action values
|
||||
:return: The chosen action
|
||||
:return: The chosen action,
|
||||
The probability of the action (if available, otherwise 1 for absolute certainty in the action)
|
||||
"""
|
||||
if self.__class__ == ExplorationPolicy:
|
||||
raise ValueError("The ExplorationPolicy class is an abstract class and should not be used directly. "
|
||||
"Please set the exploration parameters to point to an inheriting class like EGreedy or "
|
||||
"AdditiveNoise")
|
||||
else:
|
||||
raise ValueError("The get_action function should be overridden in the inheriting exploration class")
|
||||
raise NotImplementedError()
|
||||
|
||||
def change_phase(self, phase):
|
||||
"""
|
||||
@@ -82,3 +78,42 @@ class ExplorationPolicy(object):
|
||||
|
||||
def get_control_param(self):
|
||||
return 0
|
||||
|
||||
|
||||
class DiscreteActionExplorationPolicy(ExplorationPolicy):
|
||||
"""
|
||||
A discrete action exploration policy.
|
||||
"""
|
||||
def __init__(self, action_space: ActionSpace):
|
||||
"""
|
||||
:param action_space: the action space used by the environment
|
||||
"""
|
||||
assert isinstance(action_space, DiscreteActionSpace)
|
||||
super().__init__(action_space)
|
||||
|
||||
def get_action(self, action_values: List[ActionType]) -> (ActionType, List):
|
||||
"""
|
||||
Given a list of values corresponding to each action,
|
||||
choose one actions according to the exploration policy
|
||||
:param action_values: A list of action values
|
||||
:return: The chosen action,
|
||||
The probabilities of actions to select from (if not available a one-hot vector)
|
||||
"""
|
||||
if self.__class__ == ExplorationPolicy:
|
||||
raise ValueError("The ExplorationPolicy class is an abstract class and should not be used directly. "
|
||||
"Please set the exploration parameters to point to an inheriting class like EGreedy or "
|
||||
"AdditiveNoise")
|
||||
else:
|
||||
raise ValueError("The get_action function should be overridden in the inheriting exploration class")
|
||||
|
||||
|
||||
class ContinuousActionExplorationPolicy(ExplorationPolicy):
|
||||
"""
|
||||
A continuous action exploration policy.
|
||||
"""
|
||||
def __init__(self, action_space: ActionSpace):
|
||||
"""
|
||||
:param action_space: the action space used by the environment
|
||||
"""
|
||||
assert isinstance(action_space, BoxActionSpace) or isinstance(action_space, GoalsSpace)
|
||||
super().__init__(action_space)
|
||||
|
||||
Reference in New Issue
Block a user