# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""Tests for Montgomery dataset."""

import importlib
import json

import pytest

from PIL import Image


def id_function(val):
    if isinstance(val, dict):
        return str(val)
    return repr(val)


@pytest.mark.parametrize(
    "split,lenghts",
    [
        ("default", dict(train=88, validation=22, test=28)),
        ("fold-0", dict(train=99, validation=25, test=14)),
        ("fold-1", dict(train=99, validation=25, test=14)),
        ("fold-2", dict(train=99, validation=25, test=14)),
        ("fold-3", dict(train=99, validation=25, test=14)),
        ("fold-4", dict(train=99, validation=25, test=14)),
        ("fold-5", dict(train=99, validation=25, test=14)),
        ("fold-6", dict(train=99, validation=25, test=14)),
        ("fold-7", dict(train=99, validation=25, test=14)),
        ("fold-8", dict(train=100, validation=25, test=13)),
        ("fold-9", dict(train=100, validation=25, test=13)),
    ],
    ids=id_function,  # just changes how pytest prints it
)
def test_protocol_consistency(
    database_checkers, split: str, lenghts: dict[str, int]
):
    from mednet.config.data.montgomery.datamodule import make_split

    database_checkers.check_split(
        make_split(f"{split}.json"),
        lengths=lenghts,
        prefixes=("CXR_png/MCUCXR_0",),
        possible_labels=(0, 1),
    )


@pytest.mark.skip_if_rc_var_not_set("datadir.montgomery")
@pytest.mark.parametrize(
    "dataset",
    [
        "train",
        "validation",
        "test",
    ],
)
@pytest.mark.parametrize(
    "name",
    [
        "default",
        "fold_0",
        "fold_1",
        "fold_2",
        "fold_3",
        "fold_4",
        "fold_5",
        "fold_6",
        "fold_7",
        "fold_8",
        "fold_9",
    ],
)
def test_loading(database_checkers, name: str, dataset: str):
    datamodule = importlib.import_module(
        f".{name}", "mednet.config.data.montgomery"
    ).datamodule

    datamodule.model_transforms = []  # should be done before setup()
    datamodule.setup("predict")  # sets up all datasets

    loader = datamodule.predict_dataloader()[dataset]

    limit = 3  # limit load checking
    for batch in loader:
        if limit == 0:
            break
        database_checkers.check_loaded_batch(
            batch,
            batch_size=1,
            color_planes=1,
            prefixes=("CXR_png/MCUCXR_0",),
            possible_labels=(0, 1),
            expected_num_labels=1,
        )
        limit -= 1


@pytest.mark.skip_if_rc_var_not_set("datadir.montgomery")
def test_loaded_image_quality(datadir):
    datamodule = importlib.import_module(
        ".default", "mednet.config.data.montgomery"
    ).datamodule

    datamodule.model_transforms = []
    datamodule.setup("predict")

    loader = datamodule.splits["train"][0][1]
    first_sample = datamodule.splits["train"][0][0][0]
    image_data = loader.sample(first_sample)[0].numpy()[
        0, :, :
    ]  # PIL expects grayscale to not have any leading dim
    img = Image.fromarray(image_data, mode="L")

    histogram = img.histogram()

    reference_histogram_file = str(datadir / "histogram_montgomery.json")
    with open(reference_histogram_file) as i_f:
        ref_histogram = json.load(i_f)["histogram"]

    assert histogram == ref_histogram