1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 19:20:19 +01:00
Files
coach/rl_coach/tests/test_global_variable_saver.py

80 lines
2.2 KiB
Python

import random
import pickle
import pytest
import tensorflow as tf
import numpy as np
from rl_coach.architectures.tensorflow_components.savers import GlobalVariableSaver
def random_name():
return "%032x" % random.randrange(16 ** 32)
@pytest.fixture
def name():
return random_name()
@pytest.fixture
def variable(shape, name):
tf.reset_default_graph()
return tf.Variable(tf.zeros(shape), name=name)
@pytest.fixture
def shape():
return (3, 5)
def assert_arrays_ones_shape(arrays, shape, name):
assert list(arrays.keys()) == [name]
assert len(arrays) == 1
assert np.all(list(arrays[name][0]) == np.ones(shape))
@pytest.mark.unit_test
def test_global_variable_saver_to_arrays(variable, name, shape):
with tf.Session() as session:
session.run(tf.global_variables_initializer())
session.run(variable.assign(tf.ones(shape)))
saver = GlobalVariableSaver("name")
arrays = saver.to_arrays(session)
assert_arrays_ones_shape(arrays, shape, name)
@pytest.mark.unit_test
def test_global_variable_saver_from_arrays(variable, name, shape):
with tf.Session() as session:
session.run(tf.global_variables_initializer())
saver = GlobalVariableSaver("name")
saver.from_arrays(session, {name: np.ones(shape)})
arrays = saver.to_arrays(session)
assert_arrays_ones_shape(arrays, shape, name)
@pytest.mark.unit_test
def test_global_variable_saver_to_string(variable, name, shape):
with tf.Session() as session:
session.run(tf.global_variables_initializer())
session.run(variable.assign(tf.ones(shape)))
saver = GlobalVariableSaver("name")
string = saver.to_string(session)
arrays = pickle.loads(string)
assert_arrays_ones_shape(arrays, shape, name)
@pytest.mark.unit_test
def test_global_variable_saver_from_string(variable, name, shape):
with tf.Session() as session:
session.run(tf.global_variables_initializer())
saver = GlobalVariableSaver("name")
saver.from_string(session, pickle.dumps({name: np.ones(shape)}, protocol=-1))
arrays = saver.to_arrays(session)
assert_arrays_ones_shape(arrays, shape, name)