From f45e6ea5e2efd7c8d344ad4ea145c8f6dc7c7814 Mon Sep 17 00:00:00 2001
From: Andre Anjos <andre.anjos@idiap.ch>
Date: Sun, 5 Apr 2020 10:56:19 +0200
Subject: [PATCH] [engine.trainer] Always remove tqdm bars; Only use tqdm if
 tty is connected

---
 bob/ip/binseg/engine/trainer.py | 55 +++++++++++++++++++++------------
 1 file changed, 35 insertions(+), 20 deletions(-)

diff --git a/bob/ip/binseg/engine/trainer.py b/bob/ip/binseg/engine/trainer.py
index 3201df17..b225f69c 100644
--- a/bob/ip/binseg/engine/trainer.py
+++ b/bob/ip/binseg/engine/trainer.py
@@ -80,8 +80,15 @@ def do_train(
         logger.info(f"Truncating {logfile_name} - training is restarting...")
         os.unlink(logfile_name)
 
-    logfile_fields = ("epoch", "total-time", "eta", "average-loss",
-            "median-loss", "learning-rate", "gpu-memory-megabytes")
+    logfile_fields = (
+        "epoch",
+        "total-time",
+        "eta",
+        "average-loss",
+        "median-loss",
+        "learning-rate",
+        "gpu-memory-megabytes",
+    )
     with open(logfile_name, "a+", newline="") as logfile:
         logwriter = csv.DictWriter(logfile, fieldnames=logfile_fields)
 
@@ -105,7 +112,10 @@ def do_train(
             # Epoch time
             start_epoch_time = time.time()
 
-            for samples in tqdm(data_loader):
+            # progress bar only on interactive jobs
+            for samples in tqdm(
+                data_loader, desc="batches", leave=False, disable=None
+            ):
 
                 images = samples[1].to(device)
                 ground_truths = samples[2].to(device)
@@ -126,7 +136,7 @@ def do_train(
             if checkpoint_period and (epoch % checkpoint_period == 0):
                 checkpointer.save("model_{:03d}".format(epoch), **arguments)
 
-            if epoch == max_epoch:
+            if epoch >= max_epoch:
                 checkpointer.save("model_final", **arguments)
 
             # computes ETA (estimated time-of-arrival; end of training) taking
@@ -136,26 +146,31 @@ def do_train(
             current_time = time.time() - start_training_time
 
             logdata = (
-                    ("epoch", f"{epoch}"),
-                    ("total-time",
-                        f"{datetime.timedelta(seconds=int(current_time))}"),
-                    ("eta",
-                        f"{datetime.timedelta(seconds=int(eta_seconds))}"),
-                    ("average-loss", f"{losses.avg:.6f}"),
-                    ("median-loss", f"{losses.median:.6f}"),
-                    ("learning-rate",
-                        f"{optimizer.param_groups[0]['lr']:.6f}"),
-                    ("gpu-memory-megabytes",
-                        f"{torch.cuda.max_memory_allocated()/(1024.0*1024.0)}" \
-                        if torch.cuda.is_available() else "0.0"),
-                    )
+                ("epoch", f"{epoch}"),
+                (
+                    "total-time",
+                    f"{datetime.timedelta(seconds=int(current_time))}",
+                ),
+                ("eta", f"{datetime.timedelta(seconds=int(eta_seconds))}"),
+                ("average-loss", f"{losses.avg:.6f}"),
+                ("median-loss", f"{losses.median:.6f}"),
+                ("learning-rate", f"{optimizer.param_groups[0]['lr']:.6f}"),
+                (
+                    "gpu-memory-megabytes",
+                    f"{torch.cuda.max_memory_allocated()/(1024.0*1024.0)}"
+                    if torch.cuda.is_available()
+                    else "0.0",
+                ),
+            )
 
             logwriter.writerow(dict(k for k in logdata))
-            logger.info("|".join([f"{k}: {v}" for (k,v) in logdata]))
+            logger.info("|".join([f"{k}: {v}" for (k, v) in logdata]))
 
-        logger.info("End of training.")
+        logger.info("End of training")
         total_training_time = time.time() - start_training_time
-        logger.info(f"Total training time: {datetime.timedelta(seconds=total_training_time)} ({(total_training_time/max_epoch):.4f}s in average per epoch)")
+        logger.info(
+            f"Total training time: {datetime.timedelta(seconds=total_training_time)} ({(total_training_time/max_epoch):.4f}s in average per epoch)"
+        )
 
     # plots a version of the CSV trainlog into a PDF
     logdf = pd.read_csv(logfile_name, header=0, names=logfile_fields)
-- 
GitLab