From 797aeb754a77130e19e6e5b12130895d240e0047 Mon Sep 17 00:00:00 2001
From: Amir MOHAMMADI <amir.mohammadi@idiap.ch>
Date: Mon, 9 May 2022 18:12:05 +0200
Subject: [PATCH] Improve support for writing and reading scalars in hdf5

---
 bob/io/base/__init__.py       | 14 ++++++++++----
 bob/io/base/test/test_hdf5.py | 28 ++++++++++++++--------------
 2 files changed, 24 insertions(+), 18 deletions(-)

diff --git a/bob/io/base/__init__.py b/bob/io/base/__init__.py
index fd99f4a..cced69b 100644
--- a/bob/io/base/__init__.py
+++ b/bob/io/base/__init__.py
@@ -95,8 +95,12 @@ def open_file(filename):
                 raise RuntimeError(
                     f"The file {filename} does not contain the key {key}"
                 )
-
-            return np.array(f[key])
+            dataset = f[key]
+            # if the data was saved as a string, load it back as string
+            string_dtype = h5py.check_string_dtype(dataset.dtype)
+            if string_dtype is not None:
+                dataset = dataset.asstr()
+            return dataset[()]
     elif extension in image_extensions:
         from ..image import to_bob
 
@@ -235,8 +239,10 @@ def save(array, filename, create_directories=False):
     if create_directories:
         create_directories_safe(os.path.dirname(filename))
 
-    # requires data is c-contiguous and aligned, will create a copy otherwise
-    array = np.require(array, requirements=("C_CONTIGUOUS", "ALIGNED"))
+    # if array is a string, don't create a numpy array
+    if not isinstance(array, str):
+        # requires data is c-contiguous and aligned, will create a copy otherwise
+        array = np.require(array, requirements=("C_CONTIGUOUS", "ALIGNED"))
 
     write_file(filename, array)
 
diff --git a/bob/io/base/test/test_hdf5.py b/bob/io/base/test/test_hdf5.py
index ec0e148..a5fa23b 100644
--- a/bob/io/base/test/test_hdf5.py
+++ b/bob/io/base/test/test_hdf5.py
@@ -7,29 +7,24 @@
 """Tests for the base HDF5 infrastructure
 """
 
-import os
 import random
+import tempfile
 
 import numpy as np
 
 from bob.io.base import load, save
 
-from ..test_utils import temporary_filename
 
-
-def read_write_check(data):
+def read_write_check(data, numpy_assert=True):
     """Testing loading and save different file types"""
 
-    tmpname = temporary_filename()
-
-    try:
-
-        save(data, tmpname)
-        data2 = load(tmpname)
-    finally:
-        os.unlink(tmpname)
-
-    assert np.allclose(data, data2, atol=10e-5, rtol=10e-5)
+    with tempfile.NamedTemporaryFile(prefix="bobtest_", suffix=".hdf5") as f:
+        save(data, f.name)
+        data2 = load(f.name)
+        if numpy_assert:
+            assert np.allclose(data, data2, atol=10e-5, rtol=10e-5)
+        else:
+            assert data == data2
 
 
 def test_type_support():
@@ -54,3 +49,8 @@ def test_type_support():
     read_write_check(np.array(data, np.float64))
     read_write_check(np.array(data, np.complex64))
     read_write_check(np.array(data, np.complex128))
+
+
+def test_scalar_support():
+    for oracle in (1, 1.0, 1j, "a", True):
+        read_write_check(oracle, numpy_assert=False)
-- 
GitLab