mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
TD3 (#338)
This commit is contained in:
@@ -38,3 +38,10 @@ def get_activation_function(activation_function_string: str):
|
||||
"Activation function must be one of the following {}. instead it was: {}" \
|
||||
.format(activation_functions.keys(), activation_function_string)
|
||||
return activation_functions[activation_function_string]
|
||||
|
||||
|
||||
def squeeze_tensor(tensor):
|
||||
if tensor.shape[0] == 1:
|
||||
return tensor[0]
|
||||
else:
|
||||
return tensor
|
||||
Reference in New Issue
Block a user