diff --git a/src/mednet/libs/common/data/datamodule.py b/src/mednet/libs/common/data/datamodule.py
index c8ffcffed7ec2c5fc7265d6a509f355e5e25f37c..d51105571332a5728bacefe10446b437a56d4587 100644
--- a/src/mednet/libs/common/data/datamodule.py
+++ b/src/mednet/libs/common/data/datamodule.py
@@ -30,7 +30,7 @@ from .typing import (
 logger = logging.getLogger(__name__)
 
 
-def _sample_size_bytes(dataset: Sample):
+def _sample_size_bytes(dataset: Dataset):
     """Recurse into the first sample of a dataset and figures out its total occupance in bytes.
 
     Parameters
@@ -54,7 +54,7 @@ def _sample_size_bytes(dataset: Sample):
         """
 
         logger.info(f"{list(t.shape)}@{t.dtype}")
-        return int(t.element_size() * torch.prod(torch.tensor(t.shape)))
+        return int(t.element_size() * t.shape.numel())
 
     def _dict_size_bytes(d):
         """Return a dictionary size in bytes.
diff --git a/src/mednet/libs/segmentation/config/models/m2unet.py b/src/mednet/libs/segmentation/config/models/m2unet.py
index ccfae59b323b2e9c0042248c24c4029b6eb597d2..b7de862491a0112f18ce56c80be483506d5ba246 100644
--- a/src/mednet/libs/segmentation/config/models/m2unet.py
+++ b/src/mednet/libs/segmentation/config/models/m2unet.py
@@ -18,7 +18,7 @@ References: [SANDLER-2018]_, [RONNEBERGER-2015]_
 from mednet.libs.common.models.transforms import ResizeMaxSide, SquareCenterPad
 from mednet.libs.segmentation.engine.adabound import AdaBound
 from mednet.libs.segmentation.models.losses import SoftJaccardBCELogitsLoss
-from mednet.libs.segmentation.models.m2unet import M2UNET
+from mednet.libs.segmentation.models.m2unet import M2Unet
 
 lr = 0.001
 alpha = 0.7
@@ -32,7 +32,7 @@ amsbound = False
 
 resize_transform = ResizeMaxSide(512)
 
-model = M2UNET(
+model = M2Unet(
     loss_type=SoftJaccardBCELogitsLoss,
     loss_arguments=dict(alpha=alpha),
     optimizer_type=AdaBound,
diff --git a/src/mednet/libs/segmentation/models/losses.py b/src/mednet/libs/segmentation/models/losses.py
index 2a344dd461fb4294bab6cb83f0edfa60789486b3..e36ca1602003553ad4612d54fcc145afd5746a63 100644
--- a/src/mednet/libs/segmentation/models/losses.py
+++ b/src/mednet/libs/segmentation/models/losses.py
@@ -18,21 +18,16 @@ class WeightedBCELogitsLoss(torch.nn.Module):
     def __init__(self):
         super().__init__()
 
-    def forward(
-        self, tensor: torch.Tensor, target: torch.Tensor, mask: torch.Tensor
-    ) -> torch.Tensor:
+    def forward(self, input_: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
         """Forward pass.
 
         Parameters
         ----------
-        tensor
+        input_
             Value produced by the model to be evaluated, with the shape ``[n, c,
             h, w]``.
         target
             Ground-truth information with the shape ``[n, c, h, w]``.
-        mask
-            Mask to be use for specifying the region of interest where to
-            compute the loss, with the shape ``[n, c, h, w]``.
 
         Returns
         -------
@@ -41,15 +36,12 @@ class WeightedBCELogitsLoss(torch.nn.Module):
 
         # calculates the proportion of negatives to the total number of pixels
         # available in the masked region
-        valid = mask > 0.5
-        num_pos = target[valid].sum()
-        num_neg = valid.sum() - num_pos
-        pos_weight = num_neg / num_pos
+        num_pos = target.sum()
         return torch.nn.functional.binary_cross_entropy_with_logits(
-            tensor[valid],
-            target[valid],
+            input_,
+            target,
             reduction="mean",
-            pos_weight=pos_weight,
+            pos_weight=(input_.shape.numel() - num_pos) / num_pos,
         )
 
 
@@ -75,21 +67,16 @@ class SoftJaccardBCELogitsLoss(torch.nn.Module):
         super().__init__()
         self.alpha = alpha
 
-    def forward(
-        self, tensor: torch.Tensor, target: torch.Tensor, mask: torch.Tensor
-    ) -> torch.Tensor:
+    def forward(self, input_: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
         """Forward pass.
 
         Parameters
         ----------
-        tensor
+        input_
             Value produced by the model to be evaluated, with the shape ``[n, c,
             h, w]``.
         target
             Ground-truth information with the shape ``[n, c, h, w]``.
-        mask
-            Mask to be use for specifying the region of interest where to
-            compute the loss, with the shape ``[n, c, h, w]``.
 
         Returns
         -------
@@ -97,15 +84,14 @@ class SoftJaccardBCELogitsLoss(torch.nn.Module):
         """
 
         eps = 1e-8
-        valid = mask > 0.5
-        probabilities = torch.sigmoid(tensor[valid])
-        intersection = (probabilities * target[valid]).sum()
-        sums = probabilities.sum() + target[valid].sum()
+        probabilities = torch.sigmoid(input_)
+        intersection = (probabilities * target).sum()
+        sums = probabilities.sum() + target.sum()
         j = intersection / (sums - intersection + eps)
 
         # this implements the support for looking just into the RoI
         h = torch.nn.functional.binary_cross_entropy_with_logits(
-            tensor[valid], target[valid], reduction="mean"
+            input_, target, reduction="mean"
         )
         return (self.alpha * h) + ((1 - self.alpha) * (1 - j))
 
@@ -118,21 +104,16 @@ class MultiWeightedBCELogitsLoss(WeightedBCELogitsLoss):
     def __init__(self):
         super().__init__()
 
-    def forward(
-        self, tensor: torch.Tensor, target: torch.Tensor, mask: torch.Tensor
-    ) -> torch.Tensor:
+    def forward(self, input_: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
         """Forward pass.
 
         Parameters
         ----------
-        tensor
+        input_
             Value produced by the model to be evaluated, with the shape ``[L,
             n, c, h, w]``.
         target
             Ground-truth information with the shape ``[n, c, h, w]``.
-        mask
-            Mask to be use for specifying the region of interest where to
-            compute the loss, with the shape ``[n, c, h, w]``.
 
         Returns
         -------
@@ -140,18 +121,13 @@ class MultiWeightedBCELogitsLoss(WeightedBCELogitsLoss):
         """
 
         return torch.cat(
-            [
-                super(MultiWeightedBCELogitsLoss, self)
-                .forward(i, target, mask)
-                .unsqueeze(0)
-                for i in tensor
-            ]
+            [super().forward(i, target).unsqueeze(0) for i in input_]
         ).mean()
 
 
 class MultiSoftJaccardBCELogitsLoss(SoftJaccardBCELogitsLoss):
-    """Implements  Equation 3 in [IGLOVIKOV-2018]_ for the multi-output
-    networks such as HED or Little W-Net.
+    """Implement Equation 3 in [IGLOVIKOV-2018]_ for the multi-output networks
+    such as HED or Little W-Net.
 
     Parameters
     ----------
@@ -162,21 +138,16 @@ class MultiSoftJaccardBCELogitsLoss(SoftJaccardBCELogitsLoss):
     def __init__(self, alpha: float = 0.7):
         super().__init__(alpha=alpha)
 
-    def forward(
-        self, tensor: torch.Tensor, target: torch.Tensor, mask: torch.Tensor
-    ) -> torch.Tensor:
+    def forward(self, input_: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
         """Forward pass.
 
         Parameters
         ----------
-        tensor
+        input_
             Value produced by the model to be evaluated, with the shape ``[L,
             n, c, h, w]``.
         target
             Ground-truth information with the shape ``[n, c, h, w]``.
-        mask
-            Mask to be use for specifying the region of interest where to
-            compute the loss, with the shape ``[n, c, h, w]``.
 
         Returns
         -------
@@ -184,88 +155,5 @@ class MultiSoftJaccardBCELogitsLoss(SoftJaccardBCELogitsLoss):
         """
 
         return torch.cat(
-            [
-                super(MultiSoftJaccardBCELogitsLoss, self)
-                .forward(i, target, mask)
-                .unsqueeze(0)
-                for i in tensor
-            ]
+            [super().forward(i, target).unsqueeze(0) for i in input_]
         ).mean()
-
-
-# class MixJacLoss(torch.nn.Module):
-#     """Implements Mix Jaccard Loss.
-
-#     Parameters
-#     ----------
-#     lambda_u
-#         Determines the weighting of SoftJaccard and BCE.
-#     jacalpha
-#         Determines the weighting of J and H.
-#     size_average
-#         By default, the losses are averaged over each loss element in the
-#         batch. Note that for some losses, there are multiple elements per
-#         sample. If the field `size_average` is set to ``False``, the losses
-#         are instead summed for each minibatch. Ignored when `reduce` is
-#         ``False``. Default: ``True``.
-#     reduce
-#         By default, the losses are averaged or summed over observations for
-#         each minibatch depending on `size_average`. When `reduce` is
-#         ``False``, returns a loss per batch element instead and ignores
-#         `size_average`. Default: ``True``.
-#     reduction
-#         Specifies the reduction to apply to the output: ``'none'`` |
-#         ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
-#         ``'mean'``: the sum of the output will be divided by the number of
-#         elements in the output, ``'sum'``: the output will be summed. Note:
-#         `size_average` and `reduce` are in the process of being deprecated,
-#         and in the meantime, specifying either of those two args will
-#         override `reduction`. Default: ``'mean'``.
-#     """
-
-#     def __init__(
-#         self,
-#         lambda_u: int = 100,
-#         jacalpha=0.7,
-#         size_average=None,
-#         reduce=None,
-#         reduction="mean",
-#     ):
-#         super().__init__(size_average, reduce, reduction)
-#         self.lambda_u = lambda_u
-#         self.labeled_loss = SoftJaccardBCELogitsLoss(alpha=jacalpha)
-#         self.unlabeled_loss = torch.nn.BCEWithLogitsLoss()
-
-#     def forward(
-#         self,
-#         tensor: torch.Tensor,
-#         target: torch.Tensor,
-#         unlabeled_tensor: torch.Tensor,
-#         unlabeled_target: torch.Tensor,
-#         ramp_up_factor: float,
-#     ) -> tuple:
-#         """Forward pass.
-
-#         Parameters
-#         ----------
-#         tensor
-#             Value produced by the model to be evaluated, with the shape ``[L,
-#             n, c, h, w]``.
-#         target
-#             Ground-truth information with the shape ``[n, c, h, w]``.
-
-#         unlabeled_tensor
-
-#         unlabeled_target
-
-#         ramp_up_factor
-
-#         Returns
-#         -------
-#         list
-#         """
-#         ll = self.labeled_loss(tensor, target)
-#         ul = self.unlabeled_loss(unlabeled_tensor, unlabeled_target)
-
-#         loss = ll + self.lambda_u * ramp_up_factor * ul
-#         return loss, ll, ul
diff --git a/src/mednet/libs/segmentation/models/lwnet.py b/src/mednet/libs/segmentation/models/lwnet.py
index 5cd491b04d392bbf8579bbca3525e4ea6f46bd45..b2c0ac2f803ea77c8f6d2e7527d0bd2fc68c2293 100644
--- a/src/mednet/libs/segmentation/models/lwnet.py
+++ b/src/mednet/libs/segmentation/models/lwnet.py
@@ -353,3 +353,7 @@ class LittleWNet(SegmentationModel):
         x1 = self.unet1(xn)
         x2 = self.unet2(torch.cat([xn, x1], dim=1))
         return x1, x2
+
+    def predict_step(self, batch, batch_idx, dataloader_idx=0):
+        # prediction only returns the result of the last unet
+        return torch.sigmoid(self(batch[0]["image"])[1])
diff --git a/src/mednet/libs/segmentation/models/m2unet.py b/src/mednet/libs/segmentation/models/m2unet.py
index 8735956595b3737b7320a7f76ac26db37fe35c0b..5734a2a4e83252b3a2d956c850903b98de5ae9dd 100644
--- a/src/mednet/libs/segmentation/models/m2unet.py
+++ b/src/mednet/libs/segmentation/models/m2unet.py
@@ -122,8 +122,8 @@ class M2UNetHead(torch.nn.Module):
         return self.decode1(decode2, x[1])  # 30, 3
 
 
-class M2UNET(SegmentationModel):
-    """Implementation of the M2UNET model.
+class M2Unet(SegmentationModel):
+    """Implementation of the M2Unet model.
 
     Parameters
     ----------
diff --git a/src/mednet/libs/segmentation/models/segmentation_model.py b/src/mednet/libs/segmentation/models/segmentation_model.py
index dae5cde663ad0aebef3302b8a68210efc396beac..8f3c5f5e2d1e79b147c3f7677e52738ba5f517e1 100644
--- a/src/mednet/libs/segmentation/models/segmentation_model.py
+++ b/src/mednet/libs/segmentation/models/segmentation_model.py
@@ -94,21 +94,20 @@ class SegmentationModel(Model):
         self.normalizer = make_z_normalizer(dataloader)
 
     def training_step(self, batch, _):
-        images = self.augmentation_transforms(batch[0]["image"])
-        ground_truths = self.augmentation_transforms(batch[0]["target"])
         masks = self.augmentation_transforms(batch[0]["mask"])
+        images = self.augmentation_transforms(batch[0]["image"]) * masks
+        ground_truths = self.augmentation_transforms(batch[0]["target"]) * masks
 
         outputs = self(images)
-        return self._train_loss(outputs, ground_truths, masks)
+        return self._train_loss(outputs, ground_truths)
 
     def validation_step(self, batch, batch_idx, dataloader_idx=0):
-        images = batch[0]["image"]
-        ground_truths = batch[0]["target"]
-        masks = batch[0]["mask"]
+        images = batch[0]["image"] * batch[0]["mask"]
+        ground_truths = batch[0]["target"] * batch[0]["mask"]
 
         outputs = self(images)
-        return self._validation_loss(outputs, ground_truths, masks)
+        return self._validation_loss(outputs, ground_truths)
 
     def predict_step(self, batch, batch_idx, dataloader_idx=0):
-        output = self(batch[0]["image"])[1]
+        output = self(batch[0]["image"])
         return torch.sigmoid(output)