diff --git a/.gitignore b/.gitignore
index 7fa19d5c769d8bf8bf3804ce2e794710a3fbbc12..a3459a35b7c043e706661283d797215df9e8f64a 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,34 +1,33 @@
 *~
 *.swp
 *.pyc
+*.so
+*.dylib
 bin
 eggs
 parts
 .installed.cfg
 .mr.developer.cfg
 *.egg-info
-src
 develop-eggs
-sphinx
 dist
 .nfs*
 .gdb_history
 build
-.coverage
+*.egg
+src/
+doc/api
 record.txt
-miniconda.sh
-miniconda/
-.vscode
+core
+output_temp
+output
+*.DS_Store
 *.ipynb
 .ipynb_checkpoints
 */.ipynb_checkpoints/*
-*.DS_Store
-/test*.py
-.noseids
-/coverage.xml
-/nosetests.xml
-/results/*
-/predictions/*
-/evaluations/*
-/comparisons/*
-doc/api
\ No newline at end of file
+submitted.sql3
+./logs/
+.coverage
+.envrc
+environment.yaml
+html/
diff --git a/bob/med/tb/data/transforms.py b/bob/med/tb/data/transforms.py
index c107f93ca2b5f8ff149fb643528fdb1306c369ee..0d18d28d0867b851fd0d966540960ba33593336b 100644
--- a/bob/med/tb/data/transforms.py
+++ b/bob/med/tb/data/transforms.py
@@ -15,10 +15,7 @@ import random
 
 import numpy
 import PIL.Image
-import torchvision.transforms
-import torchvision.transforms.functional
-from scipy.ndimage.filters import gaussian_filter
-from scipy.ndimage.interpolation import map_coordinates
+from scipy.ndimage import gaussian_filter, map_coordinates
 
 class SingleAutoLevel16to8:
     """Converts a 16-bit image to 8-bit representation using "auto-level"
@@ -74,7 +71,7 @@ class ElasticDeformation:
         if random.random() < self.p:
 
             img = numpy.asarray(img)
-            
+
             assert img.ndim == 2
 
             shape = img.shape
@@ -91,4 +88,4 @@ class ElasticDeformation:
                 img[:, :], indices, order=self.spline_order, mode=self.mode).reshape(shape)
             return PIL.Image.fromarray(result)
         else:
-            return img
\ No newline at end of file
+            return img
diff --git a/bob/med/tb/engine/trainer.py b/bob/med/tb/engine/trainer.py
index c66c1d3bcf45c84741f91b9812a07e1f3612fe7d..a46150c17eb556cb19dd894bd8717162522486ab 100644
--- a/bob/med/tb/engine/trainer.py
+++ b/bob/med/tb/engine/trainer.py
@@ -9,12 +9,19 @@ import shutil
 import datetime
 import contextlib
 
+import numpy
 import torch
 from tqdm import tqdm
 
 from ..utils.measure import SmoothedValue
 from ..utils.summary import summary
-from ..utils.resources import cpu_constants, gpu_constants, cpu_log, gpu_log
+
+# from ..utils.resources import cpu_constants, gpu_constants, cpu_log, gpu_log
+from ..utils.resources import (
+    ResourceMonitor,
+    cpu_constants,
+    gpu_constants,
+)
 
 import logging
 
@@ -33,14 +40,14 @@ def torch_evaluation(model):
     ----------
 
     model : :py:class:`torch.nn.Module`
-        Network (e.g. driu, hed, unet)
+        Network
 
 
     Yields
     ------
 
     model : :py:class:`torch.nn.Module`
-        Network (e.g. driu, hed, unet)
+        Network
 
     """
 
@@ -49,10 +56,451 @@ def torch_evaluation(model):
     model.train()
 
 
+def check_gpu(device):
+    """
+    Check the device type and the availability of GPU.
+
+    Parameters
+    ----------
+
+    device : :py:class:`torch.device`
+        device to use
+
+    """
+    if device.type == "cuda":
+        # asserts we do have a GPU
+        assert bool(
+            gpu_constants()
+        ), f"Device set to '{device}', but nvidia-smi is not installed"
+
+
+def save_model_summary(output_folder, model):
+    """
+    Save a little summary of the model in a txt file.
+
+    Parameters
+    ----------
+
+    output_folder : str
+        output path
+
+    model : :py:class:`torch.nn.Module`
+        Network (e.g. driu, hed, unet)
+
+    Returns
+    -------
+    r : str
+        The model summary in a text format.
+
+    n : int
+        The number of parameters of the model.
+
+    """
+    summary_path = os.path.join(output_folder, "model_summary.txt")
+    logger.info(f"Saving model summary at {summary_path}...")
+    with open(summary_path, "wt") as f:
+        r, n = summary(model)
+        logger.info(f"Model has {n} parameters...")
+        f.write(r)
+    return r, n
+
+
+def static_information_to_csv(static_logfile_name, device, n):
+    """
+    Save the static information in a csv file.
+
+    Parameters
+    ----------
+
+    static_logfile_name : str
+        The static file name which is a join between the output folder and "constant.csv"
+
+    """
+    if os.path.exists(static_logfile_name):
+        backup = static_logfile_name + "~"
+        if os.path.exists(backup):
+            os.unlink(backup)
+        shutil.move(static_logfile_name, backup)
+    with open(static_logfile_name, "w", newline="") as f:
+        logdata = cpu_constants()
+        if device.type == "cuda":
+            logdata += gpu_constants()
+        logdata += (("model_size", n),)
+        logwriter = csv.DictWriter(f, fieldnames=[k[0] for k in logdata])
+        logwriter.writeheader()
+        logwriter.writerow(dict(k for k in logdata))
+
+
+def check_exist_logfile(logfile_name, arguments):
+    """
+    Check existance of logfile (trainlog.csv),
+    If the logfile exist the and the epochs number are still 0, The logfile will be replaced.
+
+    Parameters
+    ----------
+
+    logfile_name : str
+        The logfile_name which is a join between the output_folder and trainlog.csv
+
+    arguments : dict
+        start and end epochs
+
+    """
+    if arguments["epoch"] == 0 and os.path.exists(logfile_name):
+        backup = logfile_name + "~"
+        if os.path.exists(backup):
+            os.unlink(backup)
+        shutil.move(logfile_name, backup)
+
+
+def create_logfile_fields(valid_loader, extra_valid_loaders, device):
+    """
+    Creation of the logfile fields that will appear in the logfile.
+
+    Parameters
+    ----------
+
+    valid_loader : :py:class:`torch.utils.data.DataLoader`
+        To be used to validate the model and enable automatic checkpointing.
+        If set to ``None``, then do not validate it.
+
+    extra_valid_loaders : :py:class:`list` of :py:class:`torch.utils.data.DataLoader`
+        To be used to validate the model, however **does not affect** automatic
+        checkpointing. If set to ``None``, or empty, then does not log anything
+        else.  Otherwise, an extra column with the loss of every dataset in
+        this list is kept on the final training log.
+
+    device : :py:class:`torch.device`
+        device to use
+
+    Returns
+    -------
+
+    logfile_fields: tuple
+        The fields that will appear in trainlog.csv
+
+
+    """
+    logfile_fields = (
+        "epoch",
+        "total_time",
+        "eta",
+        "loss",
+        "learning_rate",
+    )
+    if valid_loader is not None:
+        logfile_fields += ("validation_loss",)
+    if extra_valid_loaders:
+        logfile_fields += ("extra_validation_losses",)
+    logfile_fields += tuple(
+        ResourceMonitor.monitored_keys(device.type == "cuda")
+    )
+    return logfile_fields
+
+
+def train_epoch(loader, model, optimizer, device, criterion, batch_chunk_count):
+    """Trains the model for a single epoch (through all batches)
+
+    Parameters
+    ----------
+
+    loader : :py:class:`torch.utils.data.DataLoader`
+        To be used to train the model
+
+    model : :py:class:`torch.nn.Module`
+        Network (e.g. driu, hed, unet)
+
+    optimizer : :py:mod:`torch.optim`
+
+    device : :py:class:`torch.device`
+        device to use
+
+    criterion : :py:class:`torch.nn.modules.loss._Loss`
+
+    batch_chunk_count: int
+        If this number is different than 1, then each batch will be divided in
+        this number of chunks.  Gradients will be accumulated to perform each
+        mini-batch.   This is particularly interesting when one has limited RAM
+        on the GPU, but would like to keep training with larger batches.  One
+        exchanges for longer processing times in this case.  To better understand
+        gradient accumulation, read
+        https://stackoverflow.com/questions/62067400/understanding-accumulated-gradients-in-pytorch.
+
+
+    Returns
+    -------
+
+    loss : float
+        A floating-point value corresponding the weighted average of this
+        epoch's loss
+
+    """
+
+    losses_in_epoch = []
+    samples_in_epoch = []
+    losses_in_batch = []
+    samples_in_batch = []
+
+    # progress bar only on interactive jobs
+    for idx, samples in enumerate(
+        tqdm(loader, desc="train", leave=False, disable=None)
+    ):
+
+        images = samples[1].to(
+            device=device, non_blocking=torch.cuda.is_available()
+        )
+        labels = samples[2].to(
+            device=device, non_blocking=torch.cuda.is_available()
+        )
+
+        # Increase label dimension if too low
+        # Allows single and multiclass usage
+        if labels.ndim == 1:
+            labels = torch.reshape(labels, (labels.shape[0], 1))
+
+        # Forward pass on the network
+        outputs = model(images)
+
+        loss = criterion(outputs, labels.double())
+
+        losses_in_batch.append(loss.item())
+        samples_in_batch.append(len(samples))
+
+        # Normalize loss to account for batch accumulation
+        loss = loss / batch_chunk_count
+
+        # Accumulate gradients - does not update weights just yet...
+        loss.backward()
+
+        # Weight update on the network
+        if ((idx + 1) % batch_chunk_count == 0) or (idx + 1 == len(loader)):
+            # Advances optimizer to the "next" state and applies weight update
+            # over the whole model
+            optimizer.step()
+
+            # Zeroes gradients for the next batch
+            optimizer.zero_grad()
+
+            # Normalize loss for current batch
+            batch_loss = numpy.average(
+                losses_in_batch, weights=samples_in_batch
+            )
+            losses_in_epoch.append(batch_loss.item())
+            samples_in_epoch.append(len(samples))
+
+            losses_in_batch.clear()
+            samples_in_batch.clear()
+            logger.debug(f"batch loss: {batch_loss.item()}")
+
+    return numpy.average(losses_in_epoch, weights=samples_in_epoch)
+
+
+def validate_epoch(loader, model, device, criterion, pbar_desc):
+    """
+    Processes input samples and returns loss (scalar)
+
+
+    Parameters
+    ----------
+
+    loader : :py:class:`torch.utils.data.DataLoader`
+        To be used to validate the model
+
+    model : :py:class:`torch.nn.Module`
+        Network (e.g. driu, hed, unet)
+
+    optimizer : :py:mod:`torch.optim`
+
+    device : :py:class:`torch.device`
+        device to use
+
+    criterion : :py:class:`torch.nn.modules.loss._Loss`
+        loss function
+
+    pbar_desc : str
+        A string for the progress bar descriptor
+
+
+    Returns
+    -------
+
+    loss : float
+        A floating-point value corresponding the weighted average of this
+        epoch's loss
+
+    """
+
+    batch_losses = []
+    samples_in_batch = []
+
+    with torch.no_grad(), torch_evaluation(model):
+
+        for samples in tqdm(loader, desc=pbar_desc, leave=False, disable=None):
+            images = samples[1].to(
+                device=device,
+                non_blocking=torch.cuda.is_available(),
+            )
+            labels = samples[2].to(
+                device=device,
+                non_blocking=torch.cuda.is_available(),
+            )
+
+            # Increase label dimension if too low
+            # Allows single and multiclass usage
+            if labels.ndim == 1:
+                labels = torch.reshape(labels, (labels.shape[0], 1))
+
+            # data forwarding on the existing network
+            outputs = model(images)
+            loss = criterion(outputs, labels.double())
+
+            batch_losses.append(loss.item())
+            samples_in_batch.append(len(samples))
+
+    return numpy.average(batch_losses, weights=samples_in_batch)
+
+
+def checkpointer_process(
+    checkpointer,
+    checkpoint_period,
+    valid_loss,
+    lowest_validation_loss,
+    arguments,
+    epoch,
+    max_epoch,
+):
+    """
+    Process the checkpointer, save the final model and keep track of the best model.
+
+    Parameters
+    ----------
+
+    checkpointer : :py:class:`bob.med.tb.utils.checkpointer.Checkpointer`
+        checkpointer implementation
+
+    checkpoint_period : int
+        save a checkpoint every ``n`` epochs.  If set to ``0`` (zero), then do
+        not save intermediary checkpoints
+
+    valid_loss : float
+        Current epoch validation loss
+
+    lowest_validation_loss : float
+        Keeps track of the best (lowest) validation loss
+
+    arguments : dict
+        start and end epochs
+
+    max_epoch : int
+        end_potch
+
+    Returns
+    -------
+
+    lowest_validation_loss : float
+        The lowest validation loss currently observed
+
+
+    """
+    if checkpoint_period and (epoch % checkpoint_period == 0):
+        checkpointer.save("model_periodic_save", **arguments)
+
+    if valid_loss is not None and valid_loss < lowest_validation_loss:
+        lowest_validation_loss = valid_loss
+        logger.info(
+            f"Found new low on validation set:" f" {lowest_validation_loss:.6f}"
+        )
+        checkpointer.save("model_lowest_valid_loss", **arguments)
+
+    if epoch >= max_epoch:
+        checkpointer.save("model_final_epoch", **arguments)
+
+    return lowest_validation_loss
+
+
+def write_log_info(
+    epoch,
+    current_time,
+    eta_seconds,
+    loss,
+    valid_loss,
+    extra_valid_losses,
+    optimizer,
+    logwriter,
+    logfile,
+    resource_data,
+):
+    """
+    Write log info in trainlog.csv
+
+    Parameters
+    ----------
+
+    epoch : int
+        Current epoch
+
+    current_time : float
+        Current training time
+
+    eta_seconds : float
+        estimated time-of-arrival taking into consideration previous epoch performance
+
+    loss : float
+        Current epoch's training loss
+
+    valid_loss : :py:class:`float`, None
+        Current epoch's validation loss
+
+    extra_valid_losses : :py:class:`list` of :py:class:`float`
+        Validation losses from other validation datasets being currently
+        tracked
+
+    optimizer : :py:mod:`torch.optim`
+
+    logwriter : csv.DictWriter
+        Dictionary writer that give the ability to write on the trainlog.csv
+
+    logfile : io.TextIOWrapper
+
+    resource_data : tuple
+        Monitored resources at the machine (CPU and GPU)
+
+    """
+
+    logdata = (
+        ("epoch", f"{epoch}"),
+        (
+            "total_time",
+            f"{datetime.timedelta(seconds=int(current_time))}",
+        ),
+        ("eta", f"{datetime.timedelta(seconds=int(eta_seconds))}"),
+        ("loss", f"{loss:.6f}"),
+        ("learning_rate", f"{optimizer.param_groups[0]['lr']:.6f}"),
+    )
+
+    if valid_loss is not None:
+        logdata += (("validation_loss", f"{valid_loss:.6f}"),)
+
+    if extra_valid_losses:
+        entry = numpy.array_str(
+            numpy.array(extra_valid_losses),
+            max_line_width=sys.maxsize,
+            precision=6,
+        )
+        logdata += (("extra_validation_losses", entry),)
+
+    logdata += resource_data
+
+    logwriter.writerow(dict(k for k in logdata))
+    logfile.flush()
+    tqdm.write("|".join([f"{k}: {v}" for (k, v) in logdata[:4]]))
+
+
 def run(
     model,
     data_loader,
     valid_loader,
+    extra_valid_loaders,
     optimizer,
     criterion,
     checkpointer,
@@ -60,7 +508,9 @@ def run(
     device,
     arguments,
     output_folder,
-    criterion_valid = None
+    monitoring_interval,
+    batch_chunk_count,
+    criterion_valid,
 ):
     """
     Fits a CNN model using supervised learning and save it to disk.
@@ -73,14 +523,20 @@ def run(
     ----------
 
     model : :py:class:`torch.nn.Module`
-        Network (e.g. pasa)
+        Network (e.g. driu, hed, unet)
 
     data_loader : :py:class:`torch.utils.data.DataLoader`
         To be used to train the model
 
-    valid_loader : :py:class:`torch.utils.data.DataLoader`
+    valid_loaders : :py:class:`list` of :py:class:`torch.utils.data.DataLoader`
         To be used to validate the model and enable automatic checkpointing.
-        If set to ``None``, then do not validate it.
+        If ``None``, then do not validate it.
+
+    extra_valid_loaders : :py:class:`list` of :py:class:`torch.utils.data.DataLoader`
+        To be used to validate the model, however **does not affect** automatic
+        checkpointing. If empty, then does not log anything else.  Otherwise,
+        an extra column with the loss of every dataset in this list is kept on
+        the final training log.
 
     optimizer : :py:mod:`torch.optim`
 
@@ -94,8 +550,8 @@ def run(
         save a checkpoint every ``n`` epochs.  If set to ``0`` (zero), then do
         not save intermediary checkpoints
 
-    device : str
-        device to use ``'cpu'`` or ``cuda:0``
+    device : :py:class:`torch.device`
+        device to use
 
     arguments : dict
         start and end epochs
@@ -103,73 +559,53 @@ def run(
     output_folder : str
         output path
 
+    monitoring_interval : int, float
+        interval, in seconds (or fractions), through which we should monitor
+        resources during training.
+
+    batch_chunk_count: int
+        If this number is different than 1, then each batch will be divided in
+        this number of chunks.  Gradients will be accumulated to perform each
+        mini-batch.   This is particularly interesting when one has limited RAM
+        on the GPU, but would like to keep training with larger batches.  One
+        exchanges for longer processing times in this case.
+
     criterion_valid : :py:class:`torch.nn.modules.loss._Loss`
         specific loss function for the validation set
+
     """
 
     start_epoch = arguments["epoch"]
     max_epoch = arguments["max_epoch"]
 
-    if device != "cpu":
-        # asserts we do have a GPU
-        assert bool(gpu_constants()), (
-            f"Device set to '{device}', but cannot "
-            f"find a GPU (maybe nvidia-smi is not installed?)"
-        )
+    check_gpu(device)
 
     os.makedirs(output_folder, exist_ok=True)
 
     # Save model summary
-    summary_path = os.path.join(output_folder, "model_summary.txt")
-    logger.info(f"Saving model summary at {summary_path}...")
-    with open(summary_path, "wt") as f:
-        r, n = summary(model)
-        logger.info(f"Model has {n} parameters...")
-        f.write(r)
+    r, n = save_model_summary(output_folder, model)
 
     # write static information to a CSV file
     static_logfile_name = os.path.join(output_folder, "constants.csv")
-    if os.path.exists(static_logfile_name):
-        backup = static_logfile_name + "~"
-        if os.path.exists(backup):
-            os.unlink(backup)
-        shutil.move(static_logfile_name, backup)
-    with open(static_logfile_name, "w", newline="") as f:
-        logdata = cpu_constants()
-        if device != "cpu":
-            logdata += gpu_constants()
-        logdata += (("model_size", n),)
-        logwriter = csv.DictWriter(f, fieldnames=[k[0] for k in logdata])
-        logwriter.writeheader()
-        logwriter.writerow(dict(k for k in logdata))
+
+    static_information_to_csv(static_logfile_name, device, n)
 
     # Log continous information to (another) file
     logfile_name = os.path.join(output_folder, "trainlog.csv")
 
-    if arguments["epoch"] == 0 and os.path.exists(logfile_name):
-        backup = logfile_name + "~"
-        if os.path.exists(backup):
-            os.unlink(backup)
-        shutil.move(logfile_name, backup)
+    check_exist_logfile(logfile_name, arguments)
 
-    logfile_fields = (
-        "epoch",
-        "total_time",
-        "eta",
-        "average_loss",
-        "median_loss",
-        "learning_rate",
+    logfile_fields = create_logfile_fields(
+        valid_loader, extra_valid_loaders, device
     )
-    if valid_loader is not None:
-        logfile_fields += ("validation_average_loss", "validation_median_loss")
-    logfile_fields += tuple([k[0] for k in cpu_log()])
-    if device != "cpu":
-        logfile_fields += tuple([k[0] for k in gpu_log()])
 
     # the lowest validation loss obtained so far - this value is updated only
     # if a validation set is available
     lowest_validation_loss = sys.float_info.max
 
+    # set a specific validation criterion if the user has set one
+    criterion_valid = criterion_valid or criterion
+
     with open(logfile_name, "a+", newline="") as logfile:
         logwriter = csv.DictWriter(logfile, fieldnames=logfile_fields)
 
@@ -193,92 +629,56 @@ def run(
             leave=False,
             disable=None,
         ):
-            losses = SmoothedValue(len(data_loader))
-            epoch = epoch + 1
-            arguments["epoch"] = epoch
-
-            # Epoch time
-            start_epoch_time = time.time()
-
-            # progress bar only on interactive jobs
-            for samples in tqdm(
-                data_loader, desc="batch", leave=False, disable=None
-            ):
 
-                # data forwarding on the existing network
-                images = samples[1].to(
-                    device=device, non_blocking=torch.cuda.is_available()
-                )
-                labels = samples[2].to(
-                    device=device, non_blocking=torch.cuda.is_available()
+            with ResourceMonitor(
+                interval=monitoring_interval,
+                has_gpu=(device.type == "cuda"),
+                main_pid=os.getpid(),
+                logging_level=logging.ERROR,
+            ) as resource_monitor:
+                epoch = epoch + 1
+                arguments["epoch"] = epoch
+
+                # Epoch time
+                start_epoch_time = time.time()
+
+                train_loss = train_epoch(
+                    data_loader,
+                    model,
+                    optimizer,
+                    device,
+                    criterion,
+                    batch_chunk_count,
                 )
 
-                # Increase labels dimension if too low
-                # Allows single and multiclass usage
-                if labels.ndim == 1:
-                    labels = torch.reshape(labels, (labels.shape[0], 1))
-
-                outputs = model(images)
-
-                # loss evaluation and learning (backward step)
-                loss = criterion(outputs, labels.double())
-                optimizer.zero_grad()
-                loss.backward()
-                optimizer.step()
-
-                losses.update(loss)
-                logger.debug(f"batch loss: {loss.item()}")
-
-
-            # calculates the validation loss if necessary
-            valid_losses = None
-            if valid_loader is not None:
-
-                with torch.no_grad(), torch_evaluation(model):
-
-                    valid_losses = SmoothedValue(len(valid_loader))
-                    for samples in tqdm(
-                        valid_loader, desc="valid", leave=False, disable=None
-                    ):
-                        # data forwarding on the existing network
-                        images = samples[1].to(
-                            device=device,
-                            non_blocking=torch.cuda.is_available(),
-                        )
-                        labels = samples[2].to(
-                            device=device,
-                            non_blocking=torch.cuda.is_available(),
-                        )
-
-                        # Increase labels dimension if too low
-                        # Allows single and multiclass usage
-                        if labels.ndim == 1:
-                            labels = torch.reshape(labels, (labels.shape[0], 1))
-
-                        outputs = model(images)
-
-                        if criterion_valid is not None:
-                            loss = criterion_valid(outputs, labels.double())
-                        else:
-                            loss = criterion(outputs, labels.double())
-                        valid_losses.update(loss)
-
-            if checkpoint_period and (epoch % checkpoint_period == 0):
-                checkpointer.save(f"model_{epoch:03d}", **arguments)
-
-            if (
-                valid_losses is not None
-                and valid_losses.avg < lowest_validation_loss
-            ):
-                lowest_validation_loss = valid_losses.avg
-                logger.info(
-                    f"Found new low on validation set:"
-                    f" {lowest_validation_loss:.6f}"
+                valid_loss = (
+                    validate_epoch(
+                        valid_loader, model, device, criterion_valid, "valid"
+                    )
+                    if valid_loader is not None
+                    else None
                 )
-                checkpointer.save(f"model_lowest_valid_loss", **arguments)
 
-            if epoch >= max_epoch:
-                checkpointer.save("model_final", **arguments)
+                extra_valid_losses = []
+                for pos, extra_valid_loader in enumerate(extra_valid_loaders):
+                    loss = validate_epoch(
+                        extra_valid_loader,
+                        model,
+                        device,
+                        criterion_valid,
+                        f"xval@{pos+1}",
+                    )
+                    extra_valid_losses.append(loss)
+
+            lowest_validation_loss = checkpointer_process(
+                checkpointer,
+                checkpoint_period,
+                valid_loss,
+                lowest_validation_loss,
+                arguments,
+                epoch,
+                max_epoch,
+            )
 
             # computes ETA (estimated time-of-arrival; end of training) taking
             # into consideration previous epoch performance
@@ -286,29 +686,18 @@ def run(
             eta_seconds = epoch_time * (max_epoch - epoch)
             current_time = time.time() - start_training_time
 
-            logdata = (
-                ("epoch", f"{epoch}"),
-                (
-                    "total_time",
-                    f"{datetime.timedelta(seconds=int(current_time))}",
-                ),
-                ("eta", f"{datetime.timedelta(seconds=int(eta_seconds))}"),
-                ("average_loss", f"{losses.avg:.6f}"),
-                ("median_loss", f"{losses.median:.6f}"),
-                ("learning_rate", f"{optimizer.param_groups[0]['lr']:.6f}"),
+            write_log_info(
+                epoch,
+                current_time,
+                eta_seconds,
+                train_loss,
+                valid_loss,
+                extra_valid_losses,
+                optimizer,
+                logwriter,
+                logfile,
+                resource_monitor.data,
             )
-            if valid_losses is not None:
-                logdata += (
-                    ("validation_average_loss", f"{valid_losses.avg:.6f}"),
-                    ("validation_median_loss", f"{valid_losses.median:.6f}"),
-                )
-            logdata += cpu_log()
-            if device != "cpu":
-                logdata += gpu_log()
-
-            logwriter.writerow(dict(k for k in logdata))
-            logfile.flush()
-            tqdm.write("|".join([f"{k}: {v}" for (k, v) in logdata[:4]]))
 
         total_training_time = time.time() - start_training_time
         logger.info(
diff --git a/bob/med/tb/models/alexnet.py b/bob/med/tb/models/alexnet.py
index afdfd46ef9b2d3974a07ce1042f5684805867fdf..9dbc3b0b575395f07635a4024d89fec463ccf350 100644
--- a/bob/med/tb/models/alexnet.py
+++ b/bob/med/tb/models/alexnet.py
@@ -1,7 +1,6 @@
 #!/usr/bin/env python
 # -*- coding: utf-8 -*-
 
-import torch
 import torch.nn as nn
 import torchvision.models as models
 from collections import OrderedDict
@@ -18,7 +17,8 @@ class Alexnet(nn.Module):
         super(Alexnet, self).__init__()
 
         # Load pretrained model
-        self.model_ft = models.alexnet(pretrained=pretrained)
+        weights = None if pretrained is False else models.AlexNet_Weights.DEFAULT
+        self.model_ft = models.alexnet(weights=weights)
 
         # Adapt output features
         self.model_ft.classifier[4] = nn.Linear(4096,512)
@@ -55,9 +55,9 @@ def build_alexnet(pretrained=False):
     """
 
     model = Alexnet(pretrained=pretrained)
-    model = [("normalizer", TorchVisionNormalizer()), 
+    model = [("normalizer", TorchVisionNormalizer()),
             ("model", model)]
     model = nn.Sequential(OrderedDict(model))
 
     model.name = "AlexNet"
-    return model
\ No newline at end of file
+    return model
diff --git a/bob/med/tb/models/densenet.py b/bob/med/tb/models/densenet.py
index 197df0aab8362b43d431dd39d58260303ab9824d..6a83e80c35f8ebac1690b825ba7f70b13da77b18 100644
--- a/bob/med/tb/models/densenet.py
+++ b/bob/med/tb/models/densenet.py
@@ -1,7 +1,6 @@
 #!/usr/bin/env python
 # -*- coding: utf-8 -*-
 
-import torch
 import torch.nn as nn
 import torchvision.models as models
 from collections import OrderedDict
@@ -18,11 +17,12 @@ class Densenet(nn.Module):
         super(Densenet, self).__init__()
 
         # Load pretrained model
-        self.model_ft = models.densenet121(pretrained=pretrained)
+        weights = None if not pretrained else models.DenseNet121_Weights.DEFAULT
+        self.model_ft = models.densenet121(weights=weights)
 
         # Adapt output features
         self.model_ft.classifier = nn.Sequential(
-                                        nn.Linear(1024,256), 
+                                        nn.Linear(1024,256),
                                         nn.Linear(256,1)
                                         )
 
@@ -56,9 +56,9 @@ def build_densenet(pretrained=False, nb_channels=3):
     """
 
     model = Densenet(pretrained=pretrained)
-    model = [("normalizer", TorchVisionNormalizer(nb_channels=nb_channels)), 
+    model = [("normalizer", TorchVisionNormalizer(nb_channels=nb_channels)),
             ("model", model)]
     model = nn.Sequential(OrderedDict(model))
 
     model.name = "Densenet"
-    return model
\ No newline at end of file
+    return model
diff --git a/bob/med/tb/models/densenet_rs.py b/bob/med/tb/models/densenet_rs.py
index 1b7943986fccc05e3faedac333e01ad87b3abf5a..7fb7cb772f8032a134fac550f377a9244b154377 100644
--- a/bob/med/tb/models/densenet_rs.py
+++ b/bob/med/tb/models/densenet_rs.py
@@ -1,22 +1,25 @@
 #!/usr/bin/env python
 # -*- coding: utf-8 -*-
 
-import torch
 import torch.nn as nn
 import torchvision.models as models
 from collections import OrderedDict
 from .normalizer import TorchVisionNormalizer
 
+
 class DensenetRS(nn.Module):
     """
     Densenet121 module for radiological extraction
 
     """
+
     def __init__(self):
         super(DensenetRS, self).__init__()
 
         # Load pretrained model
-        self.model_ft = models.densenet121(pretrained=True)
+        self.model_ft = models.densenet121(
+            weights=models.DenseNet121_Weights.DEFAULT
+        )
 
         # Adapt output features
         num_ftrs = self.model_ft.classifier.in_features
@@ -53,9 +56,8 @@ def build_densenetrs():
     """
 
     model = DensenetRS()
-    model = [("normalizer", TorchVisionNormalizer()), 
-            ("model", model)]
+    model = [("normalizer", TorchVisionNormalizer()), ("model", model)]
     model = nn.Sequential(OrderedDict(model))
 
     model.name = "DensenetRS"
-    return model
\ No newline at end of file
+    return model
diff --git a/bob/med/tb/scripts/predict.py b/bob/med/tb/scripts/predict.py
index 95d5e6b4991a72b54d06f1130f74f518e481e0ed..d5ba053b2ca34ca6fe409506036648b9e9b5bcf4 100644
--- a/bob/med/tb/scripts/predict.py
+++ b/bob/med/tb/scripts/predict.py
@@ -110,12 +110,12 @@ logger = logging.getLogger(__name__)
     cls=ResourceOption,
 )
 @verbosity_option(cls=ResourceOption)
-def predict(output_folder, model, dataset, batch_size, device, weight, 
+def predict(output_folder, model, dataset, batch_size, device, weight,
             relevance_analysis, grad_cams, **kwargs):
     """Predicts Tuberculosis presence (probabilities) on input images"""
 
     dataset = dataset if isinstance(dataset, dict) else dict(test=dataset)
-    
+
     if weight.startswith("http"):
         logger.info(f"Temporarily downloading '{weight}'...")
         f = download_to_tempfile(weight, progress=True)
@@ -137,7 +137,7 @@ def predict(output_folder, model, dataset, batch_size, device, weight,
         os.makedirs(os.path.dirname(filepath), exist_ok=True)
         pdf = PdfPages(filepath)
         pdf.savefig(relevance_analysis_plot(
-            model_weights, 
+            model_weights,
             title="LogReg model weights"))
         pdf.close()
 
@@ -193,21 +193,21 @@ def predict(output_folder, model, dataset, batch_size, device, weight,
 
                     # Compute MSE between original and new predictions
                     all_mse.append(metrics.mean_squared_error(
-                        np.array(predictions)[:,1],
-                        np.array(predictions_with_mean)[:,1]
+                        np.array(predictions, dtype=object)[:,1],
+                        np.array(predictions_with_mean, dtype=object)[:,1]
                         ))
 
                     # Back to original values
                     v = v_original
-                
+
                 # Remove temporary folder
                 shutil.rmtree(output_folder + "_temp", ignore_errors=True)
-                
+
                 filepath = os.path.join(output_folder, k + "_RA.pdf")
                 logger.info(f"Creating and saving plot at {filepath}...")
                 os.makedirs(os.path.dirname(filepath), exist_ok=True)
                 pdf = PdfPages(filepath)
                 pdf.savefig(relevance_analysis_plot(
-                    all_mse, 
+                    all_mse,
                     title=k.capitalize() + " set relevance analysis"))
                 pdf.close()
diff --git a/bob/med/tb/scripts/train.py b/bob/med/tb/scripts/train.py
index 6ea1c4e1b33047863acb8487a77bed417afe9f6e..d7400429a3531bf8bc061395ac2e3ef3720c28ff 100644
--- a/bob/med/tb/scripts/train.py
+++ b/bob/med/tb/scripts/train.py
@@ -2,13 +2,19 @@
 # coding=utf-8
 
 import os
+import sys
+import random
+import multiprocessing
 
 import click
+import numpy
 import torch
 from torch.nn import BCEWithLogitsLoss
 from torch.utils.data import DataLoader, WeightedRandomSampler
+
 from ..configs.datasets import get_samples_weights, get_positive_weights
 
+
 from bob.extension.scripts.click_helper import (
     verbosity_option,
     ConfigCommand,
@@ -18,12 +24,106 @@ from bob.extension.scripts.click_helper import (
 from ..utils.checkpointer import Checkpointer
 from ..engine.trainer import run
 from .tb import download_to_tempfile
-from ..models.normalizer import TorchVisionNormalizer
 
 import logging
+
 logger = logging.getLogger(__name__)
 
 
+def setup_pytorch_device(name):
+    """Sets-up the pytorch device to use
+
+
+    Parameters
+    ----------
+
+    name : str
+        The device name (``cpu``, ``cuda:0``, ``cuda:1``, and so on).  If you
+        set a specific cuda device such as ``cuda:1``, then we'll make sure it
+        is currently set.
+
+
+    Returns
+    -------
+
+    device : :py:class:`torch.device`
+        The pytorch device to use, pre-configured (and checked)
+
+    """
+
+    if name.startswith("cuda:"):
+        # In case one has multiple devices, we must first set the one
+        # we would like to use so pytorch can find it.
+        logger.info(f"User set device to '{name}' - trying to force device...")
+        os.environ["CUDA_VISIBLE_DEVICES"] = name.split(":", 1)[1]
+        if not torch.cuda.is_available():
+            raise RuntimeError(
+                f"CUDA is not currently available, but "
+                f"you set device to '{name}'"
+            )
+        # Let pytorch auto-select from environment variable
+        return torch.device("cuda")
+
+    elif name.startswith("cuda"):  # use default device
+        logger.info(f"User set device to '{name}' - using default CUDA device")
+        assert os.environ.get("CUDA_VISIBLE_DEVICES") is not None
+
+    # cuda or cpu
+    return torch.device(name)
+
+
+def set_seeds(value, all_gpus):
+    """Sets up all relevant random seeds (numpy, python, cuda)
+
+    If running with multiple GPUs **at the same time**, set ``all_gpus`` to
+    ``True`` to force all GPU seeds to be initialized.
+
+    Reference: `PyTorch page for reproducibility
+    <https://pytorch.org/docs/stable/notes/randomness.html>`_.
+
+
+    Parameters
+    ----------
+
+    value : int
+        The random seed value to use
+
+    all_gpus : :py:class:`bool`, Optional
+        If set, then reset the seed on all GPUs available at once.  This is
+        normally **not** what you want if running on a single GPU
+
+    """
+
+    random.seed(value)
+    numpy.random.seed(value)
+    torch.manual_seed(value)
+    torch.cuda.manual_seed(value)  # noop if cuda not available
+
+    # set seeds for all gpus
+    if all_gpus:
+        torch.cuda.manual_seed_all(value)  # noop if cuda not available
+
+
+def set_reproducible_cuda():
+    """Turns-off all CUDA optimizations that would affect reproducibility
+
+    For full reproducibility, also ensure not to use multiple (parallel) data
+    lowers.  That is setup ``num_workers=0``.
+
+    Reference: `PyTorch page for reproducibility
+    <https://pytorch.org/docs/stable/notes/randomness.html>`_.
+
+
+    """
+
+    # ensure to use only optimization algos for cuda that are known to have
+    # a deterministic effect (not random)
+    torch.backends.cudnn.deterministic = True
+
+    # turns off any optimization tricks
+    torch.backends.cudnn.benchmark = False
+
+
 @click.command(
     entry_point_group="bob.med.tb.config",
     cls=ConfigCommand,
@@ -56,17 +156,22 @@ logger = logging.getLogger(__name__)
 @click.option(
     "--dataset",
     "-d",
-    help="A torch.utils.data.dataset.Dataset instance implementing a dataset "
-    "to be used for training the model, possibly including all pre-processing "
-    "pipelines required or, optionally, a dictionary mapping string keys to "
-    "torch.utils.data.dataset.Dataset instances.  At least one key "
-    "named ``train`` must be available.  This dataset will be used for "
+    help="A dictionary mapping string keys to "
+    "torch.utils.data.dataset.Dataset instances implementing datasets "
+    "to be used for training and validating the model, possibly including all "
+    "pre-processing pipelines required or, optionally, a dictionary mapping "
+    "string keys to torch.utils.data.dataset.Dataset instances.  At least "
+    "one key named ``train`` must be available.  This dataset will be used for "
     "training the network model.  The dataset description must include all "
     "required pre-processing, including eventual data augmentation.  If a "
     "dataset named ``__train__`` is available, it is used prioritarily for "
     "training instead of ``train``.  If a dataset named ``__valid__`` is "
-    "available, it is used for model validation (and automatic check-pointing) "
-    "at each epoch.",
+    "available, it is used for model validation (and automatic "
+    "check-pointing) at each epoch.  If a dataset list named "
+    "``__extra_valid__`` is available, then it will be tracked during the "
+    "validation process and its loss output at the training log as well, "
+    "in the format of an array occupying a single column.  All other keys "
+    "are considered test datasets and are ignored during training",
     required=True,
     cls=ResourceOption,
 )
@@ -84,7 +189,7 @@ logger = logging.getLogger(__name__)
     cls=ResourceOption,
 )
 @click.option(
-    "--criterion_valid",
+    "--criterion-valid",
     help="A specific loss function for the validation set to compute the CNN"
     "error for every sample respecting the PyTorch API for loss functions"
     "(see torch.nn.modules.loss)",
@@ -102,7 +207,27 @@ logger = logging.getLogger(__name__)
     "until there are no more new samples to feed (epoch is finished).  "
     "If the total number of training samples is not a multiple of the "
     "batch-size, the last batch will be smaller than the first, unless "
-    "--drop-incomplete--batch is set, in which case this batch is not used.",
+    "--drop-incomplete-batch is set, in which case this batch is not used.",
+    required=True,
+    show_default=True,
+    default=1,
+    type=click.IntRange(min=1),
+    cls=ResourceOption,
+)
+@click.option(
+    "--batch-chunk-count",
+    "-c",
+    help="Number of chunks in every batch (this parameter affects "
+    "memory requirements for the network). The number of samples "
+    "loaded for every iteration will be batch-size/batch-chunk-count. "
+    "batch-size needs to be divisible by batch-chunk-count, otherwise an "
+    "error will be raised. This parameter is used to reduce number of "
+    "samples loaded in each iteration, in order to reduce the memory usage "
+    "in exchange for processing time (more iterations).  This is specially "
+    "interesting whe one is running with GPUs with limited RAM. The "
+    "default of 1 forces the whole batch to be processed at once.  Otherwise "
+    "the batch is broken into batch-chunk-count pieces, and gradients are "
+    "accumulated to complete each batch.",
     required=True,
     show_default=True,
     default=1,
@@ -124,7 +249,9 @@ logger = logging.getLogger(__name__)
 @click.option(
     "--epochs",
     "-e",
-    help="Number of epochs (complete training set passes) to train for",
+    help="Number of epochs (complete training set passes) to train for. "
+    "If continuing from a saved checkpoint, ensure to provide a greater "
+    "number of epochs than that saved on the checkpoint to be loaded. ",
     show_default=True,
     required=True,
     default=1000,
@@ -165,13 +292,16 @@ logger = logging.getLogger(__name__)
     cls=ResourceOption,
 )
 @click.option(
-    "--num_workers",
-    "-ns",
-    help="Number of parallel threads to use",
+    "--parallel",
+    "-P",
+    help="""Use multiprocessing for data loading: if set to -1 (default),
+    disables multiprocessing data loading.  Set to 0 to enable as many data
+    loading instances as processing cores as available in the system.  Set to
+    >= 1 to enable that many multiprocessing instances for data loading.""",
+    type=click.IntRange(min=-1),
     show_default=True,
-    required=False,
-    default=0,
-    type=click.IntRange(min=0),
+    required=True,
+    default=-1,
     cls=ResourceOption,
 )
 @click.option(
@@ -191,6 +321,22 @@ logger = logging.getLogger(__name__)
     default="none",
     cls=ResourceOption,
 )
+@click.option(
+    "--monitoring-interval",
+    "-I",
+    help="""Time between checks for the use of resources during each training
+    epoch.  An interval of 5 seconds, for example, will lead to CPU and GPU
+    resources being probed every 5 seconds during each training epoch.
+    Values registered in the training logs correspond to averages (or maxima)
+    observed through possibly many probes in each epoch.  Notice that setting a
+    very small value may cause the probing process to become extremely busy,
+    potentially biasing the overall perception of resource usage.""",
+    type=click.FloatRange(min=0.1),
+    show_default=True,
+    required=True,
+    default=5.0,
+    cls=ResourceOption,
+)
 @verbosity_option(cls=ResourceOption)
 def train(
     model,
@@ -198,6 +344,7 @@ def train(
     output_folder,
     epochs,
     batch_size,
+    batch_chunk_count,
     drop_incomplete_batch,
     criterion,
     criterion_valid,
@@ -205,9 +352,10 @@ def train(
     checkpoint_period,
     device,
     seed,
-    num_workers,
+    parallel,
     weight,
     normalization,
+    monitoring_interval,
     verbose,
     **kwargs,
 ):
@@ -220,10 +368,14 @@ def train(
     abruptly.
     """
 
-    torch.manual_seed(seed)
+    device = setup_pytorch_device(device)
+
+    set_seeds(seed, all_gpus=False)
 
     use_dataset = dataset
     validation_dataset = None
+    extra_validation_datasets = []
+
     if isinstance(dataset, dict):
         if "__train__" in dataset:
             logger.info("Found (dedicated) '__train__' set for training")
@@ -236,76 +388,142 @@ def train(
             logger.info("Will checkpoint lowest loss model on validation set")
             validation_dataset = dataset["__valid__"]
 
+        if "__extra_valid__" in dataset:
+            if not isinstance(dataset["__extra_valid__"], list):
+                raise RuntimeError(
+                    f"If present, dataset['__extra_valid__'] must be a list, "
+                    f"but you passed a {type(dataset['__extra_valid__'])}, "
+                    f"which is invalid."
+                )
+            logger.info(
+                f"Found {len(dataset['__extra_valid__'])} extra validation "
+                f"set(s) to be tracked during training"
+            )
+            logger.info(
+                "Extra validation sets are NOT used for model checkpointing!"
+            )
+            extra_validation_datasets = dataset["__extra_valid__"]
+
+    # PyTorch dataloader
+    multiproc_kwargs = dict()
+    if parallel < 0:
+        multiproc_kwargs["num_workers"] = 0
+    else:
+        multiproc_kwargs["num_workers"] = (
+            parallel or multiprocessing.cpu_count()
+        )
+
+    if multiproc_kwargs["num_workers"] > 0 and sys.platform == "darwin":
+        multiproc_kwargs[
+            "multiprocessing_context"
+        ] = multiprocessing.get_context("spawn")
+
+    batch_chunk_size = batch_size
+    if batch_size % batch_chunk_count != 0:
+        # batch_size must be divisible by batch_chunk_count.
+        raise RuntimeError(
+            f"--batch-size ({batch_size}) must be divisible by "
+            f"--batch-chunk-size ({batch_chunk_count})."
+        )
+    else:
+        batch_chunk_size = batch_size // batch_chunk_count
+
     # Create weighted random sampler
     train_samples_weights = get_samples_weights(use_dataset)
     train_samples_weights = train_samples_weights.to(
-                    device=device, non_blocking=torch.cuda.is_available()
-                )
-    train_sampler = WeightedRandomSampler(train_samples_weights, len(train_samples_weights), replacement=True)
+        device=device, non_blocking=torch.cuda.is_available()
+    )
+    train_sampler = WeightedRandomSampler(
+        train_samples_weights, len(train_samples_weights), replacement=True
+    )
 
     # Redefine a weighted criterion if possible
     if isinstance(criterion, torch.nn.BCEWithLogitsLoss):
         positive_weights = get_positive_weights(use_dataset)
         positive_weights = positive_weights.to(
-                        device=device, non_blocking=torch.cuda.is_available()
-                    )
+            device=device, non_blocking=torch.cuda.is_available()
+        )
         criterion = BCEWithLogitsLoss(pos_weight=positive_weights)
     else:
         logger.warning("Weighted criterion not supported")
 
     # PyTorch dataloader
+
     data_loader = DataLoader(
         dataset=use_dataset,
-        batch_size=batch_size,
-        num_workers=num_workers,
+        batch_size=batch_chunk_size,
         drop_last=drop_incomplete_batch,
         pin_memory=torch.cuda.is_available(),
-        sampler=train_sampler
+        sampler=train_sampler,
+        **multiproc_kwargs,
     )
 
     valid_loader = None
     if validation_dataset is not None:
 
         # Redefine a weighted valid criterion if possible
-        if isinstance(criterion_valid, torch.nn.BCEWithLogitsLoss) or criterion_valid is None:
+        if (
+            isinstance(criterion_valid, torch.nn.BCEWithLogitsLoss)
+            or criterion_valid is None
+        ):
             positive_weights = get_positive_weights(validation_dataset)
             positive_weights = positive_weights.to(
-                            device=device, non_blocking=torch.cuda.is_available()
-                        )
+                device=device, non_blocking=torch.cuda.is_available()
+            )
             criterion_valid = BCEWithLogitsLoss(pos_weight=positive_weights)
         else:
             logger.warning("Weighted valid criterion not supported")
 
         valid_loader = DataLoader(
-                dataset=validation_dataset,
-                batch_size=batch_size,
-                num_workers=num_workers,
-                shuffle=False,
-                drop_last=False,
-                pin_memory=torch.cuda.is_available(),
-                )
+            dataset=validation_dataset,
+            batch_size=batch_chunk_size,
+            shuffle=False,
+            drop_last=False,
+            pin_memory=torch.cuda.is_available(),
+            **multiproc_kwargs,
+        )
+
+    extra_valid_loaders = [
+        DataLoader(
+            dataset=k,
+            batch_size=batch_chunk_size,
+            shuffle=False,
+            drop_last=False,
+            pin_memory=torch.cuda.is_available(),
+            **multiproc_kwargs,
+        )
+        for k in extra_validation_datasets
+    ]
 
     # Create z-normalization model layer if needed
     if normalization == "imagenet":
-        model.normalizer.set_mean_std([0.485, 0.456, 0.406],
-                                    [0.229, 0.224, 0.225])
+        model.normalizer.set_mean_std(
+            [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
+        )
         logger.info("Z-normalization with ImageNet mean and std")
     elif normalization == "current":
         # Compute mean/std of current train subset
-        temp_dl = DataLoader(
-            dataset=use_dataset,
-            batch_size=len(use_dataset)
-        )
+        temp_dl = DataLoader(dataset=use_dataset, batch_size=len(use_dataset))
 
         data = next(iter(temp_dl))
-        mean = data[1].mean(dim=[0,2,3])
-        std = data[1].std(dim=[0,2,3])
+        mean = data[1].mean(dim=[0, 2, 3])
+        std = data[1].std(dim=[0, 2, 3])
 
         model.normalizer.set_mean_std(mean, std)
 
         # Format mean and std for logging
-        mean = str([round(x, 3) for x in ((mean * 10**3).round() / (10**3)).tolist()])
-        std = str([round(x, 3) for x in ((std * 10**3).round() / (10**3)).tolist()])
+        mean = str(
+            [
+                round(x, 3)
+                for x in ((mean * 10**3).round() / (10**3)).tolist()
+            ]
+        )
+        std = str(
+            [
+                round(x, 3)
+                for x in ((std * 10**3).round() / (10**3)).tolist()
+            ]
+        )
         logger.info("Z-normalization with mean {} and std {}".format(mean, std))
 
     # Checkpointer
@@ -329,15 +547,18 @@ def train(
     logger.info("Continuing from epoch {}".format(arguments["epoch"]))
 
     run(
-        model,
-        data_loader,
-        valid_loader,
-        optimizer,
-        criterion,
-        checkpointer,
-        checkpoint_period,
-        device,
-        arguments,
-        output_folder,
-        criterion_valid,
-    )
\ No newline at end of file
+        model=model,
+        data_loader=data_loader,
+        valid_loader=valid_loader,
+        extra_valid_loaders=extra_valid_loaders,
+        optimizer=optimizer,
+        criterion=criterion,
+        checkpointer=checkpointer,
+        checkpoint_period=checkpoint_period,
+        device=device,
+        arguments=arguments,
+        output_folder=output_folder,
+        monitoring_interval=monitoring_interval,
+        batch_chunk_count=batch_chunk_count,
+        criterion_valid=criterion_valid,
+    )
diff --git a/bob/med/tb/scripts/train_analysis.py b/bob/med/tb/scripts/train_analysis.py
new file mode 100644
index 0000000000000000000000000000000000000000..163a80bf7f8ae95d2971af6cb8348c910e5871af
--- /dev/null
+++ b/bob/med/tb/scripts/train_analysis.py
@@ -0,0 +1,202 @@
+#!/usr/bin/env python
+# coding=utf-8
+
+import logging
+import os
+
+import click
+import matplotlib.pyplot as plt
+import numpy
+import pandas
+
+from matplotlib.backends.backend_pdf import PdfPages
+
+logger = logging.getLogger(__name__)
+
+
+from bob.extension.scripts.click_helper import (
+    ConfigCommand,
+    ResourceOption,
+    verbosity_option,
+)
+
+
+def _loss_evolution(df):
+    """Plots the loss evolution over time (epochs)
+
+    Parameters
+    ----------
+
+    df : pandas.DataFrame
+        dataframe containing the training logs
+
+
+    Returns
+    -------
+
+    figure : matplotlib.figure.Figure
+        figure to be displayed or saved to file
+
+    """
+
+    figure = plt.figure()
+    axes = figure.gca()
+
+    axes.plot(df.epoch.values, df.loss.values, label="Training")
+    if "validation_loss" in df.columns:
+        axes.plot(
+            df.epoch.values, df.validation_loss.values, label="Validation"
+        )
+        # shows a red dot on the location with the minima on the validation set
+        lowest_index = numpy.argmin(df["validation_loss"])
+
+        axes.plot(
+            df.epoch.values[lowest_index],
+            df.validation_loss[lowest_index],
+            "mo",
+            label=f"Lowest validation ({df.validation_loss[lowest_index]:.3f}@{df.epoch[lowest_index]})",
+        )
+
+    if "extra_validation_losses" in df.columns:
+        # These losses are in array format. So, we read all rows, then create a
+        # 2d array.  We transpose the array to iterate over each column and
+        # plot the losses individually.  They are numbered from 1.
+        df["extra_validation_losses"] = df["extra_validation_losses"].apply(
+            lambda x: numpy.fromstring(x.strip("[]"), sep=" ")
+        )
+        losses = numpy.vstack(df.extra_validation_losses.values).T
+        for n, k in enumerate(losses):
+            axes.plot(df.epoch.values, k, label=f"Extra validation {n+1}")
+
+    axes.set_title("Loss over time")
+    axes.set_xlabel("Epoch")
+    axes.set_ylabel("Loss")
+
+    axes.legend(loc="best")
+    axes.grid(alpha=0.3)
+    figure.set_layout_engine("tight")
+
+    return figure
+
+
+def _hardware_utilisation(df, const):
+    """Plot the CPU utilisation over time (epochs).
+
+    Parameters
+    ----------
+
+    df : pandas.DataFrame
+        dataframe containing the training logs
+
+    const : dict
+        training and hardware constants
+
+
+    Returns
+    -------
+
+    figure : matplotlib.figure.Figure
+        figure to be displayed or saved to file
+
+    """
+    figure = plt.figure()
+    axes = figure.gca()
+
+    cpu_percent = df.cpu_percent.values / const["cpu_count"]
+    cpu_memory = 100 * df.cpu_rss / const["cpu_memory_total"]
+
+    axes.plot(
+        df.epoch.values,
+        cpu_percent,
+        label=f"CPU usage (cores: {const['cpu_count']})",
+    )
+    axes.plot(
+        df.epoch.values,
+        cpu_memory,
+        label=f"CPU memory (total: {const['cpu_memory_total']:.1f} Gb)",
+    )
+    if "gpu_percent" in df:
+        axes.plot(
+            df.epoch.values,
+            df.gpu_percent.values,
+            label=f"GPU usage (type: {const['gpu_name']})",
+        )
+    if "gpu_memory_percent" in df:
+        axes.plot(
+            df.epoch.values,
+            df.gpu_memory_percent.values,
+            label=f"GPU memory (total: {const['gpu_memory_total']:.1f} Gb)",
+        )
+    axes.set_title("Hardware utilisation over time")
+    axes.set_xlabel("Epoch")
+    axes.set_ylabel("Relative utilisation (%)")
+    axes.set_ylim([0, 100])
+
+    axes.legend(loc="best")
+    axes.grid(alpha=0.3)
+    figure.set_layout_engine("tight")
+
+    return figure
+
+
+def base_analysis(log, constants, output_pdf, verbose, **kwargs):
+    """Create base train_analysis function."""
+
+
+@click.command(
+    entry_point_group="bob.med.tb.config",
+    cls=ConfigCommand,
+    epilog="""Examples:
+
+\b
+    1. Analyzes a training log and produces various plots:
+
+       $ bob binseg train-analysis -vv log.csv constants.csv
+
+""",
+)
+@click.argument(
+    "log",
+    type=click.Path(dir_okay=False, exists=True),
+)
+@click.argument(
+    "constants",
+    type=click.Path(dir_okay=False, exists=True),
+)
+@click.option(
+    "--output-pdf",
+    "-o",
+    help="Name of the output file to dump",
+    required=True,
+    show_default=True,
+    default="trainlog.pdf",
+)
+@verbosity_option(cls=ResourceOption)
+def train_analysis(
+    log,
+    constants,
+    output_pdf,
+    verbose,
+    **kwargs,
+):
+    """Analyze the training logs for loss evolution and resource utilisation."""
+
+    constants = pandas.read_csv(constants)
+    constants = dict(zip(constants.keys(), constants.values[0]))
+    data = pandas.read_csv(log)
+
+    # makes sure the directory to save the output PDF is there
+    dirname = os.path.dirname(os.path.realpath(output_pdf))
+    if not os.path.exists(dirname):
+        os.makedirs(dirname)
+
+    # now, do the analysis
+    with PdfPages(output_pdf) as pdf:
+
+        figure = _loss_evolution(data)
+        pdf.savefig(figure)
+        plt.close(figure)
+
+        figure = _hardware_utilisation(data, constants)
+        pdf.savefig(figure)
+        plt.close(figure)
diff --git a/bob/med/tb/test/test_cli.py b/bob/med/tb/test/test_cli.py
index b62c2193fa9d2bfd942f6705f6e033592bd00474..d6d35205ec3f4692ededa036f413deb62bed5413 100644
--- a/bob/med/tb/test/test_cli.py
+++ b/bob/med/tb/test/test_cli.py
@@ -6,9 +6,11 @@
 import os
 import re
 import contextlib
+from _pytest.tmpdir import tmp_path
 from bob.extension import rc
 import pkg_resources
 
+import pytest
 from click.testing import CliRunner
 
 from . import mock_dataset
@@ -19,7 +21,12 @@ montgomery_datadir = mock_dataset()
 _pasa_checkpoint_URL = "http://www.idiap.ch/software/bob/data/bob/bob.med.tb/master/_test_fpasa_checkpoint.pth"
 _signstotb_checkpoint_URL = "http://www.idiap.ch/software/bob/data/bob/bob.med.tb/master/_test_signstotb_checkpoint.pth"
 _logreg_checkpoint_URL = "http://www.idiap.ch/software/bob/data/bob/bob.med.tb/master/_test_logreg_checkpoint.pth"
-#_densenetrs_checkpoint_URL = "http://www.idiap.ch/software/bob/data/bob/bob.med.tb/master/_test_densenetrs_checkpoint.pth"
+# _densenetrs_checkpoint_URL = "http://www.idiap.ch/software/bob/data/bob/bob.med.tb/master/_test_densenetrs_checkpoint.pth"
+
+
+@pytest.fixture(scope="session")
+def temporary_basedir(tmp_path_factory):
+    return tmp_path_factory.mktemp("test-cli")
 
 
 @contextlib.contextmanager
@@ -37,7 +44,6 @@ def rc_context(**new_config):
 def stdout_logging():
 
     ## copy logging messages to std out
-    import sys
     import logging
     import io
 
@@ -197,7 +203,7 @@ def test_compare_help():
     _check_help(compare)
 
 
-def test_train_pasa_montgomery():
+def test_train_pasa_montgomery(temporary_basedir):
 
     # Temporarily modify Montgomery datadir
     new_value = {"bob.med.tb.montgomery.datadir": montgomery_datadir}
@@ -209,7 +215,7 @@ def test_train_pasa_montgomery():
 
         with stdout_logging() as buf:
 
-            output_folder = "results"
+            output_folder = str(temporary_basedir / "results")
             result = runner.invoke(
                 train,
                 [
@@ -225,7 +231,7 @@ def test_train_pasa_montgomery():
             _assert_exit_0(result)
 
             assert os.path.exists(
-                os.path.join(output_folder, "model_final.pth")
+                os.path.join(output_folder, "model_final_epoch.pth")
             )
             assert os.path.exists(
                 os.path.join(output_folder, "model_lowest_valid_loss.pth")
@@ -260,7 +266,7 @@ def test_train_pasa_montgomery():
                 )
 
 
-def test_predict_pasa_montgomery():
+def test_predict_pasa_montgomery(temporary_basedir):
 
     # Temporarily modify Montgomery datadir
     new_value = {"bob.med.tb.montgomery.datadir": montgomery_datadir}
@@ -272,7 +278,7 @@ def test_predict_pasa_montgomery():
 
         with stdout_logging() as buf:
 
-            output_folder = "predictions"
+            output_folder = str(temporary_basedir / "predictions")
             result = runner.invoke(
                 predict,
                 [
@@ -317,7 +323,7 @@ def test_predict_pasa_montgomery():
                 )
 
 
-def test_predtojson():
+def test_predtojson(temporary_basedir):
 
     # Temporarily modify Montgomery datadir
     new_value = {"bob.med.tb.montgomery.datadir": montgomery_datadir}
@@ -330,7 +336,7 @@ def test_predtojson():
         with stdout_logging() as buf:
 
             predictions = _data_file("test_predictions.csv")
-            output_folder = "pred_to_json"
+            output_folder = str(temporary_basedir / "pred_to_json")
             result = runner.invoke(
                 predtojson,
                 [
@@ -348,7 +354,7 @@ def test_predtojson():
             assert os.path.exists(os.path.join(output_folder, "dataset.json"))
 
             keywords = {
-                r"Output folder: pred_to_json": 1,
+                f"Output folder: {output_folder}": 1,
                 r"Saving JSON file...": 1,
                 r"^Loading predictions from.*$": 2,
             }
@@ -363,7 +369,7 @@ def test_predtojson():
                 )
 
 
-def test_evaluate_pasa_montgomery():
+def test_evaluate_pasa_montgomery(temporary_basedir):
 
     # Temporarily modify Montgomery datadir
     new_value = {"bob.med.tb.montgomery.datadir": montgomery_datadir}
@@ -375,8 +381,8 @@ def test_evaluate_pasa_montgomery():
 
         with stdout_logging() as buf:
 
-            prediction_folder = "predictions"
-            output_folder = "evaluations"
+            prediction_folder = str(temporary_basedir / "predictions")
+            output_folder = str(temporary_basedir / "evaluations")
             result = runner.invoke(
                 evaluate,
                 [
@@ -418,7 +424,7 @@ def test_evaluate_pasa_montgomery():
                 )
 
 
-def test_compare_pasa_montgomery():
+def test_compare_pasa_montgomery(temporary_basedir):
 
     # Temporarily modify Montgomery datadir
     new_value = {"bob.med.tb.montgomery.datadir": montgomery_datadir}
@@ -430,8 +436,8 @@ def test_compare_pasa_montgomery():
 
         with stdout_logging() as buf:
 
-            predictions_folder = "predictions"
-            output_folder = "comparisons"
+            predictions_folder = str(temporary_basedir / "predictions")
+            output_folder = str(temporary_basedir / "comparisons")
             result = runner.invoke(
                 compare,
                 [
@@ -467,7 +473,7 @@ def test_compare_pasa_montgomery():
                 )
 
 
-def test_train_signstotb_montgomery_rs():
+def test_train_signstotb_montgomery_rs(temporary_basedir):
 
     from ..scripts.train import train
 
@@ -475,7 +481,7 @@ def test_train_signstotb_montgomery_rs():
 
     with stdout_logging() as buf:
 
-        output_folder = "results"
+        output_folder = str(temporary_basedir / "results")
         result = runner.invoke(
             train,
             [
@@ -490,7 +496,9 @@ def test_train_signstotb_montgomery_rs():
         )
         _assert_exit_0(result)
 
-        assert os.path.exists(os.path.join(output_folder, "model_final.pth"))
+        assert os.path.exists(
+            os.path.join(output_folder, "model_final_epoch.pth")
+        )
         assert os.path.exists(
             os.path.join(output_folder, "model_lowest_valid_loss.pth")
         )
@@ -519,7 +527,7 @@ def test_train_signstotb_montgomery_rs():
             )
 
 
-def test_predict_signstotb_montgomery_rs():
+def test_predict_signstotb_montgomery_rs(temporary_basedir):
 
     from ..scripts.predict import predict
 
@@ -527,7 +535,7 @@ def test_predict_signstotb_montgomery_rs():
 
     with stdout_logging() as buf:
 
-        output_folder = "predictions"
+        output_folder = str(temporary_basedir / "predictions")
         result = runner.invoke(
             predict,
             [
@@ -569,7 +577,7 @@ def test_predict_signstotb_montgomery_rs():
             )
 
 
-def test_train_logreg_montgomery_rs():
+def test_train_logreg_montgomery_rs(temporary_basedir):
 
     from ..scripts.train import train
 
@@ -577,7 +585,7 @@ def test_train_logreg_montgomery_rs():
 
     with stdout_logging() as buf:
 
-        output_folder = "results"
+        output_folder = str(temporary_basedir / "results")
         result = runner.invoke(
             train,
             [
@@ -592,7 +600,9 @@ def test_train_logreg_montgomery_rs():
         )
         _assert_exit_0(result)
 
-        assert os.path.exists(os.path.join(output_folder, "model_final.pth"))
+        assert os.path.exists(
+            os.path.join(output_folder, "model_final_epoch.pth")
+        )
         assert os.path.exists(
             os.path.join(output_folder, "model_lowest_valid_loss.pth")
         )
@@ -621,7 +631,7 @@ def test_train_logreg_montgomery_rs():
             )
 
 
-def test_predict_logreg_montgomery_rs():
+def test_predict_logreg_montgomery_rs(temporary_basedir):
 
     from ..scripts.predict import predict
 
@@ -629,7 +639,7 @@ def test_predict_logreg_montgomery_rs():
 
     with stdout_logging() as buf:
 
-        output_folder = "predictions"
+        output_folder = str(temporary_basedir / "predictions")
         result = runner.invoke(
             predict,
             [
@@ -665,7 +675,7 @@ def test_predict_logreg_montgomery_rs():
             )
 
 
-def test_aggregpred():
+def test_aggregpred(temporary_basedir):
 
     # Temporarily modify Montgomery datadir
     new_value = {"bob.med.tb.montgomery.datadir": montgomery_datadir}
@@ -677,8 +687,10 @@ def test_aggregpred():
 
         with stdout_logging() as buf:
 
-            predictions = "predictions/train/predictions.csv"
-            output_folder = "aggregpred"
+            predictions = str(
+                temporary_basedir / "predictions" / "train" / "predictions.csv"
+            )
+            output_folder = str(temporary_basedir / "aggregpred")
             result = runner.invoke(
                 aggregpred,
                 [
@@ -694,7 +706,7 @@ def test_aggregpred():
             assert os.path.exists(os.path.join(output_folder, "aggregpred.csv"))
 
             keywords = {
-                r"Output folder: aggregpred": 1,
+                f"Output folder: {output_folder}": 1,
                 r"Saving aggregated CSV file...": 1,
                 r"^Loading predictions from.*$": 2,
             }
@@ -710,7 +722,7 @@ def test_aggregpred():
 
 
 # Not enough RAM available to do this test
-# def test_predict_densenetrs_montgomery():
+# def test_predict_densenetrs_montgomery(temporary_basedir):
 
 #     # Temporarily modify Montgomery datadir
 #     new_value = {"bob.med.tb.montgomery.datadir": montgomery_datadir}
@@ -722,7 +734,7 @@ def test_aggregpred():
 
 #         with stdout_logging() as buf:
 
-#             output_folder = "predictions"
+#             output_folder = str(temporary_basedir / "predictions")
 #             result = runner.invoke(
 #                 predict,
 #                 [
diff --git a/bob/med/tb/utils/resources.py b/bob/med/tb/utils/resources.py
index ea64657ca7f4f1dea3684fc0499eb269a8ffe2e4..7bbaa9afcce55b7925bf2eeac007da99375074fa 100644
--- a/bob/med/tb/utils/resources.py
+++ b/bob/med/tb/utils/resources.py
@@ -3,21 +3,24 @@
 
 """Tools for interacting with the running computer or GPU"""
 
+import logging
+import multiprocessing
 import os
-import subprocess
+import queue
 import shutil
+import subprocess
+import time
 
+import numpy
 import psutil
 
-import logging
-
 logger = logging.getLogger(__name__)
 
 _nvidia_smi = shutil.which("nvidia-smi")
 """Location of the nvidia-smi program, if one exists"""
 
 
-GB = float(2 ** 30)
+GB = float(2**30)
 """The number of bytes in a gigabyte"""
 
 
@@ -58,9 +61,14 @@ def run_nvidia_smi(query, rename=None):
         else:
             assert len(rename) == len(query)
 
+        # Get GPU information based on GPU ID.
         values = subprocess.getoutput(
-            "%s --query-gpu=%s --format=csv,noheader"
-            % (_nvidia_smi, ",".join(query))
+            "%s --query-gpu=%s --format=csv,noheader --id=%s"
+            % (
+                _nvidia_smi,
+                ",".join(query),
+                os.environ.get("CUDA_VISIBLE_DEVICES"),
+            )
         )
         values = [k.strip() for k in values.split(",")]
         t_values = []
@@ -70,7 +78,7 @@ def run_nvidia_smi(query, rename=None):
             elif k.endswith("MiB"):
                 t_values.append(float(k[:-3].strip()) / 1024)
             else:
-                t_values.append(k)  #unchanged
+                t_values.append(k)  # unchanged
         return tuple(zip(rename, t_values))
 
 
@@ -117,26 +125,35 @@ def gpu_log():
           :py:class:`float`)
         * ``memory.free``, as ``gpu_memory_free`` (transformed to gigabytes,
           :py:class:`float`)
-        * ``utilization.memory``, as ``gpu_memory_percent``,
+        * ``100*memory.used/memory.total``, as ``gpu_memory_percent``,
           (:py:class:`float`, in percent)
-        * ``utilization.gpu``, as ``gpu_utilization``,
+        * ``utilization.gpu``, as ``gpu_percent``,
           (:py:class:`float`, in percent)
 
     """
 
-    return run_nvidia_smi(
-        ("memory.used", "memory.free", "utilization.memory", "utilization.gpu"),
+    retval = run_nvidia_smi(
+        (
+            "memory.total",
+            "memory.used",
+            "memory.free",
+            "utilization.gpu",
+        ),
         (
+            "gpu_memory_total",
             "gpu_memory_used",
             "gpu_memory_free",
-            "gpu_memory_percent",
             "gpu_percent",
         ),
     )
 
-
-_CLUSTER = []
-"""List of processes currently being monitored"""
+    # re-compose the output to generate expected values
+    return (
+        retval[1],  # gpu_memory_used
+        retval[2],  # gpu_memory_free
+        ("gpu_memory_percent", 100 * (retval[1][1] / retval[0][1])),
+        retval[3],  # gpu_percent
+    )
 
 
 def cpu_constants():
@@ -161,60 +178,270 @@ def cpu_constants():
     )
 
 
-def cpu_log():
-    """Returns process (+child) information using ``psutil``.
-
-    This call examines the current process plus any spawn child and returns the
-    combined resource usage summary for the process group.
+class CPULogger:
+    """Logs CPU information using :py:mod:`psutil`
 
 
-    Returns
-    -------
-
-    data : tuple
-        An ordered dictionary (organized as 2-tuples) containing these entries:
+    Parameters
+    ----------
 
-        0. ``cpu_memory_used`` (:py:class:`float`): total memory used from
-           the system, in gigabytes
-        1. ``cpu_rss`` (:py:class:`float`):  RAM currently used by
-           process and children, in gigabytes
-        2. ``cpu_vms`` (:py:class:`float`):  total memory (RAM + swap) currently
-           used by process and children, in gigabytes
-        3. ``cpu_percent`` (:py:class:`float`): percentage of the total CPU
-           used by this process and children (recursively) since last call
-           (first time called should be ignored).  This number depends on the
-           number of CPUs in the system and can be greater than 100%
-        4. ``cpu_processes`` (:py:class:`int`): total number of processes
-           including self and children (recursively)
-        5. ``cpu_open_files`` (:py:class:`int`): total number of open files by
-           self and children
+    pid : :py:class:`int`, Optional
+        Process identifier of the main process (parent process) to observe
 
     """
 
-    global _CLUSTER
-    if (not _CLUSTER) or (_CLUSTER[0] != psutil.Process()):  # initialization
-        this = psutil.Process()
-        _CLUSTER = [this] + this.children(recursive=True)
-        # touch cpu_percent() at least once for all
-        [k.cpu_percent(interval=None) for k in _CLUSTER]
-    else:
+    def __init__(self, pid=None):
+        this = psutil.Process(pid=pid)
+        self.cluster = [this] + this.children(recursive=True)
+        # touch cpu_percent() at least once for all processes in the cluster
+        [k.cpu_percent(interval=None) for k in self.cluster]
+
+    def log(self):
+        """Returns current process cluster information
+
+        Returns
+        -------
+
+        data : tuple
+            An ordered dictionary (organized as 2-tuples) containing these entries:
+
+            0. ``cpu_memory_used`` (:py:class:`float`): total memory used from
+               the system, in gigabytes
+            1. ``cpu_rss`` (:py:class:`float`):  RAM currently used by
+               process and children, in gigabytes
+            2. ``cpu_vms`` (:py:class:`float`):  total memory (RAM + swap) currently
+               used by process and children, in gigabytes
+            3. ``cpu_percent`` (:py:class:`float`): percentage of the total CPU
+               used by this process and children (recursively) since last call
+               (first time called should be ignored).  This number depends on the
+               number of CPUs in the system and can be greater than 100%
+            4. ``cpu_processes`` (:py:class:`int`): total number of processes
+               including self and children (recursively)
+            5. ``cpu_open_files`` (:py:class:`int`): total number of open files by
+               self and children
+
+        """
+
         # check all cluster components and update process list
         # done so we can keep the cpu_percent() initialization
-        children = _CLUSTER[0].children()
-        stored_children = set(_CLUSTER[1:])
-        current_children = set(_CLUSTER[0].children())
+        stored_children = set(self.cluster[1:])
+        current_children = set(self.cluster[0].children(recursive=True))
         keep_children = stored_children - current_children
         new_children = current_children - stored_children
-        [k.cpu_percent(interval=None) for k in new_children]
-        _CLUSTER = _CLUSTER[:1] + list(keep_children) + list(new_children)
+        gone = set()
+        for k in new_children:
+            try:
+                k.cpu_percent(interval=None)
+            except (psutil.ZombieProcess, psutil.NoSuchProcess):
+                # child process is gone meanwhile
+                # update the intermediate list for this time
+                gone.add(k)
+        new_children = new_children - gone
+        self.cluster = (
+            self.cluster[:1] + list(keep_children) + list(new_children)
+        )
 
-    memory_info = [k.memory_info() for k in _CLUSTER]
+        memory_info = []
+        cpu_percent = []
+        open_files = []
+        gone = set()
+        for k in self.cluster:
+            try:
+                memory_info.append(k.memory_info())
+                cpu_percent.append(k.cpu_percent(interval=None))
+                open_files.append(len(k.open_files()))
+            except (psutil.ZombieProcess, psutil.NoSuchProcess):
+                # child process is gone meanwhile, just ignore it
+                # it is too late to update any intermediate list
+                # at this point, but ensures to update counts later on
+                gone.add(k)
+
+        return (
+            ("cpu_memory_used", psutil.virtual_memory().used / GB),
+            ("cpu_rss", sum([k.rss for k in memory_info]) / GB),
+            ("cpu_vms", sum([k.vms for k in memory_info]) / GB),
+            ("cpu_percent", sum(cpu_percent)),
+            ("cpu_processes", len(self.cluster) - len(gone)),
+            ("cpu_open_files", sum(open_files)),
+        )
 
-    return (
-        ("cpu_memory_used", psutil.virtual_memory().used / GB),
-        ("cpu_rss", sum([k.rss for k in memory_info]) / GB),
-        ("cpu_vms", sum([k.vms for k in memory_info]) / GB),
-        ("cpu_percent", sum(k.cpu_percent(interval=None) for k in _CLUSTER)),
-        ("cpu_processes", len(_CLUSTER)),
-        ("cpu_open_files", sum(len(k.open_files()) for k in _CLUSTER)),
-    )
+
+class _InformationGatherer:
+    """A container to store monitoring information
+
+    Parameters
+    ----------
+
+    has_gpu : bool
+        A flag indicating if we have a GPU installed on the platform or not
+
+    main_pid : int
+        The main process identifier to monitor
+
+    logger : logging.Logger
+        A logger to be used for logging messages
+
+    """
+
+    def __init__(self, has_gpu, main_pid, logger):
+        self.cpu_logger = CPULogger(main_pid)
+        self.keys = [k[0] for k in self.cpu_logger.log()]
+        self.cpu_keys_len = len(self.keys)
+        self.has_gpu = has_gpu
+        self.logger = logger
+        if self.has_gpu:
+            self.keys += [k[0] for k in gpu_log()]
+        self.data = [[] for _ in self.keys]
+
+    def acc(self):
+        """Accumulates another measurement"""
+        for i, k in enumerate(self.cpu_logger.log()):
+            self.data[i].append(k[1])
+        if self.has_gpu:
+            for i, k in enumerate(gpu_log()):
+                self.data[i + self.cpu_keys_len].append(k[1])
+
+    def summary(self):
+        """Returns the current data"""
+
+        if len(self.data[0]) == 0:
+            self.logger.error("CPU/GPU logger was not able to collect any data")
+        retval = []
+        for k, values in zip(self.keys, self.data):
+            retval.append((k, values))
+        return tuple(retval)
+
+
+def _monitor_worker(interval, has_gpu, main_pid, stop, queue, logging_level):
+    """A monitoring worker that measures resources and returns lists
+
+    Parameters
+    ==========
+
+    interval : int, float
+        Number of seconds to wait between each measurement (maybe a floating
+        point number as accepted by :py:func:`time.sleep`)
+
+    has_gpu : bool
+        A flag indicating if we have a GPU installed on the platform or not
+
+    main_pid : int
+        The main process identifier to monitor
+
+    stop : :py:class:`multiprocessing.Event`
+        Indicates if we should continue running or stop
+
+    queue : :py:class:`queue.Queue`
+        A queue, to send monitoring information back to the spawner
+
+    logging_level: int
+        The logging level to use for logging from launched processes
+
+    """
+
+    logger = multiprocessing.log_to_stderr(level=logging_level)
+    ra = _InformationGatherer(has_gpu, main_pid, logger)
+
+    while not stop.is_set():
+        try:
+            ra.acc()  # guarantees at least an entry will be available
+            time.sleep(interval)
+        except Exception:
+            logger.warning(
+                "Iterative CPU/GPU logging did not work properly " "this once",
+                exc_info=True,
+            )
+            time.sleep(0.5)  # wait half a second, and try again!
+
+    queue.put(ra.summary())
+
+
+class ResourceMonitor:
+    """An external, non-blocking CPU/GPU resource monitor
+
+    Parameters
+    ----------
+
+    interval : int, float
+        Number of seconds to wait between each measurement (maybe a floating
+        point number as accepted by :py:func:`time.sleep`)
+
+    has_gpu : bool
+        A flag indicating if we have a GPU installed on the platform or not
+
+    main_pid : int
+        The main process identifier to monitor
+
+    logging_level: int
+        The logging level to use for logging from launched processes
+
+    """
+
+    def __init__(self, interval, has_gpu, main_pid, logging_level):
+
+        self.interval = interval
+        self.has_gpu = has_gpu
+        self.main_pid = main_pid
+        self.event = multiprocessing.Event()
+        self.q = multiprocessing.Queue()
+        self.logging_level = logging_level
+
+        self.monitor = multiprocessing.Process(
+            target=_monitor_worker,
+            name="ResourceMonitorProcess",
+            args=(
+                self.interval,
+                self.has_gpu,
+                self.main_pid,
+                self.event,
+                self.q,
+                self.logging_level,
+            ),
+        )
+
+        self.data = None
+
+    @staticmethod
+    def monitored_keys(has_gpu):
+
+        return _InformationGatherer(has_gpu, None, logger).keys
+
+    def __enter__(self):
+        """Starts the monitoring process"""
+
+        self.monitor.start()
+        return self
+
+    def __exit__(self, *exc):
+        """Stops the monitoring process and returns the summary of observations"""
+
+        self.event.set()
+        self.monitor.join()
+        if self.monitor.exitcode != 0:
+            logger.error(
+                f"CPU/GPU resource monitor process exited with code "
+                f"{self.monitor.exitcode}.  Check logs for errors!"
+            )
+
+        try:
+            data = self.q.get(timeout=2 * self.interval)
+        except queue.Empty:
+            logger.warn(
+                f"CPU/GPU resource monitor did not provide anything when "
+                f"joined (even after a {2*self.interval}-second timeout - "
+                f"this is normally due to exceptions on the monitoring process. "
+                f"Check above for other exceptions."
+            )
+            self.data = None
+        else:
+            # summarize the returned data by creating means
+            summary = []
+            for k, values in data:
+                if values:
+                    if k in ("cpu_processes", "cpu_open_files"):
+                        summary.append((k, numpy.max(values)))
+                    else:
+                        summary.append((k, numpy.mean(values)))
+                else:
+                    summary.append((k, 0.0))
+            self.data = tuple(summary)
diff --git a/doc/conf.py b/doc/conf.py
index a6acde2d93ffc5cc01c9c9d4f1380658a8cd9283..6af7632807452778399b32736c92f192c58126fb 100644
--- a/doc/conf.py
+++ b/doc/conf.py
@@ -23,14 +23,17 @@ extensions = [
     "sphinx.ext.napoleon",
     "sphinx.ext.viewcode",
     "sphinx.ext.mathjax",
-    'sphinxcontrib.programoutput',
-    #'matplotlib.sphinxext.plot_directive'
+    "sphinxcontrib.programoutput",
 ]
 
 # This allows sphinxcontrib-programoutput to work in buildout mode
-candidate_binpath = os.path.join(os.path.dirname(os.path.realpath(os.curdir)), 'bin')
+candidate_binpath = os.path.join(
+    os.path.dirname(os.path.realpath(os.curdir)), "bin"
+)
 if os.path.exists(candidate_binpath):
-    os.environ['PATH'] = candidate_binpath + os.pathsep + os.environ.get('PATH', '')
+    os.environ["PATH"] = (
+        candidate_binpath + os.pathsep + os.environ.get("PATH", "")
+    )
 
 # Be picky about warnings
 nitpicky = True
@@ -104,11 +107,11 @@ release = distribution.version
 # List of patterns, relative to source directory, that match files and
 # directories to ignore when looking for source files.
 exclude_patterns = [
-        'links.rst',
-        'api/modules.rst',
-        'api/bob.rst',
-        'api/bob.med.rst',
-        ]
+    "links.rst",
+    "api/modules.rst",
+    "api/bob.rst",
+    "api/bob.med.rst",
+]
 
 # The reST default role (used for this markup: `text`) to use for all documents.
 # default_role = None
@@ -132,8 +135,8 @@ pygments_style = "sphinx"
 
 # Some variables which are useful for generated material
 project_variable = project.replace(".", "_")
-short_description = u"Active Tuberculosis Detection On CXR Package for Bob"
-owner = [u"Idiap Research Institute"]
+short_description = "Active Tuberculosis Detection On CXR Package for Bob"
+owner = ["Idiap Research Institute"]
 
 # -- Options for HTML output ---------------------------------------------------
 
@@ -170,7 +173,7 @@ html_favicon = "img/bob-favicon.ico"
 # Add any paths that contain custom static files (such as style sheets) here,
 # relative to this directory. They are copied after the builtin static files,
 # so a file named "default.css" will overwrite the builtin "default.css".
-html_static_path = ['_static']
+html_static_path = ["_static"]
 
 # If not '', a 'Last updated on:' timestamp is inserted at every page bottom,
 # using the given strftime format.
@@ -214,7 +217,7 @@ html_static_path = ['_static']
 # html_file_suffix = None
 
 # Output file base name for HTML help builder.
-htmlhelp_basename = project_variable + u"_doc"
+htmlhelp_basename = project_variable + "_doc"
 
 # -- Post configuration --------------------------------------------------------
 
@@ -232,9 +235,9 @@ rst_epilog = """
 autoclass_content = "class"
 autodoc_member_order = "bysource"
 autodoc_default_options = {
-  "members": True,
-  "undoc-members": True,
-  "show-inheritance": True,
+    "members": True,
+    "undoc-members": True,
+    "show-inheritance": True,
 }
 
 # For inter-documentation mapping: notice sphinx changes to the current
@@ -244,11 +247,17 @@ from bob.extension.utils import link_documentation
 if os.path.exists("requirements.txt"):
     # building on the CI, with a copy of requirements.txt
     intersphinx_mapping = link_documentation(
-        requirements_file="requirements.txt"
+        requirements_file="requirements.txt",
     )
 else:
     # building locally
     intersphinx_mapping = link_documentation()
 
+# Adds psutil
+intersphinx_mapping["psutil"] = (
+    "https://psutil.readthedocs.io/en/latest/",
+    None,
+)
+
 # Add our private index (for extras and fixes)
 intersphinx_mapping["extras"] = ("", "extras.inv")
diff --git a/setup.py b/setup.py
index 2f07c7f9579b150b078d2d15161c3287517b106c..95a96f3674d97121c5bd168fbe60536262f25333 100644
--- a/setup.py
+++ b/setup.py
@@ -34,6 +34,7 @@ setup(
             "config = bob.med.tb.scripts.config:config",
             "dataset = bob.med.tb.scripts.dataset:dataset",
             "train = bob.med.tb.scripts.train:train",
+            "train-analysis = bob.med.tb.scripts.train_analysis:train_analysis",
             "predict = bob.med.tb.scripts.predict:predict",
             "evaluate = bob.med.tb.scripts.evaluate:evaluate",
             "predtojson = bob.med.tb.scripts.predtojson:predtojson",