Skip to content
Snippets Groups Projects
Commit b4ffb2a6 authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

[pre-commit] update configs

parent 65f4e960
No related branches found
No related tags found
1 merge request!32make sure the algorithm runs outside dask wrappers
Pipeline #59495 passed with warnings
[flake8] [flake8]
max-line-length = 88 max-line-length = 80
select = B,C,E,F,W,T4,B9,B950 ignore = E501,W503,E302,E402,E203
ignore = E501, W503, E203
...@@ -2,20 +2,20 @@ ...@@ -2,20 +2,20 @@
# See https://pre-commit.com/hooks.html for more hooks # See https://pre-commit.com/hooks.html for more hooks
repos: repos:
- repo: https://github.com/timothycrosley/isort - repo: https://github.com/timothycrosley/isort
rev: 5.8.0 rev: 5.9.3
hooks: hooks:
- id: isort - id: isort
args: [--sl, --line-length, "88"] args: [--settings-path, "pyproject.toml"]
- repo: https://github.com/psf/black - repo: https://github.com/psf/black
rev: 20.8b1 rev: 21.7b0
hooks: hooks:
- id: black - id: black
- repo: https://gitlab.com/pycqa/flake8 - repo: https://gitlab.com/pycqa/flake8
rev: 3.9.0 rev: 3.9.2
hooks: hooks:
- id: flake8 - id: flake8
- repo: https://github.com/pre-commit/pre-commit-hooks - repo: https://github.com/pre-commit/pre-commit-hooks
rev: v3.4.0 rev: v4.0.1
hooks: hooks:
- id: check-ast - id: check-ast
- id: check-case-conflict - id: check-case-conflict
...@@ -23,21 +23,5 @@ repos: ...@@ -23,21 +23,5 @@ repos:
- id: end-of-file-fixer - id: end-of-file-fixer
- id: debug-statements - id: debug-statements
- id: check-added-large-files - id: check-added-large-files
- repo: local - id: check-yaml
hooks: exclude: .*/meta.yaml
- id: sphinx-build
name: sphinx build
entry: python -m sphinx.cmd.build
args: [-a, -E, -W, doc, sphinx]
language: system
files: ^doc/
types: [file]
pass_filenames: false
- id: sphinx-doctest
name: sphinx doctest
entry: python -m sphinx.cmd.build
args: [-a, -E, -b, doctest, doc, sphinx]
language: system
files: ^doc/
types: [file]
pass_filenames: false
...@@ -13,8 +13,7 @@ This adds the notions of models, probes, enrollment, and scores to GMM. ...@@ -13,8 +13,7 @@ This adds the notions of models, probes, enrollment, and scores to GMM.
import copy import copy
import logging import logging
from typing import Callable from typing import Callable, Union
from typing import Union
import dask.array as da import dask.array as da
import numpy as np import numpy as np
...@@ -23,10 +22,7 @@ from h5py import File as HDF5File ...@@ -23,10 +22,7 @@ from h5py import File as HDF5File
from sklearn.base import BaseEstimator from sklearn.base import BaseEstimator
from bob.bio.base.pipelines.vanilla_biometrics import BioAlgorithm from bob.bio.base.pipelines.vanilla_biometrics import BioAlgorithm
from bob.learn.em import GMMMachine from bob.learn.em import GMMMachine, GMMStats, KMeansMachine, linear_scoring
from bob.learn.em import GMMStats
from bob.learn.em import KMeansMachine
from bob.learn.em import linear_scoring
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -153,7 +149,9 @@ class GMM(BioAlgorithm, BaseEstimator): ...@@ -153,7 +149,9 @@ class GMM(BioAlgorithm, BaseEstimator):
or feature.ndim != 2 or feature.ndim != 2
or feature.dtype != np.float64 or feature.dtype != np.float64
): ):
raise ValueError(f"The given feature is not appropriate: \n{feature}") raise ValueError(
f"The given feature is not appropriate: \n{feature}"
)
if self.ubm is not None and feature.shape[1] != self.ubm.shape[1]: if self.ubm is not None and feature.shape[1] != self.ubm.shape[1]:
raise ValueError( raise ValueError(
"The given feature is expected to have %d elements, but it has %d" "The given feature is expected to have %d elements, but it has %d"
...@@ -165,7 +163,11 @@ class GMM(BioAlgorithm, BaseEstimator): ...@@ -165,7 +163,11 @@ class GMM(BioAlgorithm, BaseEstimator):
# Saves the UBM to file # Saves the UBM to file
logger.debug("Saving model to file '%s'", ubm_file) logger.debug("Saving model to file '%s'", ubm_file)
hdf5 = ubm_file if isinstance(ubm_file, HDF5File) else HDF5File(ubm_file, "w") hdf5 = (
ubm_file
if isinstance(ubm_file, HDF5File)
else HDF5File(ubm_file, "w")
)
self.ubm.save(hdf5) self.ubm.save(hdf5)
def load_model(self, ubm_file): def load_model(self, ubm_file):
...@@ -273,7 +275,9 @@ class GMM(BioAlgorithm, BaseEstimator): ...@@ -273,7 +275,9 @@ class GMM(BioAlgorithm, BaseEstimator):
The probe data to compare to the models. The probe data to compare to the models.
""" """
stats = self.project(probe) if not isinstance(probe, GMMStats) else probe stats = (
self.project(probe) if not isinstance(probe, GMMStats) else probe
)
return self.scoring_function( return self.scoring_function(
models_means=biometric_references, models_means=biometric_references,
ubm=self.ubm, ubm=self.ubm,
...@@ -291,7 +295,9 @@ class GMM(BioAlgorithm, BaseEstimator): ...@@ -291,7 +295,9 @@ class GMM(BioAlgorithm, BaseEstimator):
if array[0].ndim == 2: if array[0].ndim == 2:
array = np.vstack(array) array = np.vstack(array)
logger.debug(f"Creating UBM machine with {self.number_of_gaussians} gaussians and {len(array)} samples") logger.debug(
f"Creating UBM machine with {self.number_of_gaussians} gaussians and {len(array)} samples"
)
self.ubm = GMMMachine( self.ubm = GMMMachine(
n_gaussians=self.number_of_gaussians, n_gaussians=self.number_of_gaussians,
......
...@@ -27,8 +27,7 @@ import bob.bio.gmm ...@@ -27,8 +27,7 @@ import bob.bio.gmm
from bob.bio.base.test import utils from bob.bio.base.test import utils
from bob.bio.gmm.algorithm import GMM from bob.bio.gmm.algorithm import GMM
from bob.learn.em import GMMMachine from bob.learn.em import GMMMachine, GMMStats
from bob.learn.em import GMMStats
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -44,7 +43,8 @@ def test_class(): ...@@ -44,7 +43,8 @@ def test_class():
) )
assert isinstance(gmm1, GMM) assert isinstance(gmm1, GMM)
assert isinstance( assert isinstance(
gmm1, bob.bio.base.pipelines.vanilla_biometrics.abstract_classes.BioAlgorithm gmm1,
bob.bio.base.pipelines.vanilla_biometrics.abstract_classes.BioAlgorithm,
) )
assert gmm1.number_of_gaussians == 512 assert gmm1.number_of_gaussians == 512
assert "bob_fit_supports_dask_array" in gmm1._get_tags() assert "bob_fit_supports_dask_array" in gmm1._get_tags()
...@@ -117,7 +117,9 @@ def test_enroll(): ...@@ -117,7 +117,9 @@ def test_enroll():
) )
# Create a GMM object with that UBM # Create a GMM object with that UBM
gmm1 = GMM( gmm1 = GMM(
number_of_gaussians=2, enroll_update_means=True, enroll_update_variances=True number_of_gaussians=2,
enroll_update_means=True,
enroll_update_variances=True,
) )
gmm1.ubm = ubm gmm1.ubm = ubm
# Enroll the biometric reference from random features # Enroll the biometric reference from random features
...@@ -136,7 +138,9 @@ def test_enroll(): ...@@ -136,7 +138,9 @@ def test_enroll():
gmm2 = gmm1.read_biometric_reference(reference_file) gmm2 = gmm1.read_biometric_reference(reference_file)
assert biometric_reference.is_similar_to(gmm2) assert biometric_reference.is_similar_to(gmm2)
with tempfile.NamedTemporaryFile(prefix="bob_", suffix="_bioref.hdf5") as fd: with tempfile.NamedTemporaryFile(
prefix="bob_", suffix="_bioref.hdf5"
) as fd:
temp_file = fd.name temp_file = fd.name
gmm1.write_biometric_reference(biometric_reference, temp_file) gmm1.write_biometric_reference(biometric_reference, temp_file)
assert GMMMachine.from_hdf5(temp_file, ubm).is_similar_to(gmm2) assert GMMMachine.from_hdf5(temp_file, ubm).is_similar_to(gmm2)
...@@ -148,11 +152,15 @@ def test_score(): ...@@ -148,11 +152,15 @@ def test_score():
pkg_resources.resource_filename("bob.bio.gmm.test", "data/gmm_ubm.hdf5") pkg_resources.resource_filename("bob.bio.gmm.test", "data/gmm_ubm.hdf5")
) )
biometric_reference = GMMMachine.from_hdf5( biometric_reference = GMMMachine.from_hdf5(
pkg_resources.resource_filename("bob.bio.gmm.test", "data/gmm_enrolled.hdf5"), pkg_resources.resource_filename(
"bob.bio.gmm.test", "data/gmm_enrolled.hdf5"
),
ubm=gmm1.ubm, ubm=gmm1.ubm,
) )
probe = GMMStats.from_hdf5( probe = GMMStats.from_hdf5(
pkg_resources.resource_filename("bob.bio.gmm.test", "data/gmm_projected.hdf5") pkg_resources.resource_filename(
"bob.bio.gmm.test", "data/gmm_projected.hdf5"
)
) )
probe_data = utils.random_array((20, 45), -5.0, 5.0, seed=seed_value) probe_data = utils.random_array((20, 45), -5.0, 5.0, seed=seed_value)
......
...@@ -8,8 +8,7 @@ import pkg_resources ...@@ -8,8 +8,7 @@ import pkg_resources
import sphinx_rtd_theme import sphinx_rtd_theme
# For inter-documentation mapping: # For inter-documentation mapping:
from bob.extension.utils import link_documentation from bob.extension.utils import link_documentation, load_requirements
from bob.extension.utils import load_requirements
# -- General configuration ----------------------------------------------------- # -- General configuration -----------------------------------------------------
...@@ -130,9 +129,7 @@ pygments_style = "sphinx" ...@@ -130,9 +129,7 @@ pygments_style = "sphinx"
# Some variables which are useful for generated material # Some variables which are useful for generated material
project_variable = project.replace(".", "_") project_variable = project.replace(".", "_")
short_description = ( short_description = u"Tools for running biometric recognition experiments using GMM-based approximation"
u"Tools for running biometric recognition experiments using GMM-based approximation"
)
owner = [u"Idiap Research Institute"] owner = [u"Idiap Research Institute"]
......
[build-system] [build-system]
requires = ["setuptools", "wheel", "bob.extension"] requires = ["setuptools", "wheel", "bob.extension"]
build-backend = "setuptools.build_meta" build-backend = "setuptools.build_meta"
[tool.isort]
profile = "black"
line_length = 80
order_by_type = true
lines_between_types = 1
[tool.black]
line-length = 80
...@@ -33,11 +33,9 @@ ...@@ -33,11 +33,9 @@
# allows you to test your package with new python dependencies w/o requiring # allows you to test your package with new python dependencies w/o requiring
# administrative interventions. # administrative interventions.
from setuptools import dist from setuptools import dist, setup
from setuptools import setup
from bob.extension.utils import find_packages from bob.extension.utils import find_packages, load_requirements
from bob.extension.utils import load_requirements
dist.Distribution(dict(setup_requires=["bob.extension"])) dist.Distribution(dict(setup_requires=["bob.extension"]))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment