diff --git a/src/ptbench/scripts/experiment.py b/src/ptbench/scripts/experiment.py
index 3c43db6a286e7f5d46dc1197867b248aed87dbcd..bbfe5b86ed0a9aa1726cdedea904854915da2be5 100644
--- a/src/ptbench/scripts/experiment.py
+++ b/src/ptbench/scripts/experiment.py
@@ -297,8 +297,22 @@ def experiment(
         resume_from=resume_from,
         balance_classes=balance_classes,
     )
+
     logger.info("Ended training")
 
+    logger.info("Started train analysis")
+    from .train_analysis import train_analysis
+
+    logdir = os.path.join(train_output_folder, "logs")
+    output_pdf = os.path.join(train_output_folder, "train_analysis.pdf")
+    ctx.invoke(
+        train_analysis,
+        logdir=logdir,
+        output_pdf=output_pdf,
+    )
+
+    logger.info("Ended train analysis")
+
     logger.info("Started predicting")
 
     from .predict import predict
diff --git a/src/ptbench/scripts/train_analysis.py b/src/ptbench/scripts/train_analysis.py
index f33cd9abfda8a3952281dd83b6eeeb0e66276ebc..4061164f7112f363730e6d0af5318b3141bd74f5 100644
--- a/src/ptbench/scripts/train_analysis.py
+++ b/src/ptbench/scripts/train_analysis.py
@@ -8,126 +8,79 @@ import os
 
 import click
 import matplotlib.pyplot as plt
+import pandas
 
 from clapper.click import ConfigCommand, ResourceOption, verbosity_option
 from clapper.logging import setup
+from matplotlib.ticker import MaxNLocator
 
 logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
 
 
-def _loss_evolution(df):
-    """Plots the loss evolution over time (epochs)
+def create_figures(df: pandas.DataFrame) -> list[plt.figure]:
+    """Generates figures for each metric in the dataframe.
+
+    Each row of the dataframe correspond to an epoch and each column to a metric.
+    It is assumed that some metric names are of the form <metric>/<subset>.
+    All subsets for a metric will be displayed on the same figure.
 
     Parameters
     ----------
 
-        df : pandas.DataFrame
-            dataframe containing the training logs
-
+    df:
+        Pandas dataframe containing the data to plot.
 
     Returns
     -------
 
-        matplotlib.figure.Figure: Figure to be displayed or saved to file
+    figures:
+        List of matplotlib figures, one per metric.
     """
-    import numpy
-
-    figure = plt.figure()
-    axes = figure.gca()
-
-    axes.plot(df.epoch.values, df.loss.values, label="Training")
-    if "validation_loss" in df.columns:
-        axes.plot(
-            df.epoch.values, df.validation_loss.values, label="Validation"
-        )
-        # shows a red dot on the location with the minima on the validation set
-        lowest_index = numpy.argmin(df["validation_loss"])
-
-        axes.plot(
-            df.epoch.values[lowest_index],
-            df.validation_loss[lowest_index],
-            "mo",
-            label=f"Lowest validation ({df.validation_loss[lowest_index]:.3f}@{df.epoch[lowest_index]})",
-        )
-
-    if "extra_validation_losses" in df.columns:
-        # These losses are in array format. So, we read all rows, then create a
-        # 2d array.  We transpose the array to iterate over each column and
-        # plot the losses individually.  They are numbered from 1.
-        df["extra_validation_losses"] = df["extra_validation_losses"].apply(
-            lambda x: numpy.fromstring(x.strip("[]"), sep=" ")
-        )
-        losses = numpy.vstack(df.extra_validation_losses.values).T
-        for n, k in enumerate(losses):
-            axes.plot(df.epoch.values, k, label=f"Extra validation {n+1}")
-
-    axes.set_title("Loss over time")
-    axes.set_xlabel("Epoch")
-    axes.set_ylabel("Loss")
-
-    axes.legend(loc="best")
-    axes.grid(alpha=0.3)
-    figure.set_layout_engine("tight")
-
-    return figure
-
-
-def _hardware_utilisation(df, const):
-    """Plot the CPU utilisation over time (epochs).
 
-    Parameters
-    ----------
+    figures = []
 
-        df : pandas.DataFrame
-            dataframe containing the training logs
+    labels = sorted(df.columns)
+    from collections import defaultdict
 
-        const : dict
-            training and hardware constants
+    # Dict of metric: subset. Subset can be None.
+    metrics_groups = defaultdict(list)
 
+    for label in labels:
+        # Separate the name of the subset from the metric
+        split_label = label.rsplit("/", 1)
+        metric = split_label[0]
+        subset = split_label[1] if len(split_label) > 1 else None
+        metrics_groups[metric].append(subset)
 
-    Returns
-    -------
+    for metric, subsets in metrics_groups.items():
+        figure = plt.figure()
+        axes = figure.gca()
 
-        matplotlib.figure.Figure: figure to be displayed or saved to file
-    """
-    figure = plt.figure()
-    axes = figure.gca()
-
-    cpu_percent = df.cpu_percent.values / const["cpu_count"]
-    cpu_memory = 100 * df.cpu_rss / const["cpu_memory_total"]
-
-    axes.plot(
-        df.epoch.values,
-        cpu_percent,
-        label=f"CPU usage (cores: {const['cpu_count']})",
-    )
-    axes.plot(
-        df.epoch.values,
-        cpu_memory,
-        label=f"CPU memory (total: {const['cpu_memory_total']:.1f} Gb)",
-    )
-    if "gpu_percent" in df:
-        axes.plot(
-            df.epoch.values,
-            df.gpu_percent.values,
-            label=f"GPU usage (type: {const['gpu_name']})",
-        )
-    if "gpu_memory_percent" in df:
-        axes.plot(
-            df.epoch.values,
-            df.gpu_memory_percent.values,
-            label=f"GPU memory (total: {const['gpu_memory_total']:.1f} Gb)",
-        )
-    axes.set_title("Hardware utilisation over time")
-    axes.set_xlabel("Epoch")
-    axes.set_ylabel("Relative utilisation (%)")
-    axes.set_ylim([0, 100])
-
-    axes.legend(loc="best")
-    axes.grid(alpha=0.3)
-    figure.set_layout_engine("tight")
-
-    return figure
+        for subset in subsets:
+            if subset is None:
+                axes.plot(
+                    df["step"].values,
+                    df[metric],
+                    label=metric,
+                )
+            else:
+                axes.plot(
+                    df["step"].values,
+                    df[metric + "/" + subset],
+                    label=subset,
+                )
+
+        axes.xaxis.set_major_locator(MaxNLocator(integer=True))
+        axes.set_title(metric)
+        axes.set_xlabel("Epoch")
+
+        axes.legend(loc="best")
+        axes.grid(alpha=0.3)
+        figure.set_layout_engine("tight")
+
+        figures.append(figure)
+
+    return figures
 
 
 @click.command(
@@ -140,17 +93,12 @@ def _hardware_utilisation(df, const):
 
        .. code:: sh
 
-          ptbench train-analysis -vv log.csv constants.csv
-
+          ptbench train-analysis -vv results/logs
 """,
 )
 @click.argument(
-    "log",
-    type=click.Path(dir_okay=False, exists=True),
-)
-@click.argument(
-    "constants",
-    type=click.Path(dir_okay=False, exists=True),
+    "logdir",
+    type=click.Path(dir_okay=True, exists=True),
 )
 @click.option(
     "--output-pdf",
@@ -162,33 +110,35 @@ def _hardware_utilisation(df, const):
 )
 @verbosity_option(logger=logger, cls=ResourceOption, expose_value=False)
 def train_analysis(
-    log,
-    constants,
-    output_pdf,
+    logdir: str,
+    output_pdf: str,
     **_,
-):
-    """Analyzes the training logs for loss evolution and resource
-    utilisation."""
+) -> None:
+    """Creates a plot for each metric in the training logs and saves them in a
+    pdf file.
+
+    Parameters
+    ----------
+
+    logdir:
+        Directory containing tensorboard event files.
 
-    import pandas
+    output_pdf:
+        The pdf file in which to save the plots.
+    """
 
     from matplotlib.backends.backend_pdf import PdfPages
 
-    constants = pandas.read_csv(constants)
-    constants = dict(zip(constants.keys(), constants.values[0]))
-    data = pandas.read_csv(log)
+    from ..utils.tensorboard import get_scalars
+
+    data = get_scalars(logdir)
 
     # makes sure the directory to save the output PDF is there
     dirname = os.path.dirname(os.path.realpath(output_pdf))
     if not os.path.exists(dirname):
         os.makedirs(dirname)
 
-    # now, do the analysis
     with PdfPages(output_pdf) as pdf:
-        figure = _loss_evolution(data)
-        pdf.savefig(figure)
-        plt.close(figure)
-
-        figure = _hardware_utilisation(data, constants)
-        pdf.savefig(figure)
-        plt.close(figure)
+        for figure in create_figures(data):
+            pdf.savefig(figure)
+            plt.close(figure)
diff --git a/src/ptbench/utils/tensorboard.py b/src/ptbench/utils/tensorboard.py
new file mode 100644
index 0000000000000000000000000000000000000000..6cd0d1e68b1313e44e95374ff938071b8d0b1dae
--- /dev/null
+++ b/src/ptbench/utils/tensorboard.py
@@ -0,0 +1,55 @@
+import glob
+import os
+
+from collections import defaultdict
+from typing import Any
+
+import pandas
+
+from tensorboard.backend.event_processing.event_accumulator import (
+    EventAccumulator,
+)
+
+
+def get_scalars(logdir: str) -> pandas.DataFrame:
+    """Returns scalars stored in tensorboard event files.
+
+    Parameters
+    ----------
+
+    logdir:
+        Directory containing the event files.
+
+    Returns
+    -------
+
+    data:
+        Pandas dataframe containing the results. Rows correspond to an epoch, columns to the metrics.
+    """
+    tensorboard_logs = sorted(
+        glob.glob(os.path.join(logdir, "events.out.tfevents.*"))
+    )
+
+    data: dict[str, dict[str, Any]] = defaultdict(dict)
+    headers = {"step"}
+
+    for logfile in tensorboard_logs:
+        event_accumulator = EventAccumulator(logfile)
+        event_accumulator.Reload()
+
+        tags = event_accumulator.Tags()
+        # Can cause issues if different logfiles don't have the same tags
+
+        for scalar_tag in tags["scalars"]:
+            headers.add(scalar_tag)
+            tag_list = event_accumulator.Scalars(scalar_tag)
+            for tag_data in tag_list:
+                _ = tag_data.wall_time
+                step = tag_data.step
+                value = tag_data.value
+
+                data[step]["step"] = step
+                data[step][scalar_tag] = value
+
+    data = pandas.DataFrame.from_dict(data, orient="index")
+    return data