diff --git a/src/mednet/libs/classification/models/alexnet.py b/src/mednet/libs/classification/models/alexnet.py
index 51cabb0b0c3b5e1c8ca762e654f92f6e557023c4..f930b625e96361182b8b9240e2fd2c82194328f5 100644
--- a/src/mednet/libs/classification/models/alexnet.py
+++ b/src/mednet/libs/classification/models/alexnet.py
@@ -117,40 +117,3 @@ class Alexnet(ClassificationModel):
             self.normalizer = make_imagenet_normalizer()
         else:
             super().set_normalizer(dataloader)
-
-    def training_step(self, batch, _):
-        images = batch[0]
-        labels = batch[1]["target"]
-
-        # Increase label dimension if too low
-        # Allows single and multiclass usage
-        if labels.ndim == 1:
-            labels = torch.reshape(labels, (labels.shape[0], 1))
-
-        # Forward pass on the network
-        outputs = self(self.augmentation_transforms(images))
-
-        return self._train_loss(outputs, labels.float())
-
-    def validation_step(self, batch, batch_idx, dataloader_idx=0):
-        images = batch[0]
-        labels = batch[1]["target"]
-
-        # Increase label dimension if too low
-        # Allows single and multiclass usage
-        if labels.ndim == 1:
-            labels = torch.reshape(labels, (labels.shape[0], 1))
-
-        # debug code to inspect images by eye:
-        # from torchvision.transforms.functional import to_pil_image
-        # for k in images:
-        #    to_pil_image(k).show()
-        #    __import__("pdb").set_trace()
-
-        # data forwarding on the existing network
-        outputs = self(images)
-        return self._validation_loss(outputs, labels.float())
-
-    def predict_step(self, batch, batch_idx, dataloader_idx=0):
-        outputs = self(batch[0])
-        return torch.sigmoid(outputs)
diff --git a/src/mednet/libs/classification/models/classification_model.py b/src/mednet/libs/classification/models/classification_model.py
index 30dd0a54d08fb3fe193699274a4d82ea7b402167..e9654b480b67c7dc8b2ec249ed1921822383d33b 100644
--- a/src/mednet/libs/classification/models/classification_model.py
+++ b/src/mednet/libs/classification/models/classification_model.py
@@ -79,3 +79,40 @@ class ClassificationModel(Model):
             f"computing z-norm factors from train dataloader.",
         )
         self.normalizer = make_z_normalizer(dataloader)
+
+    def training_step(self, batch, _):
+        images = batch[0]
+        labels = batch[1]["target"]
+
+        # Increase label dimension if too low
+        # Allows single and multiclass usage
+        if labels.ndim == 1:
+            labels = torch.reshape(labels, (labels.shape[0], 1))
+
+        # Forward pass on the network
+        outputs = self(self.augmentation_transforms(images))
+
+        return self._train_loss(outputs, labels.float())
+
+    def validation_step(self, batch, batch_idx, dataloader_idx=0):
+        images = batch[0]
+        labels = batch[1]["target"]
+
+        # Increase label dimension if too low
+        # Allows single and multiclass usage
+        if labels.ndim == 1:
+            labels = torch.reshape(labels, (labels.shape[0], 1))
+
+        # debug code to inspect images by eye:
+        # from torchvision.transforms.functional import to_pil_image
+        # for k in images:
+        #    to_pil_image(k).show()
+        #    __import__("pdb").set_trace()
+
+        # data forwarding on the existing network
+        outputs = self(images)
+        return self._validation_loss(outputs, labels.float())
+
+    def predict_step(self, batch, batch_idx, dataloader_idx=0):
+        outputs = self(batch[0])
+        return torch.sigmoid(outputs)
diff --git a/src/mednet/libs/classification/models/densenet.py b/src/mednet/libs/classification/models/densenet.py
index 83969d2c30ce4daaa08680591faaea7154db290f..c40cc957c5ca8f284b6c3f4832bb1b7d54052bda 100644
--- a/src/mednet/libs/classification/models/densenet.py
+++ b/src/mednet/libs/classification/models/densenet.py
@@ -120,35 +120,3 @@ class Densenet(ClassificationModel):
             self.normalizer = make_imagenet_normalizer()
         else:
             super().set_normalizer(dataloader)
-
-    def training_step(self, batch, _):
-        images = batch[0]
-        labels = batch[1]["target"]
-
-        # Increase label dimension if too low
-        # Allows single and multiclass usage
-        if labels.ndim == 1:
-            labels = torch.reshape(labels, (labels.shape[0], 1))
-
-        # Forward pass on the network
-        outputs = self(self.augmentation_transforms(images))
-
-        return self._train_loss(outputs, labels.float())
-
-    def validation_step(self, batch, batch_idx, dataloader_idx=0):
-        images = batch[0]
-        labels = batch[1]["target"]
-
-        # Increase label dimension if too low
-        # Allows single and multiclass usage
-        if labels.ndim == 1:
-            labels = torch.reshape(labels, (labels.shape[0], 1))
-
-        # data forwarding on the existing network
-        outputs = self(images)
-
-        return self._validation_loss(outputs, labels.float())
-
-    def predict_step(self, batch, batch_idx, dataloader_idx=0):
-        outputs = self(batch[0])
-        return torch.sigmoid(outputs)
diff --git a/src/mednet/libs/classification/models/pasa.py b/src/mednet/libs/classification/models/pasa.py
index e7e15ddca050c661c118c1964bb2ff04cecd2e8f..b17af1724e97c18e6b58b1c25f9db3b3e8c167cb 100644
--- a/src/mednet/libs/classification/models/pasa.py
+++ b/src/mednet/libs/classification/models/pasa.py
@@ -192,34 +192,3 @@ class Pasa(ClassificationModel):
         return self.dense(x)
 
         # x = F.log_softmax(x, dim=1) # 0 is batch size
-
-    def training_step(self, batch, _):
-        images = batch[0]
-        labels = batch[1]["target"]
-
-        # Increase label dimension if too low
-        # Allows single and multiclass usage
-        if labels.ndim == 1:
-            labels = torch.reshape(labels, (labels.shape[0], 1))
-
-        # Forward pass on the network
-        outputs = self(self.augmentation_transforms(images))
-
-        return self._train_loss(outputs, labels.float())
-
-    def validation_step(self, batch, batch_idx, dataloader_idx=0):
-        images = batch[0]
-        labels = batch[1]["target"]
-
-        # Increase label dimension if too low
-        # Allows single and multiclass usage
-        if labels.ndim == 1:
-            labels = torch.reshape(labels, (labels.shape[0], 1))
-
-        # data forwarding on the existing network
-        outputs = self(images)
-        return self._validation_loss(outputs, labels.float())
-
-    def predict_step(self, batch, batch_idx, dataloader_idx=0):
-        outputs = self(batch[0])
-        return torch.sigmoid(outputs)
diff --git a/src/mednet/libs/segmentation/models/driu.py b/src/mednet/libs/segmentation/models/driu.py
index 41f0d58b65f78607097f311b379d964082d32b0a..e12aac71517ff42a6caf6b2ce052da2be8f32db0 100644
--- a/src/mednet/libs/segmentation/models/driu.py
+++ b/src/mednet/libs/segmentation/models/driu.py
@@ -159,26 +159,3 @@ class DRIU(SegmentationModel):
             self.normalizer = make_imagenet_normalizer()
         else:
             super().set_normalizer(dataloader)
-
-    def training_step(self, batch, batch_idx):
-        images = batch[0]["image"]
-        ground_truths = batch[0]["target"]
-        masks = batch[0]["mask"]
-
-        outputs = self(self._augmentation_transforms(images))
-        return self._train_loss(outputs, ground_truths, masks)
-
-    def validation_step(self, batch, batch_idx):
-        images = batch[0]["image"]
-        ground_truths = batch[0]["target"]
-        masks = batch[0]["mask"]
-
-        outputs = self(images)
-        return self._validation_loss(outputs, ground_truths, masks)
-
-    def predict_step(self, batch, batch_idx, dataloader_idx=0):
-        output = self(batch[0]["image"])[1]
-        return torch.sigmoid(output)
-
-    def configure_optimizers(self):
-        return self._optimizer_type(self.parameters(), **self._optimizer_arguments)
diff --git a/src/mednet/libs/segmentation/models/driu_bn.py b/src/mednet/libs/segmentation/models/driu_bn.py
index 7cfd18178b40a7b49c62da8337ca200b3ff87d49..555d4ef1ffc03e6d88ba81c60326df0cfc745a74 100644
--- a/src/mednet/libs/segmentation/models/driu_bn.py
+++ b/src/mednet/libs/segmentation/models/driu_bn.py
@@ -162,26 +162,3 @@ class DRIUBN(SegmentationModel):
             self.normalizer = make_imagenet_normalizer()
         else:
             super().set_normalizer(dataloader)
-
-    def training_step(self, batch, batch_idx):
-        images = batch[0]["image"]
-        ground_truths = batch[0]["target"]
-        masks = batch[0]["mask"]
-
-        outputs = self(self._augmentation_transforms(images))
-        return self._train_loss(outputs, ground_truths, masks)
-
-    def validation_step(self, batch, batch_idx):
-        images = batch[0]["image"]
-        ground_truths = batch[0]["target"]
-        masks = batch[0]["mask"]
-
-        outputs = self(images)
-        return self._validation_loss(outputs, ground_truths, masks)
-
-    def predict_step(self, batch, batch_idx, dataloader_idx=0):
-        output = self(batch[0]["image"])[1]
-        return torch.sigmoid(output)
-
-    def configure_optimizers(self):
-        return self._optimizer_type(self.parameters(), **self._optimizer_arguments)
diff --git a/src/mednet/libs/segmentation/models/driu_od.py b/src/mednet/libs/segmentation/models/driu_od.py
index 9a2cffac96ccfa6f1afac8cbee9376cef0fde66c..3b0cff036c843835a54e566206c045569e4bcd9c 100644
--- a/src/mednet/libs/segmentation/models/driu_od.py
+++ b/src/mednet/libs/segmentation/models/driu_od.py
@@ -144,26 +144,3 @@ class DRIUOD(SegmentationModel):
             self.normalizer = make_imagenet_normalizer()
         else:
             super().set_normalizer(dataloader)
-
-    def training_step(self, batch, batch_idx):
-        images = batch[0]["image"]
-        ground_truths = batch[0]["target"]
-        masks = batch[0]["mask"]
-
-        outputs = self(self._augmentation_transforms(images))
-        return self._train_loss(outputs, ground_truths, masks)
-
-    def validation_step(self, batch, batch_idx):
-        images = batch[0]["image"]
-        ground_truths = batch[0]["target"]
-        masks = batch[0]["mask"]
-
-        outputs = self(images)
-        return self._validation_loss(outputs, ground_truths, masks)
-
-    def predict_step(self, batch, batch_idx, dataloader_idx=0):
-        output = self(batch[0]["image"])[1]
-        return torch.sigmoid(output)
-
-    def configure_optimizers(self):
-        return self._optimizer_type(self.parameters(), **self._optimizer_arguments)
diff --git a/src/mednet/libs/segmentation/models/driu_pix.py b/src/mednet/libs/segmentation/models/driu_pix.py
index 8003317650389e3b1c1e31e3cb13304feddcda15..75adce0c91da0fd53ff9386b3ffa08bc187fa72f 100644
--- a/src/mednet/libs/segmentation/models/driu_pix.py
+++ b/src/mednet/libs/segmentation/models/driu_pix.py
@@ -148,26 +148,3 @@ class DRIUPix(SegmentationModel):
             self.normalizer = make_imagenet_normalizer()
         else:
             super().set_normalizer(dataloader)
-
-    def training_step(self, batch, batch_idx):
-        images = batch[0]["image"]
-        ground_truths = batch[0]["target"]
-        masks = batch[0]["mask"]
-
-        outputs = self(self._augmentation_transforms(images))
-        return self._train_loss(outputs, ground_truths, masks)
-
-    def validation_step(self, batch, batch_idx):
-        images = batch[0]["image"]
-        ground_truths = batch[0]["target"]
-        masks = batch[0]["mask"]
-
-        outputs = self(images)
-        return self._validation_loss(outputs, ground_truths, masks)
-
-    def predict_step(self, batch, batch_idx, dataloader_idx=0):
-        output = self(batch[0]["image"])[1]
-        return torch.sigmoid(output)
-
-    def configure_optimizers(self):
-        return self._optimizer_type(self.parameters(), **self._optimizer_arguments)
diff --git a/src/mednet/libs/segmentation/models/hed.py b/src/mednet/libs/segmentation/models/hed.py
index 7ba67339a412fc5499f28404dbbece60680ea209..97a663c51cc3fa649f70f1c0902fad712babb916 100644
--- a/src/mednet/libs/segmentation/models/hed.py
+++ b/src/mednet/libs/segmentation/models/hed.py
@@ -163,26 +163,3 @@ class HED(SegmentationModel):
             self.normalizer = make_imagenet_normalizer()
         else:
             super().set_normalizer(dataloader)
-
-    def training_step(self, batch, batch_idx):
-        images = batch[0]["image"]
-        ground_truths = batch[0]["target"]
-        masks = batch[0]["mask"]
-
-        outputs = self(self._augmentation_transforms(images))
-        return self._train_loss(outputs, ground_truths, masks)
-
-    def validation_step(self, batch, batch_idx):
-        images = batch[0]["image"]
-        ground_truths = batch[0]["target"]
-        masks = batch[0]["mask"]
-
-        outputs = self(images)
-        return self._validation_loss(outputs, ground_truths, masks)
-
-    def predict_step(self, batch, batch_idx, dataloader_idx=0):
-        output = self(batch[0]["image"])[1]
-        return torch.sigmoid(output)
-
-    def configure_optimizers(self):
-        return self._optimizer_type(self.parameters(), **self._optimizer_arguments)
diff --git a/src/mednet/libs/segmentation/models/lwnet.py b/src/mednet/libs/segmentation/models/lwnet.py
index e1948a594dd31af9180d4b713474894ac11f6243..adb17512bcd3953efd65c6dc42b382d0b99f12ea 100644
--- a/src/mednet/libs/segmentation/models/lwnet.py
+++ b/src/mednet/libs/segmentation/models/lwnet.py
@@ -348,26 +348,3 @@ class LittleWNet(SegmentationModel):
         x2 = self.unet2(torch.cat([xn, x1], dim=1))
 
         return x1, x2
-
-    def training_step(self, batch, batch_idx):
-        images = batch[0]["image"]
-        ground_truths = batch[0]["target"]
-        masks = batch[0]["mask"]
-
-        outputs = self(self._augmentation_transforms(images))
-        return self._train_loss(outputs, ground_truths, masks)
-
-    def validation_step(self, batch, batch_idx):
-        images = batch[0]["image"]
-        ground_truths = batch[0]["target"]
-        masks = batch[0]["mask"]
-
-        outputs = self(images)
-        return self._validation_loss(outputs, ground_truths, masks)
-
-    def predict_step(self, batch, batch_idx, dataloader_idx=0):
-        output = self(batch[0]["image"])[1]
-        return torch.sigmoid(output)
-
-    def configure_optimizers(self):
-        return self._optimizer_type(self.parameters(), **self._optimizer_arguments)
diff --git a/src/mednet/libs/segmentation/models/m2unet.py b/src/mednet/libs/segmentation/models/m2unet.py
index 5dfb333126d9266aad049447d8140aa19a6ff008..cf7ce5c436b8af5796f119c06940d8098c3d238b 100644
--- a/src/mednet/libs/segmentation/models/m2unet.py
+++ b/src/mednet/libs/segmentation/models/m2unet.py
@@ -211,26 +211,3 @@ class M2UNET(SegmentationModel):
             self.normalizer = make_imagenet_normalizer()
         else:
             super().set_normalizer(dataloader)
-
-    def training_step(self, batch, batch_idx):
-        images = batch[0]["image"]
-        ground_truths = batch[0]["target"]
-        masks = batch[0]["mask"]
-
-        outputs = self(self._augmentation_transforms(images))
-        return self._train_loss(outputs, ground_truths, masks)
-
-    def validation_step(self, batch, batch_idx):
-        images = batch[0]["image"]
-        ground_truths = batch[0]["target"]
-        masks = batch[0]["mask"]
-
-        outputs = self(images)
-        return self._validation_loss(outputs, ground_truths, masks)
-
-    def predict_step(self, batch, batch_idx, dataloader_idx=0):
-        output = self(batch[0]["image"])[1]
-        return torch.sigmoid(output)
-
-    def configure_optimizers(self):
-        return self._optimizer_type(self.parameters(), **self._optimizer_arguments)
diff --git a/src/mednet/libs/segmentation/models/segmentation_model.py b/src/mednet/libs/segmentation/models/segmentation_model.py
index 1bc7135dd7f285e92c318856af323e50ff329da4..8d8aa3ca42757bc415fc9784d309119f6f4599c4 100644
--- a/src/mednet/libs/segmentation/models/segmentation_model.py
+++ b/src/mednet/libs/segmentation/models/segmentation_model.py
@@ -80,3 +80,23 @@ class SegmentationModel(Model):
             f"computing z-norm factors from train dataloader.",
         )
         self.normalizer = make_z_normalizer(dataloader)
+
+    def training_step(self, batch, batch_idx):
+        images = batch[0]["image"]
+        ground_truths = batch[0]["target"]
+        masks = batch[0]["mask"]
+
+        outputs = self(self._augmentation_transforms(images))
+        return self._train_loss(outputs, ground_truths, masks)
+
+    def validation_step(self, batch, batch_idx):
+        images = batch[0]["image"]
+        ground_truths = batch[0]["target"]
+        masks = batch[0]["mask"]
+
+        outputs = self(images)
+        return self._validation_loss(outputs, ground_truths, masks)
+
+    def predict_step(self, batch, batch_idx, dataloader_idx=0):
+        output = self(batch[0]["image"])[1]
+        return torch.sigmoid(output)