diff --git a/bob/ip/binseg/configs/datasets/drive1024ssliostar.py b/bob/ip/binseg/configs/datasets/drive1024ssliostar.py new file mode 100644 index 0000000000000000000000000000000000000000..0158c4a9c4595c9baf558afa6dd97d5ee1ca64b6 --- /dev/null +++ b/bob/ip/binseg/configs/datasets/drive1024ssliostar.py @@ -0,0 +1,42 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +from bob.db.drive import Database as DRIVE +from bob.db.iostar import Database as IOSTAR +from bob.ip.binseg.data.transforms import * +from bob.ip.binseg.data.binsegdataset import BinSegDataset, SSLBinSegDataset, UnLabeledBinSegDataset + +#### Config #### + +#### Unlabeled IOSTAR TRAIN #### +unlabeled_transforms = Compose([ + RandomHFlip() + ,RandomVFlip() + ,RandomRotation() + ,ColorJitter() + ,ToTensor() + ]) + +# bob.db.dataset init +sslbobdb = IOSTAR(protocol = 'default_vessel') + +# PyTorch dataset +unlabeled_dataset = UnLabeledBinSegDataset(sslbobdb, split='train', transform=unlabeled_transforms) + + +#### Labeled #### +labeled_transforms = Compose([ + CenterCrop((540,540)) + ,Resize(1024) + ,RandomHFlip() + ,RandomVFlip() + ,RandomRotation() + ,ColorJitter() + ,ToTensor() + ]) + +# bob.db.dataset init +bobdb = DRIVE(protocol = 'default') + +# PyTorch dataset +dataset = SSLBinSegDataset(bobdb, unlabeled_dataset, split='train', transform=labeled_transforms) diff --git a/bob/ip/binseg/configs/datasets/drive2336sslhrf.py b/bob/ip/binseg/configs/datasets/drive2336sslhrf.py new file mode 100644 index 0000000000000000000000000000000000000000..6ab1e77daf8195c18d3853bcb2e66bc843c6ca35 --- /dev/null +++ b/bob/ip/binseg/configs/datasets/drive2336sslhrf.py @@ -0,0 +1,44 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +from bob.db.drive import Database as DRIVE +from bob.db.hrf import Database as HRF +from bob.ip.binseg.data.transforms import * +from bob.ip.binseg.data.binsegdataset import BinSegDataset, SSLBinSegDataset, UnLabeledBinSegDataset + +#### Config #### + +#### Unlabeled HRF TRAIN #### +unlabeled_transforms = Compose([ + Crop(0,108,2336,3296) + ,RandomHFlip() + ,RandomVFlip() + ,RandomRotation() + ,ColorJitter() + ,ToTensor() + ]) + +# bob.db.dataset init +sslbobdb = HRF(protocol = 'default') + +# PyTorch dataset +unlabeled_dataset = UnLabeledBinSegDataset(sslbobdb, split='train', transform=unlabeled_transforms) + + +#### Labeled #### +labeled_transforms = Compose([ + Crop(75,10,416,544) + ,Pad((21,0,22,0)) + ,Resize(2336) + ,RandomHFlip() + ,RandomVFlip() + ,RandomRotation() + ,ColorJitter() + ,ToTensor() + ]) + +# bob.db.dataset init +bobdb = DRIVE(protocol = 'default') + +# PyTorch dataset +dataset = SSLBinSegDataset(bobdb, unlabeled_dataset, split='train', transform=labeled_transforms) diff --git a/bob/ip/binseg/configs/models/m2unetssl.py b/bob/ip/binseg/configs/models/m2unetssl.py new file mode 100644 index 0000000000000000000000000000000000000000..b0beafe6acef86a74e8955a7de7c2c6c04502037 --- /dev/null +++ b/bob/ip/binseg/configs/models/m2unetssl.py @@ -0,0 +1,39 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +from torch.optim.lr_scheduler import MultiStepLR +from bob.ip.binseg.modeling.m2u import build_m2unet +import torch.optim as optim +from torch.nn import BCEWithLogitsLoss +from bob.ip.binseg.utils.model_zoo import modelurls +from bob.ip.binseg.modeling.losses import MixJacLoss +from bob.ip.binseg.engine.adabound import AdaBound + +##### Config ##### +lr = 0.001 +betas = (0.9, 0.999) +eps = 1e-08 +weight_decay = 0 +final_lr = 0.1 +gamma = 1e-3 +eps = 1e-8 +amsbound = False + +scheduler_milestones = [900] +scheduler_gamma = 0.1 + +# model +model = build_m2unet() + +# pretrained backbone +pretrained_backbone = modelurls['mobilenetv2'] + +# optimizer +optimizer = AdaBound(model.parameters(), lr=lr, betas=betas, final_lr=final_lr, gamma=gamma, + eps=eps, weight_decay=weight_decay, amsbound=amsbound) + +# criterion +criterion = MixJacLoss(lambda_u=0.3, jacalpha=0.7, unlabeledjacalpha=0.7) + +# scheduler +scheduler = MultiStepLR(optimizer, milestones=scheduler_milestones, gamma=scheduler_gamma) diff --git a/bob/ip/binseg/data/binsegdataset.py b/bob/ip/binseg/data/binsegdataset.py index 86fc97df482707a9715b38790741aff65bfe2b52..0f3ca24730c7f3b83880ee42137555057e9218eb 100644 --- a/bob/ip/binseg/data/binsegdataset.py +++ b/bob/ip/binseg/data/binsegdataset.py @@ -1,6 +1,7 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- from torch.utils.data import Dataset +import random class BinSegDataset(Dataset): """PyTorch dataset wrapper around bob.db binary segmentation datasets. @@ -62,3 +63,113 @@ class BinSegDataset(Dataset): sample.insert(0,img_name) return sample + + +class SSLBinSegDataset(Dataset): + """PyTorch dataset wrapper around bob.db binary segmentation datasets. + A transform object can be passed that will be applied to the image, ground truth and mask (if present). + It supports indexing such that dataset[i] can be used to get ith sample. + + Parameters + ---------- + bobdb : :py:mod:`bob.db.base` + Binary segmentation bob database (e.g. bob.db.drive) + unlabeled_dataset : :py:class:`torch.utils.data.Dataset` + dataset with unlabeled data + split : str + ``'train'`` or ``'test'``. Defaults to ``'train'`` + transform : :py:mod:`bob.ip.binseg.data.transforms`, optional + A transform or composition of transfroms. Defaults to ``None``. + """ + def __init__(self, bobdb, unlabeled_dataset, split = 'train', transform = None): + self.database = bobdb.samples(split) + self.unlabeled_dataset = unlabeled_dataset + self.transform = transform + self.split = split + + + def __len__(self): + """ + Returns + ------- + int + size of the dataset + """ + return len(self.database) + + def __getitem__(self,index): + """ + Parameters + ---------- + index : int + + Returns + ------- + list + dataitem [img_name, img, gt, unlabeled_img_name, unlabeled_img] + """ + img = self.database[index].img.pil_image() + gt = self.database[index].gt.pil_image() + img_name = self.database[index].img.basename + sample = [img, gt] + + + if self.transform : + sample = self.transform(*sample) + + sample.insert(0,img_name) + unlabeled_img_name, unlabeled_img = self.unlabeled_dataset[index] + sample.extend([unlabeled_img_name, unlabeled_img]) + return sample + + +class UnLabeledBinSegDataset(Dataset): + # TODO: if switch to handle case were not a bob.db object but a path to a directory is used + """PyTorch dataset wrapper around bob.db binary segmentation datasets. + A transform object can be passed that will be applied to the image, ground truth and mask (if present). + It supports indexing such that dataset[i] can be used to get ith sample. + + Parameters + ---------- + dv : :py:mod:`bob.db.base` or str + Binary segmentation bob database (e.g. bob.db.drive) or path to folder containing unlabeled images + split : str + ``'train'`` or ``'test'``. Defaults to ``'train'`` + transform : :py:mod:`bob.ip.binseg.data.transforms`, optional + A transform or composition of transfroms. Defaults to ``None``. + """ + def __init__(self, db, split = 'train', transform = None): + self.database = db.samples(split) + self.transform = transform + self.split = split + + def __len__(self): + """ + Returns + ------- + int + size of the dataset + """ + return len(self.database) + + def __getitem__(self,index): + """ + Parameters + ---------- + index : int + + Returns + ------- + list + dataitem [img_name, img] + """ + random.shuffle(self.database) + img = self.database[index].img.pil_image() + img_name = self.database[index].img.basename + sample = [img] + if self.transform : + sample = self.transform(img) + + sample.insert(0,img_name) + + return sample \ No newline at end of file diff --git a/bob/ip/binseg/engine/ssltrainer.py b/bob/ip/binseg/engine/ssltrainer.py new file mode 100644 index 0000000000000000000000000000000000000000..b0489558e075b7fb873d21a4b3cf50db954eacbb --- /dev/null +++ b/bob/ip/binseg/engine/ssltrainer.py @@ -0,0 +1,239 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +import os +import logging +import time +import datetime +import torch +import pandas as pd +from tqdm import tqdm +import numpy as np + +from bob.ip.binseg.utils.metric import SmoothedValue +from bob.ip.binseg.utils.plot import loss_curve + +def mix_up(alpha, input, target, unlabeled_input, unlabled_target): + """Applies mix up as described in [MIXMATCH_19]. + + Parameters + ---------- + alpha : float + input : :py:class:`torch.Tensor` + target : :py:class:`torch.Tensor` + unlabeled_input : :py:class:`torch.Tensor` + unlabled_target : :py:class:`torch.Tensor` + + Returns + ------- + list + """ + l = np.random.beta(alpha, alpha) # Eq (8) + l = max(l, 1 - l) # Eq (9) + # Shuffle and concat. Alg. 1 Line: 12 + w_inputs = torch.cat([input,unlabeled_input],0) + w_targets = torch.cat([target,unlabled_target],0) + idx = torch.randperm(w_inputs.size(0)) # get random index + + # Apply MixUp to labeled data and entries from W. Alg. 1 Line: 13 + input_mixedup = l * input + (1 - l) * w_inputs[idx[len(input):]] + target_mixedup = l * target + (1 - l) * w_targets[idx[len(target):]] + + # Apply MixUp to unlabeled data and entries from W. Alg. 1 Line: 14 + unlabeled_input_mixedup = l * unlabeled_input + (1 - l) * w_inputs[idx[:len(unlabeled_input)]] + unlabled_target_mixedup = l * unlabled_target + (1 - l) * w_targets[idx[:len(unlabled_target)]] + return input_mixedup, target_mixedup, unlabeled_input_mixedup, unlabled_target_mixedup + + +def linear_rampup(current, rampup_length=16): + """slowly ramp-up ``lambda_u`` + + Parameters + ---------- + current : int + current epoch + rampup_length : int, optional + how long to ramp up, by default 16 + + Returns + ------- + float + ramp up factor + """ + if rampup_length == 0: + return 1.0 + else: + current = np.clip(current / rampup_length, 0.0, 1.0) + return float(current) + +def guess_labels(unlabeled_images, model): + """ + Calculate the average predictions by 2 augmentations: horizontal and vertical flips + Parameters + ---------- + unlabeled_images : :py:class:`torch.Tensor` + shape: ``[n,c,h,w]`` + target : :py:class:`torch.Tensor` + + Returns + ------- + :py:class:`torch.Tensor` + shape: ``[n,c,h,w]``. + """ + with torch.no_grad(): + guess1 = torch.sigmoid(model(unlabeled_images)).unsqueeze(0) + # Horizontal flip and unsqueeze to work with batches (increase flip dimension by 1) + hflip = torch.sigmoid(model(unlabeled_images.flip(2))).unsqueeze(0) + guess2 = hflip.flip(3) + # Vertical flip and unsqueeze to work with batches (increase flip dimension by 1) + vflip = torch.sigmoid(model(unlabeled_images.flip(3))).unsqueeze(0) + guess3 = vflip.flip(4) + # Concat + concat = torch.cat([guess1,guess2,guess3],0) + avg_guess = torch.mean(concat,0) + return avg_guess + +def do_ssltrain( + model, + data_loader, + optimizer, + criterion, + scheduler, + checkpointer, + checkpoint_period, + device, + arguments, + output_folder +): + """ + Train model and save to disk. + + Parameters + ---------- + model : :py:class:`torch.nn.Module` + Network (e.g. DRIU, HED, UNet) + data_loader : :py:class:`torch.utils.data.DataLoader` + optimizer : :py:mod:`torch.optim` + criterion : :py:class:`torch.nn.modules.loss._Loss` + loss function + scheduler : :py:mod:`torch.optim` + learning rate scheduler + checkpointer : :py:class:`bob.ip.binseg.utils.checkpointer.DetectronCheckpointer` + checkpointer + checkpoint_period : int + save a checkpoint every n epochs + device : str + device to use ``'cpu'`` or ``'cuda'`` + arguments : dict + start end end epochs + output_folder : str + output path + """ + logger = logging.getLogger("bob.ip.binseg.engine.trainer") + logger.info("Start training") + start_epoch = arguments["epoch"] + max_epoch = arguments["max_epoch"] + + # Logg to file + with open (os.path.join(output_folder,"{}_trainlog.csv".format(model.name)), "a+") as outfile: + + model.train().to(device) + # Total training timer + start_training_time = time.time() + for epoch in range(start_epoch, max_epoch): + scheduler.step() + losses = SmoothedValue(len(data_loader)) + labeled_loss = SmoothedValue(len(data_loader)) + unlabeled_loss = SmoothedValue(len(data_loader)) + epoch = epoch + 1 + arguments["epoch"] = epoch + + # Epoch time + start_epoch_time = time.time() + + for samples in tqdm(data_loader): + # labeled + images = samples[1].to(device) + ground_truths = samples[2].to(device) + unlabeled_images = samples[4].to(device) + # labeled outputs + outputs = model(images) + unlabeled_outputs = model(unlabeled_images) + # guessed unlabeled outputs + unlabeled_ground_truths = guess_labels(unlabeled_images, model) + ramp_up_factor = linear_rampup(epoch,rampup_length=16) + + + loss, ll, ul = criterion(outputs, ground_truths, unlabeled_outputs, unlabeled_ground_truths, ramp_up_factor) + optimizer.zero_grad() + loss.backward() + optimizer.step() + losses.update(loss) + labeled_loss.update(ll) + unlabeled_loss.update(ul) + logger.debug("batch loss: {}".format(loss.item())) + + if epoch % checkpoint_period == 0: + checkpointer.save("model_{:03d}".format(epoch), **arguments) + + if epoch == max_epoch: + checkpointer.save("model_final", **arguments) + + epoch_time = time.time() - start_epoch_time + + + eta_seconds = epoch_time * (max_epoch - epoch) + eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) + + outfile.write(("{epoch}, " + "{avg_loss:.6f}, " + "{median_loss:.6f}, " + "{median_labeled_loss}," + "{median_unlabeled_loss}," + "{lr:.6f}, " + "{memory:.0f}" + "\n" + ).format( + eta=eta_string, + epoch=epoch, + avg_loss=losses.avg, + median_loss=losses.median, + median_labeled_loss = labeled_loss.median, + median_unlabeled_loss = unlabeled_loss.median, + lr=optimizer.param_groups[0]["lr"], + memory = (torch.cuda.max_memory_allocated() / 1024.0 / 1024.0) if torch.cuda.is_available() else .0, + ) + ) + logger.info(("eta: {eta}, " + "epoch: {epoch}, " + "avg. loss: {avg_loss:.6f}, " + "median loss: {median_loss:.6f}, " + "{median_labeled_loss}, " + "{median_unlabeled_loss}, " + "lr: {lr:.6f}, " + "max mem: {memory:.0f}" + ).format( + eta=eta_string, + epoch=epoch, + avg_loss=losses.avg, + median_loss=losses.median, + median_labeled_loss = labeled_loss.median, + median_unlabeled_loss = unlabeled_loss.median, + lr=optimizer.param_groups[0]["lr"], + memory = (torch.cuda.max_memory_allocated() / 1024.0 / 1024.0) if torch.cuda.is_available() else .0 + ) + ) + + + total_training_time = time.time() - start_training_time + total_time_str = str(datetime.timedelta(seconds=total_training_time)) + logger.info( + "Total training time: {} ({:.4f} s / epoch)".format( + total_time_str, total_training_time / (max_epoch) + )) + + log_plot_file = os.path.join(output_folder,"{}_trainlog.pdf".format(model.name)) + logdf = pd.read_csv(os.path.join(output_folder,"{}_trainlog.csv".format(model.name)),header=None, names=["avg. loss", "median loss","lr","max memory"]) + fig = loss_curve(logdf,output_folder) + logger.info("saving {}".format(log_plot_file)) + fig.savefig(log_plot_file) diff --git a/bob/ip/binseg/modeling/losses.py b/bob/ip/binseg/modeling/losses.py index 5eeb7950ebe344dadc08fcc8b21c057a081c08ed..4ab9175802de7e1cbcbe676d7e22693b1fba868d 100644 --- a/bob/ip/binseg/modeling/losses.py +++ b/bob/ip/binseg/modeling/losses.py @@ -49,9 +49,9 @@ class SoftJaccardBCELogitsLoss(_Loss): Attributes ---------- alpha : float - determines the weighting of SoftJaccard and BCE. Default: ``0.3`` + determines the weighting of SoftJaccard and BCE. Default: ``0.7`` """ - def __init__(self, alpha=0.3, size_average=None, reduce=None, reduction='mean', pos_weight=None): + def __init__(self, alpha=0.7, size_average=None, reduce=None, reduction='mean', pos_weight=None): super(SoftJaccardBCELogitsLoss, self).__init__(size_average, reduce, reduction) self.alpha = alpha @@ -160,4 +160,41 @@ class HEDSoftJaccardBCELogitsLoss(_Loss): loss = self.alpha * bceloss + (1 - self.alpha) * (1-softjaccard) loss_over_all_inputs.append(loss.unsqueeze(0)) final_loss = torch.cat(loss_over_all_inputs).mean() - return loss \ No newline at end of file + return loss + + + +class MixJacLoss(_Loss): + """ + Attributes + ---------- + lambda_u : int + determines the weighting of SoftJaccard and BCE. + """ + def __init__(self, lambda_u=0.3, jacalpha=0.7, unlabeledjacalpha=0.7, size_average=None, reduce=None, reduction='mean', pos_weight=None): + super(MixJacLoss, self).__init__(size_average, reduce, reduction) + self.lambda_u = lambda_u + self.labeled_loss = SoftJaccardBCELogitsLoss(alpha=jacalpha) + self.unlabeled_loss = SoftJaccardBCELogitsLoss(alpha=unlabeledjacalpha) + + + @weak_script_method + def forward(self, input, target, unlabeled_input, unlabeled_traget, ramp_up_factor): + """ + Parameters + ---------- + input : :py:class:`torch.Tensor` + target : :py:class:`torch.Tensor` + unlabeled_input : :py:class:`torch.Tensor` + unlabeled_traget : :py:class:`torch.Tensor` + ramp_up_factor : float + + Returns + ------- + list + """ + ll = self.labeled_loss(input,target) + ul = self.unlabeled_loss(unlabeled_input, unlabeled_traget) + + loss = ll + self.lambda_u * ramp_up_factor * ul + return loss, ll, ul \ No newline at end of file diff --git a/bob/ip/binseg/script/binseg.py b/bob/ip/binseg/script/binseg.py index 2a90d5ed652dcab7779a97770caf02fec1f39622..16b59f2d83c468b10786e793a3d4a266a9df13b5 100644 --- a/bob/ip/binseg/script/binseg.py +++ b/bob/ip/binseg/script/binseg.py @@ -24,6 +24,7 @@ from bob.extension.scripts.click_helper import (verbosity_option, from bob.ip.binseg.utils.checkpointer import DetectronCheckpointer from torch.utils.data import DataLoader from bob.ip.binseg.engine.trainer import do_train +from bob.ip.binseg.engine.ssltrainer import do_ssltrain from bob.ip.binseg.engine.inferencer import do_inference from bob.ip.binseg.utils.plot import plot_overview from bob.ip.binseg.utils.click import OptionEatAll @@ -390,4 +391,126 @@ def visualize(dataset, output_path, **kwargs): logger.info('Creating TP, FP, FN visualizations for {}'.format(output_path)) metricsviz(dataset=dataset, output_path=output_path) logger.info('Creating overlay visualizations for {}'.format(output_path)) - overlay(dataset=dataset, output_path=output_path) \ No newline at end of file + overlay(dataset=dataset, output_path=output_path) + + +# SSLTrain +@binseg.command(entry_point_group='bob.ip.binseg.config', cls=ConfigCommand) +@click.option( + '--output-path', + '-o', + required=True, + default="output", + cls=ResourceOption + ) +@click.option( + '--model', + '-m', + required=True, + cls=ResourceOption + ) +@click.option( + '--dataset', + '-d', + required=True, + cls=ResourceOption + ) +@click.option( + '--optimizer', + required=True, + cls=ResourceOption + ) +@click.option( + '--criterion', + required=True, + cls=ResourceOption + ) +@click.option( + '--scheduler', + required=True, + cls=ResourceOption + ) +@click.option( + '--pretrained-backbone', + '-t', + required=True, + cls=ResourceOption + ) +@click.option( + '--batch-size', + '-b', + required=True, + default=2, + cls=ResourceOption) +@click.option( + '--epochs', + '-e', + help='Number of epochs used for training', + show_default=True, + required=True, + default=6, + cls=ResourceOption) +@click.option( + '--checkpoint-period', + '-p', + help='Number of epochs after which a checkpoint is saved', + show_default=True, + required=True, + default=2, + cls=ResourceOption) +@click.option( + '--device', + '-d', + help='A string indicating the device to use (e.g. "cpu" or "cuda:0"', + show_default=True, + required=True, + default='cpu', + cls=ResourceOption) + +@verbosity_option(cls=ResourceOption) +def ssltrain(model + ,optimizer + ,scheduler + ,output_path + ,epochs + ,pretrained_backbone + ,batch_size + ,criterion + ,dataset + ,checkpoint_period + ,device + ,**kwargs): + """ Train a model """ + + if not os.path.exists(output_path): os.makedirs(output_path) + + # PyTorch dataloader + data_loader = DataLoader( + dataset = dataset + ,batch_size = batch_size + ,shuffle= True + ,pin_memory = torch.cuda.is_available() + ) + + # Checkpointer + checkpointer = DetectronCheckpointer(model, optimizer, scheduler,save_dir = output_path, save_to_disk=True) + arguments = {} + arguments["epoch"] = 0 + extra_checkpoint_data = checkpointer.load(pretrained_backbone) + arguments.update(extra_checkpoint_data) + arguments["max_epoch"] = epochs + + # Train + logger.info("Training for {} epochs".format(arguments["max_epoch"])) + logger.info("Continuing from epoch {}".format(arguments["epoch"])) + do_ssltrain(model + , data_loader + , optimizer + , criterion + , scheduler + , checkpointer + , checkpoint_period + , device + , arguments + , output_path + ) \ No newline at end of file