diff --git a/tests/test_ch.py b/tests/test_ch.py index c678e0870ed6b69abbbb75ecd2f0185168020864..787fa95e49dab86692560cd987ff513b4f57c4c0 100644 --- a/tests/test_ch.py +++ b/tests/test_ch.py @@ -15,7 +15,8 @@ def test_protocol_consistency(): datamodule = importlib.import_module( "ptbench.data.shenzhen.default" ).datamodule - subset = datamodule.database_split.subsets + + subset = datamodule.database_split assert len(subset) == 3 @@ -49,7 +50,8 @@ def test_protocol_consistency(): datamodule = importlib.import_module( f"ptbench.data.shenzhen.fold_{str(f)}" ).datamodule - subset = datamodule.database_split.subsets + + subset = datamodule.database_split assert len(subset) == 3 @@ -83,7 +85,8 @@ def test_protocol_consistency(): datamodule = importlib.import_module( f"ptbench.data.shenzhen.fold_{str(f)}" ).datamodule - subset = datamodule.database_split.subsets + + subset = datamodule.database_split assert len(subset) == 3 @@ -128,11 +131,11 @@ def test_loading(): assert isinstance(data, torch.Tensor) - assert data.size(0) == 3 # check 3 channels + assert data.size(0) == 1 # check 1 channel assert data.size(1) == data.size(2) # check square image assert ( - torchvision.transforms.ToPILImage()(data).mode == "RGB" + torchvision.transforms.ToPILImage()(data).mode == "L" ) # Check colors assert "label" in metadata @@ -143,7 +146,7 @@ def test_loading(): datamodule = importlib.import_module( "ptbench.data.shenzhen.default" ).datamodule - subset = datamodule.database_split.subsets + subset = datamodule.database_split raw_data_loader = datamodule.raw_data_loader # Need to use private function so we can limit the number of samples to use @@ -171,7 +174,7 @@ def test_check(): assert ( check_database_split_loading( - database_split.subsets, raw_data_loader, limit=limit + database_split, raw_data_loader, limit=limit ) == 0 ) @@ -186,7 +189,7 @@ def test_check(): assert ( check_database_split_loading( - database_split.subsets, raw_data_loader, limit=limit + database_split, raw_data_loader, limit=limit ) == 0 ) diff --git a/tests/test_mc.py b/tests/test_mc.py index 25bd47095584469e86d41adac23926e52ecadbb4..3cc6adb0de48e74f9b8cf9cff7c140c4c6c7a32a 100644 --- a/tests/test_mc.py +++ b/tests/test_mc.py @@ -14,7 +14,8 @@ def test_protocol_consistency(): datamodule = importlib.import_module( "ptbench.data.montgomery.default" ).datamodule - subset = datamodule.database_split.subsets + + subset = datamodule.database_split assert len(subset) == 3 @@ -48,7 +49,7 @@ def test_protocol_consistency(): datamodule = importlib.import_module( f"ptbench.data.montgomery.fold_{str(f)}" ).datamodule - subset = datamodule.database_split.subsets + subset = datamodule.database_split assert len(subset) == 3 @@ -82,7 +83,7 @@ def test_protocol_consistency(): datamodule = importlib.import_module( f"ptbench.data.montgomery.fold_{str(f)}" ).datamodule - subset = datamodule.database_split.subsets + subset = datamodule.database_split assert len(subset) == 3 @@ -120,6 +121,8 @@ def test_loading(): from ptbench.data.datamodule import _DelayedLoadingDataset def _check_sample(s): + assert len(s) == 2 + data = s[0] metadata = s[1] @@ -140,7 +143,7 @@ def test_loading(): datamodule = importlib.import_module( "ptbench.data.montgomery.default" ).datamodule - subset = datamodule.database_split.subsets + subset = datamodule.database_split raw_data_loader = datamodule.raw_data_loader # Need to use private function so we can limit the number of samples to use