From e2dde35917350b05cdb6e608ca48da4456682ebf Mon Sep 17 00:00:00 2001
From: Andre Anjos <andre.dos.anjos@gmail.com>
Date: Mon, 20 Apr 2020 19:14:30 +0200
Subject: [PATCH] [engine.trainer-s] Fix pytorch>=1.1 scheduler stepping;
 Remove excessive logging

---
 bob/ip/binseg/engine/evaluator.py  | 1 -
 bob/ip/binseg/engine/predictor.py  | 3 ---
 bob/ip/binseg/engine/ssltrainer.py | 7 +++++--
 bob/ip/binseg/engine/trainer.py    | 8 ++++++--
 4 files changed, 11 insertions(+), 8 deletions(-)

diff --git a/bob/ip/binseg/engine/evaluator.py b/bob/ip/binseg/engine/evaluator.py
index 65f1dfc2..5b4f2a39 100644
--- a/bob/ip/binseg/engine/evaluator.py
+++ b/bob/ip/binseg/engine/evaluator.py
@@ -216,7 +216,6 @@ def run(data_loader, predictions_folder, output_folder, overlayed_folder=None,
 
     """
 
-    logger.info("Start evaluation")
     logger.info(f"Output folder: {output_folder}")
 
     if not os.path.exists(output_folder):
diff --git a/bob/ip/binseg/engine/predictor.py b/bob/ip/binseg/engine/predictor.py
index a4cd2ce1..ca22cb3e 100644
--- a/bob/ip/binseg/engine/predictor.py
+++ b/bob/ip/binseg/engine/predictor.py
@@ -129,7 +129,6 @@ def run(model, data_loader, device, output_folder, overlayed_folder):
 
     """
 
-    logger.info("Start prediction")
     logger.info(f"Output folder: {output_folder}")
     logger.info(f"Device: {device}")
 
@@ -176,8 +175,6 @@ def run(model, data_loader, device, output_folder, overlayed_folder):
                 if overlayed_folder is not None:
                     _save_overlayed_png(stem, img, prob, overlayed_folder)
 
-    logger.info("End prediction")
-
     # report operational summary
     total_time = datetime.timedelta(seconds=int(time.time() - start_total_time))
     logger.info(f"Total time: {total_time}")
diff --git a/bob/ip/binseg/engine/ssltrainer.py b/bob/ip/binseg/engine/ssltrainer.py
index 7bd4aba0..27c0089f 100644
--- a/bob/ip/binseg/engine/ssltrainer.py
+++ b/bob/ip/binseg/engine/ssltrainer.py
@@ -16,6 +16,8 @@ from bob.ip.binseg.utils.plot import loss_curve
 import logging
 logger = logging.getLogger(__name__)
 
+PYTORCH_GE_110 = (distutils.version.StrictVersion(torch.__version__) >= "1.1.0")
+
 
 def sharpen(x, T):
     temp = x ** (1 / T)
@@ -209,7 +211,6 @@ def run(
 
     """
 
-    logger.info("Start SSL training")
     start_epoch = arguments["epoch"]
     max_epoch = arguments["max_epoch"]
 
@@ -251,7 +252,7 @@ def run(
         start_training_time = time.time()
 
         for epoch in range(start_epoch, max_epoch):
-            scheduler.step()
+            if not PYTORCH_GE_110: scheduler.step()
             losses = SmoothedValue(len(data_loader))
             labelled_loss = SmoothedValue(len(data_loader))
             unlabelled_loss = SmoothedValue(len(data_loader))
@@ -296,6 +297,8 @@ def run(
                 unlabelled_loss.update(ul)
                 logger.debug(f"batch loss: {loss.item()}")
 
+            if PYTORCH_GE_110: scheduler.step()
+
             if checkpoint_period and (epoch % checkpoint_period == 0):
                 checkpointer.save(f"model_{epoch:03d}", **arguments)
 
diff --git a/bob/ip/binseg/engine/trainer.py b/bob/ip/binseg/engine/trainer.py
index 040e5651..dee2d628 100644
--- a/bob/ip/binseg/engine/trainer.py
+++ b/bob/ip/binseg/engine/trainer.py
@@ -5,6 +5,7 @@ import os
 import csv
 import time
 import datetime
+import distutils.version
 
 import torch
 import pandas
@@ -16,6 +17,8 @@ from bob.ip.binseg.utils.plot import loss_curve
 import logging
 logger = logging.getLogger(__name__)
 
+PYTORCH_GE_110 = (distutils.version.StrictVersion(torch.__version__) >= "1.1.0")
+
 
 def run(
     model,
@@ -69,7 +72,6 @@ def run(
         output path
     """
 
-    logger.info("Start training")
     start_epoch = arguments["epoch"]
     max_epoch = arguments["max_epoch"]
 
@@ -108,7 +110,7 @@ def run(
         start_training_time = time.time()
 
         for epoch in range(start_epoch, max_epoch):
-            scheduler.step()
+            if not PYTORCH_GE_110: scheduler.step()
             losses = SmoothedValue(len(data_loader))
             epoch = epoch + 1
             arguments["epoch"] = epoch
@@ -139,6 +141,8 @@ def run(
                 losses.update(loss)
                 logger.debug(f"batch loss: {loss.item()}")
 
+            if PYTORCH_GE_110: scheduler.step()
+
             if checkpoint_period and (epoch % checkpoint_period == 0):
                 checkpointer.save(f"model_{epoch:03d}", **arguments)
 
-- 
GitLab