Skip to content
Snippets Groups Projects
Commit 36386b66 authored by Yannick DAYER's avatar Yannick DAYER
Browse files

IVector as Transformer instead of BioAlgorithm.

parent c4c09df0
No related branches found
No related tags found
1 merge request!60Port of I-Vector to python
......@@ -11,7 +11,7 @@ import numpy as np
from sklearn.base import BaseEstimator
from bob.learn.em import GMMMachine, GMMStats, linear_scoring
from bob.learn.em import GMMMachine, GMMStats
logger = logging.getLogger("__name__")
......@@ -137,6 +137,7 @@ class IVectorMachine(BaseEstimator):
- dim_t: dimension of the i-vector
**Attributes**
T (c,d,t):
The total variability matrix :math:`T`
sigma (c,d):
......@@ -157,9 +158,10 @@ class IVectorMachine(BaseEstimator):
"""Initializes the IVectorMachine object.
**Parameters**
ubm:
ubm
The Universal Background Model.
dim_t:
dim_t
The dimension of the i-vector.
"""
......@@ -169,16 +171,12 @@ class IVectorMachine(BaseEstimator):
self.convergence_threshold = convergence_threshold
self.max_iterations = max_iterations
self.update_sigma = update_sigma
self.dim_c = self.ubm.n_gaussians
self.dim_d = self.ubm.means.shape[-1]
self.dim_c = None
self.dim_d = None
self.variance_floor = variance_floor
self.T = np.random.normal(
loc=0.0,
scale=1.0,
size=(self.dim_c, self.dim_d, self.dim_t),
)
self.sigma = copy.deepcopy(self.ubm.variances)
self.T = None
self.sigma = None
if self.convergence_threshold:
logger.info(
......@@ -263,6 +261,21 @@ class IVectorMachine(BaseEstimator):
``max_iterations`` is reached.
"""
if not isinstance(data[0], GMMStats):
if self.ubm is None: # Train a GMMMachine if not provided
self.ubm.fit(data)
data = self.ubm.transform(data) # Transform to GMMStats
self.dim_c = self.ubm.n_gaussians
self.dim_d = self.ubm.means.shape[-1]
self.T = np.random.normal(
loc=0.0,
scale=1.0,
size=(self.dim_c, self.dim_d, self.dim_t),
)
self.sigma = copy.deepcopy(self.ubm.variances)
for step in range(self.max_iterations):
logger.debug(
f"IVector step {step+1:{len(str(self.max_iterations))}d}/{self.max_iterations}."
......@@ -281,6 +294,7 @@ class IVectorMachine(BaseEstimator):
This takes data already projected onto the UBM.
**Returns:**
The IVector of the input stats.
"""
......@@ -292,7 +306,7 @@ class IVectorMachine(BaseEstimator):
),
)
def transform(self, data: List[np.ndarray]) -> List[np.ndarray]:
def transform(self, data: List[GMMStats]) -> List[np.ndarray]:
"""Transforms the data using the trained IVectorMachine.
This takes MFCC data, will project them onto the ubm, and compute the IVector
......@@ -301,39 +315,15 @@ class IVectorMachine(BaseEstimator):
**Parameters:**
data
The data (MFCC features) to transform. Arrays of shape (n_samples, n_features).
The data (MFCC features) to transform.
Arrays of shape (n_samples, n_features).
**Returns:**
The IVector for each sample. Shape: (dim_t,)
The IVector for each sample. Arrays of shape (dim_t,)
"""
return [self.project(self.ubm.acc_stats(d)) for d in data]
def enroll(self, stats: List[GMMStats]) -> IVectorStats:
"""Enrolls a new speaker.
Parameters
----------
stats : List[GMMStats]
The GMM statistics of the speaker to enroll.
Returns
-------
IVectorStats
The IVector statistics of the speaker.
"""
return self.project(stats)
def score(
self, model: IVectorStats, probes: List[np.ndarray]
) -> List[float]:
return linear_scoring(
model,
ubm=self.ubm,
test_stats=probes,
test_channel_offsets=0,
frame_length_normalization=True,
)
return [self.project(self.ubm.acc_stats(d)) for d in data]
def _more_tags(self) -> Dict[str, Any]:
return {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment