Skip to content
Snippets Groups Projects
Commit a48bce10 authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

remove bob.bio.base

parent ccf1dea7
No related branches found
No related tags found
1 merge request!85Porting to TF2
# See https://pre-commit.com for more information
# See https://pre-commit.com/hooks.html for more hooks
repos:
- repo: https://github.com/timothycrosley/isort
rev: 4.3.21-2
hooks:
- id: isort
args: [-sl]
- repo: https://github.com/psf/black
rev: stable
hooks:
- id: black
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v2.0.0
hooks:
- id: check-ast
- id: check-case-conflict
- id: trailing-whitespace
- id: end-of-file-fixer
- id: debug-statements
- id: check-added-large-files
- id: flake8
- repo: local
hooks:
- id: sphinx-build
name: sphinx build
entry: python -m sphinx.cmd.build
args: [-a, -E, -W, doc, sphinx]
language: system
files: ^doc/
types: [file]
pass_filenames: false
- id: sphinx-doctest
name: sphinx doctest
entry: python -m sphinx.cmd.build
args: [-a, -E, -b, doctest, doc, sphinx]
language: system
files: ^doc/
types: [file]
pass_filenames: false
...@@ -4,11 +4,12 @@ ...@@ -4,11 +4,12 @@
import logging import logging
logger = logging.getLogger(__name__)
import tensorflow as tf import tensorflow as tf
from bob.learn.tensorflow.utils import compute_euclidean_distance from bob.learn.tensorflow.utils import compute_euclidean_distance
logger = logging.getLogger(__name__)
def triplet_loss(anchor_embedding, positive_embedding, negative_embedding, margin=5.0): def triplet_loss(anchor_embedding, positive_embedding, negative_embedding, margin=5.0):
""" """
......
...@@ -59,9 +59,7 @@ $ bob tf trim -vvrK 2 ~/my_models ...@@ -59,9 +59,7 @@ $ bob tf trim -vvrK 2 ~/my_models
""" """
) )
@click.argument( @click.argument(
"root_dirs", "root_dirs", nargs=-1, type=click.Path(exists=True, file_okay=False, dir_okay=True),
nargs=-1,
type=click.Path(exists=True, file_okay=False, dir_okay=True),
) )
@click.option( @click.option(
"--keep-last-n-models", "--keep-last-n-models",
......
from bob.bio.base.test.dummy.database import database
from bob.bio.base.utils import read_original_data
from bob.learn.tensorflow.dataset.generator import dataset_using_generator from bob.learn.tensorflow.dataset.generator import dataset_using_generator
groups = ["dev"] groups = ["dev"]
......
import tensorflow as tf
from bob.learn.tensorflow.dataset.bio import BioGenerator
from bob.learn.tensorflow.utils import to_channels_last
batch_size = 2
epochs = 2
def input_fn(mode):
from bob.bio.base.test.dummy.database import database as db
if mode == tf.estimator.ModeKeys.TRAIN:
groups = "world"
elif mode == tf.estimator.ModeKeys.EVAL:
groups = "dev"
files = db.objects(groups=groups)
# construct integer labels for each identity in the database
CLIENT_IDS = (str(f.client_id) for f in files)
CLIENT_IDS = list(set(CLIENT_IDS))
CLIENT_IDS = dict(zip(CLIENT_IDS, range(len(CLIENT_IDS))))
def biofile_to_label(f):
return CLIENT_IDS[str(f.client_id)]
def load_data(database, f):
img = f.load(database.original_directory, database.original_extension)
# make a channels_first image (bob format) with 1 channel
img = img.reshape(1, 112, 92)
return img
generator = BioGenerator(db, files, load_data, biofile_to_label)
dataset = tf.data.Dataset.from_generator(
generator, generator.output_types, generator.output_shapes
)
def transform(image, label, key):
# convert to channels last
image = to_channels_last(image)
# per_image_standardization
image = tf.image.per_image_standardization(image)
return (image, label, key)
dataset = dataset.map(transform)
if mode == tf.estimator.ModeKeys.TRAIN:
# since we are caching to memory, caching only in training makes sense.
dataset = dataset.cache()
dataset = dataset.repeat(epochs)
dataset = dataset.batch(batch_size)
data, label, key = tf.compat.v1.data.make_one_shot_iterator(dataset).get_next()
return {"data": data, "key": key}, label
def train_input_fn():
return input_fn(tf.estimator.ModeKeys.TRAIN)
def eval_input_fn():
return input_fn(tf.estimator.ModeKeys.EVAL)
train_spec = tf.estimator.TrainSpec(input_fn=train_input_fn, max_steps=50)
eval_spec = tf.estimator.EvalSpec(input_fn=eval_input_fn)
import tensorflow as tf
from bob.bio.base.test.dummy.database import database
biofiles = database.all_files(["dev"])
def bio_predict_input_fn(generator, output_types, output_shapes):
def input_fn():
dataset = tf.data.Dataset.from_generator(generator, output_types, output_shapes)
# apply all kinds of transformations here, process the data
# even further if you want.
dataset = dataset.prefetch(1)
dataset = dataset.batch(10 ** 3)
images, labels, keys = tf.compat.v1.data.make_one_shot_iterator(
dataset
).get_next()
return {"data": images, "key": keys}, labels
return input_fn
...@@ -5,10 +5,8 @@ ...@@ -5,10 +5,8 @@
import numpy import numpy
import tensorflow as tf import tensorflow as tf
from bob.learn.tensorflow.loss import \ from bob.learn.tensorflow.loss import balanced_sigmoid_cross_entropy_loss_weights
balanced_sigmoid_cross_entropy_loss_weights from bob.learn.tensorflow.loss import balanced_softmax_cross_entropy_loss_weights
from bob.learn.tensorflow.loss import \
balanced_softmax_cross_entropy_loss_weights
def test_balanced_softmax_cross_entropy_loss_weights(): def test_balanced_softmax_cross_entropy_loss_weights():
......
...@@ -8,6 +8,7 @@ import tensorflow as tf ...@@ -8,6 +8,7 @@ import tensorflow as tf
from bob.learn.tensorflow.utils import compute_embedding_accuracy from bob.learn.tensorflow.utils import compute_embedding_accuracy
from bob.learn.tensorflow.utils import compute_embedding_accuracy_tensors from bob.learn.tensorflow.utils import compute_embedding_accuracy_tensors
""" """
Some unit tests for the datashuffler Some unit tests for the datashuffler
""" """
......
from __future__ import division
import numpy
from keras.utils import Sequence
from bob.bio.base.preprocessor import Preprocessor
# documentation imports
from bob.dap.base.database import PadDatabase
from bob.dap.base.database import PadFile
class PadSequence(Sequence):
"""A data shuffler for bob.dap.base database interfaces.
Attributes
----------
batch_size : int
The number of samples to return in every batch.
files : list of :any:`PadFile`
List of file objects for a particular group and protocol.
labels : list of bool
List of labels for the files. ``True`` if bona-fide, ``False`` if
attack.
preprocessor : :any:`Preprocessor`
The preprocessor to be used to load and process the data.
"""
def __init__(
self,
files,
labels,
batch_size,
preprocessor,
original_directory,
original_extension,
):
super(PadSequence, self).__init__()
self.files = files
self.labels = labels
self.batch_size = int(batch_size)
self.preprocessor = preprocessor
self.original_directory = original_directory
self.original_extension = original_extension
def __len__(self):
"""Number of batch in the Sequence.
Returns
-------
int
The number of batches in the Sequence.
"""
return int(numpy.ceil(len(self.files) / self.batch_size))
def __getitem__(self, idx):
files = self.files[idx * self.batch_size : (idx + 1) * self.batch_size]
labels = self.labels[idx * self.batch_size : (idx + 1) * self.batch_size]
return self.load_batch(files, labels)
def load_batch(self, files, labels):
"""Loads a batch of files and processes them.
Parameters
----------
files : list of :any:`PadFile`
List of files to load.
labels : list of bool
List of labels corresponding to the files.
Returns
-------
tuple of :any:`numpy.array`
A tuple of (x, y): the data and their targets.
"""
data, targets = [], []
for file_object, target in zip(files, labels):
loaded_data = self.preprocessor.read_original_data(
file_object, self.original_directory, self.original_extension
)
preprocessed_data = self.preprocessor(loaded_data)
data.append(preprocessed_data)
targets.append(target)
return numpy.array(data), numpy.array(targets)
def on_epoch_end(self):
pass
def shuffle_data(files, labels):
indexes = numpy.arange(len(files))
numpy.random.shuffle(indexes)
return [files[i] for i in indexes], [labels[i] for i in indexes]
def get_pad_files_labels(database, groups):
"""Returns the pad files and their labels.
Parameters
----------
database : :any:`PadDatabase`
The database to be used. The database should have a proper
``database.protocol`` attribute.
groups : str
The group to be used to return the data. One of ('world', 'dev',
'eval'). 'world' means training data and 'dev' means validation data.
Returns
-------
tuple
A tuple of (files, labels) for that particular group and protocol.
"""
files = database.samples(groups=groups, protocol=database.protocol)
labels = ((f.attack_type is None) for f in files)
labels = numpy.fromiter(labels, bool, len(files))
return files, labels
def get_pad_sequences(
database,
preprocessor,
batch_size,
groups=("world", "dev", "eval"),
shuffle=False,
limit=None,
):
"""Returns a list of :any:`Sequence` objects for the database.
Parameters
----------
database : :any:`PadDatabase`
The database to be used. The database should have a proper
``database.protocol`` attribute.
preprocessor : :any:`Preprocessor`
The preprocessor to be used to load and process the data.
batch_size : int
The number of samples to return in every batch.
groups : str
The group to be used to return the data. One of ('world', 'dev',
'eval'). 'world' means training data and 'dev' means validation data.
Returns
-------
list of :any:`Sequence`
The requested sequences to be used.
"""
seqs = []
for grp in groups:
files, labels = get_pad_files_labels(database, grp)
if shuffle:
files, labels = shuffle_data(files, labels)
if limit is not None:
files, labels = files[:limit], labels[:limit]
seqs.append(
PadSequence(
files,
labels,
batch_size,
preprocessor,
database.original_directory,
database.original_extension,
)
)
return seqs
...@@ -11,4 +11,4 @@ verbose = true ...@@ -11,4 +11,4 @@ verbose = true
[scripts] [scripts]
recipe = bob.buildout:scripts recipe = bob.buildout:scripts
dependent-scripts = true dependent-scripts = true
\ No newline at end of file
...@@ -4,9 +4,18 @@ ...@@ -4,9 +4,18 @@
import glob import glob
import os import os
import sys import sys
import time
import pkg_resources import pkg_resources
# The theme to use for HTML and HTML Help pages. See the documentation for
# a list of builtin themes.
import sphinx_rtd_theme
# For inter-documentation mapping:
from bob.extension.utils import link_documentation
from bob.extension.utils import load_requirements
# -- General configuration ----------------------------------------------------- # -- General configuration -----------------------------------------------------
# If your documentation needs a minimal Sphinx version, state it here. # If your documentation needs a minimal Sphinx version, state it here.
...@@ -76,7 +85,6 @@ master_doc = "index" ...@@ -76,7 +85,6 @@ master_doc = "index"
# General information about the project. # General information about the project.
project = u"bob.learn.tensorflow" project = u"bob.learn.tensorflow"
import time
copyright = u"%s, Idiap Research Institute" % time.strftime("%Y") copyright = u"%s, Idiap Research Institute" % time.strftime("%Y")
...@@ -134,9 +142,6 @@ owner = [u"Idiap Research Institute"] ...@@ -134,9 +142,6 @@ owner = [u"Idiap Research Institute"]
# -- Options for HTML output --------------------------------------------------- # -- Options for HTML output ---------------------------------------------------
# The theme to use for HTML and HTML Help pages. See the documentation for
# a list of builtin themes.
import sphinx_rtd_theme
html_theme = "sphinx_rtd_theme" html_theme = "sphinx_rtd_theme"
...@@ -234,9 +239,6 @@ autodoc_default_flags = [ ...@@ -234,9 +239,6 @@ autodoc_default_flags = [
"show-inheritance", "show-inheritance",
] ]
# For inter-documentation mapping:
from bob.extension.utils import link_documentation
from bob.extension.utils import load_requirements
sphinx_requirements = "extra-intersphinx.txt" sphinx_requirements = "extra-intersphinx.txt"
if os.path.exists(sphinx_requirements): if os.path.exists(sphinx_requirements):
......
...@@ -31,4 +31,3 @@ Indices and tables ...@@ -31,4 +31,3 @@ Indices and tables
* :ref:`genindex` * :ref:`genindex`
* :ref:`modindex` * :ref:`modindex`
* :ref:`search` * :ref:`search`
...@@ -169,12 +169,11 @@ There are several ways to provide data to Tensorflow graphs. In this section we ...@@ -169,12 +169,11 @@ There are several ways to provide data to Tensorflow graphs. In this section we
provide some examples on how to make the bridge between `bob.db` databases and provide some examples on how to make the bridge between `bob.db` databases and
tensorflow `input_fn`. tensorflow `input_fn`.
The BioGenerator input pipeline The Generator input pipeline
******************************* *******************************
The :any:`bob.learn.tensorflow.dataset.bio.BioGenerator` class can be used to The :any:`bob.learn.tensorflow.dataset.Generator` class can be used to convert any
convert any database of bob (not just bob.bio.base's databases) to a database of bob to a ``tf.data.Dataset`` instance.
``tf.data.Dataset`` instance.
While building the input pipeline, you can manipulate your data in two While building the input pipeline, you can manipulate your data in two
sections: sections:
...@@ -284,4 +283,3 @@ In this package we have crafted 4 types of estimators. ...@@ -284,4 +283,3 @@ In this package we have crafted 4 types of estimators.
:py:class:`bob.learn.tensorflow.estimators.Triplet` :py:class:`bob.learn.tensorflow.estimators.Triplet`
.. _tensorflow: https://www.tensorflow.org/ .. _tensorflow: https://www.tensorflow.org/
...@@ -6,11 +6,12 @@ ...@@ -6,11 +6,12 @@
from setuptools import dist from setuptools import dist
from setuptools import setup from setuptools import setup
dist.Distribution(dict(setup_requires=["bob.extension"]))
from bob.extension.utils import find_packages from bob.extension.utils import find_packages
from bob.extension.utils import load_requirements from bob.extension.utils import load_requirements
dist.Distribution(dict(setup_requires=["bob.extension"]))
install_requires = load_requirements() install_requires = load_requirements()
# The only thing we do in this file is to call the setup() function with all # The only thing we do in this file is to call the setup() function with all
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment