Skip to content
Snippets Groups Projects
Commit e2dde359 authored by André Anjos's avatar André Anjos :speech_balloon:
Browse files

[engine.trainer-s] Fix pytorch>=1.1 scheduler stepping; Remove excessive logging

parent 2d491569
No related branches found
No related tags found
1 merge request!12Streamlining
...@@ -216,7 +216,6 @@ def run(data_loader, predictions_folder, output_folder, overlayed_folder=None, ...@@ -216,7 +216,6 @@ def run(data_loader, predictions_folder, output_folder, overlayed_folder=None,
""" """
logger.info("Start evaluation")
logger.info(f"Output folder: {output_folder}") logger.info(f"Output folder: {output_folder}")
if not os.path.exists(output_folder): if not os.path.exists(output_folder):
......
...@@ -129,7 +129,6 @@ def run(model, data_loader, device, output_folder, overlayed_folder): ...@@ -129,7 +129,6 @@ def run(model, data_loader, device, output_folder, overlayed_folder):
""" """
logger.info("Start prediction")
logger.info(f"Output folder: {output_folder}") logger.info(f"Output folder: {output_folder}")
logger.info(f"Device: {device}") logger.info(f"Device: {device}")
...@@ -176,8 +175,6 @@ def run(model, data_loader, device, output_folder, overlayed_folder): ...@@ -176,8 +175,6 @@ def run(model, data_loader, device, output_folder, overlayed_folder):
if overlayed_folder is not None: if overlayed_folder is not None:
_save_overlayed_png(stem, img, prob, overlayed_folder) _save_overlayed_png(stem, img, prob, overlayed_folder)
logger.info("End prediction")
# report operational summary # report operational summary
total_time = datetime.timedelta(seconds=int(time.time() - start_total_time)) total_time = datetime.timedelta(seconds=int(time.time() - start_total_time))
logger.info(f"Total time: {total_time}") logger.info(f"Total time: {total_time}")
......
...@@ -16,6 +16,8 @@ from bob.ip.binseg.utils.plot import loss_curve ...@@ -16,6 +16,8 @@ from bob.ip.binseg.utils.plot import loss_curve
import logging import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
PYTORCH_GE_110 = (distutils.version.StrictVersion(torch.__version__) >= "1.1.0")
def sharpen(x, T): def sharpen(x, T):
temp = x ** (1 / T) temp = x ** (1 / T)
...@@ -209,7 +211,6 @@ def run( ...@@ -209,7 +211,6 @@ def run(
""" """
logger.info("Start SSL training")
start_epoch = arguments["epoch"] start_epoch = arguments["epoch"]
max_epoch = arguments["max_epoch"] max_epoch = arguments["max_epoch"]
...@@ -251,7 +252,7 @@ def run( ...@@ -251,7 +252,7 @@ def run(
start_training_time = time.time() start_training_time = time.time()
for epoch in range(start_epoch, max_epoch): for epoch in range(start_epoch, max_epoch):
scheduler.step() if not PYTORCH_GE_110: scheduler.step()
losses = SmoothedValue(len(data_loader)) losses = SmoothedValue(len(data_loader))
labelled_loss = SmoothedValue(len(data_loader)) labelled_loss = SmoothedValue(len(data_loader))
unlabelled_loss = SmoothedValue(len(data_loader)) unlabelled_loss = SmoothedValue(len(data_loader))
...@@ -296,6 +297,8 @@ def run( ...@@ -296,6 +297,8 @@ def run(
unlabelled_loss.update(ul) unlabelled_loss.update(ul)
logger.debug(f"batch loss: {loss.item()}") logger.debug(f"batch loss: {loss.item()}")
if PYTORCH_GE_110: scheduler.step()
if checkpoint_period and (epoch % checkpoint_period == 0): if checkpoint_period and (epoch % checkpoint_period == 0):
checkpointer.save(f"model_{epoch:03d}", **arguments) checkpointer.save(f"model_{epoch:03d}", **arguments)
......
...@@ -5,6 +5,7 @@ import os ...@@ -5,6 +5,7 @@ import os
import csv import csv
import time import time
import datetime import datetime
import distutils.version
import torch import torch
import pandas import pandas
...@@ -16,6 +17,8 @@ from bob.ip.binseg.utils.plot import loss_curve ...@@ -16,6 +17,8 @@ from bob.ip.binseg.utils.plot import loss_curve
import logging import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
PYTORCH_GE_110 = (distutils.version.StrictVersion(torch.__version__) >= "1.1.0")
def run( def run(
model, model,
...@@ -69,7 +72,6 @@ def run( ...@@ -69,7 +72,6 @@ def run(
output path output path
""" """
logger.info("Start training")
start_epoch = arguments["epoch"] start_epoch = arguments["epoch"]
max_epoch = arguments["max_epoch"] max_epoch = arguments["max_epoch"]
...@@ -108,7 +110,7 @@ def run( ...@@ -108,7 +110,7 @@ def run(
start_training_time = time.time() start_training_time = time.time()
for epoch in range(start_epoch, max_epoch): for epoch in range(start_epoch, max_epoch):
scheduler.step() if not PYTORCH_GE_110: scheduler.step()
losses = SmoothedValue(len(data_loader)) losses = SmoothedValue(len(data_loader))
epoch = epoch + 1 epoch = epoch + 1
arguments["epoch"] = epoch arguments["epoch"] = epoch
...@@ -139,6 +141,8 @@ def run( ...@@ -139,6 +141,8 @@ def run(
losses.update(loss) losses.update(loss)
logger.debug(f"batch loss: {loss.item()}") logger.debug(f"batch loss: {loss.item()}")
if PYTORCH_GE_110: scheduler.step()
if checkpoint_period and (epoch % checkpoint_period == 0): if checkpoint_period and (epoch % checkpoint_period == 0):
checkpointer.save(f"model_{epoch:03d}", **arguments) checkpointer.save(f"model_{epoch:03d}", **arguments)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment