diff --git a/src/ptbench/models/alexnet.py b/src/ptbench/models/alexnet.py index e871a982ea393c919aab6c819ef2dd2bb70fcc96..ba9bf05f7428d759489bd744f8ec35c3b43bab02 100644 --- a/src/ptbench/models/alexnet.py +++ b/src/ptbench/models/alexnet.py @@ -66,7 +66,7 @@ class Alexnet(pl.LightningModule): return {"loss": training_loss} - def validation_step(self, batch, batch_idx): + def validation_step(self, batch, batch_idx, dataloader_idx=0): images = batch[1] labels = batch[2] @@ -84,9 +84,12 @@ class Alexnet(pl.LightningModule): ) validation_loss = self.hparams.criterion_valid(outputs, labels.float()) - return {"validation_loss": validation_loss} + if dataloader_idx == 0: + return {"validation_loss": validation_loss} + else: + return {f"extra_validation_loss_{dataloader_idx}": validation_loss} - def predict_step(self, batch, batch_idx, grad_cams=False): + def predict_step(self, batch, batch_idx, dataloader_idx=0, grad_cams=False): names = batch[0] images = batch[1] diff --git a/src/ptbench/models/densenet.py b/src/ptbench/models/densenet.py index ea6e623c3a9cdfc3f9d6896ea77a681f7a2f5cc7..a7cf9d567946899c874efce1664d0e7f65d5ace2 100644 --- a/src/ptbench/models/densenet.py +++ b/src/ptbench/models/densenet.py @@ -66,7 +66,7 @@ class Densenet(pl.LightningModule): return {"loss": training_loss} - def validation_step(self, batch, batch_idx): + def validation_step(self, batch, batch_idx, dataloader_idx=0): images = batch[1] labels = batch[2] @@ -84,9 +84,12 @@ class Densenet(pl.LightningModule): ) validation_loss = self.hparams.criterion_valid(outputs, labels.float()) - return {"validation_loss": validation_loss} + if dataloader_idx == 0: + return {"validation_loss": validation_loss} + else: + return {f"extra_validation_loss_{dataloader_idx}": validation_loss} - def predict_step(self, batch, batch_idx, grad_cams=False): + def predict_step(self, batch, batch_idx, dataloader_idx=0, grad_cams=False): names = batch[0] images = batch[1] diff --git a/src/ptbench/models/densenet_rs.py b/src/ptbench/models/densenet_rs.py index a9d69e27928d5ec9a3d525d1a043370deeacb119..0fbf2b258e3432fa3ae6099973e7cd56317eaedb 100644 --- a/src/ptbench/models/densenet_rs.py +++ b/src/ptbench/models/densenet_rs.py @@ -60,7 +60,7 @@ class DensenetRS(pl.LightningModule): return {"loss": training_loss} - def validation_step(self, batch, batch_idx): + def validation_step(self, batch, batch_idx, dataloader_idx=0): images = batch[1] labels = batch[2] @@ -78,9 +78,12 @@ class DensenetRS(pl.LightningModule): ) validation_loss = self.hparams.criterion_valid(outputs, labels.float()) - return {"validation_loss": validation_loss} + if dataloader_idx == 0: + return {"validation_loss": validation_loss} + else: + return {f"extra_validation_loss_{dataloader_idx}": validation_loss} - def predict_step(self, batch, batch_idx, grad_cams=False): + def predict_step(self, batch, batch_idx, dataloader_idx=0, grad_cams=False): names = batch[0] images = batch[1] diff --git a/src/ptbench/models/logistic_regression.py b/src/ptbench/models/logistic_regression.py index 6efd2a25c9726d5aeb081ae2a7ed22192b9befcd..dfde83181a466b3ca2d4b5fd71f02bf8e583836e 100644 --- a/src/ptbench/models/logistic_regression.py +++ b/src/ptbench/models/logistic_regression.py @@ -49,7 +49,7 @@ class LogisticRegression(pl.LightningModule): return {"loss": training_loss} - def validation_step(self, batch, batch_idx): + def validation_step(self, batch, batch_idx, dataloader_idx=0): images = batch[1] labels = batch[2] @@ -67,9 +67,12 @@ class LogisticRegression(pl.LightningModule): ) validation_loss = self.hparams.criterion_valid(outputs, labels.float()) - return {"validation_loss": validation_loss} + if dataloader_idx == 0: + return {"validation_loss": validation_loss} + else: + return {f"extra_validation_loss_{dataloader_idx}": validation_loss} - def predict_step(self, batch, batch_idx, grad_cams=False): + def predict_step(self, batch, batch_idx, dataloader_idx=0, grad_cams=False): names = batch[0] images = batch[1] diff --git a/src/ptbench/models/signs_to_tb.py b/src/ptbench/models/signs_to_tb.py index aa22864558aec7340cfd53b7e3b9622e72980e8a..2f86ded58e518520efc3441fe2c1f6fd74d7a5eb 100644 --- a/src/ptbench/models/signs_to_tb.py +++ b/src/ptbench/models/signs_to_tb.py @@ -56,7 +56,7 @@ class SignsToTB(pl.LightningModule): return {"loss": training_loss} - def validation_step(self, batch, batch_idx): + def validation_step(self, batch, batch_idx, dataloader_idx=0): images = batch[1] labels = batch[2] @@ -74,9 +74,12 @@ class SignsToTB(pl.LightningModule): ) validation_loss = self.hparams.criterion_valid(outputs, labels.float()) - return {"validation_loss": validation_loss} + if dataloader_idx == 0: + return {"validation_loss": validation_loss} + else: + return {f"extra_validation_loss_{dataloader_idx}": validation_loss} - def predict_step(self, batch, batch_idx, grad_cams=False): + def predict_step(self, batch, batch_idx, dataloader_idx=0, grad_cams=False): names = batch[0] images = batch[1]