From 74b56e509f2cb4e0a288804c9b9789af648e19ac Mon Sep 17 00:00:00 2001 From: Yannick DAYER <yannick.dayer@idiap.ch> Date: Mon, 21 Feb 2022 17:29:08 +0100 Subject: [PATCH] Support h5py v3; add tests for loaded objects. --- bob/learn/em/data/gmm_MAP.hdf5 | Bin 12920 -> 12016 bytes bob/learn/em/data/gmm_ML.hdf5 | Bin 12920 -> 12016 bytes bob/learn/em/mixture/gmm.py | 58 ++++++++++++++++----------------- bob/learn/em/test/test_gmm.py | 28 +++++++++++----- 4 files changed, 49 insertions(+), 37 deletions(-) diff --git a/bob/learn/em/data/gmm_MAP.hdf5 b/bob/learn/em/data/gmm_MAP.hdf5 index 0f57a2bde913bba76c2d9ab5a07cf81b07346d0f..8106590f1721a6c6084c63bc9637b52a5b39b24e 100644 GIT binary patch delta 801 zcmey7@*#GD2Ga+<iCPv4tPG3{3=9$s5Fo%H0A?sKgDHj!3=^BwBpFyCJeVp2O$a|= z<H7fgLK`3)h6PYsL1|)?9G8GPgd4#=@u1nn0ts0j1`h@n1_lO323`gc1`dX_%$(Ht zvecsD%=|ot$#YqyO+gL;X#(5L0b+qvGccg*EHBC|NiB*`&PgmT2J4iD>C^>jg6d=f zu^9v=9u#F}WMG}#&E_u&^CnaUSOJJwF!A8h$p$POObQH>CrYX@PMDm?A})pS0Z0u4 zg9U_Q5Kx?WP;T-D77nHl%o7jFalt~xgKgr$X<)mmS;Hkcpu&*wglc7AV3;xS;M~a; zk`hc6?305e`xXmu2u#+{<DIOlP_ubCM>#v#fY*}Yk`hn@&>agh0IaB+EtQ#pVZp?M zYbWo}TEGpn2;^CAnRHGKsD6RTg&N|MS24GNHA%~+b6P-U7%^llamrYs%S?Wtt-#r# z4e^-3=7k*J87CXCu<@`kL_owC9w>uCCj-j+paSAGKzR&mAYK5JcR^|5!F^!oMH`fh zSwpQ-fXYIsNm>n)1@r}&DzreM@1W1Y^gtJ62F%+RbS55r1~(%bml<Fcpx~{Bs=yY! E01z^QrT_o{ literal 12920 zcmeD5aB<`1lHy_j0S*oZ76t(j3y%LoL4^@S2+I8r;W02IKpBisx&unDV1h6h89<PM zK?1^M5QLhKt}Z0V)s=yPi2-IljD~7sFkpeO6d)9Xh6aSu0HH?7VIBe=u8sj9FD5_} z+654Yfq|hx3&H@UGLRH{I(wiEk++b9(4>-#3~US_b0Gu+10w?m*a)x~0|Nsyg9KPg zl93UdvO&2AA_q~&%D@4ZXJP_V3>?s$zzj18tP5-;m=HoTziaK1L<R;}P60~~I1%jU z@50Ey$^ecG1_p)~1j6+I)Ws*@<pD$hJzg>(!VEv8pbQ)oBvcs~6qs=+fC|IPi3(Pz zz#yY&m#ze$Spbw05uP6G>56c82!Jz?00#p!hd6*ehsa+{;B1G;UkcE0fTb9)at6$B zm;jL&p(H39fx;BzeNX_pJNx^9XjriXW`oN)kO~F{Lp=j91Dc%~biq6Z2AkxZ#NuKF z^`!hHy`0p<qCCCST)o`PijvZzRK4`vTs?PRU*E*!jLf`Lbvq>UVCLlJKn;RY4g65X zC^Z@aqaiRF0;3@?8UmvsFd71*Aut*OqaiRF0;3@?8UmvsK=lv+_2*&Z0ieD)2!kvk zJU&D2IH(46sBUPDSHZ@kBS3+}z`$Sur6bVBLvvFT^NJZ57|IfhG86NXQ$cKyQW!tJ zB%>&`I3qtN1tJ3*Z-E91NGVk?vGFIUU=JT(GvQ`1`41W(VqjpvJ<etUbt<88w!t-y zO*lMA9dC3XG8_iycq3+bK!#5l86hbWRt~^u*t`en>BR#YPK5G%1y#~LxrShd2PAz% zLIX7KqX8a|mS6(S2{AB$%1nqH2ZNiVPp~Tk14BqqD2M_pz@2|0py2`04Rfy$G#7$I zyTJ26BPBc{heUY5;%TJNpnMNX;~@Wns)Uajk2|*hcw%<~Bmk=amR6^yPk&<HyOqPx z@ZvRlMg|5`&4w&KEkTEa)ia7j{s%i$R!lg*%$LC-MJh^5(%##Fq4A0Roz|oF3>zLk zeD}!F!Qu9wf}@O^>>C`^Po8+-0GZ!m*kA*(0T&IC9VJIYU^E0qLtr!nMnhmU1ZWTf zkaa(7@P2Lw8-7bXpwb$UZUmHvNx^!B60(?b$RabK91S@r0|&JMDqf%oW#FJ->jpVw za44XWF#Y-!QA4U7ots*c7@wAzlNw)^T2!2wp9c+puy}b<W=U#MJa~x-R6H*}J+ZX7 zI1@6u!oW~cl$e>9TEvi?pI4Szl%ATGoQfD)VPIg$O{|Dd%Pc9$%uA0iE=etbjJ7Z^ zl&5B<XOt8(loq5UmZZkRhHGGa*f0%DIZPa;jRDma;I1Cng@dJ+4_Q~sh;v=7fIir{ zu;o4w&L}xzLIATof-J0Nge<3H0QLJofq{!Ae;g<pT0TK3kR=QZSk~vkR)#_q(3pbM zbIj1u7X}6fJ^U8JT%TcxUma|;D+0274!`1IAUkmBsu~(4PzoBBnB!^mO;-aK9<!kV K0i^~>cmMz>03&?> diff --git a/bob/learn/em/data/gmm_ML.hdf5 b/bob/learn/em/data/gmm_ML.hdf5 index 20362881b1661a826a8773d1658a8df559099d46..5e667e2498f69cf59303fc30c675c9f6f7b1fcc3 100644 GIT binary patch delta 797 zcmey7@*#GD2Ga+<iCPv4tPG3{3=9$s5Fo%H0A?sKgDHj!3=^BwBpFyCJeVp2O$a|= z<H7fgLK`3)h6PYsL1|)?9G8GPgd4#=@u1nn0ts0j1`h@n1_lO323`gc1`dX_%$(Ht zvecsD%=|ot$#YqyO+gL;X#(5L0b+qvGccg*EHBC|NiB*`&PgmT2J4iD>C^>jg6d=f zu^9v=9u#F}WMG}#&E_u&^CnaUSOJJwF!A8h$p$POObQH>CrYX@PMDm?A})pS0Z0u4 zg9U_Q5Kx?WP;T-D77nHl%o7jFalt~xgKgr$X<)mmS;Hkcpu&*wglc7AV3;xS;M~a; zk`hc6?305e`xXmu2u#+{<K3*Q@R=Q~`L$%Yqy$tmx<f&l!HT-sQkfYT7EC<2cJdCb z1>7)8Kpy3mN$1pn>KB+?s3AUi6>}R{leBC)rv+4o5ktljr;HW4%;X2!3Y;C<5Kjqg zUdZvCak2pm8xIRZ1VoJCfifs$GN8N<Dj;41l*gb3;sro?7nCL*+y{1Ev_ZL;HPk8v zs4Rq<q}4E4Kwp5VLJJh~4*DES4|G9hz`T7yXX3$Ua5JKDnE_S-3f^j{3T(j(05K|o AWB>pF delta 1137 zcmewm`y*w722+L6L@f(;Mg~R(1_p`AlFafGI}{`hSQx-S0ZcM5XlOtf4I2->XS7NH zsbyebxBwAkXwZT%K!$*%Bp_S{K?VUZ>wz|yVz7{$c+iZQk%4WpH(R(G0|!JDq6Wfd zU=V@|gUr#EU6wdGKzYI>Rt~NJsDc-g6A#KU9++&%DlUaE7b2Mfp%{KhO*|+!xq+2~ zNr8FdK{+m%FDh6k9-IcZUjXV%g7!~h1UXoKvXP>jFuL;?7$!_S_?MTNfq_8)>L13< zf7zJ@q;-*mm>3usKp5mcnD;fHt|Dmtgvk-?5=;@2lZ)8<CSKrxx?gqj0WBjn3#cJb zt3l>7U=2f<AB0#8G#!Z21oqifux*!>($zemng}>ILAMDST1+Y)Y7tOPpvZ)}4V&{& zHANDusbS)U50fNV7zH--N_=LYc!6UQD;qZ(g9X&4j>#KY<+(hdJPip@O6Y*{BxFIn z8Bm^v9Ei68$}7+W@fx5!4w;Du^T46t%_ct?<f$kU+{dArn_7|>pO%@E8ef)LRGgWg z2Q?QgUS5=0l3Em>oRe5wJXuj`Ipc(h2iqqLs4H*^=tI;AY@W#Roe^w*wN|-iG}LY= z1yO^MXBZeVCLYY4#L70gPn(O0K@a2~n2roXP~Lc;&A}95F!A6sxXIN}1yBmF$zV<1 PY>*^48>)j~5}W`4z%HhJ diff --git a/bob/learn/em/mixture/gmm.py b/bob/learn/em/mixture/gmm.py index 0be5a17..63171cf 100644 --- a/bob/learn/em/mixture/gmm.py +++ b/bob/learn/em/mixture/gmm.py @@ -80,42 +80,42 @@ class GMMStats: if isinstance(hdf5, str): hdf5 = HDF5File(hdf5, "r") try: - version_major, version_minor = hdf5.get("meta_file_version")[()].split(".") + version_major, version_minor = hdf5.attrs["file_version"].split(".") logger.debug( f"Reading a GMMStats HDF5 file of version {version_major}.{version_minor}" ) - except TypeError: + except (KeyError, RuntimeError): version_major, version_minor = 0, 0 if int(version_major) >= 1: - if hdf5["meta_writer_class"][()] != str(cls): - logger.warning(f"{hdf5['meta_writer_class'][()]} is not {cls}.") + if hdf5.attrs["writer_class"] != str(cls): + logger.warning(f"{hdf5.attrs['writer_class']} is not {cls}.") self = cls( n_gaussians=hdf5["n_gaussians"][()], n_features=hdf5["n_features"][()] ) self.log_likelihood = hdf5["log_likelihood"][()] self.t = hdf5["T"][()] - self.n = hdf5["n"][()] - self.sum_px = hdf5["sumPx"][()] - self.sum_pxx = hdf5["sumPxx"][()] + self.n = hdf5["n"][...] + self.sum_px = hdf5["sumPx"][...] + self.sum_pxx = hdf5["sumPxx"][...] else: # Legacy file version logger.info("Loading a legacy HDF5 stats file.") self = cls( n_gaussians=int(hdf5["n_gaussians"][()]), n_features=int(hdf5["n_inputs"][()]), ) - self.log_likelihood = float(hdf5["log_liklihood"][()]) + self.log_likelihood = hdf5["log_liklihood"][()] self.t = int(hdf5["T"][()]) - self.n = hdf5["n"][()].reshape((self.n_gaussians,)) - self.sum_px = hdf5["sumPx"][()].reshape(self.shape) - self.sum_pxx = hdf5["sumPxx"][()].reshape(self.shape) + self.n = np.reshape(hdf5["n"], (self.n_gaussians,)) + self.sum_px = np.reshape(hdf5["sumPx"], (self.shape)) + self.sum_pxx = np.reshape(hdf5["sumPxx"], (self.shape)) return self def save(self, hdf5): """Saves the current statistsics in an `HDF5File` object.""" if isinstance(hdf5, str): hdf5 = HDF5File(hdf5, "w") - hdf5["meta_file_version"] = "1.0" - hdf5["meta_writer_class"] = str(self.__class__) + hdf5.attrs["file_version"] = "1.0" + hdf5.attrs["writer_class"] = str(self.__class__) hdf5["n_gaussians"] = self.n_gaussians hdf5["n_features"] = self.n_features hdf5["log_likelihood"] = float(self.log_likelihood) @@ -438,16 +438,16 @@ class GMMMachine(BaseEstimator): if isinstance(hdf5, str): hdf5 = HDF5File(hdf5, "r") try: - version_major, version_minor = hdf5.get("meta_file_version")[()].split(".") + version_major, version_minor = hdf5.attrs["file_version"].split(".") logger.debug( f"Reading a GMMMachine HDF5 file of version {version_major}.{version_minor}" ) - except (TypeError, RuntimeError): + except (KeyError, RuntimeError): version_major, version_minor = 0, 0 if int(version_major) >= 1: - if hdf5["meta_writer_class"][()] != str(cls): - logger.warning(f"{hdf5['meta_writer_class'][()]} is not {cls}.") - if hdf5["trainer"][()] == "map" and ubm is None: + if hdf5.attrs["writer_class"] != str(cls): + logger.warning(f"{hdf5.attrs['writer_class']} is not {cls}.") + if hdf5["trainer"] == "map" and ubm is None: raise ValueError("The UBM is needed when loading a MAP machine.") self = cls( n_gaussians=hdf5["n_gaussians"][()], @@ -455,30 +455,30 @@ class GMMMachine(BaseEstimator): ubm=ubm, convergence_threshold=1e-5, max_fitting_steps=hdf5["max_fitting_steps"][()], - weights=hdf5["weights"][()], + weights=hdf5["weights"][...], k_means_trainer=None, update_means=hdf5["update_means"][()], update_variances=hdf5["update_variances"][()], update_weights=hdf5["update_weights"][()], ) gaussians_group = hdf5["gaussians"] - self.means = gaussians_group["means"][()] - self.variances = gaussians_group["variances"][()] - self.variance_thresholds = gaussians_group["variance_thresholds"][()] + self.means = gaussians_group["means"][...] + self.variances = gaussians_group["variances"][...] + self.variance_thresholds = gaussians_group["variance_thresholds"][...] else: # Legacy file version logger.info("Loading a legacy HDF5 machine file.") - n_gaussians = int(hdf5["m_n_gaussians"][()]) + n_gaussians = hdf5["m_n_gaussians"][()] g_means = [] g_variances = [] g_variance_thresholds = [] for i in range(n_gaussians): gaussian_group = hdf5[f"m_gaussians{i}"] - g_means.append(gaussian_group["m_mean"][()]) - g_variances.append(gaussian_group["m_variance"][()]) + g_means.append(gaussian_group["m_mean"][...]) + g_variances.append(gaussian_group["m_variance"][...]) g_variance_thresholds.append( - gaussian_group["m_variance_thresholds"][()] + gaussian_group["m_variance_thresholds"][...] ) - weights = hdf5["m_weights"][()].reshape(n_gaussians) + weights = np.reshape(hdf5["m_weights"], (n_gaussians,)) self = cls(n_gaussians=n_gaussians, ubm=ubm, weights=weights) self.means = np.array(g_means).reshape(n_gaussians, -1) self.variances = np.array(g_variances).reshape(n_gaussians, -1) @@ -491,8 +491,8 @@ class GMMMachine(BaseEstimator): """Saves the current statistics in an `HDF5File` object.""" if isinstance(hdf5, str): hdf5 = HDF5File(hdf5, "w") - hdf5["meta_file_version"] = "1.0" - hdf5["meta_writer_class"] = str(self.__class__) + hdf5.attrs["file_version"] = "1.0" + hdf5.attrs["writer_class"] = str(self.__class__) hdf5["n_gaussians"] = self.n_gaussians hdf5["trainer"] = self.trainer hdf5["convergence_threshold"] = self.convergence_threshold diff --git a/bob/learn/em/test/test_gmm.py b/bob/learn/em/test/test_gmm.py index 785b0ae..e99e160 100644 --- a/bob/learn/em/test/test_gmm.py +++ b/bob/learn/em/test/test_gmm.py @@ -55,6 +55,10 @@ def test_GMMStats(): assert (gs != gs_loaded ) is False assert gs.is_similar_to(gs_loaded) + assert type(gs_loaded.n_gaussians) is np.int64 + assert type(gs_loaded.n_features) is np.int64 + assert type(gs_loaded.log_likelihood) is np.float64 + # Saves and load from file using `load` filename = str(tempfile.mkstemp(".hdf5")[1]) gs.save(hdf5=HDF5File(filename, "w")) @@ -183,14 +187,22 @@ def test_GMMMachine(): assert gmm.is_similar_to(gmm6) is False # Saving and loading - filename = tempfile.mkstemp(suffix=".hdf5")[1] - gmm.save(HDF5File(filename, "w")) - gmm1 = GMMMachine.from_hdf5(HDF5File(filename, "r")) - assert gmm == gmm1 - gmm.save(filename) - gmm1 = GMMMachine.from_hdf5(filename) - assert gmm == gmm1 - os.unlink(filename) + with tempfile.NamedTemporaryFile(suffix=".hdf5") as f: + filename = f.name + gmm.save(HDF5File(filename, "w")) + gmm1 = GMMMachine.from_hdf5(HDF5File(filename, "r")) + assert type(gmm1.n_gaussians) is np.int64 + assert type(gmm1.update_means) is np.bool_ + assert type(gmm1.update_variances) is np.bool_ + assert type(gmm1.update_weights) is np.bool_ + assert type(gmm1.trainer) is str + assert gmm1.ubm is None + assert gmm == gmm1 + with tempfile.NamedTemporaryFile(suffix=".hdf5") as f: + filename = f.name + gmm.save(filename) + gmm1 = GMMMachine.from_hdf5(filename) + assert gmm == gmm1 # Weights n_gaussians = 5 -- GitLab