Skip to content
Snippets Groups Projects
Commit de9a8fec authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

[plotting] Fix the random grouping option

parent 43186bae
No related branches found
Tags v2.0.3
No related merge requests found
Pipeline #49345 passed
#!/usr/bin/env python #!/usr/bin/env python
import numpy
import bob.learn.em import bob.learn.em
import numpy as np
from numpy.random import default_rng
def grouping(scores, gformat='random', npoints=500, seed=None, **kwargs): def grouping(scores, gformat="random", npoints=500, seed=None, **kwargs):
scores = numpy.asarray(scores) scores = np.asarray(scores)
if scores.size == 0: if scores.size == 0:
return scores return scores
if(gformat == "kmeans"): if gformat == "kmeans":
kmeans_machine = bob.learn.em.KMeansMachine(npoints, 2) kmeans_machine = bob.learn.em.KMeansMachine(npoints, 2)
kmeans_trainer = bob.learn.em.KMeansTrainer() kmeans_trainer = bob.learn.em.KMeansTrainer()
bob.learn.em.train( bob.learn.em.train(
kmeans_trainer, kmeans_machine, scores, max_iterations=500, kmeans_trainer,
convergence_threshold=0.1) kmeans_machine,
scores,
max_iterations=500,
convergence_threshold=0.1,
)
scores = kmeans_machine.means scores = kmeans_machine.means
elif(gformat == "random"): elif gformat == "random":
if seed is not None: rng = default_rng(seed)
numpy.random.seed(seed) scores = rng.choice(scores, npoints, replace=False)
scores_indexes = numpy.array(
numpy.random.rand(npoints) * scores.shape[0], dtype=int)
scores = scores[scores_indexes]
return scores return scores
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment