diff --git a/bob/io/base/__init__.py b/bob/io/base/__init__.py index fd99f4a45ffeeb243817bca227344b44c64a3199..cced69bcb98f5cc04a97b56dbf4a9d998f355eaf 100644 --- a/bob/io/base/__init__.py +++ b/bob/io/base/__init__.py @@ -95,8 +95,12 @@ def open_file(filename): raise RuntimeError( f"The file {filename} does not contain the key {key}" ) - - return np.array(f[key]) + dataset = f[key] + # if the data was saved as a string, load it back as string + string_dtype = h5py.check_string_dtype(dataset.dtype) + if string_dtype is not None: + dataset = dataset.asstr() + return dataset[()] elif extension in image_extensions: from ..image import to_bob @@ -235,8 +239,10 @@ def save(array, filename, create_directories=False): if create_directories: create_directories_safe(os.path.dirname(filename)) - # requires data is c-contiguous and aligned, will create a copy otherwise - array = np.require(array, requirements=("C_CONTIGUOUS", "ALIGNED")) + # if array is a string, don't create a numpy array + if not isinstance(array, str): + # requires data is c-contiguous and aligned, will create a copy otherwise + array = np.require(array, requirements=("C_CONTIGUOUS", "ALIGNED")) write_file(filename, array) diff --git a/bob/io/base/test/test_hdf5.py b/bob/io/base/test/test_hdf5.py index ec0e148ce0500e23765b76577919f6280bc11a98..a5fa23b6cf290c3a7e92aa40421f8e1345105b7d 100644 --- a/bob/io/base/test/test_hdf5.py +++ b/bob/io/base/test/test_hdf5.py @@ -7,29 +7,24 @@ """Tests for the base HDF5 infrastructure """ -import os import random +import tempfile import numpy as np from bob.io.base import load, save -from ..test_utils import temporary_filename - -def read_write_check(data): +def read_write_check(data, numpy_assert=True): """Testing loading and save different file types""" - tmpname = temporary_filename() - - try: - - save(data, tmpname) - data2 = load(tmpname) - finally: - os.unlink(tmpname) - - assert np.allclose(data, data2, atol=10e-5, rtol=10e-5) + with tempfile.NamedTemporaryFile(prefix="bobtest_", suffix=".hdf5") as f: + save(data, f.name) + data2 = load(f.name) + if numpy_assert: + assert np.allclose(data, data2, atol=10e-5, rtol=10e-5) + else: + assert data == data2 def test_type_support(): @@ -54,3 +49,8 @@ def test_type_support(): read_write_check(np.array(data, np.float64)) read_write_check(np.array(data, np.complex64)) read_write_check(np.array(data, np.complex128)) + + +def test_scalar_support(): + for oracle in (1, 1.0, 1j, "a", True): + read_write_check(oracle, numpy_assert=False)