From ccc390cc1affdb4d85bf3e4f4a62385e7fca7a81 Mon Sep 17 00:00:00 2001 From: dcarron <daniel.carron@idiap.ch> Date: Tue, 18 Jul 2023 15:58:50 +0200 Subject: [PATCH] Fixed issue with check_database_split_loading --- src/ptbench/data/split.py | 4 ++-- tests/test_ch.py | 6 ++---- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/src/ptbench/data/split.py b/src/ptbench/data/split.py index 78fe7e33..606e40cb 100644 --- a/src/ptbench/data/split.py +++ b/src/ptbench/data/split.py @@ -238,8 +238,8 @@ def check_database_split_loading( "Checking if can load all samples in all subsets of this split..." ) errors = 0 - for subset in database_split.keys(): - samples = subset if not limit else subset[:limit] + for subset, samples in database_split.items(): + samples = samples if not limit else samples[:limit] for pos, sample in enumerate(samples): try: data, _ = loader.sample(sample) diff --git a/tests/test_ch.py b/tests/test_ch.py index 510b1171..659e2c35 100644 --- a/tests/test_ch.py +++ b/tests/test_ch.py @@ -132,8 +132,6 @@ def test_loading(): metadata = s[1] assert isinstance(data, torch.Tensor) - - print(data.shape) assert _check_size(data.shape) # Check size assert ( @@ -176,7 +174,7 @@ def test_check(): assert ( check_database_split_loading( - database_split, raw_data_loader, limit=limit + database_split.subsets, raw_data_loader, limit=limit ) == 0 ) @@ -191,7 +189,7 @@ def test_check(): assert ( check_database_split_loading( - database_split, raw_data_loader, limit=limit + database_split.subsets, raw_data_loader, limit=limit ) == 0 ) -- GitLab