-
André Anjos authoredAndré Anjos authored
test_kmeans.py 2.21 KiB
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# Laurent El Shafey <Laurent.El-Shafey@idiap.ch>
# Thu Feb 16 17:57:10 2012 +0200
#
# Copyright (C) 2011-2013 Idiap Research Institute, Martigny, Switzerland
"""Tests the KMeans machine
"""
import os, sys
import unittest
import bob
import numpy, math
import tempfile
def equals(x, y, epsilon):
return (abs(x - y) < epsilon)
class KMeansMachineTest(unittest.TestCase):
"""Performs various KMeans machine-related tests."""
def test01_KMeansMachine(self):
# Test a KMeansMachine
means = numpy.array([[3, 70, 0], [4, 72, 0]], 'float64')
mean = numpy.array([3,70,1], 'float64')
# Initializes a KMeansMachine
km = bob.machine.KMeansMachine(2,3)
km.means = means
self.assertTrue( km.dim_c == 2 )
self.assertTrue( km.dim_d == 3 )
# Sets and gets
self.assertTrue( (km.means == means).all() )
self.assertTrue( (km.get_mean(0) == means[0,:]).all() )
self.assertTrue( (km.get_mean(1) == means[1,:]).all() )
km.set_mean(0, mean)
self.assertTrue( (km.get_mean(0) == mean).all() )
# Distance and closest mean
eps = 1e-10
self.assertTrue( equals( km.get_distance_from_mean(mean, 0), 0, eps) )
self.assertTrue( equals( km.get_distance_from_mean(mean, 1), 6, eps) )
(index, dist) = km.get_closest_mean(mean)
self.assertTrue( index == 0)
self.assertTrue( equals( dist, 0, eps) )
self.assertTrue( equals( km.get_min_distance(mean), 0, eps) )
# Loads and saves
filename = str(tempfile.mkstemp(".hdf5")[1])
km.save(bob.io.HDF5File(filename, 'w'))
km_loaded = bob.machine.KMeansMachine(bob.io.HDF5File(filename))
self.assertTrue( km == km_loaded )
# Resize
km.resize(4,5)
self.assertTrue( km.dim_c == 4 )
self.assertTrue( km.dim_d == 5 )
# Copy constructor and comparison operators
km.resize(2,3)
km2 = bob.machine.KMeansMachine(km)
self.assertTrue( km2 == km)
self.assertFalse( km2 != km)
self.assertTrue( km2.is_similar_to(km) )
means2 = numpy.array([[3, 70, 0], [4, 72, 2]], 'float64')
km2.means = means2
self.assertFalse( km2 == km)
self.assertTrue( km2 != km)
self.assertFalse( km2.is_similar_to(km) )
# Clean-up
os.unlink(filename)