Skip to content
Snippets Groups Projects
Commit d079cddf authored by Yannick DAYER's avatar Yannick DAYER
Browse files

h5py instead of bob.io.base H5File

parent ed1032dc
No related branches found
No related tags found
1 merge request!26Python implementation of GMM
Pipeline #56661 failed
...@@ -17,11 +17,11 @@ from typing import Callable ...@@ -17,11 +17,11 @@ from typing import Callable
import dask.array as da import dask.array as da
import numpy as np import numpy as np
import dask import dask
from h5py import File as HDF5File
from sklearn.base import BaseEstimator from sklearn.base import BaseEstimator
import bob.core import bob.core
import bob.io.base
from bob.bio.base.pipelines.vanilla_biometrics.abstract_classes import BioAlgorithm from bob.bio.base.pipelines.vanilla_biometrics.abstract_classes import BioAlgorithm
from bob.learn.em.mixture import GMMMachine from bob.learn.em.mixture import GMMMachine
...@@ -150,13 +150,13 @@ class GMM(BioAlgorithm, BaseEstimator): ...@@ -150,13 +150,13 @@ class GMM(BioAlgorithm, BaseEstimator):
hdf5 = ( hdf5 = (
ubm_file ubm_file
if isinstance(ubm_file, bob.io.base.HDF5File) if isinstance(ubm_file, HDF5File)
else bob.io.base.HDF5File(ubm_file, "w") else HDF5File(ubm_file, "w")
) )
self.ubm.save(hdf5) self.ubm.save(hdf5)
def load_ubm(self, ubm_file): def load_ubm(self, ubm_file):
hdf5file = bob.io.base.HDF5File(ubm_file) hdf5file = HDF5File(ubm_file)
logger.debug("Loading model from file '%s'", ubm_file) logger.debug("Loading model from file '%s'", ubm_file)
# read UBM # read UBM
self.ubm = GMMMachine.from_hdf5(hdf5file) self.ubm = GMMMachine.from_hdf5(hdf5file)
...@@ -177,7 +177,7 @@ class GMM(BioAlgorithm, BaseEstimator): ...@@ -177,7 +177,7 @@ class GMM(BioAlgorithm, BaseEstimator):
def read_feature(self, feature_file): def read_feature(self, feature_file):
"""Read the type of features that we require, namely GMM_Stats""" """Read the type of features that we require, namely GMM_Stats"""
return GMMStats.from_hdf5(bob.io.base.HDF5File(feature_file)) return GMMStats.from_hdf5(HDF5File(feature_file))
def write_feature(self, feature, feature_file): def write_feature(self, feature, feature_file):
"""Write the features (GMM_Stats)""" """Write the features (GMM_Stats)"""
...@@ -213,7 +213,7 @@ class GMM(BioAlgorithm, BaseEstimator): ...@@ -213,7 +213,7 @@ class GMM(BioAlgorithm, BaseEstimator):
def read_model(self, model_file): def read_model(self, model_file):
"""Reads the model, which is a GMM machine""" """Reads the model, which is a GMM machine"""
return GMMMachine.from_hdf5(bob.io.base.HDF5File(model_file), ubm=self.ubm) return GMMMachine.from_hdf5(HDF5File(model_file), ubm=self.ubm)
def write_model(self, model, model_file): def write_model(self, model, model_file):
"""Write the features (GMM_Stats)""" """Write the features (GMM_Stats)"""
...@@ -232,7 +232,6 @@ class GMM(BioAlgorithm, BaseEstimator): ...@@ -232,7 +232,6 @@ class GMM(BioAlgorithm, BaseEstimator):
The probe data to compare to the model. The probe data to compare to the model.
""" """
# import ipdb; ipdb.set_trace()
assert isinstance(biometric_reference, GMMMachine) assert isinstance(biometric_reference, GMMMachine)
stats = self.project(data) stats = self.project(data)
return self.scoring_function( return self.scoring_function(
...@@ -287,9 +286,6 @@ class GMM(BioAlgorithm, BaseEstimator): ...@@ -287,9 +286,6 @@ class GMM(BioAlgorithm, BaseEstimator):
def fit(self, X, y=None, **kwargs): def fit(self, X, y=None, **kwargs):
"""Trains the UBM.""" """Trains the UBM."""
# TODO: Delayed to dask array
if not all(isinstance(x, da.Array) for x in X):
raise ValueError(f"This function only supports dask arrays, {type(X[0])}")
# Stack all the samples in a 2D array of features # Stack all the samples in a 2D array of features
array = da.vstack(X) array = da.vstack(X)
...@@ -332,7 +328,6 @@ def delayed_to_da(delayed, meta=None): ...@@ -332,7 +328,6 @@ def delayed_to_da(delayed, meta=None):
"""Converts one dask.delayed object to a dask.array""" """Converts one dask.delayed object to a dask.array"""
if meta is None: if meta is None:
meta = np.array(delayed.data.compute()) meta = np.array(delayed.data.compute())
print(meta.shape)
darray = da.from_delayed(delayed.data, meta.shape, dtype=meta.dtype, name=False) darray = da.from_delayed(delayed.data, meta.shape, dtype=meta.dtype, name=False)
return darray, meta return darray, meta
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment