diff --git a/bob/ip/binseg/engine/ssltrainer.py b/bob/ip/binseg/engine/ssltrainer.py index d8b66b69d5e729de87abafd4963a9fc71a4a87d9..9db427a7ebe4b3138948da9d8e7b35259049cbe1 100644 --- a/bob/ip/binseg/engine/ssltrainer.py +++ b/bob/ip/binseg/engine/ssltrainer.py @@ -4,21 +4,23 @@ import os import csv import time +import shutil import datetime import distutils.version import numpy -import pandas import torch from tqdm import tqdm from ..utils.measure import SmoothedValue -from ..utils.plot import loss_curve +from ..utils.summary import summary +from ..utils.resources import cpu_constants, gpu_constants, cpu_log, gpu_log import logging + logger = logging.getLogger(__name__) -PYTORCH_GE_110 = (distutils.version.StrictVersion(torch.__version__) >= "1.1.0") +PYTORCH_GE_110 = distutils.version.StrictVersion(torch.__version__) >= "1.1.0" def sharpen(x, T): @@ -48,7 +50,7 @@ def mix_up(alpha, input, target, unlabelled_input, unlabled_target): list """ - # TODO: + with torch.no_grad(): l = numpy.random.beta(alpha, alpha) # Eq (8) l = max(l, 1 - l) # Eq (9) @@ -63,10 +65,12 @@ def mix_up(alpha, input, target, unlabelled_input, unlabled_target): # Apply MixUp to unlabelled data and entries from W. Alg. 1 Line: 14 unlabelled_input_mixedup = ( - l * unlabelled_input + (1 - l) * w_inputs[idx[: len(unlabelled_input)]] + l * unlabelled_input + + (1 - l) * w_inputs[idx[: len(unlabelled_input)]] ) unlabled_target_mixedup = ( - l * unlabled_target + (1 - l) * w_targets[idx[: len(unlabled_target)]] + l * unlabled_target + + (1 - l) * w_targets[idx[: len(unlabled_target)]] ) return ( input_mixedup, @@ -176,6 +180,11 @@ def run( """ Fits an FCN model using semi-supervised learning and saves it to disk. + + This method supports periodic checkpointing and the output of a + CSV-formatted log with the evolution of some figures during training. + + Parameters ---------- @@ -193,7 +202,7 @@ def run( learning rate scheduler checkpointer : :py:class:`bob.ip.binseg.utils.checkpointer.DetectronCheckpointer` - checkpointer + checkpointer implementation checkpoint_period : int save a checkpoint every ``n`` epochs. If set to ``0`` (zero), then do @@ -216,45 +225,85 @@ def run( start_epoch = arguments["epoch"] max_epoch = arguments["max_epoch"] - if not os.path.exists(output_folder): - logger.debug(f"Creating output directory '{output_folder}'...") - os.makedirs(output_folder) + if device != "cpu": + # asserts we do have a GPU + assert bool(gpu_constants()), ( + f"Device set to '{device}', but cannot " + f"find a GPU (maybe nvidia-smi is not installed?)" + ) - # Log to file + os.makedirs(output_folder, exist_ok=True) + + # Save model summary + summary_path = os.path.join(output_folder, "model_summary.txt") + logger.info(f"Saving model summary at {summary_path}...") + with open(summary_path, "wt") as f: + r, n = summary(model) + logger.info(f"Model has {n} parameters...") + f.write(r) + + # write static information to a CSV file + static_logfile_name = os.path.join(output_folder, "constants.csv") + if os.path.exists(static_logfile_name): + backup = static_logfile_name + "~" + if os.path.exists(backup): + os.unlink(backup) + shutil.move(static_logfile_name, backup) + with open(static_logfile_name, "w", newline="") as f: + logdata = cpu_constants() + if device != "cpu": + logdata += gpu_constants() + logdata += (("model_size", n),) + logwriter = csv.DictWriter(f, fieldnames=[k[0] for k in logdata]) + logwriter.writeheader() + logwriter.writerow(dict(k for k in logdata)) + + # Log continous information to (another) file logfile_name = os.path.join(output_folder, "trainlog.csv") if arguments["epoch"] == 0 and os.path.exists(logfile_name): - logger.info(f"Truncating {logfile_name} - training is restarting...") - os.unlink(logfile_name) + backup = logfile_name + "~" + if os.path.exists(backup): + os.unlink(backup) + shutil.move(logfile_name, backup) logfile_fields = ( "epoch", - "total-time", + "total_time", "eta", - "average-loss", - "median-loss", - "median-labelled-loss", - "median-unlabelled-loss", - "learning-rate", - "gpu-memory-megabytes", + "average_loss", + "median_loss", + "median_labelled_loss", + "median_unlabelled_loss", + "learning_rate", ) + logfile_fields += tuple([k[0] for k in cpu_log()]) + if device != "cpu": + logfile_fields += tuple([k[0] for k in gpu_log()]) + with open(logfile_name, "a+", newline="") as logfile: logwriter = csv.DictWriter(logfile, fieldnames=logfile_fields) if arguments["epoch"] == 0: logwriter.writeheader() + model.train().to(device) for state in optimizer.state.values(): for k, v in state.items(): if isinstance(v, torch.Tensor): state[k] = v.to(device) - model.train().to(device) # Total training timer start_training_time = time.time() - for epoch in range(start_epoch, max_epoch): - if not PYTORCH_GE_110: scheduler.step() + for epoch in tqdm( + range(start_epoch, max_epoch), + desc="epoch", + leave=False, + disable=None, + ): + if not PYTORCH_GE_110: + scheduler.step() losses = SmoothedValue(len(data_loader)) labelled_loss = SmoothedValue(len(data_loader)) unlabelled_loss = SmoothedValue(len(data_loader)) @@ -264,8 +313,9 @@ def run( # Epoch time start_epoch_time = time.time() - for samples in tqdm(data_loader, desc="batches", leave=False, - disable=None,): + for samples in tqdm( + data_loader, desc="batch", leave=False, disable=None + ): # data forwarding on the existing network @@ -277,12 +327,16 @@ def run( outputs = model(images) unlabelled_outputs = model(unlabelled_images) # guessed unlabelled outputs - unlabelled_ground_truths = guess_labels(unlabelled_images, model) + unlabelled_ground_truths = guess_labels( + unlabelled_images, model + ) # unlabelled_ground_truths = sharpen(unlabelled_ground_truths,0.5) # images, ground_truths, unlabelled_images, unlabelled_ground_truths = mix_up(0.75, images, ground_truths, unlabelled_images, unlabelled_ground_truths) # loss evaluation and learning (backward step) - ramp_up_factor = square_rampup(epoch, rampup_length=rampup_length) + ramp_up_factor = square_rampup( + epoch, rampup_length=rampup_length + ) loss, ll, ul = criterion( outputs, @@ -299,7 +353,8 @@ def run( unlabelled_loss.update(ul) logger.debug(f"batch loss: {loss.item()}") - if PYTORCH_GE_110: scheduler.step() + if PYTORCH_GE_110: + scheduler.step() if checkpoint_period and (epoch % checkpoint_period == 0): checkpointer.save(f"model_{epoch:03d}", **arguments) @@ -316,33 +371,24 @@ def run( logdata = ( ("epoch", f"{epoch}"), ( - "total-time", + "total_time", f"{datetime.timedelta(seconds=int(current_time))}", ), ("eta", f"{datetime.timedelta(seconds=int(eta_seconds))}"), - ("average-loss", f"{losses.avg:.6f}"), - ("median-loss", f"{losses.median:.6f}"), - ("median-labelled-loss", f"{labelled_loss.median:.6f}"), - ("median-unlabelled-loss", f"{unlabelled_loss.median:.6f}"), - ("learning-rate", f"{optimizer.param_groups[0]['lr']:.6f}"), - ( - "gpu-memory-megabytes", - f"{torch.cuda.max_memory_allocated()/(1024.0*1024.0)}" - if torch.cuda.is_available() - else "0.0", - ), - ) + ("average_loss", f"{losses.avg:.6f}"), + ("median_loss", f"{losses.median:.6f}"), + ("median_labelled_loss", f"{labelled_loss.median:.6f}"), + ("median_unlabelled_loss", f"{unlabelled_loss.median:.6f}"), + ("learning_rate", f"{optimizer.param_groups[0]['lr']:.6f}"), + ) + cpu_log() + if device != "cpu": + logdata += gpu_log() + logwriter.writerow(dict(k for k in logdata)) - logger.info("|".join([f"{k}: {v}" for (k, v) in logdata])) + logfile.flush() + tqdm.write("|".join([f"{k}: {v}" for (k, v) in logdata[:4]])) total_training_time = time.time() - start_training_time logger.info( f"Total training time: {datetime.timedelta(seconds=total_training_time)} ({(total_training_time/max_epoch):.4f}s in average per epoch)" ) - - # plots a version of the CSV trainlog into a PDF - logdf = pandas.read_csv(logfile_name, header=0, names=logfile_fields) - fig = loss_curve(logdf) - figurefile_name = os.path.join(output_folder, "trainlog.pdf") - logger.info(f"Saving {figurefile_name}") - fig.savefig(figurefile_name)