From 46bacfef6adf47b5007549271c65bba55c2d4f20 Mon Sep 17 00:00:00 2001
From: Andre Anjos <andre.dos.anjos@gmail.com>
Date: Tue, 21 Apr 2020 11:22:01 +0200
Subject: [PATCH] [engine] Deploy h5py instead of bob.io.base for HDF5 I/O

---
 bob/ip/binseg/engine/evaluator.py | 5 +++--
 bob/ip/binseg/engine/predictor.py | 8 +++++---
 2 files changed, 8 insertions(+), 5 deletions(-)

diff --git a/bob/ip/binseg/engine/evaluator.py b/bob/ip/binseg/engine/evaluator.py
index 5b4f2a39..c3124ef9 100644
--- a/bob/ip/binseg/engine/evaluator.py
+++ b/bob/ip/binseg/engine/evaluator.py
@@ -13,7 +13,7 @@ from tqdm import tqdm
 import torch
 import torchvision.transforms.functional as VF
 
-import bob.io.base
+import h5py
 
 from ..utils.metric import base_metrics
 from ..utils.plot import precision_recall_f1iso_confintval
@@ -231,7 +231,8 @@ def run(data_loader, predictions_folder, output_folder, overlayed_folder=None,
         image = sample[1].to("cpu")
         gt = sample[2].to("cpu")
         pred_fullpath = os.path.join(predictions_folder, stem + ".hdf5")
-        pred = bob.io.base.load(pred_fullpath).astype("float32")
+        with h5py.File(pred_fullpath, "r") as f:
+            pred = f["array"][:]
         pred = torch.from_numpy(pred)
         if stem in data:
             raise RuntimeError(f"{stem} entry already exists in data. "
diff --git a/bob/ip/binseg/engine/predictor.py b/bob/ip/binseg/engine/predictor.py
index ca22cb3e..33a91d0b 100644
--- a/bob/ip/binseg/engine/predictor.py
+++ b/bob/ip/binseg/engine/predictor.py
@@ -12,7 +12,7 @@ from tqdm import tqdm
 import torch
 import torchvision.transforms.functional as VF
 
-import bob.io.base
+import h5py
 
 from ..utils.summary import summary
 
@@ -44,8 +44,10 @@ def _save_hdf5(stem, prob, output_folder):
     if not os.path.exists(fulldir):
         tqdm.write(f"Creating directory {fulldir}...")
         os.makedirs(fulldir, exist_ok=True)
-    bob.io.base.save(prob.cpu().squeeze(0).numpy(), fullpath)
-
+    with h5py.File(fullpath, 'w') as f:
+        data = prob.cpu().squeeze(0).numpy()
+        f.create_dataset("array", data=data, compression="gzip",
+                compression_opts=9)
 
 def _save_image(stem, extension, data, output_folder):
     """Saves a PIL image into a file
-- 
GitLab