From 99356e8279d3ac728ef1f0f1ac4464d6922fb609 Mon Sep 17 00:00:00 2001
From: dcarron <daniel.carron@idiap.ch>
Date: Mon, 8 May 2023 12:54:22 +0200
Subject: [PATCH] Fix saving and loading model hyperparameters

---
 src/ptbench/configs/models/pasa.py        |  2 ++
 src/ptbench/engine/callbacks.py           |  2 +-
 src/ptbench/models/alexnet.py             |  7 ++-----
 src/ptbench/models/densenet.py            |  7 ++-----
 src/ptbench/models/densenet_rs.py         |  7 ++-----
 src/ptbench/models/logistic_regression.py |  9 +++------
 src/ptbench/models/pasa.py                |  9 +++------
 src/ptbench/models/signs_to_tb.py         | 17 +++++++----------
 src/ptbench/scripts/predict.py            |  7 +++----
 src/ptbench/scripts/train.py              |  4 ++--
 10 files changed, 27 insertions(+), 44 deletions(-)

diff --git a/src/ptbench/configs/models/pasa.py b/src/ptbench/configs/models/pasa.py
index 3ee0b921..c02d650d 100644
--- a/src/ptbench/configs/models/pasa.py
+++ b/src/ptbench/configs/models/pasa.py
@@ -28,3 +28,5 @@ criterion_valid = BCEWithLogitsLoss(pos_weight=empty(1))
 
 # model
 model = PASA(criterion, criterion_valid, optimizer, optimizer_configs)
+model.criterion = criterion
+model.criterion_valid = criterion_valid
diff --git a/src/ptbench/engine/callbacks.py b/src/ptbench/engine/callbacks.py
index 0080676a..b1d86c8f 100644
--- a/src/ptbench/engine/callbacks.py
+++ b/src/ptbench/engine/callbacks.py
@@ -70,7 +70,7 @@ class LoggingCallback(Callback):
         assert self.resource_monitor.q.empty()
 
         for metric_name, metric_value in self.resource_monitor.data:
-            self.log(metric_name, metric_value)
+            self.log(metric_name, float(metric_value))
 
         self.resource_monitor.data = None
 
diff --git a/src/ptbench/models/alexnet.py b/src/ptbench/models/alexnet.py
index 59acba15..8b242277 100644
--- a/src/ptbench/models/alexnet.py
+++ b/src/ptbench/models/alexnet.py
@@ -28,9 +28,6 @@ class Alexnet(pl.LightningModule):
 
         self.save_hyperparameters(ignore=["criterion", "criterion_valid"])
 
-        self.criterion = criterion
-        self.criterion_valid = criterion_valid
-
         self.name = "AlexNet"
 
         # Load pretrained model
@@ -77,7 +74,7 @@ class Alexnet(pl.LightningModule):
         # Forward pass on the network
         outputs = self(images)
 
-        training_loss = self.criterion(outputs, labels.double())
+        training_loss = self.hparams.criterion(outputs, labels.double())
 
         return {"loss": training_loss}
 
@@ -92,7 +89,7 @@ class Alexnet(pl.LightningModule):
 
         # data forwarding on the existing network
         outputs = self(images)
-        validation_loss = self.criterion_valid(outputs, labels.double())
+        validation_loss = self.hparams.criterion_valid(outputs, labels.double())
 
         return {"validation_loss": validation_loss}
 
diff --git a/src/ptbench/models/densenet.py b/src/ptbench/models/densenet.py
index b44dac93..f5c58ad6 100644
--- a/src/ptbench/models/densenet.py
+++ b/src/ptbench/models/densenet.py
@@ -31,9 +31,6 @@ class Densenet(pl.LightningModule):
 
         self.name = "Densenet"
 
-        self.criterion = criterion
-        self.criterion_valid = criterion_valid
-
         self.normalizer = TorchVisionNormalizer(nb_channels=nb_channels)
 
         # Load pretrained model
@@ -78,7 +75,7 @@ class Densenet(pl.LightningModule):
         # Forward pass on the network
         outputs = self(images)
 
-        training_loss = self.criterion(outputs, labels.double())
+        training_loss = self.hparams.criterion(outputs, labels.double())
 
         return {"loss": training_loss}
 
@@ -93,7 +90,7 @@ class Densenet(pl.LightningModule):
 
         # data forwarding on the existing network
         outputs = self(images)
-        validation_loss = self.criterion_valid(outputs, labels.double())
+        validation_loss = self.hparams.criterion_valid(outputs, labels.double())
 
         return {"validation_loss": validation_loss}
 
diff --git a/src/ptbench/models/densenet_rs.py b/src/ptbench/models/densenet_rs.py
index 997516a0..97cb9bda 100644
--- a/src/ptbench/models/densenet_rs.py
+++ b/src/ptbench/models/densenet_rs.py
@@ -26,9 +26,6 @@ class DensenetRS(pl.LightningModule):
 
         self.name = "DensenetRS"
 
-        self.criterion = criterion
-        self.criterion_valid = criterion_valid
-
         self.normalizer = TorchVisionNormalizer()
 
         # Load pretrained model
@@ -72,7 +69,7 @@ class DensenetRS(pl.LightningModule):
         # Forward pass on the network
         outputs = self(images)
 
-        training_loss = self.criterion(outputs, labels.double())
+        training_loss = self.hparams.criterion(outputs, labels.double())
 
         return {"loss": training_loss}
 
@@ -87,7 +84,7 @@ class DensenetRS(pl.LightningModule):
 
         # data forwarding on the existing network
         outputs = self(images)
-        validation_loss = self.criterion_valid(outputs, labels.double())
+        validation_loss = self.hparams.criterion_valid(outputs, labels.double())
 
         return {"validation_loss": validation_loss}
 
diff --git a/src/ptbench/models/logistic_regression.py b/src/ptbench/models/logistic_regression.py
index ad56cb80..deda25aa 100644
--- a/src/ptbench/models/logistic_regression.py
+++ b/src/ptbench/models/logistic_regression.py
@@ -22,12 +22,9 @@ class LogisticRegression(pl.LightningModule):
 
         self.save_hyperparameters(ignore=["criterion", "criterion_valid"])
 
-        self.criterion = criterion
-        self.criterion_valid = criterion_valid
-
         self.name = "logistic_regression"
 
-        self.linear = nn.Linear(input_size, 1)
+        self.linear = nn.Linear(self.hparams.input_size, 1)
 
     def forward(self, x):
         """
@@ -60,7 +57,7 @@ class LogisticRegression(pl.LightningModule):
         # Forward pass on the network
         outputs = self(images)
 
-        training_loss = self.criterion(outputs, labels.double())
+        training_loss = self.hparams.criterion(outputs, labels.double())
 
         return {"loss": training_loss}
 
@@ -75,7 +72,7 @@ class LogisticRegression(pl.LightningModule):
 
         # data forwarding on the existing network
         outputs = self(images)
-        validation_loss = self.criterion_valid(outputs, labels.double())
+        validation_loss = self.hparams.criterion_valid(outputs, labels.double())
 
         return {"validation_loss": validation_loss}
 
diff --git a/src/ptbench/models/pasa.py b/src/ptbench/models/pasa.py
index af47d9e3..155aa7d8 100644
--- a/src/ptbench/models/pasa.py
+++ b/src/ptbench/models/pasa.py
@@ -38,13 +38,10 @@ class PASA(pl.LightningModule):
     ):
         super().__init__()
 
-        self.save_hyperparameters(ignore=["criterion", "criterion_valid"])
+        self.save_hyperparameters()
 
         self.name = "pasa"
 
-        self.criterion = criterion
-        self.criterion_valid = criterion_valid
-
         self.normalizer = TorchVisionNormalizer(nb_channels=1)
 
         # First convolution block
@@ -169,7 +166,7 @@ class PASA(pl.LightningModule):
         # Forward pass on the network
         outputs = self(images)
 
-        training_loss = self.criterion(outputs, labels.double())
+        training_loss = self.hparams.criterion(outputs, labels.double())
 
         return {"loss": training_loss}
 
@@ -184,7 +181,7 @@ class PASA(pl.LightningModule):
 
         # data forwarding on the existing network
         outputs = self(images)
-        validation_loss = self.criterion_valid(outputs, labels.double())
+        validation_loss = self.hparams.criterion_valid(outputs, labels.double())
 
         return {"validation_loss": validation_loss}
 
diff --git a/src/ptbench/models/signs_to_tb.py b/src/ptbench/models/signs_to_tb.py
index 0169a1b8..9db39f7a 100644
--- a/src/ptbench/models/signs_to_tb.py
+++ b/src/ptbench/models/signs_to_tb.py
@@ -20,18 +20,15 @@ class SignsToTB(pl.LightningModule):
     ):
         super().__init__()
 
-        self.save_hyperparameters(ignore=["criterion", "criterion_valid"])
+        self.save_hyperparameters()
 
         self.name = "signs_to_tb"
 
-        self.criterion = criterion
-        self.criterion_valid = criterion_valid
-
-        self.input_size = input_size
-        self.hidden_size = hidden_size
-        self.fc1 = torch.nn.Linear(self.input_size, self.hidden_size)
+        self.fc1 = torch.nn.Linear(
+            self.hparams.input_size, self.hparams.hidden_size
+        )
         self.relu = torch.nn.ReLU()
-        self.fc2 = torch.nn.Linear(self.hidden_size, 1)
+        self.fc2 = torch.nn.Linear(self.hparams.hidden_size, 1)
 
     def forward(self, x):
         """
@@ -67,7 +64,7 @@ class SignsToTB(pl.LightningModule):
         # Forward pass on the network
         outputs = self(images)
 
-        training_loss = self.criterion(outputs, labels.double())
+        training_loss = self.hparams.criterion(outputs, labels.double())
 
         return {"loss": training_loss}
 
@@ -82,7 +79,7 @@ class SignsToTB(pl.LightningModule):
 
         # data forwarding on the existing network
         outputs = self(images)
-        validation_loss = self.criterion_valid(outputs, labels.double())
+        validation_loss = self.hparams.criterion_valid(outputs, labels.double())
 
         return {"validation_loss": validation_loss}
 
diff --git a/src/ptbench/scripts/predict.py b/src/ptbench/scripts/predict.py
index 860d95b2..65336ac1 100644
--- a/src/ptbench/scripts/predict.py
+++ b/src/ptbench/scripts/predict.py
@@ -73,7 +73,7 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
 @click.option(
     "--weight",
     "-w",
-    help="Path or URL to pretrained model file (.pth extension)",
+    help="Path or URL to pretrained model file (.ckpt extension)",
     required=True,
     cls=ResourceOption,
 )
@@ -122,9 +122,8 @@ def predict(
 
     dataset = dataset if isinstance(dataset, dict) else dict(test=dataset)
 
-    model = model.load_from_checkpoint(
-        weight, criterion=model.criterion, criterion_valid=model.criterion_valid
-    )
+    logger.info(f"Loading checkpoint from {weight}")
+    model = model.load_from_checkpoint(weight, strict=False)
 
     # Logistic regressor weights
     if model.name == "logistic_regression":
diff --git a/src/ptbench/scripts/train.py b/src/ptbench/scripts/train.py
index f5bd3a7a..6d117c5f 100644
--- a/src/ptbench/scripts/train.py
+++ b/src/ptbench/scripts/train.py
@@ -349,7 +349,7 @@ def train(
     # Redefine a weighted criterion if possible
     if isinstance(criterion, torch.nn.BCEWithLogitsLoss):
         positive_weights = get_positive_weights(use_dataset)
-        model.criterion = BCEWithLogitsLoss(pos_weight=positive_weights)
+        model.hparams.criterion = BCEWithLogitsLoss(pos_weight=positive_weights)
     else:
         logger.warning("Weighted criterion not supported")
 
@@ -372,7 +372,7 @@ def train(
             or criterion_valid is None
         ):
             positive_weights = get_positive_weights(validation_dataset)
-            model.criterion_valid = BCEWithLogitsLoss(
+            model.hparams.criterion_valid = BCEWithLogitsLoss(
                 pos_weight=positive_weights
             )
         else:
-- 
GitLab