Skip to content
Snippets Groups Projects
Commit aaab33de authored by André Anjos's avatar André Anjos :speech_balloon:
Browse files

[engine.*trainer] Optimize validation during training with torch.no_grad() and model.eval()

parent 1fe18393
No related branches found
No related tags found
No related merge requests found
......@@ -16,12 +16,12 @@ from tqdm import tqdm
from ..utils.measure import SmoothedValue
from ..utils.summary import summary
from ..utils.resources import cpu_constants, gpu_constants, cpu_log, gpu_log
from .trainer import PYTORCH_GE_110, torch_evaluation
import logging
logger = logging.getLogger(__name__)
PYTORCH_GE_110 = distutils.version.StrictVersion(torch.__version__) >= "1.1.0"
def sharpen(x, T):
......@@ -371,31 +371,34 @@ def run(
# calculates the validation loss if necessary
valid_losses = None
if valid_loader is not None:
valid_losses = SmoothedValue(len(valid_loader))
for samples in tqdm(
valid_loader, desc="valid", leave=False, disable=None
):
# labelled
images = samples[1].to(device)
ground_truths = samples[2].to(device)
unlabelled_images = samples[4].to(device)
# labelled outputs
outputs = model(images)
unlabelled_outputs = model(unlabelled_images)
# guessed unlabelled outputs
unlabelled_ground_truths = guess_labels(
unlabelled_images, model
)
loss, ll, ul = criterion(
outputs,
ground_truths,
unlabelled_outputs,
unlabelled_ground_truths,
ramp_up_factor,
)
valid_losses.update(loss)
with torch.no_grad(), torch_evaluation(model):
valid_losses = SmoothedValue(len(valid_loader))
for samples in tqdm(
valid_loader, desc="valid", leave=False, disable=None
):
# labelled
images = samples[1].to(device)
ground_truths = samples[2].to(device)
unlabelled_images = samples[4].to(device)
# labelled outputs
outputs = model(images)
unlabelled_outputs = model(unlabelled_images)
# guessed unlabelled outputs
unlabelled_ground_truths = guess_labels(
unlabelled_images, model
)
loss, ll, ul = criterion(
outputs,
ground_truths,
unlabelled_outputs,
unlabelled_ground_truths,
ramp_up_factor,
)
valid_losses.update(loss)
if checkpoint_period and (epoch % checkpoint_period == 0):
checkpointer.save(f"model_{epoch:03d}", **arguments)
......
......@@ -7,6 +7,7 @@ import csv
import time
import shutil
import datetime
import contextlib
import distutils.version
import torch
......@@ -23,6 +24,34 @@ logger = logging.getLogger(__name__)
PYTORCH_GE_110 = distutils.version.StrictVersion(torch.__version__) >= "1.1.0"
@contextlib.contextmanager
def torch_evaluation(model):
"""Context manager to turn ON/OFF model evaluation
This context manager will turn evaluation mode ON on entry and turn it OFF
when exiting the ``with`` statement block.
Parameters
----------
model : :py:class:`torch.nn.Module`
Network (e.g. driu, hed, unet)
Yields
------
model : :py:class:`torch.nn.Module`
Network (e.g. driu, hed, unet)
"""
model.eval()
yield model
model.train()
def run(
model,
data_loader,
......@@ -203,21 +232,24 @@ def run(
# calculates the validation loss if necessary
valid_losses = None
if valid_loader is not None:
valid_losses = SmoothedValue(len(valid_loader))
for samples in tqdm(
valid_loader, desc="valid", leave=False, disable=None
):
# data forwarding on the existing network
images = samples[1].to(device)
ground_truths = samples[2].to(device)
masks = None
if len(samples) == 4:
masks = samples[-1].to(device)
outputs = model(images)
loss = criterion(outputs, ground_truths, masks)
valid_losses.update(loss)
with torch.no_grad(), torch_evaluation(model):
valid_losses = SmoothedValue(len(valid_loader))
for samples in tqdm(
valid_loader, desc="valid", leave=False, disable=None
):
# data forwarding on the existing network
images = samples[1].to(device)
ground_truths = samples[2].to(device)
masks = None
if len(samples) == 4:
masks = samples[-1].to(device)
outputs = model(images)
loss = criterion(outputs, ground_truths, masks)
valid_losses.update(loss)
if checkpoint_period and (epoch % checkpoint_period == 0):
checkpointer.save(f"model_{epoch:03d}", **arguments)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment