From e2dde35917350b05cdb6e608ca48da4456682ebf Mon Sep 17 00:00:00 2001 From: Andre Anjos <andre.dos.anjos@gmail.com> Date: Mon, 20 Apr 2020 19:14:30 +0200 Subject: [PATCH] [engine.trainer-s] Fix pytorch>=1.1 scheduler stepping; Remove excessive logging --- bob/ip/binseg/engine/evaluator.py | 1 - bob/ip/binseg/engine/predictor.py | 3 --- bob/ip/binseg/engine/ssltrainer.py | 7 +++++-- bob/ip/binseg/engine/trainer.py | 8 ++++++-- 4 files changed, 11 insertions(+), 8 deletions(-) diff --git a/bob/ip/binseg/engine/evaluator.py b/bob/ip/binseg/engine/evaluator.py index 65f1dfc2..5b4f2a39 100644 --- a/bob/ip/binseg/engine/evaluator.py +++ b/bob/ip/binseg/engine/evaluator.py @@ -216,7 +216,6 @@ def run(data_loader, predictions_folder, output_folder, overlayed_folder=None, """ - logger.info("Start evaluation") logger.info(f"Output folder: {output_folder}") if not os.path.exists(output_folder): diff --git a/bob/ip/binseg/engine/predictor.py b/bob/ip/binseg/engine/predictor.py index a4cd2ce1..ca22cb3e 100644 --- a/bob/ip/binseg/engine/predictor.py +++ b/bob/ip/binseg/engine/predictor.py @@ -129,7 +129,6 @@ def run(model, data_loader, device, output_folder, overlayed_folder): """ - logger.info("Start prediction") logger.info(f"Output folder: {output_folder}") logger.info(f"Device: {device}") @@ -176,8 +175,6 @@ def run(model, data_loader, device, output_folder, overlayed_folder): if overlayed_folder is not None: _save_overlayed_png(stem, img, prob, overlayed_folder) - logger.info("End prediction") - # report operational summary total_time = datetime.timedelta(seconds=int(time.time() - start_total_time)) logger.info(f"Total time: {total_time}") diff --git a/bob/ip/binseg/engine/ssltrainer.py b/bob/ip/binseg/engine/ssltrainer.py index 7bd4aba0..27c0089f 100644 --- a/bob/ip/binseg/engine/ssltrainer.py +++ b/bob/ip/binseg/engine/ssltrainer.py @@ -16,6 +16,8 @@ from bob.ip.binseg.utils.plot import loss_curve import logging logger = logging.getLogger(__name__) +PYTORCH_GE_110 = (distutils.version.StrictVersion(torch.__version__) >= "1.1.0") + def sharpen(x, T): temp = x ** (1 / T) @@ -209,7 +211,6 @@ def run( """ - logger.info("Start SSL training") start_epoch = arguments["epoch"] max_epoch = arguments["max_epoch"] @@ -251,7 +252,7 @@ def run( start_training_time = time.time() for epoch in range(start_epoch, max_epoch): - scheduler.step() + if not PYTORCH_GE_110: scheduler.step() losses = SmoothedValue(len(data_loader)) labelled_loss = SmoothedValue(len(data_loader)) unlabelled_loss = SmoothedValue(len(data_loader)) @@ -296,6 +297,8 @@ def run( unlabelled_loss.update(ul) logger.debug(f"batch loss: {loss.item()}") + if PYTORCH_GE_110: scheduler.step() + if checkpoint_period and (epoch % checkpoint_period == 0): checkpointer.save(f"model_{epoch:03d}", **arguments) diff --git a/bob/ip/binseg/engine/trainer.py b/bob/ip/binseg/engine/trainer.py index 040e5651..dee2d628 100644 --- a/bob/ip/binseg/engine/trainer.py +++ b/bob/ip/binseg/engine/trainer.py @@ -5,6 +5,7 @@ import os import csv import time import datetime +import distutils.version import torch import pandas @@ -16,6 +17,8 @@ from bob.ip.binseg.utils.plot import loss_curve import logging logger = logging.getLogger(__name__) +PYTORCH_GE_110 = (distutils.version.StrictVersion(torch.__version__) >= "1.1.0") + def run( model, @@ -69,7 +72,6 @@ def run( output path """ - logger.info("Start training") start_epoch = arguments["epoch"] max_epoch = arguments["max_epoch"] @@ -108,7 +110,7 @@ def run( start_training_time = time.time() for epoch in range(start_epoch, max_epoch): - scheduler.step() + if not PYTORCH_GE_110: scheduler.step() losses = SmoothedValue(len(data_loader)) epoch = epoch + 1 arguments["epoch"] = epoch @@ -139,6 +141,8 @@ def run( losses.update(loss) logger.debug(f"batch loss: {loss.item()}") + if PYTORCH_GE_110: scheduler.step() + if checkpoint_period and (epoch % checkpoint_period == 0): checkpointer.save(f"model_{epoch:03d}", **arguments) -- GitLab