mirror of
https://github.com/gryf/coach.git
synced 2026-04-01 17:43:32 +02:00
Moved tf.variable_scope and tf.device calls to framework-specific architecture (#136)
This commit is contained in:
committed by
Gal Leibovich
parent
559969d3dd
commit
87a7848b0a
@@ -19,7 +19,6 @@ from typing import List
|
||||
|
||||
import numpy as np
|
||||
|
||||
from rl_coach.architectures.tensorflow_components.shared_variables import SharedRunningStats, TFSharedRunningStats
|
||||
from rl_coach.core_types import ObservationType
|
||||
from rl_coach.filters.observation.observation_filter import ObservationFilter
|
||||
from rl_coach.spaces import ObservationSpace
|
||||
@@ -54,6 +53,7 @@ class ObservationNormalizationFilter(ObservationFilter):
|
||||
:return: None
|
||||
"""
|
||||
if mode == 'tf':
|
||||
from rl_coach.architectures.tensorflow_components.shared_variables import TFSharedRunningStats
|
||||
self.running_observation_stats = TFSharedRunningStats(device, name=self.name, create_ops=False,
|
||||
pubsub_params=memory_backend_params)
|
||||
elif mode == 'numpy':
|
||||
|
||||
@@ -17,7 +17,6 @@ import os
|
||||
|
||||
import numpy as np
|
||||
|
||||
from rl_coach.architectures.tensorflow_components.shared_variables import TFSharedRunningStats
|
||||
from rl_coach.core_types import RewardType
|
||||
from rl_coach.filters.reward.reward_filter import RewardFilter
|
||||
from rl_coach.spaces import RewardSpace
|
||||
@@ -48,6 +47,7 @@ class RewardNormalizationFilter(RewardFilter):
|
||||
"""
|
||||
|
||||
if mode == 'tf':
|
||||
from rl_coach.architectures.tensorflow_components.shared_variables import TFSharedRunningStats
|
||||
self.running_rewards_stats = TFSharedRunningStats(device, name='rewards_stats', create_ops=False,
|
||||
pubsub_params=memory_backend_params)
|
||||
elif mode == 'numpy':
|
||||
|
||||
Reference in New Issue
Block a user