1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 19:20:19 +01:00

update nec and value optimization agents to work with recurrent middleware

This commit is contained in:
Zach Dwiel
2017-11-03 13:58:42 -07:00
parent 93a54c7e8e
commit 6c79a442f2
12 changed files with 138 additions and 72 deletions

View File

@@ -1,5 +1,5 @@
#
# Copyright (c) 2017 Intel Corporation
# Copyright (c) 2017 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -30,16 +30,19 @@ except ImportError:
class NetworkWrapper(object):
"""
Contains multiple networks and managers syncing and gradient updates
between them.
"""
def __init__(self, tuning_parameters, has_target, has_global, name, replicated_device=None, worker_device=None):
"""
:param tuning_parameters:
:param tuning_parameters:
:type tuning_parameters: Preset
:param has_target:
:param has_global:
:param name:
:param replicated_device:
:param worker_device:
:param has_target:
:param has_global:
:param name:
:param replicated_device:
:param worker_device:
"""
self.tp = tuning_parameters
self.has_target = has_target
@@ -87,7 +90,7 @@ class NetworkWrapper(object):
def sync(self):
"""
Initializes the weights of the networks to match each other
:return:
:return:
"""
self.update_online_network()
self.update_target_network()
@@ -111,14 +114,14 @@ class NetworkWrapper(object):
def apply_gradients_to_global_network(self):
"""
Apply gradients from the online network on the global network
:return:
:return:
"""
self.global_network.apply_gradients(self.online_network.accumulated_gradients)
def apply_gradients_to_online_network(self):
"""
Apply gradients from the online network on itself
:return:
:return:
"""
self.online_network.apply_gradients(self.online_network.accumulated_gradients)
@@ -135,7 +138,7 @@ class NetworkWrapper(object):
def apply_gradients_and_sync_networks(self):
"""
Applies the gradients accumulated in the online network to the global network or to itself and syncs the
Applies the gradients accumulated in the online network to the global network or to itself and syncs the
networks if necessary
"""
if self.global_network: