From 029a57a9fe305f96b57c56fe9fed8bdf4740e863 Mon Sep 17 00:00:00 2001 From: dcarron <daniel.carron@idiap.ch> Date: Tue, 11 Apr 2023 14:49:09 +0200 Subject: [PATCH] Properly save/load criterion --- src/ptbench/models/alexnet.py | 2 +- src/ptbench/models/densenet.py | 2 +- src/ptbench/models/logistic_regression.py | 2 +- src/ptbench/models/pasa.py | 2 +- src/ptbench/models/signs_to_tb.py | 2 +- src/ptbench/scripts/predict.py | 4 +++- 6 files changed, 8 insertions(+), 6 deletions(-) diff --git a/src/ptbench/models/alexnet.py b/src/ptbench/models/alexnet.py index 7aaaccb6..59acba15 100644 --- a/src/ptbench/models/alexnet.py +++ b/src/ptbench/models/alexnet.py @@ -26,7 +26,7 @@ class Alexnet(pl.LightningModule): ): super().__init__() - self.save_hyperparameters() + self.save_hyperparameters(ignore=["criterion", "criterion_valid"]) self.criterion = criterion self.criterion_valid = criterion_valid diff --git a/src/ptbench/models/densenet.py b/src/ptbench/models/densenet.py index 4e5b34c0..b44dac93 100644 --- a/src/ptbench/models/densenet.py +++ b/src/ptbench/models/densenet.py @@ -27,7 +27,7 @@ class Densenet(pl.LightningModule): ): super().__init__() - self.save_hyperparameters() + self.save_hyperparameters(ignore=["criterion", "criterion_valid"]) self.name = "Densenet" diff --git a/src/ptbench/models/logistic_regression.py b/src/ptbench/models/logistic_regression.py index 684155b4..ad56cb80 100644 --- a/src/ptbench/models/logistic_regression.py +++ b/src/ptbench/models/logistic_regression.py @@ -20,7 +20,7 @@ class LogisticRegression(pl.LightningModule): ): super().__init__() - self.save_hyperparameters() + self.save_hyperparameters(ignore=["criterion", "criterion_valid"]) self.criterion = criterion self.criterion_valid = criterion_valid diff --git a/src/ptbench/models/pasa.py b/src/ptbench/models/pasa.py index 8c9705e6..af47d9e3 100644 --- a/src/ptbench/models/pasa.py +++ b/src/ptbench/models/pasa.py @@ -38,7 +38,7 @@ class PASA(pl.LightningModule): ): super().__init__() - self.save_hyperparameters() + self.save_hyperparameters(ignore=["criterion", "criterion_valid"]) self.name = "pasa" diff --git a/src/ptbench/models/signs_to_tb.py b/src/ptbench/models/signs_to_tb.py index 653b590a..0169a1b8 100644 --- a/src/ptbench/models/signs_to_tb.py +++ b/src/ptbench/models/signs_to_tb.py @@ -20,7 +20,7 @@ class SignsToTB(pl.LightningModule): ): super().__init__() - self.save_hyperparameters() + self.save_hyperparameters(ignore=["criterion", "criterion_valid"]) self.name = "signs_to_tb" diff --git a/src/ptbench/scripts/predict.py b/src/ptbench/scripts/predict.py index 82939f25..860d95b2 100644 --- a/src/ptbench/scripts/predict.py +++ b/src/ptbench/scripts/predict.py @@ -122,7 +122,9 @@ def predict( dataset = dataset if isinstance(dataset, dict) else dict(test=dataset) - model = model.load_from_checkpoint(weight) + model = model.load_from_checkpoint( + weight, criterion=model.criterion, criterion_valid=model.criterion_valid + ) # Logistic regressor weights if model.name == "logistic_regression": -- GitLab