Skip to content
Snippets Groups Projects
Commit 900f69d4 authored by André Anjos's avatar André Anjos :speech_balloon:
Browse files

[tests] Fix montgomery test; Make naming more explicit

parent 1dc638e6
No related branches found
No related tags found
1 merge request!6Making use of LightningDataModule and simplification of data loading
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""Tests for Montgomery dataset."""
import pytest
import torch
from ptbench.data.montgomery.datamodule import make_split
def _check_split(
split_filename: str,
lengths: dict[str, int],
prefix: str = "CXR_png/MCUCXR_0",
possible_labels: list[int] = [0, 1],
):
"""Runs a simple consistence check on the data split.
Parameters
----------
split_filename
This is the split we will check
lenghts
A dictionary that contains keys matching those of the split (this will
be checked). The values of the dictionary should correspond to the
sizes of each of the datasets in the split.
prefix
Each file named in a split should start with this prefix.
possible_labels
These are the list of possible labels contained in any split.
"""
split = make_split(split_filename)
assert len(split) == len(lengths)
for k in lengths.keys():
# dataset must have been declared
assert k in split
assert len(split[k]) == lengths[k]
for s in split[k]:
assert s[0].startswith(prefix)
assert s[1] in possible_labels
def _check_loaded_batch(
batch,
size: int = 1,
prefix: str = "CXR_png/MCUCXR_0",
possible_labels: list[int] = [0, 1],
):
"""Checks the consistence of an individual (loaded) batch.
Parameters
----------
batch
The loaded batch to be checked.
prefix
Each file named in a split should start with this prefix.
possible_labels
These are the list of possible labels contained in any split.
"""
assert len(batch) == 2 # data, metadata
assert isinstance(batch[0], torch.Tensor)
assert batch[0].shape[0] == size # mini-batch size
assert batch[0].shape[1] == 1 # grayscale images
assert batch[0].shape[2] == batch[0].shape[3] # image is square
assert isinstance(batch[1], dict) # metadata
assert len(batch[1]) == 2 # label and name
assert "label" in batch[1]
assert all([k in possible_labels for k in batch[1]["label"]])
assert "name" in batch[1]
assert all([k.startswith(prefix) for k in batch[1]["name"]])
def test_protocol_consistency():
_check_split(
"default.json",
lengths=dict(train=88, validation=22, test=28),
)
# Cross-validation fold 0-7
for k in range(8):
_check_split(
f"fold_{k}.json",
lengths=dict(train=99, validation=25, test=14),
)
# Cross-validation fold 8-9
for k in range(8, 10):
_check_split(
f"fold_{k}.json",
lengths=dict(train=100, validation=25, test=13),
)
@pytest.mark.skip_if_rc_var_not_set("datadir.montgomery")
def test_loading():
from ptbench.data.montgomery.default import datamodule
datamodule.model_transforms = [] # should be done before setup()
datamodule.setup("predict") # sets up all datasets
for loader in datamodule.predict_dataloader().values():
limit = 5 # limit load checking
for batch in loader:
if limit == 0:
break
_check_loaded_batch(batch)
limit -= 1
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment