Commit 4add4621 authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira

Merge branch 'predict' into 'master'

Use bob.extension to load config files

See merge request !36
parents 0e818979 f2cc091a
Pipeline #14575 failed with stages
in 20 minutes and 55 seconds
......@@ -41,7 +41,8 @@ class BioGenerator(object):
"""
def __init__(self, database, biofiles, load_data=None,
biofile_to_label=None, multiple_samples=False):
biofile_to_label=None, multiple_samples=False, **kwargs):
super(BioGenerator, self).__init__(**kwargs)
if load_data is None:
def load_data(database, biofile):
data = read_original_data(
......@@ -95,6 +96,9 @@ class BioGenerator(object):
def output_shapes(self):
return self._output_shapes
def __len__(self):
return len(self.biofiles)
def __call__(self):
"""A generator function that when called will return the samples.
......
......@@ -93,7 +93,7 @@ import random
import pkg_resources
import tensorflow as tf
from bob.io.base import create_directories_safe
from bob.bio.base.utils import read_config_file
from bob.extension.config import load as read_config_file
from bob.learn.tensorflow.utils.commandline import \
get_from_config_or_commandline
from bob.core.log import setup, set_verbosity_level
......
......@@ -43,7 +43,7 @@ import os
import time
import six
import tensorflow as tf
from bob.bio.base.utils import read_config_file
from bob.extension.config import load as read_config_file
from ..utils.eval import get_global_step
......
......@@ -112,7 +112,8 @@ from multiprocessing import Pool
from collections import defaultdict
import numpy as np
from bob.io.base import create_directories_safe
from bob.bio.base.utils import read_config_file, save
from bob.bio.base.utils import save
from bob.extension.config import load as read_config_file
from bob.bio.base.tools.grid import indices
from bob.learn.tensorflow.utils.commandline import \
get_from_config_or_commandline
......@@ -231,7 +232,8 @@ def main(argv=None):
checkpoint_path=checkpoint_path,
)
logger.info("Saving the predictions in %s", output_dir)
logger.info("Saving the predictions of %d files in %s", len(generator),
output_dir)
pool = Pool()
try:
......
......@@ -53,7 +53,8 @@ from multiprocessing import Pool
from collections import defaultdict
import numpy as np
from bob.io.base import create_directories_safe
from bob.bio.base.utils import read_config_file, save
from bob.extension.config import load as read_config_file
from bob.bio.base.utils import save
from bob.learn.tensorflow.utils.commandline import \
get_from_config_or_commandline
from bob.core.log import setup, set_verbosity_level
......
......@@ -39,7 +39,7 @@ from __future__ import division
from __future__ import print_function
# import pkg_resources so that bob imports work properly:
import pkg_resources
from bob.bio.base.utils import read_config_file
from bob.extension.config import load as read_config_file
from bob.learn.tensorflow.utils.commandline import \
get_from_config_or_commandline
from bob.core.log import setup, set_verbosity_level
......
......@@ -2,8 +2,6 @@ from __future__ import print_function
import os
from tempfile import mkdtemp
import shutil
import logging
logging.getLogger("tensorflow").setLevel(logging.WARNING)
from bob.io.base.test_utils import datafile
from bob.learn.tensorflow.script.db_to_tfrecords import main as tfrecords
......@@ -13,7 +11,6 @@ from bob.learn.tensorflow.script.eval_generic import main as eval_generic
dummy_tfrecord_config = datafile('dummy_verify_config.py', __name__)
CONFIG = '''
import tensorflow as tf
from bob.learn.tensorflow.utils.reproducible import run_config
from bob.learn.tensorflow.dataset.tfrecords import shuffle_data_and_labels, \
batch_data_and_labels
......@@ -88,8 +85,7 @@ def model_fn(features, labels, mode, params, config):
eval_metric_ops=metrics)
estimator = tf.estimator.Estimator(model_fn=model_fn, model_dir=model_dir,
config=run_config)
estimator = tf.estimator.Estimator(model_fn=model_fn, model_dir=model_dir)
'''
......
......@@ -11,7 +11,7 @@ def get_from_config_or_commandline(config, keyword, args, defaults,
Parameters
----------
config : object
config : :any:`module`
The loaded config files.
keyword : str
The keyword to load from the config file or through command line.
......@@ -30,7 +30,7 @@ def get_from_config_or_commandline(config, keyword, args, defaults,
Example
-------
>>> from bob.bio.base.utils import read_config_file
>>> from bob.extension.config import load as read_config_file
>>> defaults = docopt(docs, argv=[""])
>>> args = docopt(docs, argv=argv)
>>> config_files = args['<config_files>']
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment