Commit dd5fe89e authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

Fix test_estimator_scripts

parent 00629861
Pipeline #19689 failed with stage
in 43 minutes and 31 seconds
......@@ -19,8 +19,9 @@ def test_db_to_tfrecords():
try:
runner = CliRunner()
result = runner.invoke(db_to_tfrecords, args=(
dummy_config, '--output', output_path))
assert result.exit_code == 0, result.output
dummy_config, '--output', output_path), standalone_mode=False)
assert result.exit_code == 0, '%s\n%s\n%s' % (
result.exc_info, result.output, result.exception)
# TODO: test if the generated tfrecords file is equal with a reference
# file
......@@ -29,18 +30,18 @@ def test_db_to_tfrecords():
shutil.rmtree(test_dir)
# def test_db_to_tfrecords_size_estimate():
# test_dir = tempfile.mkdtemp(prefix='bobtest_')
# output_path = os.path.join(test_dir, 'dev.tfrecords')
#
# try:
# runner = CliRunner()
# args = (dummy_config, '--size-estimate', '--output', output_path)
# print(' '.join(args))
# result = runner.invoke(db_to_tfrecords, args=args,)
# assert result.exit_code == 0, '%s\n%s\n%s' % (
# result.exc_info, result.output, result.exception)
# assert '2.0 M bytes' in result.output, result.output
#
# finally:
# shutil.rmtree(test_dir)
def test_db_to_tfrecords_size_estimate():
test_dir = tempfile.mkdtemp(prefix='bobtest_')
output_path = os.path.join(test_dir, 'dev.tfrecords')
try:
args = (dummy_config, '--size-estimate', '--output', output_path)
runner = CliRunner()
result = runner.invoke(
db_to_tfrecords, args=args, standalone_mode=False)
assert result.exit_code == 0, '%s\n%s\n%s' % (
result.exc_info, result.output, result.exception)
assert '2.0 M bytes' in result.output, result.output
finally:
shutil.rmtree(test_dir)
from __future__ import print_function
import os
from tempfile import mkdtemp
import shutil
from tempfile import mkdtemp
from click.testing import CliRunner
from bob.io.base.test_utils import datafile
from bob.learn.tensorflow.script.db_to_tfrecords import main as tfrecords
from bob.learn.tensorflow.script.train_generic import main as train_generic
from bob.learn.tensorflow.script.eval_generic import main as eval_generic
from bob.learn.tensorflow.script.train_and_evaluate import main as train_and_evaluate
from bob.learn.tensorflow.script.db_to_tfrecords import db_to_tfrecords
from bob.learn.tensorflow.script.train import train
from bob.learn.tensorflow.script.eval import eval as eval_script
from bob.learn.tensorflow.script.train_and_evaluate import train_and_evaluate
dummy_tfrecord_config = datafile('dummy_verify_config.py', __name__)
CONFIG = '''
......@@ -98,8 +99,13 @@ def _create_tfrecord(test_dir):
config_path = os.path.join(test_dir, 'tfrecordconfig.py')
with open(dummy_tfrecord_config) as f, open(config_path, 'w') as f2:
f2.write(f.read().replace('TEST_DIR', test_dir))
tfrecords([config_path])
return os.path.join(test_dir, 'dev.tfrecords')
output = os.path.join(test_dir, 'dev.tfrecords')
runner = CliRunner()
result = runner.invoke(db_to_tfrecords, args=[
dummy_tfrecord_config, '--output', output])
assert result.exit_code == 0, '%s\n%s\n%s' % (
result.exc_info, result.output, result.exception)
return output
def _create_checkpoint(tmpdir, model_dir, dummy_tfrecord):
......@@ -110,7 +116,10 @@ def _create_checkpoint(tmpdir, model_dir, dummy_tfrecord):
config_path = os.path.join(tmpdir, 'train_config.py')
with open(config_path, 'w') as f:
f.write(config)
train_generic([config_path])
runner = CliRunner()
result = runner.invoke(train, args=[config_path])
assert result.exit_code == 0, '%s\n%s\n%s' % (
result.exc_info, result.output, result.exception)
def _eval(tmpdir, model_dir, dummy_tfrecord):
......@@ -121,7 +130,10 @@ def _eval(tmpdir, model_dir, dummy_tfrecord):
config_path = os.path.join(tmpdir, 'eval_config.py')
with open(config_path, 'w') as f:
f.write(config)
eval_generic([config_path])
runner = CliRunner()
result = runner.invoke(eval_script, args=[config_path])
assert result.exit_code == 0, '%s\n%s\n%s' % (
result.exc_info, result.output, result.exception)
def _train_and_evaluate(tmpdir, model_dir, dummy_tfrecord):
......@@ -132,7 +144,8 @@ def _train_and_evaluate(tmpdir, model_dir, dummy_tfrecord):
config_path = os.path.join(tmpdir, 'train_config.py')
with open(config_path, 'w') as f:
f.write(config)
train_and_evaluate([config_path])
runner = CliRunner()
runner.invoke(train_and_evaluate, args=[config_path])
def test_eval():
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment