diff --git a/src/ptbench/models/alexnet.py b/src/ptbench/models/alexnet.py
index e871a982ea393c919aab6c819ef2dd2bb70fcc96..ba9bf05f7428d759489bd744f8ec35c3b43bab02 100644
--- a/src/ptbench/models/alexnet.py
+++ b/src/ptbench/models/alexnet.py
@@ -66,7 +66,7 @@ class Alexnet(pl.LightningModule):
 
         return {"loss": training_loss}
 
-    def validation_step(self, batch, batch_idx):
+    def validation_step(self, batch, batch_idx, dataloader_idx=0):
         images = batch[1]
         labels = batch[2]
 
@@ -84,9 +84,12 @@ class Alexnet(pl.LightningModule):
         )
         validation_loss = self.hparams.criterion_valid(outputs, labels.float())
 
-        return {"validation_loss": validation_loss}
+        if dataloader_idx == 0:
+            return {"validation_loss": validation_loss}
+        else:
+            return {f"extra_validation_loss_{dataloader_idx}": validation_loss}
 
-    def predict_step(self, batch, batch_idx, grad_cams=False):
+    def predict_step(self, batch, batch_idx, dataloader_idx=0, grad_cams=False):
         names = batch[0]
         images = batch[1]
 
diff --git a/src/ptbench/models/densenet.py b/src/ptbench/models/densenet.py
index ea6e623c3a9cdfc3f9d6896ea77a681f7a2f5cc7..a7cf9d567946899c874efce1664d0e7f65d5ace2 100644
--- a/src/ptbench/models/densenet.py
+++ b/src/ptbench/models/densenet.py
@@ -66,7 +66,7 @@ class Densenet(pl.LightningModule):
 
         return {"loss": training_loss}
 
-    def validation_step(self, batch, batch_idx):
+    def validation_step(self, batch, batch_idx, dataloader_idx=0):
         images = batch[1]
         labels = batch[2]
 
@@ -84,9 +84,12 @@ class Densenet(pl.LightningModule):
         )
         validation_loss = self.hparams.criterion_valid(outputs, labels.float())
 
-        return {"validation_loss": validation_loss}
+        if dataloader_idx == 0:
+            return {"validation_loss": validation_loss}
+        else:
+            return {f"extra_validation_loss_{dataloader_idx}": validation_loss}
 
-    def predict_step(self, batch, batch_idx, grad_cams=False):
+    def predict_step(self, batch, batch_idx, dataloader_idx=0, grad_cams=False):
         names = batch[0]
         images = batch[1]
 
diff --git a/src/ptbench/models/densenet_rs.py b/src/ptbench/models/densenet_rs.py
index a9d69e27928d5ec9a3d525d1a043370deeacb119..0fbf2b258e3432fa3ae6099973e7cd56317eaedb 100644
--- a/src/ptbench/models/densenet_rs.py
+++ b/src/ptbench/models/densenet_rs.py
@@ -60,7 +60,7 @@ class DensenetRS(pl.LightningModule):
 
         return {"loss": training_loss}
 
-    def validation_step(self, batch, batch_idx):
+    def validation_step(self, batch, batch_idx, dataloader_idx=0):
         images = batch[1]
         labels = batch[2]
 
@@ -78,9 +78,12 @@ class DensenetRS(pl.LightningModule):
         )
         validation_loss = self.hparams.criterion_valid(outputs, labels.float())
 
-        return {"validation_loss": validation_loss}
+        if dataloader_idx == 0:
+            return {"validation_loss": validation_loss}
+        else:
+            return {f"extra_validation_loss_{dataloader_idx}": validation_loss}
 
-    def predict_step(self, batch, batch_idx, grad_cams=False):
+    def predict_step(self, batch, batch_idx, dataloader_idx=0, grad_cams=False):
         names = batch[0]
         images = batch[1]
 
diff --git a/src/ptbench/models/logistic_regression.py b/src/ptbench/models/logistic_regression.py
index 6efd2a25c9726d5aeb081ae2a7ed22192b9befcd..dfde83181a466b3ca2d4b5fd71f02bf8e583836e 100644
--- a/src/ptbench/models/logistic_regression.py
+++ b/src/ptbench/models/logistic_regression.py
@@ -49,7 +49,7 @@ class LogisticRegression(pl.LightningModule):
 
         return {"loss": training_loss}
 
-    def validation_step(self, batch, batch_idx):
+    def validation_step(self, batch, batch_idx, dataloader_idx=0):
         images = batch[1]
         labels = batch[2]
 
@@ -67,9 +67,12 @@ class LogisticRegression(pl.LightningModule):
         )
         validation_loss = self.hparams.criterion_valid(outputs, labels.float())
 
-        return {"validation_loss": validation_loss}
+        if dataloader_idx == 0:
+            return {"validation_loss": validation_loss}
+        else:
+            return {f"extra_validation_loss_{dataloader_idx}": validation_loss}
 
-    def predict_step(self, batch, batch_idx, grad_cams=False):
+    def predict_step(self, batch, batch_idx, dataloader_idx=0, grad_cams=False):
         names = batch[0]
         images = batch[1]
 
diff --git a/src/ptbench/models/signs_to_tb.py b/src/ptbench/models/signs_to_tb.py
index aa22864558aec7340cfd53b7e3b9622e72980e8a..2f86ded58e518520efc3441fe2c1f6fd74d7a5eb 100644
--- a/src/ptbench/models/signs_to_tb.py
+++ b/src/ptbench/models/signs_to_tb.py
@@ -56,7 +56,7 @@ class SignsToTB(pl.LightningModule):
 
         return {"loss": training_loss}
 
-    def validation_step(self, batch, batch_idx):
+    def validation_step(self, batch, batch_idx, dataloader_idx=0):
         images = batch[1]
         labels = batch[2]
 
@@ -74,9 +74,12 @@ class SignsToTB(pl.LightningModule):
         )
         validation_loss = self.hparams.criterion_valid(outputs, labels.float())
 
-        return {"validation_loss": validation_loss}
+        if dataloader_idx == 0:
+            return {"validation_loss": validation_loss}
+        else:
+            return {f"extra_validation_loss_{dataloader_idx}": validation_loss}
 
-    def predict_step(self, batch, batch_idx, grad_cams=False):
+    def predict_step(self, batch, batch_idx, dataloader_idx=0, grad_cams=False):
         names = batch[0]
         images = batch[1]