Skip to content
Snippets Groups Projects
Commit 26bebe08 authored by Yannick DAYER's avatar Yannick DAYER
Browse files

[refactor] Rename k_means as kmeans

parent 63c831ff
No related branches found
No related tags found
1 merge request!49Fix np memory issues and rename k_means to kmeans
Pipeline #59378 failed
import bob.extension import bob.extension
from .gmm import GMMMachine, GMMStats from .gmm import GMMMachine, GMMStats
from .k_means import KMeansMachine from .kmeans import KMeansMachine
from .linear_scoring import linear_scoring # noqa: F401 from .linear_scoring import linear_scoring # noqa: F401
from .wccn import WCCN from .wccn import WCCN
from .whitening import Whitening from .whitening import Whitening
......
...@@ -18,7 +18,7 @@ import numpy as np ...@@ -18,7 +18,7 @@ import numpy as np
from h5py import File as HDF5File from h5py import File as HDF5File
from sklearn.base import BaseEstimator from sklearn.base import BaseEstimator
from .k_means import ( from .kmeans import (
KMeansMachine, KMeansMachine,
array_to_delayed_list, array_to_delayed_list,
check_and_persist_dask_input, check_and_persist_dask_input,
......
File moved
...@@ -16,7 +16,7 @@ import dask.array as da ...@@ -16,7 +16,7 @@ import dask.array as da
import numpy as np import numpy as np
import scipy.spatial.distance import scipy.spatial.distance
from bob.learn.em import KMeansMachine, k_means from bob.learn.em import KMeansMachine, kmeans
def to_numpy(*args): def to_numpy(*args):
...@@ -187,6 +187,6 @@ def test_get_centroids_distance(): ...@@ -187,6 +187,6 @@ def test_get_centroids_distance():
oracle = scipy.spatial.distance.cdist(means, data, metric="sqeuclidean") oracle = scipy.spatial.distance.cdist(means, data, metric="sqeuclidean")
for transform in (to_numpy,): for transform in (to_numpy,):
data, means = transform(data, means) data, means = transform(data, means)
dist = k_means.get_centroids_distance(data, means) dist = kmeans.get_centroids_distance(data, means)
np.testing.assert_allclose(dist, oracle) np.testing.assert_allclose(dist, oracle)
assert type(data) is type(dist), (type(data), type(dist)) assert type(data) is type(dist), (type(data), type(dist))
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment