Skip to content
Snippets Groups Projects
Commit 4ae951f1 authored by André Anjos's avatar André Anjos :speech_balloon:
Browse files

[data.shenzhen] Make configuration variable configurable

parent a0c5ae9e
No related branches found
No related tags found
1 merge request!6Making use of LightningDataModule and simplification of data loading
Pipeline #76715 failed
......@@ -55,5 +55,5 @@ class DataModule(CachingDataModule):
def __init__(self, split_filename: str):
super().__init__(
database_split=make_split(split_filename),
raw_data_loader=RawDataLoader(),
raw_data_loader=RawDataLoader(config_variable="datadir.indian"),
)
......@@ -34,9 +34,9 @@ class RawDataLoader(_BaseRawDataLoader):
datadir: str
def __init__(self):
def __init__(self, config_variable: str = "datadir.shenzhen"):
self.datadir = load_rc().get(
"datadir.shenzhen", os.path.realpath(os.curdir)
config_variable, os.path.realpath(os.curdir)
)
def sample(self, sample: tuple[str, int]) -> Sample:
......
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""Tests for Indian dataset."""
import pytest
"""Tests for Indian (a.k.a.
database A/database B) dataset.
"""
@pytest.mark.skip(reason="Test need to be updated")
def test_protocol_consistency():
from ptbench.data.indian import dataset
import pytest
import torch
# Default protocol
subset = dataset.subsets("default")
assert len(subset) == 3
from ptbench.data.indian.datamodule import make_split
assert "train" in subset
assert len(subset["train"]) == 83
for s in subset["train"]:
assert s.key.startswith("DatasetA/Training/")
assert "validation" in subset
assert len(subset["validation"]) == 20
for s in subset["validation"]:
assert s.key.startswith("DatasetA/Training/")
def _check_split(
split_filename: str,
lengths: dict[str, int],
prefix: str = "Dataset",
possible_labels: list[int] = [0, 1],
):
"""Runs a simple consistence check on the data split.
assert "test" in subset
assert len(subset["test"]) == 52
for s in subset["test"]:
assert s.key.startswith("DatasetA/Testing/")
Parameters
----------
# Check labels
for s in subset["train"]:
assert s.label in [0.0, 1.0]
split_filename
This is the split we will check
for s in subset["validation"]:
assert s.label in [0.0, 1.0]
lenghts
A dictionary that contains keys matching those of the split (this will
be checked). The values of the dictionary should correspond to the
sizes of each of the datasets in the split.
for s in subset["test"]:
assert s.label in [0.0, 1.0]
prefix
Each file named in a split should start with this prefix.
# Cross-validation fold 0-4
for f in range(5):
subset = dataset.subsets("fold_" + str(f))
assert len(subset) == 3
possible_labels
These are the list of possible labels contained in any split.
"""
assert "train" in subset
assert len(subset["train"]) == 111
for s in subset["train"]:
assert s.key.startswith("DatasetA")
split = make_split(split_filename)
assert "validation" in subset
assert len(subset["validation"]) == 28
for s in subset["validation"]:
assert s.key.startswith("DatasetA")
assert len(split) == len(lengths)
assert "test" in subset
assert len(subset["test"]) == 16
for s in subset["test"]:
assert s.key.startswith("DatasetA")
for k in lengths.keys():
# dataset must have been declared
assert k in split
# Check labels
for s in subset["train"]:
assert s.label in [0.0, 1.0]
assert len(split[k]) == lengths[k]
for s in split[k]:
assert s[0].startswith(prefix)
assert s[1] in possible_labels
for s in subset["validation"]:
assert s.label in [0.0, 1.0]
for s in subset["test"]:
assert s.label in [0.0, 1.0]
def _check_loaded_batch(
batch,
size: int = 1,
prefix: str = "Dataset",
possible_labels: list[int] = [0, 1],
):
"""Checks the consistence of an individual (loaded) batch.
# Cross-validation fold 5-9
for f in range(5, 10):
subset = dataset.subsets("fold_" + str(f))
assert len(subset) == 3
Parameters
----------
assert "train" in subset
assert len(subset["train"]) == 112
for s in subset["train"]:
assert s.key.startswith("DatasetA")
batch
The loaded batch to be checked.
assert "validation" in subset
assert len(subset["validation"]) == 28
for s in subset["validation"]:
assert s.key.startswith("DatasetA")
prefix
Each file named in a split should start with this prefix.
assert "test" in subset
assert len(subset["test"]) == 15
for s in subset["test"]:
assert s.key.startswith("DatasetA")
possible_labels
These are the list of possible labels contained in any split.
"""
# Check labels
for s in subset["train"]:
assert s.label in [0.0, 1.0]
assert len(batch) == 2 # data, metadata
for s in subset["validation"]:
assert s.label in [0.0, 1.0]
assert isinstance(batch[0], torch.Tensor)
assert batch[0].shape[0] == size # mini-batch size
assert batch[0].shape[1] == 1 # grayscale images
assert batch[0].shape[2] == batch[0].shape[3] # image is square
for s in subset["test"]:
assert s.label in [0.0, 1.0]
assert isinstance(batch[1], dict) # metadata
assert len(batch[1]) == 2 # label and name
assert "label" in batch[1]
assert all([k in possible_labels for k in batch[1]["label"]])
@pytest.mark.skip(reason="Test need to be updated")
@pytest.mark.skip_if_rc_var_not_set("datadir.indian")
def test_loading():
from ptbench.data.indian import dataset
assert "name" in batch[1]
assert all([k.startswith(prefix) for k in batch[1]["name"]])
def _check_size(size):
if (
size[0] >= 1024
and size[0] <= 2320
and size[1] >= 1024
and size[1] <= 2828
):
return True
return False
def _check_sample(s):
data = s.data
assert isinstance(data, dict)
assert len(data) == 2
assert "data" in data
assert _check_size(data["data"].size) # Check size
assert data["data"].mode == "L" # Check colors
assert "label" in data
assert data["label"] in [0, 1] # Check labels
def test_protocol_consistency():
_check_split(
"default.json",
lengths=dict(train=83, validation=20, test=52),
)
limit = 30 # use this to limit testing to first images only, else None
# Cross-validation fold 0-4
for k in range(5):
_check_split(
f"fold-{k}.json",
lengths=dict(train=111, validation=28, test=16),
)
subset = dataset.subsets("default")
for s in subset["train"][:limit]:
_check_sample(s)
# Cross-validation fold 5-9
for k in range(5, 10):
_check_split(
f"fold-{k}.json",
lengths=dict(train=112, validation=28, test=15),
)
@pytest.mark.skip(reason="Test need to be updated")
@pytest.mark.skip_if_rc_var_not_set("datadir.indian")
def test_check():
from ptbench.data.indian import dataset
assert dataset.check() == 0
def test_loading():
from ptbench.data.indian.default import datamodule
datamodule.model_transforms = [] # should be done before setup()
datamodule.setup("predict") # sets up all datasets
for loader in datamodule.predict_dataloader().values():
limit = 5 # limit load checking
for batch in loader:
if limit == 0:
break
_check_loaded_batch(batch)
limit -= 1
......@@ -3,202 +3,122 @@
# SPDX-License-Identifier: GPL-3.0-or-later
"""Tests for Shenzhen dataset."""
import importlib
import pytest
import torch
from ptbench.data.shenzhen.datamodule import make_split
def test_protocol_consistency():
# Default protocol
datamodule = getattr(
importlib.import_module("ptbench.data.shenzhen.datamodules"), "default"
)
subset = datamodule.splits
assert len(subset) == 3
assert "train" in subset
train_samples = subset["train"][0][0]
assert len(train_samples) == 422
for s in train_samples:
assert s[0].startswith("CXR_png/CHNCXR_0")
assert "validation" in subset
validation_samples = subset["validation"][0][0]
assert len(validation_samples) == 107
for s in validation_samples:
assert s[0].startswith("CXR_png/CHNCXR_0")
assert "test" in subset
test_samples = subset["test"][0][0]
assert len(test_samples) == 133
for s in test_samples:
assert s[0].startswith("CXR_png/CHNCXR_0")
# Check labels
for s in train_samples:
assert s[1] in [0.0, 1.0]
def _check_split(
split_filename: str,
lengths: dict[str, int],
prefix: str = "CXR_png/CHNCXR_0",
possible_labels: list[int] = [0, 1],
):
"""Runs a simple consistence check on the data split.
for s in validation_samples:
assert s[1] in [0.0, 1.0]
Parameters
----------
for s in test_samples:
assert s[1] in [0.0, 1.0]
split_filename
This is the split we will check
# Cross-validation folds 0-1
for f in range(2):
datamodule = getattr(
importlib.import_module("ptbench.data.shenzhen.datamodules"),
f"fold_{str(f)}",
)
subset = datamodule.splits
assert len(subset) == 3
lenghts
A dictionary that contains keys matching those of the split (this will
be checked). The values of the dictionary should correspond to the
sizes of each of the datasets in the split.
assert "train" in subset
train_samples = subset["train"][0][0]
assert len(train_samples) == 476
for s in train_samples:
assert s[0].startswith("CXR_png/CHNCXR_0")
prefix
Each file named in a split should start with this prefix.
assert "validation" in subset
validation_samples = subset["validation"][0][0]
assert len(validation_samples) == 119
for s in validation_samples:
assert s[0].startswith("CXR_png/CHNCXR_0")
possible_labels
These are the list of possible labels contained in any split.
"""
assert "test" in subset
test_samples = subset["test"][0][0]
assert len(test_samples) == 67
for s in test_samples:
assert s[0].startswith("CXR_png/CHNCXR_0")
split = make_split(split_filename)
# Check labels
for s in train_samples:
assert s[1] in [0.0, 1.0]
assert len(split) == len(lengths)
for s in validation_samples:
assert s[1] in [0.0, 1.0]
for k in lengths.keys():
# dataset must have been declared
assert k in split
for s in test_samples:
assert s[1] in [0.0, 1.0]
# Cross-validation folds 2-9
for f in range(2, 10):
datamodule = getattr(
importlib.import_module("ptbench.data.shenzhen.datamodules"),
f"fold_{str(f)}",
)
assert len(split[k]) == lengths[k]
for s in split[k]:
assert s[0].startswith(prefix)
assert s[1] in possible_labels
subset = datamodule.splits
assert len(subset) == 3
def _check_loaded_batch(
batch,
size: int = 1,
prefix: str = "CXR_png/CHNCXR_0",
possible_labels: list[int] = [0, 1],
):
"""Checks the consistence of an individual (loaded) batch.
assert "train" in subset
train_samples = subset["train"][0][0]
assert len(train_samples) == 476
for s in train_samples:
assert s[0].startswith("CXR_png/CHNCXR_0")
Parameters
----------
assert "validation" in subset
validation_samples = subset["validation"][0][0]
assert len(validation_samples) == 120
for s in validation_samples:
assert s[0].startswith("CXR_png/CHNCXR_0")
batch
The loaded batch to be checked.
assert "test" in subset
test_samples = subset["test"][0][0]
assert len(test_samples) == 66
for s in test_samples:
assert s[0].startswith("CXR_png/CHNCXR_0")
prefix
Each file named in a split should start with this prefix.
# Check labels
for s in train_samples:
assert s[1] in [0.0, 1.0]
possible_labels
These are the list of possible labels contained in any split.
"""
for s in validation_samples:
assert s[1] in [0.0, 1.0]
assert len(batch) == 2 # data, metadata
for s in test_samples:
assert s[1] in [0.0, 1.0]
@pytest.mark.skip_if_rc_var_not_set("datadir.shenzhen")
def test_loading():
import torch
import torchvision.transforms
assert isinstance(batch[0], torch.Tensor)
assert batch[0].shape[0] == size # mini-batch size
assert batch[0].shape[1] == 1 # grayscale images
assert batch[0].shape[2] == batch[0].shape[3] # image is square
from ptbench.data.datamodule import _DelayedLoadingDataset
assert isinstance(batch[1], dict) # metadata
assert len(batch[1]) == 2 # label and name
def _check_sample(s):
assert len(s) == 2
assert "label" in batch[1]
assert all([k in possible_labels for k in batch[1]["label"]])
data = s[0]
metadata = s[1]
assert "name" in batch[1]
assert all([k.startswith(prefix) for k in batch[1]["name"]])
assert isinstance(data, torch.Tensor)
assert data.size(0) == 1 # check 1 channel
assert data.size(1) == data.size(2) # check square image
assert (
torchvision.transforms.ToPILImage()(data).mode == "L"
) # Check colors
assert "label" in metadata
assert metadata["label"] in [0, 1] # Check labels
limit = 30 # use this to limit testing to first images only, else None
module = importlib.import_module("ptbench.data.shenzhen.datamodules")
datamodule = getattr(module, "default")
raw_data_loader = module.RawDataLoader()
subset = datamodule.splits
# Need to use private function so we can limit the number of samples to use
dataset = _DelayedLoadingDataset(
subset["train"][0][0][:limit],
raw_data_loader,
def test_protocol_consistency():
_check_split(
"default.json",
lengths=dict(train=422, validation=107, test=133),
)
for s in dataset:
_check_sample(s)
@pytest.mark.skip(reason="Test need to be updated")
@pytest.mark.skip_if_rc_var_not_set("datadir.shenzhen")
def test_check():
from ptbench.data.split import check_database_split_loading
limit = 30 # use this to limit testing to first images only, else 0
# Default protocol
module = importlib.import_module("ptbench.data.shenzhen.datamodules")
datamodule = getattr(module, "default")
database_split = datamodule.splits
raw_data_loader = module.RawDataLoader()
assert (
check_database_split_loading(
database_split, raw_data_loader, limit=limit
# Cross-validation fold 0-1
for k in range(2):
_check_split(
f"fold-{k}.json",
lengths=dict(train=476, validation=119, test=67),
)
== 0
)
# Folds
for f in range(10):
module = importlib.import_module("ptbench.data.shenzhen.datamodules")
datamodule = getattr(module, f"fold_{f}")
# Cross-validation fold 2-9
for k in range(2, 10):
_check_split(
f"fold-{k}.json",
lengths=dict(train=476, validation=120, test=66),
)
database_split = datamodule.splits
raw_data_loader = module.RawDataLoader()
assert (
check_database_split_loading(
database_split, raw_data_loader, limit=limit
)
== 0
)
@pytest.mark.skip_if_rc_var_not_set("datadir.shenzhen")
def test_loading():
from ptbench.data.shenzhen.default import datamodule
datamodule.model_transforms = [] # should be done before setup()
datamodule.setup("predict") # sets up all datasets
for loader in datamodule.predict_dataloader().values():
limit = 5 # limit load checking
for batch in loader:
if limit == 0:
break
_check_loaded_batch(batch)
limit -= 1
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment