diff --git a/bob/ip/binseg/engine/evaluator.py b/bob/ip/binseg/engine/evaluator.py
index 65f1dfc2bacbb668f38020befd01e788e729cacf..5b4f2a39b23d6af3ac2fa034586e3fad217c9767 100644
--- a/bob/ip/binseg/engine/evaluator.py
+++ b/bob/ip/binseg/engine/evaluator.py
@@ -216,7 +216,6 @@ def run(data_loader, predictions_folder, output_folder, overlayed_folder=None,
 
     """
 
-    logger.info("Start evaluation")
     logger.info(f"Output folder: {output_folder}")
 
     if not os.path.exists(output_folder):
diff --git a/bob/ip/binseg/engine/predictor.py b/bob/ip/binseg/engine/predictor.py
index a4cd2ce164346a3386fca42fbbe46d11ab573a2b..ca22cb3e341424f63c39bc8b783e0d9b78f8ed36 100644
--- a/bob/ip/binseg/engine/predictor.py
+++ b/bob/ip/binseg/engine/predictor.py
@@ -129,7 +129,6 @@ def run(model, data_loader, device, output_folder, overlayed_folder):
 
     """
 
-    logger.info("Start prediction")
     logger.info(f"Output folder: {output_folder}")
     logger.info(f"Device: {device}")
 
@@ -176,8 +175,6 @@ def run(model, data_loader, device, output_folder, overlayed_folder):
                 if overlayed_folder is not None:
                     _save_overlayed_png(stem, img, prob, overlayed_folder)
 
-    logger.info("End prediction")
-
     # report operational summary
     total_time = datetime.timedelta(seconds=int(time.time() - start_total_time))
     logger.info(f"Total time: {total_time}")
diff --git a/bob/ip/binseg/engine/ssltrainer.py b/bob/ip/binseg/engine/ssltrainer.py
index 7bd4aba027cd70fd62286abfa154445516a3a48b..27c0089f4658171f5f0b87c5c8be6bba604f5521 100644
--- a/bob/ip/binseg/engine/ssltrainer.py
+++ b/bob/ip/binseg/engine/ssltrainer.py
@@ -16,6 +16,8 @@ from bob.ip.binseg.utils.plot import loss_curve
 import logging
 logger = logging.getLogger(__name__)
 
+PYTORCH_GE_110 = (distutils.version.StrictVersion(torch.__version__) >= "1.1.0")
+
 
 def sharpen(x, T):
     temp = x ** (1 / T)
@@ -209,7 +211,6 @@ def run(
 
     """
 
-    logger.info("Start SSL training")
     start_epoch = arguments["epoch"]
     max_epoch = arguments["max_epoch"]
 
@@ -251,7 +252,7 @@ def run(
         start_training_time = time.time()
 
         for epoch in range(start_epoch, max_epoch):
-            scheduler.step()
+            if not PYTORCH_GE_110: scheduler.step()
             losses = SmoothedValue(len(data_loader))
             labelled_loss = SmoothedValue(len(data_loader))
             unlabelled_loss = SmoothedValue(len(data_loader))
@@ -296,6 +297,8 @@ def run(
                 unlabelled_loss.update(ul)
                 logger.debug(f"batch loss: {loss.item()}")
 
+            if PYTORCH_GE_110: scheduler.step()
+
             if checkpoint_period and (epoch % checkpoint_period == 0):
                 checkpointer.save(f"model_{epoch:03d}", **arguments)
 
diff --git a/bob/ip/binseg/engine/trainer.py b/bob/ip/binseg/engine/trainer.py
index 040e5651514a1b12d89f3a3c00b2f8b9b5eb90d8..dee2d6287dad0f4481b69ede9b25f0dfd52236b1 100644
--- a/bob/ip/binseg/engine/trainer.py
+++ b/bob/ip/binseg/engine/trainer.py
@@ -5,6 +5,7 @@ import os
 import csv
 import time
 import datetime
+import distutils.version
 
 import torch
 import pandas
@@ -16,6 +17,8 @@ from bob.ip.binseg.utils.plot import loss_curve
 import logging
 logger = logging.getLogger(__name__)
 
+PYTORCH_GE_110 = (distutils.version.StrictVersion(torch.__version__) >= "1.1.0")
+
 
 def run(
     model,
@@ -69,7 +72,6 @@ def run(
         output path
     """
 
-    logger.info("Start training")
     start_epoch = arguments["epoch"]
     max_epoch = arguments["max_epoch"]
 
@@ -108,7 +110,7 @@ def run(
         start_training_time = time.time()
 
         for epoch in range(start_epoch, max_epoch):
-            scheduler.step()
+            if not PYTORCH_GE_110: scheduler.step()
             losses = SmoothedValue(len(data_loader))
             epoch = epoch + 1
             arguments["epoch"] = epoch
@@ -139,6 +141,8 @@ def run(
                 losses.update(loss)
                 logger.debug(f"batch loss: {loss.item()}")
 
+            if PYTORCH_GE_110: scheduler.step()
+
             if checkpoint_period and (epoch % checkpoint_period == 0):
                 checkpointer.save(f"model_{epoch:03d}", **arguments)