Skip to content
Snippets Groups Projects
Commit a12a332a authored by Daniel CARRON's avatar Daniel CARRON :b: Committed by André Anjos
Browse files

[segmentation.models] Update models to handle updated samples

parent f142d968
No related branches found
No related tags found
1 merge request!46Create common library
...@@ -161,23 +161,23 @@ class DRIU(Model): ...@@ -161,23 +161,23 @@ class DRIU(Model):
super().set_normalizer(dataloader) super().set_normalizer(dataloader)
def training_step(self, batch, batch_idx): def training_step(self, batch, batch_idx):
images = batch[0] images = batch[0]["image"]
ground_truths = batch[1]["target"] ground_truths = batch[0]["target"]
masks = batch[1]["mask"] masks = batch[0]["mask"]
outputs = self(self._augmentation_transforms(images)) outputs = self(self._augmentation_transforms(images))
return self._train_loss(outputs, ground_truths, masks) return self._train_loss(outputs, ground_truths, masks)
def validation_step(self, batch, batch_idx): def validation_step(self, batch, batch_idx):
images = batch[0] images = batch[0]["image"]
ground_truths = batch[1]["target"] ground_truths = batch[0]["target"]
masks = batch[1]["mask"] masks = batch[0]["mask"]
outputs = self(images) outputs = self(images)
return self._validation_loss(outputs, ground_truths, masks) return self._validation_loss(outputs, ground_truths, masks)
def predict_step(self, batch, batch_idx, dataloader_idx=0): def predict_step(self, batch, batch_idx, dataloader_idx=0):
output = self(batch[0])[1] output = self(batch[0]["image"])[1]
return torch.sigmoid(output) return torch.sigmoid(output)
def configure_optimizers(self): def configure_optimizers(self):
......
...@@ -164,23 +164,23 @@ class DRIUBN(Model): ...@@ -164,23 +164,23 @@ class DRIUBN(Model):
super().set_normalizer(dataloader) super().set_normalizer(dataloader)
def training_step(self, batch, batch_idx): def training_step(self, batch, batch_idx):
images = batch[0] images = batch[0]["image"]
ground_truths = batch[1]["target"] ground_truths = batch[0]["target"]
masks = batch[1]["mask"] masks = batch[0]["mask"]
outputs = self(self._augmentation_transforms(images)) outputs = self(self._augmentation_transforms(images))
return self._train_loss(outputs, ground_truths, masks) return self._train_loss(outputs, ground_truths, masks)
def validation_step(self, batch, batch_idx): def validation_step(self, batch, batch_idx):
images = batch[0] images = batch[0]["image"]
ground_truths = batch[1]["target"] ground_truths = batch[0]["target"]
masks = batch[1]["mask"] masks = batch[0]["mask"]
outputs = self(images) outputs = self(images)
return self._validation_loss(outputs, ground_truths, masks) return self._validation_loss(outputs, ground_truths, masks)
def predict_step(self, batch, batch_idx, dataloader_idx=0): def predict_step(self, batch, batch_idx, dataloader_idx=0):
output = self(batch[0])[1] output = self(batch[0]["image"])[1]
return torch.sigmoid(output) return torch.sigmoid(output)
def configure_optimizers(self): def configure_optimizers(self):
......
...@@ -146,23 +146,23 @@ class DRIUOD(Model): ...@@ -146,23 +146,23 @@ class DRIUOD(Model):
super().set_normalizer(dataloader) super().set_normalizer(dataloader)
def training_step(self, batch, batch_idx): def training_step(self, batch, batch_idx):
images = batch[0] images = batch[0]["image"]
ground_truths = batch[1]["target"] ground_truths = batch[0]["target"]
masks = batch[1]["mask"] masks = batch[0]["mask"]
outputs = self(self._augmentation_transforms(images)) outputs = self(self._augmentation_transforms(images))
return self._train_loss(outputs, ground_truths, masks) return self._train_loss(outputs, ground_truths, masks)
def validation_step(self, batch, batch_idx): def validation_step(self, batch, batch_idx):
images = batch[0] images = batch[0]["image"]
ground_truths = batch[1]["target"] ground_truths = batch[0]["target"]
masks = batch[1]["mask"] masks = batch[0]["mask"]
outputs = self(images) outputs = self(images)
return self._validation_loss(outputs, ground_truths, masks) return self._validation_loss(outputs, ground_truths, masks)
def predict_step(self, batch, batch_idx, dataloader_idx=0): def predict_step(self, batch, batch_idx, dataloader_idx=0):
output = self(batch[0])[1] output = self(batch[0]["image"])[1]
return torch.sigmoid(output) return torch.sigmoid(output)
def configure_optimizers(self): def configure_optimizers(self):
......
...@@ -150,23 +150,23 @@ class DRIUPix(Model): ...@@ -150,23 +150,23 @@ class DRIUPix(Model):
super().set_normalizer(dataloader) super().set_normalizer(dataloader)
def training_step(self, batch, batch_idx): def training_step(self, batch, batch_idx):
images = batch[0] images = batch[0]["image"]
ground_truths = batch[1]["target"] ground_truths = batch[0]["target"]
masks = batch[1]["mask"] masks = batch[0]["mask"]
outputs = self(self._augmentation_transforms(images)) outputs = self(self._augmentation_transforms(images))
return self._train_loss(outputs, ground_truths, masks) return self._train_loss(outputs, ground_truths, masks)
def validation_step(self, batch, batch_idx): def validation_step(self, batch, batch_idx):
images = batch[0] images = batch[0]["image"]
ground_truths = batch[1]["target"] ground_truths = batch[0]["target"]
masks = batch[1]["mask"] masks = batch[0]["mask"]
outputs = self(images) outputs = self(images)
return self._validation_loss(outputs, ground_truths, masks) return self._validation_loss(outputs, ground_truths, masks)
def predict_step(self, batch, batch_idx, dataloader_idx=0): def predict_step(self, batch, batch_idx, dataloader_idx=0):
output = self(batch[0])[1] output = self(batch[0]["image"])[1]
return torch.sigmoid(output) return torch.sigmoid(output)
def configure_optimizers(self): def configure_optimizers(self):
......
...@@ -165,23 +165,23 @@ class HED(Model): ...@@ -165,23 +165,23 @@ class HED(Model):
super().set_normalizer(dataloader) super().set_normalizer(dataloader)
def training_step(self, batch, batch_idx): def training_step(self, batch, batch_idx):
images = batch[0] images = batch[0]["image"]
ground_truths = batch[1]["target"] ground_truths = batch[0]["target"]
masks = batch[1]["mask"] masks = batch[0]["mask"]
outputs = self(self._augmentation_transforms(images)) outputs = self(self._augmentation_transforms(images))
return self._train_loss(outputs, ground_truths, masks) return self._train_loss(outputs, ground_truths, masks)
def validation_step(self, batch, batch_idx): def validation_step(self, batch, batch_idx):
images = batch[0] images = batch[0]["image"]
ground_truths = batch[1]["target"] ground_truths = batch[0]["target"]
masks = batch[1]["mask"] masks = batch[0]["mask"]
outputs = self(images) outputs = self(images)
return self._validation_loss(outputs, ground_truths, masks) return self._validation_loss(outputs, ground_truths, masks)
def predict_step(self, batch, batch_idx, dataloader_idx=0): def predict_step(self, batch, batch_idx, dataloader_idx=0):
output = self(batch[0])[1] output = self(batch[0]["image"])[1]
return torch.sigmoid(output) return torch.sigmoid(output)
def configure_optimizers(self): def configure_optimizers(self):
......
...@@ -366,23 +366,23 @@ class LittleWNet(Model): ...@@ -366,23 +366,23 @@ class LittleWNet(Model):
return x1, x2 return x1, x2
def training_step(self, batch, batch_idx): def training_step(self, batch, batch_idx):
images = batch[0] images = batch[0]["image"]
ground_truths = batch[1]["target"] ground_truths = batch[0]["target"]
masks = batch[1]["mask"] masks = batch[0]["mask"]
outputs = self(self._augmentation_transforms(images)) outputs = self(self._augmentation_transforms(images))
return self._train_loss(outputs, ground_truths, masks) return self._train_loss(outputs, ground_truths, masks)
def validation_step(self, batch, batch_idx): def validation_step(self, batch, batch_idx):
images = batch[0] images = batch[0]["image"]
ground_truths = batch[1]["target"] ground_truths = batch[0]["target"]
masks = batch[1]["mask"] masks = batch[0]["mask"]
outputs = self(images) outputs = self(images)
return self._validation_loss(outputs, ground_truths, masks) return self._validation_loss(outputs, ground_truths, masks)
def predict_step(self, batch, batch_idx, dataloader_idx=0): def predict_step(self, batch, batch_idx, dataloader_idx=0):
output = self(batch[0])[1] output = self(batch[0]["image"])[1]
return torch.sigmoid(output) return torch.sigmoid(output)
def configure_optimizers(self): def configure_optimizers(self):
......
...@@ -213,23 +213,23 @@ class M2UNET(Model): ...@@ -213,23 +213,23 @@ class M2UNET(Model):
super().set_normalizer(dataloader) super().set_normalizer(dataloader)
def training_step(self, batch, batch_idx): def training_step(self, batch, batch_idx):
images = batch[0] images = batch[0]["image"]
ground_truths = batch[1]["target"] ground_truths = batch[0]["target"]
masks = batch[1]["mask"] masks = batch[0]["mask"]
outputs = self(self._augmentation_transforms(images)) outputs = self(self._augmentation_transforms(images))
return self._train_loss(outputs, ground_truths, masks) return self._train_loss(outputs, ground_truths, masks)
def validation_step(self, batch, batch_idx): def validation_step(self, batch, batch_idx):
images = batch[0] images = batch[0]["image"]
ground_truths = batch[1]["target"] ground_truths = batch[0]["target"]
masks = batch[1]["mask"] masks = batch[0]["mask"]
outputs = self(images) outputs = self(images)
return self._validation_loss(outputs, ground_truths, masks) return self._validation_loss(outputs, ground_truths, masks)
def predict_step(self, batch, batch_idx, dataloader_idx=0): def predict_step(self, batch, batch_idx, dataloader_idx=0):
output = self(batch[0])[1] output = self(batch[0]["image"])[1]
return torch.sigmoid(output) return torch.sigmoid(output)
def configure_optimizers(self): def configure_optimizers(self):
......
...@@ -154,23 +154,23 @@ class Unet(Model): ...@@ -154,23 +154,23 @@ class Unet(Model):
super().set_normalizer(dataloader) super().set_normalizer(dataloader)
def training_step(self, batch, batch_idx): def training_step(self, batch, batch_idx):
images = batch[0] images = batch[0]["image"]
ground_truths = batch[1]["target"] ground_truths = batch[0]["target"]
masks = batch[1]["mask"] masks = batch[0]["mask"]
outputs = self(self._augmentation_transforms(images)) outputs = self(self._augmentation_transforms(images))
return self._train_loss(outputs, ground_truths, masks) return self._train_loss(outputs, ground_truths, masks)
def validation_step(self, batch, batch_idx): def validation_step(self, batch, batch_idx):
images = batch[0] images = batch[0]["image"]
ground_truths = batch[1]["target"] ground_truths = batch[0]["target"]
masks = batch[1]["mask"] masks = batch[0]["mask"]
outputs = self(images) outputs = self(images)
return self._validation_loss(outputs, ground_truths, masks) return self._validation_loss(outputs, ground_truths, masks)
def predict_step(self, batch, batch_idx, dataloader_idx=0): def predict_step(self, batch, batch_idx, dataloader_idx=0):
output = self(batch[0])[1] output = self(batch[0]["image"])[1]
return torch.sigmoid(output) return torch.sigmoid(output)
def configure_optimizers(self): def configure_optimizers(self):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment