diff --git a/bob/ip/binseg/configs/datasets/drivetest.py b/bob/ip/binseg/configs/datasets/drivetest.py index 67d21cac20df4095716c2735665ee418eb671a88..230598dce92a39276e05dd4b4f842643428546b4 100644 --- a/bob/ip/binseg/configs/datasets/drivetest.py +++ b/bob/ip/binseg/configs/datasets/drivetest.py @@ -7,8 +7,9 @@ from bob.ip.binseg.data.binsegdataset import BinSegDataset #### Config #### -transforms = Compose([ - ToTensor() +transforms = Compose([ + CenterCrop((544,544)) + ,ToTensor() ]) # bob.db.dataset init diff --git a/bob/ip/binseg/data/binsegdataset.py b/bob/ip/binseg/data/binsegdataset.py index b75a7dde55eaec9a2d94b5ba712e5154ba29b6d9..f212c4fe680969aaef27a300d08ca6f6ec35a50c 100644 --- a/bob/ip/binseg/data/binsegdataset.py +++ b/bob/ip/binseg/data/binsegdataset.py @@ -21,6 +21,11 @@ class BinSegDataset(Dataset): self.transform = transform self.split = split + @property + def mask(self): + # check if first sample contains a mask + return hasattr(self.database[0], 'mask') + def __len__(self): """ Returns @@ -39,16 +44,19 @@ class BinSegDataset(Dataset): Returns ------- list - dataitem [img, gt, mask, img_name] + dataitem [img_name, img, gt, mask] """ img = self.database[index].img.pil_image() gt = self.database[index].gt.pil_image() - mask = self.database[index].mask.pil_image() if hasattr(self.database[index], 'mask') else None img_name = self.database[index].img.basename + sample = [img, gt] + if self.mask: + mask = self.database[index].mask.pil_image() + sample.append(mask) + + if self.transform : + sample = self.transform(*sample) + + sample.insert(0,img_name) - if self.transform and mask: - img, gt, mask = self.transform(img, gt, mask) - else: - img, gt = self.transform(img, gt) - - return img, gt, mask, img_name + return sample diff --git a/bob/ip/binseg/data/transforms.py b/bob/ip/binseg/data/transforms.py index 34b95ca7b8900cffd9fd6b45807c51f1dda06861..99e941b65099cdd1f85ba31374807f90b471131e 100644 --- a/bob/ip/binseg/data/transforms.py +++ b/bob/ip/binseg/data/transforms.py @@ -10,7 +10,7 @@ from torchvision.transforms.transforms import Compose as TorchVisionCompose import math from math import floor import warnings - +import collections _pil_interpolation_to_str = { Image.NEAREST: 'PIL.Image.NEAREST', @@ -20,6 +20,7 @@ _pil_interpolation_to_str = { Image.HAMMING: 'PIL.Image.HAMMING', Image.BOX: 'PIL.Image.BOX', } +Iterable = collections.abc.Iterable # Compose @@ -440,4 +441,32 @@ class Distortion: imgs.append(img) return imgs else: - return args \ No newline at end of file + return args + + +class Resize: + """Resize to given size. + + Attributes + ---------- + size : tuple or int + Desired output size. If size is a sequence like + (h, w), output size will be matched to this. If size is an int, + smaller edge of the image will be matched to this number. + i.e, if height > width, then image will be rescaled to + (size * height / width, size) + interpolation : int + Desired interpolation. Default is``PIL.Image.BILINEAR`` + """ + + def __init__(self, size, interpolation=Image.BILINEAR): + assert isinstance(size, int) or (isinstance(size, Iterable) and len(size) == 2) + self.size = size + self.interpolation = interpolation + + def __call__(self, *args): + return [VF.resize(img, self.size, self.interpolation) for img in args] + + def __repr__(self): + interpolate_str = _pil_interpolation_to_str[self.interpolation] + return self.__class__.__name__ + '(size={0}, interpolation={1})'.format(self.size, interpolate_str) diff --git a/bob/ip/binseg/engine/inferencer.py b/bob/ip/binseg/engine/inferencer.py index 3221277c11ff8545f0d3182952c329a82488ed50..bfdf9c1ba297b27e01206c0b7a77d25f0afd4d15 100644 --- a/bob/ip/binseg/engine/inferencer.py +++ b/bob/ip/binseg/engine/inferencer.py @@ -13,10 +13,11 @@ from tqdm import tqdm from bob.ip.binseg.utils.metric import SmoothedValue, base_metrics from bob.ip.binseg.utils.plot import precision_recall_f1iso +from bob.ip.binseg.utils.summary import summary -def batch_metrics(predictions, ground_truths, masks, names, output_folder, logger): +def batch_metrics(predictions, ground_truths, names, output_folder, logger): """ Calculates metrics on the batch and saves it to disc @@ -26,8 +27,6 @@ def batch_metrics(predictions, ground_truths, masks, names, output_folder, logge tensor with pixel-wise probabilities ground_truths : :py:class:`torch.Tensor` tensor with binary ground-truth - mask : :py:class:`torch.Tensor` - tensor with mask names : list list of file names output_folder : str @@ -73,7 +72,6 @@ def batch_metrics(predictions, ground_truths, masks, names, output_folder, logge # true negatives tn_tensor = equals - tp_tensor tn_count = torch.sum(tn_tensor).item() - # TODO: Substract masks from True negatives? # false negatives fn_tensor = notequals - fp_tensor @@ -152,9 +150,10 @@ def do_inference( # Collect overall metrics metrics = [] - for images, ground_truths, masks, names in tqdm(data_loader): - images = images.to(device) - ground_truths = ground_truths.to(device) + for samples in tqdm(data_loader): + names = samples[0] + images = samples[1].to(device) + ground_truths = samples[2].to(device) with torch.no_grad(): start_time = time.perf_counter() @@ -166,15 +165,12 @@ def do_inference( outputs = outputs[-1] probabilities = sigmoid(outputs) - if hasattr(masks,'dtype'): - masks = masks.to(device) - probabilities = probabilities * masks batch_time = time.perf_counter() - start_time times.append(batch_time) logger.info("Batch time: {:.5f} s".format(batch_time)) - b_metrics = batch_metrics(probabilities, ground_truths, masks, names,results_subfolder, logger) + b_metrics = batch_metrics(probabilities, ground_truths, names,results_subfolder, logger) metrics.extend(b_metrics) # Create probability images @@ -209,7 +205,7 @@ def do_inference( # Plotting np_avg_metrics = avg_metrics.to_numpy().T - fig_name = "precision_recall.pdf".format(model.name) + fig_name = "precision_recall.pdf" logger.info("saving {}".format(fig_name)) fig = precision_recall_f1iso([np_avg_metrics[0]],[np_avg_metrics[1]], [model.name,None]) fig_filename = os.path.join(results_subfolder, fig_name) @@ -222,7 +218,7 @@ def do_inference( logger.info("Average batch inference time: {:.5f}s".format(average_batch_inference_time)) - times_file = "Times.txt".format(model.name) + times_file = "Times.txt" logger.info("saving {}".format(times_file)) with open (os.path.join(results_subfolder,times_file), "w+") as outfile: @@ -232,4 +228,12 @@ def do_inference( outfile.write("Average batch inference time: {} \n".format(average_batch_inference_time)) outfile.write("Total inference time: {} \n".format(total_inference_time)) + # Save model summary + summary_file = 'ModelSummary.txt' + logger.info("saving {}".format(summary_file)) + + with open (os.path.join(results_subfolder,summary_file), "w+") as outfile: + summary(model,outfile) + + diff --git a/bob/ip/binseg/engine/trainer.py b/bob/ip/binseg/engine/trainer.py index ac24b55b5030ed5a51d2787fae1db5d5ef9af5e0..d703ca3060761cac010b9d99821c7e951566e349 100644 --- a/bob/ip/binseg/engine/trainer.py +++ b/bob/ip/binseg/engine/trainer.py @@ -70,12 +70,13 @@ def do_train( # Epoch time start_epoch_time = time.time() - for images, ground_truths, masks, _ in tqdm(data_loader): + for samples in tqdm(data_loader): - images = images.to(device) - ground_truths = ground_truths.to(device) - if hasattr(masks,'dtype'): - masks = masks.to(device) + images = samples[1].to(device) + ground_truths = samples[2].to(device) + masks = None + if len(samples) == 4: + masks = samples[-1].to(device) outputs = model(images) diff --git a/bob/ip/binseg/modeling/losses.py b/bob/ip/binseg/modeling/losses.py index c57bf7fe1ab851669de7b878b3f2fa3b5383637b..3e32b0904f2eb53ac145fa59b0c5537392210ef8 100644 --- a/bob/ip/binseg/modeling/losses.py +++ b/bob/ip/binseg/modeling/losses.py @@ -51,7 +51,7 @@ class SoftJaccardBCELogitsLoss(_Loss): alpha : float determines the weighting of SoftJaccard and BCE. Default: ``0.3`` """ - def __init__(self, alpha=0.1, size_average=None, reduce=None, reduction='mean', pos_weight=None): + def __init__(self, alpha=0.3, size_average=None, reduce=None, reduction='mean', pos_weight=None): super(SoftJaccardBCELogitsLoss, self).__init__(size_average, reduce, reduction) self.alpha = alpha @@ -118,4 +118,46 @@ class HEDWeightedBCELogitsLoss(_Loss): loss = torch.nn.functional.binary_cross_entropy_with_logits(input, target, weight=weight, reduction=self.reduction) loss_over_all_inputs.append(loss.unsqueeze(0)) final_loss = torch.cat(loss_over_all_inputs).mean() - return final_loss \ No newline at end of file + return final_loss + + +class HEDSoftJaccardBCELogitsLoss(_Loss): + """ + Implements Equation 6 in [SAT17]_ for the hed network. Based on torch.nn.modules.loss.BCEWithLogitsLoss. + + Attributes + ---------- + alpha : float + determines the weighting of SoftJaccard and BCE. Default: ``0.3`` + """ + def __init__(self, alpha=0.3, size_average=None, reduce=None, reduction='mean', pos_weight=None): + super(HEDSoftJaccardBCELogitsLoss, self).__init__(size_average, reduce, reduction) + self.alpha = alpha + + @weak_script_method + def forward(self, inputlist, target, masks=None): + """ + Parameters + ---------- + input : :py:class:`torch.Tensor` + target : :py:class:`torch.Tensor` + masks : :py:class:`torch.Tensor`, optional + + Returns + ------- + :py:class:`torch.Tensor` + """ + eps = 1e-8 + loss_over_all_inputs = [] + for input in inputlist: + probabilities = torch.sigmoid(input) + intersection = (probabilities * target).sum() + sums = probabilities.sum() + target.sum() + + softjaccard = intersection/(sums - intersection + eps) + + bceloss = torch.nn.functional.binary_cross_entropy_with_logits(input, target, weight=None, reduction=self.reduction) + loss = self.alpha * bceloss + (1 - self.alpha) * (1-softjaccard) + loss_over_all_inputs.append(loss.unsqueeze(0)) + final_loss = torch.cat(loss_over_all_inputs).mean() + return loss \ No newline at end of file diff --git a/bob/ip/binseg/script/binseg.py b/bob/ip/binseg/script/binseg.py index 8fc667b1bdd3d4c22f926116f82c61b4841f2579..7adedbae04c2b373b57f937b7c8cd26ee694f7a8 100644 --- a/bob/ip/binseg/script/binseg.py +++ b/bob/ip/binseg/script/binseg.py @@ -297,11 +297,12 @@ def testcheckpoints(model required=True, ) @verbosity_option(cls=ResourceOption) -def compare(output_path_list, output_path,**kwargs): +def compare(output_path_list, output_path, **kwargs): """ Compares multiple metrics files that are stored in the format mymodel/results/Metrics.csv """ logger.debug("Output paths: {}".format(output_path_list)) logger.info('Plotting precision vs recall curves for {}'.format(output_path_list)) fig = plot_overview(output_path_list) + if not os.path.exists(output_path): os.makedirs(output_path) fig_filename = os.path.join(output_path, 'precision_recall_comparison.pdf') logger.info('saving {}'.format(fig_filename)) fig.savefig(fig_filename) diff --git a/bob/ip/binseg/test/test_batchmetrics.py b/bob/ip/binseg/test/test_batchmetrics.py index 93d573e809e3ee17f2af4a61cb71a36ab49849c1..4988cab6ea8a7ffbc907f2e580ebe46b48ad9611 100644 --- a/bob/ip/binseg/test/test_batchmetrics.py +++ b/bob/ip/binseg/test/test_batchmetrics.py @@ -20,7 +20,6 @@ class Tester(unittest.TestCase): self.fn = random.randint(1, 100) self.predictions = torch.rand(size=(2,1,420,420)) self.ground_truths = torch.randint(low=0, high=2, size=(2,1,420,420)) - self.masks = None self.names = ['Bob','Tim'] self.output_folder = tempfile.mkdtemp() self.logger = logging.getLogger(__name__) @@ -30,7 +29,7 @@ class Tester(unittest.TestCase): shutil.rmtree(self.output_folder) def test_batch_metrics(self): - bm = batch_metrics(self.predictions, self.ground_truths, self.masks, self.names, self.output_folder, self.logger) + bm = batch_metrics(self.predictions, self.ground_truths, self.names, self.output_folder, self.logger) self.assertEqual(len(bm),2*100) for metric in bm: # check whether f1 score agree diff --git a/bob/ip/binseg/test/test_models.py b/bob/ip/binseg/test/test_models.py index a5f37d3fb3ef05c1241544bef53a79e1166a6b34..35e39db8c2e3641b22a8a4e563651b70c968b4dd 100644 --- a/bob/ip/binseg/test/test_models.py +++ b/bob/ip/binseg/test/test_models.py @@ -5,6 +5,7 @@ import torch import unittest import numpy as np from bob.ip.binseg.modeling.driu import build_driu +from bob.ip.binseg.modeling.driuod import build_driuod from bob.ip.binseg.modeling.hed import build_hed from bob.ip.binseg.modeling.unet import build_unet from bob.ip.binseg.modeling.resunet import build_res50unet @@ -24,6 +25,12 @@ class Tester(unittest.TestCase): self.assertEqual(self.hw.all(), out_hw.all()) + def test_driuod(self): + model = build_driuod() + out = model(self.x) + out_hw = np.array(out.shape)[[2,3]] + self.assertEqual(self.hw.all(), out_hw.all()) + def test_hed(self): model = build_hed() out = model(self.x) diff --git a/bob/ip/binseg/test/test_summary.py b/bob/ip/binseg/test/test_summary.py new file mode 100644 index 0000000000000000000000000000000000000000..7faabf796674db6b7914d631ba41f9160c08a623 --- /dev/null +++ b/bob/ip/binseg/test/test_summary.py @@ -0,0 +1,46 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +import torch +import unittest +import numpy as np +from bob.ip.binseg.modeling.driu import build_driu +from bob.ip.binseg.modeling.driuod import build_driuod +from bob.ip.binseg.modeling.hed import build_hed +from bob.ip.binseg.modeling.unet import build_unet +from bob.ip.binseg.modeling.resunet import build_res50unet +from bob.ip.binseg.utils.summary import summary + +class Tester(unittest.TestCase): + """ + Unit test for model architectures + """ + def test_summary_driu(self): + model = build_driu() + param = summary(model) + self.assertIsInstance(param,int) + + + def test__summary_driuod(self): + model = build_driuod() + param = summary(model) + self.assertIsInstance(param,int) + + + def test_summary_hed(self): + model = build_hed() + param = summary(model) + self.assertIsInstance(param,int) + + def test_summary_unet(self): + model = build_unet() + param = summary(model) + self.assertIsInstance(param,int) + + def test_summary_resunet(self): + model = build_res50unet() + param = summary(model) + self.assertIsInstance(param,int) + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/bob/ip/binseg/utils/plot.py b/bob/ip/binseg/utils/plot.py index 5a2677d5b055a9ed9581d08ee50bdc45c329d221..5e6fa29d35658c42cfb0191018ac7bc6d80123f7 100644 --- a/bob/ip/binseg/utils/plot.py +++ b/bob/ip/binseg/utils/plot.py @@ -24,7 +24,7 @@ def precision_recall_f1iso(precision, recall, names, title=None, human_perf_bsds of the system's recall coordinates. names : :py:class:`list` An iterable over the names of each of the systems along the rows of - ``precision`` and ``recall`` + ``precision`` and ``recall`` title : :py:class:`str`, optional A title for the plot. If not set, omits the title human_perf_bsds500 : :py:class:`bool`, optional @@ -38,7 +38,10 @@ def precision_recall_f1iso(precision, recall, names, title=None, human_perf_bsds import matplotlib matplotlib.use('agg') import matplotlib.pyplot as plt + from itertools import cycle fig, ax1 = plt.subplots(1) + lines = ["-","--","-.",":"] + linecycler = cycle(lines) for p, r, n in zip(precision, recall, names): # Plots only from the point where recall reaches its maximum, otherwise, we # don't see a curve... @@ -47,8 +50,13 @@ def precision_recall_f1iso(precision, recall, names, title=None, human_perf_bsds ri = r[i:] valid = (pi+ri) > 0 f1 = 2 * (pi[valid]*ri[valid]) / (pi[valid]+ri[valid]) + # optimal point along the curve + argmax = f1.argmax() + opi = pi[argmax] + ori = ri[argmax] # Plot Recall/Precision as threshold changes - ax1.plot(ri[pi>0], pi[pi>0], label='[F={:.3f}] {}'.format(f1.max(), n)) + ax1.plot(ri[pi>0], pi[pi>0], next(linecycler), label='[F={:.3f}] {}'.format(f1.max(), n),) + ax1.plot(ori,opi, marker='o', linestyle=None, markersize=3, color='black') ax1.grid(linestyle='--', linewidth=1, color='gray', alpha=0.2) if len(names) > 1: plt.legend(loc='lower left', framealpha=0.5) @@ -152,8 +160,7 @@ def plot_overview(outputfolders): Arguments --------- outputfolder : list - list containing output paths of all evaluated models (e.g. ['output/model1', 'output/model2']) - + list containing output paths of all evaluated models (e.g. ['DRIVE/model1', 'DRIVE/model2']) Returns ------- fig : matplotlib.figure.Figure @@ -161,11 +168,24 @@ def plot_overview(outputfolders): precisions = [] recalls = [] names = [] + params = [] for folder in outputfolders: + # metrics metrics_path = os.path.join(folder,'results/Metrics.csv') pr, re = read_metricscsv(metrics_path) precisions.append(pr) recalls.append(re) - names.append(folder) - fig = precision_recall_f1iso(precisions,recalls,names) + modelname = folder.split('/')[-1] + # parameters + summary_path = os.path.join(folder,'results/ModelSummary.txt') + with open (summary_path, "r") as outfile: + rows = outfile.readlines() + lastrow = rows[-1] + parameter = int(lastrow.split()[1].replace(',','')) + name = '[P={:.2f}M] {}'.format(parameter/100**3, modelname) + names.append(name) + title = folder.split('/')[-2] + fig = precision_recall_f1iso(precisions,recalls,names,title) return fig + + diff --git a/bob/ip/binseg/utils/summary.py b/bob/ip/binseg/utils/summary.py new file mode 100644 index 0000000000000000000000000000000000000000..127c5e66a3d7c97629ef95acebccb4f138ba43b6 --- /dev/null +++ b/bob/ip/binseg/utils/summary.py @@ -0,0 +1,63 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +# Adapted from https://github.com/pytorch/pytorch/issues/2001#issuecomment-405675488 +import sys +import logging +from functools import reduce + +from torch.nn.modules.module import _addindent +from bob.ip.binseg.modeling.driu import build_driu + + +def summary(model, file=sys.stderr): + """Counts the number of paramters in each layers + + Parameters + ---------- + model : :py:class:`torch.nn.Module` + + Returns + ------- + int + number of parameters + """ + def repr(model): + # We treat the extra repr like the sub-module, one item per line + extra_lines = [] + extra_repr = model.extra_repr() + # empty string will be split into list [''] + if extra_repr: + extra_lines = extra_repr.split('\n') + child_lines = [] + total_params = 0 + for key, module in model._modules.items(): + mod_str, num_params = repr(module) + mod_str = _addindent(mod_str, 2) + child_lines.append('(' + key + '): ' + mod_str) + total_params += num_params + lines = extra_lines + child_lines + + for name, p in model._parameters.items(): + if hasattr(p,'dtype'): + total_params += reduce(lambda x, y: x * y, p.shape) + + main_str = model._get_name() + '(' + if lines: + # simple one-liner info, which most builtin Modules will use + if len(extra_lines) == 1 and not child_lines: + main_str += extra_lines[0] + else: + main_str += '\n ' + '\n '.join(lines) + '\n' + + main_str += ')' + if file is sys.stderr: + main_str += ', \033[92m{:,}\033[0m params'.format(total_params) + else: + main_str += ', {:,} params'.format(total_params) + return main_str, total_params + + string, count = repr(model) + if file is not None: + print(string, file=file) + return count \ No newline at end of file