From 797aeb754a77130e19e6e5b12130895d240e0047 Mon Sep 17 00:00:00 2001 From: Amir MOHAMMADI <amir.mohammadi@idiap.ch> Date: Mon, 9 May 2022 18:12:05 +0200 Subject: [PATCH] Improve support for writing and reading scalars in hdf5 --- bob/io/base/__init__.py | 14 ++++++++++---- bob/io/base/test/test_hdf5.py | 28 ++++++++++++++-------------- 2 files changed, 24 insertions(+), 18 deletions(-) diff --git a/bob/io/base/__init__.py b/bob/io/base/__init__.py index fd99f4a..cced69b 100644 --- a/bob/io/base/__init__.py +++ b/bob/io/base/__init__.py @@ -95,8 +95,12 @@ def open_file(filename): raise RuntimeError( f"The file {filename} does not contain the key {key}" ) - - return np.array(f[key]) + dataset = f[key] + # if the data was saved as a string, load it back as string + string_dtype = h5py.check_string_dtype(dataset.dtype) + if string_dtype is not None: + dataset = dataset.asstr() + return dataset[()] elif extension in image_extensions: from ..image import to_bob @@ -235,8 +239,10 @@ def save(array, filename, create_directories=False): if create_directories: create_directories_safe(os.path.dirname(filename)) - # requires data is c-contiguous and aligned, will create a copy otherwise - array = np.require(array, requirements=("C_CONTIGUOUS", "ALIGNED")) + # if array is a string, don't create a numpy array + if not isinstance(array, str): + # requires data is c-contiguous and aligned, will create a copy otherwise + array = np.require(array, requirements=("C_CONTIGUOUS", "ALIGNED")) write_file(filename, array) diff --git a/bob/io/base/test/test_hdf5.py b/bob/io/base/test/test_hdf5.py index ec0e148..a5fa23b 100644 --- a/bob/io/base/test/test_hdf5.py +++ b/bob/io/base/test/test_hdf5.py @@ -7,29 +7,24 @@ """Tests for the base HDF5 infrastructure """ -import os import random +import tempfile import numpy as np from bob.io.base import load, save -from ..test_utils import temporary_filename - -def read_write_check(data): +def read_write_check(data, numpy_assert=True): """Testing loading and save different file types""" - tmpname = temporary_filename() - - try: - - save(data, tmpname) - data2 = load(tmpname) - finally: - os.unlink(tmpname) - - assert np.allclose(data, data2, atol=10e-5, rtol=10e-5) + with tempfile.NamedTemporaryFile(prefix="bobtest_", suffix=".hdf5") as f: + save(data, f.name) + data2 = load(f.name) + if numpy_assert: + assert np.allclose(data, data2, atol=10e-5, rtol=10e-5) + else: + assert data == data2 def test_type_support(): @@ -54,3 +49,8 @@ def test_type_support(): read_write_check(np.array(data, np.float64)) read_write_check(np.array(data, np.complex64)) read_write_check(np.array(data, np.complex128)) + + +def test_scalar_support(): + for oracle in (1, 1.0, 1j, "a", True): + read_write_check(oracle, numpy_assert=False) -- GitLab