diff --git a/bob/bio/gmm/algorithm/GMM.py b/bob/bio/gmm/algorithm/GMM.py index 423fbc516e32fdcc4b53c66b874827bddfec2470..a122da9bf79cde683ddcf6819afbe9cf1037818e 100644 --- a/bob/bio/gmm/algorithm/GMM.py +++ b/bob/bio/gmm/algorithm/GMM.py @@ -13,8 +13,8 @@ This adds the notions of models, probes, enrollment, and scores to GMM. import copy import logging -from typing import Union from typing import Callable +from typing import Union import dask import dask.array as da @@ -24,9 +24,9 @@ from h5py import File as HDF5File from sklearn.base import BaseEstimator from bob.bio.base.pipelines.vanilla_biometrics.abstract_classes import BioAlgorithm -from bob.learn.em import KMeansMachine from bob.learn.em import GMMMachine from bob.learn.em import GMMStats +from bob.learn.em import KMeansMachine from bob.learn.em import linear_scoring logger = logging.getLogger(__name__) @@ -50,7 +50,9 @@ class GMM(BioAlgorithm, BaseEstimator): number_of_gaussians: int, # parameters of UBM training kmeans_training_iterations: int = 25, # Maximum number of iterations for K-Means - kmeans_init_iterations: Union[int,None] = None, # Maximum number of iterations for K-Means init + kmeans_init_iterations: Union[ + int, None + ] = None, # Maximum number of iterations for K-Means init kmeans_oversampling_factor: int = 64, ubm_training_iterations: int = 25, # Maximum number of iterations for GMM Training training_threshold: float = 5e-4, # Threshold to end the ML training @@ -117,7 +119,11 @@ class GMM(BioAlgorithm, BaseEstimator): # Copy parameters self.number_of_gaussians = number_of_gaussians self.kmeans_training_iterations = kmeans_training_iterations - self.kmeans_init_iterations = kmeans_training_iterations if kmeans_init_iterations is None else kmeans_init_iterations + self.kmeans_init_iterations = ( + kmeans_training_iterations + if kmeans_init_iterations is None + else kmeans_init_iterations + ) self.kmeans_oversampling_factor = kmeans_oversampling_factor self.ubm_training_iterations = ubm_training_iterations self.training_threshold = training_threshold diff --git a/bob/bio/gmm/test/test_gmm.py b/bob/bio/gmm/test/test_gmm.py index 4a9acb404149b5060a3af453e74dc21a63d89008..0c9f126a3c9162e9cc32afb7f198f526860ab51f 100644 --- a/bob/bio/gmm/test/test_gmm.py +++ b/bob/bio/gmm/test/test_gmm.py @@ -18,7 +18,6 @@ # along with this program. If not, see <http://www.gnu.org/licenses/>. import logging -import os import tempfile import numpy