# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""Tests for the cam_utils script."""

import numpy as np
import pytest

# from ptbench.utils.cam_utils import (
#     _calculate_stats_over_dataset,
#     calculate_metrics_avg_for_every_class,
#     draw_boxes_on_image,
#     draw_largest_component_bbox_on_image,
#     show_cam_on_image,
#     visualize_road_scores,
# )


def test_calculate_stats_over_dataset(datadir):
    # Sample CSV
    metrics = _calculate_stats_over_dataset(datadir / "test_vis_metrics.csv")

    expected_metrics = {
        "MoRF": {
            "mean": 1.4,
            "median": 1.0,
            "std_dev": 0.548,
            "Q1": 1.0,
            "Q3": 2.0,
        },
        "LeRF": {
            "mean": 2.4,
            "median": 2.0,
            "std_dev": 0.548,
            "Q1": 2.0,
            "Q3": 3.0,
        },
        "Combined Score ((LeRF-MoRF) / 2)": {
            "mean": 3.4,
            "median": 3.0,
            "std_dev": 0.548,
            "Q1": 3.0,
            "Q3": 4.0,
        },
        "IoU": {
            "mean": 4.4,
            "median": 4.0,
            "std_dev": 0.548,
            "Q1": 4.0,
            "Q3": 5.0,
        },
        "IoDA": {
            "mean": 5.4,
            "median": 5.0,
            "std_dev": 0.548,
            "Q1": 5.0,
            "Q3": 6.0,
        },
        "propEnergy": {
            "mean": 6.4,
            "median": 6.0,
            "std_dev": 0.548,
            "Q1": 6.0,
            "Q3": 7.0,
        },
        "ASF": {
            "mean": 7.4,
            "median": 7.0,
            "std_dev": 0.548,
            "Q1": 7.0,
            "Q3": 8.0,
        },
    }
    for column, expected_values in expected_metrics.items():
        for key, expected_value in expected_values.items():
            assert metrics[column][key] == expected_value


def test_calculate_metrics_avg_for_every_class(temporary_basedir):
    # Create a sample directory structure and CSV files
    input_folder = temporary_basedir / "camutils" / "gradcam"
    input_folder.mkdir(parents=True, exist_ok=True)
    class1_dir = input_folder / "class1"
    class1_dir.mkdir(parents=True, exist_ok=True)
    class2_dir = input_folder / "class2"
    class2_dir.mkdir(parents=True, exist_ok=True)

    data = {
        "MoRF": [1, 2, 3],
        "LeRF": [2, 4, 6],
        "Combined Score ((LeRF-MoRF) / 2)": [1.5, 3, 4.5],
        "IoU": [1, 2, 3],
        "IoDA": [2, 4, 6],
        "propEnergy": [1.5, 3, 4.5],
        "ASF": [1, 2, 3],
    }
    df = pd.DataFrame(data)
    df.to_csv(class1_dir / "file1.csv", index=False)
    df.to_csv(class2_dir / "file1.csv", index=False)
    df.to_csv(class1_dir / "file2.csv", index=False)
    df.to_csv(class2_dir / "file2.csv", index=False)

    calculate_metrics_avg_for_every_class(input_folder)

    # Assert the existence of summary files
    assert (input_folder / "file1_summary.csv").exists()
    assert (input_folder / "file2_summary.csv").exists()

    # Assert the content of summary files
    score_types = [
        "MoRF",
        "LeRF",
        "Combined Score ((LeRF-MoRF) / 2)",
        "IoU",
        "IoDA",
        "propEnergy",
        "ASF",
    ]
    expected_data = {
        "Class Name": ["class1"] * len(score_types)
        + ["class2"] * len(score_types),
        "Score Type": score_types * 2,
        "Mean": [
            2.0,
            4.0,
            3.0,
            2.0,
            4.0,
            3.0,
            2.0,
            2.0,
            4.0,
            3.0,
            2.0,
            4.0,
            3.0,
            2.0,
        ],
        "Standard Deviation": [
            1.0,
            2.0,
            1.5,
            1.0,
            2.0,
            1.5,
            1.0,
            1.0,
            2.0,
            1.5,
            1.0,
            2.0,
            1.5,
            1.0,
        ],
        "Median": [
            2.0,
            4.0,
            3.0,
            2.0,
            4.0,
            3.0,
            2.0,
            2.0,
            4.0,
            3.0,
            2.0,
            4.0,
            3.0,
            2.0,
        ],
        "Q1": [
            1.5,
            3.0,
            2.25,
            1.5,
            3.0,
            2.25,
            1.5,
            1.5,
            3.0,
            2.25,
            1.5,
            3.0,
            2.25,
            1.5,
        ],
        "Q3": [
            2.5,
            5.0,
            3.75,
            2.5,
            5.0,
            3.75,
            2.5,
            2.5,
            5.0,
            3.75,
            2.5,
            5.0,
            3.75,
            2.5,
        ],
    }
    for file_name in ["file1", "file2"]:
        df = pd.read_csv(input_folder / f"{file_name}_summary.csv")
        for column, expected_values in expected_data.items():
            assert list(df[column]) == expected_values


def test_draw_boxes_on_image():
    import torch

    # Create a sample image and bounding boxes
    img_with_boxes = np.zeros((200, 200, 3), dtype=np.uint8)
    img_with_boxes2 = np.zeros((256, 256, 3), dtype=np.uint8)
    img_with_boxes3 = np.zeros((512, 512, 3), dtype=np.uint8)
    bboxes = [
        [
            torch.tensor(1),
            torch.tensor(10),
            torch.tensor(10),
            torch.tensor(20),
            torch.tensor(20),
        ],
        [
            torch.tensor(1),
            torch.tensor(30),
            torch.tensor(30),
            torch.tensor(40),
            torch.tensor(40),
        ],
    ]

    result = draw_boxes_on_image(img_with_boxes, bboxes)
    result2 = draw_boxes_on_image(img_with_boxes2, bboxes)
    result3 = draw_boxes_on_image(img_with_boxes3, bboxes)

    # Verify bounding box by checking for green pixel
    assert np.all(result[10, 10] == np.array([0, 255, 0]))
    assert np.all(result[30, 30] == np.array([0, 255, 0]))
    assert np.all(result2[10, 10] == np.array([0, 255, 0]))
    assert np.all(result2[30, 30] == np.array([0, 255, 0]))
    assert np.all(result3[10, 10] == np.array([0, 255, 0]))
    assert np.all(result3[30, 30] == np.array([0, 255, 0]))

    # Verify that no bounding box is drawn outside expected area
    assert np.all(result[0:10, 0:10] == np.array([0, 0, 0]))
    assert np.all(result[160:, 160:] == np.array([0, 0, 0]))
    assert np.all(result2[0:10, 0:10] == np.array([0, 0, 0]))
    assert np.all(result2[160:, 160:] == np.array([0, 0, 0]))
    assert np.all(result3[0:10, 0:10] == np.array([0, 0, 0]))
    assert np.all(result3[160:, 160:] == np.array([0, 0, 0]))

    # Verifying some pixels in the expected text region aren't pure black (text is greenish)
    assert not np.all(result[30:40, 10:30] == np.array([0, 0, 0]))
    assert not np.all(result[70:80, 30:50] == np.array([0, 0, 0]))
    assert not np.all(result2[30:40, 10:30] == np.array([0, 0, 0]))
    assert not np.all(result2[70:80, 30:50] == np.array([0, 0, 0]))
    assert not np.all(result3[30:40, 10:30] == np.array([0, 0, 0]))
    assert not np.all(result3[70:80, 30:50] == np.array([0, 0, 0]))

    # Verify that it returns an image
    assert isinstance(result, np.ndarray)
    assert result.shape == img_with_boxes.shape
    assert isinstance(result2, np.ndarray)
    assert result2.shape == img_with_boxes2.shape
    assert isinstance(result3, np.ndarray)
    assert result3.shape == img_with_boxes3.shape


def test_draw_largest_component_bbox():
    # Create a sample image and bounding boxes
    img_with_boxes = np.zeros((200, 200, 3), dtype=np.uint8)
    img_with_boxes2 = np.zeros((256, 256, 3), dtype=np.uint8)
    img_with_boxes3 = np.zeros((512, 512, 3), dtype=np.uint8)
    x, y, w, h = 10, 10, 20, 20

    result = draw_largest_component_bbox_on_image(img_with_boxes, x, y, w, h)
    result2 = draw_largest_component_bbox_on_image(img_with_boxes2, x, y, w, h)
    result3 = draw_largest_component_bbox_on_image(img_with_boxes3, x, y, w, h)

    # Verify that box is drawn as expected
    assert np.all(result[10, 10] == np.array([0, 0, 255]))  # Red Box Pixel
    assert np.all(result2[10, 10] == np.array([0, 0, 255]))
    assert np.all(result3[10, 10] == np.array([0, 0, 255]))

    # Check that no bounding box is outside expected area
    assert np.all(result[0:10, 0:10] == np.array([0, 0, 0]))
    assert np.all(result2[0:10, 0:10] == np.array([0, 0, 0]))
    assert np.all(result3[0:10, 0:10] == np.array([0, 0, 0]))
    assert np.any(result[40:, 40:] != np.array([0, 0, 0]))
    assert np.any(result2[40:, 40:] != np.array([0, 0, 0]))
    assert np.any(result3[40:, 40:] != np.array([0, 0, 0]))

    # Verify some pixels in the expected text region aren't pure black (text is reddish)
    assert not np.all(result[31:40, 10:30] == np.array([0, 0, 0]))
    assert not np.all(result2[31:40, 10:30] == np.array([0, 0, 0]))
    assert not np.all(result3[31:40, 10:30] == np.array([0, 0, 0]))

    # Verify that it returns an image
    assert isinstance(result, np.ndarray)
    assert result.shape == img_with_boxes.shape
    assert isinstance(result2, np.ndarray)
    assert result2.shape == img_with_boxes2.shape
    assert isinstance(result3, np.ndarray)
    assert result3.shape == img_with_boxes3.shape


def test_show_cam_on_image():
    # Create a sample image and CAM
    img = np.zeros((200, 200, 3), dtype=np.uint8)
    img2 = np.zeros((256, 256, 3), dtype=np.uint8)
    img3 = np.zeros((512, 512, 3), dtype=np.uint8)
    cam = np.ones((200, 200), dtype=np.float32)
    cam2 = np.ones((256, 256), dtype=np.float32)
    cam3 = np.ones((512, 512), dtype=np.float32)

    result = show_cam_on_image(img, cam)
    result2 = show_cam_on_image(img2, cam2)
    result3 = show_cam_on_image(img3, cam3)

    assert isinstance(result, np.ndarray)
    assert result.dtype == np.uint8
    assert result.shape == (200, 200, 3)
    assert isinstance(result2, np.ndarray)
    assert result2.dtype == np.uint8
    assert result2.shape == (256, 256, 3)
    assert isinstance(result3, np.ndarray)
    assert result3.dtype == np.uint8
    assert result3.shape == (512, 512, 3)

    assert np.all(result >= 0)
    assert np.all(result <= 255)


def test_show_cam_on_image_use_rgb():
    img = np.ones((2, 2, 3), dtype=np.float32)
    cam = np.ones((2, 2), dtype=np.float32) * 0.5

    result = show_cam_on_image(img, cam, use_rgb=True)

    # Assert that result and img are not the same (that color mapping was applied)
    assert not np.array_equal(result, np.uint8(img * 255))


def test_show_cam_on_image_colormap():
    img = np.ones((2, 2, 3), dtype=np.float32)
    cam = np.ones((2, 2), dtype=np.float32) * 0.5

    result = show_cam_on_image(img, cam, colormap=cv2.COLORMAP_JET)

    # Assert that result and img are not the same (that color mapping was applied)
    assert not np.array_equal(result, np.uint8(img * 255))


def test_show_cam_on_image_raises_for_invalid_img():
    img = np.ones((200, 200, 3), dtype=np.float32) * 2  # values are above 1
    cam = np.ones((200, 200), dtype=np.float32)

    with pytest.raises(
        Exception,
        match=r"The input image should np.float32 in the range \[0, 1\].*",
    ):
        show_cam_on_image(img, cam)


def test_show_cam_on_image_raises_for_invalid_image_weight():
    img = np.zeros((200, 200, 3), dtype=np.float32)
    cam = np.ones((200, 200), dtype=np.float32)

    with pytest.raises(
        Exception, match=r"image_weight should be in the range \[0, 1\].*"
    ):
        show_cam_on_image(img, cam, image_weight=1.5)


def test_show_cam_on_image_raises_for_mismatched_shapes():
    img = np.ones((200, 200, 3), dtype=np.float32)
    cam = np.ones((100, 100), dtype=np.float32)

    with pytest.raises(
        ValueError,
        match="The shape of the mask should be the same as the shape of the image.",
    ):
        show_cam_on_image(img, cam)


def test_visualize_road_scores():
    # Create a sample image and CAM
    img = np.zeros((200, 200, 3), dtype=np.uint8)
    MoRF_score = 0.12345
    LeRF_score = 0.23456
    combiend_score = 0.34567
    name = "Road"
    percentiles = [25, 75]

    result = visualize_road_scores(
        img, MoRF_score, LeRF_score, combiend_score, name, percentiles
    )

    assert isinstance(result, np.ndarray)
    assert np.array_equal(result, img)