diff --git a/src/mednet/libs/common/data/datamodule.py b/src/mednet/libs/common/data/datamodule.py index c8ffcffed7ec2c5fc7265d6a509f355e5e25f37c..d51105571332a5728bacefe10446b437a56d4587 100644 --- a/src/mednet/libs/common/data/datamodule.py +++ b/src/mednet/libs/common/data/datamodule.py @@ -30,7 +30,7 @@ from .typing import ( logger = logging.getLogger(__name__) -def _sample_size_bytes(dataset: Sample): +def _sample_size_bytes(dataset: Dataset): """Recurse into the first sample of a dataset and figures out its total occupance in bytes. Parameters @@ -54,7 +54,7 @@ def _sample_size_bytes(dataset: Sample): """ logger.info(f"{list(t.shape)}@{t.dtype}") - return int(t.element_size() * torch.prod(torch.tensor(t.shape))) + return int(t.element_size() * t.shape.numel()) def _dict_size_bytes(d): """Return a dictionary size in bytes. diff --git a/src/mednet/libs/segmentation/config/models/m2unet.py b/src/mednet/libs/segmentation/config/models/m2unet.py index ccfae59b323b2e9c0042248c24c4029b6eb597d2..b7de862491a0112f18ce56c80be483506d5ba246 100644 --- a/src/mednet/libs/segmentation/config/models/m2unet.py +++ b/src/mednet/libs/segmentation/config/models/m2unet.py @@ -18,7 +18,7 @@ References: [SANDLER-2018]_, [RONNEBERGER-2015]_ from mednet.libs.common.models.transforms import ResizeMaxSide, SquareCenterPad from mednet.libs.segmentation.engine.adabound import AdaBound from mednet.libs.segmentation.models.losses import SoftJaccardBCELogitsLoss -from mednet.libs.segmentation.models.m2unet import M2UNET +from mednet.libs.segmentation.models.m2unet import M2Unet lr = 0.001 alpha = 0.7 @@ -32,7 +32,7 @@ amsbound = False resize_transform = ResizeMaxSide(512) -model = M2UNET( +model = M2Unet( loss_type=SoftJaccardBCELogitsLoss, loss_arguments=dict(alpha=alpha), optimizer_type=AdaBound, diff --git a/src/mednet/libs/segmentation/models/losses.py b/src/mednet/libs/segmentation/models/losses.py index 2a344dd461fb4294bab6cb83f0edfa60789486b3..e36ca1602003553ad4612d54fcc145afd5746a63 100644 --- a/src/mednet/libs/segmentation/models/losses.py +++ b/src/mednet/libs/segmentation/models/losses.py @@ -18,21 +18,16 @@ class WeightedBCELogitsLoss(torch.nn.Module): def __init__(self): super().__init__() - def forward( - self, tensor: torch.Tensor, target: torch.Tensor, mask: torch.Tensor - ) -> torch.Tensor: + def forward(self, input_: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """Forward pass. Parameters ---------- - tensor + input_ Value produced by the model to be evaluated, with the shape ``[n, c, h, w]``. target Ground-truth information with the shape ``[n, c, h, w]``. - mask - Mask to be use for specifying the region of interest where to - compute the loss, with the shape ``[n, c, h, w]``. Returns ------- @@ -41,15 +36,12 @@ class WeightedBCELogitsLoss(torch.nn.Module): # calculates the proportion of negatives to the total number of pixels # available in the masked region - valid = mask > 0.5 - num_pos = target[valid].sum() - num_neg = valid.sum() - num_pos - pos_weight = num_neg / num_pos + num_pos = target.sum() return torch.nn.functional.binary_cross_entropy_with_logits( - tensor[valid], - target[valid], + input_, + target, reduction="mean", - pos_weight=pos_weight, + pos_weight=(input_.shape.numel() - num_pos) / num_pos, ) @@ -75,21 +67,16 @@ class SoftJaccardBCELogitsLoss(torch.nn.Module): super().__init__() self.alpha = alpha - def forward( - self, tensor: torch.Tensor, target: torch.Tensor, mask: torch.Tensor - ) -> torch.Tensor: + def forward(self, input_: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """Forward pass. Parameters ---------- - tensor + input_ Value produced by the model to be evaluated, with the shape ``[n, c, h, w]``. target Ground-truth information with the shape ``[n, c, h, w]``. - mask - Mask to be use for specifying the region of interest where to - compute the loss, with the shape ``[n, c, h, w]``. Returns ------- @@ -97,15 +84,14 @@ class SoftJaccardBCELogitsLoss(torch.nn.Module): """ eps = 1e-8 - valid = mask > 0.5 - probabilities = torch.sigmoid(tensor[valid]) - intersection = (probabilities * target[valid]).sum() - sums = probabilities.sum() + target[valid].sum() + probabilities = torch.sigmoid(input_) + intersection = (probabilities * target).sum() + sums = probabilities.sum() + target.sum() j = intersection / (sums - intersection + eps) # this implements the support for looking just into the RoI h = torch.nn.functional.binary_cross_entropy_with_logits( - tensor[valid], target[valid], reduction="mean" + input_, target, reduction="mean" ) return (self.alpha * h) + ((1 - self.alpha) * (1 - j)) @@ -118,21 +104,16 @@ class MultiWeightedBCELogitsLoss(WeightedBCELogitsLoss): def __init__(self): super().__init__() - def forward( - self, tensor: torch.Tensor, target: torch.Tensor, mask: torch.Tensor - ) -> torch.Tensor: + def forward(self, input_: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """Forward pass. Parameters ---------- - tensor + input_ Value produced by the model to be evaluated, with the shape ``[L, n, c, h, w]``. target Ground-truth information with the shape ``[n, c, h, w]``. - mask - Mask to be use for specifying the region of interest where to - compute the loss, with the shape ``[n, c, h, w]``. Returns ------- @@ -140,18 +121,13 @@ class MultiWeightedBCELogitsLoss(WeightedBCELogitsLoss): """ return torch.cat( - [ - super(MultiWeightedBCELogitsLoss, self) - .forward(i, target, mask) - .unsqueeze(0) - for i in tensor - ] + [super().forward(i, target).unsqueeze(0) for i in input_] ).mean() class MultiSoftJaccardBCELogitsLoss(SoftJaccardBCELogitsLoss): - """Implements Equation 3 in [IGLOVIKOV-2018]_ for the multi-output - networks such as HED or Little W-Net. + """Implement Equation 3 in [IGLOVIKOV-2018]_ for the multi-output networks + such as HED or Little W-Net. Parameters ---------- @@ -162,21 +138,16 @@ class MultiSoftJaccardBCELogitsLoss(SoftJaccardBCELogitsLoss): def __init__(self, alpha: float = 0.7): super().__init__(alpha=alpha) - def forward( - self, tensor: torch.Tensor, target: torch.Tensor, mask: torch.Tensor - ) -> torch.Tensor: + def forward(self, input_: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """Forward pass. Parameters ---------- - tensor + input_ Value produced by the model to be evaluated, with the shape ``[L, n, c, h, w]``. target Ground-truth information with the shape ``[n, c, h, w]``. - mask - Mask to be use for specifying the region of interest where to - compute the loss, with the shape ``[n, c, h, w]``. Returns ------- @@ -184,88 +155,5 @@ class MultiSoftJaccardBCELogitsLoss(SoftJaccardBCELogitsLoss): """ return torch.cat( - [ - super(MultiSoftJaccardBCELogitsLoss, self) - .forward(i, target, mask) - .unsqueeze(0) - for i in tensor - ] + [super().forward(i, target).unsqueeze(0) for i in input_] ).mean() - - -# class MixJacLoss(torch.nn.Module): -# """Implements Mix Jaccard Loss. - -# Parameters -# ---------- -# lambda_u -# Determines the weighting of SoftJaccard and BCE. -# jacalpha -# Determines the weighting of J and H. -# size_average -# By default, the losses are averaged over each loss element in the -# batch. Note that for some losses, there are multiple elements per -# sample. If the field `size_average` is set to ``False``, the losses -# are instead summed for each minibatch. Ignored when `reduce` is -# ``False``. Default: ``True``. -# reduce -# By default, the losses are averaged or summed over observations for -# each minibatch depending on `size_average`. When `reduce` is -# ``False``, returns a loss per batch element instead and ignores -# `size_average`. Default: ``True``. -# reduction -# Specifies the reduction to apply to the output: ``'none'`` | -# ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, -# ``'mean'``: the sum of the output will be divided by the number of -# elements in the output, ``'sum'``: the output will be summed. Note: -# `size_average` and `reduce` are in the process of being deprecated, -# and in the meantime, specifying either of those two args will -# override `reduction`. Default: ``'mean'``. -# """ - -# def __init__( -# self, -# lambda_u: int = 100, -# jacalpha=0.7, -# size_average=None, -# reduce=None, -# reduction="mean", -# ): -# super().__init__(size_average, reduce, reduction) -# self.lambda_u = lambda_u -# self.labeled_loss = SoftJaccardBCELogitsLoss(alpha=jacalpha) -# self.unlabeled_loss = torch.nn.BCEWithLogitsLoss() - -# def forward( -# self, -# tensor: torch.Tensor, -# target: torch.Tensor, -# unlabeled_tensor: torch.Tensor, -# unlabeled_target: torch.Tensor, -# ramp_up_factor: float, -# ) -> tuple: -# """Forward pass. - -# Parameters -# ---------- -# tensor -# Value produced by the model to be evaluated, with the shape ``[L, -# n, c, h, w]``. -# target -# Ground-truth information with the shape ``[n, c, h, w]``. - -# unlabeled_tensor - -# unlabeled_target - -# ramp_up_factor - -# Returns -# ------- -# list -# """ -# ll = self.labeled_loss(tensor, target) -# ul = self.unlabeled_loss(unlabeled_tensor, unlabeled_target) - -# loss = ll + self.lambda_u * ramp_up_factor * ul -# return loss, ll, ul diff --git a/src/mednet/libs/segmentation/models/lwnet.py b/src/mednet/libs/segmentation/models/lwnet.py index 5cd491b04d392bbf8579bbca3525e4ea6f46bd45..b2c0ac2f803ea77c8f6d2e7527d0bd2fc68c2293 100644 --- a/src/mednet/libs/segmentation/models/lwnet.py +++ b/src/mednet/libs/segmentation/models/lwnet.py @@ -353,3 +353,7 @@ class LittleWNet(SegmentationModel): x1 = self.unet1(xn) x2 = self.unet2(torch.cat([xn, x1], dim=1)) return x1, x2 + + def predict_step(self, batch, batch_idx, dataloader_idx=0): + # prediction only returns the result of the last unet + return torch.sigmoid(self(batch[0]["image"])[1]) diff --git a/src/mednet/libs/segmentation/models/m2unet.py b/src/mednet/libs/segmentation/models/m2unet.py index 8735956595b3737b7320a7f76ac26db37fe35c0b..5734a2a4e83252b3a2d956c850903b98de5ae9dd 100644 --- a/src/mednet/libs/segmentation/models/m2unet.py +++ b/src/mednet/libs/segmentation/models/m2unet.py @@ -122,8 +122,8 @@ class M2UNetHead(torch.nn.Module): return self.decode1(decode2, x[1]) # 30, 3 -class M2UNET(SegmentationModel): - """Implementation of the M2UNET model. +class M2Unet(SegmentationModel): + """Implementation of the M2Unet model. Parameters ---------- diff --git a/src/mednet/libs/segmentation/models/segmentation_model.py b/src/mednet/libs/segmentation/models/segmentation_model.py index dae5cde663ad0aebef3302b8a68210efc396beac..8f3c5f5e2d1e79b147c3f7677e52738ba5f517e1 100644 --- a/src/mednet/libs/segmentation/models/segmentation_model.py +++ b/src/mednet/libs/segmentation/models/segmentation_model.py @@ -94,21 +94,20 @@ class SegmentationModel(Model): self.normalizer = make_z_normalizer(dataloader) def training_step(self, batch, _): - images = self.augmentation_transforms(batch[0]["image"]) - ground_truths = self.augmentation_transforms(batch[0]["target"]) masks = self.augmentation_transforms(batch[0]["mask"]) + images = self.augmentation_transforms(batch[0]["image"]) * masks + ground_truths = self.augmentation_transforms(batch[0]["target"]) * masks outputs = self(images) - return self._train_loss(outputs, ground_truths, masks) + return self._train_loss(outputs, ground_truths) def validation_step(self, batch, batch_idx, dataloader_idx=0): - images = batch[0]["image"] - ground_truths = batch[0]["target"] - masks = batch[0]["mask"] + images = batch[0]["image"] * batch[0]["mask"] + ground_truths = batch[0]["target"] * batch[0]["mask"] outputs = self(images) - return self._validation_loss(outputs, ground_truths, masks) + return self._validation_loss(outputs, ground_truths) def predict_step(self, batch, batch_idx, dataloader_idx=0): - output = self(batch[0]["image"])[1] + output = self(batch[0]["image"]) return torch.sigmoid(output)