Skip to content
Snippets Groups Projects
Commit 6cd6ac63 authored by André Anjos's avatar André Anjos :speech_balloon:
Browse files

[libs.segmentation.models.losses] Simplify loss calculation; Specialize prediction step on lwnet

parent c5e36064
No related branches found
No related tags found
1 merge request!46Create common library
Pipeline #89334 failed
......@@ -30,7 +30,7 @@ from .typing import (
logger = logging.getLogger(__name__)
def _sample_size_bytes(dataset: Sample):
def _sample_size_bytes(dataset: Dataset):
"""Recurse into the first sample of a dataset and figures out its total occupance in bytes.
Parameters
......@@ -54,7 +54,7 @@ def _sample_size_bytes(dataset: Sample):
"""
logger.info(f"{list(t.shape)}@{t.dtype}")
return int(t.element_size() * torch.prod(torch.tensor(t.shape)))
return int(t.element_size() * t.shape.numel())
def _dict_size_bytes(d):
"""Return a dictionary size in bytes.
......
......@@ -18,7 +18,7 @@ References: [SANDLER-2018]_, [RONNEBERGER-2015]_
from mednet.libs.common.models.transforms import ResizeMaxSide, SquareCenterPad
from mednet.libs.segmentation.engine.adabound import AdaBound
from mednet.libs.segmentation.models.losses import SoftJaccardBCELogitsLoss
from mednet.libs.segmentation.models.m2unet import M2UNET
from mednet.libs.segmentation.models.m2unet import M2Unet
lr = 0.001
alpha = 0.7
......@@ -32,7 +32,7 @@ amsbound = False
resize_transform = ResizeMaxSide(512)
model = M2UNET(
model = M2Unet(
loss_type=SoftJaccardBCELogitsLoss,
loss_arguments=dict(alpha=alpha),
optimizer_type=AdaBound,
......
......@@ -18,21 +18,16 @@ class WeightedBCELogitsLoss(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(
self, tensor: torch.Tensor, target: torch.Tensor, mask: torch.Tensor
) -> torch.Tensor:
def forward(self, input_: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""Forward pass.
Parameters
----------
tensor
input_
Value produced by the model to be evaluated, with the shape ``[n, c,
h, w]``.
target
Ground-truth information with the shape ``[n, c, h, w]``.
mask
Mask to be use for specifying the region of interest where to
compute the loss, with the shape ``[n, c, h, w]``.
Returns
-------
......@@ -41,15 +36,12 @@ class WeightedBCELogitsLoss(torch.nn.Module):
# calculates the proportion of negatives to the total number of pixels
# available in the masked region
valid = mask > 0.5
num_pos = target[valid].sum()
num_neg = valid.sum() - num_pos
pos_weight = num_neg / num_pos
num_pos = target.sum()
return torch.nn.functional.binary_cross_entropy_with_logits(
tensor[valid],
target[valid],
input_,
target,
reduction="mean",
pos_weight=pos_weight,
pos_weight=(input_.shape.numel() - num_pos) / num_pos,
)
......@@ -75,21 +67,16 @@ class SoftJaccardBCELogitsLoss(torch.nn.Module):
super().__init__()
self.alpha = alpha
def forward(
self, tensor: torch.Tensor, target: torch.Tensor, mask: torch.Tensor
) -> torch.Tensor:
def forward(self, input_: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""Forward pass.
Parameters
----------
tensor
input_
Value produced by the model to be evaluated, with the shape ``[n, c,
h, w]``.
target
Ground-truth information with the shape ``[n, c, h, w]``.
mask
Mask to be use for specifying the region of interest where to
compute the loss, with the shape ``[n, c, h, w]``.
Returns
-------
......@@ -97,15 +84,14 @@ class SoftJaccardBCELogitsLoss(torch.nn.Module):
"""
eps = 1e-8
valid = mask > 0.5
probabilities = torch.sigmoid(tensor[valid])
intersection = (probabilities * target[valid]).sum()
sums = probabilities.sum() + target[valid].sum()
probabilities = torch.sigmoid(input_)
intersection = (probabilities * target).sum()
sums = probabilities.sum() + target.sum()
j = intersection / (sums - intersection + eps)
# this implements the support for looking just into the RoI
h = torch.nn.functional.binary_cross_entropy_with_logits(
tensor[valid], target[valid], reduction="mean"
input_, target, reduction="mean"
)
return (self.alpha * h) + ((1 - self.alpha) * (1 - j))
......@@ -118,21 +104,16 @@ class MultiWeightedBCELogitsLoss(WeightedBCELogitsLoss):
def __init__(self):
super().__init__()
def forward(
self, tensor: torch.Tensor, target: torch.Tensor, mask: torch.Tensor
) -> torch.Tensor:
def forward(self, input_: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""Forward pass.
Parameters
----------
tensor
input_
Value produced by the model to be evaluated, with the shape ``[L,
n, c, h, w]``.
target
Ground-truth information with the shape ``[n, c, h, w]``.
mask
Mask to be use for specifying the region of interest where to
compute the loss, with the shape ``[n, c, h, w]``.
Returns
-------
......@@ -140,18 +121,13 @@ class MultiWeightedBCELogitsLoss(WeightedBCELogitsLoss):
"""
return torch.cat(
[
super(MultiWeightedBCELogitsLoss, self)
.forward(i, target, mask)
.unsqueeze(0)
for i in tensor
]
[super().forward(i, target).unsqueeze(0) for i in input_]
).mean()
class MultiSoftJaccardBCELogitsLoss(SoftJaccardBCELogitsLoss):
"""Implements Equation 3 in [IGLOVIKOV-2018]_ for the multi-output
networks such as HED or Little W-Net.
"""Implement Equation 3 in [IGLOVIKOV-2018]_ for the multi-output networks
such as HED or Little W-Net.
Parameters
----------
......@@ -162,21 +138,16 @@ class MultiSoftJaccardBCELogitsLoss(SoftJaccardBCELogitsLoss):
def __init__(self, alpha: float = 0.7):
super().__init__(alpha=alpha)
def forward(
self, tensor: torch.Tensor, target: torch.Tensor, mask: torch.Tensor
) -> torch.Tensor:
def forward(self, input_: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""Forward pass.
Parameters
----------
tensor
input_
Value produced by the model to be evaluated, with the shape ``[L,
n, c, h, w]``.
target
Ground-truth information with the shape ``[n, c, h, w]``.
mask
Mask to be use for specifying the region of interest where to
compute the loss, with the shape ``[n, c, h, w]``.
Returns
-------
......@@ -184,88 +155,5 @@ class MultiSoftJaccardBCELogitsLoss(SoftJaccardBCELogitsLoss):
"""
return torch.cat(
[
super(MultiSoftJaccardBCELogitsLoss, self)
.forward(i, target, mask)
.unsqueeze(0)
for i in tensor
]
[super().forward(i, target).unsqueeze(0) for i in input_]
).mean()
# class MixJacLoss(torch.nn.Module):
# """Implements Mix Jaccard Loss.
# Parameters
# ----------
# lambda_u
# Determines the weighting of SoftJaccard and BCE.
# jacalpha
# Determines the weighting of J and H.
# size_average
# By default, the losses are averaged over each loss element in the
# batch. Note that for some losses, there are multiple elements per
# sample. If the field `size_average` is set to ``False``, the losses
# are instead summed for each minibatch. Ignored when `reduce` is
# ``False``. Default: ``True``.
# reduce
# By default, the losses are averaged or summed over observations for
# each minibatch depending on `size_average`. When `reduce` is
# ``False``, returns a loss per batch element instead and ignores
# `size_average`. Default: ``True``.
# reduction
# Specifies the reduction to apply to the output: ``'none'`` |
# ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
# ``'mean'``: the sum of the output will be divided by the number of
# elements in the output, ``'sum'``: the output will be summed. Note:
# `size_average` and `reduce` are in the process of being deprecated,
# and in the meantime, specifying either of those two args will
# override `reduction`. Default: ``'mean'``.
# """
# def __init__(
# self,
# lambda_u: int = 100,
# jacalpha=0.7,
# size_average=None,
# reduce=None,
# reduction="mean",
# ):
# super().__init__(size_average, reduce, reduction)
# self.lambda_u = lambda_u
# self.labeled_loss = SoftJaccardBCELogitsLoss(alpha=jacalpha)
# self.unlabeled_loss = torch.nn.BCEWithLogitsLoss()
# def forward(
# self,
# tensor: torch.Tensor,
# target: torch.Tensor,
# unlabeled_tensor: torch.Tensor,
# unlabeled_target: torch.Tensor,
# ramp_up_factor: float,
# ) -> tuple:
# """Forward pass.
# Parameters
# ----------
# tensor
# Value produced by the model to be evaluated, with the shape ``[L,
# n, c, h, w]``.
# target
# Ground-truth information with the shape ``[n, c, h, w]``.
# unlabeled_tensor
# unlabeled_target
# ramp_up_factor
# Returns
# -------
# list
# """
# ll = self.labeled_loss(tensor, target)
# ul = self.unlabeled_loss(unlabeled_tensor, unlabeled_target)
# loss = ll + self.lambda_u * ramp_up_factor * ul
# return loss, ll, ul
......@@ -353,3 +353,7 @@ class LittleWNet(SegmentationModel):
x1 = self.unet1(xn)
x2 = self.unet2(torch.cat([xn, x1], dim=1))
return x1, x2
def predict_step(self, batch, batch_idx, dataloader_idx=0):
# prediction only returns the result of the last unet
return torch.sigmoid(self(batch[0]["image"])[1])
......@@ -122,8 +122,8 @@ class M2UNetHead(torch.nn.Module):
return self.decode1(decode2, x[1]) # 30, 3
class M2UNET(SegmentationModel):
"""Implementation of the M2UNET model.
class M2Unet(SegmentationModel):
"""Implementation of the M2Unet model.
Parameters
----------
......
......@@ -94,21 +94,20 @@ class SegmentationModel(Model):
self.normalizer = make_z_normalizer(dataloader)
def training_step(self, batch, _):
images = self.augmentation_transforms(batch[0]["image"])
ground_truths = self.augmentation_transforms(batch[0]["target"])
masks = self.augmentation_transforms(batch[0]["mask"])
images = self.augmentation_transforms(batch[0]["image"]) * masks
ground_truths = self.augmentation_transforms(batch[0]["target"]) * masks
outputs = self(images)
return self._train_loss(outputs, ground_truths, masks)
return self._train_loss(outputs, ground_truths)
def validation_step(self, batch, batch_idx, dataloader_idx=0):
images = batch[0]["image"]
ground_truths = batch[0]["target"]
masks = batch[0]["mask"]
images = batch[0]["image"] * batch[0]["mask"]
ground_truths = batch[0]["target"] * batch[0]["mask"]
outputs = self(images)
return self._validation_loss(outputs, ground_truths, masks)
return self._validation_loss(outputs, ground_truths)
def predict_step(self, batch, batch_idx, dataloader_idx=0):
output = self(batch[0]["image"])[1]
output = self(batch[0]["image"])
return torch.sigmoid(output)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment