From 5f0c48a13d94b9228dcdcf2ca5f163063d53b856 Mon Sep 17 00:00:00 2001
From: dcarron <daniel.carron@idiap.ch>
Date: Wed, 24 May 2023 15:01:50 +0200
Subject: [PATCH] Supports for extra_valid loaders in all models

---
 src/ptbench/models/alexnet.py             | 9 ++++++---
 src/ptbench/models/densenet.py            | 9 ++++++---
 src/ptbench/models/densenet_rs.py         | 9 ++++++---
 src/ptbench/models/logistic_regression.py | 9 ++++++---
 src/ptbench/models/signs_to_tb.py         | 9 ++++++---
 5 files changed, 30 insertions(+), 15 deletions(-)

diff --git a/src/ptbench/models/alexnet.py b/src/ptbench/models/alexnet.py
index e871a982..ba9bf05f 100644
--- a/src/ptbench/models/alexnet.py
+++ b/src/ptbench/models/alexnet.py
@@ -66,7 +66,7 @@ class Alexnet(pl.LightningModule):
 
         return {"loss": training_loss}
 
-    def validation_step(self, batch, batch_idx):
+    def validation_step(self, batch, batch_idx, dataloader_idx=0):
         images = batch[1]
         labels = batch[2]
 
@@ -84,9 +84,12 @@ class Alexnet(pl.LightningModule):
         )
         validation_loss = self.hparams.criterion_valid(outputs, labels.float())
 
-        return {"validation_loss": validation_loss}
+        if dataloader_idx == 0:
+            return {"validation_loss": validation_loss}
+        else:
+            return {f"extra_validation_loss_{dataloader_idx}": validation_loss}
 
-    def predict_step(self, batch, batch_idx, grad_cams=False):
+    def predict_step(self, batch, batch_idx, dataloader_idx=0, grad_cams=False):
         names = batch[0]
         images = batch[1]
 
diff --git a/src/ptbench/models/densenet.py b/src/ptbench/models/densenet.py
index ea6e623c..a7cf9d56 100644
--- a/src/ptbench/models/densenet.py
+++ b/src/ptbench/models/densenet.py
@@ -66,7 +66,7 @@ class Densenet(pl.LightningModule):
 
         return {"loss": training_loss}
 
-    def validation_step(self, batch, batch_idx):
+    def validation_step(self, batch, batch_idx, dataloader_idx=0):
         images = batch[1]
         labels = batch[2]
 
@@ -84,9 +84,12 @@ class Densenet(pl.LightningModule):
         )
         validation_loss = self.hparams.criterion_valid(outputs, labels.float())
 
-        return {"validation_loss": validation_loss}
+        if dataloader_idx == 0:
+            return {"validation_loss": validation_loss}
+        else:
+            return {f"extra_validation_loss_{dataloader_idx}": validation_loss}
 
-    def predict_step(self, batch, batch_idx, grad_cams=False):
+    def predict_step(self, batch, batch_idx, dataloader_idx=0, grad_cams=False):
         names = batch[0]
         images = batch[1]
 
diff --git a/src/ptbench/models/densenet_rs.py b/src/ptbench/models/densenet_rs.py
index a9d69e27..0fbf2b25 100644
--- a/src/ptbench/models/densenet_rs.py
+++ b/src/ptbench/models/densenet_rs.py
@@ -60,7 +60,7 @@ class DensenetRS(pl.LightningModule):
 
         return {"loss": training_loss}
 
-    def validation_step(self, batch, batch_idx):
+    def validation_step(self, batch, batch_idx, dataloader_idx=0):
         images = batch[1]
         labels = batch[2]
 
@@ -78,9 +78,12 @@ class DensenetRS(pl.LightningModule):
         )
         validation_loss = self.hparams.criterion_valid(outputs, labels.float())
 
-        return {"validation_loss": validation_loss}
+        if dataloader_idx == 0:
+            return {"validation_loss": validation_loss}
+        else:
+            return {f"extra_validation_loss_{dataloader_idx}": validation_loss}
 
-    def predict_step(self, batch, batch_idx, grad_cams=False):
+    def predict_step(self, batch, batch_idx, dataloader_idx=0, grad_cams=False):
         names = batch[0]
         images = batch[1]
 
diff --git a/src/ptbench/models/logistic_regression.py b/src/ptbench/models/logistic_regression.py
index 6efd2a25..dfde8318 100644
--- a/src/ptbench/models/logistic_regression.py
+++ b/src/ptbench/models/logistic_regression.py
@@ -49,7 +49,7 @@ class LogisticRegression(pl.LightningModule):
 
         return {"loss": training_loss}
 
-    def validation_step(self, batch, batch_idx):
+    def validation_step(self, batch, batch_idx, dataloader_idx=0):
         images = batch[1]
         labels = batch[2]
 
@@ -67,9 +67,12 @@ class LogisticRegression(pl.LightningModule):
         )
         validation_loss = self.hparams.criterion_valid(outputs, labels.float())
 
-        return {"validation_loss": validation_loss}
+        if dataloader_idx == 0:
+            return {"validation_loss": validation_loss}
+        else:
+            return {f"extra_validation_loss_{dataloader_idx}": validation_loss}
 
-    def predict_step(self, batch, batch_idx, grad_cams=False):
+    def predict_step(self, batch, batch_idx, dataloader_idx=0, grad_cams=False):
         names = batch[0]
         images = batch[1]
 
diff --git a/src/ptbench/models/signs_to_tb.py b/src/ptbench/models/signs_to_tb.py
index aa228645..2f86ded5 100644
--- a/src/ptbench/models/signs_to_tb.py
+++ b/src/ptbench/models/signs_to_tb.py
@@ -56,7 +56,7 @@ class SignsToTB(pl.LightningModule):
 
         return {"loss": training_loss}
 
-    def validation_step(self, batch, batch_idx):
+    def validation_step(self, batch, batch_idx, dataloader_idx=0):
         images = batch[1]
         labels = batch[2]
 
@@ -74,9 +74,12 @@ class SignsToTB(pl.LightningModule):
         )
         validation_loss = self.hparams.criterion_valid(outputs, labels.float())
 
-        return {"validation_loss": validation_loss}
+        if dataloader_idx == 0:
+            return {"validation_loss": validation_loss}
+        else:
+            return {f"extra_validation_loss_{dataloader_idx}": validation_loss}
 
-    def predict_step(self, batch, batch_idx, grad_cams=False):
+    def predict_step(self, batch, batch_idx, dataloader_idx=0, grad_cams=False):
         names = batch[0]
         images = batch[1]
 
-- 
GitLab