From 0aadc7eab6e04eaa3329efc99a0157a669dafd17 Mon Sep 17 00:00:00 2001 From: Andre Anjos <andre.dos.anjos@gmail.com> Date: Wed, 22 May 2024 13:58:40 +0200 Subject: [PATCH] [scripts.train_analysis] Add indication of lowest validation loss epochs on loss plots (closes #70) --- src/mednet/engine/callbacks.py | 6 +- src/mednet/scripts/experiment.py | 15 ++- src/mednet/scripts/train_analysis.py | 177 ++++++++++++++++++++------- 3 files changed, 142 insertions(+), 56 deletions(-) diff --git a/src/mednet/engine/callbacks.py b/src/mednet/engine/callbacks.py index b1b16f65..c1f66d8c 100644 --- a/src/mednet/engine/callbacks.py +++ b/src/mednet/engine/callbacks.py @@ -40,9 +40,9 @@ class LoggingCallback(lightning.pytorch.Callback): super().__init__() # timers - self._start_training_time = 0.0 - self._start_training_epoch_time = 0.0 - self._start_validation_epoch_time = 0.0 + self._start_training_time = time.time() + self._start_training_epoch_time = time.time() + self._start_validation_epoch_time = time.time() # log accumulators for a single flush at each training cycle self._to_log: dict[str, float] = {} diff --git a/src/mednet/scripts/experiment.py b/src/mednet/scripts/experiment.py index 538ddf63..67e16b64 100644 --- a/src/mednet/scripts/experiment.py +++ b/src/mednet/scripts/experiment.py @@ -92,7 +92,7 @@ def experiment( ) train_stop_timestamp = datetime.now() - logger.info(f"Ended training in {train_stop_timestamp}") + logger.info(f"Ended training at {train_stop_timestamp}") logger.info( f"Training runtime: {train_stop_timestamp-train_start_timestamp}" ) @@ -100,11 +100,10 @@ def experiment( logger.info("Started train analysis") from .train_analysis import train_analysis - logdir = train_output_folder / "logs" ctx.invoke( train_analysis, - logdir=logdir, - output_folder=train_output_folder, + logdir=train_output_folder / "logs", + output=output_folder / "trainlog.pdf", ) logger.info("Ended train analysis") @@ -128,7 +127,7 @@ def experiment( ) predict_stop_timestamp = datetime.now() - logger.info(f"Ended prediction in {predict_stop_timestamp}") + logger.info(f"Ended prediction at {predict_stop_timestamp}") logger.info( f"Prediction runtime: {predict_stop_timestamp-predict_start_timestamp}" ) @@ -146,7 +145,7 @@ def experiment( ) evaluation_stop_timestamp = datetime.now() - logger.info(f"Ended prediction in {evaluation_stop_timestamp}") + logger.info(f"Ended prediction at {evaluation_stop_timestamp}") logger.info( f"Prediction runtime: {evaluation_stop_timestamp-evaluation_start_timestamp}" ) @@ -170,7 +169,7 @@ def experiment( saliency_map_generation_stop_timestamp = datetime.now() logger.info( - f"Ended saliency map generation in {saliency_map_generation_stop_timestamp}" + f"Ended saliency map generation at {saliency_map_generation_stop_timestamp}" ) logger.info( f"Saliency map generation runtime: {saliency_map_generation_stop_timestamp-saliency_map_generation_start_timestamp}" @@ -195,7 +194,7 @@ def experiment( saliency_images_generation_stop_timestamp = datetime.now() logger.info( - f"Ended saliency images generation in {saliency_images_generation_stop_timestamp}" + f"Ended saliency images generation at {saliency_images_generation_stop_timestamp}" ) logger.info( f"Saliency images generation runtime: {saliency_images_generation_stop_timestamp-saliency_images_generation_start_timestamp}" diff --git a/src/mednet/scripts/train_analysis.py b/src/mednet/scripts/train_analysis.py index 902bc48e..80e74e44 100644 --- a/src/mednet/scripts/train_analysis.py +++ b/src/mednet/scripts/train_analysis.py @@ -3,18 +3,128 @@ # SPDX-License-Identifier: GPL-3.0-or-later import pathlib +import typing import click from clapper.click import ResourceOption, verbosity_option from clapper.logging import setup +from .click import ConfigCommand + # avoids X11/graphical desktop requirement when creating plots __import__("matplotlib").use("agg") logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") -def create_figures( +def _create_generic_figure( + curves: dict[str, tuple[list[int], list[float]]], + group: str, +) -> tuple: + """Create a generic figure showing the evolution of a metric. + + This function will create a generic figure (one-size-fits-all kind of + style) of a given metric across epochs. + + Parameters + ---------- + curves + A dictionary where keys represent all scalar names, and values + correspond to a tuple that contains an array with epoch numbers (when + values were taken), and the monitored values themselves. These lists + are pre-sorted by epoch number. + group + A scalar globs present in the existing tensorboard data that + we are interested in for plotting. + + Returns + ------- + A matplotlib figure and its axes. + """ + + import matplotlib.pyplot as plt + from matplotlib.axes import Axes + from matplotlib.figure import Figure + from matplotlib.ticker import MaxNLocator + + fig, ax = plt.subplots(1, 1) + ax = typing.cast(Axes, ax) + fig = typing.cast(Figure, fig) + + if len(curves) == 1: + # there is only one curve, just plot it + title, (epochs, values) = next(iter(curves.items())) + ax.plot(epochs, values) + + else: + # this is an aggregate plot, name things consistently + labels = {k: k[len(group) - 1 :] for k in curves.keys()} + title = group.rstrip("*").rstrip("/") + for key, (epochs, values) in curves.items(): + ax.plot(epochs, values, label=labels[key]) + ax.legend(loc="best") + + ax.xaxis.set_major_locator(MaxNLocator(integer=True)) + ax.set_title(title) + ax.set_xlabel("Epoch") + ax.set_ylabel(title) + + ax.grid(alpha=0.3) + fig.tight_layout() + + return fig, ax + + +def _create_loss_figure( + curves: dict[str, tuple[list[int], list[float]]], + group: str, +) -> tuple: + """Create a specific figure of the loss evolution. + + This function will create a specific, more detailed figure of the loss + evolution plot where the curves will be enriched with vertical lines + indicating the points where the lowest validation losses are detected. The + higher the point oppacity, the lower is the loss on the validation set. + + Parameters + ---------- + curves + A dictionary where keys represent all scalar names, and values + correspond to a tuple that contains an array with epoch numbers (when + values were taken), and the monitored values themselves. These lists + are pre-sorted by epoch number. + group + A scalar globs present in the existing tensorboard data that + we are interested in for plotting. + + Returns + ------- + A matplotlib figure and its axes. + """ + + fig, ax = _create_generic_figure(curves, group) + + if "loss/validation" in curves: + points = sorted(zip(*curves["loss/validation"]), key=lambda x: x[1])[:4] + + # create the highlights for each point, with fading colours + alpha_decay = 0.5 + alpha = 1.0 + for x, y in points: + ax.axvline( + x=x, + color="red", + alpha=alpha, + linestyle=":", + label=f"epoch {x} (val={y:.3g})", + ) + alpha *= 1 - alpha_decay + + ax.legend(loc="best") + return fig, ax + + +def _create_figures( data: dict[str, tuple[list[int], list[float]]], groups: list[str] = [ "loss/*", @@ -56,12 +166,6 @@ def create_figures( """ import fnmatch - import typing - - import matplotlib.pyplot as plt - from matplotlib.axes import Axes - from matplotlib.figure import Figure - from matplotlib.ticker import MaxNLocator figures = [] @@ -71,36 +175,18 @@ def create_figures( if len(curves) == 0: continue - fig, ax = plt.subplots(1, 1) - ax = typing.cast(Axes, ax) - fig = typing.cast(Figure, fig) - - if len(curves) == 1: - # there is only one curve, just plot it - title, (epochs, values) = next(iter(curves.items())) - ax.plot(epochs, values) - + if group == "loss/*": + fig, _ = _create_loss_figure(curves, group) + figures.append(fig) else: - # this is an aggregate plot, name things consistently - labels = {k: k[len(group) - 1 :] for k in curves.keys()} - title = group.rstrip("*").rstrip("/") - for key, (epochs, values) in curves.items(): - ax.plot(epochs, values, label=labels[key]) - ax.legend(loc="best") - - ax.xaxis.set_major_locator(MaxNLocator(integer=True)) - ax.set_title(title) - ax.set_xlabel("Epoch") - ax.set_ylabel(title) - - ax.grid(alpha=0.3) - fig.tight_layout() - figures.append(fig) + fig, _ = _create_generic_figure(curves, group) + figures.append(fig) return figures @click.command( + cls=ConfigCommand, epilog="""Examples: \b @@ -108,7 +194,7 @@ def create_figures( .. code:: sh - mednet train-analysis -vv results/logs + mednet train-analysis -vv --log-dir=results/logs --output=trainlog.pdf """, ) @click.option( @@ -117,25 +203,29 @@ def create_figures( help="Path to the directory containing the Tensorboard training logs", required=True, type=click.Path(dir_okay=True, exists=True, path_type=pathlib.Path), + default="results/logs", + cls=ResourceOption, ) @click.option( - "--output-folder", + "--output", "-o", - help="Directory in which to store results (created if does not exist)", + help="""Path to a PDF file in which to store results. (leading + directories are created if they do not exist).""", required=True, + default="trainlog.pdf", + cls=ResourceOption, type=click.Path( - file_okay=False, - dir_okay=True, + file_okay=True, + dir_okay=False, writable=True, path_type=pathlib.Path, ), - default="results", - cls=ResourceOption, ) @verbosity_option(logger=logger, expose_value=False) def train_analysis( logdir: pathlib.Path, - output_folder: pathlib.Path, + output: pathlib.Path, + **_, ) -> None: # numpydoc ignore=PR01 """Create a plot for each metric in the training logs and saves them in a .pdf file.""" import matplotlib.pyplot as plt @@ -143,14 +233,11 @@ def train_analysis( from ..utils.tensorboard import scalars_to_dict - train_log_filename = "trainlog.pdf" - train_log_file = pathlib.Path(output_folder) / train_log_filename - data = scalars_to_dict(logdir) - train_log_file.parent.mkdir(parents=True, exist_ok=True) + output.parent.mkdir(parents=True, exist_ok=True) - with PdfPages(train_log_file) as pdf: - for figure in create_figures(data): + with PdfPages(output) as pdf: + for figure in _create_figures(data): pdf.savefig(figure) plt.close(figure) -- GitLab