diff --git a/src/mednet/libs/classification/models/alexnet.py b/src/mednet/libs/classification/models/alexnet.py index 51cabb0b0c3b5e1c8ca762e654f92f6e557023c4..f930b625e96361182b8b9240e2fd2c82194328f5 100644 --- a/src/mednet/libs/classification/models/alexnet.py +++ b/src/mednet/libs/classification/models/alexnet.py @@ -117,40 +117,3 @@ class Alexnet(ClassificationModel): self.normalizer = make_imagenet_normalizer() else: super().set_normalizer(dataloader) - - def training_step(self, batch, _): - images = batch[0] - labels = batch[1]["target"] - - # Increase label dimension if too low - # Allows single and multiclass usage - if labels.ndim == 1: - labels = torch.reshape(labels, (labels.shape[0], 1)) - - # Forward pass on the network - outputs = self(self.augmentation_transforms(images)) - - return self._train_loss(outputs, labels.float()) - - def validation_step(self, batch, batch_idx, dataloader_idx=0): - images = batch[0] - labels = batch[1]["target"] - - # Increase label dimension if too low - # Allows single and multiclass usage - if labels.ndim == 1: - labels = torch.reshape(labels, (labels.shape[0], 1)) - - # debug code to inspect images by eye: - # from torchvision.transforms.functional import to_pil_image - # for k in images: - # to_pil_image(k).show() - # __import__("pdb").set_trace() - - # data forwarding on the existing network - outputs = self(images) - return self._validation_loss(outputs, labels.float()) - - def predict_step(self, batch, batch_idx, dataloader_idx=0): - outputs = self(batch[0]) - return torch.sigmoid(outputs) diff --git a/src/mednet/libs/classification/models/classification_model.py b/src/mednet/libs/classification/models/classification_model.py index 30dd0a54d08fb3fe193699274a4d82ea7b402167..e9654b480b67c7dc8b2ec249ed1921822383d33b 100644 --- a/src/mednet/libs/classification/models/classification_model.py +++ b/src/mednet/libs/classification/models/classification_model.py @@ -79,3 +79,40 @@ class ClassificationModel(Model): f"computing z-norm factors from train dataloader.", ) self.normalizer = make_z_normalizer(dataloader) + + def training_step(self, batch, _): + images = batch[0] + labels = batch[1]["target"] + + # Increase label dimension if too low + # Allows single and multiclass usage + if labels.ndim == 1: + labels = torch.reshape(labels, (labels.shape[0], 1)) + + # Forward pass on the network + outputs = self(self.augmentation_transforms(images)) + + return self._train_loss(outputs, labels.float()) + + def validation_step(self, batch, batch_idx, dataloader_idx=0): + images = batch[0] + labels = batch[1]["target"] + + # Increase label dimension if too low + # Allows single and multiclass usage + if labels.ndim == 1: + labels = torch.reshape(labels, (labels.shape[0], 1)) + + # debug code to inspect images by eye: + # from torchvision.transforms.functional import to_pil_image + # for k in images: + # to_pil_image(k).show() + # __import__("pdb").set_trace() + + # data forwarding on the existing network + outputs = self(images) + return self._validation_loss(outputs, labels.float()) + + def predict_step(self, batch, batch_idx, dataloader_idx=0): + outputs = self(batch[0]) + return torch.sigmoid(outputs) diff --git a/src/mednet/libs/classification/models/densenet.py b/src/mednet/libs/classification/models/densenet.py index 83969d2c30ce4daaa08680591faaea7154db290f..c40cc957c5ca8f284b6c3f4832bb1b7d54052bda 100644 --- a/src/mednet/libs/classification/models/densenet.py +++ b/src/mednet/libs/classification/models/densenet.py @@ -120,35 +120,3 @@ class Densenet(ClassificationModel): self.normalizer = make_imagenet_normalizer() else: super().set_normalizer(dataloader) - - def training_step(self, batch, _): - images = batch[0] - labels = batch[1]["target"] - - # Increase label dimension if too low - # Allows single and multiclass usage - if labels.ndim == 1: - labels = torch.reshape(labels, (labels.shape[0], 1)) - - # Forward pass on the network - outputs = self(self.augmentation_transforms(images)) - - return self._train_loss(outputs, labels.float()) - - def validation_step(self, batch, batch_idx, dataloader_idx=0): - images = batch[0] - labels = batch[1]["target"] - - # Increase label dimension if too low - # Allows single and multiclass usage - if labels.ndim == 1: - labels = torch.reshape(labels, (labels.shape[0], 1)) - - # data forwarding on the existing network - outputs = self(images) - - return self._validation_loss(outputs, labels.float()) - - def predict_step(self, batch, batch_idx, dataloader_idx=0): - outputs = self(batch[0]) - return torch.sigmoid(outputs) diff --git a/src/mednet/libs/classification/models/pasa.py b/src/mednet/libs/classification/models/pasa.py index e7e15ddca050c661c118c1964bb2ff04cecd2e8f..b17af1724e97c18e6b58b1c25f9db3b3e8c167cb 100644 --- a/src/mednet/libs/classification/models/pasa.py +++ b/src/mednet/libs/classification/models/pasa.py @@ -192,34 +192,3 @@ class Pasa(ClassificationModel): return self.dense(x) # x = F.log_softmax(x, dim=1) # 0 is batch size - - def training_step(self, batch, _): - images = batch[0] - labels = batch[1]["target"] - - # Increase label dimension if too low - # Allows single and multiclass usage - if labels.ndim == 1: - labels = torch.reshape(labels, (labels.shape[0], 1)) - - # Forward pass on the network - outputs = self(self.augmentation_transforms(images)) - - return self._train_loss(outputs, labels.float()) - - def validation_step(self, batch, batch_idx, dataloader_idx=0): - images = batch[0] - labels = batch[1]["target"] - - # Increase label dimension if too low - # Allows single and multiclass usage - if labels.ndim == 1: - labels = torch.reshape(labels, (labels.shape[0], 1)) - - # data forwarding on the existing network - outputs = self(images) - return self._validation_loss(outputs, labels.float()) - - def predict_step(self, batch, batch_idx, dataloader_idx=0): - outputs = self(batch[0]) - return torch.sigmoid(outputs) diff --git a/src/mednet/libs/segmentation/models/driu.py b/src/mednet/libs/segmentation/models/driu.py index 41f0d58b65f78607097f311b379d964082d32b0a..e12aac71517ff42a6caf6b2ce052da2be8f32db0 100644 --- a/src/mednet/libs/segmentation/models/driu.py +++ b/src/mednet/libs/segmentation/models/driu.py @@ -159,26 +159,3 @@ class DRIU(SegmentationModel): self.normalizer = make_imagenet_normalizer() else: super().set_normalizer(dataloader) - - def training_step(self, batch, batch_idx): - images = batch[0]["image"] - ground_truths = batch[0]["target"] - masks = batch[0]["mask"] - - outputs = self(self._augmentation_transforms(images)) - return self._train_loss(outputs, ground_truths, masks) - - def validation_step(self, batch, batch_idx): - images = batch[0]["image"] - ground_truths = batch[0]["target"] - masks = batch[0]["mask"] - - outputs = self(images) - return self._validation_loss(outputs, ground_truths, masks) - - def predict_step(self, batch, batch_idx, dataloader_idx=0): - output = self(batch[0]["image"])[1] - return torch.sigmoid(output) - - def configure_optimizers(self): - return self._optimizer_type(self.parameters(), **self._optimizer_arguments) diff --git a/src/mednet/libs/segmentation/models/driu_bn.py b/src/mednet/libs/segmentation/models/driu_bn.py index 7cfd18178b40a7b49c62da8337ca200b3ff87d49..555d4ef1ffc03e6d88ba81c60326df0cfc745a74 100644 --- a/src/mednet/libs/segmentation/models/driu_bn.py +++ b/src/mednet/libs/segmentation/models/driu_bn.py @@ -162,26 +162,3 @@ class DRIUBN(SegmentationModel): self.normalizer = make_imagenet_normalizer() else: super().set_normalizer(dataloader) - - def training_step(self, batch, batch_idx): - images = batch[0]["image"] - ground_truths = batch[0]["target"] - masks = batch[0]["mask"] - - outputs = self(self._augmentation_transforms(images)) - return self._train_loss(outputs, ground_truths, masks) - - def validation_step(self, batch, batch_idx): - images = batch[0]["image"] - ground_truths = batch[0]["target"] - masks = batch[0]["mask"] - - outputs = self(images) - return self._validation_loss(outputs, ground_truths, masks) - - def predict_step(self, batch, batch_idx, dataloader_idx=0): - output = self(batch[0]["image"])[1] - return torch.sigmoid(output) - - def configure_optimizers(self): - return self._optimizer_type(self.parameters(), **self._optimizer_arguments) diff --git a/src/mednet/libs/segmentation/models/driu_od.py b/src/mednet/libs/segmentation/models/driu_od.py index 9a2cffac96ccfa6f1afac8cbee9376cef0fde66c..3b0cff036c843835a54e566206c045569e4bcd9c 100644 --- a/src/mednet/libs/segmentation/models/driu_od.py +++ b/src/mednet/libs/segmentation/models/driu_od.py @@ -144,26 +144,3 @@ class DRIUOD(SegmentationModel): self.normalizer = make_imagenet_normalizer() else: super().set_normalizer(dataloader) - - def training_step(self, batch, batch_idx): - images = batch[0]["image"] - ground_truths = batch[0]["target"] - masks = batch[0]["mask"] - - outputs = self(self._augmentation_transforms(images)) - return self._train_loss(outputs, ground_truths, masks) - - def validation_step(self, batch, batch_idx): - images = batch[0]["image"] - ground_truths = batch[0]["target"] - masks = batch[0]["mask"] - - outputs = self(images) - return self._validation_loss(outputs, ground_truths, masks) - - def predict_step(self, batch, batch_idx, dataloader_idx=0): - output = self(batch[0]["image"])[1] - return torch.sigmoid(output) - - def configure_optimizers(self): - return self._optimizer_type(self.parameters(), **self._optimizer_arguments) diff --git a/src/mednet/libs/segmentation/models/driu_pix.py b/src/mednet/libs/segmentation/models/driu_pix.py index 8003317650389e3b1c1e31e3cb13304feddcda15..75adce0c91da0fd53ff9386b3ffa08bc187fa72f 100644 --- a/src/mednet/libs/segmentation/models/driu_pix.py +++ b/src/mednet/libs/segmentation/models/driu_pix.py @@ -148,26 +148,3 @@ class DRIUPix(SegmentationModel): self.normalizer = make_imagenet_normalizer() else: super().set_normalizer(dataloader) - - def training_step(self, batch, batch_idx): - images = batch[0]["image"] - ground_truths = batch[0]["target"] - masks = batch[0]["mask"] - - outputs = self(self._augmentation_transforms(images)) - return self._train_loss(outputs, ground_truths, masks) - - def validation_step(self, batch, batch_idx): - images = batch[0]["image"] - ground_truths = batch[0]["target"] - masks = batch[0]["mask"] - - outputs = self(images) - return self._validation_loss(outputs, ground_truths, masks) - - def predict_step(self, batch, batch_idx, dataloader_idx=0): - output = self(batch[0]["image"])[1] - return torch.sigmoid(output) - - def configure_optimizers(self): - return self._optimizer_type(self.parameters(), **self._optimizer_arguments) diff --git a/src/mednet/libs/segmentation/models/hed.py b/src/mednet/libs/segmentation/models/hed.py index 7ba67339a412fc5499f28404dbbece60680ea209..97a663c51cc3fa649f70f1c0902fad712babb916 100644 --- a/src/mednet/libs/segmentation/models/hed.py +++ b/src/mednet/libs/segmentation/models/hed.py @@ -163,26 +163,3 @@ class HED(SegmentationModel): self.normalizer = make_imagenet_normalizer() else: super().set_normalizer(dataloader) - - def training_step(self, batch, batch_idx): - images = batch[0]["image"] - ground_truths = batch[0]["target"] - masks = batch[0]["mask"] - - outputs = self(self._augmentation_transforms(images)) - return self._train_loss(outputs, ground_truths, masks) - - def validation_step(self, batch, batch_idx): - images = batch[0]["image"] - ground_truths = batch[0]["target"] - masks = batch[0]["mask"] - - outputs = self(images) - return self._validation_loss(outputs, ground_truths, masks) - - def predict_step(self, batch, batch_idx, dataloader_idx=0): - output = self(batch[0]["image"])[1] - return torch.sigmoid(output) - - def configure_optimizers(self): - return self._optimizer_type(self.parameters(), **self._optimizer_arguments) diff --git a/src/mednet/libs/segmentation/models/lwnet.py b/src/mednet/libs/segmentation/models/lwnet.py index e1948a594dd31af9180d4b713474894ac11f6243..adb17512bcd3953efd65c6dc42b382d0b99f12ea 100644 --- a/src/mednet/libs/segmentation/models/lwnet.py +++ b/src/mednet/libs/segmentation/models/lwnet.py @@ -348,26 +348,3 @@ class LittleWNet(SegmentationModel): x2 = self.unet2(torch.cat([xn, x1], dim=1)) return x1, x2 - - def training_step(self, batch, batch_idx): - images = batch[0]["image"] - ground_truths = batch[0]["target"] - masks = batch[0]["mask"] - - outputs = self(self._augmentation_transforms(images)) - return self._train_loss(outputs, ground_truths, masks) - - def validation_step(self, batch, batch_idx): - images = batch[0]["image"] - ground_truths = batch[0]["target"] - masks = batch[0]["mask"] - - outputs = self(images) - return self._validation_loss(outputs, ground_truths, masks) - - def predict_step(self, batch, batch_idx, dataloader_idx=0): - output = self(batch[0]["image"])[1] - return torch.sigmoid(output) - - def configure_optimizers(self): - return self._optimizer_type(self.parameters(), **self._optimizer_arguments) diff --git a/src/mednet/libs/segmentation/models/m2unet.py b/src/mednet/libs/segmentation/models/m2unet.py index 5dfb333126d9266aad049447d8140aa19a6ff008..cf7ce5c436b8af5796f119c06940d8098c3d238b 100644 --- a/src/mednet/libs/segmentation/models/m2unet.py +++ b/src/mednet/libs/segmentation/models/m2unet.py @@ -211,26 +211,3 @@ class M2UNET(SegmentationModel): self.normalizer = make_imagenet_normalizer() else: super().set_normalizer(dataloader) - - def training_step(self, batch, batch_idx): - images = batch[0]["image"] - ground_truths = batch[0]["target"] - masks = batch[0]["mask"] - - outputs = self(self._augmentation_transforms(images)) - return self._train_loss(outputs, ground_truths, masks) - - def validation_step(self, batch, batch_idx): - images = batch[0]["image"] - ground_truths = batch[0]["target"] - masks = batch[0]["mask"] - - outputs = self(images) - return self._validation_loss(outputs, ground_truths, masks) - - def predict_step(self, batch, batch_idx, dataloader_idx=0): - output = self(batch[0]["image"])[1] - return torch.sigmoid(output) - - def configure_optimizers(self): - return self._optimizer_type(self.parameters(), **self._optimizer_arguments) diff --git a/src/mednet/libs/segmentation/models/segmentation_model.py b/src/mednet/libs/segmentation/models/segmentation_model.py index 1bc7135dd7f285e92c318856af323e50ff329da4..8d8aa3ca42757bc415fc9784d309119f6f4599c4 100644 --- a/src/mednet/libs/segmentation/models/segmentation_model.py +++ b/src/mednet/libs/segmentation/models/segmentation_model.py @@ -80,3 +80,23 @@ class SegmentationModel(Model): f"computing z-norm factors from train dataloader.", ) self.normalizer = make_z_normalizer(dataloader) + + def training_step(self, batch, batch_idx): + images = batch[0]["image"] + ground_truths = batch[0]["target"] + masks = batch[0]["mask"] + + outputs = self(self._augmentation_transforms(images)) + return self._train_loss(outputs, ground_truths, masks) + + def validation_step(self, batch, batch_idx): + images = batch[0]["image"] + ground_truths = batch[0]["target"] + masks = batch[0]["mask"] + + outputs = self(images) + return self._validation_loss(outputs, ground_truths, masks) + + def predict_step(self, batch, batch_idx, dataloader_idx=0): + output = self(batch[0]["image"])[1] + return torch.sigmoid(output)