Skip to content
Snippets Groups Projects
Commit de9a8fec authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

[plotting] Fix the random grouping option

parent 43186bae
No related branches found
No related tags found
No related merge requests found
Pipeline #49345 passed
#!/usr/bin/env python
import numpy
import bob.learn.em
import numpy as np
from numpy.random import default_rng
def grouping(scores, gformat='random', npoints=500, seed=None, **kwargs):
def grouping(scores, gformat="random", npoints=500, seed=None, **kwargs):
scores = numpy.asarray(scores)
scores = np.asarray(scores)
if scores.size == 0:
return scores
if(gformat == "kmeans"):
if gformat == "kmeans":
kmeans_machine = bob.learn.em.KMeansMachine(npoints, 2)
kmeans_trainer = bob.learn.em.KMeansTrainer()
bob.learn.em.train(
kmeans_trainer, kmeans_machine, scores, max_iterations=500,
convergence_threshold=0.1)
kmeans_trainer,
kmeans_machine,
scores,
max_iterations=500,
convergence_threshold=0.1,
)
scores = kmeans_machine.means
elif(gformat == "random"):
if seed is not None:
numpy.random.seed(seed)
scores_indexes = numpy.array(
numpy.random.rand(npoints) * scores.shape[0], dtype=int)
scores = scores[scores_indexes]
elif gformat == "random":
rng = default_rng(seed)
scores = rng.choice(scores, npoints, replace=False)
return scores
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment