From 239ca0ff903b09a9f2483bf0343ab9eacfb1c26f Mon Sep 17 00:00:00 2001 From: dcarron <daniel.carron@idiap.ch> Date: Wed, 24 May 2023 15:01:50 +0200 Subject: [PATCH] Supports for extra_valid loaders in all models --- src/ptbench/models/alexnet.py | 9 ++++++--- src/ptbench/models/densenet.py | 9 ++++++--- src/ptbench/models/densenet_rs.py | 9 ++++++--- src/ptbench/models/logistic_regression.py | 9 ++++++--- src/ptbench/models/signs_to_tb.py | 9 ++++++--- 5 files changed, 30 insertions(+), 15 deletions(-) diff --git a/src/ptbench/models/alexnet.py b/src/ptbench/models/alexnet.py index e871a982..ba9bf05f 100644 --- a/src/ptbench/models/alexnet.py +++ b/src/ptbench/models/alexnet.py @@ -66,7 +66,7 @@ class Alexnet(pl.LightningModule): return {"loss": training_loss} - def validation_step(self, batch, batch_idx): + def validation_step(self, batch, batch_idx, dataloader_idx=0): images = batch[1] labels = batch[2] @@ -84,9 +84,12 @@ class Alexnet(pl.LightningModule): ) validation_loss = self.hparams.criterion_valid(outputs, labels.float()) - return {"validation_loss": validation_loss} + if dataloader_idx == 0: + return {"validation_loss": validation_loss} + else: + return {f"extra_validation_loss_{dataloader_idx}": validation_loss} - def predict_step(self, batch, batch_idx, grad_cams=False): + def predict_step(self, batch, batch_idx, dataloader_idx=0, grad_cams=False): names = batch[0] images = batch[1] diff --git a/src/ptbench/models/densenet.py b/src/ptbench/models/densenet.py index ea6e623c..a7cf9d56 100644 --- a/src/ptbench/models/densenet.py +++ b/src/ptbench/models/densenet.py @@ -66,7 +66,7 @@ class Densenet(pl.LightningModule): return {"loss": training_loss} - def validation_step(self, batch, batch_idx): + def validation_step(self, batch, batch_idx, dataloader_idx=0): images = batch[1] labels = batch[2] @@ -84,9 +84,12 @@ class Densenet(pl.LightningModule): ) validation_loss = self.hparams.criterion_valid(outputs, labels.float()) - return {"validation_loss": validation_loss} + if dataloader_idx == 0: + return {"validation_loss": validation_loss} + else: + return {f"extra_validation_loss_{dataloader_idx}": validation_loss} - def predict_step(self, batch, batch_idx, grad_cams=False): + def predict_step(self, batch, batch_idx, dataloader_idx=0, grad_cams=False): names = batch[0] images = batch[1] diff --git a/src/ptbench/models/densenet_rs.py b/src/ptbench/models/densenet_rs.py index a9d69e27..0fbf2b25 100644 --- a/src/ptbench/models/densenet_rs.py +++ b/src/ptbench/models/densenet_rs.py @@ -60,7 +60,7 @@ class DensenetRS(pl.LightningModule): return {"loss": training_loss} - def validation_step(self, batch, batch_idx): + def validation_step(self, batch, batch_idx, dataloader_idx=0): images = batch[1] labels = batch[2] @@ -78,9 +78,12 @@ class DensenetRS(pl.LightningModule): ) validation_loss = self.hparams.criterion_valid(outputs, labels.float()) - return {"validation_loss": validation_loss} + if dataloader_idx == 0: + return {"validation_loss": validation_loss} + else: + return {f"extra_validation_loss_{dataloader_idx}": validation_loss} - def predict_step(self, batch, batch_idx, grad_cams=False): + def predict_step(self, batch, batch_idx, dataloader_idx=0, grad_cams=False): names = batch[0] images = batch[1] diff --git a/src/ptbench/models/logistic_regression.py b/src/ptbench/models/logistic_regression.py index 6efd2a25..dfde8318 100644 --- a/src/ptbench/models/logistic_regression.py +++ b/src/ptbench/models/logistic_regression.py @@ -49,7 +49,7 @@ class LogisticRegression(pl.LightningModule): return {"loss": training_loss} - def validation_step(self, batch, batch_idx): + def validation_step(self, batch, batch_idx, dataloader_idx=0): images = batch[1] labels = batch[2] @@ -67,9 +67,12 @@ class LogisticRegression(pl.LightningModule): ) validation_loss = self.hparams.criterion_valid(outputs, labels.float()) - return {"validation_loss": validation_loss} + if dataloader_idx == 0: + return {"validation_loss": validation_loss} + else: + return {f"extra_validation_loss_{dataloader_idx}": validation_loss} - def predict_step(self, batch, batch_idx, grad_cams=False): + def predict_step(self, batch, batch_idx, dataloader_idx=0, grad_cams=False): names = batch[0] images = batch[1] diff --git a/src/ptbench/models/signs_to_tb.py b/src/ptbench/models/signs_to_tb.py index aa228645..2f86ded5 100644 --- a/src/ptbench/models/signs_to_tb.py +++ b/src/ptbench/models/signs_to_tb.py @@ -56,7 +56,7 @@ class SignsToTB(pl.LightningModule): return {"loss": training_loss} - def validation_step(self, batch, batch_idx): + def validation_step(self, batch, batch_idx, dataloader_idx=0): images = batch[1] labels = batch[2] @@ -74,9 +74,12 @@ class SignsToTB(pl.LightningModule): ) validation_loss = self.hparams.criterion_valid(outputs, labels.float()) - return {"validation_loss": validation_loss} + if dataloader_idx == 0: + return {"validation_loss": validation_loss} + else: + return {f"extra_validation_loss_{dataloader_idx}": validation_loss} - def predict_step(self, batch, batch_idx, grad_cams=False): + def predict_step(self, batch, batch_idx, dataloader_idx=0, grad_cams=False): names = batch[0] images = batch[1] -- GitLab