#!/usr/bin/env python # -*- coding: utf-8 -*- import os import csv import time import shutil import datetime import distutils.version import torch from tqdm import tqdm from ..utils.metric import SmoothedValue from ..utils.summary import summary from ..utils.resources import cpu_info, gpu_info, cpu_log, gpu_log import logging logger = logging.getLogger(__name__) PYTORCH_GE_110 = distutils.version.StrictVersion(torch.__version__) >= "1.1.0" def run( model, data_loader, optimizer, criterion, scheduler, checkpointer, checkpoint_period, device, arguments, output_folder, ): """ Fits an FCN model using supervised learning and save it to disk. This method supports periodic checkpointing and the output of a CSV-formatted log with the evolution of some figures during training. Parameters ---------- model : :py:class:`torch.nn.Module` Network (e.g. DRIU, HED, UNet) data_loader : :py:class:`torch.utils.data.DataLoader` optimizer : :py:mod:`torch.optim` criterion : :py:class:`torch.nn.modules.loss._Loss` loss function scheduler : :py:mod:`torch.optim` learning rate scheduler checkpointer : :py:class:`bob.ip.binseg.utils.checkpointer.DetectronCheckpointer` checkpointer implementation checkpoint_period : int save a checkpoint every ``n`` epochs. If set to ``0`` (zero), then do not save intermediary checkpoints device : str device to use ``'cpu'`` or ``cuda:0`` arguments : dict start and end epochs output_folder : str output path """ start_epoch = arguments["epoch"] max_epoch = arguments["max_epoch"] if device != "cpu": # asserts we do have a GPU assert bool(gpu_info()), ( f"Device set to '{device}', but cannot " f"find a GPU (maybe nvidia-smi is not installed?)" ) if not os.path.exists(output_folder): logger.debug(f"Creating output directory '{output_folder}'...") os.makedirs(output_folder) # Save model summary summary_path = os.path.join(output_folder, "model_summary.txt") logger.info(f"Saving model summary at {summary_path}...") with open(summary_path, "wt") as f: r, n = summary(model) logger.info(f"Model has {n} parameters...") f.write(r) # write static information to a CSV file static_logfile_name = os.path.join(output_folder, "constants.csv") if os.path.exists(static_logfile_name): backup = static_logfile_name + "~" if os.path.exists(backup): os.unlink(backup) shutil.move(static_logfile_name, backup) with open(static_logfile_name, "w", newline="") as f: logdata = cpu_info() if device != "cpu": logdata += gpu_info() logdata = (("model_size", n),) + logdata logwriter = csv.DictWriter(f, fieldnames=[k[0] for k in logdata]) logwriter.writeheader() logwriter.writerow(dict(k for k in logdata)) # Log continous information to (another) file logfile_name = os.path.join(output_folder, "trainlog.csv") if arguments["epoch"] == 0 and os.path.exists(logfile_name): backup = logfile_name + "~" if os.path.exists(backup): os.unlink(backup) shutil.move(logfile_name, backup) logfile_fields = ( "epoch", "total_time", "eta", "average_loss", "median_loss", "learning_rate", ) logfile_fields += tuple([k[0] for k in cpu_log()]) if device != "cpu": logfile_fields += tuple([k[0] for k in gpu_log()]) with open(logfile_name, "a+", newline="") as logfile: logwriter = csv.DictWriter(logfile, fieldnames=logfile_fields) if arguments["epoch"] == 0: logwriter.writeheader() model.train().to(device) for state in optimizer.state.values(): for k, v in state.items(): if isinstance(v, torch.Tensor): state[k] = v.to(device) # Total training timer start_training_time = time.time() for epoch in tqdm( range(start_epoch, max_epoch), desc="epoch", leave=False, disable=None, ): if not PYTORCH_GE_110: scheduler.step() losses = SmoothedValue(len(data_loader)) epoch = epoch + 1 arguments["epoch"] = epoch # Epoch time start_epoch_time = time.time() # progress bar only on interactive jobs for samples in tqdm( data_loader, desc="batch", leave=False, disable=None ): # data forwarding on the existing network images = samples[1].to(device) ground_truths = samples[2].to(device) masks = None if len(samples) == 4: masks = samples[-1].to(device) outputs = model(images) # loss evaluation and learning (backward step) loss = criterion(outputs, ground_truths, masks) optimizer.zero_grad() loss.backward() optimizer.step() losses.update(loss) logger.debug(f"batch loss: {loss.item()}") if PYTORCH_GE_110: scheduler.step() if checkpoint_period and (epoch % checkpoint_period == 0): checkpointer.save(f"model_{epoch:03d}", **arguments) if epoch >= max_epoch: checkpointer.save("model_final", **arguments) # computes ETA (estimated time-of-arrival; end of training) taking # into consideration previous epoch performance epoch_time = time.time() - start_epoch_time eta_seconds = epoch_time * (max_epoch - epoch) current_time = time.time() - start_training_time logdata = ( ("epoch", f"{epoch}"), ( "total_time", f"{datetime.timedelta(seconds=int(current_time))}", ), ("eta", f"{datetime.timedelta(seconds=int(eta_seconds))}"), ("average_loss", f"{losses.avg:.6f}"), ("median_loss", f"{losses.median:.6f}"), ("learning_rate", f"{optimizer.param_groups[0]['lr']:.6f}"), ) + cpu_log() if device != 'cpu': logdata += gpu_log() logwriter.writerow(dict(k for k in logdata)) logfile.flush() tqdm.write("|".join([f"{k}: {v}" for (k, v) in logdata[:4]])) total_training_time = time.time() - start_training_time logger.info( f"Total training time: {datetime.timedelta(seconds=total_training_time)} ({(total_training_time/max_epoch):.4f}s in average per epoch)" )