Skip to content
Snippets Groups Projects
Commit 513bed9d authored by André Anjos's avatar André Anjos :speech_balloon:
Browse files

[tests.test_nih_cxr14] Adapt to new datamodule syntax

parent 1a6f015c
No related branches found
No related tags found
1 merge request!6Making use of LightningDataModule and simplification of data loading
Pipeline #76737 failed
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""Tests for NIH CXR14 dataset."""
"""Tests for NIH CXR-14 dataset."""
import pytest
@pytest.mark.skip(reason="Test need to be updated")
def test_protocol_consistency():
from ptbench.data.nih_cxr14_re import dataset
# Default protocol
subset = dataset.subsets("default")
assert len(subset) == 3
assert "train" in subset
assert len(subset["train"]) == 98637
for s in subset["train"]:
assert s.key.startswith("images/000")
assert "validation" in subset
assert len(subset["validation"]) == 6350
for s in subset["validation"]:
assert s.key.startswith("images/000")
assert "test" in subset
assert len(subset["test"]) == 4054
for s in subset["test"]:
assert s.key.startswith("images/000")
# Check labels
for s in subset["train"]:
for element in list(set(s.label)):
assert element in [0.0, 1.0]
import importlib
for s in subset["validation"]:
for element in list(set(s.label)):
assert element in [0.0, 1.0]
for s in subset["test"]:
for element in list(set(s.label)):
assert element in [0.0, 1.0]
@pytest.mark.skip(reason="Test need to be updated")
@pytest.mark.skip_if_rc_var_not_set("datadir.nih_cxr14_re")
def test_loading():
from ptbench.data.nih_cxr14_re import dataset
def _check_size(size):
if size == (1024, 1024):
return True
return False
def _check_sample(s):
data = s.data
assert isinstance(data, dict)
assert len(data) == 2
assert "data" in data
assert _check_size(data["data"].size) # Check size
assert data["data"].mode == "RGB" # Check colors
assert "label" in data
assert len(data["label"]) == 14 # Check labels
import pytest
limit = 30 # use this to limit testing to first images only, else None
subset = dataset.subsets("default")
for s in subset["train"][:limit]:
_check_sample(s)
def id_function(val):
if isinstance(val, dict):
return str(val)
return repr(val)
@pytest.mark.parametrize(
"split,lenghts",
[
("default", dict(train=98637, validation=6350, test=4054)),
],
ids=id_function, # just changes how pytest prints it
)
def test_protocol_consistency(
database_checkers, split: str, lenghts: dict[str, int]
):
from ptbench.data.nih_cxr14.datamodule import make_split
database_checkers.check_split(
make_split(f"{split}.json.bz2"),
lengths=lenghts,
prefixes=("images/000",),
possible_labels=(0, 1),
)
@pytest.mark.skip_if_rc_var_not_set("datadir.nih_cxr14")
@pytest.mark.parametrize(
"dataset",
[
"train",
"validation",
"test",
],
)
@pytest.mark.parametrize(
"name",
[
"default",
],
)
def test_loading(database_checkers, name: str, dataset: str):
datamodule = importlib.import_module(
f".{name}", "ptbench.data.nih_cxr14"
).datamodule
datamodule.model_transforms = [] # should be done before setup()
datamodule.setup("predict") # sets up all datasets
loader = datamodule.predict_dataloader()[dataset]
limit = 3 # limit load checking
for batch in loader:
if limit == 0:
break
database_checkers.check_loaded_batch(
batch,
batch_size=1,
color_planes=1,
prefixes=("images/000",),
possible_labels=(0, 1),
)
limit -= 1
# TODO: check size 1024x1024
# TODO: check there are 14 binary labels (0, 1)
File moved
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment