From 615a5061e58beace95d21d199e9b544a7f211f2b Mon Sep 17 00:00:00 2001 From: Yannick DAYER <yannick.dayer@idiap.ch> Date: Mon, 29 Nov 2021 10:43:47 +0100 Subject: [PATCH] fix GMM_ML test and legacy files loading --- bob/learn/em/data/gmm_ML.hdf5 | Bin 9128 -> 12920 bytes bob/learn/em/mixture/gmm.py | 25 +++++++++++-------------- bob/learn/em/test/test_gmm.py | 19 +++++++++---------- 3 files changed, 20 insertions(+), 24 deletions(-) diff --git a/bob/learn/em/data/gmm_ML.hdf5 b/bob/learn/em/data/gmm_ML.hdf5 index 74269fe3d824aa877e59609097828ee922d5dbc4..0326c186f11ad38387b3a4c4b1e3cc7a66d32b80 100644 GIT binary patch literal 12920 zcmeD5aB<`1lHy_j0S*oZ76t(j3y%LoL4^@S2+I8r;W02IKpBisx&unDV1h6h89<PM zK?1^M5QLhKt}Z0V)s=yPi2-IljD~7sFkpeO6d)9Xh6aSu0HH?7VIBe=u8sj9FD5_} z+654Yfq|hx3&H@UGLRH{I(wiEk++b9(4>-#3~US_b0Gu+10w?m*a)x~0|Nsyg9KPg zl93UdvO&2AA_q~&%D@4ZXJP_V3>?s$zzj18tP5-;m=HoTUw`hBL<R;}P60~~I1%jU z@50Ey$^ecG1_p)~1j6+I)Ws*@<pD$hJzg>(!VEv8pbQ)oBvcs~6qs=+fC|IPi3(Pz zz#yY&m#ze$Spbw05uP6G>56c82!Jz?00#p!hd6*ehsa+{;B1G;UkcE0fTb9)at6$B zm;jL&p(H39fx;BzeNX_pJNx^9XjriXW`oN)kO~F{Lp=j91Dc%~biq6Z2AkxZ#NuKF z^`!hHy`0p<qCCCST)o`PijvZzRK4`vTs?PRU*E*!jLf`Lbvq>UVCLlJKn;RY4g65X zC^Z@aqaiRF0;3@?8UmvsFd71*Aut*OqaiRF0;3@?8UmvsK=lv+_2*&Z0ieD)2!kvk zJU&D2IH(46sBUPDSHZ@kBS3+}z`$Sur6bVBLvvFT^NJZ57|IfhG86NXQ$cKyQW!tJ zB%>&`I3qtN1tJ3*Z-E91NGVk?vGFIUU=JT(GvQ`1`41W(VqjpvJ<etUbt<88w!t-y zO*lMA9dC3XG8_iycq3+bK!#5l86hbWRt~^u*t`en>BR#YPK5G%1y#~LxrShd2PAz% zLIX7KqX8a|mS6(S2{AB$%1nqH2ZNiVPp~Tk14BqqD2M_pz@2|0py2`04Rfy$G#7$I zA@e{ZB|IXBM0mjBX{6Agd=E<FApe4@gg+}kUh+)}us;D30M&m>tJBk`Ke6xK%3)}D z@tQp&BV)~@hw}Vdf)0*nB|eM%4|Zrda=2!NFM~teufGve_TCN*jZf_Fv>vr**zoY- zyGM==4!8dl9A(^O-{7Er^27rN$ovk&1{;VCxM+y%C^;GeqaiRF0;3@?8UmvsK!Xr~ ztovbu_j5bg@LS>mmDYfCBcMD?3f3!>kj0ck7MTI%XvjesIH(O!@d8aK0|y0LH^?D_ zLjjG1>DRA_8dB}(+|-i9__WNN)cCU0qT<Z_JZSiX#mkE_OHzyC!Ane_;(77uiKWHG znUK*H28NQN#LT?ZB8KGryt34y^whlMRK(B<0|P^DVnuvfW=TnAUV40SNooOPw1t78 zJT)^tqokOjv>+w1BsCs3Tm$37hG}5RVd5}t45+REclE$794x(j$hulaoa<@@^uf-B zE%$+NM#&Kq0+{6yWMMTUWH}uJsNV+)3|ut%<3Q2S@(D_TEMZ{4vOW*CG8C$S#uTKU zV}_2tFfcIa;kOXx`V2$->R_W?5s>9`_!SQW*?~(})zBz`QqZu(98aTfx*E9fm<<gG JC^bmJ0{|tZAMyYI literal 9128 zcmeD5aB<`1lHy_j0S*oZ76t(j3y%Lo!3t%F5S05L!ed}afHD}NbO)4P!31G2GJqfh zg9L=jAP6-dU0q0!t1ANoBLmEQ7!B3NV88-lc|fR9a)gC|hpS@%$jcERf&r9LAdC~x zbOzxuFyzMP#iu8h78hqG<`pwQ_?dYHr6nK^m=BjV0O`p^s5XS~%TqJcGhoUY7#Y|Y zz-b#yfb4~&WJX4a0E7gIgLE@6Ff#~%)eCSiFmQl{9Ka+i0|!`~iHQlUg@Zu?%x7jy zfGC5i1}kP@U=V`xL5kBWvLhH68emcbN(cM-yD&1aGH5_u(ZL2`fZPdUB|wXh4N!3l zkRSsCg9X%=Fm*d1!N_n!3Zeii&7i;x;o~Bq+88QWajAfa!^(*nN)Y~FBWa(mGN3^U zrJw;YSko1#mH?#<Sp3533H0;?D+e|}{R1u(7#J8hU>c#~)TRi<7pxe9(ag~NEvR4* z&B+c7Cfp1r|Dov{cX?0%jWB3Nf_elq-YTFn14H48FZ6s33lA7=^C8efT1(L3u2Vy` z?Ehc~hLvZ<em5VrZ}|GopZk#`q_AYzU<2{h2AlkvuDypJh&sUZzb=(JanzT=VeaAe z-um|54hI%goW9Po$^O88|J<C14jB4nBARdbXbHjfbN#t4YB~Lh{d@L+E4>%5*+c9Z zB}YSGGz3ONU^E0qLtr!ns2Ku;`t=P$5Ze?W)F?T^LIBpEg!L0&h~f_bXycecgB`y* zXhV?UhZuf!uvYd7Vf^Y~>GOsNesu#SOTYd}12hU?AqVmC;On2@AGd+^Ctx)B{j~|h zB0NBSZ%|CaX!65j0W>_I6v$Al<5>gUf5o3ZVBr9xNgvl+0Sy2sg)<xm$G9HUc?`Mn zxv7bHpmEc(#G=f^yyR4fOl~}kA77GDlv<pTpOXR_QcaJ~%P-0Wsn3f?9%qGx8*F|6 zVhZ(0LhXtK8Hio5avQ=MB}YmK;2I}<p-9^hpkF<_0a~yO?s^!#eRo10fA~W^#SkEk zUmdJ{cY{C%hlS1xS^Va~jM*WFU)?~-(l1>dfJOl<<RCsCx$V0P!y-Ij^)ZYlzkT-r q8Xiyzl02}?R}S~~-3w>{Kq;KzFxtL@1`SCR{mQ!!&<Gsd<sAUOdLD=X diff --git a/bob/learn/em/mixture/gmm.py b/bob/learn/em/mixture/gmm.py index d64b6f1..4370d5b 100644 --- a/bob/learn/em/mixture/gmm.py +++ b/bob/learn/em/mixture/gmm.py @@ -439,7 +439,7 @@ class GMMMachine(BaseEstimator): try: version_major, version_minor = hdf5.get("meta_file_version")[()].split(".") logger.debug( - f"Reading a GMMStats HDF5 file of version {version_major}.{version_minor}" + f"Reading a GMMMachine HDF5 file of version {version_major}.{version_minor}" ) except (TypeError, RuntimeError): version_major, version_minor = 0, 0 @@ -460,12 +460,12 @@ class GMMMachine(BaseEstimator): update_variances=hdf5["update_variances"][()], update_weights=hdf5["update_weights"][()], ) - gaussians = hdf5["gaussians"] - self.means = gaussians["means"][()] - self.variances = gaussians["variances"][()] - self.variance_thresholds = gaussians["variance_thresholds"][()] + gaussians_group = hdf5["gaussians"] + self.means = gaussians_group["means"][()] + self.variances = gaussians_group["variances"][()] + self.variance_thresholds = gaussians_group["variance_thresholds"][()] else: # Legacy file version - logger.info("Loading a legacy HDF5 stats file.") + logger.info("Loading a legacy HDF5 machine file.") n_gaussians = int(hdf5["m_n_gaussians"][()]) g_means = [] g_variances = [] @@ -475,14 +475,11 @@ class GMMMachine(BaseEstimator): g_means.append(gaussian_group["m_mean"][()]) g_variances.append(gaussian_group["m_variance"][()]) g_variance_thresholds.append(gaussian_group["m_variance_thresholds"][()]) - self = cls( - n_gaussians=n_gaussians, - ubm=ubm, - weights=hdf5["m_weights"][()], - ) - self.means = np.array(g_means) - self.variances = np.array(g_variances) - self.variance_thresholds = np.array(g_variance_thresholds) + weights = hdf5["m_weights"][()].reshape(n_gaussians) + self = cls(n_gaussians=n_gaussians, ubm=ubm, weights=weights) + self.means = np.array(g_means).reshape(n_gaussians,-1) + self.variances = np.array(g_variances).reshape(n_gaussians,-1) + self.variance_thresholds = np.array(g_variance_thresholds).reshape(n_gaussians,-1) return self def save(self, hdf5): diff --git a/bob/learn/em/test/test_gmm.py b/bob/learn/em/test/test_gmm.py index 40d535a..d69415d 100644 --- a/bob/learn/em/test/test_gmm.py +++ b/bob/learn/em/test/test_gmm.py @@ -547,7 +547,7 @@ def test_map_transformer(): ## Tests from `test_em.py` def loadGMM(): - gmm = GMMMachine(2) + gmm = GMMMachine(n_gaussians=2) gmm.weights = bob.io.base.load(datafile("gmm.init_weights.hdf5", __name__, path="../data/")) gmm.means = bob.io.base.load(datafile("gmm.init_means.hdf5", __name__, path="../data/")) @@ -578,13 +578,11 @@ def test_gmm_ML_1(): gmm.update_weights = True gmm = gmm.fit(ar) - #config = HDF5File(datafile("gmm_ML.hdf5", __name__), "w") - #gmm.save(config) + # Generate reference + # gmm.save(HDF5File(datafile("gmm_ML.hdf5", __name__, path="../data"), "w")) - gmm_ref = GMMMachine.from_hdf5(HDF5File(datafile("gmm_ML.hdf5", __name__, path="../data"), "r")) # TODO update the ref file(s) - gmm_ref_32bit_debug = GMMMachine.from_hdf5(HDF5File(datafile("gmm_ML_32bit_debug.hdf5", __name__, path="../data/"), "r")) - gmm_ref_32bit_release = GMMMachine.from_hdf5(HDF5File(datafile("gmm_ML_32bit_release.hdf5", __name__, path="../data/"), "r")) - assert (gmm == gmm_ref) # or (gmm == gmm_ref_32bit_release) or (gmm == gmm_ref_32bit_debug) + gmm_ref = GMMMachine.from_hdf5(HDF5File(datafile("gmm_ML.hdf5", __name__, path="../data"), "r")) + assert gmm == gmm_ref def test_gmm_ML_2(): @@ -592,7 +590,7 @@ def test_gmm_ML_2(): ar = bob.io.base.load(datafile("dataNormalized.hdf5", __name__, path="../data/")) # Initialize GMMMachine - gmm = GMMMachine(5, 45) + gmm = GMMMachine(n_gaussians=5) gmm.means = bob.io.base.load(datafile("meansAfterKMeans.hdf5", __name__, path="../data/")).astype("float64") gmm.variances = bob.io.base.load(datafile("variancesAfterKMeans.hdf5", __name__, path="../data/")).astype("float64") gmm.weights = np.exp(bob.io.base.load(datafile("weightsAfterKMeans.hdf5", __name__, path="../data/")).astype("float64")) @@ -628,7 +626,7 @@ def test_gmm_ML_parallel(): ar = da.array(bob.io.base.load(datafile("dataNormalized.hdf5", __name__, path="../data/"))) # Initialize GMMMachine - gmm = GMMMachine(5, 45) + gmm = GMMMachine(n_gaussians=5) gmm.means = bob.io.base.load(datafile("meansAfterKMeans.hdf5", __name__, path="../data/")).astype("float64") gmm.variances = bob.io.base.load(datafile("variancesAfterKMeans.hdf5", __name__, path="../data/")).astype("float64") gmm.weights = np.exp(bob.io.base.load(datafile("weightsAfterKMeans.hdf5", __name__, path="../data/")).astype("float64")) @@ -703,6 +701,8 @@ def test_gmm_MAP_2(): gmm.variances = variances gmm.weights = weights + gmm = gmm.fit(data) + gmm_adapted = GMMMachine( n_gaussians=2, trainer="map", @@ -717,7 +717,6 @@ def test_gmm_MAP_2(): gmm_adapted.variances = variances gmm_adapted.weights = weights - gmm = gmm.fit(data) gmm_adapted = gmm_adapted.fit(data) -- GitLab