Skip to content
Snippets Groups Projects
plot.py 8.66 KiB
Newer Older
André Anjos's avatar
André Anjos committed
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later

import contextlib

from itertools import cycle

import matplotlib
import numpy

from sklearn.metrics import auc
from sklearn.metrics import precision_recall_curve as pr_curve
from sklearn.metrics import roc_curve as r_curve

matplotlib.use("agg")
import logging

import matplotlib.pyplot as plt

logger = logging.getLogger(__name__)


@contextlib.contextmanager
def _precision_recall_canvas(title=None):
    """Generates a canvas to draw precision-recall curves.

    Works like a context manager, yielding a figure and an axes set in which
    the precision-recall curves should be added to.  The figure already
    contains F1-ISO lines and is preset to a 0-1 square region.  Once the
    context is finished, ``fig.tight_layout()`` is called.


    Parameters
    ----------

    title : :py:class:`str`, Optional
        Optional title to add to this plot


    Yields
    ------

    figure : matplotlib.figure.Figure
        The figure that should be finally returned to the user

    axes : matplotlib.figure.Axes
        An axis set where to precision-recall plots should be added to
    """
    fig, axes1 = plt.subplots(1)

    # Names and bounds
    axes1.set_xlabel("Recall")
    axes1.set_ylabel("Precision")
    axes1.set_xlim([0.0, 1.0])
    axes1.set_ylim([0.0, 1.0])

    if title is not None:
        axes1.set_title(title)

    axes1.grid(linestyle="--", linewidth=1, color="gray", alpha=0.2)
    axes2 = axes1.twinx()

    # Annotates plot with F1-score iso-lines
    f_scores = numpy.linspace(0.1, 0.9, num=9)
    tick_locs = []
    tick_labels = []
    for f_score in f_scores:
        x = numpy.linspace(0.01, 1)
        y = f_score * x / (2 * x - f_score)
        (l,) = plt.plot(x[y >= 0], y[y >= 0], color="green", alpha=0.1)
        tick_locs.append(y[-1])
        tick_labels.append("%.1f" % f_score)
    axes2.tick_params(axis="y", which="both", pad=0, right=False, left=False)
    axes2.set_ylabel("iso-F", color="green", alpha=0.3)
    axes2.set_ylim([0.0, 1.0])
    axes2.yaxis.set_label_coords(1.015, 0.97)
    axes2.set_yticks(tick_locs)  # notice these are invisible
    for k in axes2.set_yticklabels(tick_labels):
        k.set_color("green")
        k.set_alpha(0.3)
        k.set_size(8)

    # we should see some of axes 1 axes
    axes1.spines["right"].set_visible(False)
    axes1.spines["top"].set_visible(False)
    axes1.spines["left"].set_position(("data", -0.015))
    axes1.spines["bottom"].set_position(("data", -0.015))

    # we shouldn't see any of axes 2 axes
    axes2.spines["right"].set_visible(False)
    axes2.spines["top"].set_visible(False)
    axes2.spines["left"].set_visible(False)
    axes2.spines["bottom"].set_visible(False)

    # yield execution, lets user draw precision-recall plots, and the legend
    # before tighteneing the layout
    yield fig, axes1

    plt.tight_layout()


def precision_recall_f1iso(data):
    """Creates a precision-recall plot.

    This function creates and returns a Matplotlib figure with a
    precision-recall plot.  The plot will be annotated with F1-score
    iso-lines (in which the F1-score maintains the same value).


    Parameters
    ----------

    data : dict
        A dictionary in which keys are strings defining plot labels and values
        are dictionaries with two entries:

        * ``df``: :py:class:`pandas.DataFrame`

          A dataframe that is produced by our predictor engine containing
          the following columns: ``filename``, ``likelihood``,
          ``ground_truth``.

        * ``threshold``: :py:class:`list`

          A threshold for each set. Not used here.


    Returns
    -------

    figure : matplotlib.figure.Figure
        A matplotlib figure you can save or display (uses an ``agg`` backend)
    """
    lines = ["-", "--", "-.", ":"]
    colors = [
        "#1f77b4",
        "#ff7f0e",
        "#2ca02c",
        "#d62728",
        "#9467bd",
        "#8c564b",
        "#e377c2",
        "#7f7f7f",
        "#bcbd22",
        "#17becf",
    ]
    colorcycler = cycle(colors)
    linecycler = cycle(lines)

    with _precision_recall_canvas(title=None) as (fig, axes):
        legend = []

        for name, value in data.items():
            df = value["df"]

            # plots Recall/Precision curve
            prec, recall, _ = pr_curve(df["ground_truth"], df["likelihood"])
            _auc = auc(recall, prec)
            label = f"{name} (AUC={_auc:.2f})"
            color = next(colorcycler)
            style = next(linecycler)

            (line,) = axes.plot(recall, prec, color=color, linestyle=style)
            legend.append((line, label))

        if len(label) > 1:
            axes.legend(
                [k[0] for k in legend],
                [k[1] for k in legend],
                loc="lower left",
                fancybox=True,
                framealpha=0.7,
            )

    return fig


def roc_curve(data, title=None):
    """Creates a ROC plot.

    This function creates and returns a Matplotlib figure with a
    ROC plot.


    Parameters
    ----------

    data : dict
        A dictionary in which keys are strings defining plot labels and values
        are dictionaries with two entries:

        * ``df``: :py:class:`pandas.DataFrame`

          A dataframe that is produced by our predictor engine containing
          the following columns: ``filename``, ``likelihood``,
          ``ground_truth``.

        * ``threshold``: :py:class:`list`

          A threshold for each set. Not used here.


    Returns
    -------

    figure : matplotlib.figure.Figure
        A matplotlib figure you can save or display (uses an ``agg`` backend)
    """
    fig, axes = plt.subplots(1)

    # Names and bounds
    axes.set_xlabel("1 - specificity")
    axes.set_ylabel("Sensitivity")
    axes.set_xlim([0.0, 1.0])
    axes.set_ylim([0.0, 1.0])

    # we should see some of axes 1 axes
    axes.spines["right"].set_visible(False)
    axes.spines["top"].set_visible(False)
    axes.spines["left"].set_position(("data", -0.015))
    axes.spines["bottom"].set_position(("data", -0.015))

    if title is not None:
        axes.set_title(title)

    axes.grid(linestyle="--", linewidth=1, color="gray", alpha=0.2)

    plt.tight_layout()

    lines = ["-", "--", "-.", ":"]
    colors = [
        "#1f77b4",
        "#ff7f0e",
        "#2ca02c",
        "#d62728",
        "#9467bd",
        "#8c564b",
        "#e377c2",
        "#7f7f7f",
        "#bcbd22",
        "#17becf",
    ]
    colorcycler = cycle(colors)
    linecycler = cycle(lines)

    legend = []

    for name, value in data.items():
        df = value["df"]

        # plots roc curve
        fpr, tpr, _ = r_curve(df["ground_truth"], df["likelihood"])
        _auc = auc(fpr, tpr)
        label = f"{name} (AUC={_auc:.2f})"
        color = next(colorcycler)
        style = next(linecycler)

        (line,) = axes.plot(fpr, tpr, color=color, linestyle=style)
        legend.append((line, label))

    if len(label) > 1:
        axes.legend(
            [k[0] for k in legend],
            [k[1] for k in legend],
            loc="lower right",
            fancybox=True,
            framealpha=0.7,
        )

    return fig


def relevance_analysis_plot(data, title=None):
    """Create an histogram plot to show the relative importance of features.

    Parameters
    ----------

    data : :py:class:`list`
        The list of values (one for each feature)


    Returns
    -------

    figure : matplotlib.figure.Figure
        A matplotlib figure you can save or display (uses an ``agg`` backend)
    """
    fig, axes = plt.subplots(1, 1, figsize=(6, 6))

    # Names and bounds
    axes.set_xlabel("Features")
    axes.set_ylabel("Importance")

    # we should see some of axes 1 axes
    axes.spines["right"].set_visible(False)
    axes.spines["top"].set_visible(False)

    if title is not None:
        axes.set_title(title)

    # 818C2E = likely
    # F2921D = could be
    # 8C3503 = unlikely

    labels = [
        "Cardiomegaly",
        "Emphysema",
        "Pleural effusion",
        "Hernia",
        "Infiltration",
        "Mass",
        "Nodule",
        "Atelectasis",
        "Pneumothorax",
        "Pleural thickening",
        "Pneumonia",
        "Fibrosis",
        "Edema",
        "Consolidation",
    ]
    bars = axes.bar(labels, data, color="#8C3503")

    bars[2].set_color("#818C2E")
    bars[4].set_color("#818C2E")
    bars[10].set_color("#818C2E")
    bars[5].set_color("#F2921D")
    bars[6].set_color("#F2921D")
    bars[7].set_color("#F2921D")
    bars[11].set_color("#F2921D")
    bars[13].set_color("#F2921D")

    for tick in axes.get_xticklabels():
        tick.set_rotation(90)

    fig.tight_layout()

    return fig