From 7eb7def379d98d1b523acf6b2758a266db0e1d04 Mon Sep 17 00:00:00 2001 From: Andre Anjos <andre.anjos@idiap.ch> Date: Sun, 5 Apr 2020 11:45:10 +0200 Subject: [PATCH] [engine.ssltrainer] Re-sync with engine.trainer --- bob/ip/binseg/engine/ssltrainer.py | 205 ++++++++++++++--------------- 1 file changed, 99 insertions(+), 106 deletions(-) diff --git a/bob/ip/binseg/engine/ssltrainer.py b/bob/ip/binseg/engine/ssltrainer.py index f03e01e4..542d3162 100644 --- a/bob/ip/binseg/engine/ssltrainer.py +++ b/bob/ip/binseg/engine/ssltrainer.py @@ -2,6 +2,7 @@ # -*- coding: utf-8 -*- import os +import csv import time import datetime import torch @@ -21,7 +22,7 @@ def sharpen(x, T): return temp / temp.sum(dim=1, keepdim=True) -def mix_up(alpha, input, target, unlabeled_input, unlabled_target): +def mix_up(alpha, input, target, unlabelled_input, unlabled_target): """Applies mix up as described in [MIXMATCH_19]. Parameters @@ -32,7 +33,7 @@ def mix_up(alpha, input, target, unlabeled_input, unlabled_target): target : :py:class:`torch.Tensor` - unlabeled_input : :py:class:`torch.Tensor` + unlabelled_input : :py:class:`torch.Tensor` unlabled_target : :py:class:`torch.Tensor` @@ -48,17 +49,17 @@ def mix_up(alpha, input, target, unlabeled_input, unlabled_target): l = np.random.beta(alpha, alpha) # Eq (8) l = max(l, 1 - l) # Eq (9) # Shuffle and concat. Alg. 1 Line: 12 - w_inputs = torch.cat([input, unlabeled_input], 0) + w_inputs = torch.cat([input, unlabelled_input], 0) w_targets = torch.cat([target, unlabled_target], 0) idx = torch.randperm(w_inputs.size(0)) # get random index - # Apply MixUp to labeled data and entries from W. Alg. 1 Line: 13 + # Apply MixUp to labelled data and entries from W. Alg. 1 Line: 13 input_mixedup = l * input + (1 - l) * w_inputs[idx[len(input) :]] target_mixedup = l * target + (1 - l) * w_targets[idx[len(target) :]] - # Apply MixUp to unlabeled data and entries from W. Alg. 1 Line: 14 - unlabeled_input_mixedup = ( - l * unlabeled_input + (1 - l) * w_inputs[idx[: len(unlabeled_input)]] + # Apply MixUp to unlabelled data and entries from W. Alg. 1 Line: 14 + unlabelled_input_mixedup = ( + l * unlabelled_input + (1 - l) * w_inputs[idx[: len(unlabelled_input)]] ) unlabled_target_mixedup = ( l * unlabled_target + (1 - l) * w_targets[idx[: len(unlabled_target)]] @@ -66,7 +67,7 @@ def mix_up(alpha, input, target, unlabeled_input, unlabled_target): return ( input_mixedup, target_mixedup, - unlabeled_input_mixedup, + unlabelled_input_mixedup, unlabled_target_mixedup, ) @@ -122,14 +123,14 @@ def linear_rampup(current, rampup_length=16): return float(current) -def guess_labels(unlabeled_images, model): +def guess_labels(unlabelled_images, model): """ Calculate the average predictions by 2 augmentations: horizontal and vertical flips Parameters ---------- - unlabeled_images : :py:class:`torch.Tensor` + unlabelled_images : :py:class:`torch.Tensor` ``[n,c,h,w]`` target : :py:class:`torch.Tensor` @@ -142,12 +143,12 @@ def guess_labels(unlabeled_images, model): """ with torch.no_grad(): - guess1 = torch.sigmoid(model(unlabeled_images)).unsqueeze(0) + guess1 = torch.sigmoid(model(unlabelled_images)).unsqueeze(0) # Horizontal flip and unsqueeze to work with batches (increase flip dimension by 1) - hflip = torch.sigmoid(model(unlabeled_images.flip(2))).unsqueeze(0) + hflip = torch.sigmoid(model(unlabelled_images.flip(2))).unsqueeze(0) guess2 = hflip.flip(3) # Vertical flip and unsqueeze to work with batches (increase flip dimension by 1) - vflip = torch.sigmoid(model(unlabeled_images.flip(3))).unsqueeze(0) + vflip = torch.sigmoid(model(unlabelled_images.flip(3))).unsqueeze(0) guess3 = vflip.flip(4) # Concat concat = torch.cat([guess1, guess2, guess3], 0) @@ -169,13 +170,13 @@ def do_ssltrain( rampup_length, ): """ - Train model and save to disk. + Trains model using semi-supervised learning and saves it to disk. Parameters ---------- model : :py:class:`torch.nn.Module` - Network (e.g. DRIU, HED, UNet) + Network (e.g. driu, hed, unet) data_loader : :py:class:`torch.utils.data.DataLoader` @@ -191,13 +192,14 @@ def do_ssltrain( checkpointer checkpoint_period : int - save a checkpoint every n epochs + save a checkpoint every ``n`` epochs. If set to ``0`` (zero), then do + not save intermediary checkpoints device : str - device to use ``'cpu'`` or ``'cuda'`` + device to use ``'cpu'`` or ``cuda:0`` arguments : dict - start end end epochs + start and end epochs output_folder : str output path @@ -206,15 +208,35 @@ def do_ssltrain( rampup epochs """ - logger.info("Start SSL training") + logger.info("Start SSL training") start_epoch = arguments["epoch"] max_epoch = arguments["max_epoch"] - # Logg to file - with open( - os.path.join(output_folder, "{}_trainlog.csv".format(model.name)), "a+", 1 - ) as outfile: + # Log to file + logfile_name = os.path.join(output_folder, "trainlog.csv") + + if arguments["epoch"] == 0 and os.path.exists(logfile_name): + logger.info(f"Truncating {logfile_name} - training is restarting...") + os.unlink(logfile_name) + + logfile_fields = ( + "epoch", + "total-time", + "eta", + "average-loss", + "median-loss", + "median-labelled-loss", + "median-unlabelled-loss", + "learning-rate", + "gpu-memory-megabytes", + ) + with open(logfile_name, "a+", newline="") as logfile: + logwriter = csv.DictWriter(logfile, fieldnames=logfile_fields) + + if arguments["epoch"] == 0: + logwriter.writeheader() + for state in optimizer.state.values(): for k, v in state.items(): if isinstance(v, torch.Tensor): @@ -223,125 +245,96 @@ def do_ssltrain( model.train().to(device) # Total training timer start_training_time = time.time() + for epoch in range(start_epoch, max_epoch): scheduler.step() losses = SmoothedValue(len(data_loader)) - labeled_loss = SmoothedValue(len(data_loader)) - unlabeled_loss = SmoothedValue(len(data_loader)) + labelled_loss = SmoothedValue(len(data_loader)) + unlabelled_loss = SmoothedValue(len(data_loader)) epoch = epoch + 1 arguments["epoch"] = epoch # Epoch time start_epoch_time = time.time() - for samples in tqdm(data_loader): - # labeled + for samples in tqdm(data_loader, desc="batches", leave=False, + disable=None,): + + # data forwarding on the existing network + + # labelled images = samples[1].to(device) ground_truths = samples[2].to(device) - unlabeled_images = samples[4].to(device) - # labeled outputs + unlabelled_images = samples[4].to(device) + # labelled outputs outputs = model(images) - unlabeled_outputs = model(unlabeled_images) - # guessed unlabeled outputs - unlabeled_ground_truths = guess_labels(unlabeled_images, model) - # unlabeled_ground_truths = sharpen(unlabeled_ground_truths,0.5) - # images, ground_truths, unlabeled_images, unlabeled_ground_truths = mix_up(0.75, images, ground_truths, unlabeled_images, unlabeled_ground_truths) + unlabelled_outputs = model(unlabelled_images) + # guessed unlabelled outputs + unlabelled_ground_truths = guess_labels(unlabelled_images, model) + # unlabelled_ground_truths = sharpen(unlabelled_ground_truths,0.5) + # images, ground_truths, unlabelled_images, unlabelled_ground_truths = mix_up(0.75, images, ground_truths, unlabelled_images, unlabelled_ground_truths) + + # loss evaluation and learning (backward step) ramp_up_factor = square_rampup(epoch, rampup_length=rampup_length) loss, ll, ul = criterion( outputs, ground_truths, - unlabeled_outputs, - unlabeled_ground_truths, + unlabelled_outputs, + unlabelled_ground_truths, ramp_up_factor, ) optimizer.zero_grad() loss.backward() optimizer.step() losses.update(loss) - labeled_loss.update(ll) - unlabeled_loss.update(ul) - logger.debug("batch loss: {}".format(loss.item())) + labelled_loss.update(ll) + unlabelled_loss.update(ul) + logger.debug(f"batch loss: {loss.item()}") - if epoch % checkpoint_period == 0: - checkpointer.save("model_{:03d}".format(epoch), **arguments) + if checkpoint_period and (epoch % checkpoint_period == 0): + checkpointer.save(f"model_{epoch:03d}", **arguments) - if epoch == max_epoch: + if epoch >= max_epoch: checkpointer.save("model_final", **arguments) + # computes ETA (estimated time-of-arrival; end of training) taking + # into consideration previous epoch performance epoch_time = time.time() - start_epoch_time - eta_seconds = epoch_time * (max_epoch - epoch) - eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) + current_time = time.time() - start_training_time - outfile.write( + logdata = ( + ("epoch", f"{epoch}"), ( - "{epoch}, " - "{avg_loss:.6f}, " - "{median_loss:.6f}, " - "{median_labeled_loss}," - "{median_unlabeled_loss}," - "{lr:.6f}, " - "{memory:.0f}" - "\n" - ).format( - eta=eta_string, - epoch=epoch, - avg_loss=losses.avg, - median_loss=losses.median, - median_labeled_loss=labeled_loss.median, - median_unlabeled_loss=unlabeled_loss.median, - lr=optimizer.param_groups[0]["lr"], - memory=(torch.cuda.max_memory_allocated() / 1024.0 / 1024.0) - if torch.cuda.is_available() - else 0.0, - ) - ) - logger.info( + "total-time", + f"{datetime.timedelta(seconds=int(current_time))}", + ), + ("eta", f"{datetime.timedelta(seconds=int(eta_seconds))}"), + ("average-loss", f"{losses.avg:.6f}"), + ("median-loss", f"{losses.median:.6f}"), + ("median-labelled-loss", f"{labelled_loss.median:.6f}"), + ("median-unlabelled-loss", f"{unlabelled_loss.median:.6f}"), + ("learning-rate", f"{optimizer.param_groups[0]['lr']:.6f}"), ( - "eta: {eta}, " - "epoch: {epoch}, " - "avg. loss: {avg_loss:.6f}, " - "median loss: {median_loss:.6f}, " - "labeled loss: {median_labeled_loss}, " - "unlabeled loss: {median_unlabeled_loss}, " - "lr: {lr:.6f}, " - "max mem: {memory:.0f}" - ).format( - eta=eta_string, - epoch=epoch, - avg_loss=losses.avg, - median_loss=losses.median, - median_labeled_loss=labeled_loss.median, - median_unlabeled_loss=unlabeled_loss.median, - lr=optimizer.param_groups[0]["lr"], - memory=(torch.cuda.max_memory_allocated() / 1024.0 / 1024.0) + "gpu-memory-megabytes", + f"{torch.cuda.max_memory_allocated()/(1024.0*1024.0)}" if torch.cuda.is_available() - else 0.0, - ) + else "0.0", + ), ) + logwriter.writerow(dict(k for k in logdata)) + logger.info("|".join([f"{k}: {v}" for (k, v) in logdata])) + logger.info("End of training") total_training_time = time.time() - start_training_time - total_time_str = str(datetime.timedelta(seconds=total_training_time)) logger.info( - "Total training time: {} ({:.4f} s / epoch)".format( - total_time_str, total_training_time / (max_epoch) - ) + f"Total training time: {datetime.timedelta(seconds=total_training_time)} ({(total_training_time/max_epoch):.4f}s in average per epoch)" ) - log_plot_file = os.path.join(output_folder, "{}_trainlog.pdf".format(model.name)) - logdf = pd.read_csv( - os.path.join(output_folder, "{}_trainlog.csv".format(model.name)), - header=None, - names=[ - "avg. loss", - "median loss", - "labeled loss", - "unlabeled loss", - "lr", - "max memory", - ], - ) - fig = loss_curve(logdf, output_folder) - logger.info("saving {}".format(log_plot_file)) - fig.savefig(log_plot_file) + # plots a version of the CSV trainlog into a PDF + logdf = pd.read_csv(logfile_name, header=0, names=logfile_fields) + fig = loss_curve(logdf, title="Loss Evolution") + figurefile_name = os.path.join(output_folder, "trainlog.pdf") + logger.info(f"Saving {figurefile_name}") + fig.savefig(figurefile_name) -- GitLab