Skip to content
Snippets Groups Projects
Commit a0c54a5c authored by Yannick DAYER's avatar Yannick DAYER
Browse files

Merge branch 'ivector' into 'master'

Port of I-Vector to python

See merge request !60
parents 9440d528 0f4c5437
No related branches found
No related tags found
1 merge request!60Port of I-Vector to python
Pipeline #65243 passed
Showing
with 703 additions and 7 deletions
......@@ -2,6 +2,7 @@ import bob.extension
from .factor_analysis import ISVMachine, JFAMachine
from .gmm import GMMMachine, GMMStats
from .ivector import IVectorMachine
from .kmeans import KMeansMachine
from .linear_scoring import linear_scoring # noqa: F401
from .wccn import WCCN
......@@ -30,6 +31,13 @@ def __appropriate__(*args):
__appropriate__(
KMeansMachine, GMMMachine, GMMStats, WCCN, Whitening, ISVMachine, JFAMachine
KMeansMachine,
GMMMachine,
GMMStats,
IVectorMachine,
WCCN,
Whitening,
ISVMachine,
JFAMachine,
)
__all__ = [_ for _ in dir() if not _.startswith("_")]
File added
File added
File added
File added
File added
File added
File added
File added
File added
#!/usr/bin/env python
# @author: Yannick Dayer <yannick.dayer@idiap.ch>
# @date: Fri 06 May 2022 14:18:25 UTC+02
import copy
import logging
import operator
from typing import Any, Dict, List, Optional, Tuple, Union
import dask
import dask.bag
import numpy as np
from sklearn.base import BaseEstimator
from bob.learn.em import GMMMachine, GMMStats
logger = logging.getLogger("__name__")
class IVectorStats:
"""Stores I-Vector statistics. Can be used to accumulate multiple statistics.
**Attributes:**
nij_sigma_wij2: numpy.ndarray of shape (n_gaussians,dim_t,dim_t)
fnorm_sigma_wij: numpy.ndarray of shape (n_gaussians,n_features,dim_t)
snormij: numpy.ndarray of shape (n_gaussians,n_features)
nij: numpy.ndarray of shape (n_gaussians,)
"""
def __init__(self, dim_c, dim_d, dim_t):
self.dim_c = dim_c
self.dim_d = dim_d
self.dim_t = dim_t
# Accumulator storage variables
# nij sigma wij2: shape = (c,t,t)
self.nij_sigma_wij2 = np.zeros(
shape=(self.dim_c, self.dim_t, self.dim_t), dtype=float
)
# fnorm sigma wij: shape = (c,d,t)
self.fnorm_sigma_wij = np.zeros(
shape=(self.dim_c, self.dim_d, self.dim_t), dtype=float
)
# Snormij (used only when updating sigma)
self.snormij = np.zeros(
shape=(
self.dim_c,
self.dim_d,
),
dtype=float,
)
# Nij (used only when updating sigma)
self.nij = np.zeros(shape=(self.dim_c,), dtype=float)
@property
def shape(self) -> Tuple[int, int, int]:
return (self.dim_c, self.dim_d, self.dim_t)
def __add__(self, other):
if self.shape != other.shape:
raise ValueError("Cannot add stats of different shapes")
result = IVectorStats(self.dim_c, self.dim_d, self.dim_t)
result.nij_sigma_wij2 = self.nij_sigma_wij2 + other.nij_sigma_wij2
result.fnorm_sigma_wij = self.fnorm_sigma_wij + other.fnorm_sigma_wij
result.snormij = self.snormij + other.snormij
result.nij = self.nij + other.nij
return result
def __iadd__(self, other):
if self.shape != other.shape:
raise ValueError("Cannot add stats of different shapes")
self.nij_sigma_wij2 += other.nij_sigma_wij2
self.fnorm_sigma_wij += other.fnorm_sigma_wij
self.snormij += other.snormij
self.nij += other.nij
return self
def compute_tct_sigmac_inv(T: np.ndarray, sigma: np.ndarray) -> np.ndarray:
"""Computes T_{c}^{T}.sigma_{c}^{-1}"""
# TT_sigma_inv (c,t,d) = T.T (c,t,d) / sigma (c,1,d)
Tct_sigmacInv = T.transpose(0, 2, 1) / sigma[:, None, :]
# Tt_sigma_inv (c,t,d)
return Tct_sigmacInv
def compute_tct_sigmac_inv_tc(T: np.ndarray, sigma: np.ndarray) -> np.ndarray:
"""Computes T_{c}^{T}.sigma_{c}^{-1}.T_{c}"""
tct_sigmac_inv = compute_tct_sigmac_inv(T, sigma)
# (c,t,t) = (c,t,d) @ (c,d,t)
Tct_sigmacInv_Tc = tct_sigmac_inv @ T
# Output: shape (c,t,t)
return Tct_sigmacInv_Tc
def compute_id_tt_sigma_inv_t(
stats: GMMStats, T: np.ndarray, sigma: np.ndarray
) -> np.ndarray:
dim_t = T.shape[-1]
tct_sigmac_inv_tc = compute_tct_sigmac_inv_tc(T, sigma)
output = np.eye(dim_t, dim_t) + np.einsum(
"c,ctu->tu", stats.n, tct_sigmac_inv_tc
)
# Output: (t,t)
return output
def compute_tt_sigma_inv_fnorm(
ubm_means: np.ndarray, stats: GMMStats, T: np.ndarray, sigma: np.ndarray
) -> np.ndarray:
"""Computes \f$(Id + \\sum_{c=1}^{C} N_{i,j,c} T^{T} \\Sigma_{c}^{-1} T)\f$
Returns an array of shape (t,)
"""
tct_sigmac_inv = compute_tct_sigmac_inv(T, sigma) # (c,t,d)
fnorm = stats.sum_px - stats.n[:, None] * ubm_means # (c,d)
# (t,) += (t,d) @ (d) [repeated c times]
output = np.einsum("ctd,cd->t", tct_sigmac_inv, fnorm)
# Output: shape (t,)
return output
def e_step(machine: "IVectorMachine", data: List[GMMStats]) -> IVectorStats:
"""Computes the expectation step of the e-m algorithm."""
stats = IVectorStats(machine.dim_c, machine.dim_d, machine.dim_t)
for sample in data:
Nij = sample.n
Fij = sample.sum_px
Sij = sample.sum_pxx
# Estimate latent variables
TtSigmaInv_Fnorm = compute_tt_sigma_inv_fnorm(
machine.ubm.means, sample, machine.T, machine.sigma
) # self.compute_TtSigmaInvFnorm(data[n]) # shape: (t,)
I_TtSigmaInvNT = compute_id_tt_sigma_inv_t(
sample, machine.T, machine.sigma
) # self.compute_Id_TtSigmaInvT(data[n]), # shape: (t,t)
# Latent variables
I_TtSigmaInvNT_inv = np.linalg.inv(I_TtSigmaInvNT) # shape: (t,t)
sigma_w_ij = np.dot(I_TtSigmaInvNT_inv, TtSigmaInv_Fnorm) # shape: (t,)
sigma_w_ij2 = I_TtSigmaInvNT_inv + np.outer(
sigma_w_ij, sigma_w_ij
) # shape: (t,t)
# Compute normalized statistics
Fnorm = Fij - Nij[:, None] * machine.ubm.means
Snorm = (
Sij
- (2 * Fij * machine.ubm.means)
+ (Nij[:, None] * machine.ubm.means * machine.ubm.means)
)
# Do the accumulation for each component
stats.snormij = stats.snormij + Snorm # shape: (c, d)
# (c,t,t) += (c,) * (t,t)
stats.nij_sigma_wij2 = stats.nij_sigma_wij2 + (
Nij[:, None, None] * sigma_w_ij2[None, :, :]
) # (c,t,t)
stats.nij = stats.nij + Nij
stats.fnorm_sigma_wij = stats.fnorm_sigma_wij + np.matmul(
Fnorm[:, :, None], sigma_w_ij[None, :]
) # (c,d,t)
return stats
def m_step(machine: "IVectorMachine", stats: IVectorStats) -> "IVectorMachine":
"""Updates the Machine with the maximization step of the e-m algorithm."""
logger.debug("Computing new machine parameters.")
A = stats.nij_sigma_wij2.transpose((0, 2, 1))
B = stats.fnorm_sigma_wij.transpose((0, 2, 1))
# Default value of X if any of A[c] is 0
X = np.zeros_like(B)
# Solve for all A[c] != 0
if any(mask := A.any(axis=(-2, -1))): # Prevents solving with 0 matrices
X[mask] = [
np.linalg.solve(A[c], B[c]) for c in range(len(mask)) if A[c].any()
]
# Update the machine
machine.T = X.transpose((0, 2, 1))
if machine.update_sigma:
fnorm_sigma_wij_tt = np.diagonal(
stats.fnorm_sigma_wij @ X, axis1=-2, axis2=-1
)
machine.sigma = (stats.snormij - fnorm_sigma_wij_tt) / stats.nij[
:, None
]
machine.sigma[
machine.sigma < machine.variance_floor
] = machine.variance_floor
return machine
class IVectorMachine(BaseEstimator):
"""Trains and projects data using I-Vector.
Dimensions:
- dim_c: number of Gaussians
- dim_d: number of features
- dim_t: dimension of the i-vector
**Attributes**
T (c,d,t):
The total variability matrix :math:`T`
sigma (c,d):
The diagonal covariance matrix :math:`Sigma`
"""
def __init__(
self,
ubm: GMMMachine,
dim_t: int = 2,
convergence_threshold: Optional[float] = None,
max_iterations: int = 25,
update_sigma: bool = True,
variance_floor: float = 1e-10,
**kwargs,
) -> None:
"""Initializes the IVectorMachine object.
**Parameters**
ubm
The Universal Background Model.
dim_t
The dimension of the i-vector.
"""
super().__init__(**kwargs)
self.ubm = ubm
self.dim_t = dim_t
self.convergence_threshold = convergence_threshold
self.max_iterations = max_iterations
self.update_sigma = update_sigma
self.dim_c = None
self.dim_d = None
self.variance_floor = variance_floor
self.T = None
self.sigma = None
if self.convergence_threshold:
logger.info(
"The convergence threshold is ignored by IVectorMachine."
)
def fit(
self, X: Union[List[np.ndarray], dask.bag.Bag], y=None
) -> "IVectorMachine":
"""Trains the IVectorMachine.
Repeats the e-m steps until ``max_iterations`` is reached.
"""
chunky = False
if isinstance(X, dask.bag.Bag):
chunky = True
X = X.to_delayed()
self.dim_c = self.ubm.n_gaussians
self.dim_d = self.ubm.means.shape[-1]
self.T = np.random.normal(
loc=0.0,
scale=1.0,
size=(self.dim_c, self.dim_d, self.dim_t),
)
self.sigma = copy.deepcopy(self.ubm.variances)
for step in range(self.max_iterations):
if chunky:
stats = [
dask.delayed(e_step)(
machine=self,
data=xx,
)
for xx in X
]
# Workaround to prevent memory issues at compute with too many chunks.
# This adds pairs of stats together instead of sending all the stats to
# one worker.
while (l := len(stats)) > 1:
last = stats[-1]
stats = [
dask.delayed(operator.add)(stats[i], stats[l // 2 + i])
for i in range(l // 2)
]
if l % 2 != 0:
stats.append(last)
stats_sum = stats[0]
new_machine = dask.compute(
dask.delayed(m_step)(self, stats_sum)
)[0]
for attr in ["T", "sigma"]:
setattr(self, attr, getattr(new_machine, attr))
else:
stats = e_step(machine=self, data=X)
_ = m_step(self, stats)
logger.info(
f"IVector step {step+1:{len(str(self.max_iterations))}d}/{self.max_iterations}."
)
logger.info(f"Reached {step+1} steps.")
return self
def project(self, stats: GMMStats) -> np.ndarray:
"""Projects the GMMStats on the IVectorMachine.
This takes data already projected onto the UBM.
**Returns:**
The IVector of the input stats.
"""
return np.linalg.solve(
compute_id_tt_sigma_inv_t(stats, self.T, self.sigma),
compute_tt_sigma_inv_fnorm(
self.ubm.means, stats, self.T, self.sigma
),
)
def transform(self, X: List[GMMStats]) -> List[np.ndarray]:
"""Transforms the data using the trained IVectorMachine.
This takes MFCC data, will project them onto the ubm, and compute the IVector
statistics.
**Parameters:**
data
The data (MFCC features) to transform.
Arrays of shape (n_samples, n_features).
**Returns:**
The IVector for each sample. Arrays of shape (dim_t,)
"""
return [self.project(x) for x in X]
def _more_tags(self) -> Dict[str, Any]:
return {
"requires_fit": True,
"bob_fit_supports_dask_bag": True,
}
#!/usr/bin/env python
# @author: Yannick Dayer <yannick.dayer@idiap.ch>
# @date: Fri 06 May 2022 12:59:21 UTC+02
import contextlib
import copy
import dask.bag
import dask.distributed
import numpy as np
from h5py import File as HDF5File
from pkg_resources import resource_filename
from bob.learn.em import GMMMachine, GMMStats, IVectorMachine
from bob.learn.em.ivector import e_step, m_step
from bob.learn.em.test.test_kmeans import to_numpy
@contextlib.contextmanager
def _dask_distributed_context():
try:
client = dask.distributed.Client()
with client.as_current():
yield client
finally:
client.close()
def to_dask_bag(*args):
"""Converts all args into dask Bags."""
result = []
for x in args:
x = np.asarray(x)
result.append(dask.bag.from_sequence(x, npartitions=x.shape[0] * 2))
if len(result) == 1:
return result[0]
return result
def test_ivector_machine_base():
# Create the UBM and set its values manually
ubm = GMMMachine(n_gaussians=2)
ubm.weights = np.array([0.4, 0.6], dtype=float)
ubm.means = np.array([[1, 7, 4], [4, 5, 3]], dtype=float)
ubm.variances = np.array([[0.5, 1.0, 1.5], [1.0, 1.5, 2.0]], dtype=float)
machine = IVectorMachine(ubm=ubm, dim_t=4)
assert hasattr(machine, "ubm")
assert hasattr(machine, "T")
assert hasattr(machine, "sigma")
assert machine.T is None
assert machine.sigma is None
def test_ivector_machine_projection():
# Create the UBM and set its values manually
ubm = GMMMachine(n_gaussians=2)
ubm.weights = np.array([0.4, 0.6], dtype=float)
ubm.means = np.array([[1, 7, 4], [4, 5, 3]], dtype=float)
ubm.variances = np.array([[0.5, 1.0, 1.5], [1.0, 1.5, 2.0]], dtype=float)
machine = IVectorMachine(ubm=ubm, dim_t=2)
machine.T = np.array(
[[[1, 2], [4, 1], [0, 3]], [[5, 8], [7, 10], [11, 1]]], dtype=float
)
machine.sigma = np.array([[1, 2, 1], [3, 2, 4]], dtype=float)
# Manually create a feature (usually projected with the UBM)
gmm_projection = GMMStats(ubm.n_gaussians, ubm.means.shape[-1])
gmm_projection.t = 1
gmm_projection.n = np.array([0.4, 0.6], dtype=float)
gmm_projection.sum_px = np.array([[1, 2, 3], [2, 4, 3]], dtype=float)
gmm_projection.sum_pxx = np.array([[10, 20, 30], [40, 50, 60]], dtype=float)
# Reference from C++ implementation
ivector_projection_ref = np.array([-0.04213415, 0.21463343])
ivector_projection = machine.project(gmm_projection)
np.testing.assert_almost_equal(
ivector_projection_ref, ivector_projection, decimal=7
)
def test_ivector_machine_transformer():
dim_t = 2
ubm = GMMMachine(n_gaussians=2)
ubm.means = np.array([[1, 7, 4], [4, 5, 3]], dtype=float)
ubm.variances = np.array([[0.5, 1.0, 1.5], [1.0, 1.5, 2.0]], dtype=float)
machine = IVectorMachine(ubm=ubm, dim_t=dim_t)
machine.T = np.array(
[[[1, 2], [4, 1], [0, 3]], [[5, 8], [7, 10], [11, 1]]], dtype=float
)
machine.sigma = ubm.variances.copy()
assert hasattr(machine, "fit")
assert hasattr(machine, "transform")
transformed = machine.transform(ubm.transform([np.array([1, 2, 3])]))[0]
assert isinstance(transformed, np.ndarray)
np.testing.assert_almost_equal(
transformed, np.array([0.02774721, -0.35237828]), decimal=7
)
def test_ivector_machine_training():
gs1 = GMMStats.from_hdf5(
resource_filename("bob.learn.em", "data/ivector_gs1.hdf5")
)
gs2 = GMMStats.from_hdf5(
resource_filename("bob.learn.em", "data/ivector_gs2.hdf5")
)
data = [gs1, gs2]
# Define the ubm
ubm = GMMMachine(n_gaussians=2)
ubm.means = np.array([[1, 2, 3], [6, 7, 8]])
ubm.variances = np.ones((2, 3))
np.random.seed(0)
machine = IVectorMachine(ubm=ubm, dim_t=2)
machine.fit(data)
test_data = GMMStats(2, 3)
test_data.t = 1
test_data.log_likelihood = -0.5
test_data.n = np.array([0.5, 0.5])
test_data.sum_px = np.array([[8, 0, 4], [6, 6, 6]])
test_data.sum_pxx = np.array([[10, 20, 30], [60, 70, 80]])
projected = machine.project(test_data)
proj_reference = np.array([0.94234370, -0.61558459])
np.testing.assert_almost_equal(projected, proj_reference, decimal=4)
def _load_references_from_file(filename):
"""Loads the IVectorStats references, T, and sigma for one step"""
with HDF5File(filename, "r") as f:
keys = (
"nij_sigma_wij2",
"fnorm_sigma_wij",
"nij",
"snormij",
"T",
"sigma",
)
ret = {k: f[k][()] for k in keys}
return ret
def test_trainer_nosigma():
# Ubm
ubm = GMMMachine(2)
ubm.means = np.array([[1.0, 7, 4], [4, 5, 3]])
ubm.variances = np.array([[0.5, 1.0, 1.5], [1.0, 1.5, 2.0]])
ubm.weights = np.array([0.4, 0.6])
data = [
GMMStats.from_hdf5(
resource_filename("bob.learn.em", f"data/ivector_gs{i+1}.hdf5")
)
for i in range(2)
]
references = [
_load_references_from_file(
resource_filename(
"bob.learn.em", f"data/ivector_ref_nosigma_step{i+1}.hdf5"
)
)
for i in range(2)
]
# Machine
m = IVectorMachine(ubm, dim_t=2, update_sigma=False)
# Manual Initialization
m.dim_c = ubm.n_gaussians
m.dim_d = ubm.shape[-1]
m.T = np.array([[[1.0, 2], [4, 1], [0, 3]], [[5, 8], [7, 10], [11, 1]]])
init_sigma = np.array([[1.0, 2.0, 1.0], [3.0, 2.0, 4.0]])
m.sigma = copy.deepcopy(init_sigma)
stats = None
for it in range(2):
# E-Step
stats = e_step(m, data)
np.testing.assert_almost_equal(
references[it]["nij_sigma_wij2"], stats.nij_sigma_wij2, decimal=5
)
np.testing.assert_almost_equal(
references[it]["fnorm_sigma_wij"], stats.fnorm_sigma_wij, decimal=5
)
np.testing.assert_almost_equal(
references[it]["snormij"], stats.snormij, decimal=5
)
np.testing.assert_almost_equal(
references[it]["nij"], stats.nij, decimal=5
)
# M-Step
m_step(m, stats)
np.testing.assert_almost_equal(references[it]["T"], m.T, decimal=5)
np.testing.assert_equal(
init_sigma, m.sigma
) # sigma should not be updated
def test_trainer_update_sigma():
# Ubm
ubm = GMMMachine(n_gaussians=2)
ubm.weights = np.array([0.4, 0.6])
ubm.means = np.array([[1.0, 7, 4], [4, 5, 3]])
ubm.variances = np.array([[0.5, 1.0, 1.5], [1.0, 1.5, 2.0]])
data = [
GMMStats.from_hdf5(
resource_filename("bob.learn.em", f"data/ivector_gs{i+1}.hdf5")
)
for i in range(2)
]
references = [
_load_references_from_file(
resource_filename(
"bob.learn.em", f"data/ivector_ref_step{i+1}.hdf5"
)
)
for i in range(2)
]
# Machine
m = IVectorMachine(
ubm, dim_t=2, variance_floor=1e-5
) # update_sigma is True by default
# Manual Initialization
m.dim_c = ubm.n_gaussians
m.dim_d = ubm.shape[-1]
m.T = np.array([[[1.0, 2], [4, 1], [0, 3]], [[5, 8], [7, 10], [11, 1]]])
m.sigma = np.array([[1.0, 2.0, 1.0], [3.0, 2.0, 4.0]])
for it in range(2):
# E-Step
stats = e_step(m, data)
np.testing.assert_almost_equal(
references[it]["nij_sigma_wij2"], stats.nij_sigma_wij2, decimal=5
)
np.testing.assert_almost_equal(
references[it]["fnorm_sigma_wij"], stats.fnorm_sigma_wij, decimal=5
)
np.testing.assert_almost_equal(
references[it]["snormij"], stats.snormij, decimal=5
)
np.testing.assert_almost_equal(
references[it]["nij"], stats.nij, decimal=5
)
# M-Step
m_step(m, stats)
np.testing.assert_almost_equal(references[it]["T"], m.T, decimal=5)
np.testing.assert_almost_equal(
references[it]["sigma"], m.sigma, decimal=5
)
def test_ivector_fit():
# Ubm
ubm = GMMMachine(n_gaussians=2)
ubm.weights = np.array([0.4, 0.6])
ubm.means = np.array([[1.0, 7, 4], [4, 5, 3]])
ubm.variances = np.array([[0.5, 1.0, 1.5], [1.0, 1.5, 2.0]])
fit_data_file = resource_filename(
"bob.learn.em", "data/ivector_fit_data.hdf5"
)
with HDF5File(fit_data_file, "r") as f:
fit_data = f["array"][()]
test_data_file = resource_filename(
"bob.learn.em", "data/ivector_test_data.hdf5"
)
with HDF5File(test_data_file, "r") as f:
test_data = f["array"][()]
reference_result_file = resource_filename(
"bob.learn.em", "data/ivector_results.hdf5"
)
with HDF5File(reference_result_file, "r") as f:
reference_result = f["array"][()]
# Serial test
np.random.seed(0)
fit_data = to_numpy(fit_data)
projected_data = ubm.transform(fit_data)
m = IVectorMachine(ubm=ubm, dim_t=2, max_iterations=2)
m.fit(projected_data)
result = m.transform(ubm.transform(test_data))
np.testing.assert_almost_equal(result, reference_result, decimal=5)
# Parallel test
with _dask_distributed_context():
for transform in [to_numpy, to_dask_bag]:
np.random.seed(0)
fit_data = transform(fit_data)
projected_data = ubm.transform(fit_data)
projected_data = transform(projected_data)
m = IVectorMachine(ubm=ubm, dim_t=2, max_iterations=2)
m.fit(projected_data)
result = m.transform(ubm.transform(test_data))
np.testing.assert_almost_equal(
np.array(result), reference_result, decimal=5
)
......@@ -55,7 +55,7 @@ def run_whitening(with_dask):
t = Whitening()
t.fit(data)
s = t.transform(sample)
s = t.transform([sample])
# Makes sure results are good
eps = 1e-4
......@@ -65,7 +65,7 @@ def run_whitening(with_dask):
# Runs whitening (second method)
m2 = t.fit(data)
s2 = t.transform(sample)
s2 = t.transform([sample])
# Makes sure results are good
eps = 1e-4
......@@ -113,7 +113,7 @@ def run_wccn(with_dask):
# Runs WCCN (first method)
t = WCCN()
t.fit(X, y=y)
s = t.transform(sample)
s = t.transform([sample])
# Makes sure results are good
eps = 1e-4
......@@ -123,7 +123,7 @@ def run_wccn(with_dask):
# Runs WCCN (second method)
t.fit(X, y)
s2 = t.transform(sample)
s2 = t.transform([sample])
# Makes sure results are good
eps = 1e-4
......
......@@ -50,6 +50,8 @@ class WCCN(TransformerMixin, BaseEstimator):
from scipy.linalg import cholesky, inv
X = numerical_module.array(X)
possible_labels = set(y)
y_ = numerical_module.array(y)
......@@ -89,4 +91,7 @@ class WCCN(TransformerMixin, BaseEstimator):
def transform(self, X):
return ((X - self.input_subtract) / self.input_divide) @ self.weights
return [
((x - self.input_subtract) / self.input_divide) @ self.weights
for x in X
]
......@@ -59,7 +59,7 @@ class Whitening(TransformerMixin, BaseEstimator):
# 1. Computes the mean vector and the covariance matrix of the training set
mu = numerical_module.mean(X, axis=0)
cov = numerical_module.cov(X.T)
cov = numerical_module.cov(numerical_module.transpose(X))
# 2. Computes the inverse of the covariance matrix
inv_cov = pinv(cov) if self.pinv else inv(cov)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment