From 38d5407281f5796e33ccbcbf3ffd1058e18d5bc4 Mon Sep 17 00:00:00 2001 From: dcarron <daniel.carron@idiap.ch> Date: Fri, 28 Jun 2024 10:25:36 +0200 Subject: [PATCH] [models] Move repeated functions into specialized models --- .../libs/classification/models/alexnet.py | 37 ------------------- .../models/classification_model.py | 37 +++++++++++++++++++ .../libs/classification/models/densenet.py | 32 ---------------- src/mednet/libs/classification/models/pasa.py | 31 ---------------- src/mednet/libs/segmentation/models/driu.py | 23 ------------ .../libs/segmentation/models/driu_bn.py | 23 ------------ .../libs/segmentation/models/driu_od.py | 23 ------------ .../libs/segmentation/models/driu_pix.py | 23 ------------ src/mednet/libs/segmentation/models/hed.py | 23 ------------ src/mednet/libs/segmentation/models/lwnet.py | 23 ------------ src/mednet/libs/segmentation/models/m2unet.py | 23 ------------ .../segmentation/models/segmentation_model.py | 20 ++++++++++ 12 files changed, 57 insertions(+), 261 deletions(-) diff --git a/src/mednet/libs/classification/models/alexnet.py b/src/mednet/libs/classification/models/alexnet.py index 51cabb0b..f930b625 100644 --- a/src/mednet/libs/classification/models/alexnet.py +++ b/src/mednet/libs/classification/models/alexnet.py @@ -117,40 +117,3 @@ class Alexnet(ClassificationModel): self.normalizer = make_imagenet_normalizer() else: super().set_normalizer(dataloader) - - def training_step(self, batch, _): - images = batch[0] - labels = batch[1]["target"] - - # Increase label dimension if too low - # Allows single and multiclass usage - if labels.ndim == 1: - labels = torch.reshape(labels, (labels.shape[0], 1)) - - # Forward pass on the network - outputs = self(self.augmentation_transforms(images)) - - return self._train_loss(outputs, labels.float()) - - def validation_step(self, batch, batch_idx, dataloader_idx=0): - images = batch[0] - labels = batch[1]["target"] - - # Increase label dimension if too low - # Allows single and multiclass usage - if labels.ndim == 1: - labels = torch.reshape(labels, (labels.shape[0], 1)) - - # debug code to inspect images by eye: - # from torchvision.transforms.functional import to_pil_image - # for k in images: - # to_pil_image(k).show() - # __import__("pdb").set_trace() - - # data forwarding on the existing network - outputs = self(images) - return self._validation_loss(outputs, labels.float()) - - def predict_step(self, batch, batch_idx, dataloader_idx=0): - outputs = self(batch[0]) - return torch.sigmoid(outputs) diff --git a/src/mednet/libs/classification/models/classification_model.py b/src/mednet/libs/classification/models/classification_model.py index 30dd0a54..e9654b48 100644 --- a/src/mednet/libs/classification/models/classification_model.py +++ b/src/mednet/libs/classification/models/classification_model.py @@ -79,3 +79,40 @@ class ClassificationModel(Model): f"computing z-norm factors from train dataloader.", ) self.normalizer = make_z_normalizer(dataloader) + + def training_step(self, batch, _): + images = batch[0] + labels = batch[1]["target"] + + # Increase label dimension if too low + # Allows single and multiclass usage + if labels.ndim == 1: + labels = torch.reshape(labels, (labels.shape[0], 1)) + + # Forward pass on the network + outputs = self(self.augmentation_transforms(images)) + + return self._train_loss(outputs, labels.float()) + + def validation_step(self, batch, batch_idx, dataloader_idx=0): + images = batch[0] + labels = batch[1]["target"] + + # Increase label dimension if too low + # Allows single and multiclass usage + if labels.ndim == 1: + labels = torch.reshape(labels, (labels.shape[0], 1)) + + # debug code to inspect images by eye: + # from torchvision.transforms.functional import to_pil_image + # for k in images: + # to_pil_image(k).show() + # __import__("pdb").set_trace() + + # data forwarding on the existing network + outputs = self(images) + return self._validation_loss(outputs, labels.float()) + + def predict_step(self, batch, batch_idx, dataloader_idx=0): + outputs = self(batch[0]) + return torch.sigmoid(outputs) diff --git a/src/mednet/libs/classification/models/densenet.py b/src/mednet/libs/classification/models/densenet.py index 83969d2c..c40cc957 100644 --- a/src/mednet/libs/classification/models/densenet.py +++ b/src/mednet/libs/classification/models/densenet.py @@ -120,35 +120,3 @@ class Densenet(ClassificationModel): self.normalizer = make_imagenet_normalizer() else: super().set_normalizer(dataloader) - - def training_step(self, batch, _): - images = batch[0] - labels = batch[1]["target"] - - # Increase label dimension if too low - # Allows single and multiclass usage - if labels.ndim == 1: - labels = torch.reshape(labels, (labels.shape[0], 1)) - - # Forward pass on the network - outputs = self(self.augmentation_transforms(images)) - - return self._train_loss(outputs, labels.float()) - - def validation_step(self, batch, batch_idx, dataloader_idx=0): - images = batch[0] - labels = batch[1]["target"] - - # Increase label dimension if too low - # Allows single and multiclass usage - if labels.ndim == 1: - labels = torch.reshape(labels, (labels.shape[0], 1)) - - # data forwarding on the existing network - outputs = self(images) - - return self._validation_loss(outputs, labels.float()) - - def predict_step(self, batch, batch_idx, dataloader_idx=0): - outputs = self(batch[0]) - return torch.sigmoid(outputs) diff --git a/src/mednet/libs/classification/models/pasa.py b/src/mednet/libs/classification/models/pasa.py index e7e15ddc..b17af172 100644 --- a/src/mednet/libs/classification/models/pasa.py +++ b/src/mednet/libs/classification/models/pasa.py @@ -192,34 +192,3 @@ class Pasa(ClassificationModel): return self.dense(x) # x = F.log_softmax(x, dim=1) # 0 is batch size - - def training_step(self, batch, _): - images = batch[0] - labels = batch[1]["target"] - - # Increase label dimension if too low - # Allows single and multiclass usage - if labels.ndim == 1: - labels = torch.reshape(labels, (labels.shape[0], 1)) - - # Forward pass on the network - outputs = self(self.augmentation_transforms(images)) - - return self._train_loss(outputs, labels.float()) - - def validation_step(self, batch, batch_idx, dataloader_idx=0): - images = batch[0] - labels = batch[1]["target"] - - # Increase label dimension if too low - # Allows single and multiclass usage - if labels.ndim == 1: - labels = torch.reshape(labels, (labels.shape[0], 1)) - - # data forwarding on the existing network - outputs = self(images) - return self._validation_loss(outputs, labels.float()) - - def predict_step(self, batch, batch_idx, dataloader_idx=0): - outputs = self(batch[0]) - return torch.sigmoid(outputs) diff --git a/src/mednet/libs/segmentation/models/driu.py b/src/mednet/libs/segmentation/models/driu.py index 41f0d58b..e12aac71 100644 --- a/src/mednet/libs/segmentation/models/driu.py +++ b/src/mednet/libs/segmentation/models/driu.py @@ -159,26 +159,3 @@ class DRIU(SegmentationModel): self.normalizer = make_imagenet_normalizer() else: super().set_normalizer(dataloader) - - def training_step(self, batch, batch_idx): - images = batch[0]["image"] - ground_truths = batch[0]["target"] - masks = batch[0]["mask"] - - outputs = self(self._augmentation_transforms(images)) - return self._train_loss(outputs, ground_truths, masks) - - def validation_step(self, batch, batch_idx): - images = batch[0]["image"] - ground_truths = batch[0]["target"] - masks = batch[0]["mask"] - - outputs = self(images) - return self._validation_loss(outputs, ground_truths, masks) - - def predict_step(self, batch, batch_idx, dataloader_idx=0): - output = self(batch[0]["image"])[1] - return torch.sigmoid(output) - - def configure_optimizers(self): - return self._optimizer_type(self.parameters(), **self._optimizer_arguments) diff --git a/src/mednet/libs/segmentation/models/driu_bn.py b/src/mednet/libs/segmentation/models/driu_bn.py index 7cfd1817..555d4ef1 100644 --- a/src/mednet/libs/segmentation/models/driu_bn.py +++ b/src/mednet/libs/segmentation/models/driu_bn.py @@ -162,26 +162,3 @@ class DRIUBN(SegmentationModel): self.normalizer = make_imagenet_normalizer() else: super().set_normalizer(dataloader) - - def training_step(self, batch, batch_idx): - images = batch[0]["image"] - ground_truths = batch[0]["target"] - masks = batch[0]["mask"] - - outputs = self(self._augmentation_transforms(images)) - return self._train_loss(outputs, ground_truths, masks) - - def validation_step(self, batch, batch_idx): - images = batch[0]["image"] - ground_truths = batch[0]["target"] - masks = batch[0]["mask"] - - outputs = self(images) - return self._validation_loss(outputs, ground_truths, masks) - - def predict_step(self, batch, batch_idx, dataloader_idx=0): - output = self(batch[0]["image"])[1] - return torch.sigmoid(output) - - def configure_optimizers(self): - return self._optimizer_type(self.parameters(), **self._optimizer_arguments) diff --git a/src/mednet/libs/segmentation/models/driu_od.py b/src/mednet/libs/segmentation/models/driu_od.py index 9a2cffac..3b0cff03 100644 --- a/src/mednet/libs/segmentation/models/driu_od.py +++ b/src/mednet/libs/segmentation/models/driu_od.py @@ -144,26 +144,3 @@ class DRIUOD(SegmentationModel): self.normalizer = make_imagenet_normalizer() else: super().set_normalizer(dataloader) - - def training_step(self, batch, batch_idx): - images = batch[0]["image"] - ground_truths = batch[0]["target"] - masks = batch[0]["mask"] - - outputs = self(self._augmentation_transforms(images)) - return self._train_loss(outputs, ground_truths, masks) - - def validation_step(self, batch, batch_idx): - images = batch[0]["image"] - ground_truths = batch[0]["target"] - masks = batch[0]["mask"] - - outputs = self(images) - return self._validation_loss(outputs, ground_truths, masks) - - def predict_step(self, batch, batch_idx, dataloader_idx=0): - output = self(batch[0]["image"])[1] - return torch.sigmoid(output) - - def configure_optimizers(self): - return self._optimizer_type(self.parameters(), **self._optimizer_arguments) diff --git a/src/mednet/libs/segmentation/models/driu_pix.py b/src/mednet/libs/segmentation/models/driu_pix.py index 80033176..75adce0c 100644 --- a/src/mednet/libs/segmentation/models/driu_pix.py +++ b/src/mednet/libs/segmentation/models/driu_pix.py @@ -148,26 +148,3 @@ class DRIUPix(SegmentationModel): self.normalizer = make_imagenet_normalizer() else: super().set_normalizer(dataloader) - - def training_step(self, batch, batch_idx): - images = batch[0]["image"] - ground_truths = batch[0]["target"] - masks = batch[0]["mask"] - - outputs = self(self._augmentation_transforms(images)) - return self._train_loss(outputs, ground_truths, masks) - - def validation_step(self, batch, batch_idx): - images = batch[0]["image"] - ground_truths = batch[0]["target"] - masks = batch[0]["mask"] - - outputs = self(images) - return self._validation_loss(outputs, ground_truths, masks) - - def predict_step(self, batch, batch_idx, dataloader_idx=0): - output = self(batch[0]["image"])[1] - return torch.sigmoid(output) - - def configure_optimizers(self): - return self._optimizer_type(self.parameters(), **self._optimizer_arguments) diff --git a/src/mednet/libs/segmentation/models/hed.py b/src/mednet/libs/segmentation/models/hed.py index 7ba67339..97a663c5 100644 --- a/src/mednet/libs/segmentation/models/hed.py +++ b/src/mednet/libs/segmentation/models/hed.py @@ -163,26 +163,3 @@ class HED(SegmentationModel): self.normalizer = make_imagenet_normalizer() else: super().set_normalizer(dataloader) - - def training_step(self, batch, batch_idx): - images = batch[0]["image"] - ground_truths = batch[0]["target"] - masks = batch[0]["mask"] - - outputs = self(self._augmentation_transforms(images)) - return self._train_loss(outputs, ground_truths, masks) - - def validation_step(self, batch, batch_idx): - images = batch[0]["image"] - ground_truths = batch[0]["target"] - masks = batch[0]["mask"] - - outputs = self(images) - return self._validation_loss(outputs, ground_truths, masks) - - def predict_step(self, batch, batch_idx, dataloader_idx=0): - output = self(batch[0]["image"])[1] - return torch.sigmoid(output) - - def configure_optimizers(self): - return self._optimizer_type(self.parameters(), **self._optimizer_arguments) diff --git a/src/mednet/libs/segmentation/models/lwnet.py b/src/mednet/libs/segmentation/models/lwnet.py index e1948a59..adb17512 100644 --- a/src/mednet/libs/segmentation/models/lwnet.py +++ b/src/mednet/libs/segmentation/models/lwnet.py @@ -348,26 +348,3 @@ class LittleWNet(SegmentationModel): x2 = self.unet2(torch.cat([xn, x1], dim=1)) return x1, x2 - - def training_step(self, batch, batch_idx): - images = batch[0]["image"] - ground_truths = batch[0]["target"] - masks = batch[0]["mask"] - - outputs = self(self._augmentation_transforms(images)) - return self._train_loss(outputs, ground_truths, masks) - - def validation_step(self, batch, batch_idx): - images = batch[0]["image"] - ground_truths = batch[0]["target"] - masks = batch[0]["mask"] - - outputs = self(images) - return self._validation_loss(outputs, ground_truths, masks) - - def predict_step(self, batch, batch_idx, dataloader_idx=0): - output = self(batch[0]["image"])[1] - return torch.sigmoid(output) - - def configure_optimizers(self): - return self._optimizer_type(self.parameters(), **self._optimizer_arguments) diff --git a/src/mednet/libs/segmentation/models/m2unet.py b/src/mednet/libs/segmentation/models/m2unet.py index 5dfb3331..cf7ce5c4 100644 --- a/src/mednet/libs/segmentation/models/m2unet.py +++ b/src/mednet/libs/segmentation/models/m2unet.py @@ -211,26 +211,3 @@ class M2UNET(SegmentationModel): self.normalizer = make_imagenet_normalizer() else: super().set_normalizer(dataloader) - - def training_step(self, batch, batch_idx): - images = batch[0]["image"] - ground_truths = batch[0]["target"] - masks = batch[0]["mask"] - - outputs = self(self._augmentation_transforms(images)) - return self._train_loss(outputs, ground_truths, masks) - - def validation_step(self, batch, batch_idx): - images = batch[0]["image"] - ground_truths = batch[0]["target"] - masks = batch[0]["mask"] - - outputs = self(images) - return self._validation_loss(outputs, ground_truths, masks) - - def predict_step(self, batch, batch_idx, dataloader_idx=0): - output = self(batch[0]["image"])[1] - return torch.sigmoid(output) - - def configure_optimizers(self): - return self._optimizer_type(self.parameters(), **self._optimizer_arguments) diff --git a/src/mednet/libs/segmentation/models/segmentation_model.py b/src/mednet/libs/segmentation/models/segmentation_model.py index 1bc7135d..8d8aa3ca 100644 --- a/src/mednet/libs/segmentation/models/segmentation_model.py +++ b/src/mednet/libs/segmentation/models/segmentation_model.py @@ -80,3 +80,23 @@ class SegmentationModel(Model): f"computing z-norm factors from train dataloader.", ) self.normalizer = make_z_normalizer(dataloader) + + def training_step(self, batch, batch_idx): + images = batch[0]["image"] + ground_truths = batch[0]["target"] + masks = batch[0]["mask"] + + outputs = self(self._augmentation_transforms(images)) + return self._train_loss(outputs, ground_truths, masks) + + def validation_step(self, batch, batch_idx): + images = batch[0]["image"] + ground_truths = batch[0]["target"] + masks = batch[0]["mask"] + + outputs = self(images) + return self._validation_loss(outputs, ground_truths, masks) + + def predict_step(self, batch, batch_idx, dataloader_idx=0): + output = self(batch[0]["image"])[1] + return torch.sigmoid(output) -- GitLab