From 7eb7def379d98d1b523acf6b2758a266db0e1d04 Mon Sep 17 00:00:00 2001
From: Andre Anjos <andre.anjos@idiap.ch>
Date: Sun, 5 Apr 2020 11:45:10 +0200
Subject: [PATCH] [engine.ssltrainer] Re-sync with engine.trainer

---
 bob/ip/binseg/engine/ssltrainer.py | 205 ++++++++++++++---------------
 1 file changed, 99 insertions(+), 106 deletions(-)

diff --git a/bob/ip/binseg/engine/ssltrainer.py b/bob/ip/binseg/engine/ssltrainer.py
index f03e01e4..542d3162 100644
--- a/bob/ip/binseg/engine/ssltrainer.py
+++ b/bob/ip/binseg/engine/ssltrainer.py
@@ -2,6 +2,7 @@
 # -*- coding: utf-8 -*-
 
 import os
+import csv
 import time
 import datetime
 import torch
@@ -21,7 +22,7 @@ def sharpen(x, T):
     return temp / temp.sum(dim=1, keepdim=True)
 
 
-def mix_up(alpha, input, target, unlabeled_input, unlabled_target):
+def mix_up(alpha, input, target, unlabelled_input, unlabled_target):
     """Applies mix up as described in [MIXMATCH_19].
 
     Parameters
@@ -32,7 +33,7 @@ def mix_up(alpha, input, target, unlabeled_input, unlabled_target):
 
     target : :py:class:`torch.Tensor`
 
-    unlabeled_input : :py:class:`torch.Tensor`
+    unlabelled_input : :py:class:`torch.Tensor`
 
     unlabled_target : :py:class:`torch.Tensor`
 
@@ -48,17 +49,17 @@ def mix_up(alpha, input, target, unlabeled_input, unlabled_target):
         l = np.random.beta(alpha, alpha)  # Eq (8)
         l = max(l, 1 - l)  # Eq (9)
         # Shuffle and concat. Alg. 1 Line: 12
-        w_inputs = torch.cat([input, unlabeled_input], 0)
+        w_inputs = torch.cat([input, unlabelled_input], 0)
         w_targets = torch.cat([target, unlabled_target], 0)
         idx = torch.randperm(w_inputs.size(0))  # get random index
 
-        # Apply MixUp to labeled data and entries from W. Alg. 1 Line: 13
+        # Apply MixUp to labelled data and entries from W. Alg. 1 Line: 13
         input_mixedup = l * input + (1 - l) * w_inputs[idx[len(input) :]]
         target_mixedup = l * target + (1 - l) * w_targets[idx[len(target) :]]
 
-        # Apply MixUp to unlabeled data and entries from W. Alg. 1 Line: 14
-        unlabeled_input_mixedup = (
-            l * unlabeled_input + (1 - l) * w_inputs[idx[: len(unlabeled_input)]]
+        # Apply MixUp to unlabelled data and entries from W. Alg. 1 Line: 14
+        unlabelled_input_mixedup = (
+            l * unlabelled_input + (1 - l) * w_inputs[idx[: len(unlabelled_input)]]
         )
         unlabled_target_mixedup = (
             l * unlabled_target + (1 - l) * w_targets[idx[: len(unlabled_target)]]
@@ -66,7 +67,7 @@ def mix_up(alpha, input, target, unlabeled_input, unlabled_target):
         return (
             input_mixedup,
             target_mixedup,
-            unlabeled_input_mixedup,
+            unlabelled_input_mixedup,
             unlabled_target_mixedup,
         )
 
@@ -122,14 +123,14 @@ def linear_rampup(current, rampup_length=16):
     return float(current)
 
 
-def guess_labels(unlabeled_images, model):
+def guess_labels(unlabelled_images, model):
     """
     Calculate the average predictions by 2 augmentations: horizontal and vertical flips
 
     Parameters
     ----------
 
-    unlabeled_images : :py:class:`torch.Tensor`
+    unlabelled_images : :py:class:`torch.Tensor`
         ``[n,c,h,w]``
 
     target : :py:class:`torch.Tensor`
@@ -142,12 +143,12 @@ def guess_labels(unlabeled_images, model):
 
     """
     with torch.no_grad():
-        guess1 = torch.sigmoid(model(unlabeled_images)).unsqueeze(0)
+        guess1 = torch.sigmoid(model(unlabelled_images)).unsqueeze(0)
         # Horizontal flip and unsqueeze to work with batches (increase flip dimension by 1)
-        hflip = torch.sigmoid(model(unlabeled_images.flip(2))).unsqueeze(0)
+        hflip = torch.sigmoid(model(unlabelled_images.flip(2))).unsqueeze(0)
         guess2 = hflip.flip(3)
         # Vertical flip and unsqueeze to work with batches (increase flip dimension by 1)
-        vflip = torch.sigmoid(model(unlabeled_images.flip(3))).unsqueeze(0)
+        vflip = torch.sigmoid(model(unlabelled_images.flip(3))).unsqueeze(0)
         guess3 = vflip.flip(4)
         # Concat
         concat = torch.cat([guess1, guess2, guess3], 0)
@@ -169,13 +170,13 @@ def do_ssltrain(
     rampup_length,
 ):
     """
-    Train model and save to disk.
+    Trains model using semi-supervised learning and saves it to disk.
 
     Parameters
     ----------
 
     model : :py:class:`torch.nn.Module`
-        Network (e.g. DRIU, HED, UNet)
+        Network (e.g. driu, hed, unet)
 
     data_loader : :py:class:`torch.utils.data.DataLoader`
 
@@ -191,13 +192,14 @@ def do_ssltrain(
         checkpointer
 
     checkpoint_period : int
-        save a checkpoint every n epochs
+        save a checkpoint every ``n`` epochs.  If set to ``0`` (zero), then do
+        not save intermediary checkpoints
 
     device : str
-        device to use ``'cpu'`` or ``'cuda'``
+        device to use ``'cpu'`` or ``cuda:0``
 
     arguments : dict
-        start end end epochs
+        start and end epochs
 
     output_folder : str
         output path
@@ -206,15 +208,35 @@ def do_ssltrain(
         rampup epochs
 
     """
-    logger.info("Start SSL training")
 
+    logger.info("Start SSL training")
     start_epoch = arguments["epoch"]
     max_epoch = arguments["max_epoch"]
 
-    # Logg to file
-    with open(
-        os.path.join(output_folder, "{}_trainlog.csv".format(model.name)), "a+", 1
-    ) as outfile:
+    # Log to file
+    logfile_name = os.path.join(output_folder, "trainlog.csv")
+
+    if arguments["epoch"] == 0 and os.path.exists(logfile_name):
+        logger.info(f"Truncating {logfile_name} - training is restarting...")
+        os.unlink(logfile_name)
+
+    logfile_fields = (
+        "epoch",
+        "total-time",
+        "eta",
+        "average-loss",
+        "median-loss",
+        "median-labelled-loss",
+        "median-unlabelled-loss",
+        "learning-rate",
+        "gpu-memory-megabytes",
+    )
+    with open(logfile_name, "a+", newline="") as logfile:
+        logwriter = csv.DictWriter(logfile, fieldnames=logfile_fields)
+
+        if arguments["epoch"] == 0:
+            logwriter.writeheader()
+
         for state in optimizer.state.values():
             for k, v in state.items():
                 if isinstance(v, torch.Tensor):
@@ -223,125 +245,96 @@ def do_ssltrain(
         model.train().to(device)
         # Total training timer
         start_training_time = time.time()
+
         for epoch in range(start_epoch, max_epoch):
             scheduler.step()
             losses = SmoothedValue(len(data_loader))
-            labeled_loss = SmoothedValue(len(data_loader))
-            unlabeled_loss = SmoothedValue(len(data_loader))
+            labelled_loss = SmoothedValue(len(data_loader))
+            unlabelled_loss = SmoothedValue(len(data_loader))
             epoch = epoch + 1
             arguments["epoch"] = epoch
 
             # Epoch time
             start_epoch_time = time.time()
 
-            for samples in tqdm(data_loader):
-                # labeled
+            for samples in tqdm(data_loader, desc="batches", leave=False,
+                    disable=None,):
+
+                # data forwarding on the existing network
+
+                # labelled
                 images = samples[1].to(device)
                 ground_truths = samples[2].to(device)
-                unlabeled_images = samples[4].to(device)
-                # labeled outputs
+                unlabelled_images = samples[4].to(device)
+                # labelled outputs
                 outputs = model(images)
-                unlabeled_outputs = model(unlabeled_images)
-                # guessed unlabeled outputs
-                unlabeled_ground_truths = guess_labels(unlabeled_images, model)
-                # unlabeled_ground_truths = sharpen(unlabeled_ground_truths,0.5)
-                # images, ground_truths, unlabeled_images, unlabeled_ground_truths = mix_up(0.75, images, ground_truths, unlabeled_images, unlabeled_ground_truths)
+                unlabelled_outputs = model(unlabelled_images)
+                # guessed unlabelled outputs
+                unlabelled_ground_truths = guess_labels(unlabelled_images, model)
+                # unlabelled_ground_truths = sharpen(unlabelled_ground_truths,0.5)
+                # images, ground_truths, unlabelled_images, unlabelled_ground_truths = mix_up(0.75, images, ground_truths, unlabelled_images, unlabelled_ground_truths)
+
+                # loss evaluation and learning (backward step)
                 ramp_up_factor = square_rampup(epoch, rampup_length=rampup_length)
 
                 loss, ll, ul = criterion(
                     outputs,
                     ground_truths,
-                    unlabeled_outputs,
-                    unlabeled_ground_truths,
+                    unlabelled_outputs,
+                    unlabelled_ground_truths,
                     ramp_up_factor,
                 )
                 optimizer.zero_grad()
                 loss.backward()
                 optimizer.step()
                 losses.update(loss)
-                labeled_loss.update(ll)
-                unlabeled_loss.update(ul)
-                logger.debug("batch loss: {}".format(loss.item()))
+                labelled_loss.update(ll)
+                unlabelled_loss.update(ul)
+                logger.debug(f"batch loss: {loss.item()}")
 
-            if epoch % checkpoint_period == 0:
-                checkpointer.save("model_{:03d}".format(epoch), **arguments)
+            if checkpoint_period and (epoch % checkpoint_period == 0):
+                checkpointer.save(f"model_{epoch:03d}", **arguments)
 
-            if epoch == max_epoch:
+            if epoch >= max_epoch:
                 checkpointer.save("model_final", **arguments)
 
+            # computes ETA (estimated time-of-arrival; end of training) taking
+            # into consideration previous epoch performance
             epoch_time = time.time() - start_epoch_time
-
             eta_seconds = epoch_time * (max_epoch - epoch)
-            eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
+            current_time = time.time() - start_training_time
 
-            outfile.write(
+            logdata = (
+                ("epoch", f"{epoch}"),
                 (
-                    "{epoch}, "
-                    "{avg_loss:.6f}, "
-                    "{median_loss:.6f}, "
-                    "{median_labeled_loss},"
-                    "{median_unlabeled_loss},"
-                    "{lr:.6f}, "
-                    "{memory:.0f}"
-                    "\n"
-                ).format(
-                    eta=eta_string,
-                    epoch=epoch,
-                    avg_loss=losses.avg,
-                    median_loss=losses.median,
-                    median_labeled_loss=labeled_loss.median,
-                    median_unlabeled_loss=unlabeled_loss.median,
-                    lr=optimizer.param_groups[0]["lr"],
-                    memory=(torch.cuda.max_memory_allocated() / 1024.0 / 1024.0)
-                    if torch.cuda.is_available()
-                    else 0.0,
-                )
-            )
-            logger.info(
+                    "total-time",
+                    f"{datetime.timedelta(seconds=int(current_time))}",
+                ),
+                ("eta", f"{datetime.timedelta(seconds=int(eta_seconds))}"),
+                ("average-loss", f"{losses.avg:.6f}"),
+                ("median-loss", f"{losses.median:.6f}"),
+                ("median-labelled-loss", f"{labelled_loss.median:.6f}"),
+                ("median-unlabelled-loss", f"{unlabelled_loss.median:.6f}"),
+                ("learning-rate", f"{optimizer.param_groups[0]['lr']:.6f}"),
                 (
-                    "eta: {eta}, "
-                    "epoch: {epoch}, "
-                    "avg. loss: {avg_loss:.6f}, "
-                    "median loss: {median_loss:.6f}, "
-                    "labeled loss: {median_labeled_loss}, "
-                    "unlabeled loss: {median_unlabeled_loss}, "
-                    "lr: {lr:.6f}, "
-                    "max mem: {memory:.0f}"
-                ).format(
-                    eta=eta_string,
-                    epoch=epoch,
-                    avg_loss=losses.avg,
-                    median_loss=losses.median,
-                    median_labeled_loss=labeled_loss.median,
-                    median_unlabeled_loss=unlabeled_loss.median,
-                    lr=optimizer.param_groups[0]["lr"],
-                    memory=(torch.cuda.max_memory_allocated() / 1024.0 / 1024.0)
+                    "gpu-memory-megabytes",
+                    f"{torch.cuda.max_memory_allocated()/(1024.0*1024.0)}"
                     if torch.cuda.is_available()
-                    else 0.0,
-                )
+                    else "0.0",
+                ),
             )
+            logwriter.writerow(dict(k for k in logdata))
+            logger.info("|".join([f"{k}: {v}" for (k, v) in logdata]))
 
+        logger.info("End of training")
         total_training_time = time.time() - start_training_time
-        total_time_str = str(datetime.timedelta(seconds=total_training_time))
         logger.info(
-            "Total training time: {} ({:.4f} s / epoch)".format(
-                total_time_str, total_training_time / (max_epoch)
-            )
+            f"Total training time: {datetime.timedelta(seconds=total_training_time)} ({(total_training_time/max_epoch):.4f}s in average per epoch)"
         )
 
-    log_plot_file = os.path.join(output_folder, "{}_trainlog.pdf".format(model.name))
-    logdf = pd.read_csv(
-        os.path.join(output_folder, "{}_trainlog.csv".format(model.name)),
-        header=None,
-        names=[
-            "avg. loss",
-            "median loss",
-            "labeled loss",
-            "unlabeled loss",
-            "lr",
-            "max memory",
-        ],
-    )
-    fig = loss_curve(logdf, output_folder)
-    logger.info("saving {}".format(log_plot_file))
-    fig.savefig(log_plot_file)
+    # plots a version of the CSV trainlog into a PDF
+    logdf = pd.read_csv(logfile_name, header=0, names=logfile_fields)
+    fig = loss_curve(logdf, title="Loss Evolution")
+    figurefile_name = os.path.join(output_folder, "trainlog.pdf")
+    logger.info(f"Saving {figurefile_name}")
+    fig.savefig(figurefile_name)
-- 
GitLab