Skip to content
Snippets Groups Projects
Commit 5f0c48a1 authored by Daniel CARRON's avatar Daniel CARRON :b: Committed by André Anjos
Browse files

Supports for extra_valid loaders in all models

parent a321b3b3
No related branches found
No related tags found
1 merge request!6Making use of LightningDataModule and simplification of data loading
...@@ -66,7 +66,7 @@ class Alexnet(pl.LightningModule): ...@@ -66,7 +66,7 @@ class Alexnet(pl.LightningModule):
return {"loss": training_loss} return {"loss": training_loss}
def validation_step(self, batch, batch_idx): def validation_step(self, batch, batch_idx, dataloader_idx=0):
images = batch[1] images = batch[1]
labels = batch[2] labels = batch[2]
...@@ -84,9 +84,12 @@ class Alexnet(pl.LightningModule): ...@@ -84,9 +84,12 @@ class Alexnet(pl.LightningModule):
) )
validation_loss = self.hparams.criterion_valid(outputs, labels.float()) validation_loss = self.hparams.criterion_valid(outputs, labels.float())
return {"validation_loss": validation_loss} if dataloader_idx == 0:
return {"validation_loss": validation_loss}
else:
return {f"extra_validation_loss_{dataloader_idx}": validation_loss}
def predict_step(self, batch, batch_idx, grad_cams=False): def predict_step(self, batch, batch_idx, dataloader_idx=0, grad_cams=False):
names = batch[0] names = batch[0]
images = batch[1] images = batch[1]
......
...@@ -66,7 +66,7 @@ class Densenet(pl.LightningModule): ...@@ -66,7 +66,7 @@ class Densenet(pl.LightningModule):
return {"loss": training_loss} return {"loss": training_loss}
def validation_step(self, batch, batch_idx): def validation_step(self, batch, batch_idx, dataloader_idx=0):
images = batch[1] images = batch[1]
labels = batch[2] labels = batch[2]
...@@ -84,9 +84,12 @@ class Densenet(pl.LightningModule): ...@@ -84,9 +84,12 @@ class Densenet(pl.LightningModule):
) )
validation_loss = self.hparams.criterion_valid(outputs, labels.float()) validation_loss = self.hparams.criterion_valid(outputs, labels.float())
return {"validation_loss": validation_loss} if dataloader_idx == 0:
return {"validation_loss": validation_loss}
else:
return {f"extra_validation_loss_{dataloader_idx}": validation_loss}
def predict_step(self, batch, batch_idx, grad_cams=False): def predict_step(self, batch, batch_idx, dataloader_idx=0, grad_cams=False):
names = batch[0] names = batch[0]
images = batch[1] images = batch[1]
......
...@@ -60,7 +60,7 @@ class DensenetRS(pl.LightningModule): ...@@ -60,7 +60,7 @@ class DensenetRS(pl.LightningModule):
return {"loss": training_loss} return {"loss": training_loss}
def validation_step(self, batch, batch_idx): def validation_step(self, batch, batch_idx, dataloader_idx=0):
images = batch[1] images = batch[1]
labels = batch[2] labels = batch[2]
...@@ -78,9 +78,12 @@ class DensenetRS(pl.LightningModule): ...@@ -78,9 +78,12 @@ class DensenetRS(pl.LightningModule):
) )
validation_loss = self.hparams.criterion_valid(outputs, labels.float()) validation_loss = self.hparams.criterion_valid(outputs, labels.float())
return {"validation_loss": validation_loss} if dataloader_idx == 0:
return {"validation_loss": validation_loss}
else:
return {f"extra_validation_loss_{dataloader_idx}": validation_loss}
def predict_step(self, batch, batch_idx, grad_cams=False): def predict_step(self, batch, batch_idx, dataloader_idx=0, grad_cams=False):
names = batch[0] names = batch[0]
images = batch[1] images = batch[1]
......
...@@ -49,7 +49,7 @@ class LogisticRegression(pl.LightningModule): ...@@ -49,7 +49,7 @@ class LogisticRegression(pl.LightningModule):
return {"loss": training_loss} return {"loss": training_loss}
def validation_step(self, batch, batch_idx): def validation_step(self, batch, batch_idx, dataloader_idx=0):
images = batch[1] images = batch[1]
labels = batch[2] labels = batch[2]
...@@ -67,9 +67,12 @@ class LogisticRegression(pl.LightningModule): ...@@ -67,9 +67,12 @@ class LogisticRegression(pl.LightningModule):
) )
validation_loss = self.hparams.criterion_valid(outputs, labels.float()) validation_loss = self.hparams.criterion_valid(outputs, labels.float())
return {"validation_loss": validation_loss} if dataloader_idx == 0:
return {"validation_loss": validation_loss}
else:
return {f"extra_validation_loss_{dataloader_idx}": validation_loss}
def predict_step(self, batch, batch_idx, grad_cams=False): def predict_step(self, batch, batch_idx, dataloader_idx=0, grad_cams=False):
names = batch[0] names = batch[0]
images = batch[1] images = batch[1]
......
...@@ -56,7 +56,7 @@ class SignsToTB(pl.LightningModule): ...@@ -56,7 +56,7 @@ class SignsToTB(pl.LightningModule):
return {"loss": training_loss} return {"loss": training_loss}
def validation_step(self, batch, batch_idx): def validation_step(self, batch, batch_idx, dataloader_idx=0):
images = batch[1] images = batch[1]
labels = batch[2] labels = batch[2]
...@@ -74,9 +74,12 @@ class SignsToTB(pl.LightningModule): ...@@ -74,9 +74,12 @@ class SignsToTB(pl.LightningModule):
) )
validation_loss = self.hparams.criterion_valid(outputs, labels.float()) validation_loss = self.hparams.criterion_valid(outputs, labels.float())
return {"validation_loss": validation_loss} if dataloader_idx == 0:
return {"validation_loss": validation_loss}
else:
return {f"extra_validation_loss_{dataloader_idx}": validation_loss}
def predict_step(self, batch, batch_idx, grad_cams=False): def predict_step(self, batch, batch_idx, dataloader_idx=0, grad_cams=False):
names = batch[0] names = batch[0]
images = batch[1] images = batch[1]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment