Commit 5166896e authored by Amir MOHAMMADI's avatar Amir MOHAMMADI

Fix the db to tfrecods tests

parent e2841635
Pipeline #26007 passed with stage
in 28 minutes and 11 seconds
......@@ -3,14 +3,15 @@ import shutil
import pkg_resources
import tempfile
from click.testing import CliRunner
import bob.io.base
from bob.learn.tensorflow.script.db_to_tfrecords import db_to_tfrecords, describe_tf_record
from bob.io.base import create_directories_safe
from bob.learn.tensorflow.script.db_to_tfrecords import (
db_to_tfrecords, describe_tf_record)
from bob.learn.tensorflow.utils import load_mnist, create_mnist_tfrecord
regenerate_reference = False
dummy_config = pkg_resources.resource_filename(
'bob.learn.tensorflow', 'test/data/tfrecord_config.py')
'bob.learn.tensorflow', 'test/data/db_to_tfrecords_config.py')
def test_db_to_tfrecords():
......@@ -52,20 +53,21 @@ def test_db_to_tfrecords_size_estimate():
def test_tfrecord_counter():
tfrecord_train = "./tf-train-test/train_mnist.tfrecord"
shape = (3136,) # I'm saving the thing as float
shape = (3136,) # I'm saving the thing as float
batch_size = 1000
try:
train_data, train_labels, validation_data, validation_labels = load_mnist()
bob.io.base.create_directories_safe(os.path.dirname(tfrecord_train))
train_data, train_labels, validation_data, validation_labels = \
load_mnist()
create_directories_safe(os.path.dirname(tfrecord_train))
create_mnist_tfrecord(
tfrecord_train, train_data, train_labels, n_samples=6000)
n_samples, n_labels = describe_tf_record(os.path.dirname(tfrecord_train), shape, batch_size)
n_samples, n_labels = describe_tf_record(
os.path.dirname(tfrecord_train), shape, batch_size)
assert n_samples == 6000
assert n_labels == 10
finally:
shutil.rmtree(os.path.dirname(tfrecord_train))
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment