From 615a5061e58beace95d21d199e9b544a7f211f2b Mon Sep 17 00:00:00 2001
From: Yannick DAYER <yannick.dayer@idiap.ch>
Date: Mon, 29 Nov 2021 10:43:47 +0100
Subject: [PATCH] fix GMM_ML test and legacy files loading

---
 bob/learn/em/data/gmm_ML.hdf5 | Bin 9128 -> 12920 bytes
 bob/learn/em/mixture/gmm.py   |  25 +++++++++++--------------
 bob/learn/em/test/test_gmm.py |  19 +++++++++----------
 3 files changed, 20 insertions(+), 24 deletions(-)

diff --git a/bob/learn/em/data/gmm_ML.hdf5 b/bob/learn/em/data/gmm_ML.hdf5
index 74269fe3d824aa877e59609097828ee922d5dbc4..0326c186f11ad38387b3a4c4b1e3cc7a66d32b80 100644
GIT binary patch
literal 12920
zcmeD5aB<`1lHy_j0S*oZ76t(j3y%LoL4^@S2+I8r;W02IKpBisx&unDV1h6h89<PM
zK?1^M5QLhKt}Z0V)s=yPi2-IljD~7sFkpeO6d)9Xh6aSu0HH?7VIBe=u8sj9FD5_}
z+654Yfq|hx3&H@UGLRH{I(wiEk++b9(4>-#3~US_b0Gu+10w?m*a)x~0|Nsyg9KPg
zl93UdvO&2AA_q~&%D@4ZXJP_V3>?s$zzj18tP5-;m=HoTUw`hBL<R;}P60~~I1%jU
z@50Ey$^ecG1_p)~1j6+I)Ws*@<pD$hJzg>(!VEv8pbQ)oBvcs~6qs=+fC|IPi3(Pz
zz#yY&m#ze$Spbw05uP6G>56c82!Jz?00#p!hd6*ehsa+{;B1G;UkcE0fTb9)at6$B
zm;jL&p(H39fx;BzeNX_pJNx^9XjriXW`oN)kO~F{Lp=j91Dc%~biq6Z2AkxZ#NuKF
z^`!hHy`0p<qCCCST)o`PijvZzRK4`vTs?PRU*E*!jLf`Lbvq>UVCLlJKn;RY4g65X
zC^Z@aqaiRF0;3@?8UmvsFd71*Aut*OqaiRF0;3@?8UmvsK=lv+_2*&Z0ieD)2!kvk
zJU&D2IH(46sBUPDSHZ@kBS3+}z`$Sur6bVBLvvFT^NJZ57|IfhG86NXQ$cKyQW!tJ
zB%>&`I3qtN1tJ3*Z-E91NGVk?vGFIUU=JT(GvQ`1`41W(VqjpvJ<etUbt<88w!t-y
zO*lMA9dC3XG8_iycq3+bK!#5l86hbWRt~^u*t`en>BR#YPK5G%1y#~LxrShd2PAz%
zLIX7KqX8a|mS6(S2{AB$%1nqH2ZNiVPp~Tk14BqqD2M_pz@2|0py2`04Rfy$G#7$I
zA@e{ZB|IXBM0mjBX{6Agd=E<FApe4@gg+}kUh+)}us;D30M&m>tJBk`Ke6xK%3)}D
z@tQp&BV)~@hw}Vdf)0*nB|eM%4|Zrda=2!NFM~teufGve_TCN*jZf_Fv>vr**zoY-
zyGM==4!8dl9A(^O-{7Er^27rN$ovk&1{;VCxM+y%C^;GeqaiRF0;3@?8UmvsK!Xr~
ztovbu_j5bg@LS>mmDYfCBcMD?3f3!>kj0ck7MTI%XvjesIH(O!@d8aK0|y0LH^?D_
zLjjG1>DRA_8dB}(+|-i9__WNN)cCU0qT<Z_JZSiX#mkE_OHzyC!Ane_;(77uiKWHG
znUK*H28NQN#LT?ZB8KGryt34y^whlMRK(B<0|P^DVnuvfW=TnAUV40SNooOPw1t78
zJT)^tqokOjv>+w1BsCs3Tm$37hG}5RVd5}t45+REclE$794x(j$hulaoa<@@^uf-B
zE%$+NM#&Kq0+{6yWMMTUWH}uJsNV+)3|ut%<3Q2S@(D_TEMZ{4vOW*CG8C$S#uTKU
zV}_2tFfcIa;kOXx`V2$->R_W?5s>9`_!SQW*?~(})zBz`QqZu(98aTfx*E9fm<<gG
JC^bmJ0{|tZAMyYI

literal 9128
zcmeD5aB<`1lHy_j0S*oZ76t(j3y%Lo!3t%F5S05L!ed}afHD}NbO)4P!31G2GJqfh
zg9L=jAP6-dU0q0!t1ANoBLmEQ7!B3NV88-lc|fR9a)gC|hpS@%$jcERf&r9LAdC~x
zbOzxuFyzMP#iu8h78hqG<`pwQ_?dYHr6nK^m=BjV0O`p^s5XS~%TqJcGhoUY7#Y|Y
zz-b#yfb4~&WJX4a0E7gIgLE@6Ff#~%)eCSiFmQl{9Ka+i0|!`~iHQlUg@Zu?%x7jy
zfGC5i1}kP@U=V`xL5kBWvLhH68emcbN(cM-yD&1aGH5_u(ZL2`fZPdUB|wXh4N!3l
zkRSsCg9X%=Fm*d1!N_n!3Zeii&7i;x;o~Bq+88QWajAfa!^(*nN)Y~FBWa(mGN3^U
zrJw;YSko1#mH?#<Sp3533H0;?D+e|}{R1u(7#J8hU>c#~)TRi<7pxe9(ag~NEvR4*
z&B+c7Cfp1r|Dov{cX?0%jWB3Nf_elq-YTFn14H48FZ6s33lA7=^C8efT1(L3u2Vy`
z?Ehc~hLvZ<em5VrZ}|GopZk#`q_AYzU<2{h2AlkvuDypJh&sUZzb=(JanzT=VeaAe
z-um|54hI%goW9Po$^O88|J<C14jB4nBARdbXbHjfbN#t4YB~Lh{d@L+E4>%5*+c9Z
zB}YSGGz3ONU^E0qLtr!ns2Ku;`t=P$5Ze?W)F?T^LIBpEg!L0&h~f_bXycecgB`y*
zXhV?UhZuf!uvYd7Vf^Y~>GOsNesu#SOTYd}12hU?AqVmC;On2@AGd+^Ctx)B{j~|h
zB0NBSZ%|CaX!65j0W>_I6v$Al<5>gUf5o3ZVBr9xNgvl+0Sy2sg)<xm$G9HUc?`Mn
zxv7bHpmEc(#G=f^yyR4fOl~}kA77GDlv<pTpOXR_QcaJ~%P-0Wsn3f?9%qGx8*F|6
zVhZ(0LhXtK8Hio5avQ=MB}YmK;2I}<p-9^hpkF<_0a~yO?s^!#eRo10fA~W^#SkEk
zUmdJ{cY{C%hlS1xS^Va~jM*WFU)?~-(l1>dfJOl<<RCsCx$V0P!y-Ij^)ZYlzkT-r
q8Xiyzl02}?R}S~~-3w>{Kq;KzFxtL@1`SCR{mQ!!&<Gsd<sAUOdLD=X

diff --git a/bob/learn/em/mixture/gmm.py b/bob/learn/em/mixture/gmm.py
index d64b6f1..4370d5b 100644
--- a/bob/learn/em/mixture/gmm.py
+++ b/bob/learn/em/mixture/gmm.py
@@ -439,7 +439,7 @@ class GMMMachine(BaseEstimator):
         try:
             version_major, version_minor = hdf5.get("meta_file_version")[()].split(".")
             logger.debug(
-                f"Reading a GMMStats HDF5 file of version {version_major}.{version_minor}"
+                f"Reading a GMMMachine HDF5 file of version {version_major}.{version_minor}"
             )
         except (TypeError, RuntimeError):
             version_major, version_minor = 0, 0
@@ -460,12 +460,12 @@ class GMMMachine(BaseEstimator):
                 update_variances=hdf5["update_variances"][()],
                 update_weights=hdf5["update_weights"][()],
             )
-            gaussians = hdf5["gaussians"]
-            self.means = gaussians["means"][()]
-            self.variances = gaussians["variances"][()]
-            self.variance_thresholds = gaussians["variance_thresholds"][()]
+            gaussians_group = hdf5["gaussians"]
+            self.means = gaussians_group["means"][()]
+            self.variances = gaussians_group["variances"][()]
+            self.variance_thresholds = gaussians_group["variance_thresholds"][()]
         else:  # Legacy file version
-            logger.info("Loading a legacy HDF5 stats file.")
+            logger.info("Loading a legacy HDF5 machine file.")
             n_gaussians = int(hdf5["m_n_gaussians"][()])
             g_means = []
             g_variances = []
@@ -475,14 +475,11 @@ class GMMMachine(BaseEstimator):
                 g_means.append(gaussian_group["m_mean"][()])
                 g_variances.append(gaussian_group["m_variance"][()])
                 g_variance_thresholds.append(gaussian_group["m_variance_thresholds"][()])
-            self = cls(
-                n_gaussians=n_gaussians,
-                ubm=ubm,
-                weights=hdf5["m_weights"][()],
-            )
-            self.means = np.array(g_means)
-            self.variances = np.array(g_variances)
-            self.variance_thresholds = np.array(g_variance_thresholds)
+            weights = hdf5["m_weights"][()].reshape(n_gaussians)
+            self = cls(n_gaussians=n_gaussians, ubm=ubm, weights=weights)
+            self.means = np.array(g_means).reshape(n_gaussians,-1)
+            self.variances = np.array(g_variances).reshape(n_gaussians,-1)
+            self.variance_thresholds = np.array(g_variance_thresholds).reshape(n_gaussians,-1)
         return self
 
     def save(self, hdf5):
diff --git a/bob/learn/em/test/test_gmm.py b/bob/learn/em/test/test_gmm.py
index 40d535a..d69415d 100644
--- a/bob/learn/em/test/test_gmm.py
+++ b/bob/learn/em/test/test_gmm.py
@@ -547,7 +547,7 @@ def test_map_transformer():
 ## Tests from `test_em.py`
 
 def loadGMM():
-    gmm = GMMMachine(2)
+    gmm = GMMMachine(n_gaussians=2)
 
     gmm.weights = bob.io.base.load(datafile("gmm.init_weights.hdf5", __name__, path="../data/"))
     gmm.means = bob.io.base.load(datafile("gmm.init_means.hdf5", __name__, path="../data/"))
@@ -578,13 +578,11 @@ def test_gmm_ML_1():
     gmm.update_weights = True
     gmm = gmm.fit(ar)
 
-    #config = HDF5File(datafile("gmm_ML.hdf5", __name__), "w")
-    #gmm.save(config)
+    # Generate reference
+    # gmm.save(HDF5File(datafile("gmm_ML.hdf5", __name__, path="../data"), "w"))
 
-    gmm_ref = GMMMachine.from_hdf5(HDF5File(datafile("gmm_ML.hdf5", __name__, path="../data"), "r")) # TODO update the ref file(s)
-    gmm_ref_32bit_debug = GMMMachine.from_hdf5(HDF5File(datafile("gmm_ML_32bit_debug.hdf5", __name__, path="../data/"), "r"))
-    gmm_ref_32bit_release = GMMMachine.from_hdf5(HDF5File(datafile("gmm_ML_32bit_release.hdf5", __name__, path="../data/"), "r"))
-    assert (gmm == gmm_ref)  # or (gmm == gmm_ref_32bit_release) or (gmm == gmm_ref_32bit_debug)
+    gmm_ref = GMMMachine.from_hdf5(HDF5File(datafile("gmm_ML.hdf5", __name__, path="../data"), "r"))
+    assert gmm == gmm_ref
 
 
 def test_gmm_ML_2():
@@ -592,7 +590,7 @@ def test_gmm_ML_2():
     ar = bob.io.base.load(datafile("dataNormalized.hdf5", __name__, path="../data/"))
 
     # Initialize GMMMachine
-    gmm = GMMMachine(5, 45)
+    gmm = GMMMachine(n_gaussians=5)
     gmm.means = bob.io.base.load(datafile("meansAfterKMeans.hdf5", __name__, path="../data/")).astype("float64")
     gmm.variances = bob.io.base.load(datafile("variancesAfterKMeans.hdf5", __name__, path="../data/")).astype("float64")
     gmm.weights = np.exp(bob.io.base.load(datafile("weightsAfterKMeans.hdf5", __name__, path="../data/")).astype("float64"))
@@ -628,7 +626,7 @@ def test_gmm_ML_parallel():
     ar = da.array(bob.io.base.load(datafile("dataNormalized.hdf5", __name__, path="../data/")))
 
     # Initialize GMMMachine
-    gmm = GMMMachine(5, 45)
+    gmm = GMMMachine(n_gaussians=5)
     gmm.means = bob.io.base.load(datafile("meansAfterKMeans.hdf5", __name__, path="../data/")).astype("float64")
     gmm.variances = bob.io.base.load(datafile("variancesAfterKMeans.hdf5", __name__, path="../data/")).astype("float64")
     gmm.weights = np.exp(bob.io.base.load(datafile("weightsAfterKMeans.hdf5", __name__, path="../data/")).astype("float64"))
@@ -703,6 +701,8 @@ def test_gmm_MAP_2():
     gmm.variances = variances
     gmm.weights = weights
 
+    gmm = gmm.fit(data)
+
     gmm_adapted = GMMMachine(
         n_gaussians=2,
         trainer="map",
@@ -717,7 +717,6 @@ def test_gmm_MAP_2():
     gmm_adapted.variances = variances
     gmm_adapted.weights = weights
 
-    gmm = gmm.fit(data)
 
     gmm_adapted = gmm_adapted.fit(data)
 
-- 
GitLab