diff --git a/bob/ip/binseg/configs/models/driubn.py b/bob/ip/binseg/configs/models/driubn.py new file mode 100644 index 0000000000000000000000000000000000000000..0b95501d0a61053fd74b64976f6a761255944ece --- /dev/null +++ b/bob/ip/binseg/configs/models/driubn.py @@ -0,0 +1,38 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +from torch.optim.lr_scheduler import MultiStepLR +from bob.ip.binseg.modeling.driubn import build_driu +import torch.optim as optim +from torch.nn import BCEWithLogitsLoss +from bob.ip.binseg.utils.model_zoo import modelurls +from bob.ip.binseg.modeling.losses import SoftJaccardBCELogitsLoss +from bob.ip.binseg.engine.adabound import AdaBound + +##### Config ##### +lr = 0.001 +betas = (0.9, 0.999) +eps = 1e-08 +weight_decay = 0 +final_lr = 0.1 +gamma = 1e-3 +eps = 1e-8 +amsbound = False + +scheduler_milestones = [900] +scheduler_gamma = 0.1 + +# model +model = build_driu() + +# pretrained backbone +pretrained_backbone = modelurls['vgg16_bn'] + +# optimizer +optimizer = AdaBound(model.parameters(), lr=lr, betas=betas, final_lr=final_lr, gamma=gamma, + eps=eps, weight_decay=weight_decay, amsbound=amsbound) +# criterion +criterion = SoftJaccardBCELogitsLoss(alpha=0.7) + +# scheduler +scheduler = MultiStepLR(optimizer, milestones=scheduler_milestones, gamma=scheduler_gamma) diff --git a/bob/ip/binseg/configs/models/driubnssl.py b/bob/ip/binseg/configs/models/driubnssl.py new file mode 100644 index 0000000000000000000000000000000000000000..52b3a2b35272b99d5f47bae8f23d47da15990135 --- /dev/null +++ b/bob/ip/binseg/configs/models/driubnssl.py @@ -0,0 +1,39 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +from torch.optim.lr_scheduler import MultiStepLR +from bob.ip.binseg.modeling.driubn import build_driu +import torch.optim as optim +from torch.nn import BCEWithLogitsLoss +from bob.ip.binseg.utils.model_zoo import modelurls +from bob.ip.binseg.modeling.losses import MixJacLoss +from bob.ip.binseg.engine.adabound import AdaBound + +##### Config ##### +lr = 0.001 +betas = (0.9, 0.999) +eps = 1e-08 +weight_decay = 0 +final_lr = 0.1 +gamma = 1e-3 +eps = 1e-8 +amsbound = False + +scheduler_milestones = [900] +scheduler_gamma = 0.1 + +# model +model = build_driu() + +# pretrained backbone +pretrained_backbone = modelurls['vgg16_bn'] + +# optimizer +optimizer = AdaBound(model.parameters(), lr=lr, betas=betas, final_lr=final_lr, gamma=gamma, + eps=eps, weight_decay=weight_decay, amsbound=amsbound) + +# criterion +criterion = MixJacLoss(lambda_u=0.05, jacalpha=0.7) + +# scheduler +scheduler = MultiStepLR(optimizer, milestones=scheduler_milestones, gamma=scheduler_gamma) diff --git a/bob/ip/binseg/engine/inferencer.py b/bob/ip/binseg/engine/inferencer.py index a1017ed5d192980b431a846c18567a40a4159906..ccff70d4019c6856334e982fb861d315c2d60196 100644 --- a/bob/ip/binseg/engine/inferencer.py +++ b/bob/ip/binseg/engine/inferencer.py @@ -192,10 +192,13 @@ def do_inference( logger.info("Saving average over all input images: {}".format(metrics_file)) avg_metrics = df_metrics.groupby('threshold').mean() + std_metrics = df_metrics.groupby('threshold').std() avg_metrics["f1_score"] = (2* avg_metrics["precision"]*avg_metrics["recall"])/ \ (avg_metrics["precision"]+avg_metrics["recall"]) + avg_metrics["std_f1"] = std_metrics["f1_score"] + avg_metrics.to_csv(metrics_path) maxf1 = avg_metrics['f1_score'].max() optimal_f1_threshold = avg_metrics['f1_score'].idxmax() diff --git a/bob/ip/binseg/modeling/driubn.py b/bob/ip/binseg/modeling/driubn.py new file mode 100644 index 0000000000000000000000000000000000000000..4c70dc4bc13c1a1f4101fa16fba64a4ff9d66e70 --- /dev/null +++ b/bob/ip/binseg/modeling/driubn.py @@ -0,0 +1,82 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +import torch +from torch import nn +from collections import OrderedDict +from bob.ip.binseg.modeling.backbones.vgg import vgg16_bn +from bob.ip.binseg.modeling.make_layers import conv_with_kaiming_uniform,convtrans_with_kaiming_uniform, UpsampleCropBlock + +class ConcatFuseBlock(nn.Module): + """ + Takes in four feature maps with 16 channels each, concatenates them + and applies a 1x1 convolution with 1 output channel. + """ + def __init__(self): + super().__init__() + self.conv = nn.Sequential( + conv_with_kaiming_uniform(4*16,1,1,1,0) + ,nn.BatchNorm2d(1) + ) + def forward(self,x1,x2,x3,x4): + x_cat = torch.cat([x1,x2,x3,x4],dim=1) + x = self.conv(x_cat) + return x + +class DRIU(nn.Module): + """ + DRIU head module + Based on paper by `Maninis et al. (2016)`_ + Parameters + ---------- + in_channels_list : list + number of channels for each feature map that is returned from backbone + """ + def __init__(self, in_channels_list=None): + super(DRIU, self).__init__() + in_conv_1_2_16, in_upsample2, in_upsample_4, in_upsample_8 = in_channels_list + + self.conv1_2_16 = nn.Conv2d(in_conv_1_2_16, 16, 3, 1, 1) + # Upsample layers + self.upsample2 = UpsampleCropBlock(in_upsample2, 16, 4, 2, 0) + self.upsample4 = UpsampleCropBlock(in_upsample_4, 16, 8, 4, 0) + self.upsample8 = UpsampleCropBlock(in_upsample_8, 16, 16, 8, 0) + + # Concat and Fuse + self.concatfuse = ConcatFuseBlock() + + def forward(self,x): + """ + Parameters + ---------- + x : list + list of tensors as returned from the backbone network. + First element: height and width of input image. + Remaining elements: feature maps for each feature level. + + Returns + ------- + :py:class:`torch.Tensor` + """ + hw = x[0] + conv1_2_16 = self.conv1_2_16(x[1]) # conv1_2_16 + upsample2 = self.upsample2(x[2], hw) # side-multi2-up + upsample4 = self.upsample4(x[3], hw) # side-multi3-up + upsample8 = self.upsample8(x[4], hw) # side-multi4-up + out = self.concatfuse(conv1_2_16, upsample2, upsample4, upsample8) + return out + +def build_driu(): + """ + Adds backbone and head together + + Returns + ------- + :py:class:torch.nn.Module + """ + backbone = vgg16_bn(pretrained=False, return_features = [3, 8, 14, 22]) + driu_head = DRIU([64, 128, 256, 512]) + + model = nn.Sequential(OrderedDict([("backbone", backbone), ("head", driu_head)])) + model.name = "DRIUBN" + return model \ No newline at end of file diff --git a/bob/ip/binseg/script/binseg.py b/bob/ip/binseg/script/binseg.py index 7ddfe2afb6ed569cb03b770b5408a44b6b66d14d..ee8118afaa3dab9385e8e440a8ad37519b29c29a 100644 --- a/bob/ip/binseg/script/binseg.py +++ b/bob/ip/binseg/script/binseg.py @@ -113,6 +113,14 @@ def binseg(): required=True, default='cpu', cls=ResourceOption) +@click.option( + '--seed', + '-s', + help='torch random seed', + show_default=True, + required=False, + default=42, + cls=ResourceOption) @verbosity_option(cls=ResourceOption) def train(model @@ -126,11 +134,12 @@ def train(model ,dataset ,checkpoint_period ,device + ,seed ,**kwargs): """ Train a model """ if not os.path.exists(output_path): os.makedirs(output_path) - + torch.manual_seed(seed) # PyTorch dataloader data_loader = DataLoader( dataset = dataset @@ -481,6 +490,14 @@ def visualize(dataset, output_path, **kwargs): required=True, default='900', cls=ResourceOption) +@click.option( + '--seed', + '-s', + help='torch random seed', + show_default=True, + required=False, + default=42, + cls=ResourceOption) @verbosity_option(cls=ResourceOption) def ssltrain(model @@ -495,11 +512,12 @@ def ssltrain(model ,checkpoint_period ,device ,rampup + ,seed ,**kwargs): """ Train a model """ if not os.path.exists(output_path): os.makedirs(output_path) - + torch.manual_seed(seed) # PyTorch dataloader data_loader = DataLoader( dataset = dataset diff --git a/bob/ip/binseg/utils/plot.py b/bob/ip/binseg/utils/plot.py index e8eab7da9c94b06ed82f353ef3ba626e2879b9d7..594b97f06f2391c2524104f74d751b28a6ffa7db 100644 --- a/bob/ip/binseg/utils/plot.py +++ b/bob/ip/binseg/utils/plot.py @@ -190,7 +190,7 @@ def plot_overview(outputfolders,title): rows = outfile.readlines() lastrow = rows[-1] parameter = int(lastrow.split()[1].replace(',','')) - name = '[P={:.2f}M] {} {}'.format(parameter/100**3, modelname, datasetname) + name = '[P={:.2f}M] {} {}'.format(parameter/100**3, modelname, "") names.append(name) #title = folder.split('/')[-4] fig = precision_recall_f1iso(precisions,recalls,names,title) diff --git a/doc/benchmarkresults.rst b/doc/benchmarkresults.rst index fb30961cb203ea3f2fe1f83f96c34d99e26d2403..ad608b2e3d1e8c598057c8ff7751ff86fb9d8758 100644 --- a/doc/benchmarkresults.rst +++ b/doc/benchmarkresults.rst @@ -6,20 +6,62 @@ Benchmark Results ================== -Dice Scores +F1 Scores =========== * Benchmark results for models: DRIU, HED, M2UNet and U-Net. -* Train-Test split as indicated in :ref:`bob.ip.binseg.datasets` - -+--------+----------+--------+---------+--------+--------+ -| | CHASEDB1 | DRIVE | HRF1168 | IOSTAR | STARE | -+--------+----------+--------+---------+--------+--------+ -| DRIU | 0.8114 | 0.8226 | 0.7865 | 0.8273 | 0.8286 | -+--------+----------+--------+---------+--------+--------+ -| HED | 0.8111 | 0.8192 | 0.7868 | 0.8275 | 0.8250 | -+--------+----------+--------+---------+--------+--------+ -| M2UNet | 0.8035 | 0.8051 | 0.7838 | 0.8194 | 0.8174 | -+--------+----------+--------+---------+--------+--------+ -| UNet | 0.8136 | 0.8237 | 0.7941 | 0.8203 | 0.8306 | -+--------+----------+--------+---------+--------+--------+ +* Models are trained and tested on the same dataset using the train-test split as indicated in :ref:`bob.ip.binseg.datasets` +* standard-deviations across all test images are indicated in brakets + ++----------+-----------------+-----------------+-----------------+-----------------+-----------------+ +| F1 (std) | CHASEDB1 | DRIVE | HRF1168 | IOSTAR | STARE | ++----------+-----------------+-----------------+-----------------+-----------------+-----------------+ +| DRIU | 0.8114 (0.0206) | 0.8226 (0.0142) | 0.7865 (0.0545) | 0.8273 (0.0199) | 0.8286 (0.0368) | ++----------+-----------------+-----------------+-----------------+-----------------+-----------------+ +| HED | 0.8111 (0.0214) | 0.8192 (0.0136) | 0.7868 (0.0576) | 0.8275 (0.0201) | 0.8250 (0.0375) | ++----------+-----------------+-----------------+-----------------+-----------------+-----------------+ +| M2UNet | 0.8035 (0.0195) | 0.8051 (0.0141) | 0.7838 (0.0572) | 0.8194 (0.0201) | 0.8174 (0.0409) | ++----------+-----------------+-----------------+-----------------+-----------------+-----------------+ +| UNet | 0.8136 (0.0209) | 0.8237 (0.0145) | 0.7914 (0.0516) | 0.8203 (0.0190) | 0.8306 (0.0421) | ++----------+-----------------+-----------------+-----------------+-----------------+-----------------+ + + +.. figure:: img/pr_CHASEDB1.png + :scale: 30 % + :align: center + :alt: model comparisons + + CHASE_DB1: Precision vs Recall curve, F1 scores and + number of parameter of each model. + +.. figure:: img/pr_DRIVE.png + :scale: 30 % + :align: center + :alt: model comparisons + + DRIVE: Precision vs Recall curve, F1 scores and + number of parameter of each model. + +.. figure:: img/pr_HRF.png + :scale: 30 % + :align: center + :alt: model comparisons + + HRF: Precision vs Recall curve, F1 scores and + number of parameter of each model. + +.. figure:: img/pr_IOSTARVESSEL.png + :scale: 30 % + :align: center + :alt: model comparisons + + IOSTAR: Precision vs Recall curve, F1 scores and + number of parameter of each model. + +.. figure:: img/pr_STARE.png + :scale: 30 % + :align: center + :alt: model comparisons + + STARE: Precision vs Recall curve, F1 scores and + number of parameter of each model. diff --git a/doc/covdresults.rst b/doc/covdresults.rst index 05c48c75481c55b76f3a505c153ae38a15641bea..94ed0ab7c3871e8e985abe01896267f9a9511acc 100644 --- a/doc/covdresults.rst +++ b/doc/covdresults.rst @@ -6,7 +6,7 @@ COVD- and COVD-SLL Results ========================== -Dice Scores +F1 Scores =========== +-------------------+---------------+---------+ diff --git a/doc/img/pr_CHASEDB1.png b/doc/img/pr_CHASEDB1.png new file mode 100644 index 0000000000000000000000000000000000000000..7fe74f4e6178af9abc8fdda8c3d1142c992110c8 Binary files /dev/null and b/doc/img/pr_CHASEDB1.png differ diff --git a/doc/img/pr_DRIVE.png b/doc/img/pr_DRIVE.png new file mode 100644 index 0000000000000000000000000000000000000000..fc9e739e31c47bf319981dc6a561e335acfb261b Binary files /dev/null and b/doc/img/pr_DRIVE.png differ diff --git a/doc/img/pr_HRF.png b/doc/img/pr_HRF.png new file mode 100644 index 0000000000000000000000000000000000000000..ac6f870ece6c4fe9d439ba5c0d5e3914eea3bcbb Binary files /dev/null and b/doc/img/pr_HRF.png differ diff --git a/doc/img/pr_IOSTARVESSEL.png b/doc/img/pr_IOSTARVESSEL.png new file mode 100644 index 0000000000000000000000000000000000000000..97ed5c7a6b8f0d7ab6c0786db55588d7b163e9bb Binary files /dev/null and b/doc/img/pr_IOSTARVESSEL.png differ diff --git a/doc/img/pr_STARE.png b/doc/img/pr_STARE.png new file mode 100644 index 0000000000000000000000000000000000000000..14603d2d3782292e66c813685fc61bca60953976 Binary files /dev/null and b/doc/img/pr_STARE.png differ diff --git a/setup.py b/setup.py index 44a23e1f588c69e2d631a9666f54957a21b8f175..859a8ded5789b9c9e000e931151144083a10669a 100644 --- a/setup.py +++ b/setup.py @@ -58,7 +58,9 @@ setup( #bob train configurations 'bob.ip.binseg.config': [ 'DRIU = bob.ip.binseg.configs.models.driu', + 'DRIUBN = bob.ip.binseg.configs.models.driubn', 'DRIUSSL = bob.ip.binseg.configs.models.driussl', + 'DRIUBNSSL = bob.ip.binseg.configs.models.driubnssl', 'DRIUOD = bob.ip.binseg.configs.models.driuod', 'HED = bob.ip.binseg.configs.models.hed', 'M2UNet = bob.ip.binseg.configs.models.m2unet',