From 74b56e509f2cb4e0a288804c9b9789af648e19ac Mon Sep 17 00:00:00 2001
From: Yannick DAYER <yannick.dayer@idiap.ch>
Date: Mon, 21 Feb 2022 17:29:08 +0100
Subject: [PATCH] Support h5py v3; add tests for loaded objects.

---
 bob/learn/em/data/gmm_MAP.hdf5 | Bin 12920 -> 12016 bytes
 bob/learn/em/data/gmm_ML.hdf5  | Bin 12920 -> 12016 bytes
 bob/learn/em/mixture/gmm.py    |  58 ++++++++++++++++-----------------
 bob/learn/em/test/test_gmm.py  |  28 +++++++++++-----
 4 files changed, 49 insertions(+), 37 deletions(-)

diff --git a/bob/learn/em/data/gmm_MAP.hdf5 b/bob/learn/em/data/gmm_MAP.hdf5
index 0f57a2bde913bba76c2d9ab5a07cf81b07346d0f..8106590f1721a6c6084c63bc9637b52a5b39b24e 100644
GIT binary patch
delta 801
zcmey7@*#GD2Ga+<iCPv4tPG3{3=9$s5Fo%H0A?sKgDHj!3=^BwBpFyCJeVp2O$a|=
z<H7fgLK`3)h6PYsL1|)?9G8GPgd4#=@u1nn0ts0j1`h@n1_lO323`gc1`dX_%$(Ht
zvecsD%=|ot$#YqyO+gL;X#(5L0b+qvGccg*EHBC|NiB*`&PgmT2J4iD>C^>jg6d=f
zu^9v=9u#F}WMG}#&E_u&^CnaUSOJJwF!A8h$p$POObQH>CrYX@PMDm?A})pS0Z0u4
zg9U_Q5Kx?WP;T-D77nHl%o7jFalt~xgKgr$X<)mmS;Hkcpu&*wglc7AV3;xS;M~a;
zk`hc6?305e`xXmu2u#+{<DIOlP_ubCM>#v#fY*}Yk`hn@&>agh0IaB+EtQ#pVZp?M
zYbWo}TEGpn2;^CAnRHGKsD6RTg&N|MS24GNHA%~+b6P-U7%^llamrYs%S?Wtt-#r#
z4e^-3=7k*J87CXCu<@`kL_owC9w>uCCj-j+paSAGKzR&mAYK5JcR^|5!F^!oMH`fh
zSwpQ-fXYIsNm>n)1@r}&DzreM@1W1Y^gtJ62F%+RbS55r1~(%bml<Fcpx~{Bs=yY!
E01z^QrT_o{

literal 12920
zcmeD5aB<`1lHy_j0S*oZ76t(j3y%LoL4^@S2+I8r;W02IKpBisx&unDV1h6h89<PM
zK?1^M5QLhKt}Z0V)s=yPi2-IljD~7sFkpeO6d)9Xh6aSu0HH?7VIBe=u8sj9FD5_}
z+654Yfq|hx3&H@UGLRH{I(wiEk++b9(4>-#3~US_b0Gu+10w?m*a)x~0|Nsyg9KPg
zl93UdvO&2AA_q~&%D@4ZXJP_V3>?s$zzj18tP5-;m=HoTziaK1L<R;}P60~~I1%jU
z@50Ey$^ecG1_p)~1j6+I)Ws*@<pD$hJzg>(!VEv8pbQ)oBvcs~6qs=+fC|IPi3(Pz
zz#yY&m#ze$Spbw05uP6G>56c82!Jz?00#p!hd6*ehsa+{;B1G;UkcE0fTb9)at6$B
zm;jL&p(H39fx;BzeNX_pJNx^9XjriXW`oN)kO~F{Lp=j91Dc%~biq6Z2AkxZ#NuKF
z^`!hHy`0p<qCCCST)o`PijvZzRK4`vTs?PRU*E*!jLf`Lbvq>UVCLlJKn;RY4g65X
zC^Z@aqaiRF0;3@?8UmvsFd71*Aut*OqaiRF0;3@?8UmvsK=lv+_2*&Z0ieD)2!kvk
zJU&D2IH(46sBUPDSHZ@kBS3+}z`$Sur6bVBLvvFT^NJZ57|IfhG86NXQ$cKyQW!tJ
zB%>&`I3qtN1tJ3*Z-E91NGVk?vGFIUU=JT(GvQ`1`41W(VqjpvJ<etUbt<88w!t-y
zO*lMA9dC3XG8_iycq3+bK!#5l86hbWRt~^u*t`en>BR#YPK5G%1y#~LxrShd2PAz%
zLIX7KqX8a|mS6(S2{AB$%1nqH2ZNiVPp~Tk14BqqD2M_pz@2|0py2`04Rfy$G#7$I
zyTJ26BPBc{heUY5;%TJNpnMNX;~@Wns)Uajk2|*hcw%<~Bmk=amR6^yPk&<HyOqPx
z@ZvRlMg|5`&4w&KEkTEa)ia7j{s%i$R!lg*%$LC-MJh^5(%##Fq4A0Roz|oF3>zLk
zeD}!F!Qu9wf}@O^>>C`^Po8+-0GZ!m*kA*(0T&IC9VJIYU^E0qLtr!nMnhmU1ZWTf
zkaa(7@P2Lw8-7bXpwb$UZUmHvNx^!B60(?b$RabK91S@r0|&JMDqf%oW#FJ->jpVw
za44XWF#Y-!QA4U7ots*c7@wAzlNw)^T2!2wp9c+puy}b<W=U#MJa~x-R6H*}J+ZX7
zI1@6u!oW~cl$e>9TEvi?pI4Szl%ATGoQfD)VPIg$O{|Dd%Pc9$%uA0iE=etbjJ7Z^
zl&5B<XOt8(loq5UmZZkRhHGGa*f0%DIZPa;jRDma;I1Cng@dJ+4_Q~sh;v=7fIir{
zu;o4w&L}xzLIATof-J0Nge<3H0QLJofq{!Ae;g<pT0TK3kR=QZSk~vkR)#_q(3pbM
zbIj1u7X}6fJ^U8JT%TcxUma|;D+0274!`1IAUkmBsu~(4PzoBBnB!^mO;-aK9<!kV
K0i^~>cmMz>03&?>

diff --git a/bob/learn/em/data/gmm_ML.hdf5 b/bob/learn/em/data/gmm_ML.hdf5
index 20362881b1661a826a8773d1658a8df559099d46..5e667e2498f69cf59303fc30c675c9f6f7b1fcc3 100644
GIT binary patch
delta 797
zcmey7@*#GD2Ga+<iCPv4tPG3{3=9$s5Fo%H0A?sKgDHj!3=^BwBpFyCJeVp2O$a|=
z<H7fgLK`3)h6PYsL1|)?9G8GPgd4#=@u1nn0ts0j1`h@n1_lO323`gc1`dX_%$(Ht
zvecsD%=|ot$#YqyO+gL;X#(5L0b+qvGccg*EHBC|NiB*`&PgmT2J4iD>C^>jg6d=f
zu^9v=9u#F}WMG}#&E_u&^CnaUSOJJwF!A8h$p$POObQH>CrYX@PMDm?A})pS0Z0u4
zg9U_Q5Kx?WP;T-D77nHl%o7jFalt~xgKgr$X<)mmS;Hkcpu&*wglc7AV3;xS;M~a;
zk`hc6?305e`xXmu2u#+{<K3*Q@R=Q~`L$%Yqy$tmx<f&l!HT-sQkfYT7EC<2cJdCb
z1>7)8Kpy3mN$1pn>KB+?s3AUi6>}R{leBC)rv+4o5ktljr;HW4%;X2!3Y;C<5Kjqg
zUdZvCak2pm8xIRZ1VoJCfifs$GN8N<Dj;41l*gb3;sro?7nCL*+y{1Ev_ZL;HPk8v
zs4Rq<q}4E4Kwp5VLJJh~4*DES4|G9hz`T7yXX3$Ua5JKDnE_S-3f^j{3T(j(05K|o
AWB>pF

delta 1137
zcmewm`y*w722+L6L@f(;Mg~R(1_p`AlFafGI}{`hSQx-S0ZcM5XlOtf4I2->XS7NH
zsbyebxBwAkXwZT%K!$*%Bp_S{K?VUZ>wz|yVz7{$c+iZQk%4WpH(R(G0|!JDq6Wfd
zU=V@|gUr#EU6wdGKzYI>Rt~NJsDc-g6A#KU9++&%DlUaE7b2Mfp%{KhO*|+!xq+2~
zNr8FdK{+m%FDh6k9-IcZUjXV%g7!~h1UXoKvXP>jFuL;?7$!_S_?MTNfq_8)>L13<
zf7zJ@q;-*mm>3usKp5mcnD;fHt|Dmtgvk-?5=;@2lZ)8<CSKrxx?gqj0WBjn3#cJb
zt3l>7U=2f<AB0#8G#!Z21oqifux*!>($zemng}>ILAMDST1+Y)Y7tOPpvZ)}4V&{&
zHANDusbS)U50fNV7zH--N_=LYc!6UQD;qZ(g9X&4j>#KY<+(hdJPip@O6Y*{BxFIn
z8Bm^v9Ei68$}7+W@fx5!4w;Du^T46t%_ct?<f$kU+{dArn_7|>pO%@E8ef)LRGgWg
z2Q?QgUS5=0l3Em>oRe5wJXuj`Ipc(h2iqqLs4H*^=tI;AY@W#Roe^w*wN|-iG}LY=
z1yO^MXBZeVCLYY4#L70gPn(O0K@a2~n2roXP~Lc;&A}95F!A6sxXIN}1yBmF$zV<1
PY>*^48>)j~5}W`4z%HhJ

diff --git a/bob/learn/em/mixture/gmm.py b/bob/learn/em/mixture/gmm.py
index 0be5a17..63171cf 100644
--- a/bob/learn/em/mixture/gmm.py
+++ b/bob/learn/em/mixture/gmm.py
@@ -80,42 +80,42 @@ class GMMStats:
         if isinstance(hdf5, str):
             hdf5 = HDF5File(hdf5, "r")
         try:
-            version_major, version_minor = hdf5.get("meta_file_version")[()].split(".")
+            version_major, version_minor = hdf5.attrs["file_version"].split(".")
             logger.debug(
                 f"Reading a GMMStats HDF5 file of version {version_major}.{version_minor}"
             )
-        except TypeError:
+        except (KeyError, RuntimeError):
             version_major, version_minor = 0, 0
         if int(version_major) >= 1:
-            if hdf5["meta_writer_class"][()] != str(cls):
-                logger.warning(f"{hdf5['meta_writer_class'][()]} is not {cls}.")
+            if hdf5.attrs["writer_class"] != str(cls):
+                logger.warning(f"{hdf5.attrs['writer_class']} is not {cls}.")
             self = cls(
                 n_gaussians=hdf5["n_gaussians"][()], n_features=hdf5["n_features"][()]
             )
             self.log_likelihood = hdf5["log_likelihood"][()]
             self.t = hdf5["T"][()]
-            self.n = hdf5["n"][()]
-            self.sum_px = hdf5["sumPx"][()]
-            self.sum_pxx = hdf5["sumPxx"][()]
+            self.n = hdf5["n"][...]
+            self.sum_px = hdf5["sumPx"][...]
+            self.sum_pxx = hdf5["sumPxx"][...]
         else:  # Legacy file version
             logger.info("Loading a legacy HDF5 stats file.")
             self = cls(
                 n_gaussians=int(hdf5["n_gaussians"][()]),
                 n_features=int(hdf5["n_inputs"][()]),
             )
-            self.log_likelihood = float(hdf5["log_liklihood"][()])
+            self.log_likelihood = hdf5["log_liklihood"][()]
             self.t = int(hdf5["T"][()])
-            self.n = hdf5["n"][()].reshape((self.n_gaussians,))
-            self.sum_px = hdf5["sumPx"][()].reshape(self.shape)
-            self.sum_pxx = hdf5["sumPxx"][()].reshape(self.shape)
+            self.n = np.reshape(hdf5["n"], (self.n_gaussians,))
+            self.sum_px = np.reshape(hdf5["sumPx"], (self.shape))
+            self.sum_pxx = np.reshape(hdf5["sumPxx"], (self.shape))
         return self
 
     def save(self, hdf5):
         """Saves the current statistsics in an `HDF5File` object."""
         if isinstance(hdf5, str):
             hdf5 = HDF5File(hdf5, "w")
-        hdf5["meta_file_version"] = "1.0"
-        hdf5["meta_writer_class"] = str(self.__class__)
+        hdf5.attrs["file_version"] = "1.0"
+        hdf5.attrs["writer_class"] = str(self.__class__)
         hdf5["n_gaussians"] = self.n_gaussians
         hdf5["n_features"] = self.n_features
         hdf5["log_likelihood"] = float(self.log_likelihood)
@@ -438,16 +438,16 @@ class GMMMachine(BaseEstimator):
         if isinstance(hdf5, str):
             hdf5 = HDF5File(hdf5, "r")
         try:
-            version_major, version_minor = hdf5.get("meta_file_version")[()].split(".")
+            version_major, version_minor = hdf5.attrs["file_version"].split(".")
             logger.debug(
                 f"Reading a GMMMachine HDF5 file of version {version_major}.{version_minor}"
             )
-        except (TypeError, RuntimeError):
+        except (KeyError, RuntimeError):
             version_major, version_minor = 0, 0
         if int(version_major) >= 1:
-            if hdf5["meta_writer_class"][()] != str(cls):
-                logger.warning(f"{hdf5['meta_writer_class'][()]} is not {cls}.")
-            if hdf5["trainer"][()] == "map" and ubm is None:
+            if hdf5.attrs["writer_class"] != str(cls):
+                logger.warning(f"{hdf5.attrs['writer_class']} is not {cls}.")
+            if hdf5["trainer"] == "map" and ubm is None:
                 raise ValueError("The UBM is needed when loading a MAP machine.")
             self = cls(
                 n_gaussians=hdf5["n_gaussians"][()],
@@ -455,30 +455,30 @@ class GMMMachine(BaseEstimator):
                 ubm=ubm,
                 convergence_threshold=1e-5,
                 max_fitting_steps=hdf5["max_fitting_steps"][()],
-                weights=hdf5["weights"][()],
+                weights=hdf5["weights"][...],
                 k_means_trainer=None,
                 update_means=hdf5["update_means"][()],
                 update_variances=hdf5["update_variances"][()],
                 update_weights=hdf5["update_weights"][()],
             )
             gaussians_group = hdf5["gaussians"]
-            self.means = gaussians_group["means"][()]
-            self.variances = gaussians_group["variances"][()]
-            self.variance_thresholds = gaussians_group["variance_thresholds"][()]
+            self.means = gaussians_group["means"][...]
+            self.variances = gaussians_group["variances"][...]
+            self.variance_thresholds = gaussians_group["variance_thresholds"][...]
         else:  # Legacy file version
             logger.info("Loading a legacy HDF5 machine file.")
-            n_gaussians = int(hdf5["m_n_gaussians"][()])
+            n_gaussians = hdf5["m_n_gaussians"][()]
             g_means = []
             g_variances = []
             g_variance_thresholds = []
             for i in range(n_gaussians):
                 gaussian_group = hdf5[f"m_gaussians{i}"]
-                g_means.append(gaussian_group["m_mean"][()])
-                g_variances.append(gaussian_group["m_variance"][()])
+                g_means.append(gaussian_group["m_mean"][...])
+                g_variances.append(gaussian_group["m_variance"][...])
                 g_variance_thresholds.append(
-                    gaussian_group["m_variance_thresholds"][()]
+                    gaussian_group["m_variance_thresholds"][...]
                 )
-            weights = hdf5["m_weights"][()].reshape(n_gaussians)
+            weights = np.reshape(hdf5["m_weights"], (n_gaussians,))
             self = cls(n_gaussians=n_gaussians, ubm=ubm, weights=weights)
             self.means = np.array(g_means).reshape(n_gaussians, -1)
             self.variances = np.array(g_variances).reshape(n_gaussians, -1)
@@ -491,8 +491,8 @@ class GMMMachine(BaseEstimator):
         """Saves the current statistics in an `HDF5File` object."""
         if isinstance(hdf5, str):
             hdf5 = HDF5File(hdf5, "w")
-        hdf5["meta_file_version"] = "1.0"
-        hdf5["meta_writer_class"] = str(self.__class__)
+        hdf5.attrs["file_version"] = "1.0"
+        hdf5.attrs["writer_class"] = str(self.__class__)
         hdf5["n_gaussians"] = self.n_gaussians
         hdf5["trainer"] = self.trainer
         hdf5["convergence_threshold"] = self.convergence_threshold
diff --git a/bob/learn/em/test/test_gmm.py b/bob/learn/em/test/test_gmm.py
index 785b0ae..e99e160 100644
--- a/bob/learn/em/test/test_gmm.py
+++ b/bob/learn/em/test/test_gmm.py
@@ -55,6 +55,10 @@ def test_GMMStats():
   assert (gs != gs_loaded ) is False
   assert gs.is_similar_to(gs_loaded)
 
+  assert type(gs_loaded.n_gaussians) is np.int64
+  assert type(gs_loaded.n_features) is np.int64
+  assert type(gs_loaded.log_likelihood) is np.float64
+
   # Saves and load from file using `load`
   filename = str(tempfile.mkstemp(".hdf5")[1])
   gs.save(hdf5=HDF5File(filename, "w"))
@@ -183,14 +187,22 @@ def test_GMMMachine():
   assert gmm.is_similar_to(gmm6) is False
 
   # Saving and loading
-  filename = tempfile.mkstemp(suffix=".hdf5")[1]
-  gmm.save(HDF5File(filename, "w"))
-  gmm1 = GMMMachine.from_hdf5(HDF5File(filename, "r"))
-  assert gmm == gmm1
-  gmm.save(filename)
-  gmm1 = GMMMachine.from_hdf5(filename)
-  assert gmm == gmm1
-  os.unlink(filename)
+  with tempfile.NamedTemporaryFile(suffix=".hdf5") as f:
+      filename = f.name
+      gmm.save(HDF5File(filename, "w"))
+      gmm1 = GMMMachine.from_hdf5(HDF5File(filename, "r"))
+      assert type(gmm1.n_gaussians) is np.int64
+      assert type(gmm1.update_means) is np.bool_
+      assert type(gmm1.update_variances) is np.bool_
+      assert type(gmm1.update_weights) is np.bool_
+      assert type(gmm1.trainer) is str
+      assert gmm1.ubm is None
+      assert gmm == gmm1
+  with tempfile.NamedTemporaryFile(suffix=".hdf5") as f:
+      filename = f.name
+      gmm.save(filename)
+      gmm1 = GMMMachine.from_hdf5(filename)
+      assert gmm == gmm1
 
   # Weights
   n_gaussians = 5
-- 
GitLab