mirror of
https://github.com/gryf/coach.git
synced 2025-12-18 03:30:19 +01:00
Cleanup and refactoring (#171)
This commit is contained in:
committed by
Gal Leibovich
parent
cd812b0d25
commit
fedb4cbd7c
@@ -125,22 +125,25 @@ class GeneralTensorFlowNetwork(TensorFlowArchitecture):
|
||||
super().__init__(agent_parameters, spaces, name, global_network,
|
||||
network_is_local, network_is_trainable)
|
||||
|
||||
def fill_return_types():
|
||||
ret_dict = {}
|
||||
for cls in get_all_subclasses(PredictionType):
|
||||
ret_dict[cls] = []
|
||||
components = self.input_embedders + [self.middleware] + self.output_heads
|
||||
for component in components:
|
||||
if not hasattr(component, 'return_type'):
|
||||
raise ValueError("{} has no return_type attribute. This should not happen.")
|
||||
if component.return_type is not None:
|
||||
ret_dict[component.return_type].append(component)
|
||||
|
||||
return ret_dict
|
||||
|
||||
self.available_return_types = fill_return_types()
|
||||
self.available_return_types = self._available_return_types()
|
||||
self.is_training = None
|
||||
|
||||
def _available_return_types(self):
|
||||
ret_dict = {cls: [] for cls in get_all_subclasses(PredictionType)}
|
||||
|
||||
components = self.input_embedders + [self.middleware] + self.output_heads
|
||||
for component in components:
|
||||
if not hasattr(component, 'return_type'):
|
||||
raise ValueError((
|
||||
"{} has no return_type attribute. Without this, it is "
|
||||
"unclear how this component should be used."
|
||||
).format(component))
|
||||
|
||||
if component.return_type is not None:
|
||||
ret_dict[component.return_type].append(component)
|
||||
|
||||
return ret_dict
|
||||
|
||||
def predict_with_prediction_type(self, states: Dict[str, np.ndarray],
|
||||
prediction_type: PredictionType) -> Dict[str, np.ndarray]:
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user