From 38d5407281f5796e33ccbcbf3ffd1058e18d5bc4 Mon Sep 17 00:00:00 2001
From: dcarron <daniel.carron@idiap.ch>
Date: Fri, 28 Jun 2024 10:25:36 +0200
Subject: [PATCH] [models] Move repeated functions into specialized models

---
 .../libs/classification/models/alexnet.py     | 37 -------------------
 .../models/classification_model.py            | 37 +++++++++++++++++++
 .../libs/classification/models/densenet.py    | 32 ----------------
 src/mednet/libs/classification/models/pasa.py | 31 ----------------
 src/mednet/libs/segmentation/models/driu.py   | 23 ------------
 .../libs/segmentation/models/driu_bn.py       | 23 ------------
 .../libs/segmentation/models/driu_od.py       | 23 ------------
 .../libs/segmentation/models/driu_pix.py      | 23 ------------
 src/mednet/libs/segmentation/models/hed.py    | 23 ------------
 src/mednet/libs/segmentation/models/lwnet.py  | 23 ------------
 src/mednet/libs/segmentation/models/m2unet.py | 23 ------------
 .../segmentation/models/segmentation_model.py | 20 ++++++++++
 12 files changed, 57 insertions(+), 261 deletions(-)

diff --git a/src/mednet/libs/classification/models/alexnet.py b/src/mednet/libs/classification/models/alexnet.py
index 51cabb0b..f930b625 100644
--- a/src/mednet/libs/classification/models/alexnet.py
+++ b/src/mednet/libs/classification/models/alexnet.py
@@ -117,40 +117,3 @@ class Alexnet(ClassificationModel):
             self.normalizer = make_imagenet_normalizer()
         else:
             super().set_normalizer(dataloader)
-
-    def training_step(self, batch, _):
-        images = batch[0]
-        labels = batch[1]["target"]
-
-        # Increase label dimension if too low
-        # Allows single and multiclass usage
-        if labels.ndim == 1:
-            labels = torch.reshape(labels, (labels.shape[0], 1))
-
-        # Forward pass on the network
-        outputs = self(self.augmentation_transforms(images))
-
-        return self._train_loss(outputs, labels.float())
-
-    def validation_step(self, batch, batch_idx, dataloader_idx=0):
-        images = batch[0]
-        labels = batch[1]["target"]
-
-        # Increase label dimension if too low
-        # Allows single and multiclass usage
-        if labels.ndim == 1:
-            labels = torch.reshape(labels, (labels.shape[0], 1))
-
-        # debug code to inspect images by eye:
-        # from torchvision.transforms.functional import to_pil_image
-        # for k in images:
-        #    to_pil_image(k).show()
-        #    __import__("pdb").set_trace()
-
-        # data forwarding on the existing network
-        outputs = self(images)
-        return self._validation_loss(outputs, labels.float())
-
-    def predict_step(self, batch, batch_idx, dataloader_idx=0):
-        outputs = self(batch[0])
-        return torch.sigmoid(outputs)
diff --git a/src/mednet/libs/classification/models/classification_model.py b/src/mednet/libs/classification/models/classification_model.py
index 30dd0a54..e9654b48 100644
--- a/src/mednet/libs/classification/models/classification_model.py
+++ b/src/mednet/libs/classification/models/classification_model.py
@@ -79,3 +79,40 @@ class ClassificationModel(Model):
             f"computing z-norm factors from train dataloader.",
         )
         self.normalizer = make_z_normalizer(dataloader)
+
+    def training_step(self, batch, _):
+        images = batch[0]
+        labels = batch[1]["target"]
+
+        # Increase label dimension if too low
+        # Allows single and multiclass usage
+        if labels.ndim == 1:
+            labels = torch.reshape(labels, (labels.shape[0], 1))
+
+        # Forward pass on the network
+        outputs = self(self.augmentation_transforms(images))
+
+        return self._train_loss(outputs, labels.float())
+
+    def validation_step(self, batch, batch_idx, dataloader_idx=0):
+        images = batch[0]
+        labels = batch[1]["target"]
+
+        # Increase label dimension if too low
+        # Allows single and multiclass usage
+        if labels.ndim == 1:
+            labels = torch.reshape(labels, (labels.shape[0], 1))
+
+        # debug code to inspect images by eye:
+        # from torchvision.transforms.functional import to_pil_image
+        # for k in images:
+        #    to_pil_image(k).show()
+        #    __import__("pdb").set_trace()
+
+        # data forwarding on the existing network
+        outputs = self(images)
+        return self._validation_loss(outputs, labels.float())
+
+    def predict_step(self, batch, batch_idx, dataloader_idx=0):
+        outputs = self(batch[0])
+        return torch.sigmoid(outputs)
diff --git a/src/mednet/libs/classification/models/densenet.py b/src/mednet/libs/classification/models/densenet.py
index 83969d2c..c40cc957 100644
--- a/src/mednet/libs/classification/models/densenet.py
+++ b/src/mednet/libs/classification/models/densenet.py
@@ -120,35 +120,3 @@ class Densenet(ClassificationModel):
             self.normalizer = make_imagenet_normalizer()
         else:
             super().set_normalizer(dataloader)
-
-    def training_step(self, batch, _):
-        images = batch[0]
-        labels = batch[1]["target"]
-
-        # Increase label dimension if too low
-        # Allows single and multiclass usage
-        if labels.ndim == 1:
-            labels = torch.reshape(labels, (labels.shape[0], 1))
-
-        # Forward pass on the network
-        outputs = self(self.augmentation_transforms(images))
-
-        return self._train_loss(outputs, labels.float())
-
-    def validation_step(self, batch, batch_idx, dataloader_idx=0):
-        images = batch[0]
-        labels = batch[1]["target"]
-
-        # Increase label dimension if too low
-        # Allows single and multiclass usage
-        if labels.ndim == 1:
-            labels = torch.reshape(labels, (labels.shape[0], 1))
-
-        # data forwarding on the existing network
-        outputs = self(images)
-
-        return self._validation_loss(outputs, labels.float())
-
-    def predict_step(self, batch, batch_idx, dataloader_idx=0):
-        outputs = self(batch[0])
-        return torch.sigmoid(outputs)
diff --git a/src/mednet/libs/classification/models/pasa.py b/src/mednet/libs/classification/models/pasa.py
index e7e15ddc..b17af172 100644
--- a/src/mednet/libs/classification/models/pasa.py
+++ b/src/mednet/libs/classification/models/pasa.py
@@ -192,34 +192,3 @@ class Pasa(ClassificationModel):
         return self.dense(x)
 
         # x = F.log_softmax(x, dim=1) # 0 is batch size
-
-    def training_step(self, batch, _):
-        images = batch[0]
-        labels = batch[1]["target"]
-
-        # Increase label dimension if too low
-        # Allows single and multiclass usage
-        if labels.ndim == 1:
-            labels = torch.reshape(labels, (labels.shape[0], 1))
-
-        # Forward pass on the network
-        outputs = self(self.augmentation_transforms(images))
-
-        return self._train_loss(outputs, labels.float())
-
-    def validation_step(self, batch, batch_idx, dataloader_idx=0):
-        images = batch[0]
-        labels = batch[1]["target"]
-
-        # Increase label dimension if too low
-        # Allows single and multiclass usage
-        if labels.ndim == 1:
-            labels = torch.reshape(labels, (labels.shape[0], 1))
-
-        # data forwarding on the existing network
-        outputs = self(images)
-        return self._validation_loss(outputs, labels.float())
-
-    def predict_step(self, batch, batch_idx, dataloader_idx=0):
-        outputs = self(batch[0])
-        return torch.sigmoid(outputs)
diff --git a/src/mednet/libs/segmentation/models/driu.py b/src/mednet/libs/segmentation/models/driu.py
index 41f0d58b..e12aac71 100644
--- a/src/mednet/libs/segmentation/models/driu.py
+++ b/src/mednet/libs/segmentation/models/driu.py
@@ -159,26 +159,3 @@ class DRIU(SegmentationModel):
             self.normalizer = make_imagenet_normalizer()
         else:
             super().set_normalizer(dataloader)
-
-    def training_step(self, batch, batch_idx):
-        images = batch[0]["image"]
-        ground_truths = batch[0]["target"]
-        masks = batch[0]["mask"]
-
-        outputs = self(self._augmentation_transforms(images))
-        return self._train_loss(outputs, ground_truths, masks)
-
-    def validation_step(self, batch, batch_idx):
-        images = batch[0]["image"]
-        ground_truths = batch[0]["target"]
-        masks = batch[0]["mask"]
-
-        outputs = self(images)
-        return self._validation_loss(outputs, ground_truths, masks)
-
-    def predict_step(self, batch, batch_idx, dataloader_idx=0):
-        output = self(batch[0]["image"])[1]
-        return torch.sigmoid(output)
-
-    def configure_optimizers(self):
-        return self._optimizer_type(self.parameters(), **self._optimizer_arguments)
diff --git a/src/mednet/libs/segmentation/models/driu_bn.py b/src/mednet/libs/segmentation/models/driu_bn.py
index 7cfd1817..555d4ef1 100644
--- a/src/mednet/libs/segmentation/models/driu_bn.py
+++ b/src/mednet/libs/segmentation/models/driu_bn.py
@@ -162,26 +162,3 @@ class DRIUBN(SegmentationModel):
             self.normalizer = make_imagenet_normalizer()
         else:
             super().set_normalizer(dataloader)
-
-    def training_step(self, batch, batch_idx):
-        images = batch[0]["image"]
-        ground_truths = batch[0]["target"]
-        masks = batch[0]["mask"]
-
-        outputs = self(self._augmentation_transforms(images))
-        return self._train_loss(outputs, ground_truths, masks)
-
-    def validation_step(self, batch, batch_idx):
-        images = batch[0]["image"]
-        ground_truths = batch[0]["target"]
-        masks = batch[0]["mask"]
-
-        outputs = self(images)
-        return self._validation_loss(outputs, ground_truths, masks)
-
-    def predict_step(self, batch, batch_idx, dataloader_idx=0):
-        output = self(batch[0]["image"])[1]
-        return torch.sigmoid(output)
-
-    def configure_optimizers(self):
-        return self._optimizer_type(self.parameters(), **self._optimizer_arguments)
diff --git a/src/mednet/libs/segmentation/models/driu_od.py b/src/mednet/libs/segmentation/models/driu_od.py
index 9a2cffac..3b0cff03 100644
--- a/src/mednet/libs/segmentation/models/driu_od.py
+++ b/src/mednet/libs/segmentation/models/driu_od.py
@@ -144,26 +144,3 @@ class DRIUOD(SegmentationModel):
             self.normalizer = make_imagenet_normalizer()
         else:
             super().set_normalizer(dataloader)
-
-    def training_step(self, batch, batch_idx):
-        images = batch[0]["image"]
-        ground_truths = batch[0]["target"]
-        masks = batch[0]["mask"]
-
-        outputs = self(self._augmentation_transforms(images))
-        return self._train_loss(outputs, ground_truths, masks)
-
-    def validation_step(self, batch, batch_idx):
-        images = batch[0]["image"]
-        ground_truths = batch[0]["target"]
-        masks = batch[0]["mask"]
-
-        outputs = self(images)
-        return self._validation_loss(outputs, ground_truths, masks)
-
-    def predict_step(self, batch, batch_idx, dataloader_idx=0):
-        output = self(batch[0]["image"])[1]
-        return torch.sigmoid(output)
-
-    def configure_optimizers(self):
-        return self._optimizer_type(self.parameters(), **self._optimizer_arguments)
diff --git a/src/mednet/libs/segmentation/models/driu_pix.py b/src/mednet/libs/segmentation/models/driu_pix.py
index 80033176..75adce0c 100644
--- a/src/mednet/libs/segmentation/models/driu_pix.py
+++ b/src/mednet/libs/segmentation/models/driu_pix.py
@@ -148,26 +148,3 @@ class DRIUPix(SegmentationModel):
             self.normalizer = make_imagenet_normalizer()
         else:
             super().set_normalizer(dataloader)
-
-    def training_step(self, batch, batch_idx):
-        images = batch[0]["image"]
-        ground_truths = batch[0]["target"]
-        masks = batch[0]["mask"]
-
-        outputs = self(self._augmentation_transforms(images))
-        return self._train_loss(outputs, ground_truths, masks)
-
-    def validation_step(self, batch, batch_idx):
-        images = batch[0]["image"]
-        ground_truths = batch[0]["target"]
-        masks = batch[0]["mask"]
-
-        outputs = self(images)
-        return self._validation_loss(outputs, ground_truths, masks)
-
-    def predict_step(self, batch, batch_idx, dataloader_idx=0):
-        output = self(batch[0]["image"])[1]
-        return torch.sigmoid(output)
-
-    def configure_optimizers(self):
-        return self._optimizer_type(self.parameters(), **self._optimizer_arguments)
diff --git a/src/mednet/libs/segmentation/models/hed.py b/src/mednet/libs/segmentation/models/hed.py
index 7ba67339..97a663c5 100644
--- a/src/mednet/libs/segmentation/models/hed.py
+++ b/src/mednet/libs/segmentation/models/hed.py
@@ -163,26 +163,3 @@ class HED(SegmentationModel):
             self.normalizer = make_imagenet_normalizer()
         else:
             super().set_normalizer(dataloader)
-
-    def training_step(self, batch, batch_idx):
-        images = batch[0]["image"]
-        ground_truths = batch[0]["target"]
-        masks = batch[0]["mask"]
-
-        outputs = self(self._augmentation_transforms(images))
-        return self._train_loss(outputs, ground_truths, masks)
-
-    def validation_step(self, batch, batch_idx):
-        images = batch[0]["image"]
-        ground_truths = batch[0]["target"]
-        masks = batch[0]["mask"]
-
-        outputs = self(images)
-        return self._validation_loss(outputs, ground_truths, masks)
-
-    def predict_step(self, batch, batch_idx, dataloader_idx=0):
-        output = self(batch[0]["image"])[1]
-        return torch.sigmoid(output)
-
-    def configure_optimizers(self):
-        return self._optimizer_type(self.parameters(), **self._optimizer_arguments)
diff --git a/src/mednet/libs/segmentation/models/lwnet.py b/src/mednet/libs/segmentation/models/lwnet.py
index e1948a59..adb17512 100644
--- a/src/mednet/libs/segmentation/models/lwnet.py
+++ b/src/mednet/libs/segmentation/models/lwnet.py
@@ -348,26 +348,3 @@ class LittleWNet(SegmentationModel):
         x2 = self.unet2(torch.cat([xn, x1], dim=1))
 
         return x1, x2
-
-    def training_step(self, batch, batch_idx):
-        images = batch[0]["image"]
-        ground_truths = batch[0]["target"]
-        masks = batch[0]["mask"]
-
-        outputs = self(self._augmentation_transforms(images))
-        return self._train_loss(outputs, ground_truths, masks)
-
-    def validation_step(self, batch, batch_idx):
-        images = batch[0]["image"]
-        ground_truths = batch[0]["target"]
-        masks = batch[0]["mask"]
-
-        outputs = self(images)
-        return self._validation_loss(outputs, ground_truths, masks)
-
-    def predict_step(self, batch, batch_idx, dataloader_idx=0):
-        output = self(batch[0]["image"])[1]
-        return torch.sigmoid(output)
-
-    def configure_optimizers(self):
-        return self._optimizer_type(self.parameters(), **self._optimizer_arguments)
diff --git a/src/mednet/libs/segmentation/models/m2unet.py b/src/mednet/libs/segmentation/models/m2unet.py
index 5dfb3331..cf7ce5c4 100644
--- a/src/mednet/libs/segmentation/models/m2unet.py
+++ b/src/mednet/libs/segmentation/models/m2unet.py
@@ -211,26 +211,3 @@ class M2UNET(SegmentationModel):
             self.normalizer = make_imagenet_normalizer()
         else:
             super().set_normalizer(dataloader)
-
-    def training_step(self, batch, batch_idx):
-        images = batch[0]["image"]
-        ground_truths = batch[0]["target"]
-        masks = batch[0]["mask"]
-
-        outputs = self(self._augmentation_transforms(images))
-        return self._train_loss(outputs, ground_truths, masks)
-
-    def validation_step(self, batch, batch_idx):
-        images = batch[0]["image"]
-        ground_truths = batch[0]["target"]
-        masks = batch[0]["mask"]
-
-        outputs = self(images)
-        return self._validation_loss(outputs, ground_truths, masks)
-
-    def predict_step(self, batch, batch_idx, dataloader_idx=0):
-        output = self(batch[0]["image"])[1]
-        return torch.sigmoid(output)
-
-    def configure_optimizers(self):
-        return self._optimizer_type(self.parameters(), **self._optimizer_arguments)
diff --git a/src/mednet/libs/segmentation/models/segmentation_model.py b/src/mednet/libs/segmentation/models/segmentation_model.py
index 1bc7135d..8d8aa3ca 100644
--- a/src/mednet/libs/segmentation/models/segmentation_model.py
+++ b/src/mednet/libs/segmentation/models/segmentation_model.py
@@ -80,3 +80,23 @@ class SegmentationModel(Model):
             f"computing z-norm factors from train dataloader.",
         )
         self.normalizer = make_z_normalizer(dataloader)
+
+    def training_step(self, batch, batch_idx):
+        images = batch[0]["image"]
+        ground_truths = batch[0]["target"]
+        masks = batch[0]["mask"]
+
+        outputs = self(self._augmentation_transforms(images))
+        return self._train_loss(outputs, ground_truths, masks)
+
+    def validation_step(self, batch, batch_idx):
+        images = batch[0]["image"]
+        ground_truths = batch[0]["target"]
+        masks = batch[0]["mask"]
+
+        outputs = self(images)
+        return self._validation_loss(outputs, ground_truths, masks)
+
+    def predict_step(self, batch, batch_idx, dataloader_idx=0):
+        output = self(batch[0]["image"])[1]
+        return torch.sigmoid(output)
-- 
GitLab