Skip to content
Snippets Groups Projects
Commit 797aeb75 authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

Improve support for writing and reading scalars in hdf5

parent cfed5de9
No related branches found
No related tags found
1 merge request!42Improve support for writing and reading scalars in hdf5
Pipeline #61041 passed
......@@ -95,8 +95,12 @@ def open_file(filename):
raise RuntimeError(
f"The file {filename} does not contain the key {key}"
)
return np.array(f[key])
dataset = f[key]
# if the data was saved as a string, load it back as string
string_dtype = h5py.check_string_dtype(dataset.dtype)
if string_dtype is not None:
dataset = dataset.asstr()
return dataset[()]
elif extension in image_extensions:
from ..image import to_bob
......@@ -235,8 +239,10 @@ def save(array, filename, create_directories=False):
if create_directories:
create_directories_safe(os.path.dirname(filename))
# requires data is c-contiguous and aligned, will create a copy otherwise
array = np.require(array, requirements=("C_CONTIGUOUS", "ALIGNED"))
# if array is a string, don't create a numpy array
if not isinstance(array, str):
# requires data is c-contiguous and aligned, will create a copy otherwise
array = np.require(array, requirements=("C_CONTIGUOUS", "ALIGNED"))
write_file(filename, array)
......
......@@ -7,29 +7,24 @@
"""Tests for the base HDF5 infrastructure
"""
import os
import random
import tempfile
import numpy as np
from bob.io.base import load, save
from ..test_utils import temporary_filename
def read_write_check(data):
def read_write_check(data, numpy_assert=True):
"""Testing loading and save different file types"""
tmpname = temporary_filename()
try:
save(data, tmpname)
data2 = load(tmpname)
finally:
os.unlink(tmpname)
assert np.allclose(data, data2, atol=10e-5, rtol=10e-5)
with tempfile.NamedTemporaryFile(prefix="bobtest_", suffix=".hdf5") as f:
save(data, f.name)
data2 = load(f.name)
if numpy_assert:
assert np.allclose(data, data2, atol=10e-5, rtol=10e-5)
else:
assert data == data2
def test_type_support():
......@@ -54,3 +49,8 @@ def test_type_support():
read_write_check(np.array(data, np.float64))
read_write_check(np.array(data, np.complex64))
read_write_check(np.array(data, np.complex128))
def test_scalar_support():
for oracle in (1, 1.0, 1j, "a", True):
read_write_check(oracle, numpy_assert=False)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment