1
0
mirror of https://github.com/gryf/coach.git synced 2026-04-01 17:43:32 +02:00

Moved tf.variable_scope and tf.device calls to framework-specific architecture (#136)

This commit is contained in:
Sina Afrooze
2018-11-22 12:52:22 -08:00
committed by Gal Leibovich
parent 559969d3dd
commit 87a7848b0a
11 changed files with 219 additions and 91 deletions

View File

@@ -19,7 +19,6 @@ from typing import List
import numpy as np
from rl_coach.architectures.tensorflow_components.shared_variables import SharedRunningStats, TFSharedRunningStats
from rl_coach.core_types import ObservationType
from rl_coach.filters.observation.observation_filter import ObservationFilter
from rl_coach.spaces import ObservationSpace
@@ -54,6 +53,7 @@ class ObservationNormalizationFilter(ObservationFilter):
:return: None
"""
if mode == 'tf':
from rl_coach.architectures.tensorflow_components.shared_variables import TFSharedRunningStats
self.running_observation_stats = TFSharedRunningStats(device, name=self.name, create_ops=False,
pubsub_params=memory_backend_params)
elif mode == 'numpy':

View File

@@ -17,7 +17,6 @@ import os
import numpy as np
from rl_coach.architectures.tensorflow_components.shared_variables import TFSharedRunningStats
from rl_coach.core_types import RewardType
from rl_coach.filters.reward.reward_filter import RewardFilter
from rl_coach.spaces import RewardSpace
@@ -48,6 +47,7 @@ class RewardNormalizationFilter(RewardFilter):
"""
if mode == 'tf':
from rl_coach.architectures.tensorflow_components.shared_variables import TFSharedRunningStats
self.running_rewards_stats = TFSharedRunningStats(device, name='rewards_stats', create_ops=False,
pubsub_params=memory_backend_params)
elif mode == 'numpy':