Skip to content
Snippets Groups Projects
Commit 38d54072 authored by Daniel CARRON's avatar Daniel CARRON :b: Committed by André Anjos
Browse files

[models] Move repeated functions into specialized models

parent d3511abe
No related branches found
No related tags found
1 merge request!46Create common library
Showing
with 57 additions and 261 deletions
......@@ -117,40 +117,3 @@ class Alexnet(ClassificationModel):
self.normalizer = make_imagenet_normalizer()
else:
super().set_normalizer(dataloader)
def training_step(self, batch, _):
images = batch[0]
labels = batch[1]["target"]
# Increase label dimension if too low
# Allows single and multiclass usage
if labels.ndim == 1:
labels = torch.reshape(labels, (labels.shape[0], 1))
# Forward pass on the network
outputs = self(self.augmentation_transforms(images))
return self._train_loss(outputs, labels.float())
def validation_step(self, batch, batch_idx, dataloader_idx=0):
images = batch[0]
labels = batch[1]["target"]
# Increase label dimension if too low
# Allows single and multiclass usage
if labels.ndim == 1:
labels = torch.reshape(labels, (labels.shape[0], 1))
# debug code to inspect images by eye:
# from torchvision.transforms.functional import to_pil_image
# for k in images:
# to_pil_image(k).show()
# __import__("pdb").set_trace()
# data forwarding on the existing network
outputs = self(images)
return self._validation_loss(outputs, labels.float())
def predict_step(self, batch, batch_idx, dataloader_idx=0):
outputs = self(batch[0])
return torch.sigmoid(outputs)
......@@ -79,3 +79,40 @@ class ClassificationModel(Model):
f"computing z-norm factors from train dataloader.",
)
self.normalizer = make_z_normalizer(dataloader)
def training_step(self, batch, _):
images = batch[0]
labels = batch[1]["target"]
# Increase label dimension if too low
# Allows single and multiclass usage
if labels.ndim == 1:
labels = torch.reshape(labels, (labels.shape[0], 1))
# Forward pass on the network
outputs = self(self.augmentation_transforms(images))
return self._train_loss(outputs, labels.float())
def validation_step(self, batch, batch_idx, dataloader_idx=0):
images = batch[0]
labels = batch[1]["target"]
# Increase label dimension if too low
# Allows single and multiclass usage
if labels.ndim == 1:
labels = torch.reshape(labels, (labels.shape[0], 1))
# debug code to inspect images by eye:
# from torchvision.transforms.functional import to_pil_image
# for k in images:
# to_pil_image(k).show()
# __import__("pdb").set_trace()
# data forwarding on the existing network
outputs = self(images)
return self._validation_loss(outputs, labels.float())
def predict_step(self, batch, batch_idx, dataloader_idx=0):
outputs = self(batch[0])
return torch.sigmoid(outputs)
......@@ -120,35 +120,3 @@ class Densenet(ClassificationModel):
self.normalizer = make_imagenet_normalizer()
else:
super().set_normalizer(dataloader)
def training_step(self, batch, _):
images = batch[0]
labels = batch[1]["target"]
# Increase label dimension if too low
# Allows single and multiclass usage
if labels.ndim == 1:
labels = torch.reshape(labels, (labels.shape[0], 1))
# Forward pass on the network
outputs = self(self.augmentation_transforms(images))
return self._train_loss(outputs, labels.float())
def validation_step(self, batch, batch_idx, dataloader_idx=0):
images = batch[0]
labels = batch[1]["target"]
# Increase label dimension if too low
# Allows single and multiclass usage
if labels.ndim == 1:
labels = torch.reshape(labels, (labels.shape[0], 1))
# data forwarding on the existing network
outputs = self(images)
return self._validation_loss(outputs, labels.float())
def predict_step(self, batch, batch_idx, dataloader_idx=0):
outputs = self(batch[0])
return torch.sigmoid(outputs)
......@@ -192,34 +192,3 @@ class Pasa(ClassificationModel):
return self.dense(x)
# x = F.log_softmax(x, dim=1) # 0 is batch size
def training_step(self, batch, _):
images = batch[0]
labels = batch[1]["target"]
# Increase label dimension if too low
# Allows single and multiclass usage
if labels.ndim == 1:
labels = torch.reshape(labels, (labels.shape[0], 1))
# Forward pass on the network
outputs = self(self.augmentation_transforms(images))
return self._train_loss(outputs, labels.float())
def validation_step(self, batch, batch_idx, dataloader_idx=0):
images = batch[0]
labels = batch[1]["target"]
# Increase label dimension if too low
# Allows single and multiclass usage
if labels.ndim == 1:
labels = torch.reshape(labels, (labels.shape[0], 1))
# data forwarding on the existing network
outputs = self(images)
return self._validation_loss(outputs, labels.float())
def predict_step(self, batch, batch_idx, dataloader_idx=0):
outputs = self(batch[0])
return torch.sigmoid(outputs)
......@@ -159,26 +159,3 @@ class DRIU(SegmentationModel):
self.normalizer = make_imagenet_normalizer()
else:
super().set_normalizer(dataloader)
def training_step(self, batch, batch_idx):
images = batch[0]["image"]
ground_truths = batch[0]["target"]
masks = batch[0]["mask"]
outputs = self(self._augmentation_transforms(images))
return self._train_loss(outputs, ground_truths, masks)
def validation_step(self, batch, batch_idx):
images = batch[0]["image"]
ground_truths = batch[0]["target"]
masks = batch[0]["mask"]
outputs = self(images)
return self._validation_loss(outputs, ground_truths, masks)
def predict_step(self, batch, batch_idx, dataloader_idx=0):
output = self(batch[0]["image"])[1]
return torch.sigmoid(output)
def configure_optimizers(self):
return self._optimizer_type(self.parameters(), **self._optimizer_arguments)
......@@ -162,26 +162,3 @@ class DRIUBN(SegmentationModel):
self.normalizer = make_imagenet_normalizer()
else:
super().set_normalizer(dataloader)
def training_step(self, batch, batch_idx):
images = batch[0]["image"]
ground_truths = batch[0]["target"]
masks = batch[0]["mask"]
outputs = self(self._augmentation_transforms(images))
return self._train_loss(outputs, ground_truths, masks)
def validation_step(self, batch, batch_idx):
images = batch[0]["image"]
ground_truths = batch[0]["target"]
masks = batch[0]["mask"]
outputs = self(images)
return self._validation_loss(outputs, ground_truths, masks)
def predict_step(self, batch, batch_idx, dataloader_idx=0):
output = self(batch[0]["image"])[1]
return torch.sigmoid(output)
def configure_optimizers(self):
return self._optimizer_type(self.parameters(), **self._optimizer_arguments)
......@@ -144,26 +144,3 @@ class DRIUOD(SegmentationModel):
self.normalizer = make_imagenet_normalizer()
else:
super().set_normalizer(dataloader)
def training_step(self, batch, batch_idx):
images = batch[0]["image"]
ground_truths = batch[0]["target"]
masks = batch[0]["mask"]
outputs = self(self._augmentation_transforms(images))
return self._train_loss(outputs, ground_truths, masks)
def validation_step(self, batch, batch_idx):
images = batch[0]["image"]
ground_truths = batch[0]["target"]
masks = batch[0]["mask"]
outputs = self(images)
return self._validation_loss(outputs, ground_truths, masks)
def predict_step(self, batch, batch_idx, dataloader_idx=0):
output = self(batch[0]["image"])[1]
return torch.sigmoid(output)
def configure_optimizers(self):
return self._optimizer_type(self.parameters(), **self._optimizer_arguments)
......@@ -148,26 +148,3 @@ class DRIUPix(SegmentationModel):
self.normalizer = make_imagenet_normalizer()
else:
super().set_normalizer(dataloader)
def training_step(self, batch, batch_idx):
images = batch[0]["image"]
ground_truths = batch[0]["target"]
masks = batch[0]["mask"]
outputs = self(self._augmentation_transforms(images))
return self._train_loss(outputs, ground_truths, masks)
def validation_step(self, batch, batch_idx):
images = batch[0]["image"]
ground_truths = batch[0]["target"]
masks = batch[0]["mask"]
outputs = self(images)
return self._validation_loss(outputs, ground_truths, masks)
def predict_step(self, batch, batch_idx, dataloader_idx=0):
output = self(batch[0]["image"])[1]
return torch.sigmoid(output)
def configure_optimizers(self):
return self._optimizer_type(self.parameters(), **self._optimizer_arguments)
......@@ -163,26 +163,3 @@ class HED(SegmentationModel):
self.normalizer = make_imagenet_normalizer()
else:
super().set_normalizer(dataloader)
def training_step(self, batch, batch_idx):
images = batch[0]["image"]
ground_truths = batch[0]["target"]
masks = batch[0]["mask"]
outputs = self(self._augmentation_transforms(images))
return self._train_loss(outputs, ground_truths, masks)
def validation_step(self, batch, batch_idx):
images = batch[0]["image"]
ground_truths = batch[0]["target"]
masks = batch[0]["mask"]
outputs = self(images)
return self._validation_loss(outputs, ground_truths, masks)
def predict_step(self, batch, batch_idx, dataloader_idx=0):
output = self(batch[0]["image"])[1]
return torch.sigmoid(output)
def configure_optimizers(self):
return self._optimizer_type(self.parameters(), **self._optimizer_arguments)
......@@ -348,26 +348,3 @@ class LittleWNet(SegmentationModel):
x2 = self.unet2(torch.cat([xn, x1], dim=1))
return x1, x2
def training_step(self, batch, batch_idx):
images = batch[0]["image"]
ground_truths = batch[0]["target"]
masks = batch[0]["mask"]
outputs = self(self._augmentation_transforms(images))
return self._train_loss(outputs, ground_truths, masks)
def validation_step(self, batch, batch_idx):
images = batch[0]["image"]
ground_truths = batch[0]["target"]
masks = batch[0]["mask"]
outputs = self(images)
return self._validation_loss(outputs, ground_truths, masks)
def predict_step(self, batch, batch_idx, dataloader_idx=0):
output = self(batch[0]["image"])[1]
return torch.sigmoid(output)
def configure_optimizers(self):
return self._optimizer_type(self.parameters(), **self._optimizer_arguments)
......@@ -211,26 +211,3 @@ class M2UNET(SegmentationModel):
self.normalizer = make_imagenet_normalizer()
else:
super().set_normalizer(dataloader)
def training_step(self, batch, batch_idx):
images = batch[0]["image"]
ground_truths = batch[0]["target"]
masks = batch[0]["mask"]
outputs = self(self._augmentation_transforms(images))
return self._train_loss(outputs, ground_truths, masks)
def validation_step(self, batch, batch_idx):
images = batch[0]["image"]
ground_truths = batch[0]["target"]
masks = batch[0]["mask"]
outputs = self(images)
return self._validation_loss(outputs, ground_truths, masks)
def predict_step(self, batch, batch_idx, dataloader_idx=0):
output = self(batch[0]["image"])[1]
return torch.sigmoid(output)
def configure_optimizers(self):
return self._optimizer_type(self.parameters(), **self._optimizer_arguments)
......@@ -80,3 +80,23 @@ class SegmentationModel(Model):
f"computing z-norm factors from train dataloader.",
)
self.normalizer = make_z_normalizer(dataloader)
def training_step(self, batch, batch_idx):
images = batch[0]["image"]
ground_truths = batch[0]["target"]
masks = batch[0]["mask"]
outputs = self(self._augmentation_transforms(images))
return self._train_loss(outputs, ground_truths, masks)
def validation_step(self, batch, batch_idx):
images = batch[0]["image"]
ground_truths = batch[0]["target"]
masks = batch[0]["mask"]
outputs = self(images)
return self._validation_loss(outputs, ground_truths, masks)
def predict_step(self, batch, batch_idx, dataloader_idx=0):
output = self(batch[0]["image"])[1]
return torch.sigmoid(output)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment