Skip to content
Snippets Groups Projects
Commit 94d25ec8 authored by Daniel CARRON's avatar Daniel CARRON :b:
Browse files

Fixed access to splits subsets in tests

parent 6151cb51
No related branches found
No related tags found
No related merge requests found
Pipeline #76484 failed
......@@ -15,7 +15,8 @@ def test_protocol_consistency():
datamodule = importlib.import_module(
"ptbench.data.shenzhen.default"
).datamodule
subset = datamodule.database_split.subsets
subset = datamodule.database_split
assert len(subset) == 3
......@@ -49,7 +50,8 @@ def test_protocol_consistency():
datamodule = importlib.import_module(
f"ptbench.data.shenzhen.fold_{str(f)}"
).datamodule
subset = datamodule.database_split.subsets
subset = datamodule.database_split
assert len(subset) == 3
......@@ -83,7 +85,8 @@ def test_protocol_consistency():
datamodule = importlib.import_module(
f"ptbench.data.shenzhen.fold_{str(f)}"
).datamodule
subset = datamodule.database_split.subsets
subset = datamodule.database_split
assert len(subset) == 3
......@@ -128,11 +131,11 @@ def test_loading():
assert isinstance(data, torch.Tensor)
assert data.size(0) == 3 # check 3 channels
assert data.size(0) == 1 # check 1 channel
assert data.size(1) == data.size(2) # check square image
assert (
torchvision.transforms.ToPILImage()(data).mode == "RGB"
torchvision.transforms.ToPILImage()(data).mode == "L"
) # Check colors
assert "label" in metadata
......@@ -143,7 +146,7 @@ def test_loading():
datamodule = importlib.import_module(
"ptbench.data.shenzhen.default"
).datamodule
subset = datamodule.database_split.subsets
subset = datamodule.database_split
raw_data_loader = datamodule.raw_data_loader
# Need to use private function so we can limit the number of samples to use
......@@ -171,7 +174,7 @@ def test_check():
assert (
check_database_split_loading(
database_split.subsets, raw_data_loader, limit=limit
database_split, raw_data_loader, limit=limit
)
== 0
)
......@@ -186,7 +189,7 @@ def test_check():
assert (
check_database_split_loading(
database_split.subsets, raw_data_loader, limit=limit
database_split, raw_data_loader, limit=limit
)
== 0
)
......@@ -14,7 +14,8 @@ def test_protocol_consistency():
datamodule = importlib.import_module(
"ptbench.data.montgomery.default"
).datamodule
subset = datamodule.database_split.subsets
subset = datamodule.database_split
assert len(subset) == 3
......@@ -48,7 +49,7 @@ def test_protocol_consistency():
datamodule = importlib.import_module(
f"ptbench.data.montgomery.fold_{str(f)}"
).datamodule
subset = datamodule.database_split.subsets
subset = datamodule.database_split
assert len(subset) == 3
......@@ -82,7 +83,7 @@ def test_protocol_consistency():
datamodule = importlib.import_module(
f"ptbench.data.montgomery.fold_{str(f)}"
).datamodule
subset = datamodule.database_split.subsets
subset = datamodule.database_split
assert len(subset) == 3
......@@ -120,6 +121,8 @@ def test_loading():
from ptbench.data.datamodule import _DelayedLoadingDataset
def _check_sample(s):
assert len(s) == 2
data = s[0]
metadata = s[1]
......@@ -140,7 +143,7 @@ def test_loading():
datamodule = importlib.import_module(
"ptbench.data.montgomery.default"
).datamodule
subset = datamodule.database_split.subsets
subset = datamodule.database_split
raw_data_loader = datamodule.raw_data_loader
# Need to use private function so we can limit the number of samples to use
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment