Skip to content
Snippets Groups Projects
Commit 363378b5 authored by Yannick DAYER's avatar Yannick DAYER
Browse files

[refactor] isort, flake8

parent 89d8f391
Branches
No related tags found
1 merge request!28Adapt to `bob.learn.em` API changes
Pipeline #58617 passed
......@@ -13,8 +13,8 @@ This adds the notions of models, probes, enrollment, and scores to GMM.
import copy
import logging
from typing import Union
from typing import Callable
from typing import Union
import dask
import dask.array as da
......@@ -24,9 +24,9 @@ from h5py import File as HDF5File
from sklearn.base import BaseEstimator
from bob.bio.base.pipelines.vanilla_biometrics.abstract_classes import BioAlgorithm
from bob.learn.em import KMeansMachine
from bob.learn.em import GMMMachine
from bob.learn.em import GMMStats
from bob.learn.em import KMeansMachine
from bob.learn.em import linear_scoring
logger = logging.getLogger(__name__)
......@@ -50,7 +50,9 @@ class GMM(BioAlgorithm, BaseEstimator):
number_of_gaussians: int,
# parameters of UBM training
kmeans_training_iterations: int = 25, # Maximum number of iterations for K-Means
kmeans_init_iterations: Union[int,None] = None, # Maximum number of iterations for K-Means init
kmeans_init_iterations: Union[
int, None
] = None, # Maximum number of iterations for K-Means init
kmeans_oversampling_factor: int = 64,
ubm_training_iterations: int = 25, # Maximum number of iterations for GMM Training
training_threshold: float = 5e-4, # Threshold to end the ML training
......@@ -117,7 +119,11 @@ class GMM(BioAlgorithm, BaseEstimator):
# Copy parameters
self.number_of_gaussians = number_of_gaussians
self.kmeans_training_iterations = kmeans_training_iterations
self.kmeans_init_iterations = kmeans_training_iterations if kmeans_init_iterations is None else kmeans_init_iterations
self.kmeans_init_iterations = (
kmeans_training_iterations
if kmeans_init_iterations is None
else kmeans_init_iterations
)
self.kmeans_oversampling_factor = kmeans_oversampling_factor
self.ubm_training_iterations = ubm_training_iterations
self.training_threshold = training_threshold
......
......@@ -18,7 +18,6 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import logging
import os
import tempfile
import numpy
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment