Skip to content
Snippets Groups Projects
Commit 6cd6ac63 authored by André Anjos's avatar André Anjos :speech_balloon:
Browse files

[libs.segmentation.models.losses] Simplify loss calculation; Specialize prediction step on lwnet

parent c5e36064
No related branches found
No related tags found
1 merge request!46Create common library
Pipeline #89334 failed
...@@ -30,7 +30,7 @@ from .typing import ( ...@@ -30,7 +30,7 @@ from .typing import (
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def _sample_size_bytes(dataset: Sample): def _sample_size_bytes(dataset: Dataset):
"""Recurse into the first sample of a dataset and figures out its total occupance in bytes. """Recurse into the first sample of a dataset and figures out its total occupance in bytes.
Parameters Parameters
...@@ -54,7 +54,7 @@ def _sample_size_bytes(dataset: Sample): ...@@ -54,7 +54,7 @@ def _sample_size_bytes(dataset: Sample):
""" """
logger.info(f"{list(t.shape)}@{t.dtype}") logger.info(f"{list(t.shape)}@{t.dtype}")
return int(t.element_size() * torch.prod(torch.tensor(t.shape))) return int(t.element_size() * t.shape.numel())
def _dict_size_bytes(d): def _dict_size_bytes(d):
"""Return a dictionary size in bytes. """Return a dictionary size in bytes.
......
...@@ -18,7 +18,7 @@ References: [SANDLER-2018]_, [RONNEBERGER-2015]_ ...@@ -18,7 +18,7 @@ References: [SANDLER-2018]_, [RONNEBERGER-2015]_
from mednet.libs.common.models.transforms import ResizeMaxSide, SquareCenterPad from mednet.libs.common.models.transforms import ResizeMaxSide, SquareCenterPad
from mednet.libs.segmentation.engine.adabound import AdaBound from mednet.libs.segmentation.engine.adabound import AdaBound
from mednet.libs.segmentation.models.losses import SoftJaccardBCELogitsLoss from mednet.libs.segmentation.models.losses import SoftJaccardBCELogitsLoss
from mednet.libs.segmentation.models.m2unet import M2UNET from mednet.libs.segmentation.models.m2unet import M2Unet
lr = 0.001 lr = 0.001
alpha = 0.7 alpha = 0.7
...@@ -32,7 +32,7 @@ amsbound = False ...@@ -32,7 +32,7 @@ amsbound = False
resize_transform = ResizeMaxSide(512) resize_transform = ResizeMaxSide(512)
model = M2UNET( model = M2Unet(
loss_type=SoftJaccardBCELogitsLoss, loss_type=SoftJaccardBCELogitsLoss,
loss_arguments=dict(alpha=alpha), loss_arguments=dict(alpha=alpha),
optimizer_type=AdaBound, optimizer_type=AdaBound,
......
...@@ -18,21 +18,16 @@ class WeightedBCELogitsLoss(torch.nn.Module): ...@@ -18,21 +18,16 @@ class WeightedBCELogitsLoss(torch.nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
def forward( def forward(self, input_: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
self, tensor: torch.Tensor, target: torch.Tensor, mask: torch.Tensor
) -> torch.Tensor:
"""Forward pass. """Forward pass.
Parameters Parameters
---------- ----------
tensor input_
Value produced by the model to be evaluated, with the shape ``[n, c, Value produced by the model to be evaluated, with the shape ``[n, c,
h, w]``. h, w]``.
target target
Ground-truth information with the shape ``[n, c, h, w]``. Ground-truth information with the shape ``[n, c, h, w]``.
mask
Mask to be use for specifying the region of interest where to
compute the loss, with the shape ``[n, c, h, w]``.
Returns Returns
------- -------
...@@ -41,15 +36,12 @@ class WeightedBCELogitsLoss(torch.nn.Module): ...@@ -41,15 +36,12 @@ class WeightedBCELogitsLoss(torch.nn.Module):
# calculates the proportion of negatives to the total number of pixels # calculates the proportion of negatives to the total number of pixels
# available in the masked region # available in the masked region
valid = mask > 0.5 num_pos = target.sum()
num_pos = target[valid].sum()
num_neg = valid.sum() - num_pos
pos_weight = num_neg / num_pos
return torch.nn.functional.binary_cross_entropy_with_logits( return torch.nn.functional.binary_cross_entropy_with_logits(
tensor[valid], input_,
target[valid], target,
reduction="mean", reduction="mean",
pos_weight=pos_weight, pos_weight=(input_.shape.numel() - num_pos) / num_pos,
) )
...@@ -75,21 +67,16 @@ class SoftJaccardBCELogitsLoss(torch.nn.Module): ...@@ -75,21 +67,16 @@ class SoftJaccardBCELogitsLoss(torch.nn.Module):
super().__init__() super().__init__()
self.alpha = alpha self.alpha = alpha
def forward( def forward(self, input_: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
self, tensor: torch.Tensor, target: torch.Tensor, mask: torch.Tensor
) -> torch.Tensor:
"""Forward pass. """Forward pass.
Parameters Parameters
---------- ----------
tensor input_
Value produced by the model to be evaluated, with the shape ``[n, c, Value produced by the model to be evaluated, with the shape ``[n, c,
h, w]``. h, w]``.
target target
Ground-truth information with the shape ``[n, c, h, w]``. Ground-truth information with the shape ``[n, c, h, w]``.
mask
Mask to be use for specifying the region of interest where to
compute the loss, with the shape ``[n, c, h, w]``.
Returns Returns
------- -------
...@@ -97,15 +84,14 @@ class SoftJaccardBCELogitsLoss(torch.nn.Module): ...@@ -97,15 +84,14 @@ class SoftJaccardBCELogitsLoss(torch.nn.Module):
""" """
eps = 1e-8 eps = 1e-8
valid = mask > 0.5 probabilities = torch.sigmoid(input_)
probabilities = torch.sigmoid(tensor[valid]) intersection = (probabilities * target).sum()
intersection = (probabilities * target[valid]).sum() sums = probabilities.sum() + target.sum()
sums = probabilities.sum() + target[valid].sum()
j = intersection / (sums - intersection + eps) j = intersection / (sums - intersection + eps)
# this implements the support for looking just into the RoI # this implements the support for looking just into the RoI
h = torch.nn.functional.binary_cross_entropy_with_logits( h = torch.nn.functional.binary_cross_entropy_with_logits(
tensor[valid], target[valid], reduction="mean" input_, target, reduction="mean"
) )
return (self.alpha * h) + ((1 - self.alpha) * (1 - j)) return (self.alpha * h) + ((1 - self.alpha) * (1 - j))
...@@ -118,21 +104,16 @@ class MultiWeightedBCELogitsLoss(WeightedBCELogitsLoss): ...@@ -118,21 +104,16 @@ class MultiWeightedBCELogitsLoss(WeightedBCELogitsLoss):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
def forward( def forward(self, input_: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
self, tensor: torch.Tensor, target: torch.Tensor, mask: torch.Tensor
) -> torch.Tensor:
"""Forward pass. """Forward pass.
Parameters Parameters
---------- ----------
tensor input_
Value produced by the model to be evaluated, with the shape ``[L, Value produced by the model to be evaluated, with the shape ``[L,
n, c, h, w]``. n, c, h, w]``.
target target
Ground-truth information with the shape ``[n, c, h, w]``. Ground-truth information with the shape ``[n, c, h, w]``.
mask
Mask to be use for specifying the region of interest where to
compute the loss, with the shape ``[n, c, h, w]``.
Returns Returns
------- -------
...@@ -140,18 +121,13 @@ class MultiWeightedBCELogitsLoss(WeightedBCELogitsLoss): ...@@ -140,18 +121,13 @@ class MultiWeightedBCELogitsLoss(WeightedBCELogitsLoss):
""" """
return torch.cat( return torch.cat(
[ [super().forward(i, target).unsqueeze(0) for i in input_]
super(MultiWeightedBCELogitsLoss, self)
.forward(i, target, mask)
.unsqueeze(0)
for i in tensor
]
).mean() ).mean()
class MultiSoftJaccardBCELogitsLoss(SoftJaccardBCELogitsLoss): class MultiSoftJaccardBCELogitsLoss(SoftJaccardBCELogitsLoss):
"""Implements Equation 3 in [IGLOVIKOV-2018]_ for the multi-output """Implement Equation 3 in [IGLOVIKOV-2018]_ for the multi-output networks
networks such as HED or Little W-Net. such as HED or Little W-Net.
Parameters Parameters
---------- ----------
...@@ -162,21 +138,16 @@ class MultiSoftJaccardBCELogitsLoss(SoftJaccardBCELogitsLoss): ...@@ -162,21 +138,16 @@ class MultiSoftJaccardBCELogitsLoss(SoftJaccardBCELogitsLoss):
def __init__(self, alpha: float = 0.7): def __init__(self, alpha: float = 0.7):
super().__init__(alpha=alpha) super().__init__(alpha=alpha)
def forward( def forward(self, input_: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
self, tensor: torch.Tensor, target: torch.Tensor, mask: torch.Tensor
) -> torch.Tensor:
"""Forward pass. """Forward pass.
Parameters Parameters
---------- ----------
tensor input_
Value produced by the model to be evaluated, with the shape ``[L, Value produced by the model to be evaluated, with the shape ``[L,
n, c, h, w]``. n, c, h, w]``.
target target
Ground-truth information with the shape ``[n, c, h, w]``. Ground-truth information with the shape ``[n, c, h, w]``.
mask
Mask to be use for specifying the region of interest where to
compute the loss, with the shape ``[n, c, h, w]``.
Returns Returns
------- -------
...@@ -184,88 +155,5 @@ class MultiSoftJaccardBCELogitsLoss(SoftJaccardBCELogitsLoss): ...@@ -184,88 +155,5 @@ class MultiSoftJaccardBCELogitsLoss(SoftJaccardBCELogitsLoss):
""" """
return torch.cat( return torch.cat(
[ [super().forward(i, target).unsqueeze(0) for i in input_]
super(MultiSoftJaccardBCELogitsLoss, self)
.forward(i, target, mask)
.unsqueeze(0)
for i in tensor
]
).mean() ).mean()
# class MixJacLoss(torch.nn.Module):
# """Implements Mix Jaccard Loss.
# Parameters
# ----------
# lambda_u
# Determines the weighting of SoftJaccard and BCE.
# jacalpha
# Determines the weighting of J and H.
# size_average
# By default, the losses are averaged over each loss element in the
# batch. Note that for some losses, there are multiple elements per
# sample. If the field `size_average` is set to ``False``, the losses
# are instead summed for each minibatch. Ignored when `reduce` is
# ``False``. Default: ``True``.
# reduce
# By default, the losses are averaged or summed over observations for
# each minibatch depending on `size_average`. When `reduce` is
# ``False``, returns a loss per batch element instead and ignores
# `size_average`. Default: ``True``.
# reduction
# Specifies the reduction to apply to the output: ``'none'`` |
# ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
# ``'mean'``: the sum of the output will be divided by the number of
# elements in the output, ``'sum'``: the output will be summed. Note:
# `size_average` and `reduce` are in the process of being deprecated,
# and in the meantime, specifying either of those two args will
# override `reduction`. Default: ``'mean'``.
# """
# def __init__(
# self,
# lambda_u: int = 100,
# jacalpha=0.7,
# size_average=None,
# reduce=None,
# reduction="mean",
# ):
# super().__init__(size_average, reduce, reduction)
# self.lambda_u = lambda_u
# self.labeled_loss = SoftJaccardBCELogitsLoss(alpha=jacalpha)
# self.unlabeled_loss = torch.nn.BCEWithLogitsLoss()
# def forward(
# self,
# tensor: torch.Tensor,
# target: torch.Tensor,
# unlabeled_tensor: torch.Tensor,
# unlabeled_target: torch.Tensor,
# ramp_up_factor: float,
# ) -> tuple:
# """Forward pass.
# Parameters
# ----------
# tensor
# Value produced by the model to be evaluated, with the shape ``[L,
# n, c, h, w]``.
# target
# Ground-truth information with the shape ``[n, c, h, w]``.
# unlabeled_tensor
# unlabeled_target
# ramp_up_factor
# Returns
# -------
# list
# """
# ll = self.labeled_loss(tensor, target)
# ul = self.unlabeled_loss(unlabeled_tensor, unlabeled_target)
# loss = ll + self.lambda_u * ramp_up_factor * ul
# return loss, ll, ul
...@@ -353,3 +353,7 @@ class LittleWNet(SegmentationModel): ...@@ -353,3 +353,7 @@ class LittleWNet(SegmentationModel):
x1 = self.unet1(xn) x1 = self.unet1(xn)
x2 = self.unet2(torch.cat([xn, x1], dim=1)) x2 = self.unet2(torch.cat([xn, x1], dim=1))
return x1, x2 return x1, x2
def predict_step(self, batch, batch_idx, dataloader_idx=0):
# prediction only returns the result of the last unet
return torch.sigmoid(self(batch[0]["image"])[1])
...@@ -122,8 +122,8 @@ class M2UNetHead(torch.nn.Module): ...@@ -122,8 +122,8 @@ class M2UNetHead(torch.nn.Module):
return self.decode1(decode2, x[1]) # 30, 3 return self.decode1(decode2, x[1]) # 30, 3
class M2UNET(SegmentationModel): class M2Unet(SegmentationModel):
"""Implementation of the M2UNET model. """Implementation of the M2Unet model.
Parameters Parameters
---------- ----------
......
...@@ -94,21 +94,20 @@ class SegmentationModel(Model): ...@@ -94,21 +94,20 @@ class SegmentationModel(Model):
self.normalizer = make_z_normalizer(dataloader) self.normalizer = make_z_normalizer(dataloader)
def training_step(self, batch, _): def training_step(self, batch, _):
images = self.augmentation_transforms(batch[0]["image"])
ground_truths = self.augmentation_transforms(batch[0]["target"])
masks = self.augmentation_transforms(batch[0]["mask"]) masks = self.augmentation_transforms(batch[0]["mask"])
images = self.augmentation_transforms(batch[0]["image"]) * masks
ground_truths = self.augmentation_transforms(batch[0]["target"]) * masks
outputs = self(images) outputs = self(images)
return self._train_loss(outputs, ground_truths, masks) return self._train_loss(outputs, ground_truths)
def validation_step(self, batch, batch_idx, dataloader_idx=0): def validation_step(self, batch, batch_idx, dataloader_idx=0):
images = batch[0]["image"] images = batch[0]["image"] * batch[0]["mask"]
ground_truths = batch[0]["target"] ground_truths = batch[0]["target"] * batch[0]["mask"]
masks = batch[0]["mask"]
outputs = self(images) outputs = self(images)
return self._validation_loss(outputs, ground_truths, masks) return self._validation_loss(outputs, ground_truths)
def predict_step(self, batch, batch_idx, dataloader_idx=0): def predict_step(self, batch, batch_idx, dataloader_idx=0):
output = self(batch[0]["image"])[1] output = self(batch[0]["image"])
return torch.sigmoid(output) return torch.sigmoid(output)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment