Skip to content
Snippets Groups Projects
Commit f3086b37 authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

Merge branch 'fixes' into 'master'

make sure the algorithm runs outside dask wrappers

See merge request !32
parents 104cee7a b4ffb2a6
No related branches found
No related tags found
1 merge request!32make sure the algorithm runs outside dask wrappers
Pipeline #59496 passed with warnings
[flake8]
max-line-length = 88
select = B,C,E,F,W,T4,B9,B950
ignore = E501, W503, E203
max-line-length = 80
ignore = E501,W503,E302,E402,E203
......@@ -2,20 +2,20 @@
# See https://pre-commit.com/hooks.html for more hooks
repos:
- repo: https://github.com/timothycrosley/isort
rev: 5.8.0
rev: 5.9.3
hooks:
- id: isort
args: [--sl, --line-length, "88"]
- id: isort
args: [--settings-path, "pyproject.toml"]
- repo: https://github.com/psf/black
rev: 20.8b1
rev: 21.7b0
hooks:
- id: black
- repo: https://gitlab.com/pycqa/flake8
rev: 3.9.0
rev: 3.9.2
hooks:
- id: flake8
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v3.4.0
rev: v4.0.1
hooks:
- id: check-ast
- id: check-case-conflict
......@@ -23,21 +23,5 @@ repos:
- id: end-of-file-fixer
- id: debug-statements
- id: check-added-large-files
- repo: local
hooks:
- id: sphinx-build
name: sphinx build
entry: python -m sphinx.cmd.build
args: [-a, -E, -W, doc, sphinx]
language: system
files: ^doc/
types: [file]
pass_filenames: false
- id: sphinx-doctest
name: sphinx doctest
entry: python -m sphinx.cmd.build
args: [-a, -E, -b, doctest, doc, sphinx]
language: system
files: ^doc/
types: [file]
pass_filenames: false
- id: check-yaml
exclude: .*/meta.yaml
......@@ -13,8 +13,7 @@ This adds the notions of models, probes, enrollment, and scores to GMM.
import copy
import logging
from typing import Callable
from typing import Union
from typing import Callable, Union
import dask.array as da
import numpy as np
......@@ -23,10 +22,7 @@ from h5py import File as HDF5File
from sklearn.base import BaseEstimator
from bob.bio.base.pipelines.vanilla_biometrics import BioAlgorithm
from bob.learn.em import GMMMachine
from bob.learn.em import GMMStats
from bob.learn.em import KMeansMachine
from bob.learn.em import linear_scoring
from bob.learn.em import GMMMachine, GMMStats, KMeansMachine, linear_scoring
logger = logging.getLogger(__name__)
......@@ -153,7 +149,9 @@ class GMM(BioAlgorithm, BaseEstimator):
or feature.ndim != 2
or feature.dtype != np.float64
):
raise ValueError(f"The given feature is not appropriate: \n{feature}")
raise ValueError(
f"The given feature is not appropriate: \n{feature}"
)
if self.ubm is not None and feature.shape[1] != self.ubm.shape[1]:
raise ValueError(
"The given feature is expected to have %d elements, but it has %d"
......@@ -165,7 +163,11 @@ class GMM(BioAlgorithm, BaseEstimator):
# Saves the UBM to file
logger.debug("Saving model to file '%s'", ubm_file)
hdf5 = ubm_file if isinstance(ubm_file, HDF5File) else HDF5File(ubm_file, "w")
hdf5 = (
ubm_file
if isinstance(ubm_file, HDF5File)
else HDF5File(ubm_file, "w")
)
self.ubm.save(hdf5)
def load_model(self, ubm_file):
......@@ -199,7 +201,10 @@ class GMM(BioAlgorithm, BaseEstimator):
for feature in data:
self._check_feature(feature)
data = np.vstack(data)
# if input is a list (or SampleBatch) of 2 dimensional arrays, stack them
if data[0].ndim == 2:
data = np.vstack(data)
# Use the array to train a GMM and return it
logger.info("Enrolling with %d feature vectors", data.shape[0])
......@@ -270,7 +275,9 @@ class GMM(BioAlgorithm, BaseEstimator):
The probe data to compare to the models.
"""
stats = self.project(probe) if not isinstance(probe, GMMStats) else probe
stats = (
self.project(probe) if not isinstance(probe, GMMStats) else probe
)
return self.scoring_function(
models_means=biometric_references,
ubm=self.ubm,
......@@ -284,9 +291,13 @@ class GMM(BioAlgorithm, BaseEstimator):
if isinstance(array, da.Array):
array = array.persist()
logger.debug("UBM with %d feature vectors", len(array))
# if input is a list (or SampleBatch) of 2 dimensional arrays, stack them
if array[0].ndim == 2:
array = np.vstack(array)
logger.debug(f"Creating UBM machine with {self.number_of_gaussians} gaussians")
logger.debug(
f"Creating UBM machine with {self.number_of_gaussians} gaussians and {len(array)} samples"
)
self.ubm = GMMMachine(
n_gaussians=self.number_of_gaussians,
......
......@@ -27,8 +27,7 @@ import bob.bio.gmm
from bob.bio.base.test import utils
from bob.bio.gmm.algorithm import GMM
from bob.learn.em import GMMMachine
from bob.learn.em import GMMStats
from bob.learn.em import GMMMachine, GMMStats
logger = logging.getLogger(__name__)
......@@ -44,7 +43,8 @@ def test_class():
)
assert isinstance(gmm1, GMM)
assert isinstance(
gmm1, bob.bio.base.pipelines.vanilla_biometrics.abstract_classes.BioAlgorithm
gmm1,
bob.bio.base.pipelines.vanilla_biometrics.abstract_classes.BioAlgorithm,
)
assert gmm1.number_of_gaussians == 512
assert "bob_fit_supports_dask_array" in gmm1._get_tags()
......@@ -117,7 +117,9 @@ def test_enroll():
)
# Create a GMM object with that UBM
gmm1 = GMM(
number_of_gaussians=2, enroll_update_means=True, enroll_update_variances=True
number_of_gaussians=2,
enroll_update_means=True,
enroll_update_variances=True,
)
gmm1.ubm = ubm
# Enroll the biometric reference from random features
......@@ -136,7 +138,9 @@ def test_enroll():
gmm2 = gmm1.read_biometric_reference(reference_file)
assert biometric_reference.is_similar_to(gmm2)
with tempfile.NamedTemporaryFile(prefix="bob_", suffix="_bioref.hdf5") as fd:
with tempfile.NamedTemporaryFile(
prefix="bob_", suffix="_bioref.hdf5"
) as fd:
temp_file = fd.name
gmm1.write_biometric_reference(biometric_reference, temp_file)
assert GMMMachine.from_hdf5(temp_file, ubm).is_similar_to(gmm2)
......@@ -148,11 +152,15 @@ def test_score():
pkg_resources.resource_filename("bob.bio.gmm.test", "data/gmm_ubm.hdf5")
)
biometric_reference = GMMMachine.from_hdf5(
pkg_resources.resource_filename("bob.bio.gmm.test", "data/gmm_enrolled.hdf5"),
pkg_resources.resource_filename(
"bob.bio.gmm.test", "data/gmm_enrolled.hdf5"
),
ubm=gmm1.ubm,
)
probe = GMMStats.from_hdf5(
pkg_resources.resource_filename("bob.bio.gmm.test", "data/gmm_projected.hdf5")
pkg_resources.resource_filename(
"bob.bio.gmm.test", "data/gmm_projected.hdf5"
)
)
probe_data = utils.random_array((20, 45), -5.0, 5.0, seed=seed_value)
......
......@@ -8,8 +8,7 @@ import pkg_resources
import sphinx_rtd_theme
# For inter-documentation mapping:
from bob.extension.utils import link_documentation
from bob.extension.utils import load_requirements
from bob.extension.utils import link_documentation, load_requirements
# -- General configuration -----------------------------------------------------
......@@ -130,9 +129,7 @@ pygments_style = "sphinx"
# Some variables which are useful for generated material
project_variable = project.replace(".", "_")
short_description = (
u"Tools for running biometric recognition experiments using GMM-based approximation"
)
short_description = u"Tools for running biometric recognition experiments using GMM-based approximation"
owner = [u"Idiap Research Institute"]
......
[build-system]
requires = ["setuptools", "wheel", "bob.extension"]
build-backend = "setuptools.build_meta"
requires = ["setuptools", "wheel", "bob.extension"]
build-backend = "setuptools.build_meta"
[tool.isort]
profile = "black"
line_length = 80
order_by_type = true
lines_between_types = 1
[tool.black]
line-length = 80
......@@ -33,11 +33,9 @@
# allows you to test your package with new python dependencies w/o requiring
# administrative interventions.
from setuptools import dist
from setuptools import setup
from setuptools import dist, setup
from bob.extension.utils import find_packages
from bob.extension.utils import load_requirements
from bob.extension.utils import find_packages, load_requirements
dist.Distribution(dict(setup_requires=["bob.extension"]))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment