Skip to content
Snippets Groups Projects
Commit b280ec3a authored by André Anjos's avatar André Anjos :speech_balloon:
Browse files

[tests] Make naming more explicit

parent 900f69d4
No related branches found
No related tags found
1 merge request!6Making use of LightningDataModule and simplification of data loading
File moved
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""Tests for Montgomery dataset."""
import importlib
import pytest
@pytest.mark.skip(reason="Test need to be updated")
def test_protocol_consistency():
# Default protocol
datamodule = importlib.import_module(
"ptbench.data.montgomery.datamodules.default"
).datamodule
subset = datamodule.splits
assert len(subset) == 3
assert "train" in subset
assert len(subset["train"]) == 88
for s in subset["train"]:
assert s[0].startswith("CXR_png/MCUCXR_0")
assert "validation" in subset
assert len(subset["validation"]) == 22
for s in subset["validation"]:
assert s[0].startswith("CXR_png/MCUCXR_0")
assert "test" in subset
assert len(subset["test"]) == 28
for s in subset["test"]:
assert s[0].startswith("CXR_png/MCUCXR_0")
# Check labels
for s in subset["train"]:
assert s[1] in [0.0, 1.0]
for s in subset["validation"]:
assert s[1] in [0.0, 1.0]
for s in subset["test"]:
assert s[1] in [0.0, 1.0]
# Cross-validation fold 0-7
for f in range(8):
datamodule = importlib.import_module(
f"ptbench.data.montgomery.datamodules.fold_{str(f)}"
).datamodule
subset = datamodule.database_split
assert len(subset) == 3
assert "train" in subset
assert len(subset["train"]) == 99
for s in subset["train"]:
assert s[0].startswith("CXR_png/MCUCXR_0")
assert "validation" in subset
assert len(subset["validation"]) == 25
for s in subset["validation"]:
assert s[0].startswith("CXR_png/MCUCXR_0")
assert "test" in subset
assert len(subset["test"]) == 14
for s in subset["test"]:
assert s[0].startswith("CXR_png/MCUCXR_0")
# Check labels
for s in subset["train"]:
assert s[1] in [0.0, 1.0]
for s in subset["validation"]:
assert s[1] in [0.0, 1.0]
for s in subset["test"]:
assert s[1] in [0.0, 1.0]
# Cross-validation fold 8-9
for f in range(8, 10):
datamodule = importlib.import_module(
f"ptbench.data.montgomery.datamodules.fold_{str(f)}"
).datamodule
subset = datamodule.database_split
assert len(subset) == 3
assert "train" in subset
assert len(subset["train"]) == 100
for s in subset["train"]:
assert s[0].startswith("CXR_png/MCUCXR_0")
assert "validation" in subset
assert len(subset["validation"]) == 25
for s in subset["validation"]:
assert s[0].startswith("CXR_png/MCUCXR_0")
assert "test" in subset
assert len(subset["test"]) == 13
for s in subset["test"]:
assert s[0].startswith("CXR_png/MCUCXR_0")
# Check labels
for s in subset["train"]:
assert s[1] in [0.0, 1.0]
for s in subset["validation"]:
assert s[1] in [0.0, 1.0]
for s in subset["test"]:
assert s[1] in [0.0, 1.0]
@pytest.mark.skip(reason="Test need to be updated")
@pytest.mark.skip_if_rc_var_not_set("datadir.montgomery")
def test_loading():
import torch
import torchvision.transforms
from ptbench.data.datamodule import _DelayedLoadingDataset
def _check_sample(s):
assert len(s) == 2
data = s[0]
metadata = s[1]
assert isinstance(data, torch.Tensor)
assert data.size(0) == 1 # check single channel
assert data.size(1) == data.size(2) # check square image
assert (
torchvision.transforms.ToPILImage()(data).mode == "L"
) # Check colors
assert "label" in metadata
assert metadata["label"] in [0, 1] # Check labels
limit = 30 # use this to limit testing to first images only, else None
datamodule = importlib.import_module(
"ptbench.data.montgomery.datamodules.default"
).datamodule
subset = datamodule.database_split
raw_data_loader = datamodule.raw_data_loader
# Need to use private function so we can limit the number of samples to use
dataset = _DelayedLoadingDataset(subset["train"][:limit], raw_data_loader)
for s in dataset:
_check_sample(s)
@pytest.mark.skip(reason="Test need to be updated")
@pytest.mark.skip_if_rc_var_not_set("datadir.montgomery")
def test_check():
from ptbench.data.split import check_database_split_loading
limit = 30 # use this to limit testing to first images only, else 0
# Default protocol
datamodule = importlib.import_module(
"ptbench.data.montgomery.datamodules.default"
).datamodule
database_split = datamodule.database_split
raw_data_loader = datamodule.raw_data_loader
assert (
check_database_split_loading(
database_split, raw_data_loader, limit=limit
)
== 0
)
# Folds
for f in range(10):
datamodule = importlib.import_module(
f"ptbench.data.montgomery.datamodules.fold_{f}"
).datamodule
database_split = datamodule.database_split
raw_data_loader = datamodule.raw_data_loader
assert (
check_database_split_loading(
database_split, raw_data_loader, limit=limit
)
== 0
)
File moved
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment