diff --git a/src/ptbench/data/split.py b/src/ptbench/data/split.py index 78fe7e33886cd47265d0a7f217bcb55f799c9fa4..606e40cbf5ca8e9fb16b8c7ae3c894a440042b0a 100644 --- a/src/ptbench/data/split.py +++ b/src/ptbench/data/split.py @@ -238,8 +238,8 @@ def check_database_split_loading( "Checking if can load all samples in all subsets of this split..." ) errors = 0 - for subset in database_split.keys(): - samples = subset if not limit else subset[:limit] + for subset, samples in database_split.items(): + samples = samples if not limit else samples[:limit] for pos, sample in enumerate(samples): try: data, _ = loader.sample(sample) diff --git a/tests/test_ch.py b/tests/test_ch.py index 510b1171741a9a4ea0176976a7aa45dddaa427fe..659e2c35ae092f3a90f7d072ba033786bb80bdf9 100644 --- a/tests/test_ch.py +++ b/tests/test_ch.py @@ -132,8 +132,6 @@ def test_loading(): metadata = s[1] assert isinstance(data, torch.Tensor) - - print(data.shape) assert _check_size(data.shape) # Check size assert ( @@ -176,7 +174,7 @@ def test_check(): assert ( check_database_split_loading( - database_split, raw_data_loader, limit=limit + database_split.subsets, raw_data_loader, limit=limit ) == 0 ) @@ -191,7 +189,7 @@ def test_check(): assert ( check_database_split_loading( - database_split, raw_data_loader, limit=limit + database_split.subsets, raw_data_loader, limit=limit ) == 0 )