diff --git a/bob/ip/binseg/configs/models/hed.py b/bob/ip/binseg/configs/models/hed.py index 7300dee510abe50e8567585505132eb72741c7c9..0abcac62490b0ed9c788c66f3eecf950e3d0519c 100644 --- a/bob/ip/binseg/configs/models/hed.py +++ b/bob/ip/binseg/configs/models/hed.py @@ -14,7 +14,7 @@ Reference: [XIE-2015]_ from torch.optim.lr_scheduler import MultiStepLR from bob.ip.binseg.models.hed import hed -from bob.ip.binseg.models.losses import HEDSoftJaccardBCELogitsLoss +from bob.ip.binseg.models.losses import MultiSoftJaccardBCELogitsLoss from bob.ip.binseg.engine.adabound import AdaBound @@ -45,7 +45,7 @@ optimizer = AdaBound( amsbound=amsbound, ) # criterion -criterion = HEDSoftJaccardBCELogitsLoss(alpha=0.7) +criterion = MultiSoftJaccardBCELogitsLoss(alpha=0.7) # scheduler scheduler = MultiStepLR( diff --git a/bob/ip/binseg/configs/models/lwnet.py b/bob/ip/binseg/configs/models/lwnet.py index 660f873e68b95ded151ecb5d60bd940648a70c21..1454909d82ea6cc23c87804ebbfbd35745e4dcb5 100755 --- a/bob/ip/binseg/configs/models/lwnet.py +++ b/bob/ip/binseg/configs/models/lwnet.py @@ -11,8 +11,8 @@ Reference: [GALDRAN-2020]_ from torch.optim.lr_scheduler import CosineAnnealingLR from torch.optim import Adam -from torch.nn import BCEWithLogitsLoss from bob.ip.binseg.models.lwnet import lwnet +from bob.ip.binseg.models.losses import MultiWeightedBCELogitsLoss ##### Config ##### max_lr = 0.01 #start @@ -21,7 +21,7 @@ cycle = 50 #epochs for a complete scheduling cycle model = lwnet() -criterion = BCEWithLogitsLoss() +criterion = MultiWeightedBCELogitsLoss() optimizer = Adam( model.parameters(), diff --git a/bob/ip/binseg/engine/predictor.py b/bob/ip/binseg/engine/predictor.py index 4615b76560ee760f3af18474458161d3bb34fe65..c7150cc78776f842fcdd892365d1346dae235b21 100644 --- a/bob/ip/binseg/engine/predictor.py +++ b/bob/ip/binseg/engine/predictor.py @@ -161,9 +161,9 @@ def run(model, data_loader, name, device, output_folder, overlayed_folder): start_time = time.perf_counter() outputs = model(images) - # necessary check for HED architecture that uses several outputs - # for loss calculation instead of just the last concatfuse block - if isinstance(outputs, list): + # necessary check for HED/Little W-Net architecture that use + # several outputs for loss calculation instead of just the last one + if isinstance(outputs, (list, tuple)): outputs = outputs[-1] predictions = sigmoid(outputs) diff --git a/bob/ip/binseg/engine/ssltrainer.py b/bob/ip/binseg/engine/ssltrainer.py index f642c1f5ea5ec6fd5dd8767e2cafb37a98935731..f307534e49addf654dc518d0e31db1f70f86ab54 100644 --- a/bob/ip/binseg/engine/ssltrainer.py +++ b/bob/ip/binseg/engine/ssltrainer.py @@ -357,6 +357,7 @@ def run( epoch, rampup_length=rampup_length ) + # note: no support for masks... loss, ll, ul = criterion( outputs, ground_truths, @@ -395,15 +396,16 @@ def run( device=device, non_blocking=torch.cuda.is_available(), ) - masks = None - if len(samples) == 4: - masks = samples[-1].to( + masks = ( + torch.ones_like(ground_truths) + if len(samples) < 4 + else samples[3].to( device=device, non_blocking=torch.cuda.is_available(), ) + ) outputs = model(images) - loss = criterion(outputs, ground_truths, masks) valid_losses.update(loss) diff --git a/bob/ip/binseg/engine/trainer.py b/bob/ip/binseg/engine/trainer.py index d3717050f9c4ce3cd67c4c334a4e2d2230311fbf..372513cbc0429756cddde17aeb0b97ddfc02bf51 100644 --- a/bob/ip/binseg/engine/trainer.py +++ b/bob/ip/binseg/engine/trainer.py @@ -115,9 +115,9 @@ def run( if device.type == "cuda": # asserts we do have a GPU - assert bool(gpu_constants()), ( - f"Device set to '{device}', but nvidia-smi is not installed" - ) + assert bool( + gpu_constants() + ), f"Device set to '{device}', but nvidia-smi is not installed" os.makedirs(output_folder, exist_ok=True) @@ -216,15 +216,15 @@ def run( ground_truths = samples[2].to( device=device, non_blocking=torch.cuda.is_available() ) - masks = None - if len(samples) == 4: - masks = samples[-1].to( + masks = ( + torch.ones_like(ground_truths) + if len(samples) < 4 + else samples[3].to( device=device, non_blocking=torch.cuda.is_available() ) + ) outputs = model(images) - - # loss evaluation and learning (backward step) loss = criterion(outputs, ground_truths, masks) optimizer.zero_grad() loss.backward() @@ -255,15 +255,16 @@ def run( device=device, non_blocking=torch.cuda.is_available(), ) - masks = None - if len(samples) == 4: - masks = samples[-1].to( + masks = ( + torch.ones_like(ground_truths) + if len(samples) < 4 + else samples[3].to( device=device, non_blocking=torch.cuda.is_available(), ) + ) outputs = model(images) - loss = criterion(outputs, ground_truths, masks) valid_losses.update(loss) diff --git a/bob/ip/binseg/models/hed.py b/bob/ip/binseg/models/hed.py index f1bc13ed1f4826e3a1b32c15e2122c6d5acd4743..e125bbffa0cbb6fe3eb1f43c94081b97008c6125 100644 --- a/bob/ip/binseg/models/hed.py +++ b/bob/ip/binseg/models/hed.py @@ -79,8 +79,7 @@ class HED(torch.nn.Module): conv1_2_16, upsample2, upsample4, upsample8, upsample16 ) - out = [upsample2, upsample4, upsample8, upsample16, concatfuse] - return out + return (upsample2, upsample4, upsample8, upsample16, concatfuse) def hed(pretrained_backbone=True, progress=True): diff --git a/bob/ip/binseg/models/losses.py b/bob/ip/binseg/models/losses.py index 2f435c14cbaa3d33453a6c3b4c6d82a6d0588821..31b2f3f1a6c01d32b69e75c87dd658b1be4e9756 100644 --- a/bob/ip/binseg/models/losses.py +++ b/bob/ip/binseg/models/losses.py @@ -3,243 +3,216 @@ import torch from torch.nn.modules.loss import _Loss -# Conditionally decorates a method if a decorator exists in PyTorch -# This overcomes an import error with versions of PyTorch >= 1.2, where the -# decorator ``weak_script_method`` is not anymore available. See: -# https://github.com/pytorch/pytorch/commit/10c4b98ade8349d841518d22f19a653a939e260c#diff-ee07db084d958260fd24b4b02d4f078d -# from July 4th, 2019. -try: - from torch._jit_internal import weak_script_method -except ImportError: - - def weak_script_method(x): - return x - class WeightedBCELogitsLoss(_Loss): - """ - Implements Equation 1 in [MANINIS-2016]_. Based on - :py:class:`torch.nn.BCEWithLogitsLoss`. + """Calculates sum of weighted cross entropy loss. - Calculate sum of weighted cross entropy loss. + Implements Equation 1 in [MANINIS-2016]_. The weight depends on the + current proportion between negatives and positives in the ground-truth + sample being analyzed. """ - def __init__( - self, - weight=None, - size_average=None, - reduce=None, - reduction="mean", - pos_weight=None, - ): - super(WeightedBCELogitsLoss, self).__init__(size_average, reduce, reduction) - self.register_buffer("weight", weight) - self.register_buffer("pos_weight", pos_weight) + def __init__(self): + super(WeightedBCELogitsLoss, self).__init__() - @weak_script_method - def forward(self, input, target, masks=None): + def forward(self, input, target, mask): """ + Parameters ---------- + input : :py:class:`torch.Tensor` + Value produced by the model to be evaluated, with the shape ``[n, c, + h, w]`` + target : :py:class:`torch.Tensor` - masks : :py:class:`torch.Tensor`, optional + Ground-truth information with the shape ``[n, c, h, w]`` + + mask : :py:class:`torch.Tensor` + Mask to be use for specifying the region of interest where to + compute the loss, with the shape ``[n, c, h, w]`` + Returns ------- - :py:class:`torch.Tensor` + + loss : :py:class:`torch.Tensor` + The average loss for all input data + """ - n, c, h, w = target.shape - num_pos = ( - torch.sum(target, dim=[1, 2, 3]).float().reshape(n, 1) - ) # torch.Size([n, 1]) - if hasattr(masks, "dtype"): - num_mask_neg = c * h * w - torch.sum(masks, dim=[1, 2, 3]).float().reshape( - n, 1 - ) # torch.Size([n, 1]) - num_neg = c * h * w - num_pos - num_mask_neg - else: - num_neg = c * h * w - num_pos - numposnumtotal = torch.ones_like(target) * ( - num_pos / (num_pos + num_neg) - ).unsqueeze(1).unsqueeze(2) - numnegnumtotal = torch.ones_like(target) * ( - num_neg / (num_pos + num_neg) - ).unsqueeze(1).unsqueeze(2) - weight = torch.where((target <= 0.5), numposnumtotal, numnegnumtotal) - - loss = torch.nn.functional.binary_cross_entropy_with_logits( - input, target, weight=weight, reduction=self.reduction + + # calculates the proportion of negatives to the total number of pixels + # available in the masked region + valid = mask > 0.5 + num_pos = target[valid].sum() + num_neg = valid.sum() - num_pos + pos_weight = num_neg / num_pos + + return torch.nn.functional.binary_cross_entropy_with_logits( + input[valid], target[valid], reduction="mean", pos_weight=pos_weight ) - return loss class SoftJaccardBCELogitsLoss(_Loss): """ - Implements Equation 3 in [IGLOVIKOV-2018]_. Based on - ``torch.nn.BCEWithLogitsLoss``. + Implements the generalized loss function of Equation (3) in + [IGLOVIKOV-2018]_, with J being the Jaccard distance, and H, the Binary + Cross-Entropy Loss: + + .. math:: + + L = \alpha H + (1-\alpha)(1-J) + + + Our implementation is based on :py:class:`torch.nn.BCEWithLogitsLoss`. + Attributes ---------- + alpha : float - determines the weighting of SoftJaccard and BCE. Default: ``0.7`` + determines the weighting of J and H. Default: ``0.7`` + """ - def __init__( - self, - alpha=0.7, - size_average=None, - reduce=None, - reduction="mean", - pos_weight=None, - ): - super(SoftJaccardBCELogitsLoss, self).__init__(size_average, reduce, reduction) + def __init__(self, alpha=0.7): + super(SoftJaccardBCELogitsLoss, self).__init__() self.alpha = alpha - @weak_script_method - def forward(self, input, target, masks=None): + def forward(self, input, target, mask): """ + Parameters ---------- + input : :py:class:`torch.Tensor` + Value produced by the model to be evaluated, with the shape ``[n, c, + h, w]`` + target : :py:class:`torch.Tensor` - masks : :py:class:`torch.Tensor`, optional + Ground-truth information with the shape ``[n, c, h, w]`` + + mask : :py:class:`torch.Tensor` + Mask to be use for specifying the region of interest where to + compute the loss, with the shape ``[n, c, h, w]`` + Returns ------- - :py:class:`torch.Tensor` - """ - eps = 1e-8 - probabilities = torch.sigmoid(input) - intersection = (probabilities * target).sum() - sums = probabilities.sum() + target.sum() - softjaccard = intersection / (sums - intersection + eps) + loss : :py:class:`torch.Tensor` + Loss, in a single entry + + """ - bceloss = torch.nn.functional.binary_cross_entropy_with_logits( - input, target, weight=None, reduction=self.reduction + eps = 1e-8 + valid = mask > 0.5 + probabilities = torch.sigmoid(input[valid]) + intersection = (probabilities * target[valid]).sum() + sums = probabilities.sum() + target[valid].sum() + J = intersection / (sums - intersection + eps) + + # this implements the support for looking just into the RoI + H = torch.nn.functional.binary_cross_entropy_with_logits( + input[valid], target[valid], reduction="mean" ) - loss = self.alpha * bceloss + (1 - self.alpha) * (1 - softjaccard) - return loss + return (self.alpha * H) + ((1 - self.alpha) * (1 - J)) -class HEDWeightedBCELogitsLoss(_Loss): +class MultiWeightedBCELogitsLoss(WeightedBCELogitsLoss): """ - Implements Equation 2 in [HE-2015]_. Based on - ``torch.nn.modules.loss.BCEWithLogitsLoss``. - - Calculate sum of weighted cross entropy loss. + Weighted Binary Cross-Entropy Loss for multi-layered inputs (e.g. for + Holistically-Nested Edge Detection in [XIE-2015]_). """ - def __init__( - self, - weight=None, - size_average=None, - reduce=None, - reduction="mean", - pos_weight=None, - ): - super(HEDWeightedBCELogitsLoss, self).__init__(size_average, reduce, reduction) - self.register_buffer("weight", weight) - self.register_buffer("pos_weight", pos_weight) + def __init__(self): + super(MultiWeightedBCELogitsLoss, self).__init__() - @weak_script_method - def forward(self, inputlist, target, masks=None): + def forward(self, input, target, mask): """ Parameters ---------- - inputlist : list of :py:class:`torch.Tensor` - HED uses multiple side-output feature maps for the loss calculation + + input : iterable over :py:class:`torch.Tensor` + Value produced by the model to be evaluated, with the shape ``[L, + n, c, h, w]`` + target : :py:class:`torch.Tensor` - masks : :py:class:`torch.Tensor`, optional + Ground-truth information with the shape ``[n, c, h, w]`` + + mask : :py:class:`torch.Tensor` + Mask to be use for specifying the region of interest where to + compute the loss, with the shape ``[n, c, h, w]`` + + Returns ------- - :py:class:`torch.Tensor` + + loss : torch.Tensor + The average loss for all input data + """ - loss_over_all_inputs = [] - for input in inputlist: - n, c, h, w = target.shape - num_pos = ( - torch.sum(target, dim=[1, 2, 3]).float().reshape(n, 1) - ) # torch.Size([n, 1]) - if hasattr(masks, "dtype"): - num_mask_neg = c * h * w - torch.sum( - masks, dim=[1, 2, 3] - ).float().reshape( - n, 1 - ) # torch.Size([n, 1]) - num_neg = c * h * w - num_pos - num_mask_neg - else: - num_neg = c * h * w - num_pos # torch.Size([n, 1]) - numposnumtotal = torch.ones_like(target) * ( - num_pos / (num_pos + num_neg) - ).unsqueeze(1).unsqueeze(2) - numnegnumtotal = torch.ones_like(target) * ( - num_neg / (num_pos + num_neg) - ).unsqueeze(1).unsqueeze(2) - weight = torch.where((target <= 0.5), numposnumtotal, numnegnumtotal) - loss = torch.nn.functional.binary_cross_entropy_with_logits( - input, target, weight=weight, reduction=self.reduction - ) - loss_over_all_inputs.append(loss.unsqueeze(0)) - final_loss = torch.cat(loss_over_all_inputs).mean() - return final_loss - - -class HEDSoftJaccardBCELogitsLoss(_Loss): + + return torch.cat( + [ + super(MultiWeightedBCELogitsLoss, self).forward(i, target, + mask).unsqueeze(0) + for i in input + ] + ).mean() + + +class MultiSoftJaccardBCELogitsLoss(SoftJaccardBCELogitsLoss): """ - Implements Equation 3 in [IGLOVIKOV-2018]_ for the hed network. Based on - :py:class:`torch.nn.BCEWithLogitsLoss`. + Implements Equation 3 in [IGLOVIKOV-2018]_ for the multi-output networks + such as HED or Little W-Net. + Attributes ---------- + alpha : float determines the weighting of SoftJaccard and BCE. Default: ``0.3`` + """ - def __init__( - self, - alpha=0.3, - size_average=None, - reduce=None, - reduction="mean", - pos_weight=None, - ): - super(HEDSoftJaccardBCELogitsLoss, self).__init__( - size_average, reduce, reduction - ) - self.alpha = alpha + def __init__(self, alpha=0.7): + super(MultiSoftJaccardBCELogitsLoss, self).__init__(alpha=alpha) - @weak_script_method - def forward(self, inputlist, target, masks=None): + def forward(self, inputlist, target): """ Parameters ---------- - input : :py:class:`torch.Tensor` + + input : iterable over :py:class:`torch.Tensor` + Value produced by the model to be evaluated, with the shape ``[L, + n, c, h, w]`` + target : :py:class:`torch.Tensor` - masks : :py:class:`torch.Tensor`, optional + Ground-truth information with the shape ``[n, c, h, w]`` + + mask : :py:class:`torch.Tensor` + Mask to be use for specifying the region of interest where to + compute the loss, with the shape ``[n, c, h, w]`` + Returns ------- - :py:class:`torch.Tensor` - """ - eps = 1e-8 - loss_over_all_inputs = [] - for input in inputlist: - probabilities = torch.sigmoid(input) - intersection = (probabilities * target).sum() - sums = probabilities.sum() + target.sum() - softjaccard = intersection / (sums - intersection + eps) + loss : torch.Tensor + The average loss for all input data - bceloss = torch.nn.functional.binary_cross_entropy_with_logits( - input, target, weight=None, reduction=self.reduction - ) - loss = self.alpha * bceloss + (1 - self.alpha) * (1 - softjaccard) - loss_over_all_inputs.append(loss.unsqueeze(0)) - final_loss = torch.cat(loss_over_all_inputs).mean() - return final_loss + """ + + return torch.cat( + [ + super(MultiSoftJaccardBCELogitsLoss, self).forward( + i, target, mask + ).unsqueeze(0) + for i in input + ] + ).mean() class MixJacLoss(_Loss): @@ -267,8 +240,9 @@ class MixJacLoss(_Loss): self.labeled_loss = SoftJaccardBCELogitsLoss(alpha=jacalpha) self.unlabeled_loss = torch.nn.BCEWithLogitsLoss() - @weak_script_method - def forward(self, input, target, unlabeled_input, unlabeled_traget, ramp_up_factor): + def forward( + self, input, target, unlabeled_input, unlabeled_target, ramp_up_factor + ): """ Parameters ---------- @@ -276,7 +250,7 @@ class MixJacLoss(_Loss): input : :py:class:`torch.Tensor` target : :py:class:`torch.Tensor` unlabeled_input : :py:class:`torch.Tensor` - unlabeled_traget : :py:class:`torch.Tensor` + unlabeled_target : :py:class:`torch.Tensor` ramp_up_factor : float Returns @@ -286,7 +260,7 @@ class MixJacLoss(_Loss): """ ll = self.labeled_loss(input, target) - ul = self.unlabeled_loss(unlabeled_input, unlabeled_traget) + ul = self.unlabeled_loss(unlabeled_input, unlabeled_target) loss = ll + self.lambda_u * ramp_up_factor * ul return loss, ll, ul diff --git a/bob/ip/binseg/models/lwnet.py b/bob/ip/binseg/models/lwnet.py index b8bd26af1668b4e21b82c7badb5e49abdc89680d..bf116d4227a7c28e05dc7340a7758650cf8fa231 100644 --- a/bob/ip/binseg/models/lwnet.py +++ b/bob/ip/binseg/models/lwnet.py @@ -157,7 +157,7 @@ class LittleUNet(torch.nn.Module): conv_bridge=True, shortcut=True, ): - super(UNet, self).__init__() + super(LittleUNet, self).__init__() self.n_classes = n_classes self.first = ConvBlock( in_c=in_c, out_c=layers[0], k_sz=k_sz, shortcut=shortcut, pool=False @@ -224,7 +224,7 @@ class LittleWNet(torch.nn.Module): mode="train", ): - super(wnet, self).__init__() + super(LittleWNet, self).__init__() self.unet1 = LittleUNet( in_c=in_c, n_classes=n_classes, diff --git a/bob/ip/binseg/script/binseg.py b/bob/ip/binseg/script/binseg.py index afa00f950b674e8c579fe01dabb9d9520c883946..ac9357a39f3ba73f166f9d3bc8f087ed1664d65c 100644 --- a/bob/ip/binseg/script/binseg.py +++ b/bob/ip/binseg/script/binseg.py @@ -33,7 +33,9 @@ def setup_pytorch_device(name): ---------- name : str - The device name (``cpu``, ``cuda:0``, ``cuda:1``, and so on) + The device name (``cpu``, ``cuda:0``, ``cuda:1``, and so on). If you + set a specific cuda device such as ``cuda:1``, then we'll make sure it + is currently set. Returns @@ -44,9 +46,10 @@ def setup_pytorch_device(name): """ - if name.startswith("cuda"): + if name.startswith("cuda:"): # In case one has multiple devices, we must first set the one # we would like to use so pytorch can find it. + logger.info(f"User set device to '{name}' - trying to force device...") os.environ['CUDA_VISIBLE_DEVICES'] = name.split(":",1)[1] if not torch.cuda.is_available(): raise RuntimeError(f"CUDA is not currently available, but " \ @@ -54,7 +57,11 @@ def setup_pytorch_device(name): # Let pytorch auto-select from environment variable return torch.device("cuda") - #cpu + elif name.startswith("cuda"): #use default device + logger.info(f"User set device to '{name}' - using default CUDA device") + assert os.environ.get('CUDA_VISIBLE_DEVICES') is not None + + #cuda or cpu return torch.device(name) diff --git a/doc/references.rst b/doc/references.rst index a4a979be0ab424f4432b1857268225f2a0c477a5..df7bce250328e1b985b14051c5b20b4de74034a3 100644 --- a/doc/references.rst +++ b/doc/references.rst @@ -72,7 +72,7 @@ Workshops (CVPRW), Salt Lake City, UT, 2018, pp. 228-2284. https://doi.org/10.1109/CVPRW.2018.00042 -.. [HE-2015] *S. Xie and Z. Tu*, **Holistically-Nested Edge Detection**, 2015 +.. [XIE-2015] *S. Xie and Z. Tu*, **Holistically-Nested Edge Detection**, 2015 IEEE International Conference on Computer Vision (ICCV), Santiago, 2015, pp. 1395-1403. https://doi.org/10.1109/ICCV.2015.164 @@ -90,10 +90,6 @@ MobileNetV2**, 2018. Last accessed: 21.03.2020. https://github.com/tonylins/pytorch-mobilenet-v2 -.. [XIE-2015] *S. Xie and Z. Tu*, **Holistically-Nested Edge Detection**, 2015 - IEEE International Conference on Computer Vision (ICCV), Santiago, 2015, pp. - 1395-1403. https://doi.org/10.1109/ICCV.2015.164 - .. [RONNEBERGER-2015] *O. Ronneberger, P. Fischer, T. Brox*, **U-Net: Convolutional Networks for Biomedical Image Segmentation**, 2015. https://arxiv.org/abs/1505.04597 diff --git a/setup.py b/setup.py index 475b077666541fb9e3a4b18ade7ab6a46e04b084..b55da6c7a478c707678de611cf5e2398e8c00eb8 100644 --- a/setup.py +++ b/setup.py @@ -55,7 +55,7 @@ setup( "m2unet-ssl = bob.ip.binseg.configs.models.m2unet_ssl", "unet = bob.ip.binseg.configs.models.unet", "resunet = bob.ip.binseg.configs.models.resunet", - "lwnet = bob.ip.binseg.configs.models.lwunet", + "lwnet = bob.ip.binseg.configs.models.lwnet", # example datasets "csv-dataset-example = bob.ip.binseg.configs.datasets.csv",