Skip to content
Snippets Groups Projects

Making use of LightningDataModule and simplification of data loading

Merged Daniel CARRON requested to merge add-datamodule into main
Compare and Show latest version
2 files
+ 18
12
Compare changes
  • Side-by-side
  • Inline
Files
2
+ 11
8
@@ -15,7 +15,8 @@ def test_protocol_consistency():
datamodule = importlib.import_module(
"ptbench.data.shenzhen.default"
).datamodule
subset = datamodule.database_split.subsets
subset = datamodule.database_split
assert len(subset) == 3
@@ -49,7 +50,8 @@ def test_protocol_consistency():
datamodule = importlib.import_module(
f"ptbench.data.shenzhen.fold_{str(f)}"
).datamodule
subset = datamodule.database_split.subsets
subset = datamodule.database_split
assert len(subset) == 3
@@ -83,7 +85,8 @@ def test_protocol_consistency():
datamodule = importlib.import_module(
f"ptbench.data.shenzhen.fold_{str(f)}"
).datamodule
subset = datamodule.database_split.subsets
subset = datamodule.database_split
assert len(subset) == 3
@@ -128,11 +131,11 @@ def test_loading():
assert isinstance(data, torch.Tensor)
assert data.size(0) == 3 # check 3 channels
assert data.size(0) == 1 # check 1 channel
assert data.size(1) == data.size(2) # check square image
assert (
torchvision.transforms.ToPILImage()(data).mode == "RGB"
torchvision.transforms.ToPILImage()(data).mode == "L"
) # Check colors
assert "label" in metadata
@@ -143,7 +146,7 @@ def test_loading():
datamodule = importlib.import_module(
"ptbench.data.shenzhen.default"
).datamodule
subset = datamodule.database_split.subsets
subset = datamodule.database_split
raw_data_loader = datamodule.raw_data_loader
# Need to use private function so we can limit the number of samples to use
@@ -171,7 +174,7 @@ def test_check():
assert (
check_database_split_loading(
database_split.subsets, raw_data_loader, limit=limit
database_split, raw_data_loader, limit=limit
)
== 0
)
@@ -186,7 +189,7 @@ def test_check():
assert (
check_database_split_loading(
database_split.subsets, raw_data_loader, limit=limit
database_split, raw_data_loader, limit=limit
)
== 0
)
Loading