diff --git a/src/mednet/engine/callbacks.py b/src/mednet/engine/callbacks.py
index 6966f6feeec2d3fa11af1b12acd950b9968ffafe..cba08495a6cf3a75a9c801d65aada006d4782787 100644
--- a/src/mednet/engine/callbacks.py
+++ b/src/mednet/engine/callbacks.py
@@ -29,12 +29,20 @@ class LoggingCallback(lightning.pytorch.Callback):
     Parameters
     ----------
 
-    resource_monitor
+    train_resource_monitor
+        A monitor that watches resource usage (CPU/GPU) in a separate process
+        and totally asynchronously with the code execution.
+
+    validation_resource_monitor
         A monitor that watches resource usage (CPU/GPU) in a separate process
         and totally asynchronously with the code execution.
     """
 
-    def __init__(self, resource_monitor: ResourceMonitor):
+    def __init__(
+        self,
+        train_resource_monitor: ResourceMonitor,
+        validation_resource_monitor: ResourceMonitor,
+    ):
         super().__init__()
 
         # timers
@@ -46,7 +54,8 @@ class LoggingCallback(lightning.pytorch.Callback):
         self._to_log: dict[str, float] = {}
 
         # helpers for CPU and GPU utilisation
-        self._resource_monitor = resource_monitor
+        self._train_resource_monitor = train_resource_monitor
+        self._validation_resource_monitor = validation_resource_monitor
         self._max_queue_retries = 2
 
     def on_train_start(
@@ -97,6 +106,10 @@ class LoggingCallback(lightning.pytorch.Callback):
         pl_module
             The lightning module that is being trained
         """
+        # summarizes resource usage since the last checkpoint
+        # clears internal buffers and starts accumulating again.
+        self._train_resource_monitor.checkpoint(remove_last_n=-1)
+
         self._start_training_epoch_time = time.time()
 
     def on_train_epoch_end(
@@ -121,17 +134,13 @@ class LoggingCallback(lightning.pytorch.Callback):
             The lightning module that is being trained
         """
 
-        # summarizes resource usage since the last checkpoint
-        # clears internal buffers and starts accumulating again.
-        self._resource_monitor.checkpoint()
-
         # evaluates this training epoch total time, and log it
         epoch_time = time.time() - self._start_training_epoch_time
 
         self._to_log["epoch-duration-seconds/train"] = epoch_time
         self._to_log["learning-rate"] = pl_module.optimizers().defaults["lr"]  # type: ignore
 
-        metrics = self._resource_monitor.data
+        metrics = self._train_resource_monitor.data
         if metrics is not None:
             for metric_name, metric_value in metrics.items():
                 self._to_log[f"{metric_name}/train"] = float(metric_value)
@@ -235,6 +244,14 @@ class LoggingCallback(lightning.pytorch.Callback):
         pl_module
             The lightning module that is being trained
         """
+        # required because the validation epoch is started **within** the
+        # training epoch START/END.
+        #
+        # summarizes resource usage since the last checkpoint
+        # clears internal buffers and starts accumulating again.
+        self._train_resource_monitor.checkpoint(remove_last_n=-1)
+        self._validation_resource_monitor.checkpoint(remove_last_n=-1)
+
         self._start_validation_epoch_time = time.time()
 
     def on_validation_epoch_end(
@@ -261,12 +278,12 @@ class LoggingCallback(lightning.pytorch.Callback):
 
         # summarizes resource usage since the last checkpoint
         # clears internal buffers and starts accumulating again.
-        self._resource_monitor.checkpoint()
+        self._validation_resource_monitor.checkpoint(remove_last_n=-1)
 
         epoch_time = time.time() - self._start_validation_epoch_time
         self._to_log["epoch-duration-seconds/validation"] = epoch_time
 
-        metrics = self._resource_monitor.data
+        metrics = self._validation_resource_monitor.data
         if metrics is not None:
             for metric_name, metric_value in metrics.items():
                 self._to_log[f"{metric_name}/validation"] = float(metric_value)
diff --git a/src/mednet/engine/device.py b/src/mednet/engine/device.py
index 2eeef34a96156083df564a20746e447f2e577afe..d11aacffb0f14960a4c3434beccb6e283bd6c940 100644
--- a/src/mednet/engine/device.py
+++ b/src/mednet/engine/device.py
@@ -4,6 +4,7 @@
 
 import logging
 import os
+import typing
 
 import torch
 import torch.backends
@@ -11,6 +12,14 @@ import torch.backends
 logger = logging.getLogger(__name__)
 
 
+SupportedPytorchDevice: typing.TypeAlias = typing.Literal[
+    "cpu",
+    "cuda",
+    "mps",
+]
+"""List of supported pytorch devices by this library."""
+
+
 def _split_int_list(s: str) -> list[int]:
     """Splits a list of integers encoded in a string (e.g. "1,2,3") into a
     Python list of integers (e.g. ``[1, 2, 3]``)."""
@@ -43,11 +52,16 @@ class DeviceManager:
         current process.
     """
 
-    SUPPORTED = ("cpu", "cuda", "mps")
-
-    def __init__(self, name: str):
+    def __init__(self, name: SupportedPytorchDevice):
         parts = name.split(":", 1)
-        self.device_type = parts[0]
+
+        # make device type of the right Python type
+        if parts[0] not in typing.get_args(SupportedPytorchDevice):
+            raise ValueError(f"Unsupported device-type `{parts[0]}`")
+        self.device_type: SupportedPytorchDevice = typing.cast(
+            SupportedPytorchDevice, parts[0]
+        )
+
         self.device_ids: list[int] = []
         if len(parts) > 1:
             self.device_ids = _split_int_list(parts[1])
@@ -70,10 +84,11 @@ class DeviceManager:
                     [str(k) for k in self.device_ids]
                 )
 
-        if self.device_type not in DeviceManager.SUPPORTED:
+        if self.device_type not in typing.get_args(SupportedPytorchDevice):
             raise RuntimeError(
                 f"Unsupported device type `{self.device_type}`. "
-                f"Supported devices types are `{', '.join(DeviceManager.SUPPORTED)}`"
+                f"Supported devices types are "
+                f"`{', '.join(typing.get_args(SupportedPytorchDevice))}`"
             )
 
         if self.device_ids and self.device_type in ("cpu", "mps"):
diff --git a/src/mednet/engine/trainer.py b/src/mednet/engine/trainer.py
index dee1ad980867e3ff485be1d2b3358e1dac9f9c3b..1b9bfb0b0918c01e895a4dc0a8f15af44c9249d1 100644
--- a/src/mednet/engine/trainer.py
+++ b/src/mednet/engine/trainer.py
@@ -14,9 +14,9 @@ import lightning.pytorch.loggers
 import torch.nn
 
 from ..utils.checkpointer import CHECKPOINT_ALIASES
-from ..utils.resources import ResourceMonitor, cpu_constants, gpu_constants
+from ..utils.resources import ResourceMonitor, cpu_constants, cuda_constants
 from .callbacks import LoggingCallback
-from .device import DeviceManager
+from .device import DeviceManager, SupportedPytorchDevice
 
 logger = logging.getLogger(__name__)
 
@@ -62,7 +62,7 @@ def save_model_summary(
 
 def static_information_to_csv(
     static_logfile_name: pathlib.Path,
-    device_type: str,
+    device_type: SupportedPytorchDevice,
     model_size: int,
 ) -> None:
     """Saves the static information in a CSV file.
@@ -87,10 +87,19 @@ def static_information_to_csv(
     with static_logfile_name.open("w", newline="") as f:
         logdata: dict[str, int | float | str] = {}
         logdata.update(cpu_constants())
-        if device_type == "cuda":
-            results = gpu_constants()
-            if results is not None:
-                logdata.update(results)
+
+        match device_type:
+            case "cpu":
+                pass
+            case "cuda":
+                results = cuda_constants()
+                if results is not None:
+                    logdata.update(results)
+            case "mps":
+                pass
+            case _:
+                pass
+
         logdata["model_size"] = model_size
         logwriter = csv.DictWriter(f, fieldnames=logdata.keys())
         logwriter.writeheader()
@@ -177,9 +186,16 @@ def run(
         f"Then, open a browser on the printed address."
     )
 
-    resource_monitor = ResourceMonitor(
+    train_resource_monitor = ResourceMonitor(
+        interval=monitoring_interval,
+        device_type=device_manager.device_type,
+        main_pid=os.getpid(),
+        logging_level=logging.ERROR,
+    )
+
+    validation_resource_monitor = ResourceMonitor(
         interval=monitoring_interval,
-        has_gpu=device_manager.device_type == "cuda",
+        device_type=device_manager.device_type,
         main_pid=os.getpid(),
         logging_level=logging.ERROR,
     )
@@ -208,7 +224,7 @@ def run(
         no_of_parameters,
     )
 
-    with resource_monitor:
+    with train_resource_monitor, validation_resource_monitor:
         accelerator, devices = device_manager.lightning_accelerator()
         trainer = lightning.pytorch.Trainer(
             accelerator=accelerator,
@@ -219,7 +235,9 @@ def run(
             check_val_every_n_epoch=validation_period,
             log_every_n_steps=len(datamodule.train_dataloader()),
             callbacks=[
-                LoggingCallback(resource_monitor),
+                LoggingCallback(
+                    train_resource_monitor, validation_resource_monitor
+                ),
                 checkpoint_minvalloss_callback,
             ],
         )
diff --git a/src/mednet/scripts/train_analysis.py b/src/mednet/scripts/train_analysis.py
index 94867b6adf66443ff4adb902573c7482eca9c1da..7736d2cb9b362b6dbc5ba1d920bae9ff084a2a2e 100644
--- a/src/mednet/scripts/train_analysis.py
+++ b/src/mednet/scripts/train_analysis.py
@@ -66,6 +66,8 @@ def create_figures(
 
     import matplotlib.pyplot as plt
 
+    from matplotlib.axes import Axes
+    from matplotlib.figure import Figure
     from matplotlib.ticker import MaxNLocator
 
     figures = []
@@ -77,8 +79,8 @@ def create_figures(
             continue
 
         fig, ax = plt.subplots(1, 1)
-        ax = typing.cast(plt.Axes, ax)
-        fig = typing.cast(plt.Figure, fig)
+        ax = typing.cast(Axes, ax)
+        fig = typing.cast(Figure, fig)
 
         if len(curves) == 1:
             # there is only one curve, just plot it
diff --git a/src/mednet/utils/resources.py b/src/mednet/utils/resources.py
index ca0f5efa3727e5cc53ac6c722f2371e0ccc4d00d..7babefde5a8ec8ad16b83503ba29a4022c838f96 100644
--- a/src/mednet/utils/resources.py
+++ b/src/mednet/utils/resources.py
@@ -7,6 +7,7 @@ import logging
 import multiprocessing
 import multiprocessing.synchronize
 import os
+import plistlib
 import queue
 import shutil
 import subprocess
@@ -16,6 +17,8 @@ import typing
 import numpy
 import psutil
 
+from ..engine.device import SupportedPytorchDevice
+
 logger = logging.getLogger(__name__)
 
 _nvidia_smi = shutil.which("nvidia-smi")
@@ -74,7 +77,7 @@ def run_nvidia_smi(
     return retval
 
 
-def gpu_constants() -> dict[str, str | int | float] | None:
+def cuda_constants() -> dict[str, str | int | float] | None:
     """Returns GPU (static) information using nvidia-smi.
 
     See :py:func:`run_nvidia_smi` for operational details.
@@ -102,7 +105,7 @@ def gpu_constants() -> dict[str, str | int | float] | None:
     return retval
 
 
-def gpu_log() -> dict[str, float] | None:
+def cuda_log() -> dict[str, float] | None:
     """Returns GPU information about current non-static status using nvidia-
     smi.
 
@@ -116,13 +119,13 @@ def gpu_log() -> dict[str, float] | None:
         return a dictionary containing the following ``nvidia-smi`` query
         information, in this order:
 
-        * ``memory.used``, as ``gpu_memory_used`` (transformed to gigabytes,
+        * ``memory.used``, as ``memory-used-GB/gpu`` (transformed to gigabytes,
           :py:class:`float`)
-        * ``memory.free``, as ``gpu_memory_free`` (transformed to gigabytes,
+        * ``memory.free``, as ``memory-free-GB/gpu`` (transformed to gigabytes,
           :py:class:`float`)
-        * ``100*memory.used/memory.total``, as ``gpu_memory_percent``,
+        * ``100*memory.used/memory.total``, as ``memory-percent/gpu``,
           (:py:class:`float`, in percent)
-        * ``utilization.gpu``, as ``gpu_percent``,
+        * ``utilization.gpu``, as ``percent-usage/gpu``,
           (:py:class:`float`, in percent)
     """
 
@@ -143,6 +146,66 @@ def gpu_log() -> dict[str, float] | None:
     }
 
 
+def mps_log() -> dict[str, float] | None:
+    """Returns GPU information about current non-static status using ``sudo
+    powermetrics``.
+
+    Returns
+    -------
+
+    data
+        If ``sudo powermetrics`` is not executable (or is not configured for
+        passwordless execution), returns ``None``, otherwise, we return a
+        dictionary containing the following query information, in this order:
+
+        * ``freq_hz`` as ``frequency-MHz/gpu``
+        * 100 * (1 - ``idle_ratio``), as ``percent-usage/gpu``,
+          (:py:class:`float`, in percent)
+    """
+
+    cmd = [
+        "sudo",
+        "-n",
+        "/usr/bin/powermetrics",
+        "--samplers",
+        "gpu_power",
+        "-i500",
+        "-n1",
+        "-fplist",
+    ]
+
+    try:
+        raw_bytes = subprocess.check_output(cmd)
+        data = plistlib.loads(raw_bytes)
+        return {
+            "frequency-MHz/gpu": float(data["gpu"]["freq_hz"]),
+            "percent-usage/gpu": 100 * (1 - data["gpu"]["idle_ratio"]),
+        }
+
+    except subprocess.CalledProcessError:
+        import inspect
+        import warnings
+
+        warnings.warn(
+            inspect.cleandoc(
+                f"""Cannot run `sudo powermetrics` without a password. Probably,
+                you did not setup sudo to execute the macOS CLI application
+                `powermetrics` passwordlessly and therefore this warning is
+                being issued. This does not affect the running of your model
+                training, only the ability of the resource monitor of gathering
+                GPU usage information on macOS while using the MPS compute
+                backend.  To configure this properly and get rid of this
+                warning, execute `sudo visudo` and add the following line where
+                suitable: `yourusername ALL=(ALL) NOPASSWD:SETENV:
+                /usr/bin/powermetrics`. Replace `yourusername` by your actual
+                username on the machine. Test the setup running the command
+                `{' '.join(cmd)}` by hand."""
+            )
+        )
+
+    return None
+
+
 def cpu_constants() -> dict[str, int | float]:
     """Returns static CPU information about the current system.
 
@@ -179,7 +242,7 @@ class CPULogger:
         [k.cpu_percent(interval=None) for k in self.cluster]
 
     def log(self) -> dict[str, int | float]:
-        """Returns current process cluster iformation.
+        """Returns current process cluster information.
 
         Returns
         -------
@@ -252,8 +315,9 @@ class _InformationGatherer:
     Parameters
     ----------
 
-    has_gpu
-        A flag indicating if we have a GPU installed on the platform or not
+    device_type
+        String representation of one of the supported pytorch device types
+        triggering the correct readout of resource usage.
 
     main_pid
         The main process identifier to monitor
@@ -263,27 +327,58 @@ class _InformationGatherer:
     """
 
     def __init__(
-        self, has_gpu: bool, main_pid: int | None, logger: logging.Logger
+        self,
+        device_type: SupportedPytorchDevice,
+        main_pid: int | None,
+        logger: logging.Logger,
     ):
         self.logger: logging.Logger = logger
         self.cpu_logger = CPULogger(main_pid)
         keys: list[str] = list(self.cpu_logger.log().keys())
-        self.has_gpu: bool = has_gpu
-        if self.has_gpu:
-            example = gpu_log()
-            if example is not None:
-                keys += list(example.keys())
+        self.device_type = device_type
+
+        match self.device_type:
+            case "cpu":
+                logger.info(
+                    f"Pytorch device-type `{device_type}`: "
+                    f"no GPU logging will be performed "
+                )
+            case "cuda":
+                example = cuda_log()
+                if example is not None:
+                    keys += list(example.keys())
+            case "mps":
+                example = mps_log()
+                if example is not None:
+                    keys += list(example.keys())
+            case _:
+                logger.warning(
+                    f"Unsupported device-type `{device_type}`: "
+                    f"no GPU logging will be performed "
+                )
+
         self.data: dict[str, list[int | float]] = {k: [] for k in keys}
 
     def acc(self) -> None:
         """Accumulates another measurement."""
         for k, v in self.cpu_logger.log().items():
             self.data[k].append(v)
-        if self.has_gpu:
-            sample = gpu_log()
-            if sample is not None:
-                for k, v in sample.items():
-                    self.data[k].append(v)
+
+        match self.device_type:
+            case "cpu":
+                pass
+            case "cuda":
+                sample = cuda_log()
+                if sample is not None:
+                    for k, v in sample.items():
+                        self.data[k].append(v)
+            case "mps":
+                sample = mps_log()
+                if sample is not None:
+                    for k, v in sample.items():
+                        self.data[k].append(v)
+            case _:
+                pass
 
     def clear(self) -> None:
         """Clears accumulated data."""
@@ -299,7 +394,7 @@ class _InformationGatherer:
 
 def _monitor_worker(
     interval: int | float,
-    has_gpu: bool,
+    device_type: SupportedPytorchDevice,
     main_pid: int,
     stop: multiprocessing.synchronize.Event,
     summary_event: multiprocessing.synchronize.Event,
@@ -315,8 +410,9 @@ def _monitor_worker(
         Number of seconds to wait between each measurement (maybe a floating
         point number as accepted by :py:func:`time.sleep`)
 
-    has_gpu
-        A flag indicating if we have a GPU installed on the platform or not
+    device_type
+        String representation of one of the supported pytorch device types
+        triggering the correct readout of resource usage.
 
     main_pid
         The main process identifier to monitor
@@ -334,7 +430,7 @@ def _monitor_worker(
         The logging level to use for logging from launched processes
     """
     logger = multiprocessing.log_to_stderr(level=logging_level)
-    ra = _InformationGatherer(has_gpu, main_pid, logger)
+    ra = _InformationGatherer(device_type, main_pid, logger)
 
     while not stop.is_set():
         try:
@@ -364,8 +460,9 @@ class ResourceMonitor:
         Number of seconds to wait between each measurement (maybe a floating
         point number as accepted by :py:func:`time.sleep`)
 
-    has_gpu
-        A flag indicating if we have a GPU installed on the platform or not
+    device_type
+        String representation of one of the supported pytorch device types
+        triggering the correct readout of resource usage.
 
     main_pid
         The main process identifier to monitor
@@ -377,12 +474,12 @@ class ResourceMonitor:
     def __init__(
         self,
         interval: int | float,
-        has_gpu: bool,
+        device_type: SupportedPytorchDevice,
         main_pid: int,
         logging_level: int,
     ):
         self.interval = interval
-        self.has_gpu = has_gpu
+        self.device_type = device_type
         self.main_pid = main_pid
         self.stop_event = multiprocessing.Event()
         self.summary_event = multiprocessing.Event()
@@ -396,7 +493,7 @@ class ResourceMonitor:
             name="ResourceMonitorProcess",
             args=(
                 self.interval,
-                self.has_gpu,
+                self.device_type,
                 self.main_pid,
                 self.stop_event,
                 self.summary_event,
@@ -411,9 +508,17 @@ class ResourceMonitor:
         """Starts the monitoring process."""
         self.monitor.start()
 
-    def checkpoint(self) -> None:
+    def checkpoint(self, remove_last_n: int | None = None) -> None:
         """Forces the monitoring process to yield data and clear the internal
-        accumlator."""
+        accumulator.
+
+        Parameters
+        ----------
+        remove_last_n
+            If set, then remove the last observations from all entries if at
+            least one entry is kept. Useful to remove spurious observations by
+            the end of a period.
+        """
         self.summary_event.set()
         try:
             data: dict[str, list[int | float]] = self.q.get(
@@ -432,7 +537,9 @@ class ResourceMonitor:
             self.data = {}
             for k, values in data.items():
                 if values:
-                    if k in ("cpu_processes", "cpu_open_files"):
+                    if values[:remove_last_n]:
+                        values = values[:remove_last_n]
+                    if k in ("num-processes/cpu", "num-open-files/cpu"):
                         self.data[k] = numpy.max(values)
                     else:
                         self.data[k] = float(numpy.mean(values))