diff --git a/bob/ip/binseg/engine/evaluator.py b/bob/ip/binseg/engine/evaluator.py index 65f1dfc2bacbb668f38020befd01e788e729cacf..5b4f2a39b23d6af3ac2fa034586e3fad217c9767 100644 --- a/bob/ip/binseg/engine/evaluator.py +++ b/bob/ip/binseg/engine/evaluator.py @@ -216,7 +216,6 @@ def run(data_loader, predictions_folder, output_folder, overlayed_folder=None, """ - logger.info("Start evaluation") logger.info(f"Output folder: {output_folder}") if not os.path.exists(output_folder): diff --git a/bob/ip/binseg/engine/predictor.py b/bob/ip/binseg/engine/predictor.py index a4cd2ce164346a3386fca42fbbe46d11ab573a2b..ca22cb3e341424f63c39bc8b783e0d9b78f8ed36 100644 --- a/bob/ip/binseg/engine/predictor.py +++ b/bob/ip/binseg/engine/predictor.py @@ -129,7 +129,6 @@ def run(model, data_loader, device, output_folder, overlayed_folder): """ - logger.info("Start prediction") logger.info(f"Output folder: {output_folder}") logger.info(f"Device: {device}") @@ -176,8 +175,6 @@ def run(model, data_loader, device, output_folder, overlayed_folder): if overlayed_folder is not None: _save_overlayed_png(stem, img, prob, overlayed_folder) - logger.info("End prediction") - # report operational summary total_time = datetime.timedelta(seconds=int(time.time() - start_total_time)) logger.info(f"Total time: {total_time}") diff --git a/bob/ip/binseg/engine/ssltrainer.py b/bob/ip/binseg/engine/ssltrainer.py index 7bd4aba027cd70fd62286abfa154445516a3a48b..27c0089f4658171f5f0b87c5c8be6bba604f5521 100644 --- a/bob/ip/binseg/engine/ssltrainer.py +++ b/bob/ip/binseg/engine/ssltrainer.py @@ -16,6 +16,8 @@ from bob.ip.binseg.utils.plot import loss_curve import logging logger = logging.getLogger(__name__) +PYTORCH_GE_110 = (distutils.version.StrictVersion(torch.__version__) >= "1.1.0") + def sharpen(x, T): temp = x ** (1 / T) @@ -209,7 +211,6 @@ def run( """ - logger.info("Start SSL training") start_epoch = arguments["epoch"] max_epoch = arguments["max_epoch"] @@ -251,7 +252,7 @@ def run( start_training_time = time.time() for epoch in range(start_epoch, max_epoch): - scheduler.step() + if not PYTORCH_GE_110: scheduler.step() losses = SmoothedValue(len(data_loader)) labelled_loss = SmoothedValue(len(data_loader)) unlabelled_loss = SmoothedValue(len(data_loader)) @@ -296,6 +297,8 @@ def run( unlabelled_loss.update(ul) logger.debug(f"batch loss: {loss.item()}") + if PYTORCH_GE_110: scheduler.step() + if checkpoint_period and (epoch % checkpoint_period == 0): checkpointer.save(f"model_{epoch:03d}", **arguments) diff --git a/bob/ip/binseg/engine/trainer.py b/bob/ip/binseg/engine/trainer.py index 040e5651514a1b12d89f3a3c00b2f8b9b5eb90d8..dee2d6287dad0f4481b69ede9b25f0dfd52236b1 100644 --- a/bob/ip/binseg/engine/trainer.py +++ b/bob/ip/binseg/engine/trainer.py @@ -5,6 +5,7 @@ import os import csv import time import datetime +import distutils.version import torch import pandas @@ -16,6 +17,8 @@ from bob.ip.binseg.utils.plot import loss_curve import logging logger = logging.getLogger(__name__) +PYTORCH_GE_110 = (distutils.version.StrictVersion(torch.__version__) >= "1.1.0") + def run( model, @@ -69,7 +72,6 @@ def run( output path """ - logger.info("Start training") start_epoch = arguments["epoch"] max_epoch = arguments["max_epoch"] @@ -108,7 +110,7 @@ def run( start_training_time = time.time() for epoch in range(start_epoch, max_epoch): - scheduler.step() + if not PYTORCH_GE_110: scheduler.step() losses = SmoothedValue(len(data_loader)) epoch = epoch + 1 arguments["epoch"] = epoch @@ -139,6 +141,8 @@ def run( losses.update(loss) logger.debug(f"batch loss: {loss.item()}") + if PYTORCH_GE_110: scheduler.step() + if checkpoint_period and (epoch % checkpoint_period == 0): checkpointer.save(f"model_{epoch:03d}", **arguments)