From 029a57a9fe305f96b57c56fe9fed8bdf4740e863 Mon Sep 17 00:00:00 2001
From: dcarron <daniel.carron@idiap.ch>
Date: Tue, 11 Apr 2023 14:49:09 +0200
Subject: [PATCH] Properly save/load criterion

---
 src/ptbench/models/alexnet.py             | 2 +-
 src/ptbench/models/densenet.py            | 2 +-
 src/ptbench/models/logistic_regression.py | 2 +-
 src/ptbench/models/pasa.py                | 2 +-
 src/ptbench/models/signs_to_tb.py         | 2 +-
 src/ptbench/scripts/predict.py            | 4 +++-
 6 files changed, 8 insertions(+), 6 deletions(-)

diff --git a/src/ptbench/models/alexnet.py b/src/ptbench/models/alexnet.py
index 7aaaccb6..59acba15 100644
--- a/src/ptbench/models/alexnet.py
+++ b/src/ptbench/models/alexnet.py
@@ -26,7 +26,7 @@ class Alexnet(pl.LightningModule):
     ):
         super().__init__()
 
-        self.save_hyperparameters()
+        self.save_hyperparameters(ignore=["criterion", "criterion_valid"])
 
         self.criterion = criterion
         self.criterion_valid = criterion_valid
diff --git a/src/ptbench/models/densenet.py b/src/ptbench/models/densenet.py
index 4e5b34c0..b44dac93 100644
--- a/src/ptbench/models/densenet.py
+++ b/src/ptbench/models/densenet.py
@@ -27,7 +27,7 @@ class Densenet(pl.LightningModule):
     ):
         super().__init__()
 
-        self.save_hyperparameters()
+        self.save_hyperparameters(ignore=["criterion", "criterion_valid"])
 
         self.name = "Densenet"
 
diff --git a/src/ptbench/models/logistic_regression.py b/src/ptbench/models/logistic_regression.py
index 684155b4..ad56cb80 100644
--- a/src/ptbench/models/logistic_regression.py
+++ b/src/ptbench/models/logistic_regression.py
@@ -20,7 +20,7 @@ class LogisticRegression(pl.LightningModule):
     ):
         super().__init__()
 
-        self.save_hyperparameters()
+        self.save_hyperparameters(ignore=["criterion", "criterion_valid"])
 
         self.criterion = criterion
         self.criterion_valid = criterion_valid
diff --git a/src/ptbench/models/pasa.py b/src/ptbench/models/pasa.py
index 8c9705e6..af47d9e3 100644
--- a/src/ptbench/models/pasa.py
+++ b/src/ptbench/models/pasa.py
@@ -38,7 +38,7 @@ class PASA(pl.LightningModule):
     ):
         super().__init__()
 
-        self.save_hyperparameters()
+        self.save_hyperparameters(ignore=["criterion", "criterion_valid"])
 
         self.name = "pasa"
 
diff --git a/src/ptbench/models/signs_to_tb.py b/src/ptbench/models/signs_to_tb.py
index 653b590a..0169a1b8 100644
--- a/src/ptbench/models/signs_to_tb.py
+++ b/src/ptbench/models/signs_to_tb.py
@@ -20,7 +20,7 @@ class SignsToTB(pl.LightningModule):
     ):
         super().__init__()
 
-        self.save_hyperparameters()
+        self.save_hyperparameters(ignore=["criterion", "criterion_valid"])
 
         self.name = "signs_to_tb"
 
diff --git a/src/ptbench/scripts/predict.py b/src/ptbench/scripts/predict.py
index 82939f25..860d95b2 100644
--- a/src/ptbench/scripts/predict.py
+++ b/src/ptbench/scripts/predict.py
@@ -122,7 +122,9 @@ def predict(
 
     dataset = dataset if isinstance(dataset, dict) else dict(test=dataset)
 
-    model = model.load_from_checkpoint(weight)
+    model = model.load_from_checkpoint(
+        weight, criterion=model.criterion, criterion_valid=model.criterion_valid
+    )
 
     # Logistic regressor weights
     if model.name == "logistic_regression":
-- 
GitLab