1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-18 11:40:18 +01:00

Cleanup and refactoring (#171)

This commit is contained in:
Zach Dwiel
2019-01-15 03:04:53 -05:00
committed by Gal Leibovich
parent cd812b0d25
commit fedb4cbd7c
7 changed files with 45 additions and 33 deletions

View File

@@ -125,22 +125,25 @@ class GeneralTensorFlowNetwork(TensorFlowArchitecture):
super().__init__(agent_parameters, spaces, name, global_network,
network_is_local, network_is_trainable)
def fill_return_types():
ret_dict = {}
for cls in get_all_subclasses(PredictionType):
ret_dict[cls] = []
components = self.input_embedders + [self.middleware] + self.output_heads
for component in components:
if not hasattr(component, 'return_type'):
raise ValueError("{} has no return_type attribute. This should not happen.")
if component.return_type is not None:
ret_dict[component.return_type].append(component)
return ret_dict
self.available_return_types = fill_return_types()
self.available_return_types = self._available_return_types()
self.is_training = None
def _available_return_types(self):
ret_dict = {cls: [] for cls in get_all_subclasses(PredictionType)}
components = self.input_embedders + [self.middleware] + self.output_heads
for component in components:
if not hasattr(component, 'return_type'):
raise ValueError((
"{} has no return_type attribute. Without this, it is "
"unclear how this component should be used."
).format(component))
if component.return_type is not None:
ret_dict[component.return_type].append(component)
return ret_dict
def predict_with_prediction_type(self, states: Dict[str, np.ndarray],
prediction_type: PredictionType) -> Dict[str, np.ndarray]:
"""