From 94d25ec88d7a28ddbaff36ec7aeaedc30d6c63a8 Mon Sep 17 00:00:00 2001 From: dcarron <daniel.carron@idiap.ch> Date: Wed, 26 Jul 2023 09:58:07 +0200 Subject: [PATCH] Fixed access to splits subsets in tests --- tests/test_ch.py | 19 +++++++++++-------- tests/test_mc.py | 11 +++++++---- 2 files changed, 18 insertions(+), 12 deletions(-) diff --git a/tests/test_ch.py b/tests/test_ch.py index c678e087..787fa95e 100644 --- a/tests/test_ch.py +++ b/tests/test_ch.py @@ -15,7 +15,8 @@ def test_protocol_consistency(): datamodule = importlib.import_module( "ptbench.data.shenzhen.default" ).datamodule - subset = datamodule.database_split.subsets + + subset = datamodule.database_split assert len(subset) == 3 @@ -49,7 +50,8 @@ def test_protocol_consistency(): datamodule = importlib.import_module( f"ptbench.data.shenzhen.fold_{str(f)}" ).datamodule - subset = datamodule.database_split.subsets + + subset = datamodule.database_split assert len(subset) == 3 @@ -83,7 +85,8 @@ def test_protocol_consistency(): datamodule = importlib.import_module( f"ptbench.data.shenzhen.fold_{str(f)}" ).datamodule - subset = datamodule.database_split.subsets + + subset = datamodule.database_split assert len(subset) == 3 @@ -128,11 +131,11 @@ def test_loading(): assert isinstance(data, torch.Tensor) - assert data.size(0) == 3 # check 3 channels + assert data.size(0) == 1 # check 1 channel assert data.size(1) == data.size(2) # check square image assert ( - torchvision.transforms.ToPILImage()(data).mode == "RGB" + torchvision.transforms.ToPILImage()(data).mode == "L" ) # Check colors assert "label" in metadata @@ -143,7 +146,7 @@ def test_loading(): datamodule = importlib.import_module( "ptbench.data.shenzhen.default" ).datamodule - subset = datamodule.database_split.subsets + subset = datamodule.database_split raw_data_loader = datamodule.raw_data_loader # Need to use private function so we can limit the number of samples to use @@ -171,7 +174,7 @@ def test_check(): assert ( check_database_split_loading( - database_split.subsets, raw_data_loader, limit=limit + database_split, raw_data_loader, limit=limit ) == 0 ) @@ -186,7 +189,7 @@ def test_check(): assert ( check_database_split_loading( - database_split.subsets, raw_data_loader, limit=limit + database_split, raw_data_loader, limit=limit ) == 0 ) diff --git a/tests/test_mc.py b/tests/test_mc.py index 25bd4709..3cc6adb0 100644 --- a/tests/test_mc.py +++ b/tests/test_mc.py @@ -14,7 +14,8 @@ def test_protocol_consistency(): datamodule = importlib.import_module( "ptbench.data.montgomery.default" ).datamodule - subset = datamodule.database_split.subsets + + subset = datamodule.database_split assert len(subset) == 3 @@ -48,7 +49,7 @@ def test_protocol_consistency(): datamodule = importlib.import_module( f"ptbench.data.montgomery.fold_{str(f)}" ).datamodule - subset = datamodule.database_split.subsets + subset = datamodule.database_split assert len(subset) == 3 @@ -82,7 +83,7 @@ def test_protocol_consistency(): datamodule = importlib.import_module( f"ptbench.data.montgomery.fold_{str(f)}" ).datamodule - subset = datamodule.database_split.subsets + subset = datamodule.database_split assert len(subset) == 3 @@ -120,6 +121,8 @@ def test_loading(): from ptbench.data.datamodule import _DelayedLoadingDataset def _check_sample(s): + assert len(s) == 2 + data = s[0] metadata = s[1] @@ -140,7 +143,7 @@ def test_loading(): datamodule = importlib.import_module( "ptbench.data.montgomery.default" ).datamodule - subset = datamodule.database_split.subsets + subset = datamodule.database_split raw_data_loader = datamodule.raw_data_loader # Need to use private function so we can limit the number of samples to use -- GitLab