From 915f4e79007698ab6721d8d64fcc5f1454849bf7 Mon Sep 17 00:00:00 2001 From: Tim Laibacher <tim.laibacher@idiap.ch> Date: Mon, 24 Jun 2019 16:53:08 +0200 Subject: [PATCH] Add valtrainer --- .../datasets/drivestarechasedb1hrf1024.py | 10 ++ bob/ip/binseg/configs/datasets/hrf1024.py | 24 +++ bob/ip/binseg/engine/valtrainer.py | 151 ++++++++++++++++++ bob/ip/binseg/script/binseg.py | 151 +++++++++++++++++- bob/ip/binseg/utils/plot.py | 9 +- setup.py | 2 + 6 files changed, 341 insertions(+), 6 deletions(-) create mode 100644 bob/ip/binseg/configs/datasets/drivestarechasedb1hrf1024.py create mode 100644 bob/ip/binseg/configs/datasets/hrf1024.py create mode 100644 bob/ip/binseg/engine/valtrainer.py diff --git a/bob/ip/binseg/configs/datasets/drivestarechasedb1hrf1024.py b/bob/ip/binseg/configs/datasets/drivestarechasedb1hrf1024.py new file mode 100644 index 00000000..267fc45c --- /dev/null +++ b/bob/ip/binseg/configs/datasets/drivestarechasedb1hrf1024.py @@ -0,0 +1,10 @@ +from bob.ip.binseg.configs.datasets.drive1024 import dataset as drive +from bob.ip.binseg.configs.datasets.stare1024 import dataset as stare +from bob.ip.binseg.configs.datasets.hrf1024 import dataset as hrf +from bob.ip.binseg.configs.datasets.chasedb11024 import dataset as chase +import torch + +#### Config #### + +# PyTorch dataset +dataset = torch.utils.data.ConcatDataset([drive,stare,hrf,chase]) \ No newline at end of file diff --git a/bob/ip/binseg/configs/datasets/hrf1024.py b/bob/ip/binseg/configs/datasets/hrf1024.py new file mode 100644 index 00000000..2574901a --- /dev/null +++ b/bob/ip/binseg/configs/datasets/hrf1024.py @@ -0,0 +1,24 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +from bob.db.hrf import Database as HRF +from bob.ip.binseg.data.transforms import * +from bob.ip.binseg.data.binsegdataset import BinSegDataset + +#### Config #### + +transforms = Compose([ + Pad((0,584,0,584)) + ,Resize((1024)) + ,RandomRotation() + ,RandomHFlip() + ,RandomVFlip() + ,ColorJitter() + ,ToTensor() + ]) + +# bob.db.dataset init +bobdb = HRF(protocol = 'default') + +# PyTorch dataset +dataset = BinSegDataset(bobdb, split='train', transform=transforms) \ No newline at end of file diff --git a/bob/ip/binseg/engine/valtrainer.py b/bob/ip/binseg/engine/valtrainer.py new file mode 100644 index 00000000..838798ee --- /dev/null +++ b/bob/ip/binseg/engine/valtrainer.py @@ -0,0 +1,151 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +import os +import logging +import time +import datetime +import torch +import pandas as pd +from tqdm import tqdm + +from bob.ip.binseg.utils.metric import SmoothedValue +from bob.ip.binseg.utils.plot import loss_curve +from bob.ip.binseg.engine.inferencer import do_inference + + +def do_valtrain( + model, + data_loader, + optimizer, + criterion, + scheduler, + checkpointer, + checkpoint_period, + device, + arguments, + output_folder, + val_loader = None +): + """ + Train model and save to disk. + + Parameters + ---------- + model : :py:class:`torch.nn.Module` + Network (e.g. DRIU, HED, UNet) + data_loader : :py:class:`torch.utils.data.DataLoader` + optimizer : :py:mod:`torch.optim` + criterion : :py:class:`torch.nn.modules.loss._Loss` + loss function + scheduler : :py:mod:`torch.optim` + learning rate scheduler + checkpointer : :py:class:`bob.ip.binseg.utils.checkpointer.DetectronCheckpointer` + checkpointer + checkpoint_period : int + save a checkpoint every n epochs + device : str + device to use ``'cpu'`` or ``'cuda'`` + arguments : dict + start end end epochs + output_folder : str + output path + """ + torch.manual_seed(42) + logger = logging.getLogger("bob.ip.binseg.engine.trainer") + logger.info("Start training") + start_epoch = arguments["epoch"] + max_epoch = arguments["max_epoch"] + + # Logg to file + with open (os.path.join(output_folder,"{}_trainlog.csv".format(model.name)), "a+", 1) as outfile: + + model.train().to(device) + # Total training timer + start_training_time = time.time() + + for epoch in range(start_epoch, max_epoch): + scheduler.step() + losses = SmoothedValue(len(data_loader)) + epoch = epoch + 1 + arguments["epoch"] = epoch + + # Epoch time + start_epoch_time = time.time() + + for samples in tqdm(data_loader): + + images = samples[1].to(device) + ground_truths = samples[2].to(device) + masks = None + if len(samples) == 4: + masks = samples[-1].to(device) + + outputs = model(images) + + loss = criterion(outputs, ground_truths, masks) + optimizer.zero_grad() + loss.backward() + optimizer.step() + + losses.update(loss) + logger.debug("batch loss: {}".format(loss.item())) + + if epoch % checkpoint_period == 0: + checkpointer.save("model_{:03d}".format(epoch), **arguments) + val_folder = os.path.join((output_folder,epoch)) + do_inference(model,val_loader, device, val_folder) + + if epoch == max_epoch: + checkpointer.save("model_final", **arguments) + + epoch_time = time.time() - start_epoch_time + + + eta_seconds = epoch_time * (max_epoch - epoch) + eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) + + outfile.write(("{epoch}, " + "{avg_loss:.6f}, " + "{median_loss:.6f}, " + "{lr:.6f}, " + "{memory:.0f}" + "\n" + ).format( + eta=eta_string, + epoch=epoch, + avg_loss=losses.avg, + median_loss=losses.median, + lr=optimizer.param_groups[0]["lr"], + memory = (torch.cuda.max_memory_allocated() / 1024.0 / 1024.0) if torch.cuda.is_available() else .0, + ) + ) + logger.info(("eta: {eta}, " + "epoch: {epoch}, " + "avg. loss: {avg_loss:.6f}, " + "median loss: {median_loss:.6f}, " + "lr: {lr:.6f}, " + "max mem: {memory:.0f}" + ).format( + eta=eta_string, + epoch=epoch, + avg_loss=losses.avg, + median_loss=losses.median, + lr=optimizer.param_groups[0]["lr"], + memory = (torch.cuda.max_memory_allocated() / 1024.0 / 1024.0) if torch.cuda.is_available() else .0 + ) + ) + + + total_training_time = time.time() - start_training_time + total_time_str = str(datetime.timedelta(seconds=total_training_time)) + logger.info( + "Total training time: {} ({:.4f} s / epoch)".format( + total_time_str, total_training_time / (max_epoch) + )) + + log_plot_file = os.path.join(output_folder,"{}_trainlog.pdf".format(model.name)) + logdf = pd.read_csv(os.path.join(output_folder,"{}_trainlog.csv".format(model.name)),header=None, names=["avg. loss", "median loss","lr","max memory"]) + fig = loss_curve(logdf,output_folder) + logger.info("saving {}".format(log_plot_file)) + fig.savefig(log_plot_file) diff --git a/bob/ip/binseg/script/binseg.py b/bob/ip/binseg/script/binseg.py index 2a90d5ed..c439b93c 100644 --- a/bob/ip/binseg/script/binseg.py +++ b/bob/ip/binseg/script/binseg.py @@ -24,12 +24,14 @@ from bob.extension.scripts.click_helper import (verbosity_option, from bob.ip.binseg.utils.checkpointer import DetectronCheckpointer from torch.utils.data import DataLoader from bob.ip.binseg.engine.trainer import do_train +from bob.ip.binseg.engine.valtrainer import do_valtrain from bob.ip.binseg.engine.inferencer import do_inference from bob.ip.binseg.utils.plot import plot_overview from bob.ip.binseg.utils.click import OptionEatAll from bob.ip.binseg.utils.pdfcreator import create_pdf, get_paths from bob.ip.binseg.utils.rsttable import create_overview_grid from bob.ip.binseg.utils.plot import metricsviz, overlay +from torch.utils.data import SubsetRandomSampler logger = logging.getLogger(__name__) @@ -308,12 +310,17 @@ def testcheckpoints(model '-o', required=True, ) +@click.option( + '--title', + '-t', + required=False, + ) @verbosity_option(cls=ResourceOption) -def compare(output_path_list, output_path, **kwargs): +def compare(output_path_list, output_path, title, **kwargs): """ Compares multiple metrics files that are stored in the format mymodel/results/Metrics.csv """ logger.debug("Output paths: {}".format(output_path_list)) logger.info('Plotting precision vs recall curves for {}'.format(output_path_list)) - fig = plot_overview(output_path_list) + fig = plot_overview(output_path_list,title) if not os.path.exists(output_path): os.makedirs(output_path) fig_filename = os.path.join(output_path, 'precision_recall_comparison.pdf') logger.info('saving {}'.format(fig_filename)) @@ -390,4 +397,142 @@ def visualize(dataset, output_path, **kwargs): logger.info('Creating TP, FP, FN visualizations for {}'.format(output_path)) metricsviz(dataset=dataset, output_path=output_path) logger.info('Creating overlay visualizations for {}'.format(output_path)) - overlay(dataset=dataset, output_path=output_path) \ No newline at end of file + overlay(dataset=dataset, output_path=output_path) + + +# Validation Train +@binseg.command(entry_point_group='bob.ip.binseg.config', cls=ConfigCommand) +@click.option( + '--output-path', + '-o', + required=True, + default="output", + cls=ResourceOption + ) +@click.option( + '--model', + '-m', + required=True, + cls=ResourceOption + ) +@click.option( + '--dataset', + '-d', + required=True, + cls=ResourceOption + ) +@click.option( + '--optimizer', + required=True, + cls=ResourceOption + ) +@click.option( + '--criterion', + required=True, + cls=ResourceOption + ) +@click.option( + '--scheduler', + required=True, + cls=ResourceOption + ) +@click.option( + '--pretrained-backbone', + '-t', + required=True, + cls=ResourceOption + ) +@click.option( + '--batch-size', + '-b', + required=True, + default=2, + cls=ResourceOption) +@click.option( + '--epochs', + '-e', + help='Number of epochs used for training', + show_default=True, + required=True, + default=6, + cls=ResourceOption) +@click.option( + '--checkpoint-period', + '-p', + help='Number of epochs after which a checkpoint is saved', + show_default=True, + required=True, + default=2, + cls=ResourceOption) +@click.option( + '--device', + '-d', + help='A string indicating the device to use (e.g. "cpu" or "cuda:0"', + show_default=True, + required=True, + default='cpu', + cls=ResourceOption) +@click.option( + '--valsize', + '-a', + help='Size of validation set', + show_default=True, + required=True, + default=5, + cls=ResourceOption) +@verbosity_option(cls=ResourceOption) +def valtrain(model + ,optimizer + ,scheduler + ,output_path + ,epochs + ,pretrained_backbone + ,batch_size + ,criterion + ,dataset + ,checkpoint_period + ,device + ,valsize + ,**kwargs): + """ Train a model """ + + if not os.path.exists(output_path): os.makedirs(output_path) + + + # Validation and training set size + train_size = len(dataset) - valsize + # PyTorch dataloader + + indices = torch.randperm(len(dataset)) + train_indices = indices[:len(indices)-valsize][:train_size or None] + valid_indices = indices[len(indices)-valsize:] if valsize else None + + train_loader = torch.utils.data.DataLoader(dataset, pin_memory=torch.cuda.is_available(), batch_size=batch_size, + sampler=SubsetRandomSampler(train_indices)) + + valid_loader = torch.utils.data.DataLoader(dataset, pin_memory=torch.cuda.is_available(), batch_size=batch_size, + sampler=SubsetRandomSampler(valid_indices)) + + # Checkpointer + checkpointer = DetectronCheckpointer(model, optimizer, scheduler,save_dir = output_path, save_to_disk=True) + arguments = {} + arguments["epoch"] = 0 + extra_checkpoint_data = checkpointer.load(pretrained_backbone) + arguments.update(extra_checkpoint_data) + arguments["max_epoch"] = epochs + + # Train + logger.info("Training for {} epochs".format(arguments["max_epoch"])) + logger.info("Continuing from epoch {}".format(arguments["epoch"])) + do_valtrain(model + , train_loader + , optimizer + , criterion + , scheduler + , checkpointer + , checkpoint_period + , device + , arguments + , output_path + , valid_loader + ) \ No newline at end of file diff --git a/bob/ip/binseg/utils/plot.py b/bob/ip/binseg/utils/plot.py index bedb7625..06dd666d 100644 --- a/bob/ip/binseg/utils/plot.py +++ b/bob/ip/binseg/utils/plot.py @@ -158,7 +158,7 @@ def read_metricscsv(file): return np.array(precision), np.array(recall) -def plot_overview(outputfolders): +def plot_overview(outputfolders,title): """ Plots comparison chart of all trained models @@ -166,6 +166,8 @@ def plot_overview(outputfolders): ---------- outputfolder : list list containing output paths of all evaluated models (e.g. ``['DRIVE/model1', 'DRIVE/model2']``) + title : str + title of plot Returns ------- matplotlib.figure.Figure @@ -181,15 +183,16 @@ def plot_overview(outputfolders): precisions.append(pr) recalls.append(re) modelname = folder.split('/')[-1] + datasetname = folder.split('/')[-2] # parameters summary_path = os.path.join(folder,'results/ModelSummary.txt') with open (summary_path, "r") as outfile: rows = outfile.readlines() lastrow = rows[-1] parameter = int(lastrow.split()[1].replace(',','')) - name = '[P={:.2f}M] {}'.format(parameter/100**3, modelname) + name = '[P={:.2f}M] {} {}'.format(parameter/100**3, modelname, datasetname) names.append(name) - title = folder.split('/')[-2] + #title = folder.split('/')[-4] fig = precision_recall_f1iso(precisions,recalls,names,title) return fig diff --git a/setup.py b/setup.py index 20671d7f..a7b02f03 100644 --- a/setup.py +++ b/setup.py @@ -78,10 +78,12 @@ setup( 'DRIVE1024 = bob.ip.binseg.configs.datasets.drive1024', 'DRIVE1168 = bob.ip.binseg.configs.datasets.drive1168', 'DRIVETEST = bob.ip.binseg.configs.datasets.drivetest', + 'DRIVESTARECHASEDB1HRF1024 = bob.ip.binseg.configs.datasets.drivestarechasedb1hrf1024', 'DRIVESTARECHASEDB1IOSTAR1168 = bob.ip.binseg.configs.datasets.drivestarechasedb1iostar1168', 'DRIVESTAREIOSTARHRF960 = bob.ip.binseg.configs.datasets.drivestareiostarhrf960', 'HRF = bob.ip.binseg.configs.datasets.hrf', 'HRF960 = bob.ip.binseg.configs.datasets.hrf960', + 'HRF1024 = bob.ip.binseg.configs.datasets.hrf1024', 'HRF1168 = bob.ip.binseg.configs.datasets.hrf1168', 'HRFTEST = bob.ip.binseg.configs.datasets.hrftest', 'IOSTAROD = bob.ip.binseg.configs.datasets.iostarod', -- GitLab