From a12a332a55fbccbe7cd7048f29ab3f87d7750101 Mon Sep 17 00:00:00 2001
From: dcarron <daniel.carron@idiap.ch>
Date: Wed, 26 Jun 2024 10:27:54 +0200
Subject: [PATCH] [segmentation.models] Update models to handle updated samples

---
 src/mednet/libs/segmentation/models/driu.py     | 14 +++++++-------
 src/mednet/libs/segmentation/models/driu_bn.py  | 14 +++++++-------
 src/mednet/libs/segmentation/models/driu_od.py  | 14 +++++++-------
 src/mednet/libs/segmentation/models/driu_pix.py | 14 +++++++-------
 src/mednet/libs/segmentation/models/hed.py      | 14 +++++++-------
 src/mednet/libs/segmentation/models/lwnet.py    | 14 +++++++-------
 src/mednet/libs/segmentation/models/m2unet.py   | 14 +++++++-------
 src/mednet/libs/segmentation/models/unet.py     | 14 +++++++-------
 8 files changed, 56 insertions(+), 56 deletions(-)

diff --git a/src/mednet/libs/segmentation/models/driu.py b/src/mednet/libs/segmentation/models/driu.py
index f3a2e0d2..76608301 100644
--- a/src/mednet/libs/segmentation/models/driu.py
+++ b/src/mednet/libs/segmentation/models/driu.py
@@ -161,23 +161,23 @@ class DRIU(Model):
             super().set_normalizer(dataloader)
 
     def training_step(self, batch, batch_idx):
-        images = batch[0]
-        ground_truths = batch[1]["target"]
-        masks = batch[1]["mask"]
+        images = batch[0]["image"]
+        ground_truths = batch[0]["target"]
+        masks = batch[0]["mask"]
 
         outputs = self(self._augmentation_transforms(images))
         return self._train_loss(outputs, ground_truths, masks)
 
     def validation_step(self, batch, batch_idx):
-        images = batch[0]
-        ground_truths = batch[1]["target"]
-        masks = batch[1]["mask"]
+        images = batch[0]["image"]
+        ground_truths = batch[0]["target"]
+        masks = batch[0]["mask"]
 
         outputs = self(images)
         return self._validation_loss(outputs, ground_truths, masks)
 
     def predict_step(self, batch, batch_idx, dataloader_idx=0):
-        output = self(batch[0])[1]
+        output = self(batch[0]["image"])[1]
         return torch.sigmoid(output)
 
     def configure_optimizers(self):
diff --git a/src/mednet/libs/segmentation/models/driu_bn.py b/src/mednet/libs/segmentation/models/driu_bn.py
index 3bb93ba9..f14f911b 100644
--- a/src/mednet/libs/segmentation/models/driu_bn.py
+++ b/src/mednet/libs/segmentation/models/driu_bn.py
@@ -164,23 +164,23 @@ class DRIUBN(Model):
             super().set_normalizer(dataloader)
 
     def training_step(self, batch, batch_idx):
-        images = batch[0]
-        ground_truths = batch[1]["target"]
-        masks = batch[1]["mask"]
+        images = batch[0]["image"]
+        ground_truths = batch[0]["target"]
+        masks = batch[0]["mask"]
 
         outputs = self(self._augmentation_transforms(images))
         return self._train_loss(outputs, ground_truths, masks)
 
     def validation_step(self, batch, batch_idx):
-        images = batch[0]
-        ground_truths = batch[1]["target"]
-        masks = batch[1]["mask"]
+        images = batch[0]["image"]
+        ground_truths = batch[0]["target"]
+        masks = batch[0]["mask"]
 
         outputs = self(images)
         return self._validation_loss(outputs, ground_truths, masks)
 
     def predict_step(self, batch, batch_idx, dataloader_idx=0):
-        output = self(batch[0])[1]
+        output = self(batch[0]["image"])[1]
         return torch.sigmoid(output)
 
     def configure_optimizers(self):
diff --git a/src/mednet/libs/segmentation/models/driu_od.py b/src/mednet/libs/segmentation/models/driu_od.py
index 98e59623..d810c471 100644
--- a/src/mednet/libs/segmentation/models/driu_od.py
+++ b/src/mednet/libs/segmentation/models/driu_od.py
@@ -146,23 +146,23 @@ class DRIUOD(Model):
             super().set_normalizer(dataloader)
 
     def training_step(self, batch, batch_idx):
-        images = batch[0]
-        ground_truths = batch[1]["target"]
-        masks = batch[1]["mask"]
+        images = batch[0]["image"]
+        ground_truths = batch[0]["target"]
+        masks = batch[0]["mask"]
 
         outputs = self(self._augmentation_transforms(images))
         return self._train_loss(outputs, ground_truths, masks)
 
     def validation_step(self, batch, batch_idx):
-        images = batch[0]
-        ground_truths = batch[1]["target"]
-        masks = batch[1]["mask"]
+        images = batch[0]["image"]
+        ground_truths = batch[0]["target"]
+        masks = batch[0]["mask"]
 
         outputs = self(images)
         return self._validation_loss(outputs, ground_truths, masks)
 
     def predict_step(self, batch, batch_idx, dataloader_idx=0):
-        output = self(batch[0])[1]
+        output = self(batch[0]["image"])[1]
         return torch.sigmoid(output)
 
     def configure_optimizers(self):
diff --git a/src/mednet/libs/segmentation/models/driu_pix.py b/src/mednet/libs/segmentation/models/driu_pix.py
index 6846da5a..a85ac3ab 100644
--- a/src/mednet/libs/segmentation/models/driu_pix.py
+++ b/src/mednet/libs/segmentation/models/driu_pix.py
@@ -150,23 +150,23 @@ class DRIUPix(Model):
             super().set_normalizer(dataloader)
 
     def training_step(self, batch, batch_idx):
-        images = batch[0]
-        ground_truths = batch[1]["target"]
-        masks = batch[1]["mask"]
+        images = batch[0]["image"]
+        ground_truths = batch[0]["target"]
+        masks = batch[0]["mask"]
 
         outputs = self(self._augmentation_transforms(images))
         return self._train_loss(outputs, ground_truths, masks)
 
     def validation_step(self, batch, batch_idx):
-        images = batch[0]
-        ground_truths = batch[1]["target"]
-        masks = batch[1]["mask"]
+        images = batch[0]["image"]
+        ground_truths = batch[0]["target"]
+        masks = batch[0]["mask"]
 
         outputs = self(images)
         return self._validation_loss(outputs, ground_truths, masks)
 
     def predict_step(self, batch, batch_idx, dataloader_idx=0):
-        output = self(batch[0])[1]
+        output = self(batch[0]["image"])[1]
         return torch.sigmoid(output)
 
     def configure_optimizers(self):
diff --git a/src/mednet/libs/segmentation/models/hed.py b/src/mednet/libs/segmentation/models/hed.py
index 7e0b7705..80e4665e 100644
--- a/src/mednet/libs/segmentation/models/hed.py
+++ b/src/mednet/libs/segmentation/models/hed.py
@@ -165,23 +165,23 @@ class HED(Model):
             super().set_normalizer(dataloader)
 
     def training_step(self, batch, batch_idx):
-        images = batch[0]
-        ground_truths = batch[1]["target"]
-        masks = batch[1]["mask"]
+        images = batch[0]["image"]
+        ground_truths = batch[0]["target"]
+        masks = batch[0]["mask"]
 
         outputs = self(self._augmentation_transforms(images))
         return self._train_loss(outputs, ground_truths, masks)
 
     def validation_step(self, batch, batch_idx):
-        images = batch[0]
-        ground_truths = batch[1]["target"]
-        masks = batch[1]["mask"]
+        images = batch[0]["image"]
+        ground_truths = batch[0]["target"]
+        masks = batch[0]["mask"]
 
         outputs = self(images)
         return self._validation_loss(outputs, ground_truths, masks)
 
     def predict_step(self, batch, batch_idx, dataloader_idx=0):
-        output = self(batch[0])[1]
+        output = self(batch[0]["image"])[1]
         return torch.sigmoid(output)
 
     def configure_optimizers(self):
diff --git a/src/mednet/libs/segmentation/models/lwnet.py b/src/mednet/libs/segmentation/models/lwnet.py
index 28bdf498..43bcb819 100644
--- a/src/mednet/libs/segmentation/models/lwnet.py
+++ b/src/mednet/libs/segmentation/models/lwnet.py
@@ -366,23 +366,23 @@ class LittleWNet(Model):
         return x1, x2
 
     def training_step(self, batch, batch_idx):
-        images = batch[0]
-        ground_truths = batch[1]["target"]
-        masks = batch[1]["mask"]
+        images = batch[0]["image"]
+        ground_truths = batch[0]["target"]
+        masks = batch[0]["mask"]
 
         outputs = self(self._augmentation_transforms(images))
         return self._train_loss(outputs, ground_truths, masks)
 
     def validation_step(self, batch, batch_idx):
-        images = batch[0]
-        ground_truths = batch[1]["target"]
-        masks = batch[1]["mask"]
+        images = batch[0]["image"]
+        ground_truths = batch[0]["target"]
+        masks = batch[0]["mask"]
 
         outputs = self(images)
         return self._validation_loss(outputs, ground_truths, masks)
 
     def predict_step(self, batch, batch_idx, dataloader_idx=0):
-        output = self(batch[0])[1]
+        output = self(batch[0]["image"])[1]
         return torch.sigmoid(output)
 
     def configure_optimizers(self):
diff --git a/src/mednet/libs/segmentation/models/m2unet.py b/src/mednet/libs/segmentation/models/m2unet.py
index b3715409..ccc94c29 100644
--- a/src/mednet/libs/segmentation/models/m2unet.py
+++ b/src/mednet/libs/segmentation/models/m2unet.py
@@ -213,23 +213,23 @@ class M2UNET(Model):
             super().set_normalizer(dataloader)
 
     def training_step(self, batch, batch_idx):
-        images = batch[0]
-        ground_truths = batch[1]["target"]
-        masks = batch[1]["mask"]
+        images = batch[0]["image"]
+        ground_truths = batch[0]["target"]
+        masks = batch[0]["mask"]
 
         outputs = self(self._augmentation_transforms(images))
         return self._train_loss(outputs, ground_truths, masks)
 
     def validation_step(self, batch, batch_idx):
-        images = batch[0]
-        ground_truths = batch[1]["target"]
-        masks = batch[1]["mask"]
+        images = batch[0]["image"]
+        ground_truths = batch[0]["target"]
+        masks = batch[0]["mask"]
 
         outputs = self(images)
         return self._validation_loss(outputs, ground_truths, masks)
 
     def predict_step(self, batch, batch_idx, dataloader_idx=0):
-        output = self(batch[0])[1]
+        output = self(batch[0]["image"])[1]
         return torch.sigmoid(output)
 
     def configure_optimizers(self):
diff --git a/src/mednet/libs/segmentation/models/unet.py b/src/mednet/libs/segmentation/models/unet.py
index 98578d1d..36136bcb 100644
--- a/src/mednet/libs/segmentation/models/unet.py
+++ b/src/mednet/libs/segmentation/models/unet.py
@@ -154,23 +154,23 @@ class Unet(Model):
             super().set_normalizer(dataloader)
 
     def training_step(self, batch, batch_idx):
-        images = batch[0]
-        ground_truths = batch[1]["target"]
-        masks = batch[1]["mask"]
+        images = batch[0]["image"]
+        ground_truths = batch[0]["target"]
+        masks = batch[0]["mask"]
 
         outputs = self(self._augmentation_transforms(images))
         return self._train_loss(outputs, ground_truths, masks)
 
     def validation_step(self, batch, batch_idx):
-        images = batch[0]
-        ground_truths = batch[1]["target"]
-        masks = batch[1]["mask"]
+        images = batch[0]["image"]
+        ground_truths = batch[0]["target"]
+        masks = batch[0]["mask"]
 
         outputs = self(images)
         return self._validation_loss(outputs, ground_truths, masks)
 
     def predict_step(self, batch, batch_idx, dataloader_idx=0):
-        output = self(batch[0])[1]
+        output = self(batch[0]["image"])[1]
         return torch.sigmoid(output)
 
     def configure_optimizers(self):
-- 
GitLab