From 6cd6ac6313ad089558abd37e24fb5ec1a823fc3e Mon Sep 17 00:00:00 2001 From: Andre Anjos <andre.dos.anjos@gmail.com> Date: Wed, 3 Jul 2024 13:25:02 +0200 Subject: [PATCH] [libs.segmentation.models.losses] Simplify loss calculation; Specialize prediction step on lwnet --- src/mednet/libs/common/data/datamodule.py | 4 +- .../libs/segmentation/config/models/m2unet.py | 4 +- src/mednet/libs/segmentation/models/losses.py | 152 +++--------------- src/mednet/libs/segmentation/models/lwnet.py | 4 + src/mednet/libs/segmentation/models/m2unet.py | 4 +- .../segmentation/models/segmentation_model.py | 15 +- 6 files changed, 37 insertions(+), 146 deletions(-) diff --git a/src/mednet/libs/common/data/datamodule.py b/src/mednet/libs/common/data/datamodule.py index c8ffcffe..d5110557 100644 --- a/src/mednet/libs/common/data/datamodule.py +++ b/src/mednet/libs/common/data/datamodule.py @@ -30,7 +30,7 @@ from .typing import ( logger = logging.getLogger(__name__) -def _sample_size_bytes(dataset: Sample): +def _sample_size_bytes(dataset: Dataset): """Recurse into the first sample of a dataset and figures out its total occupance in bytes. Parameters @@ -54,7 +54,7 @@ def _sample_size_bytes(dataset: Sample): """ logger.info(f"{list(t.shape)}@{t.dtype}") - return int(t.element_size() * torch.prod(torch.tensor(t.shape))) + return int(t.element_size() * t.shape.numel()) def _dict_size_bytes(d): """Return a dictionary size in bytes. diff --git a/src/mednet/libs/segmentation/config/models/m2unet.py b/src/mednet/libs/segmentation/config/models/m2unet.py index ccfae59b..b7de8624 100644 --- a/src/mednet/libs/segmentation/config/models/m2unet.py +++ b/src/mednet/libs/segmentation/config/models/m2unet.py @@ -18,7 +18,7 @@ References: [SANDLER-2018]_, [RONNEBERGER-2015]_ from mednet.libs.common.models.transforms import ResizeMaxSide, SquareCenterPad from mednet.libs.segmentation.engine.adabound import AdaBound from mednet.libs.segmentation.models.losses import SoftJaccardBCELogitsLoss -from mednet.libs.segmentation.models.m2unet import M2UNET +from mednet.libs.segmentation.models.m2unet import M2Unet lr = 0.001 alpha = 0.7 @@ -32,7 +32,7 @@ amsbound = False resize_transform = ResizeMaxSide(512) -model = M2UNET( +model = M2Unet( loss_type=SoftJaccardBCELogitsLoss, loss_arguments=dict(alpha=alpha), optimizer_type=AdaBound, diff --git a/src/mednet/libs/segmentation/models/losses.py b/src/mednet/libs/segmentation/models/losses.py index 2a344dd4..e36ca160 100644 --- a/src/mednet/libs/segmentation/models/losses.py +++ b/src/mednet/libs/segmentation/models/losses.py @@ -18,21 +18,16 @@ class WeightedBCELogitsLoss(torch.nn.Module): def __init__(self): super().__init__() - def forward( - self, tensor: torch.Tensor, target: torch.Tensor, mask: torch.Tensor - ) -> torch.Tensor: + def forward(self, input_: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """Forward pass. Parameters ---------- - tensor + input_ Value produced by the model to be evaluated, with the shape ``[n, c, h, w]``. target Ground-truth information with the shape ``[n, c, h, w]``. - mask - Mask to be use for specifying the region of interest where to - compute the loss, with the shape ``[n, c, h, w]``. Returns ------- @@ -41,15 +36,12 @@ class WeightedBCELogitsLoss(torch.nn.Module): # calculates the proportion of negatives to the total number of pixels # available in the masked region - valid = mask > 0.5 - num_pos = target[valid].sum() - num_neg = valid.sum() - num_pos - pos_weight = num_neg / num_pos + num_pos = target.sum() return torch.nn.functional.binary_cross_entropy_with_logits( - tensor[valid], - target[valid], + input_, + target, reduction="mean", - pos_weight=pos_weight, + pos_weight=(input_.shape.numel() - num_pos) / num_pos, ) @@ -75,21 +67,16 @@ class SoftJaccardBCELogitsLoss(torch.nn.Module): super().__init__() self.alpha = alpha - def forward( - self, tensor: torch.Tensor, target: torch.Tensor, mask: torch.Tensor - ) -> torch.Tensor: + def forward(self, input_: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """Forward pass. Parameters ---------- - tensor + input_ Value produced by the model to be evaluated, with the shape ``[n, c, h, w]``. target Ground-truth information with the shape ``[n, c, h, w]``. - mask - Mask to be use for specifying the region of interest where to - compute the loss, with the shape ``[n, c, h, w]``. Returns ------- @@ -97,15 +84,14 @@ class SoftJaccardBCELogitsLoss(torch.nn.Module): """ eps = 1e-8 - valid = mask > 0.5 - probabilities = torch.sigmoid(tensor[valid]) - intersection = (probabilities * target[valid]).sum() - sums = probabilities.sum() + target[valid].sum() + probabilities = torch.sigmoid(input_) + intersection = (probabilities * target).sum() + sums = probabilities.sum() + target.sum() j = intersection / (sums - intersection + eps) # this implements the support for looking just into the RoI h = torch.nn.functional.binary_cross_entropy_with_logits( - tensor[valid], target[valid], reduction="mean" + input_, target, reduction="mean" ) return (self.alpha * h) + ((1 - self.alpha) * (1 - j)) @@ -118,21 +104,16 @@ class MultiWeightedBCELogitsLoss(WeightedBCELogitsLoss): def __init__(self): super().__init__() - def forward( - self, tensor: torch.Tensor, target: torch.Tensor, mask: torch.Tensor - ) -> torch.Tensor: + def forward(self, input_: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """Forward pass. Parameters ---------- - tensor + input_ Value produced by the model to be evaluated, with the shape ``[L, n, c, h, w]``. target Ground-truth information with the shape ``[n, c, h, w]``. - mask - Mask to be use for specifying the region of interest where to - compute the loss, with the shape ``[n, c, h, w]``. Returns ------- @@ -140,18 +121,13 @@ class MultiWeightedBCELogitsLoss(WeightedBCELogitsLoss): """ return torch.cat( - [ - super(MultiWeightedBCELogitsLoss, self) - .forward(i, target, mask) - .unsqueeze(0) - for i in tensor - ] + [super().forward(i, target).unsqueeze(0) for i in input_] ).mean() class MultiSoftJaccardBCELogitsLoss(SoftJaccardBCELogitsLoss): - """Implements Equation 3 in [IGLOVIKOV-2018]_ for the multi-output - networks such as HED or Little W-Net. + """Implement Equation 3 in [IGLOVIKOV-2018]_ for the multi-output networks + such as HED or Little W-Net. Parameters ---------- @@ -162,21 +138,16 @@ class MultiSoftJaccardBCELogitsLoss(SoftJaccardBCELogitsLoss): def __init__(self, alpha: float = 0.7): super().__init__(alpha=alpha) - def forward( - self, tensor: torch.Tensor, target: torch.Tensor, mask: torch.Tensor - ) -> torch.Tensor: + def forward(self, input_: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """Forward pass. Parameters ---------- - tensor + input_ Value produced by the model to be evaluated, with the shape ``[L, n, c, h, w]``. target Ground-truth information with the shape ``[n, c, h, w]``. - mask - Mask to be use for specifying the region of interest where to - compute the loss, with the shape ``[n, c, h, w]``. Returns ------- @@ -184,88 +155,5 @@ class MultiSoftJaccardBCELogitsLoss(SoftJaccardBCELogitsLoss): """ return torch.cat( - [ - super(MultiSoftJaccardBCELogitsLoss, self) - .forward(i, target, mask) - .unsqueeze(0) - for i in tensor - ] + [super().forward(i, target).unsqueeze(0) for i in input_] ).mean() - - -# class MixJacLoss(torch.nn.Module): -# """Implements Mix Jaccard Loss. - -# Parameters -# ---------- -# lambda_u -# Determines the weighting of SoftJaccard and BCE. -# jacalpha -# Determines the weighting of J and H. -# size_average -# By default, the losses are averaged over each loss element in the -# batch. Note that for some losses, there are multiple elements per -# sample. If the field `size_average` is set to ``False``, the losses -# are instead summed for each minibatch. Ignored when `reduce` is -# ``False``. Default: ``True``. -# reduce -# By default, the losses are averaged or summed over observations for -# each minibatch depending on `size_average`. When `reduce` is -# ``False``, returns a loss per batch element instead and ignores -# `size_average`. Default: ``True``. -# reduction -# Specifies the reduction to apply to the output: ``'none'`` | -# ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, -# ``'mean'``: the sum of the output will be divided by the number of -# elements in the output, ``'sum'``: the output will be summed. Note: -# `size_average` and `reduce` are in the process of being deprecated, -# and in the meantime, specifying either of those two args will -# override `reduction`. Default: ``'mean'``. -# """ - -# def __init__( -# self, -# lambda_u: int = 100, -# jacalpha=0.7, -# size_average=None, -# reduce=None, -# reduction="mean", -# ): -# super().__init__(size_average, reduce, reduction) -# self.lambda_u = lambda_u -# self.labeled_loss = SoftJaccardBCELogitsLoss(alpha=jacalpha) -# self.unlabeled_loss = torch.nn.BCEWithLogitsLoss() - -# def forward( -# self, -# tensor: torch.Tensor, -# target: torch.Tensor, -# unlabeled_tensor: torch.Tensor, -# unlabeled_target: torch.Tensor, -# ramp_up_factor: float, -# ) -> tuple: -# """Forward pass. - -# Parameters -# ---------- -# tensor -# Value produced by the model to be evaluated, with the shape ``[L, -# n, c, h, w]``. -# target -# Ground-truth information with the shape ``[n, c, h, w]``. - -# unlabeled_tensor - -# unlabeled_target - -# ramp_up_factor - -# Returns -# ------- -# list -# """ -# ll = self.labeled_loss(tensor, target) -# ul = self.unlabeled_loss(unlabeled_tensor, unlabeled_target) - -# loss = ll + self.lambda_u * ramp_up_factor * ul -# return loss, ll, ul diff --git a/src/mednet/libs/segmentation/models/lwnet.py b/src/mednet/libs/segmentation/models/lwnet.py index 5cd491b0..b2c0ac2f 100644 --- a/src/mednet/libs/segmentation/models/lwnet.py +++ b/src/mednet/libs/segmentation/models/lwnet.py @@ -353,3 +353,7 @@ class LittleWNet(SegmentationModel): x1 = self.unet1(xn) x2 = self.unet2(torch.cat([xn, x1], dim=1)) return x1, x2 + + def predict_step(self, batch, batch_idx, dataloader_idx=0): + # prediction only returns the result of the last unet + return torch.sigmoid(self(batch[0]["image"])[1]) diff --git a/src/mednet/libs/segmentation/models/m2unet.py b/src/mednet/libs/segmentation/models/m2unet.py index 87359565..5734a2a4 100644 --- a/src/mednet/libs/segmentation/models/m2unet.py +++ b/src/mednet/libs/segmentation/models/m2unet.py @@ -122,8 +122,8 @@ class M2UNetHead(torch.nn.Module): return self.decode1(decode2, x[1]) # 30, 3 -class M2UNET(SegmentationModel): - """Implementation of the M2UNET model. +class M2Unet(SegmentationModel): + """Implementation of the M2Unet model. Parameters ---------- diff --git a/src/mednet/libs/segmentation/models/segmentation_model.py b/src/mednet/libs/segmentation/models/segmentation_model.py index dae5cde6..8f3c5f5e 100644 --- a/src/mednet/libs/segmentation/models/segmentation_model.py +++ b/src/mednet/libs/segmentation/models/segmentation_model.py @@ -94,21 +94,20 @@ class SegmentationModel(Model): self.normalizer = make_z_normalizer(dataloader) def training_step(self, batch, _): - images = self.augmentation_transforms(batch[0]["image"]) - ground_truths = self.augmentation_transforms(batch[0]["target"]) masks = self.augmentation_transforms(batch[0]["mask"]) + images = self.augmentation_transforms(batch[0]["image"]) * masks + ground_truths = self.augmentation_transforms(batch[0]["target"]) * masks outputs = self(images) - return self._train_loss(outputs, ground_truths, masks) + return self._train_loss(outputs, ground_truths) def validation_step(self, batch, batch_idx, dataloader_idx=0): - images = batch[0]["image"] - ground_truths = batch[0]["target"] - masks = batch[0]["mask"] + images = batch[0]["image"] * batch[0]["mask"] + ground_truths = batch[0]["target"] * batch[0]["mask"] outputs = self(images) - return self._validation_loss(outputs, ground_truths, masks) + return self._validation_loss(outputs, ground_truths) def predict_step(self, batch, batch_idx, dataloader_idx=0): - output = self(batch[0]["image"])[1] + output = self(batch[0]["image"]) return torch.sigmoid(output) -- GitLab