From fe1b5191c7a59960f66de1689882c689a3648308 Mon Sep 17 00:00:00 2001 From: Yannick DAYER <yannick.dayer@idiap.ch> Date: Wed, 19 Oct 2022 20:11:08 +0200 Subject: [PATCH] [test] Added a GMM loading test for legacy files. --- bob/learn/em/data/gmm_MAP.hdf5 | Bin 12016 -> 11960 bytes bob/learn/em/data/gmm_ML.hdf5 | Bin 12016 -> 11960 bytes bob/learn/em/data/gmm_ML_fitted.hdf5 | Bin 0 -> 11960 bytes bob/learn/em/data/gmm_ML_legacy.hdf5 | Bin 0 -> 9128 bytes bob/learn/em/data/stats.hdf5 | Bin 22635 -> 9552 bytes bob/learn/em/gmm.py | 2 +- bob/learn/em/test/test_gmm.py | 50 +++++++++++++++++---------- 7 files changed, 32 insertions(+), 20 deletions(-) create mode 100644 bob/learn/em/data/gmm_ML_fitted.hdf5 create mode 100644 bob/learn/em/data/gmm_ML_legacy.hdf5 diff --git a/bob/learn/em/data/gmm_MAP.hdf5 b/bob/learn/em/data/gmm_MAP.hdf5 index 8106590f1721a6c6084c63bc9637b52a5b39b24e..8d3574922e7d20ca228ef146023433466dd1e2c2 100644 GIT binary patch delta 493 zcmewmyCZgj2IG#2nwIh`42%p63=#|wAiy91W+*U&DTW;r8`Xtgus}F4ML#q)?qXxo za$tuDGctf6!VE#E83}3-`38-N2hAocFv~D9OtxkAXJnY%%WTc72$JLg5ey7Wo8K}= zGjbs`Ob%i-W4ggS@t_<h%u<8R8(Eh#PQ1V|Swm^X#1jIOI1CsSCmV9e2)O(D`X(l4 zWag!++kwoRoXVlh)WN@bB1b;|#0&hBBsA)|{#+Neoc_fAJ$t~F-iz1l85tPPDNFz1 z(-L$DxV`_T$p2sm`J$B?D|{IoivMqvm9qDCU|4xp?054~`-ZRY{J9@FIviL~ar!#T zCi?^X{d033IyiuAVqn-{gP(@TZtj<eWZ$Hqz%=;-v(V;8N`5Si0h<lg|1(aGV3%N8 zpfS0KJ#MmrrU0YCWJgVPCI^*?2j#dLAbJ@BR3{#ko4i4jgGoVY;=z5BFX&WF{-m{x o$w3PgJPJ@%D<&V*R%e=_Gx6ZF%@=f(7$<jWWiVDx_ST&W05rsWVgLXD delta 469 zcmdlH`yqCM2IGf`nwIjc42%p63=#|wAiy91W+*U&DTWIZ8`XsvSRovkA_L8hyV#ht z6qF#sj0_-%FhdY(hJZRmK7xJXL9@vU%rcBDlWm#(8CfRxGFuDlf+RUW1Oo#Tm=c(H zP;|0@V$fzS7B)sMgzm{!tY(ZKCKs}*GkLI0JUDH$0h<8h<OU51rV94Svozu+UND%% zVZf+6*^om<A~&<5q_ikiFFiL`&)wJ8H!(RQGcQ%$ZgMP#GGoK$Mvij+NfH{-7E7zs z)2BbN@7>B_Xn66OJtG5ysb)hKpO&D*!Ri@BBL9OODk~<OU*^l;kRlbOC28;Nz|i=_ z{!Z&rdxi}UAHI9!=-_bsPr*^fP4*2A>L*V;aBu+ofnkHq=7}7&?3)x6m?mD}nC!v6 zV)6_fq0L50ek_a~n;X^sGfocBW}Ey^gNx~b^2CF3oEZ>(3?C*NYN|6asDZ=+pkfy` rH)`HzWUQE+sIAWQKo=whi5>=q3!67;b23eS#Vj!Sm9E3&B!zVV4Tg6g diff --git a/bob/learn/em/data/gmm_ML.hdf5 b/bob/learn/em/data/gmm_ML.hdf5 index 5e667e2498f69cf59303fc30c675c9f6f7b1fcc3..4a5bd139a9a8e681d4da0d0fa4b39faac98f3972 100644 GIT binary patch delta 493 zcmewmyCZgj2IG#2nwIh`42%p63=#|wAiy91W+*U&DTW;r8`Xtgus}F4ML#q)?qXxo za$tuDGctf6!VE#E83}3-`38-N2hAocFv~D9OtxkAXJnY%%WTc72$JLg5ey7Wo8K}= zGjbs`Ob%i-W4ggS@t_<h%u<8R8(Eh#PQ1V|Swm^X#1jIOI1CsSCmV9e2)O(D`X(l4 zWag!++kwoRoXVlh)WN@bB1b;|#0&hBBqZuTW<2iL`s0b+36MMk0|VEe>!Oy^pV+@= z54h5M@tQp&1B1<nKo4mxL5I6e4b`&$gB@O%N}V|B%iu8g@Op23dv6DZm1o6%Hy^cc z`1;PD`;nu=fdv((ud{5jKd|3FH|L>)1H|MFHrVLR^%9xvn-mn7CVyZS+Wbh#kA*Q{ zv!VKb#>o-v5=;v;CKs{CO*YUJU^JNQsHx86pfd5G99ILxB!&Rhi3jB-Z_wmmQc#+B zaNpz$I#rWDX)R-N&;kXI0#wzC$p^L7nP%urJos$$1sx^E$z56*jMbC9b*BOVe`JCe delta 485 zcmdlH`yqCM2IGf`nwIjc42%p63=#|wAiy91W+*U&DTWIZ8`XsvSRovkA_L8hyV#ht z6qF#sj0_-%FhdY(hJZRmK7xJXL9@vU%rcBDlWm#(8CfRxGFuDlf+RUW1Oo#Tm=c(H zP;|0@V$fzS7B)sMgzm{!tY(ZKCKs}*GkLI0JUDH$0h<8h<OU51rV94Svozu+UND%% zVZf+6*^om<A~&<5q_ikiFFiL`&)wJ8H!(RQGcQ%$ZgMP#GGoK$Mvij+NfHv-e^!3H z<eL;=e*$bX1B1oV>h$#KPwacQau^z3yk^hH$XN5}p*+8qpo8OCiO(YcgB@Cq9Ijd6 z%is|A>u-dVy|)8H;}iQktw-${HavXz?vbN|!|gu>M;SNSH#n%DJn_K60b=q7o6QqB zYS}j_C@@XDz%ki_eZ}M%IzpR`l>As2J2p3}{b!sUpv^Y<p9UAx1LcVa<v24S`WQY; zHq=ySVo(E#1wh3vY;M%N&&XIYIZ<1k>47dt3KCrm3>P+U)aGQG{EAs%@+)13$w><9 E0Hq>-UH||9 diff --git a/bob/learn/em/data/gmm_ML_fitted.hdf5 b/bob/learn/em/data/gmm_ML_fitted.hdf5 new file mode 100644 index 0000000000000000000000000000000000000000..2e38a2b1ead3822d9029977eeb2894efa029edf0 GIT binary patch literal 11960 zcmeD5aB<`1lHy_j0S*oZ76t(j3y%Lo!45r$5S05L!ed}afHD}NbO)4P!31G2GO#d! z<Rl<m1_7w~3d|6J9T18kB*@j3fq{tuW<HFDN;AA*fv{la{Lp~#10d8WIm|=A!_^TS z)f=GxTmYpV*dYu?1`uRmfTa&XXgW(!gUB~P(+h-0JP9)gJzeuKcz_LIWZ-2GVc=j$ z%gjlQFH0>d&dkpPNk}p>f?WX0O`!D8z`y|#2ImD31DyOJLSQj;^U8}dOHzyClXDV_ zi$UrMnWqR+1T_zo^%xmg!RA5;1_nk9`vt&K0vs@xIe<)HU}fL{i!(7Xft}3(4N+!B zC<AH^STUG@#RQl$Xo+Aye-}`gD8O6@XTStd=?PHo2~c1%fWi+VfS#T#Ai@j+;0S<n zuu#ybVYmT~cq|H_La_8}zy=i<WE3b2!3hIQfWighGbopVnLz@aesHB%MEU|NV?d-Y z4ye~bIR@F&GmtrhmYo-$o>*F3oSB#h&aMn4MTwbtsYMLQ`FUljMd_(|$*J)r8AYkZ z8TmOW3=9mpi52l_nI$EedFk=RC8-4vT?`E6shQ~+CB+P-1u2Oosqwj~Py?a-vc#gy z#JuEGm~xmpm^KDfSAgON5@hJ13K3;saCi3i0mmIw0hrCezyM38hI$5IrO<4^088dJ z;F3u}Jt;p)FDEszC{HgnSI^zo*EcaaBQq~m-43h^VGc}xZVpreN_Fr<8Kcx_2#kin zXb6mkz-S1JhQMeDjE2By2#kinXb6mkz-S1Jh5*e%0M;)D(~zMYaNqyW%8!?PlLG8d zKy9{ITAiLg{fT|=Rt`hMi`VQK85!#yJ(TCy5_E7rEAd(6f3QRQk;64Bd>I^~e*KM* zviEjiXn10Or}e0P!`g=r-#v14Xt?>O;3(rJdj}`=lP4ZHK*k#xHrPOI7^Ox-U^E0q zLtr!nMnhmU1n3n4#LnjkD%eA3{u~%gxEW0TLnjCj^Ev4Ag%Z&D9B88g?xDdwpF=o2 zNS!xTz!45ZW8M@qJRsg-WQ5EVf!mD?pp_et2!_r7k)B>OAO?eH%^_w%NZ8Cfgg4M6 zW_Uo-HzYJb^IaO?iA@P6$lM!fejOsm!QkfT6YR>szz`A?3ZlRYaOWQjXm~($Bm6y5 z!^3hkJQzT0H6X#j%xDE+fKwYHbucgtuJi#}*Tsl)T~~lM#O=`OX9#DM95EpPTh|o< z^)YNc%>uObTCjx>=nEiVe8l1k$kGT{sflZu4`e+M3j@qv17)ZKa8MagaR(JB0|(Ur z6%SB_GH_7f#cT`=3<^p(6bunz%=!=#ag31WAOm{)0G&^Md%+r7{-MWk#ejw!BwaB> z+sX_K4qEt~19SZfUHs}`<>L$;{OX2_?7*d~=s}sTFvA1A-3eKDM|!)v8tPw&d5H8g Hxcg}Ub%*#p literal 0 HcmV?d00001 diff --git a/bob/learn/em/data/gmm_ML_legacy.hdf5 b/bob/learn/em/data/gmm_ML_legacy.hdf5 new file mode 100644 index 0000000000000000000000000000000000000000..74269fe3d824aa877e59609097828ee922d5dbc4 GIT binary patch literal 9128 zcmeD5aB<`1lHy_j0S*oZ76t(j3y%Lo!3t%F5S05L!ed}afHD}NbO)4P!31G2GJqfh zg9L=jAP6-dU0q0!t1ANoBLmEQ7!B3NV88-lc|fR9a)gC|hpS@%$jcERf&r9LAdC~x zbOzxuFyzMP#iu8h78hqG<`pwQ_?dYHr6nK^m=BjV0O`p^s5XS~%TqJcGhoUY7#Y|Y zz-b#yfb4~&WJX4a0E7gIgLE@6Ff#~%)eCSiFmQl{9Ka+i0|!`~iHQlUg@Zu?%x7jy zfGC5i1}kP@U=V`xL5kBWvLhH68emcbN(cM-yD&1aGH5_u(ZL2`fZPdUB|wXh4N!3l zkRSsCg9X%=Fm*d1!N_n!3Zeii&7i;x;o~Bq+88QWajAfa!^(*nN)Y~FBWa(mGN3^U zrJw;YSko1#mH?#<Sp3533H0;?D+e|}{R1u(7#J8hU>c#~)TRi<7pxe9(ag~NEvR4* z&B+c7Cfp1r|Dov{cX?0%jWB3Nf_elq-YTFn14H48FZ6s33lA7=^C8efT1(L3u2Vy` z?Ehc~hLvZ<em5VrZ}|GopZk#`q_AYzU<2{h2AlkvuDypJh&sUZzb=(JanzT=VeaAe z-um|54hI%goW9Po$^O88|J<C14jB4nBARdbXbHjfbN#t4YB~Lh{d@L+E4>%5*+c9Z zB}YSGGz3ONU^E0qLtr!ns2Ku;`t=P$5Ze?W)F?T^LIBpEg!L0&h~f_bXycecgB`y* zXhV?UhZuf!uvYd7Vf^Y~>GOsNesu#SOTYd}12hU?AqVmC;On2@AGd+^Ctx)B{j~|h zB0NBSZ%|CaX!65j0W>_I6v$Al<5>gUf5o3ZVBr9xNgvl+0Sy2sg)<xm$G9HUc?`Mn zxv7bHpmEc(#G=f^yyR4fOl~}kA77GDlv<pTpOXR_QcaJ~%P-0Wsn3f?9%qGx8*F|6 zVhZ(0LhXtK8Hio5avQ=MB}YmK;2I}<p-9^hpkF<_0a~yO?s^!#eRo10fA~W^#SkEk zUmdJ{cY{C%hlS1xS^Va~jM*WFU)?~-(l1>dfJOl<<RCsCx$V0P!y-Ij^)ZYlzkT-r q8Xiyzl02}?R}S~~-3w>{Kq;KzFxtL@1`SCR{mQ!!&<Gsd<sAUOdLD=X literal 0 HcmV?d00001 diff --git a/bob/learn/em/data/stats.hdf5 b/bob/learn/em/data/stats.hdf5 index c4a13700ec20079fdaacbd3841e8289910e9dd82..b125212ffea531c0e29450f96e1ebe116d38e5b4 100644 GIT binary patch literal 9552 zcmeD5aB<`1lHy_j0S*oZ76t(j3y%LoL4Ybm2+I8r;W02IKpBisx&unDV1h6h8CV!V zauN_Og8<Zg1!joA4hY2%66EU2z`)1=Gap7nr5RqZKv)3~YLpydA>iTa2#)#<P=7~2 zX$N*FgGQ7fG`*syYaRv<u#JoiybK}?91LlhIjQkwsYS(^`FS7-Nk&GnPEc+FrGEwn z4v;W7FMt@}<PQ-7i=mrWUX)pqS`?q0lUQ5~QcuV{d5|Kgd7!Mv$iNCV7eat?7DN)_ z76t}p1_7{?00+!v4j>a4SQ$9L;!I3TV5J-kketiR2xUP10agqq7J$jYP6Ye;yD&1a zGgv?!Dxe5u;Gi^M;tDVZb{Z0D3=9fN*yUlOkWge`Fkpj;V5cFm#K3R^oItRv0f|CV z9|MDdGKhy8gEI{S1A~JKZZ)LHP(NL9K!XP6bC_#}M!JHzHZMLsv9!22GcgaGT^aJ? z(^3;lN{dn<T!x(d^!S|2?9`mhjQso*28Ix*MsQ)lz`#&knj267qbor96`}xE&Ovz8 zC*7U>eZUzVR8TP>?1d#iLp=kqrJyne!~mBw3=9l5;F3u}Jt;p)FDEszC{HgnSI^zo zH@GCRq*&b!q!f%-@PnyQVl)IsLtr!nMnhmU1V%$(Gz3ONU^E0qLtr!nMnhmU1V%$( z07C%QuZPipZ(d%uI%v%S0q~#z14D)774yfbSq=(P>lKPK{2V3*g*<mFZ*_Ru(!X%# z%rg!*qQdpuV+tHhzdhJ<Mq!gfbX%VAsU3?Q4hw8L9A{?W7-i{Hx#{RchqrsoKl&^e zbsRvij#@n$0;3@?8UmvsFd71*Api{l!t*&2;Q2s71$*esp96ylH-pK4=mY`o`9c94 z^Exx2HV;mUaCne9Zz_Q!9EQgH7G`)ryu-){X&}PdjWC*-f!y?>05KRmYmS^YU^DN? pVgtv<3=hamF(YKUoB<XNxM=dj#{e2W=o8YQS;LVS9sz?gJOC~Wn%e*X literal 22635 zcmeD5aB<`1lHy_j0S*oZ76t(j3y%LoL3RX02+I8r;W02IKpBisx&unDV1h6h89<PM zK?1^M5QLhKt}Z0V)s=yPkpX5tjD~7sFkpeO0wB~VIl@A~!_^TS^&6o6h9*ab7xoZ= zid)@zr-kjdCT~2P*y!8-_zaKCVx75K=CX5Nd?@Sx=|2DIqCSRCYx#~=+?p*R^6Adn zJy#!QdjCH()A+o_oer++#r>KE27kXLoR@6%<zwKs6i|nl%gDga08ZOr0+d2Ez!Hp1 z5CI4Y5(f*z(=SY%nLz-oUx0&wfdj0>0Zg(oaDc^`m>~vnGbn)hj8Kh?V0lJ{l+?7G z#FA77POvVJ84_T1%uLJ`tPnAf4iE+#F9fv(Bw)DZvM&RJ2SgQ>$Y4Kz7f_feK+XFh z1!dr%AiiZ_xM7Dw0aO@PK5XEJ3gDo?o@8KPP+-QP04fZLB?bluHK+g%$^t6hVF+d5 zpumZbfq~(I7Y+p^3WLf!P{{}@=U_B90durr2_LEs6we?gjE02|NNg03hQMeDjE2By z2#kinXb6mkz-R~z{SW}Py}@lgFahd!Kw7%!{Sj#Q8{Qv*@tGMQ{Rcq>duUI=fx(2E z!Q?-*Jxzsv$pPpH0?36RYr#!Y%zg>1XFyq;;A%(1`V+WJH^b46200j{2kdANV-$~u zz-S1JhQMeDjE2By2#kinXb2385FoZ44Qj-KJPaG3h4E=PuFoL@vJlj?hhT716tn$I z-En<f?Pwl|TX2%`INH&$Z~{9Xr|waS(GVC7fzc2c4S~@R7!85Z5Eu=Cp%wyzxE&n; zU&1p8+R=p8btsIsqla43A9e3&2#kinXb6mkz-S1JhQMeDP!s}~?P$o#WJYGlNGgP6 zV1TVpfYGpVX&4_y({7yyWSu%Q6Uz)Eh>t-|16d2g@Zny}@oLCA6^ItXB(CvjO~P8i zDn5c-17<-8kRw4Xm@tfnI~XK6ibq3WGz3ONU^E0qLtr!nMnhmU1V%$(Gz5lZ2w?W7 zApH;Yb@H%&0jwVd<I}GH1X+L2%)~qcx&j_#B?yDNBAERqs46POLSB4&Vrg-4W@273 z*u*?A$-oc-rSsx5^9o8!Ad(CW48^6n0Tm1k3_1Df@j02<Ihh&x`6)2*3dqC?NGsa( z1xSb@jH~}ANWfG-G>3o!7NiA6g98~PFp5V*U^E0qLtr!nMnhmU1V%$(Gz5l42w=AR z2mL(C18+#kgPN$IX$cU9H&ZeD1CV|UL<?aOS38>6d6W|%*C1H~awLcg6Nb@n2ZJO> z@n{H)hQMeDjE2By2#kinXb6mkz-S1JhQN>vfr?wnNev93^Lv?5n5#GJN&07DdfL%3 uYp+vb+k9pQi=S+Y6}Ns@fBv-Ts$(`E1B3KQ)}0l%gcDL2m>C%Q8H4~FE_kv4 diff --git a/bob/learn/em/gmm.py b/bob/learn/em/gmm.py index 0583f1b..19680f7 100644 --- a/bob/learn/em/gmm.py +++ b/bob/learn/em/gmm.py @@ -598,7 +598,7 @@ class GMMMachine(BaseEstimator): return self.means.shape @classmethod - def from_hdf5(cls, hdf5, ubm=None): + def from_hdf5(cls, hdf5: Union[str, HDF5File], ubm: "GMMMachine" = None): """Creates a new GMMMachine object from an `HDF5File` object.""" if isinstance(hdf5, str): hdf5 = HDF5File(hdf5, "r") diff --git a/bob/learn/em/test/test_gmm.py b/bob/learn/em/test/test_gmm.py index 81793ef..cb9ebc2 100644 --- a/bob/learn/em/test/test_gmm.py +++ b/bob/learn/em/test/test_gmm.py @@ -294,9 +294,23 @@ def test_GMMMachine(): ) +def test_GMMMachine_legacy_loading(): + """Tests that old GMMMachine checkpoints are loaded correctly.""" + reference_file = resource_filename("bob.learn.em", "data/gmm_ML.hdf5") + legacy_gmm_file = resource_filename( + "bob.learn.em", "data/gmm_ML_legacy.hdf5" + ) + gmm = GMMMachine.from_hdf5(legacy_gmm_file) + assert isinstance(gmm, GMMMachine) + assert isinstance(gmm.n_gaussians, np.int64), type(gmm.n_gaussians) + assert isinstance(gmm.weights, np.ndarray), type(gmm.weights) + reference = GMMMachine.from_hdf5(reference_file) + np.testing.assert_allclose(gmm.variances, reference.variances) + assert gmm.is_similar_to(reference) + + def test_GMMMachine_stats(): """Tests a GMMMachine (statistics)""" - arrayset = load_array( resource_filename("bob.learn.em", "data/faithful.torch3_f64.hdf5") ) @@ -802,7 +816,9 @@ def test_gmm_ML_1(): resource_filename("bob.learn.em", "data/faithful.torch3_f64.hdf5") ) gmm_ref = GMMMachine.from_hdf5( - HDF5File(resource_filename("bob.learn.em", "data/gmm_ML.hdf5"), "r") + HDF5File( + resource_filename("bob.learn.em", "data/gmm_ML_fitted.hdf5"), "r" + ) ) for transform in (to_numpy, to_dask_array): @@ -823,8 +839,6 @@ def test_gmm_ML_1(): gmm.update_means = True gmm.update_variances = True gmm.update_weights = True - # Generate reference - # gmm.save(HDF5File(resource_filename("bob.learn.em", "data/gmm_ML.hdf5"), "w")) gmm = gmm.fit(ar) @@ -911,16 +925,6 @@ def test_gmm_MAP_1(): gmmprior = GMMMachine.from_hdf5( HDF5File(resource_filename("bob.learn.em", "data/gmm_ML.hdf5"), "r") ) - gmm = GMMMachine.from_hdf5( - HDF5File(resource_filename("bob.learn.em", "data/gmm_ML.hdf5"), "r"), - ubm=gmmprior, - ) - gmm.update_means = True - gmm.update_variances = False - gmm.update_weights = False - - # Generate reference - # gmm.save(HDF5File(resource_filename("bob.learn.em", "data/gmm_MAP.hdf5"), "w")) gmm_ref = GMMMachine.from_hdf5( HDF5File(resource_filename("bob.learn.em", "data/gmm_MAP.hdf5"), "r") @@ -928,13 +932,21 @@ def test_gmm_MAP_1(): for transform in (to_numpy, to_dask_array): ar = transform(ar) - gmm = gmm.fit(ar) - - np.testing.assert_almost_equal(gmm.means, gmm_ref.means, decimal=3) + gmm = GMMMachine.from_hdf5( + HDF5File( + resource_filename("bob.learn.em", "data/gmm_ML.hdf5"), "r" + ), + ubm=gmmprior, + ) + gmm.update_means = True + gmm.update_variances = False + gmm.update_weights = False + gmm.fit(ar) + np.testing.assert_almost_equal(gmm.means, gmm_ref.means, decimal=7) np.testing.assert_almost_equal( - gmm.variances, gmm_ref.variances, decimal=3 + gmm.variances, gmm_ref.variances, decimal=7 ) - np.testing.assert_almost_equal(gmm.weights, gmm_ref.weights, decimal=3) + np.testing.assert_almost_equal(gmm.weights, gmm_ref.weights, decimal=7) def test_gmm_MAP_2(): -- GitLab