From f4ceb41bc4d694ea08ae07d066f5a96eb4c4bfcd Mon Sep 17 00:00:00 2001 From: Andre Anjos <andre.dos.anjos@gmail.com> Date: Tue, 12 May 2020 22:25:29 +0200 Subject: [PATCH] [engine.trainer] Implement per-epoch possible validation with checkpointing --- bob/ip/binseg/configs/datasets/__init__.py | 15 ++++- bob/ip/binseg/engine/ssltrainer.py | 67 +++++++++++++++++++++- bob/ip/binseg/engine/trainer.py | 51 +++++++++++++++- bob/ip/binseg/script/experiment.py | 4 ++ bob/ip/binseg/script/train.py | 22 ++++++- bob/ip/binseg/test/test_cli.py | 56 ++++++++++++------ 6 files changed, 191 insertions(+), 24 deletions(-) diff --git a/bob/ip/binseg/configs/datasets/__init__.py b/bob/ip/binseg/configs/datasets/__init__.py index f2621a72..4cdf1973 100644 --- a/bob/ip/binseg/configs/datasets/__init__.py +++ b/bob/ip/binseg/configs/datasets/__init__.py @@ -141,7 +141,11 @@ def make_dataset(subsets, transforms): A dictionary that contains the delayed sample lists for a number of named lists. If one of the keys is ``train``, our standard dataset augmentation transforms are appended to the definition of that subset. - All other subsets remain un-augmented. + All other subsets remain un-augmented. If one of the keys is + ``valid``, then this dataset will be also copied to the ``__valid__`` + hidden dataset and will be used for validation during training. + Otherwise, if no ``valid`` subset is available, we set ``__valid__`` to + be the same as the unaugmented ``train`` subset, if one is available. transforms : list A list of transforms that needs to be applied to all samples in the set @@ -166,5 +170,14 @@ def make_dataset(subsets, transforms): transforms=transforms, suffixes=(RANDOM_ROTATION + RANDOM_FLIP_JITTER), ) + if key == "valid": + # also use it for validation during training + retval["__valid__"] = retval[key] + + if ("__train__" in retval) and ("train" in retval) \ + and ("__valid__" not in retval): + # if the dataset does not have a validation set, we use the unaugmented + # training set as validation set + retval["__valid__"] = retval["train"] return retval diff --git a/bob/ip/binseg/engine/ssltrainer.py b/bob/ip/binseg/engine/ssltrainer.py index 9db427a7..6905a645 100644 --- a/bob/ip/binseg/engine/ssltrainer.py +++ b/bob/ip/binseg/engine/ssltrainer.py @@ -2,6 +2,7 @@ # -*- coding: utf-8 -*- import os +import sys import csv import time import shutil @@ -167,6 +168,7 @@ def guess_labels(unlabelled_images, model): def run( model, data_loader, + valid_loader, optimizer, criterion, scheduler, @@ -192,6 +194,11 @@ def run( Network (e.g. driu, hed, unet) data_loader : :py:class:`torch.utils.data.DataLoader` + To be used to train the model + + valid_loader : :py:class:`torch.utils.data.DataLoader` + To be used to validate the model and enable automatic checkpointing. + If set to ``None``, then do not validate it. optimizer : :py:mod:`torch.optim` @@ -277,10 +284,16 @@ def run( "median_unlabelled_loss", "learning_rate", ) + if valid_loader is not None: + logfile_fields += ("validation_average_loss", "validation_median_loss") logfile_fields += tuple([k[0] for k in cpu_log()]) if device != "cpu": logfile_fields += tuple([k[0] for k in gpu_log()]) + # the lowest validation loss obtained so far - this value is updated only + # if a validation set is available + lowest_validation_loss = sys.float_info.max + with open(logfile_name, "a+", newline="") as logfile: logwriter = csv.DictWriter(logfile, fieldnames=logfile_fields) @@ -313,6 +326,7 @@ def run( # Epoch time start_epoch_time = time.time() + # progress bar only on interactive jobs for samples in tqdm( data_loader, desc="batch", leave=False, disable=None ): @@ -330,8 +344,6 @@ def run( unlabelled_ground_truths = guess_labels( unlabelled_images, model ) - # unlabelled_ground_truths = sharpen(unlabelled_ground_truths,0.5) - # images, ground_truths, unlabelled_images, unlabelled_ground_truths = mix_up(0.75, images, ground_truths, unlabelled_images, unlabelled_ground_truths) # loss evaluation and learning (backward step) ramp_up_factor = square_rampup( @@ -356,9 +368,49 @@ def run( if PYTORCH_GE_110: scheduler.step() + # calculates the validation loss if necessary + valid_losses = None + if valid_loader is not None: + valid_losses = SmoothedValue(len(valid_loader)) + for samples in tqdm( + valid_loader, desc="valid", leave=False, disable=None + ): + + # labelled + images = samples[1].to(device) + ground_truths = samples[2].to(device) + unlabelled_images = samples[4].to(device) + # labelled outputs + outputs = model(images) + unlabelled_outputs = model(unlabelled_images) + # guessed unlabelled outputs + unlabelled_ground_truths = guess_labels( + unlabelled_images, model + ) + loss, ll, ul = criterion( + outputs, + ground_truths, + unlabelled_outputs, + unlabelled_ground_truths, + ramp_up_factor, + ) + + valid_losses.update(loss) + if checkpoint_period and (epoch % checkpoint_period == 0): checkpointer.save(f"model_{epoch:03d}", **arguments) + if ( + valid_losses is not None + and valid_losses.avg < lowest_validation_loss + ): + lowest_validation_loss = valid_losses.avg + logger.info( + f"Found new low on validation set:" + f" {lowest_validation_loss:.6f}" + ) + checkpointer.save(f"model_lowest_valid_loss", **arguments) + if epoch >= max_epoch: checkpointer.save("model_final", **arguments) @@ -380,7 +432,16 @@ def run( ("median_labelled_loss", f"{labelled_loss.median:.6f}"), ("median_unlabelled_loss", f"{unlabelled_loss.median:.6f}"), ("learning_rate", f"{optimizer.param_groups[0]['lr']:.6f}"), - ) + cpu_log() + ) + if valid_losses is not None: + logdata += ( + ("validation_average_loss", f"{valid_losses.avg:.6f}"), + ("validation_median_loss", f"{valid_losses.median:.6f}"), + ) + logdata += cpu_log() + if device != "cpu": + logdata += gpu_log() + if device != "cpu": logdata += gpu_log() diff --git a/bob/ip/binseg/engine/trainer.py b/bob/ip/binseg/engine/trainer.py index 81b191dd..7d1f841b 100644 --- a/bob/ip/binseg/engine/trainer.py +++ b/bob/ip/binseg/engine/trainer.py @@ -2,6 +2,7 @@ # -*- coding: utf-8 -*- import os +import sys import csv import time import shutil @@ -25,6 +26,7 @@ PYTORCH_GE_110 = distutils.version.StrictVersion(torch.__version__) >= "1.1.0" def run( model, data_loader, + valid_loader, optimizer, criterion, scheduler, @@ -48,6 +50,11 @@ def run( Network (e.g. driu, hed, unet) data_loader : :py:class:`torch.utils.data.DataLoader` + To be used to train the model + + valid_loader : :py:class:`torch.utils.data.DataLoader` + To be used to validate the model and enable automatic checkpointing. + If set to ``None``, then do not validate it. optimizer : :py:mod:`torch.optim` @@ -127,10 +134,16 @@ def run( "median_loss", "learning_rate", ) + if valid_loader is not None: + logfile_fields += ("validation_average_loss", "validation_median_loss") logfile_fields += tuple([k[0] for k in cpu_log()]) if device != "cpu": logfile_fields += tuple([k[0] for k in gpu_log()]) + # the lowest validation loss obtained so far - this value is updated only + # if a validation set is available + lowest_validation_loss = sys.float_info.max + with open(logfile_name, "a+", newline="") as logfile: logwriter = csv.DictWriter(logfile, fieldnames=logfile_fields) @@ -187,9 +200,39 @@ def run( if PYTORCH_GE_110: scheduler.step() + # calculates the validation loss if necessary + valid_losses = None + if valid_loader is not None: + valid_losses = SmoothedValue(len(valid_loader)) + for samples in tqdm( + valid_loader, desc="valid", leave=False, disable=None + ): + # data forwarding on the existing network + images = samples[1].to(device) + ground_truths = samples[2].to(device) + masks = None + if len(samples) == 4: + masks = samples[-1].to(device) + + outputs = model(images) + + loss = criterion(outputs, ground_truths, masks) + valid_losses.update(loss) + if checkpoint_period and (epoch % checkpoint_period == 0): checkpointer.save(f"model_{epoch:03d}", **arguments) + if ( + valid_losses is not None + and valid_losses.avg < lowest_validation_loss + ): + lowest_validation_loss = valid_losses.avg + logger.info( + f"Found new low on validation set:" + f" {lowest_validation_loss:.6f}" + ) + checkpointer.save(f"model_lowest_valid_loss", **arguments) + if epoch >= max_epoch: checkpointer.save("model_final", **arguments) @@ -209,7 +252,13 @@ def run( ("average_loss", f"{losses.avg:.6f}"), ("median_loss", f"{losses.median:.6f}"), ("learning_rate", f"{optimizer.param_groups[0]['lr']:.6f}"), - ) + cpu_log() + ) + if valid_losses is not None: + logdata += ( + ("validation_average_loss", f"{valid_losses.avg:.6f}"), + ("validation_median_loss", f"{valid_losses.median:.6f}"), + ) + logdata += cpu_log() if device != "cpu": logdata += gpu_log() diff --git a/bob/ip/binseg/script/experiment.py b/bob/ip/binseg/script/experiment.py index 9afa6f5b..30d14e2c 100644 --- a/bob/ip/binseg/script/experiment.py +++ b/bob/ip/binseg/script/experiment.py @@ -274,6 +274,10 @@ def experiment( * ``__train__``: dataset used for training, prioritarily. It is typically the dataset containing data augmentation pipelines. + * ``__valid__``: dataset used for validation. It is typically disjoint + from the training and test sets. In such a case, we checkpoint the model + with the lowest loss on the validation set as well, throughout all the + training, besides the model at the end of training. * ``train`` (optional): a copy of the ``__train__`` dataset, without data augmentation, that will be evaluated alongside other sets available * ``*``: any other name, not starting with an underscore character (``_``), diff --git a/bob/ip/binseg/script/train.py b/bob/ip/binseg/script/train.py index 3076aae4..da56a3bd 100644 --- a/bob/ip/binseg/script/train.py +++ b/bob/ip/binseg/script/train.py @@ -71,7 +71,9 @@ logger = logging.getLogger(__name__) "training the network model. The dataset description must include all " "required pre-processing, including eventual data augmentation. If a " "dataset named ``__train__`` is available, it is used prioritarily for " - "training instead of ``train``.", + "training instead of ``train``. If a dataset named ``__valid__`` is " + "available, it is used for model validation (and automatic check-pointing) " + "at each epoch.", required=True, cls=ResourceOption, ) @@ -227,6 +229,7 @@ def train( torch.manual_seed(seed) use_dataset = dataset + validation_dataset = None if isinstance(dataset, dict): if "__train__" in dataset: logger.info("Found (dedicated) '__train__' set for training") @@ -234,6 +237,11 @@ def train( else: use_dataset = dataset["train"] + if "__valid__" in dataset: + logger.info("Found (dedicated) '__valid__' set for validation") + logger.info("Will checkpoint lowest loss model on validation set") + validation_dataset = dataset["__valid__"] + # PyTorch dataloader data_loader = DataLoader( dataset=use_dataset, @@ -243,6 +251,16 @@ def train( pin_memory=torch.cuda.is_available(), ) + valid_loader = None + if validation_dataset is not None: + valid_loader = DataLoader( + dataset=validation_dataset, + batch_size=batch_size, + shuffle=False, + drop_last=False, + pin_memory=torch.cuda.is_available(), + ) + # Checkpointer checkpointer = DetectronCheckpointer( model, optimizer, scheduler, save_dir=output_folder, save_to_disk=True @@ -262,6 +280,7 @@ def train( run( model, data_loader, + valid_loader, optimizer, criterion, scheduler, @@ -277,6 +296,7 @@ def train( run( model, data_loader, + valid_loader, optimizer, criterion, scheduler, diff --git a/bob/ip/binseg/test/test_cli.py b/bob/ip/binseg/test/test_cli.py index 3d3f728f..9a7a9847 100644 --- a/bob/ip/binseg/test/test_cli.py +++ b/bob/ip/binseg/test/test_cli.py @@ -88,14 +88,14 @@ def _check_experiment_stare(overlay): output_folder = "results" options = [ - "m2unet", - config.name, - "-vv", - "--epochs=1", - "--batch-size=1", - "--steps=10", - f"--output-folder={output_folder}", - ] + "m2unet", + config.name, + "-vv", + "--epochs=1", + "--batch-size=1", + "--steps=10", + f"--output-folder={output_folder}", + ] if overlay: options += ["--overlayed"] result = runner.invoke(experiment, options) @@ -107,6 +107,9 @@ def _check_experiment_stare(overlay): # check model was saved train_folder = os.path.join(output_folder, "model") assert os.path.exists(os.path.join(train_folder, "model_final.pth")) + assert os.path.exists( + os.path.join(train_folder, "model_lowest_valid_loss.pth") + ) assert os.path.exists(os.path.join(train_folder, "last_checkpoint")) assert os.path.exists(os.path.join(train_folder, "constants.csv")) assert os.path.exists(os.path.join(train_folder, "trainlog.csv")) @@ -123,7 +126,9 @@ def _check_experiment_stare(overlay): if overlay: # check overlayed images are there (since we requested them) assert os.path.exists(basedir) - nose.tools.eq_(len(fnmatch.filter(os.listdir(basedir), "*.png")), 20) + nose.tools.eq_( + len(fnmatch.filter(os.listdir(basedir), "*.png")), 20 + ) else: assert not os.path.exists(basedir) @@ -135,7 +140,7 @@ def _check_experiment_stare(overlay): os.path.join(eval_folder, "second-annotator", "train.csv") ) assert os.path.exists( - os.path.join(eval_folder, "second-annotator" , "test.csv") + os.path.join(eval_folder, "second-annotator", "test.csv") ) overlay_folder = os.path.join(output_folder, "overlayed", "analysis") @@ -143,18 +148,23 @@ def _check_experiment_stare(overlay): if overlay: # check overlayed images are there (since we requested them) assert os.path.exists(basedir) - nose.tools.eq_(len(fnmatch.filter(os.listdir(basedir), "*.png")), 20) + nose.tools.eq_( + len(fnmatch.filter(os.listdir(basedir), "*.png")), 20 + ) else: assert not os.path.exists(basedir) # check overlayed images from first-to-second annotator comparisons # are there (since we requested them) - overlay_folder = os.path.join(output_folder, "overlayed", "analysis", - "second-annotator") + overlay_folder = os.path.join( + output_folder, "overlayed", "analysis", "second-annotator" + ) basedir = os.path.join(overlay_folder, "stare-images") if overlay: assert os.path.exists(basedir) - nose.tools.eq_(len(fnmatch.filter(os.listdir(basedir), "*.png")), 20) + nose.tools.eq_( + len(fnmatch.filter(os.listdir(basedir), "*.png")), 20 + ) else: assert not os.path.exists(basedir) @@ -165,10 +175,13 @@ def _check_experiment_stare(overlay): keywords = { r"^Started training$": 1, r"^Found \(dedicated\) '__train__' set for training$": 1, + r"^Found \(dedicated\) '__valid__' set for validation$": 1, + r"^Will checkpoint lowest loss model on validation set$": 1, r"^Continuing from epoch 0$": 1, r"^Saving model summary at.*$": 1, r"^Model has.*$": 1, - r"^Saving checkpoint": 1, + r"^Found new low on validation set.*$": 1, + r"^Saving checkpoint": 2, r"^Ended training$": 1, r"^Started prediction$": 1, r"^Loading checkpoint from": 2, @@ -223,7 +236,7 @@ def _check_train(runner): config.write( "from bob.ip.binseg.configs.datasets.stare import _maker\n" ) - config.write("dataset = _maker('ah', _raw)['train']\n") + config.write("dataset = _maker('ah', _raw)\n") config.flush() output_folder = "results" @@ -241,15 +254,21 @@ def _check_train(runner): _assert_exit_0(result) assert os.path.exists(os.path.join(output_folder, "model_final.pth")) + assert os.path.exists( + os.path.join(output_folder, "model_lowest_valid_loss.pth") + ) assert os.path.exists(os.path.join(output_folder, "last_checkpoint")) assert os.path.exists(os.path.join(output_folder, "constants.csv")) assert os.path.exists(os.path.join(output_folder, "trainlog.csv")) assert os.path.exists(os.path.join(output_folder, "model_summary.txt")) keywords = { + r"^Found \(dedicated\) '__train__' set for training$": 1, + r"^Found \(dedicated\) '__valid__' set for validation$": 1, r"^Continuing from epoch 0$": 1, r"^Saving model summary at.*$": 1, r"^Model has.*$": 1, + rf"^Saving checkpoint to {output_folder}/model_lowest_valid_loss.pth$": 1, rf"^Saving checkpoint to {output_folder}/model_final.pth$": 1, r"^Total training time:": 1, } @@ -364,8 +383,9 @@ def _check_evaluate(runner): _assert_exit_0(result) assert os.path.exists(os.path.join(output_folder, "test.csv")) - assert os.path.exists(os.path.join(output_folder, - "second-annotator", "test.csv")) + assert os.path.exists( + os.path.join(output_folder, "second-annotator", "test.csv") + ) # check overlayed images are there (since we requested them) basedir = os.path.join(overlay_folder, "stare-images") -- GitLab