Skip to content
Snippets Groups Projects
conftest.py 7.2 KiB
Newer Older
André Anjos's avatar
André Anjos committed
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later

import pathlib
André Anjos's avatar
André Anjos committed

import pytest
from torchvision.transforms.functional import to_pil_image

from mednet.data.split import JSONDatabaseSplit
from mednet.data.typing import DatabaseSplit
André Anjos's avatar
André Anjos committed


@pytest.fixture
def datadir(request) -> pathlib.Path:
    """Return the directory in which the test is sitting. Check the pytest documentation for more information.

    Parameters
    ----------
    request
        Information of the requesting test function.

    Returns
    -------
    pathlib.Path
        The directory in which the test is sitting.
    """
André Anjos's avatar
André Anjos committed
    return pathlib.Path(request.module.__file__).parents[0] / "data"


def pytest_configure(config):
    """This function is run once for pytest setup.

    Parameters
    ----------
    config
        Configuration values. Check the pytest documentation for more information.
    """
André Anjos's avatar
André Anjos committed
    config.addinivalue_line(
        "markers",
        "skip_if_rc_var_not_set(name): this mark skips the test if a certain "
        "~/.config/mednet.toml variable is not set",
André Anjos's avatar
André Anjos committed
    )

    config.addinivalue_line("markers", "slow: this mark indicates slow tests")


def pytest_runtest_setup(item):
    """This function is run for every test candidate in this directory.

    The test is run if this function returns ``None``.  To skip a test,
    call ``pytest.skip()``, specifying a reason.

    Parameters
    ----------
    item
        A test invocation item. Check the pytest documentation for more information.
André Anjos's avatar
André Anjos committed
    """
    from mednet.utils.rc import load_rc
André Anjos's avatar
André Anjos committed

    rc = load_rc()

    # iterates over all markers for the item being examined, get the first
    # argument and accumulate these names
    rc_names = [
        mark.args[0]
        for mark in item.iter_markers(name="skip_if_rc_var_not_set")
    ]

    # checks all names mentioned are set in ~/.config/mednet.toml, otherwise,
André Anjos's avatar
André Anjos committed
    # skip the test
    if rc_names:
        missing = [k for k in rc_names if rc.get(k) is None]
        if any(missing):
            pytest.skip(
                f"Test skipped because {', '.join(missing)} is **not** "
                f"set in ~/.config/mednet.toml"
André Anjos's avatar
André Anjos committed
            )


def rc_variable_set(name):
    from mednet.utils.rc import load_rc
André Anjos's avatar
André Anjos committed

    rc = load_rc()
    pytest.mark.skipif(
        name not in rc,
        reason=f"RC variable '{name}' is not set",
    )


@pytest.fixture(scope="session")
def temporary_basedir(tmp_path_factory):
    return tmp_path_factory.mktemp("test-cli")
class DatabaseCheckers:
    """Helpers for database tests."""

    @staticmethod
    def check_split(
        split: DatabaseSplit,
        lengths: dict[str, int],
        prefixes: typing.Sequence[str],
        possible_labels: typing.Sequence[int],
    ):
        """Run a simple consistency check on the data split.
        split
            An instance of DatabaseSplit.
        lengths
            A dictionary that contains keys matching those of the split (this will
            be checked).  The values of the dictionary should correspond to the
            sizes of each of the datasets in the split.
        prefixes
            Each file named in a split should start with at least one of these
            prefixes.
        possible_labels
            These are the list of possible labels contained in any split.
        """

        assert len(split) == len(lengths)

        for k in lengths.keys():
            # dataset must have been declared
            assert k in split

            assert len(split[k]) == lengths[k]
            for s in split[k]:
                assert any([s[0].startswith(k) for k in prefixes]), (
                    f"Sample with name {s[0]} does not start with any of the "
                    f"prefixes in {prefixes}"
                )
                if isinstance(s[1], list):
                    assert all([k in possible_labels for k in s[1]])
                else:
                    assert s[1] in possible_labels

    @staticmethod
    def check_loaded_batch(
        batch,
        batch_size: int,
        color_planes: int,
        prefixes: typing.Sequence[str],
        possible_labels: typing.Sequence[int],
        expected_num_labels: int,
        expected_image_shape: typing.Optional[tuple[int, ...]] = None,
        """Check the consistency of an individual (loaded) batch.

        Parameters
        ----------
        batch
            The loaded batch to be checked.
        batch_size
            The mini-batch size.
        color_planes
            The number of color planes in the images.
        prefixes
            Each file named in a split should start with at least one of these
            prefixes.
        possible_labels
            These are the list of possible labels contained in any split.
        expected_num_labels
            The expected number of labels each sample should have.
        expected_image_shape
            The expected shape of the image (num_channels, width, height).
        """

        assert len(batch) == 2  # data, metadata

        assert isinstance(batch[0], torch.Tensor)
        assert batch[0].shape[0] == batch_size  # mini-batch size
        assert batch[0].shape[1] == color_planes
        if expected_image_shape:
            assert all(
                [data.shape == expected_image_shape for data in batch[0]]
            )

        assert isinstance(batch[1], dict)  # metadata
        assert len(batch[1]) == 2  # label and name

        assert "label" in batch[1]
        assert all([k in possible_labels for k in batch[1]["label"]])

        if expected_num_labels:
            assert len(batch[1]["label"]) == expected_num_labels

        assert "name" in batch[1]
        assert all(
            [any([k.startswith(j) for j in prefixes]) for k in batch[1]["name"]]
        )

        # use the code below to view generated images
        # from torchvision.transforms.functional import to_pil_image
        # to_pil_image(batch[0][0]).show()
        # __import__("pdb").set_trace()

    @staticmethod
    def check_image_quality(datamodule, reference_histogram_file):
        ref_histogram_splits = JSONDatabaseSplit(reference_histogram_file)

        for split_name in ref_histogram_splits:
            raw_samples = datamodule.splits[split_name][0][0]

            # It is not possible to get a sample from a Dataset by name/path, only by index.
            # This creates a dict of sample name to dataset index.
            raw_samples_indices = {}
            for idx, rs in enumerate(raw_samples):
                raw_samples_indices[rs[0]] = idx

            for ref_hist_path, ref_hist_data in ref_histogram_splits[
                split_name
            ]:
                # Get index in the dataset that will return the data corresponding to the specified sample name
                dataset_sample_index = raw_samples_indices[ref_hist_path]

                image_tensor = datamodule._datasets[split_name][
                    dataset_sample_index
                ][0]
                img = to_pil_image(image_tensor)
                histogram = img.histogram()

                assert histogram == ref_hist_data


@pytest.fixture
def database_checkers():
    return DatabaseCheckers