diff --git a/bob/ip/binseg/engine/ssltrainer.py b/bob/ip/binseg/engine/ssltrainer.py
index d8b66b69d5e729de87abafd4963a9fc71a4a87d9..9db427a7ebe4b3138948da9d8e7b35259049cbe1 100644
--- a/bob/ip/binseg/engine/ssltrainer.py
+++ b/bob/ip/binseg/engine/ssltrainer.py
@@ -4,21 +4,23 @@
 import os
 import csv
 import time
+import shutil
 import datetime
 import distutils.version
 
 import numpy
-import pandas
 import torch
 from tqdm import tqdm
 
 from ..utils.measure import SmoothedValue
-from ..utils.plot import loss_curve
+from ..utils.summary import summary
+from ..utils.resources import cpu_constants, gpu_constants, cpu_log, gpu_log
 
 import logging
+
 logger = logging.getLogger(__name__)
 
-PYTORCH_GE_110 = (distutils.version.StrictVersion(torch.__version__) >= "1.1.0")
+PYTORCH_GE_110 = distutils.version.StrictVersion(torch.__version__) >= "1.1.0"
 
 
 def sharpen(x, T):
@@ -48,7 +50,7 @@ def mix_up(alpha, input, target, unlabelled_input, unlabled_target):
     list
 
     """
-    # TODO:
+
     with torch.no_grad():
         l = numpy.random.beta(alpha, alpha)  # Eq (8)
         l = max(l, 1 - l)  # Eq (9)
@@ -63,10 +65,12 @@ def mix_up(alpha, input, target, unlabelled_input, unlabled_target):
 
         # Apply MixUp to unlabelled data and entries from W. Alg. 1 Line: 14
         unlabelled_input_mixedup = (
-            l * unlabelled_input + (1 - l) * w_inputs[idx[: len(unlabelled_input)]]
+            l * unlabelled_input
+            + (1 - l) * w_inputs[idx[: len(unlabelled_input)]]
         )
         unlabled_target_mixedup = (
-            l * unlabled_target + (1 - l) * w_targets[idx[: len(unlabled_target)]]
+            l * unlabled_target
+            + (1 - l) * w_targets[idx[: len(unlabled_target)]]
         )
         return (
             input_mixedup,
@@ -176,6 +180,11 @@ def run(
     """
     Fits an FCN model using semi-supervised learning and saves it to disk.
 
+
+    This method supports periodic checkpointing and the output of a
+    CSV-formatted log with the evolution of some figures during training.
+
+
     Parameters
     ----------
 
@@ -193,7 +202,7 @@ def run(
         learning rate scheduler
 
     checkpointer : :py:class:`bob.ip.binseg.utils.checkpointer.DetectronCheckpointer`
-        checkpointer
+        checkpointer implementation
 
     checkpoint_period : int
         save a checkpoint every ``n`` epochs.  If set to ``0`` (zero), then do
@@ -216,45 +225,85 @@ def run(
     start_epoch = arguments["epoch"]
     max_epoch = arguments["max_epoch"]
 
-    if not os.path.exists(output_folder):
-        logger.debug(f"Creating output directory '{output_folder}'...")
-        os.makedirs(output_folder)
+    if device != "cpu":
+        # asserts we do have a GPU
+        assert bool(gpu_constants()), (
+            f"Device set to '{device}', but cannot "
+            f"find a GPU (maybe nvidia-smi is not installed?)"
+        )
 
-    # Log to file
+    os.makedirs(output_folder, exist_ok=True)
+
+    # Save model summary
+    summary_path = os.path.join(output_folder, "model_summary.txt")
+    logger.info(f"Saving model summary at {summary_path}...")
+    with open(summary_path, "wt") as f:
+        r, n = summary(model)
+        logger.info(f"Model has {n} parameters...")
+        f.write(r)
+
+    # write static information to a CSV file
+    static_logfile_name = os.path.join(output_folder, "constants.csv")
+    if os.path.exists(static_logfile_name):
+        backup = static_logfile_name + "~"
+        if os.path.exists(backup):
+            os.unlink(backup)
+        shutil.move(static_logfile_name, backup)
+    with open(static_logfile_name, "w", newline="") as f:
+        logdata = cpu_constants()
+        if device != "cpu":
+            logdata += gpu_constants()
+        logdata += (("model_size", n),)
+        logwriter = csv.DictWriter(f, fieldnames=[k[0] for k in logdata])
+        logwriter.writeheader()
+        logwriter.writerow(dict(k for k in logdata))
+
+    # Log continous information to (another) file
     logfile_name = os.path.join(output_folder, "trainlog.csv")
 
     if arguments["epoch"] == 0 and os.path.exists(logfile_name):
-        logger.info(f"Truncating {logfile_name} - training is restarting...")
-        os.unlink(logfile_name)
+        backup = logfile_name + "~"
+        if os.path.exists(backup):
+            os.unlink(backup)
+        shutil.move(logfile_name, backup)
 
     logfile_fields = (
         "epoch",
-        "total-time",
+        "total_time",
         "eta",
-        "average-loss",
-        "median-loss",
-        "median-labelled-loss",
-        "median-unlabelled-loss",
-        "learning-rate",
-        "gpu-memory-megabytes",
+        "average_loss",
+        "median_loss",
+        "median_labelled_loss",
+        "median_unlabelled_loss",
+        "learning_rate",
     )
+    logfile_fields += tuple([k[0] for k in cpu_log()])
+    if device != "cpu":
+        logfile_fields += tuple([k[0] for k in gpu_log()])
+
     with open(logfile_name, "a+", newline="") as logfile:
         logwriter = csv.DictWriter(logfile, fieldnames=logfile_fields)
 
         if arguments["epoch"] == 0:
             logwriter.writeheader()
 
+        model.train().to(device)
         for state in optimizer.state.values():
             for k, v in state.items():
                 if isinstance(v, torch.Tensor):
                     state[k] = v.to(device)
 
-        model.train().to(device)
         # Total training timer
         start_training_time = time.time()
 
-        for epoch in range(start_epoch, max_epoch):
-            if not PYTORCH_GE_110: scheduler.step()
+        for epoch in tqdm(
+            range(start_epoch, max_epoch),
+            desc="epoch",
+            leave=False,
+            disable=None,
+        ):
+            if not PYTORCH_GE_110:
+                scheduler.step()
             losses = SmoothedValue(len(data_loader))
             labelled_loss = SmoothedValue(len(data_loader))
             unlabelled_loss = SmoothedValue(len(data_loader))
@@ -264,8 +313,9 @@ def run(
             # Epoch time
             start_epoch_time = time.time()
 
-            for samples in tqdm(data_loader, desc="batches", leave=False,
-                    disable=None,):
+            for samples in tqdm(
+                data_loader, desc="batch", leave=False, disable=None
+            ):
 
                 # data forwarding on the existing network
 
@@ -277,12 +327,16 @@ def run(
                 outputs = model(images)
                 unlabelled_outputs = model(unlabelled_images)
                 # guessed unlabelled outputs
-                unlabelled_ground_truths = guess_labels(unlabelled_images, model)
+                unlabelled_ground_truths = guess_labels(
+                    unlabelled_images, model
+                )
                 # unlabelled_ground_truths = sharpen(unlabelled_ground_truths,0.5)
                 # images, ground_truths, unlabelled_images, unlabelled_ground_truths = mix_up(0.75, images, ground_truths, unlabelled_images, unlabelled_ground_truths)
 
                 # loss evaluation and learning (backward step)
-                ramp_up_factor = square_rampup(epoch, rampup_length=rampup_length)
+                ramp_up_factor = square_rampup(
+                    epoch, rampup_length=rampup_length
+                )
 
                 loss, ll, ul = criterion(
                     outputs,
@@ -299,7 +353,8 @@ def run(
                 unlabelled_loss.update(ul)
                 logger.debug(f"batch loss: {loss.item()}")
 
-            if PYTORCH_GE_110: scheduler.step()
+            if PYTORCH_GE_110:
+                scheduler.step()
 
             if checkpoint_period and (epoch % checkpoint_period == 0):
                 checkpointer.save(f"model_{epoch:03d}", **arguments)
@@ -316,33 +371,24 @@ def run(
             logdata = (
                 ("epoch", f"{epoch}"),
                 (
-                    "total-time",
+                    "total_time",
                     f"{datetime.timedelta(seconds=int(current_time))}",
                 ),
                 ("eta", f"{datetime.timedelta(seconds=int(eta_seconds))}"),
-                ("average-loss", f"{losses.avg:.6f}"),
-                ("median-loss", f"{losses.median:.6f}"),
-                ("median-labelled-loss", f"{labelled_loss.median:.6f}"),
-                ("median-unlabelled-loss", f"{unlabelled_loss.median:.6f}"),
-                ("learning-rate", f"{optimizer.param_groups[0]['lr']:.6f}"),
-                (
-                    "gpu-memory-megabytes",
-                    f"{torch.cuda.max_memory_allocated()/(1024.0*1024.0)}"
-                    if torch.cuda.is_available()
-                    else "0.0",
-                ),
-            )
+                ("average_loss", f"{losses.avg:.6f}"),
+                ("median_loss", f"{losses.median:.6f}"),
+                ("median_labelled_loss", f"{labelled_loss.median:.6f}"),
+                ("median_unlabelled_loss", f"{unlabelled_loss.median:.6f}"),
+                ("learning_rate", f"{optimizer.param_groups[0]['lr']:.6f}"),
+            ) + cpu_log()
+            if device != "cpu":
+                logdata += gpu_log()
+
             logwriter.writerow(dict(k for k in logdata))
-            logger.info("|".join([f"{k}: {v}" for (k, v) in logdata]))
+            logfile.flush()
+            tqdm.write("|".join([f"{k}: {v}" for (k, v) in logdata[:4]]))
 
         total_training_time = time.time() - start_training_time
         logger.info(
             f"Total training time: {datetime.timedelta(seconds=total_training_time)} ({(total_training_time/max_epoch):.4f}s in average per epoch)"
         )
-
-    # plots a version of the CSV trainlog into a PDF
-    logdf = pandas.read_csv(logfile_name, header=0, names=logfile_fields)
-    fig = loss_curve(logdf)
-    figurefile_name = os.path.join(output_folder, "trainlog.pdf")
-    logger.info(f"Saving {figurefile_name}")
-    fig.savefig(figurefile_name)