Skip to content
Snippets Groups Projects
Commit aaab33de authored by André Anjos's avatar André Anjos :speech_balloon:
Browse files

[engine.*trainer] Optimize validation during training with torch.no_grad() and model.eval()

parent 1fe18393
No related branches found
No related tags found
No related merge requests found
...@@ -16,12 +16,12 @@ from tqdm import tqdm ...@@ -16,12 +16,12 @@ from tqdm import tqdm
from ..utils.measure import SmoothedValue from ..utils.measure import SmoothedValue
from ..utils.summary import summary from ..utils.summary import summary
from ..utils.resources import cpu_constants, gpu_constants, cpu_log, gpu_log from ..utils.resources import cpu_constants, gpu_constants, cpu_log, gpu_log
from .trainer import PYTORCH_GE_110, torch_evaluation
import logging import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
PYTORCH_GE_110 = distutils.version.StrictVersion(torch.__version__) >= "1.1.0"
def sharpen(x, T): def sharpen(x, T):
...@@ -371,31 +371,34 @@ def run( ...@@ -371,31 +371,34 @@ def run(
# calculates the validation loss if necessary # calculates the validation loss if necessary
valid_losses = None valid_losses = None
if valid_loader is not None: if valid_loader is not None:
valid_losses = SmoothedValue(len(valid_loader))
for samples in tqdm( with torch.no_grad(), torch_evaluation(model):
valid_loader, desc="valid", leave=False, disable=None
): valid_losses = SmoothedValue(len(valid_loader))
for samples in tqdm(
# labelled valid_loader, desc="valid", leave=False, disable=None
images = samples[1].to(device) ):
ground_truths = samples[2].to(device)
unlabelled_images = samples[4].to(device) # labelled
# labelled outputs images = samples[1].to(device)
outputs = model(images) ground_truths = samples[2].to(device)
unlabelled_outputs = model(unlabelled_images) unlabelled_images = samples[4].to(device)
# guessed unlabelled outputs # labelled outputs
unlabelled_ground_truths = guess_labels( outputs = model(images)
unlabelled_images, model unlabelled_outputs = model(unlabelled_images)
) # guessed unlabelled outputs
loss, ll, ul = criterion( unlabelled_ground_truths = guess_labels(
outputs, unlabelled_images, model
ground_truths, )
unlabelled_outputs, loss, ll, ul = criterion(
unlabelled_ground_truths, outputs,
ramp_up_factor, ground_truths,
) unlabelled_outputs,
unlabelled_ground_truths,
valid_losses.update(loss) ramp_up_factor,
)
valid_losses.update(loss)
if checkpoint_period and (epoch % checkpoint_period == 0): if checkpoint_period and (epoch % checkpoint_period == 0):
checkpointer.save(f"model_{epoch:03d}", **arguments) checkpointer.save(f"model_{epoch:03d}", **arguments)
......
...@@ -7,6 +7,7 @@ import csv ...@@ -7,6 +7,7 @@ import csv
import time import time
import shutil import shutil
import datetime import datetime
import contextlib
import distutils.version import distutils.version
import torch import torch
...@@ -23,6 +24,34 @@ logger = logging.getLogger(__name__) ...@@ -23,6 +24,34 @@ logger = logging.getLogger(__name__)
PYTORCH_GE_110 = distutils.version.StrictVersion(torch.__version__) >= "1.1.0" PYTORCH_GE_110 = distutils.version.StrictVersion(torch.__version__) >= "1.1.0"
@contextlib.contextmanager
def torch_evaluation(model):
"""Context manager to turn ON/OFF model evaluation
This context manager will turn evaluation mode ON on entry and turn it OFF
when exiting the ``with`` statement block.
Parameters
----------
model : :py:class:`torch.nn.Module`
Network (e.g. driu, hed, unet)
Yields
------
model : :py:class:`torch.nn.Module`
Network (e.g. driu, hed, unet)
"""
model.eval()
yield model
model.train()
def run( def run(
model, model,
data_loader, data_loader,
...@@ -203,21 +232,24 @@ def run( ...@@ -203,21 +232,24 @@ def run(
# calculates the validation loss if necessary # calculates the validation loss if necessary
valid_losses = None valid_losses = None
if valid_loader is not None: if valid_loader is not None:
valid_losses = SmoothedValue(len(valid_loader))
for samples in tqdm( with torch.no_grad(), torch_evaluation(model):
valid_loader, desc="valid", leave=False, disable=None
): valid_losses = SmoothedValue(len(valid_loader))
# data forwarding on the existing network for samples in tqdm(
images = samples[1].to(device) valid_loader, desc="valid", leave=False, disable=None
ground_truths = samples[2].to(device) ):
masks = None # data forwarding on the existing network
if len(samples) == 4: images = samples[1].to(device)
masks = samples[-1].to(device) ground_truths = samples[2].to(device)
masks = None
outputs = model(images) if len(samples) == 4:
masks = samples[-1].to(device)
loss = criterion(outputs, ground_truths, masks)
valid_losses.update(loss) outputs = model(images)
loss = criterion(outputs, ground_truths, masks)
valid_losses.update(loss)
if checkpoint_period and (epoch % checkpoint_period == 0): if checkpoint_period and (epoch % checkpoint_period == 0):
checkpointer.save(f"model_{epoch:03d}", **arguments) checkpointer.save(f"model_{epoch:03d}", **arguments)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment