# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later

import os
import pathlib
import tempfile
import typing
import zipfile

import pytest
import tomli_w
import torch

from mednet.data.typing import DatabaseSplit


@pytest.fixture
def datadir(request) -> pathlib.Path:
    """Returns the directory in which the test is sitting."""
    return pathlib.Path(request.module.__file__).parents[0] / "data"


def pytest_configure(config):
    """This function is run once for pytest setup."""
    config.addinivalue_line(
        "markers",
        "skip_if_rc_var_not_set(name): this mark skips the test if a certain "
        "~/.config/mednet.toml variable is not set",
    )

    config.addinivalue_line("markers", "slow: this mark indicates slow tests")


def pytest_runtest_setup(item):
    """This function is run for every test candidate in this directory.

    The test is run if this function returns ``None``.  To skip a test,
    call ``pytest.skip()``, specifying a reason.
    """
    from mednet.utils.rc import load_rc

    rc = load_rc()

    # iterates over all markers for the item being examined, get the first
    # argument and accumulate these names
    rc_names = [
        mark.args[0]
        for mark in item.iter_markers(name="skip_if_rc_var_not_set")
    ]

    # checks all names mentioned are set in ~/.config/mednet.toml, otherwise,
    # skip the test
    if rc_names:
        missing = [k for k in rc_names if rc.get(k) is None]
        if any(missing):
            pytest.skip(
                f"Test skipped because {', '.join(missing)} is **not** "
                f"set in ~/.config/mednet.toml"
            )


def rc_variable_set(name):
    from mednet.utils.rc import load_rc

    rc = load_rc()
    pytest.mark.skipif(
        name not in rc,
        reason=f"RC variable '{name}' is not set",
    )


@pytest.fixture(scope="session")
def temporary_basedir(tmp_path_factory):
    return tmp_path_factory.mktemp("test-cli")


def pytest_sessionstart(session: pytest.Session) -> None:
    """Presets the session start to ensure the Montgomery dataset is always
    available."""

    from mednet.utils.rc import load_rc

    rc = load_rc()

    database_dir = rc.get("datadir.montgomery")
    if database_dir is not None:
        # if the user downloaded it, use that copy
        return

    # else, we must extract the LFS component (we are likely on the CI)
    archive = (
        pathlib.Path(__file__).parents[0] / "data" / "lfs" / "test-database.zip"
    )
    assert archive.exists(), (
        f"Neither datadir.montgomery is set on the global configuration, "
        f"(typically ~/.config/mednet.toml), or it is possible to detect "
        f"the presence of {archive}' (did you git submodule init --update "
        f"this submodule?)"
    )

    montgomery_tempdir = tempfile.TemporaryDirectory()
    rc.setdefault("datadir.montgomery", montgomery_tempdir.name)

    with zipfile.ZipFile(archive) as zf:
        zf.extractall(montgomery_tempdir.name)

    config_filename = "mednet.toml"
    with open(
        os.path.join(montgomery_tempdir.name, config_filename), "wb"
    ) as f:
        tomli_w.dump(rc.data, f)
        f.flush()

    os.environ["XDG_CONFIG_HOME"] = montgomery_tempdir.name

    # stash the newly created temporary directory so we can erase it when the
    key = pytest.StashKey[tempfile.TemporaryDirectory]()
    session.stash[key] = montgomery_tempdir


class DatabaseCheckers:
    """Helpers for database tests."""

    @staticmethod
    def check_split(
        split: DatabaseSplit,
        lengths: dict[str, int],
        prefixes: typing.Sequence[str],
        possible_labels: typing.Sequence[int],
    ):
        """Runs a simple consistence check on the data split.

        Parameters
        ----------

        make_split
            A database specific function that takes a split name and returns
            the loaded database split.

        split_filename
            This is the split we will check

        lenghts
            A dictionary that contains keys matching those of the split (this will
            be checked).  The values of the dictionary should correspond to the
            sizes of each of the datasets in the split.

        prefixes
            Each file named in a split should start with at least one of these
            prefixes.

        possible_labels
            These are the list of possible labels contained in any split.
        """

        assert len(split) == len(lengths)

        for k in lengths.keys():
            # dataset must have been declared
            assert k in split

            assert len(split[k]) == lengths[k]
            for s in split[k]:
                assert any([s[0].startswith(k) for k in prefixes]), (
                    f"Sample with name {s[0]} does not start with any of the "
                    f"prefixes in {prefixes}"
                )
                if isinstance(s[1], list):
                    assert all([k in possible_labels for k in s[1]])
                else:
                    assert s[1] in possible_labels

    @staticmethod
    def check_loaded_batch(
        batch,
        batch_size: int,
        color_planes: int,
        prefixes: typing.Sequence[str],
        possible_labels: typing.Sequence[int],
    ):
        """Checks the consistence of an individual (loaded) batch.

        Parameters
        ----------

        batch
            The loaded batch to be checked.

        size
            The mini-batch size

        prefixes
            Each file named in a split should start with at least one of these
            prefixes.

        possible_labels
            These are the list of possible labels contained in any split.
        """

        assert len(batch) == 2  # data, metadata

        assert isinstance(batch[0], torch.Tensor)
        assert batch[0].shape[0] == batch_size  # mini-batch size
        assert batch[0].shape[1] == color_planes  # grayscale images
        assert batch[0].shape[2] == batch[0].shape[3]  # image is square

        assert isinstance(batch[1], dict)  # metadata
        assert len(batch[1]) == 2  # label and name

        assert "label" in batch[1]
        assert all([k in possible_labels for k in batch[1]["label"]])

        assert "name" in batch[1]
        assert all(
            [any([k.startswith(j) for j in prefixes]) for k in batch[1]["name"]]
        )

        # use the code below to view generated images
        # from torchvision.transforms.functional import to_pil_image
        # to_pil_image(batch[0][0]).show()
        # __import__("pdb").set_trace()


@pytest.fixture
def database_checkers():
    return DatabaseCheckers