diff --git a/src/mednet/libs/segmentation/models/driu.py b/src/mednet/libs/segmentation/models/driu.py
index f3a2e0d2d38866f670f051e6b1ea685ab2c3fd1f..76608301c9b2dd717cda3c6dd6996947887caf45 100644
--- a/src/mednet/libs/segmentation/models/driu.py
+++ b/src/mednet/libs/segmentation/models/driu.py
@@ -161,23 +161,23 @@ class DRIU(Model):
             super().set_normalizer(dataloader)
 
     def training_step(self, batch, batch_idx):
-        images = batch[0]
-        ground_truths = batch[1]["target"]
-        masks = batch[1]["mask"]
+        images = batch[0]["image"]
+        ground_truths = batch[0]["target"]
+        masks = batch[0]["mask"]
 
         outputs = self(self._augmentation_transforms(images))
         return self._train_loss(outputs, ground_truths, masks)
 
     def validation_step(self, batch, batch_idx):
-        images = batch[0]
-        ground_truths = batch[1]["target"]
-        masks = batch[1]["mask"]
+        images = batch[0]["image"]
+        ground_truths = batch[0]["target"]
+        masks = batch[0]["mask"]
 
         outputs = self(images)
         return self._validation_loss(outputs, ground_truths, masks)
 
     def predict_step(self, batch, batch_idx, dataloader_idx=0):
-        output = self(batch[0])[1]
+        output = self(batch[0]["image"])[1]
         return torch.sigmoid(output)
 
     def configure_optimizers(self):
diff --git a/src/mednet/libs/segmentation/models/driu_bn.py b/src/mednet/libs/segmentation/models/driu_bn.py
index 3bb93ba9ee9221938a65af1c2d6fd62a1104d55d..f14f911bed17f3bd889b4e21ea148c2787dbd425 100644
--- a/src/mednet/libs/segmentation/models/driu_bn.py
+++ b/src/mednet/libs/segmentation/models/driu_bn.py
@@ -164,23 +164,23 @@ class DRIUBN(Model):
             super().set_normalizer(dataloader)
 
     def training_step(self, batch, batch_idx):
-        images = batch[0]
-        ground_truths = batch[1]["target"]
-        masks = batch[1]["mask"]
+        images = batch[0]["image"]
+        ground_truths = batch[0]["target"]
+        masks = batch[0]["mask"]
 
         outputs = self(self._augmentation_transforms(images))
         return self._train_loss(outputs, ground_truths, masks)
 
     def validation_step(self, batch, batch_idx):
-        images = batch[0]
-        ground_truths = batch[1]["target"]
-        masks = batch[1]["mask"]
+        images = batch[0]["image"]
+        ground_truths = batch[0]["target"]
+        masks = batch[0]["mask"]
 
         outputs = self(images)
         return self._validation_loss(outputs, ground_truths, masks)
 
     def predict_step(self, batch, batch_idx, dataloader_idx=0):
-        output = self(batch[0])[1]
+        output = self(batch[0]["image"])[1]
         return torch.sigmoid(output)
 
     def configure_optimizers(self):
diff --git a/src/mednet/libs/segmentation/models/driu_od.py b/src/mednet/libs/segmentation/models/driu_od.py
index 98e596236bddf19fcc1789f136f1e162ed76ecf3..d810c471fdc6d567e3a63d3b0c8534c86cea1681 100644
--- a/src/mednet/libs/segmentation/models/driu_od.py
+++ b/src/mednet/libs/segmentation/models/driu_od.py
@@ -146,23 +146,23 @@ class DRIUOD(Model):
             super().set_normalizer(dataloader)
 
     def training_step(self, batch, batch_idx):
-        images = batch[0]
-        ground_truths = batch[1]["target"]
-        masks = batch[1]["mask"]
+        images = batch[0]["image"]
+        ground_truths = batch[0]["target"]
+        masks = batch[0]["mask"]
 
         outputs = self(self._augmentation_transforms(images))
         return self._train_loss(outputs, ground_truths, masks)
 
     def validation_step(self, batch, batch_idx):
-        images = batch[0]
-        ground_truths = batch[1]["target"]
-        masks = batch[1]["mask"]
+        images = batch[0]["image"]
+        ground_truths = batch[0]["target"]
+        masks = batch[0]["mask"]
 
         outputs = self(images)
         return self._validation_loss(outputs, ground_truths, masks)
 
     def predict_step(self, batch, batch_idx, dataloader_idx=0):
-        output = self(batch[0])[1]
+        output = self(batch[0]["image"])[1]
         return torch.sigmoid(output)
 
     def configure_optimizers(self):
diff --git a/src/mednet/libs/segmentation/models/driu_pix.py b/src/mednet/libs/segmentation/models/driu_pix.py
index 6846da5a8f2700ecc4f2370cd485875016c1749b..a85ac3ab1753172b012b974933cd4f6abb5cac33 100644
--- a/src/mednet/libs/segmentation/models/driu_pix.py
+++ b/src/mednet/libs/segmentation/models/driu_pix.py
@@ -150,23 +150,23 @@ class DRIUPix(Model):
             super().set_normalizer(dataloader)
 
     def training_step(self, batch, batch_idx):
-        images = batch[0]
-        ground_truths = batch[1]["target"]
-        masks = batch[1]["mask"]
+        images = batch[0]["image"]
+        ground_truths = batch[0]["target"]
+        masks = batch[0]["mask"]
 
         outputs = self(self._augmentation_transforms(images))
         return self._train_loss(outputs, ground_truths, masks)
 
     def validation_step(self, batch, batch_idx):
-        images = batch[0]
-        ground_truths = batch[1]["target"]
-        masks = batch[1]["mask"]
+        images = batch[0]["image"]
+        ground_truths = batch[0]["target"]
+        masks = batch[0]["mask"]
 
         outputs = self(images)
         return self._validation_loss(outputs, ground_truths, masks)
 
     def predict_step(self, batch, batch_idx, dataloader_idx=0):
-        output = self(batch[0])[1]
+        output = self(batch[0]["image"])[1]
         return torch.sigmoid(output)
 
     def configure_optimizers(self):
diff --git a/src/mednet/libs/segmentation/models/hed.py b/src/mednet/libs/segmentation/models/hed.py
index 7e0b770513d0f8d5d94515827160f029cbf3346c..80e4665eea454ae6c8ef269ebde913bade71081b 100644
--- a/src/mednet/libs/segmentation/models/hed.py
+++ b/src/mednet/libs/segmentation/models/hed.py
@@ -165,23 +165,23 @@ class HED(Model):
             super().set_normalizer(dataloader)
 
     def training_step(self, batch, batch_idx):
-        images = batch[0]
-        ground_truths = batch[1]["target"]
-        masks = batch[1]["mask"]
+        images = batch[0]["image"]
+        ground_truths = batch[0]["target"]
+        masks = batch[0]["mask"]
 
         outputs = self(self._augmentation_transforms(images))
         return self._train_loss(outputs, ground_truths, masks)
 
     def validation_step(self, batch, batch_idx):
-        images = batch[0]
-        ground_truths = batch[1]["target"]
-        masks = batch[1]["mask"]
+        images = batch[0]["image"]
+        ground_truths = batch[0]["target"]
+        masks = batch[0]["mask"]
 
         outputs = self(images)
         return self._validation_loss(outputs, ground_truths, masks)
 
     def predict_step(self, batch, batch_idx, dataloader_idx=0):
-        output = self(batch[0])[1]
+        output = self(batch[0]["image"])[1]
         return torch.sigmoid(output)
 
     def configure_optimizers(self):
diff --git a/src/mednet/libs/segmentation/models/lwnet.py b/src/mednet/libs/segmentation/models/lwnet.py
index 28bdf498fd74705546bf39a57835b378e4144f4e..43bcb819d315849e376ba4e7e10b5d6407d3b3ec 100644
--- a/src/mednet/libs/segmentation/models/lwnet.py
+++ b/src/mednet/libs/segmentation/models/lwnet.py
@@ -366,23 +366,23 @@ class LittleWNet(Model):
         return x1, x2
 
     def training_step(self, batch, batch_idx):
-        images = batch[0]
-        ground_truths = batch[1]["target"]
-        masks = batch[1]["mask"]
+        images = batch[0]["image"]
+        ground_truths = batch[0]["target"]
+        masks = batch[0]["mask"]
 
         outputs = self(self._augmentation_transforms(images))
         return self._train_loss(outputs, ground_truths, masks)
 
     def validation_step(self, batch, batch_idx):
-        images = batch[0]
-        ground_truths = batch[1]["target"]
-        masks = batch[1]["mask"]
+        images = batch[0]["image"]
+        ground_truths = batch[0]["target"]
+        masks = batch[0]["mask"]
 
         outputs = self(images)
         return self._validation_loss(outputs, ground_truths, masks)
 
     def predict_step(self, batch, batch_idx, dataloader_idx=0):
-        output = self(batch[0])[1]
+        output = self(batch[0]["image"])[1]
         return torch.sigmoid(output)
 
     def configure_optimizers(self):
diff --git a/src/mednet/libs/segmentation/models/m2unet.py b/src/mednet/libs/segmentation/models/m2unet.py
index b371540945d8e8b3c481c84abf7ea6c3309995fa..ccc94c2991be732d3de83bb7c923b25231329cdf 100644
--- a/src/mednet/libs/segmentation/models/m2unet.py
+++ b/src/mednet/libs/segmentation/models/m2unet.py
@@ -213,23 +213,23 @@ class M2UNET(Model):
             super().set_normalizer(dataloader)
 
     def training_step(self, batch, batch_idx):
-        images = batch[0]
-        ground_truths = batch[1]["target"]
-        masks = batch[1]["mask"]
+        images = batch[0]["image"]
+        ground_truths = batch[0]["target"]
+        masks = batch[0]["mask"]
 
         outputs = self(self._augmentation_transforms(images))
         return self._train_loss(outputs, ground_truths, masks)
 
     def validation_step(self, batch, batch_idx):
-        images = batch[0]
-        ground_truths = batch[1]["target"]
-        masks = batch[1]["mask"]
+        images = batch[0]["image"]
+        ground_truths = batch[0]["target"]
+        masks = batch[0]["mask"]
 
         outputs = self(images)
         return self._validation_loss(outputs, ground_truths, masks)
 
     def predict_step(self, batch, batch_idx, dataloader_idx=0):
-        output = self(batch[0])[1]
+        output = self(batch[0]["image"])[1]
         return torch.sigmoid(output)
 
     def configure_optimizers(self):
diff --git a/src/mednet/libs/segmentation/models/unet.py b/src/mednet/libs/segmentation/models/unet.py
index 98578d1d437bbe53fbab48f31764a1580d5ad5eb..36136bcb8dfddb85d62285c3cf555776b8271cff 100644
--- a/src/mednet/libs/segmentation/models/unet.py
+++ b/src/mednet/libs/segmentation/models/unet.py
@@ -154,23 +154,23 @@ class Unet(Model):
             super().set_normalizer(dataloader)
 
     def training_step(self, batch, batch_idx):
-        images = batch[0]
-        ground_truths = batch[1]["target"]
-        masks = batch[1]["mask"]
+        images = batch[0]["image"]
+        ground_truths = batch[0]["target"]
+        masks = batch[0]["mask"]
 
         outputs = self(self._augmentation_transforms(images))
         return self._train_loss(outputs, ground_truths, masks)
 
     def validation_step(self, batch, batch_idx):
-        images = batch[0]
-        ground_truths = batch[1]["target"]
-        masks = batch[1]["mask"]
+        images = batch[0]["image"]
+        ground_truths = batch[0]["target"]
+        masks = batch[0]["mask"]
 
         outputs = self(images)
         return self._validation_loss(outputs, ground_truths, masks)
 
     def predict_step(self, batch, batch_idx, dataloader_idx=0):
-        output = self(batch[0])[1]
+        output = self(batch[0]["image"])[1]
         return torch.sigmoid(output)
 
     def configure_optimizers(self):