mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
update nec and value optimization agents to work with recurrent middleware
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
#
|
||||
# Copyright (c) 2017 Intel Corporation
|
||||
# Copyright (c) 2017 Intel Corporation
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -30,16 +30,19 @@ except ImportError:
|
||||
|
||||
|
||||
class NetworkWrapper(object):
|
||||
"""
|
||||
Contains multiple networks and managers syncing and gradient updates
|
||||
between them.
|
||||
"""
|
||||
def __init__(self, tuning_parameters, has_target, has_global, name, replicated_device=None, worker_device=None):
|
||||
"""
|
||||
|
||||
:param tuning_parameters:
|
||||
:param tuning_parameters:
|
||||
:type tuning_parameters: Preset
|
||||
:param has_target:
|
||||
:param has_global:
|
||||
:param name:
|
||||
:param replicated_device:
|
||||
:param worker_device:
|
||||
:param has_target:
|
||||
:param has_global:
|
||||
:param name:
|
||||
:param replicated_device:
|
||||
:param worker_device:
|
||||
"""
|
||||
self.tp = tuning_parameters
|
||||
self.has_target = has_target
|
||||
@@ -87,7 +90,7 @@ class NetworkWrapper(object):
|
||||
def sync(self):
|
||||
"""
|
||||
Initializes the weights of the networks to match each other
|
||||
:return:
|
||||
:return:
|
||||
"""
|
||||
self.update_online_network()
|
||||
self.update_target_network()
|
||||
@@ -111,14 +114,14 @@ class NetworkWrapper(object):
|
||||
def apply_gradients_to_global_network(self):
|
||||
"""
|
||||
Apply gradients from the online network on the global network
|
||||
:return:
|
||||
:return:
|
||||
"""
|
||||
self.global_network.apply_gradients(self.online_network.accumulated_gradients)
|
||||
|
||||
def apply_gradients_to_online_network(self):
|
||||
"""
|
||||
Apply gradients from the online network on itself
|
||||
:return:
|
||||
:return:
|
||||
"""
|
||||
self.online_network.apply_gradients(self.online_network.accumulated_gradients)
|
||||
|
||||
@@ -135,7 +138,7 @@ class NetworkWrapper(object):
|
||||
|
||||
def apply_gradients_and_sync_networks(self):
|
||||
"""
|
||||
Applies the gradients accumulated in the online network to the global network or to itself and syncs the
|
||||
Applies the gradients accumulated in the online network to the global network or to itself and syncs the
|
||||
networks if necessary
|
||||
"""
|
||||
if self.global_network:
|
||||
|
||||
Reference in New Issue
Block a user