Skip to content
Snippets Groups Projects
Commit b82e7485 authored by Daniel CARRON's avatar Daniel CARRON :b: Committed by André Anjos
Browse files

Fixed access to splits subsets in tests

parent 947d5ac2
No related branches found
No related tags found
1 merge request!6Making use of LightningDataModule and simplification of data loading
...@@ -15,7 +15,8 @@ def test_protocol_consistency(): ...@@ -15,7 +15,8 @@ def test_protocol_consistency():
datamodule = importlib.import_module( datamodule = importlib.import_module(
"ptbench.data.shenzhen.default" "ptbench.data.shenzhen.default"
).datamodule ).datamodule
subset = datamodule.database_split.subsets
subset = datamodule.database_split
assert len(subset) == 3 assert len(subset) == 3
...@@ -49,7 +50,8 @@ def test_protocol_consistency(): ...@@ -49,7 +50,8 @@ def test_protocol_consistency():
datamodule = importlib.import_module( datamodule = importlib.import_module(
f"ptbench.data.shenzhen.fold_{str(f)}" f"ptbench.data.shenzhen.fold_{str(f)}"
).datamodule ).datamodule
subset = datamodule.database_split.subsets
subset = datamodule.database_split
assert len(subset) == 3 assert len(subset) == 3
...@@ -83,7 +85,8 @@ def test_protocol_consistency(): ...@@ -83,7 +85,8 @@ def test_protocol_consistency():
datamodule = importlib.import_module( datamodule = importlib.import_module(
f"ptbench.data.shenzhen.fold_{str(f)}" f"ptbench.data.shenzhen.fold_{str(f)}"
).datamodule ).datamodule
subset = datamodule.database_split.subsets
subset = datamodule.database_split
assert len(subset) == 3 assert len(subset) == 3
...@@ -128,11 +131,11 @@ def test_loading(): ...@@ -128,11 +131,11 @@ def test_loading():
assert isinstance(data, torch.Tensor) assert isinstance(data, torch.Tensor)
assert data.size(0) == 3 # check 3 channels assert data.size(0) == 1 # check 1 channel
assert data.size(1) == data.size(2) # check square image assert data.size(1) == data.size(2) # check square image
assert ( assert (
torchvision.transforms.ToPILImage()(data).mode == "RGB" torchvision.transforms.ToPILImage()(data).mode == "L"
) # Check colors ) # Check colors
assert "label" in metadata assert "label" in metadata
...@@ -143,7 +146,7 @@ def test_loading(): ...@@ -143,7 +146,7 @@ def test_loading():
datamodule = importlib.import_module( datamodule = importlib.import_module(
"ptbench.data.shenzhen.default" "ptbench.data.shenzhen.default"
).datamodule ).datamodule
subset = datamodule.database_split.subsets subset = datamodule.database_split
raw_data_loader = datamodule.raw_data_loader raw_data_loader = datamodule.raw_data_loader
# Need to use private function so we can limit the number of samples to use # Need to use private function so we can limit the number of samples to use
...@@ -171,7 +174,7 @@ def test_check(): ...@@ -171,7 +174,7 @@ def test_check():
assert ( assert (
check_database_split_loading( check_database_split_loading(
database_split.subsets, raw_data_loader, limit=limit database_split, raw_data_loader, limit=limit
) )
== 0 == 0
) )
...@@ -186,7 +189,7 @@ def test_check(): ...@@ -186,7 +189,7 @@ def test_check():
assert ( assert (
check_database_split_loading( check_database_split_loading(
database_split.subsets, raw_data_loader, limit=limit database_split, raw_data_loader, limit=limit
) )
== 0 == 0
) )
...@@ -14,7 +14,8 @@ def test_protocol_consistency(): ...@@ -14,7 +14,8 @@ def test_protocol_consistency():
datamodule = importlib.import_module( datamodule = importlib.import_module(
"ptbench.data.montgomery.default" "ptbench.data.montgomery.default"
).datamodule ).datamodule
subset = datamodule.database_split.subsets
subset = datamodule.database_split
assert len(subset) == 3 assert len(subset) == 3
...@@ -48,7 +49,7 @@ def test_protocol_consistency(): ...@@ -48,7 +49,7 @@ def test_protocol_consistency():
datamodule = importlib.import_module( datamodule = importlib.import_module(
f"ptbench.data.montgomery.fold_{str(f)}" f"ptbench.data.montgomery.fold_{str(f)}"
).datamodule ).datamodule
subset = datamodule.database_split.subsets subset = datamodule.database_split
assert len(subset) == 3 assert len(subset) == 3
...@@ -82,7 +83,7 @@ def test_protocol_consistency(): ...@@ -82,7 +83,7 @@ def test_protocol_consistency():
datamodule = importlib.import_module( datamodule = importlib.import_module(
f"ptbench.data.montgomery.fold_{str(f)}" f"ptbench.data.montgomery.fold_{str(f)}"
).datamodule ).datamodule
subset = datamodule.database_split.subsets subset = datamodule.database_split
assert len(subset) == 3 assert len(subset) == 3
...@@ -120,6 +121,8 @@ def test_loading(): ...@@ -120,6 +121,8 @@ def test_loading():
from ptbench.data.datamodule import _DelayedLoadingDataset from ptbench.data.datamodule import _DelayedLoadingDataset
def _check_sample(s): def _check_sample(s):
assert len(s) == 2
data = s[0] data = s[0]
metadata = s[1] metadata = s[1]
...@@ -140,7 +143,7 @@ def test_loading(): ...@@ -140,7 +143,7 @@ def test_loading():
datamodule = importlib.import_module( datamodule = importlib.import_module(
"ptbench.data.montgomery.default" "ptbench.data.montgomery.default"
).datamodule ).datamodule
subset = datamodule.database_split.subsets subset = datamodule.database_split
raw_data_loader = datamodule.raw_data_loader raw_data_loader = datamodule.raw_data_loader
# Need to use private function so we can limit the number of samples to use # Need to use private function so we can limit the number of samples to use
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment