Skip to content
Snippets Groups Projects
Commit d354388c authored by André Anjos's avatar André Anjos :speech_balloon:
Browse files

Merge branch 'issue-70' into 'main'

Add indication of lowest validation loss epochs on loss plots

Closes #70

See merge request biosignal/software/mednet!39
parents dad878f3 41f1f593
No related branches found
No related tags found
1 merge request!39Add indication of lowest validation loss epochs on loss plots
Pipeline #87559 canceled
......@@ -40,9 +40,9 @@ class LoggingCallback(lightning.pytorch.Callback):
super().__init__()
# timers
self._start_training_time = 0.0
self._start_training_epoch_time = 0.0
self._start_validation_epoch_time = 0.0
self._start_training_time = time.time()
self._start_training_epoch_time = time.time()
self._start_validation_epoch_time = time.time()
# log accumulators for a single flush at each training cycle
self._to_log: dict[str, float] = {}
......@@ -255,7 +255,7 @@ class LoggingCallback(lightning.pytorch.Callback):
f"interval to a suitable value, so it allows some measures "
f"to be performed. Note this is only possible if the time "
f"to log a single measurement point is smaller than the "
f"time it takes to train a single epoch."
f"time it takes to **train** a single epoch."
)
else:
for metric_name, metric_value in aggregate(metrics).items():
......@@ -308,7 +308,7 @@ class LoggingCallback(lightning.pytorch.Callback):
f"interval to a suitable value, so it allows some measures "
f"to be performed. Note this is only possible if the time "
f"to log a single measurement point is smaller than the "
f"time it takes to train a single epoch."
f"time it takes to **validate** a single epoch."
)
else:
for metric_name, metric_value in aggregate(metrics).items():
......
......@@ -92,7 +92,7 @@ def experiment(
)
train_stop_timestamp = datetime.now()
logger.info(f"Ended training in {train_stop_timestamp}")
logger.info(f"Ended training at {train_stop_timestamp}")
logger.info(
f"Training runtime: {train_stop_timestamp-train_start_timestamp}"
)
......@@ -100,11 +100,10 @@ def experiment(
logger.info("Started train analysis")
from .train_analysis import train_analysis
logdir = train_output_folder / "logs"
ctx.invoke(
train_analysis,
logdir=logdir,
output_folder=train_output_folder,
logdir=train_output_folder / "logs",
output=output_folder / "trainlog.pdf",
)
logger.info("Ended train analysis")
......@@ -128,7 +127,7 @@ def experiment(
)
predict_stop_timestamp = datetime.now()
logger.info(f"Ended prediction in {predict_stop_timestamp}")
logger.info(f"Ended prediction at {predict_stop_timestamp}")
logger.info(
f"Prediction runtime: {predict_stop_timestamp-predict_start_timestamp}"
)
......@@ -146,7 +145,7 @@ def experiment(
)
evaluation_stop_timestamp = datetime.now()
logger.info(f"Ended prediction in {evaluation_stop_timestamp}")
logger.info(f"Ended prediction at {evaluation_stop_timestamp}")
logger.info(
f"Prediction runtime: {evaluation_stop_timestamp-evaluation_start_timestamp}"
)
......@@ -170,7 +169,7 @@ def experiment(
saliency_map_generation_stop_timestamp = datetime.now()
logger.info(
f"Ended saliency map generation in {saliency_map_generation_stop_timestamp}"
f"Ended saliency map generation at {saliency_map_generation_stop_timestamp}"
)
logger.info(
f"Saliency map generation runtime: {saliency_map_generation_stop_timestamp-saliency_map_generation_start_timestamp}"
......@@ -195,7 +194,7 @@ def experiment(
saliency_images_generation_stop_timestamp = datetime.now()
logger.info(
f"Ended saliency images generation in {saliency_images_generation_stop_timestamp}"
f"Ended saliency images generation at {saliency_images_generation_stop_timestamp}"
)
logger.info(
f"Saliency images generation runtime: {saliency_images_generation_stop_timestamp-saliency_images_generation_start_timestamp}"
......
......@@ -3,18 +3,128 @@
# SPDX-License-Identifier: GPL-3.0-or-later
import pathlib
import typing
import click
from clapper.click import ResourceOption, verbosity_option
from clapper.logging import setup
from .click import ConfigCommand
# avoids X11/graphical desktop requirement when creating plots
__import__("matplotlib").use("agg")
logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
def create_figures(
def _create_generic_figure(
curves: dict[str, tuple[list[int], list[float]]],
group: str,
) -> tuple:
"""Create a generic figure showing the evolution of a metric.
This function will create a generic figure (one-size-fits-all kind of
style) of a given metric across epochs.
Parameters
----------
curves
A dictionary where keys represent all scalar names, and values
correspond to a tuple that contains an array with epoch numbers (when
values were taken), and the monitored values themselves. These lists
are pre-sorted by epoch number.
group
A scalar globs present in the existing tensorboard data that
we are interested in for plotting.
Returns
-------
A matplotlib figure and its axes.
"""
import matplotlib.pyplot as plt
from matplotlib.axes import Axes
from matplotlib.figure import Figure
from matplotlib.ticker import MaxNLocator
fig, ax = plt.subplots(1, 1)
ax = typing.cast(Axes, ax)
fig = typing.cast(Figure, fig)
if len(curves) == 1:
# there is only one curve, just plot it
title, (epochs, values) = next(iter(curves.items()))
ax.plot(epochs, values)
else:
# this is an aggregate plot, name things consistently
labels = {k: k[len(group) - 1 :] for k in curves.keys()}
title = group.rstrip("*").rstrip("/")
for key, (epochs, values) in curves.items():
ax.plot(epochs, values, label=labels[key])
ax.legend(loc="best")
ax.xaxis.set_major_locator(MaxNLocator(integer=True))
ax.set_title(title)
ax.set_xlabel("Epoch")
ax.set_ylabel(title)
ax.grid(alpha=0.3)
fig.tight_layout()
return fig, ax
def _create_loss_figure(
curves: dict[str, tuple[list[int], list[float]]],
group: str,
) -> tuple:
"""Create a specific figure of the loss evolution.
This function will create a specific, more detailed figure of the loss
evolution plot where the curves will be enriched with vertical lines
indicating the points where the lowest validation losses are detected. The
higher the point oppacity, the lower is the loss on the validation set.
Parameters
----------
curves
A dictionary where keys represent all scalar names, and values
correspond to a tuple that contains an array with epoch numbers (when
values were taken), and the monitored values themselves. These lists
are pre-sorted by epoch number.
group
A scalar globs present in the existing tensorboard data that
we are interested in for plotting.
Returns
-------
A matplotlib figure and its axes.
"""
fig, ax = _create_generic_figure(curves, group)
if "loss/validation" in curves:
points = sorted(zip(*curves["loss/validation"]), key=lambda x: x[1])[:4]
# create the highlights for each point, with fading colours
alpha_decay = 0.5
alpha = 1.0
for x, y in points:
ax.axvline(
x=x,
color="red",
alpha=alpha,
linestyle=":",
label=f"epoch {x} (val={y:.3g})",
)
alpha *= 1 - alpha_decay
ax.legend(loc="best")
return fig, ax
def _create_figures(
data: dict[str, tuple[list[int], list[float]]],
groups: list[str] = [
"loss/*",
......@@ -56,12 +166,6 @@ def create_figures(
"""
import fnmatch
import typing
import matplotlib.pyplot as plt
from matplotlib.axes import Axes
from matplotlib.figure import Figure
from matplotlib.ticker import MaxNLocator
figures = []
......@@ -71,36 +175,18 @@ def create_figures(
if len(curves) == 0:
continue
fig, ax = plt.subplots(1, 1)
ax = typing.cast(Axes, ax)
fig = typing.cast(Figure, fig)
if len(curves) == 1:
# there is only one curve, just plot it
title, (epochs, values) = next(iter(curves.items()))
ax.plot(epochs, values)
if group == "loss/*":
fig, _ = _create_loss_figure(curves, group)
figures.append(fig)
else:
# this is an aggregate plot, name things consistently
labels = {k: k[len(group) - 1 :] for k in curves.keys()}
title = group.rstrip("*").rstrip("/")
for key, (epochs, values) in curves.items():
ax.plot(epochs, values, label=labels[key])
ax.legend(loc="best")
ax.xaxis.set_major_locator(MaxNLocator(integer=True))
ax.set_title(title)
ax.set_xlabel("Epoch")
ax.set_ylabel(title)
ax.grid(alpha=0.3)
fig.tight_layout()
figures.append(fig)
fig, _ = _create_generic_figure(curves, group)
figures.append(fig)
return figures
@click.command(
cls=ConfigCommand,
epilog="""Examples:
\b
......@@ -108,7 +194,7 @@ def create_figures(
.. code:: sh
mednet train-analysis -vv results/logs
mednet train-analysis -vv --log-dir=results/logs --output=trainlog.pdf
""",
)
@click.option(
......@@ -117,25 +203,29 @@ def create_figures(
help="Path to the directory containing the Tensorboard training logs",
required=True,
type=click.Path(dir_okay=True, exists=True, path_type=pathlib.Path),
default="results/logs",
cls=ResourceOption,
)
@click.option(
"--output-folder",
"--output",
"-o",
help="Directory in which to store results (created if does not exist)",
help="""Path to a PDF file in which to store results. (leading
directories are created if they do not exist).""",
required=True,
default="trainlog.pdf",
cls=ResourceOption,
type=click.Path(
file_okay=False,
dir_okay=True,
file_okay=True,
dir_okay=False,
writable=True,
path_type=pathlib.Path,
),
default="results",
cls=ResourceOption,
)
@verbosity_option(logger=logger, expose_value=False)
def train_analysis(
logdir: pathlib.Path,
output_folder: pathlib.Path,
output: pathlib.Path,
**_,
) -> None: # numpydoc ignore=PR01
"""Create a plot for each metric in the training logs and saves them in a .pdf file."""
import matplotlib.pyplot as plt
......@@ -143,14 +233,11 @@ def train_analysis(
from ..utils.tensorboard import scalars_to_dict
train_log_filename = "trainlog.pdf"
train_log_file = pathlib.Path(output_folder) / train_log_filename
data = scalars_to_dict(logdir)
train_log_file.parent.mkdir(parents=True, exist_ok=True)
output.parent.mkdir(parents=True, exist_ok=True)
with PdfPages(train_log_file) as pdf:
for figure in create_figures(data):
with PdfPages(output) as pdf:
for figure in _create_figures(data):
pdf.savefig(figure)
plt.close(figure)
......@@ -478,7 +478,7 @@ def test_experiment(temporary_basedir):
)
== 1
)
assert (output_folder / "model" / "trainlog.pdf").exists()
assert (output_folder / "trainlog.pdf").exists()
assert (
len(
list(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment