Skip to content
Snippets Groups Projects
Commit e2dde359 authored by André Anjos's avatar André Anjos :speech_balloon:
Browse files

[engine.trainer-s] Fix pytorch>=1.1 scheduler stepping; Remove excessive logging

parent 2d491569
No related branches found
No related tags found
1 merge request!12Streamlining
......@@ -216,7 +216,6 @@ def run(data_loader, predictions_folder, output_folder, overlayed_folder=None,
"""
logger.info("Start evaluation")
logger.info(f"Output folder: {output_folder}")
if not os.path.exists(output_folder):
......
......@@ -129,7 +129,6 @@ def run(model, data_loader, device, output_folder, overlayed_folder):
"""
logger.info("Start prediction")
logger.info(f"Output folder: {output_folder}")
logger.info(f"Device: {device}")
......@@ -176,8 +175,6 @@ def run(model, data_loader, device, output_folder, overlayed_folder):
if overlayed_folder is not None:
_save_overlayed_png(stem, img, prob, overlayed_folder)
logger.info("End prediction")
# report operational summary
total_time = datetime.timedelta(seconds=int(time.time() - start_total_time))
logger.info(f"Total time: {total_time}")
......
......@@ -16,6 +16,8 @@ from bob.ip.binseg.utils.plot import loss_curve
import logging
logger = logging.getLogger(__name__)
PYTORCH_GE_110 = (distutils.version.StrictVersion(torch.__version__) >= "1.1.0")
def sharpen(x, T):
temp = x ** (1 / T)
......@@ -209,7 +211,6 @@ def run(
"""
logger.info("Start SSL training")
start_epoch = arguments["epoch"]
max_epoch = arguments["max_epoch"]
......@@ -251,7 +252,7 @@ def run(
start_training_time = time.time()
for epoch in range(start_epoch, max_epoch):
scheduler.step()
if not PYTORCH_GE_110: scheduler.step()
losses = SmoothedValue(len(data_loader))
labelled_loss = SmoothedValue(len(data_loader))
unlabelled_loss = SmoothedValue(len(data_loader))
......@@ -296,6 +297,8 @@ def run(
unlabelled_loss.update(ul)
logger.debug(f"batch loss: {loss.item()}")
if PYTORCH_GE_110: scheduler.step()
if checkpoint_period and (epoch % checkpoint_period == 0):
checkpointer.save(f"model_{epoch:03d}", **arguments)
......
......@@ -5,6 +5,7 @@ import os
import csv
import time
import datetime
import distutils.version
import torch
import pandas
......@@ -16,6 +17,8 @@ from bob.ip.binseg.utils.plot import loss_curve
import logging
logger = logging.getLogger(__name__)
PYTORCH_GE_110 = (distutils.version.StrictVersion(torch.__version__) >= "1.1.0")
def run(
model,
......@@ -69,7 +72,6 @@ def run(
output path
"""
logger.info("Start training")
start_epoch = arguments["epoch"]
max_epoch = arguments["max_epoch"]
......@@ -108,7 +110,7 @@ def run(
start_training_time = time.time()
for epoch in range(start_epoch, max_epoch):
scheduler.step()
if not PYTORCH_GE_110: scheduler.step()
losses = SmoothedValue(len(data_loader))
epoch = epoch + 1
arguments["epoch"] = epoch
......@@ -139,6 +141,8 @@ def run(
losses.update(loss)
logger.debug(f"batch loss: {loss.item()}")
if PYTORCH_GE_110: scheduler.step()
if checkpoint_period and (epoch % checkpoint_period == 0):
checkpointer.save(f"model_{epoch:03d}", **arguments)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment