Skip to content
Snippets Groups Projects
Commit f4ceb41b authored by André Anjos's avatar André Anjos :speech_balloon:
Browse files

[engine.trainer] Implement per-epoch possible validation with checkpointing

parent 31ffe6f1
Branches
Tags
No related merge requests found
Pipeline #39827 failed
......@@ -141,7 +141,11 @@ def make_dataset(subsets, transforms):
A dictionary that contains the delayed sample lists for a number of
named lists. If one of the keys is ``train``, our standard dataset
augmentation transforms are appended to the definition of that subset.
All other subsets remain un-augmented.
All other subsets remain un-augmented. If one of the keys is
``valid``, then this dataset will be also copied to the ``__valid__``
hidden dataset and will be used for validation during training.
Otherwise, if no ``valid`` subset is available, we set ``__valid__`` to
be the same as the unaugmented ``train`` subset, if one is available.
transforms : list
A list of transforms that needs to be applied to all samples in the set
......@@ -166,5 +170,14 @@ def make_dataset(subsets, transforms):
transforms=transforms,
suffixes=(RANDOM_ROTATION + RANDOM_FLIP_JITTER),
)
if key == "valid":
# also use it for validation during training
retval["__valid__"] = retval[key]
if ("__train__" in retval) and ("train" in retval) \
and ("__valid__" not in retval):
# if the dataset does not have a validation set, we use the unaugmented
# training set as validation set
retval["__valid__"] = retval["train"]
return retval
......@@ -2,6 +2,7 @@
# -*- coding: utf-8 -*-
import os
import sys
import csv
import time
import shutil
......@@ -167,6 +168,7 @@ def guess_labels(unlabelled_images, model):
def run(
model,
data_loader,
valid_loader,
optimizer,
criterion,
scheduler,
......@@ -192,6 +194,11 @@ def run(
Network (e.g. driu, hed, unet)
data_loader : :py:class:`torch.utils.data.DataLoader`
To be used to train the model
valid_loader : :py:class:`torch.utils.data.DataLoader`
To be used to validate the model and enable automatic checkpointing.
If set to ``None``, then do not validate it.
optimizer : :py:mod:`torch.optim`
......@@ -277,10 +284,16 @@ def run(
"median_unlabelled_loss",
"learning_rate",
)
if valid_loader is not None:
logfile_fields += ("validation_average_loss", "validation_median_loss")
logfile_fields += tuple([k[0] for k in cpu_log()])
if device != "cpu":
logfile_fields += tuple([k[0] for k in gpu_log()])
# the lowest validation loss obtained so far - this value is updated only
# if a validation set is available
lowest_validation_loss = sys.float_info.max
with open(logfile_name, "a+", newline="") as logfile:
logwriter = csv.DictWriter(logfile, fieldnames=logfile_fields)
......@@ -313,6 +326,7 @@ def run(
# Epoch time
start_epoch_time = time.time()
# progress bar only on interactive jobs
for samples in tqdm(
data_loader, desc="batch", leave=False, disable=None
):
......@@ -330,8 +344,6 @@ def run(
unlabelled_ground_truths = guess_labels(
unlabelled_images, model
)
# unlabelled_ground_truths = sharpen(unlabelled_ground_truths,0.5)
# images, ground_truths, unlabelled_images, unlabelled_ground_truths = mix_up(0.75, images, ground_truths, unlabelled_images, unlabelled_ground_truths)
# loss evaluation and learning (backward step)
ramp_up_factor = square_rampup(
......@@ -356,9 +368,49 @@ def run(
if PYTORCH_GE_110:
scheduler.step()
# calculates the validation loss if necessary
valid_losses = None
if valid_loader is not None:
valid_losses = SmoothedValue(len(valid_loader))
for samples in tqdm(
valid_loader, desc="valid", leave=False, disable=None
):
# labelled
images = samples[1].to(device)
ground_truths = samples[2].to(device)
unlabelled_images = samples[4].to(device)
# labelled outputs
outputs = model(images)
unlabelled_outputs = model(unlabelled_images)
# guessed unlabelled outputs
unlabelled_ground_truths = guess_labels(
unlabelled_images, model
)
loss, ll, ul = criterion(
outputs,
ground_truths,
unlabelled_outputs,
unlabelled_ground_truths,
ramp_up_factor,
)
valid_losses.update(loss)
if checkpoint_period and (epoch % checkpoint_period == 0):
checkpointer.save(f"model_{epoch:03d}", **arguments)
if (
valid_losses is not None
and valid_losses.avg < lowest_validation_loss
):
lowest_validation_loss = valid_losses.avg
logger.info(
f"Found new low on validation set:"
f" {lowest_validation_loss:.6f}"
)
checkpointer.save(f"model_lowest_valid_loss", **arguments)
if epoch >= max_epoch:
checkpointer.save("model_final", **arguments)
......@@ -380,7 +432,16 @@ def run(
("median_labelled_loss", f"{labelled_loss.median:.6f}"),
("median_unlabelled_loss", f"{unlabelled_loss.median:.6f}"),
("learning_rate", f"{optimizer.param_groups[0]['lr']:.6f}"),
) + cpu_log()
)
if valid_losses is not None:
logdata += (
("validation_average_loss", f"{valid_losses.avg:.6f}"),
("validation_median_loss", f"{valid_losses.median:.6f}"),
)
logdata += cpu_log()
if device != "cpu":
logdata += gpu_log()
if device != "cpu":
logdata += gpu_log()
......
......@@ -2,6 +2,7 @@
# -*- coding: utf-8 -*-
import os
import sys
import csv
import time
import shutil
......@@ -25,6 +26,7 @@ PYTORCH_GE_110 = distutils.version.StrictVersion(torch.__version__) >= "1.1.0"
def run(
model,
data_loader,
valid_loader,
optimizer,
criterion,
scheduler,
......@@ -48,6 +50,11 @@ def run(
Network (e.g. driu, hed, unet)
data_loader : :py:class:`torch.utils.data.DataLoader`
To be used to train the model
valid_loader : :py:class:`torch.utils.data.DataLoader`
To be used to validate the model and enable automatic checkpointing.
If set to ``None``, then do not validate it.
optimizer : :py:mod:`torch.optim`
......@@ -127,10 +134,16 @@ def run(
"median_loss",
"learning_rate",
)
if valid_loader is not None:
logfile_fields += ("validation_average_loss", "validation_median_loss")
logfile_fields += tuple([k[0] for k in cpu_log()])
if device != "cpu":
logfile_fields += tuple([k[0] for k in gpu_log()])
# the lowest validation loss obtained so far - this value is updated only
# if a validation set is available
lowest_validation_loss = sys.float_info.max
with open(logfile_name, "a+", newline="") as logfile:
logwriter = csv.DictWriter(logfile, fieldnames=logfile_fields)
......@@ -187,9 +200,39 @@ def run(
if PYTORCH_GE_110:
scheduler.step()
# calculates the validation loss if necessary
valid_losses = None
if valid_loader is not None:
valid_losses = SmoothedValue(len(valid_loader))
for samples in tqdm(
valid_loader, desc="valid", leave=False, disable=None
):
# data forwarding on the existing network
images = samples[1].to(device)
ground_truths = samples[2].to(device)
masks = None
if len(samples) == 4:
masks = samples[-1].to(device)
outputs = model(images)
loss = criterion(outputs, ground_truths, masks)
valid_losses.update(loss)
if checkpoint_period and (epoch % checkpoint_period == 0):
checkpointer.save(f"model_{epoch:03d}", **arguments)
if (
valid_losses is not None
and valid_losses.avg < lowest_validation_loss
):
lowest_validation_loss = valid_losses.avg
logger.info(
f"Found new low on validation set:"
f" {lowest_validation_loss:.6f}"
)
checkpointer.save(f"model_lowest_valid_loss", **arguments)
if epoch >= max_epoch:
checkpointer.save("model_final", **arguments)
......@@ -209,7 +252,13 @@ def run(
("average_loss", f"{losses.avg:.6f}"),
("median_loss", f"{losses.median:.6f}"),
("learning_rate", f"{optimizer.param_groups[0]['lr']:.6f}"),
) + cpu_log()
)
if valid_losses is not None:
logdata += (
("validation_average_loss", f"{valid_losses.avg:.6f}"),
("validation_median_loss", f"{valid_losses.median:.6f}"),
)
logdata += cpu_log()
if device != "cpu":
logdata += gpu_log()
......
......@@ -274,6 +274,10 @@ def experiment(
* ``__train__``: dataset used for training, prioritarily. It is typically
the dataset containing data augmentation pipelines.
* ``__valid__``: dataset used for validation. It is typically disjoint
from the training and test sets. In such a case, we checkpoint the model
with the lowest loss on the validation set as well, throughout all the
training, besides the model at the end of training.
* ``train`` (optional): a copy of the ``__train__`` dataset, without data
augmentation, that will be evaluated alongside other sets available
* ``*``: any other name, not starting with an underscore character (``_``),
......
......@@ -71,7 +71,9 @@ logger = logging.getLogger(__name__)
"training the network model. The dataset description must include all "
"required pre-processing, including eventual data augmentation. If a "
"dataset named ``__train__`` is available, it is used prioritarily for "
"training instead of ``train``.",
"training instead of ``train``. If a dataset named ``__valid__`` is "
"available, it is used for model validation (and automatic check-pointing) "
"at each epoch.",
required=True,
cls=ResourceOption,
)
......@@ -227,6 +229,7 @@ def train(
torch.manual_seed(seed)
use_dataset = dataset
validation_dataset = None
if isinstance(dataset, dict):
if "__train__" in dataset:
logger.info("Found (dedicated) '__train__' set for training")
......@@ -234,6 +237,11 @@ def train(
else:
use_dataset = dataset["train"]
if "__valid__" in dataset:
logger.info("Found (dedicated) '__valid__' set for validation")
logger.info("Will checkpoint lowest loss model on validation set")
validation_dataset = dataset["__valid__"]
# PyTorch dataloader
data_loader = DataLoader(
dataset=use_dataset,
......@@ -243,6 +251,16 @@ def train(
pin_memory=torch.cuda.is_available(),
)
valid_loader = None
if validation_dataset is not None:
valid_loader = DataLoader(
dataset=validation_dataset,
batch_size=batch_size,
shuffle=False,
drop_last=False,
pin_memory=torch.cuda.is_available(),
)
# Checkpointer
checkpointer = DetectronCheckpointer(
model, optimizer, scheduler, save_dir=output_folder, save_to_disk=True
......@@ -262,6 +280,7 @@ def train(
run(
model,
data_loader,
valid_loader,
optimizer,
criterion,
scheduler,
......@@ -277,6 +296,7 @@ def train(
run(
model,
data_loader,
valid_loader,
optimizer,
criterion,
scheduler,
......
......@@ -88,14 +88,14 @@ def _check_experiment_stare(overlay):
output_folder = "results"
options = [
"m2unet",
config.name,
"-vv",
"--epochs=1",
"--batch-size=1",
"--steps=10",
f"--output-folder={output_folder}",
]
"m2unet",
config.name,
"-vv",
"--epochs=1",
"--batch-size=1",
"--steps=10",
f"--output-folder={output_folder}",
]
if overlay:
options += ["--overlayed"]
result = runner.invoke(experiment, options)
......@@ -107,6 +107,9 @@ def _check_experiment_stare(overlay):
# check model was saved
train_folder = os.path.join(output_folder, "model")
assert os.path.exists(os.path.join(train_folder, "model_final.pth"))
assert os.path.exists(
os.path.join(train_folder, "model_lowest_valid_loss.pth")
)
assert os.path.exists(os.path.join(train_folder, "last_checkpoint"))
assert os.path.exists(os.path.join(train_folder, "constants.csv"))
assert os.path.exists(os.path.join(train_folder, "trainlog.csv"))
......@@ -123,7 +126,9 @@ def _check_experiment_stare(overlay):
if overlay:
# check overlayed images are there (since we requested them)
assert os.path.exists(basedir)
nose.tools.eq_(len(fnmatch.filter(os.listdir(basedir), "*.png")), 20)
nose.tools.eq_(
len(fnmatch.filter(os.listdir(basedir), "*.png")), 20
)
else:
assert not os.path.exists(basedir)
......@@ -135,7 +140,7 @@ def _check_experiment_stare(overlay):
os.path.join(eval_folder, "second-annotator", "train.csv")
)
assert os.path.exists(
os.path.join(eval_folder, "second-annotator" , "test.csv")
os.path.join(eval_folder, "second-annotator", "test.csv")
)
overlay_folder = os.path.join(output_folder, "overlayed", "analysis")
......@@ -143,18 +148,23 @@ def _check_experiment_stare(overlay):
if overlay:
# check overlayed images are there (since we requested them)
assert os.path.exists(basedir)
nose.tools.eq_(len(fnmatch.filter(os.listdir(basedir), "*.png")), 20)
nose.tools.eq_(
len(fnmatch.filter(os.listdir(basedir), "*.png")), 20
)
else:
assert not os.path.exists(basedir)
# check overlayed images from first-to-second annotator comparisons
# are there (since we requested them)
overlay_folder = os.path.join(output_folder, "overlayed", "analysis",
"second-annotator")
overlay_folder = os.path.join(
output_folder, "overlayed", "analysis", "second-annotator"
)
basedir = os.path.join(overlay_folder, "stare-images")
if overlay:
assert os.path.exists(basedir)
nose.tools.eq_(len(fnmatch.filter(os.listdir(basedir), "*.png")), 20)
nose.tools.eq_(
len(fnmatch.filter(os.listdir(basedir), "*.png")), 20
)
else:
assert not os.path.exists(basedir)
......@@ -165,10 +175,13 @@ def _check_experiment_stare(overlay):
keywords = {
r"^Started training$": 1,
r"^Found \(dedicated\) '__train__' set for training$": 1,
r"^Found \(dedicated\) '__valid__' set for validation$": 1,
r"^Will checkpoint lowest loss model on validation set$": 1,
r"^Continuing from epoch 0$": 1,
r"^Saving model summary at.*$": 1,
r"^Model has.*$": 1,
r"^Saving checkpoint": 1,
r"^Found new low on validation set.*$": 1,
r"^Saving checkpoint": 2,
r"^Ended training$": 1,
r"^Started prediction$": 1,
r"^Loading checkpoint from": 2,
......@@ -223,7 +236,7 @@ def _check_train(runner):
config.write(
"from bob.ip.binseg.configs.datasets.stare import _maker\n"
)
config.write("dataset = _maker('ah', _raw)['train']\n")
config.write("dataset = _maker('ah', _raw)\n")
config.flush()
output_folder = "results"
......@@ -241,15 +254,21 @@ def _check_train(runner):
_assert_exit_0(result)
assert os.path.exists(os.path.join(output_folder, "model_final.pth"))
assert os.path.exists(
os.path.join(output_folder, "model_lowest_valid_loss.pth")
)
assert os.path.exists(os.path.join(output_folder, "last_checkpoint"))
assert os.path.exists(os.path.join(output_folder, "constants.csv"))
assert os.path.exists(os.path.join(output_folder, "trainlog.csv"))
assert os.path.exists(os.path.join(output_folder, "model_summary.txt"))
keywords = {
r"^Found \(dedicated\) '__train__' set for training$": 1,
r"^Found \(dedicated\) '__valid__' set for validation$": 1,
r"^Continuing from epoch 0$": 1,
r"^Saving model summary at.*$": 1,
r"^Model has.*$": 1,
rf"^Saving checkpoint to {output_folder}/model_lowest_valid_loss.pth$": 1,
rf"^Saving checkpoint to {output_folder}/model_final.pth$": 1,
r"^Total training time:": 1,
}
......@@ -364,8 +383,9 @@ def _check_evaluate(runner):
_assert_exit_0(result)
assert os.path.exists(os.path.join(output_folder, "test.csv"))
assert os.path.exists(os.path.join(output_folder,
"second-annotator", "test.csv"))
assert os.path.exists(
os.path.join(output_folder, "second-annotator", "test.csv")
)
# check overlayed images are there (since we requested them)
basedir = os.path.join(overlay_folder, "stare-images")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment