From 0aadc7eab6e04eaa3329efc99a0157a669dafd17 Mon Sep 17 00:00:00 2001
From: Andre Anjos <andre.dos.anjos@gmail.com>
Date: Wed, 22 May 2024 13:58:40 +0200
Subject: [PATCH] [scripts.train_analysis] Add indication of lowest validation
 loss epochs on loss plots (closes #70)

---
 src/mednet/engine/callbacks.py       |   6 +-
 src/mednet/scripts/experiment.py     |  15 ++-
 src/mednet/scripts/train_analysis.py | 177 ++++++++++++++++++++-------
 3 files changed, 142 insertions(+), 56 deletions(-)

diff --git a/src/mednet/engine/callbacks.py b/src/mednet/engine/callbacks.py
index b1b16f65..c1f66d8c 100644
--- a/src/mednet/engine/callbacks.py
+++ b/src/mednet/engine/callbacks.py
@@ -40,9 +40,9 @@ class LoggingCallback(lightning.pytorch.Callback):
         super().__init__()
 
         # timers
-        self._start_training_time = 0.0
-        self._start_training_epoch_time = 0.0
-        self._start_validation_epoch_time = 0.0
+        self._start_training_time = time.time()
+        self._start_training_epoch_time = time.time()
+        self._start_validation_epoch_time = time.time()
 
         # log accumulators for a single flush at each training cycle
         self._to_log: dict[str, float] = {}
diff --git a/src/mednet/scripts/experiment.py b/src/mednet/scripts/experiment.py
index 538ddf63..67e16b64 100644
--- a/src/mednet/scripts/experiment.py
+++ b/src/mednet/scripts/experiment.py
@@ -92,7 +92,7 @@ def experiment(
     )
     train_stop_timestamp = datetime.now()
 
-    logger.info(f"Ended training in {train_stop_timestamp}")
+    logger.info(f"Ended training at {train_stop_timestamp}")
     logger.info(
         f"Training runtime: {train_stop_timestamp-train_start_timestamp}"
     )
@@ -100,11 +100,10 @@ def experiment(
     logger.info("Started train analysis")
     from .train_analysis import train_analysis
 
-    logdir = train_output_folder / "logs"
     ctx.invoke(
         train_analysis,
-        logdir=logdir,
-        output_folder=train_output_folder,
+        logdir=train_output_folder / "logs",
+        output=output_folder / "trainlog.pdf",
     )
 
     logger.info("Ended train analysis")
@@ -128,7 +127,7 @@ def experiment(
     )
 
     predict_stop_timestamp = datetime.now()
-    logger.info(f"Ended prediction in {predict_stop_timestamp}")
+    logger.info(f"Ended prediction at {predict_stop_timestamp}")
     logger.info(
         f"Prediction runtime: {predict_stop_timestamp-predict_start_timestamp}"
     )
@@ -146,7 +145,7 @@ def experiment(
     )
 
     evaluation_stop_timestamp = datetime.now()
-    logger.info(f"Ended prediction in {evaluation_stop_timestamp}")
+    logger.info(f"Ended prediction at {evaluation_stop_timestamp}")
     logger.info(
         f"Prediction runtime: {evaluation_stop_timestamp-evaluation_start_timestamp}"
     )
@@ -170,7 +169,7 @@ def experiment(
 
     saliency_map_generation_stop_timestamp = datetime.now()
     logger.info(
-        f"Ended saliency map generation in {saliency_map_generation_stop_timestamp}"
+        f"Ended saliency map generation at {saliency_map_generation_stop_timestamp}"
     )
     logger.info(
         f"Saliency map generation runtime: {saliency_map_generation_stop_timestamp-saliency_map_generation_start_timestamp}"
@@ -195,7 +194,7 @@ def experiment(
 
     saliency_images_generation_stop_timestamp = datetime.now()
     logger.info(
-        f"Ended saliency images generation in {saliency_images_generation_stop_timestamp}"
+        f"Ended saliency images generation at {saliency_images_generation_stop_timestamp}"
     )
     logger.info(
         f"Saliency images generation runtime: {saliency_images_generation_stop_timestamp-saliency_images_generation_start_timestamp}"
diff --git a/src/mednet/scripts/train_analysis.py b/src/mednet/scripts/train_analysis.py
index 902bc48e..80e74e44 100644
--- a/src/mednet/scripts/train_analysis.py
+++ b/src/mednet/scripts/train_analysis.py
@@ -3,18 +3,128 @@
 # SPDX-License-Identifier: GPL-3.0-or-later
 
 import pathlib
+import typing
 
 import click
 from clapper.click import ResourceOption, verbosity_option
 from clapper.logging import setup
 
+from .click import ConfigCommand
+
 # avoids X11/graphical desktop requirement when creating plots
 __import__("matplotlib").use("agg")
 
 logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
 
 
-def create_figures(
+def _create_generic_figure(
+    curves: dict[str, tuple[list[int], list[float]]],
+    group: str,
+) -> tuple:
+    """Create a generic figure showing the evolution of a metric.
+
+    This function will create a generic figure (one-size-fits-all kind of
+    style) of a given metric across epochs.
+
+    Parameters
+    ----------
+    curves
+        A dictionary where keys represent all scalar names, and values
+        correspond to a tuple that contains an array with epoch numbers (when
+        values were taken), and the monitored values themselves.  These lists
+        are pre-sorted by epoch number.
+    group
+        A scalar globs present in the existing tensorboard data that
+        we are interested in for plotting.
+
+    Returns
+    -------
+        A matplotlib figure and its axes.
+    """
+
+    import matplotlib.pyplot as plt
+    from matplotlib.axes import Axes
+    from matplotlib.figure import Figure
+    from matplotlib.ticker import MaxNLocator
+
+    fig, ax = plt.subplots(1, 1)
+    ax = typing.cast(Axes, ax)
+    fig = typing.cast(Figure, fig)
+
+    if len(curves) == 1:
+        # there is only one curve, just plot it
+        title, (epochs, values) = next(iter(curves.items()))
+        ax.plot(epochs, values)
+
+    else:
+        # this is an aggregate plot, name things consistently
+        labels = {k: k[len(group) - 1 :] for k in curves.keys()}
+        title = group.rstrip("*").rstrip("/")
+        for key, (epochs, values) in curves.items():
+            ax.plot(epochs, values, label=labels[key])
+        ax.legend(loc="best")
+
+    ax.xaxis.set_major_locator(MaxNLocator(integer=True))
+    ax.set_title(title)
+    ax.set_xlabel("Epoch")
+    ax.set_ylabel(title)
+
+    ax.grid(alpha=0.3)
+    fig.tight_layout()
+
+    return fig, ax
+
+
+def _create_loss_figure(
+    curves: dict[str, tuple[list[int], list[float]]],
+    group: str,
+) -> tuple:
+    """Create a specific figure of the loss evolution.
+
+    This function will create a specific, more detailed figure of the loss
+    evolution plot where the curves will be enriched with vertical lines
+    indicating the points where the lowest validation losses are detected.  The
+    higher the point oppacity, the lower is the loss on the validation set.
+
+    Parameters
+    ----------
+    curves
+        A dictionary where keys represent all scalar names, and values
+        correspond to a tuple that contains an array with epoch numbers (when
+        values were taken), and the monitored values themselves.  These lists
+        are pre-sorted by epoch number.
+    group
+        A scalar globs present in the existing tensorboard data that
+        we are interested in for plotting.
+
+    Returns
+    -------
+        A matplotlib figure and its axes.
+    """
+
+    fig, ax = _create_generic_figure(curves, group)
+
+    if "loss/validation" in curves:
+        points = sorted(zip(*curves["loss/validation"]), key=lambda x: x[1])[:4]
+
+        # create the highlights for each point, with fading colours
+        alpha_decay = 0.5
+        alpha = 1.0
+        for x, y in points:
+            ax.axvline(
+                x=x,
+                color="red",
+                alpha=alpha,
+                linestyle=":",
+                label=f"epoch {x} (val={y:.3g})",
+            )
+            alpha *= 1 - alpha_decay
+
+    ax.legend(loc="best")
+    return fig, ax
+
+
+def _create_figures(
     data: dict[str, tuple[list[int], list[float]]],
     groups: list[str] = [
         "loss/*",
@@ -56,12 +166,6 @@ def create_figures(
     """
 
     import fnmatch
-    import typing
-
-    import matplotlib.pyplot as plt
-    from matplotlib.axes import Axes
-    from matplotlib.figure import Figure
-    from matplotlib.ticker import MaxNLocator
 
     figures = []
 
@@ -71,36 +175,18 @@ def create_figures(
         if len(curves) == 0:
             continue
 
-        fig, ax = plt.subplots(1, 1)
-        ax = typing.cast(Axes, ax)
-        fig = typing.cast(Figure, fig)
-
-        if len(curves) == 1:
-            # there is only one curve, just plot it
-            title, (epochs, values) = next(iter(curves.items()))
-            ax.plot(epochs, values)
-
+        if group == "loss/*":
+            fig, _ = _create_loss_figure(curves, group)
+            figures.append(fig)
         else:
-            # this is an aggregate plot, name things consistently
-            labels = {k: k[len(group) - 1 :] for k in curves.keys()}
-            title = group.rstrip("*").rstrip("/")
-            for key, (epochs, values) in curves.items():
-                ax.plot(epochs, values, label=labels[key])
-            ax.legend(loc="best")
-
-        ax.xaxis.set_major_locator(MaxNLocator(integer=True))
-        ax.set_title(title)
-        ax.set_xlabel("Epoch")
-        ax.set_ylabel(title)
-
-        ax.grid(alpha=0.3)
-        fig.tight_layout()
-        figures.append(fig)
+            fig, _ = _create_generic_figure(curves, group)
+            figures.append(fig)
 
     return figures
 
 
 @click.command(
+    cls=ConfigCommand,
     epilog="""Examples:
 
 \b
@@ -108,7 +194,7 @@ def create_figures(
 
        .. code:: sh
 
-          mednet train-analysis -vv results/logs
+          mednet train-analysis -vv --log-dir=results/logs --output=trainlog.pdf
 """,
 )
 @click.option(
@@ -117,25 +203,29 @@ def create_figures(
     help="Path to the directory containing the Tensorboard training logs",
     required=True,
     type=click.Path(dir_okay=True, exists=True, path_type=pathlib.Path),
+    default="results/logs",
+    cls=ResourceOption,
 )
 @click.option(
-    "--output-folder",
+    "--output",
     "-o",
-    help="Directory in which to store results (created if does not exist)",
+    help="""Path to a PDF file in which to store results. (leading
+    directories are created if they do not exist).""",
     required=True,
+    default="trainlog.pdf",
+    cls=ResourceOption,
     type=click.Path(
-        file_okay=False,
-        dir_okay=True,
+        file_okay=True,
+        dir_okay=False,
         writable=True,
         path_type=pathlib.Path,
     ),
-    default="results",
-    cls=ResourceOption,
 )
 @verbosity_option(logger=logger, expose_value=False)
 def train_analysis(
     logdir: pathlib.Path,
-    output_folder: pathlib.Path,
+    output: pathlib.Path,
+    **_,
 ) -> None:  # numpydoc ignore=PR01
     """Create a plot for each metric in the training logs and saves them in a .pdf file."""
     import matplotlib.pyplot as plt
@@ -143,14 +233,11 @@ def train_analysis(
 
     from ..utils.tensorboard import scalars_to_dict
 
-    train_log_filename = "trainlog.pdf"
-    train_log_file = pathlib.Path(output_folder) / train_log_filename
-
     data = scalars_to_dict(logdir)
 
-    train_log_file.parent.mkdir(parents=True, exist_ok=True)
+    output.parent.mkdir(parents=True, exist_ok=True)
 
-    with PdfPages(train_log_file) as pdf:
-        for figure in create_figures(data):
+    with PdfPages(output) as pdf:
+        for figure in _create_figures(data):
             pdf.savefig(figure)
             plt.close(figure)
-- 
GitLab