diff --git a/src/ptbench/configs/models/pasa.py b/src/ptbench/configs/models/pasa.py
index 3ee0b92164b5531b65049b94e71b01b07e2ad27e..c02d650dcaa03bd4c9aef47ca2d34b858ecd3e73 100644
--- a/src/ptbench/configs/models/pasa.py
+++ b/src/ptbench/configs/models/pasa.py
@@ -28,3 +28,5 @@ criterion_valid = BCEWithLogitsLoss(pos_weight=empty(1))
 
 # model
 model = PASA(criterion, criterion_valid, optimizer, optimizer_configs)
+model.criterion = criterion
+model.criterion_valid = criterion_valid
diff --git a/src/ptbench/engine/callbacks.py b/src/ptbench/engine/callbacks.py
index 0080676a2ba8b5ac63649e6a30f39593e11cfc12..b1d86c8f41bedb77389da4343756e99cf96b6d3b 100644
--- a/src/ptbench/engine/callbacks.py
+++ b/src/ptbench/engine/callbacks.py
@@ -70,7 +70,7 @@ class LoggingCallback(Callback):
         assert self.resource_monitor.q.empty()
 
         for metric_name, metric_value in self.resource_monitor.data:
-            self.log(metric_name, metric_value)
+            self.log(metric_name, float(metric_value))
 
         self.resource_monitor.data = None
 
diff --git a/src/ptbench/models/alexnet.py b/src/ptbench/models/alexnet.py
index 59acba158565723ab28a483e53b87b1395de742f..8b24227753e704281e5c00e979dfbf2192760675 100644
--- a/src/ptbench/models/alexnet.py
+++ b/src/ptbench/models/alexnet.py
@@ -28,9 +28,6 @@ class Alexnet(pl.LightningModule):
 
         self.save_hyperparameters(ignore=["criterion", "criterion_valid"])
 
-        self.criterion = criterion
-        self.criterion_valid = criterion_valid
-
         self.name = "AlexNet"
 
         # Load pretrained model
@@ -77,7 +74,7 @@ class Alexnet(pl.LightningModule):
         # Forward pass on the network
         outputs = self(images)
 
-        training_loss = self.criterion(outputs, labels.double())
+        training_loss = self.hparams.criterion(outputs, labels.double())
 
         return {"loss": training_loss}
 
@@ -92,7 +89,7 @@ class Alexnet(pl.LightningModule):
 
         # data forwarding on the existing network
         outputs = self(images)
-        validation_loss = self.criterion_valid(outputs, labels.double())
+        validation_loss = self.hparams.criterion_valid(outputs, labels.double())
 
         return {"validation_loss": validation_loss}
 
diff --git a/src/ptbench/models/densenet.py b/src/ptbench/models/densenet.py
index b44dac93f46447ef9fa8f3fbf0ca8c9f13163812..f5c58ad68affdabe6c005616599979afb994834c 100644
--- a/src/ptbench/models/densenet.py
+++ b/src/ptbench/models/densenet.py
@@ -31,9 +31,6 @@ class Densenet(pl.LightningModule):
 
         self.name = "Densenet"
 
-        self.criterion = criterion
-        self.criterion_valid = criterion_valid
-
         self.normalizer = TorchVisionNormalizer(nb_channels=nb_channels)
 
         # Load pretrained model
@@ -78,7 +75,7 @@ class Densenet(pl.LightningModule):
         # Forward pass on the network
         outputs = self(images)
 
-        training_loss = self.criterion(outputs, labels.double())
+        training_loss = self.hparams.criterion(outputs, labels.double())
 
         return {"loss": training_loss}
 
@@ -93,7 +90,7 @@ class Densenet(pl.LightningModule):
 
         # data forwarding on the existing network
         outputs = self(images)
-        validation_loss = self.criterion_valid(outputs, labels.double())
+        validation_loss = self.hparams.criterion_valid(outputs, labels.double())
 
         return {"validation_loss": validation_loss}
 
diff --git a/src/ptbench/models/densenet_rs.py b/src/ptbench/models/densenet_rs.py
index 997516a02bcdb5f2b7fdbe04e10fd48077d51092..97cb9bdace46772cd39a829b864eab2116aa1429 100644
--- a/src/ptbench/models/densenet_rs.py
+++ b/src/ptbench/models/densenet_rs.py
@@ -26,9 +26,6 @@ class DensenetRS(pl.LightningModule):
 
         self.name = "DensenetRS"
 
-        self.criterion = criterion
-        self.criterion_valid = criterion_valid
-
         self.normalizer = TorchVisionNormalizer()
 
         # Load pretrained model
@@ -72,7 +69,7 @@ class DensenetRS(pl.LightningModule):
         # Forward pass on the network
         outputs = self(images)
 
-        training_loss = self.criterion(outputs, labels.double())
+        training_loss = self.hparams.criterion(outputs, labels.double())
 
         return {"loss": training_loss}
 
@@ -87,7 +84,7 @@ class DensenetRS(pl.LightningModule):
 
         # data forwarding on the existing network
         outputs = self(images)
-        validation_loss = self.criterion_valid(outputs, labels.double())
+        validation_loss = self.hparams.criterion_valid(outputs, labels.double())
 
         return {"validation_loss": validation_loss}
 
diff --git a/src/ptbench/models/logistic_regression.py b/src/ptbench/models/logistic_regression.py
index ad56cb80530b3721e7aae20f6d3ebf03e6c19250..deda25aaac71289aa210cb88c4817d5344f57f38 100644
--- a/src/ptbench/models/logistic_regression.py
+++ b/src/ptbench/models/logistic_regression.py
@@ -22,12 +22,9 @@ class LogisticRegression(pl.LightningModule):
 
         self.save_hyperparameters(ignore=["criterion", "criterion_valid"])
 
-        self.criterion = criterion
-        self.criterion_valid = criterion_valid
-
         self.name = "logistic_regression"
 
-        self.linear = nn.Linear(input_size, 1)
+        self.linear = nn.Linear(self.hparams.input_size, 1)
 
     def forward(self, x):
         """
@@ -60,7 +57,7 @@ class LogisticRegression(pl.LightningModule):
         # Forward pass on the network
         outputs = self(images)
 
-        training_loss = self.criterion(outputs, labels.double())
+        training_loss = self.hparams.criterion(outputs, labels.double())
 
         return {"loss": training_loss}
 
@@ -75,7 +72,7 @@ class LogisticRegression(pl.LightningModule):
 
         # data forwarding on the existing network
         outputs = self(images)
-        validation_loss = self.criterion_valid(outputs, labels.double())
+        validation_loss = self.hparams.criterion_valid(outputs, labels.double())
 
         return {"validation_loss": validation_loss}
 
diff --git a/src/ptbench/models/pasa.py b/src/ptbench/models/pasa.py
index af47d9e3d96afecde176ea193c9e0d449f341ee5..155aa7d89775868710bfbdc3f0d7a7f9f49df699 100644
--- a/src/ptbench/models/pasa.py
+++ b/src/ptbench/models/pasa.py
@@ -38,13 +38,10 @@ class PASA(pl.LightningModule):
     ):
         super().__init__()
 
-        self.save_hyperparameters(ignore=["criterion", "criterion_valid"])
+        self.save_hyperparameters()
 
         self.name = "pasa"
 
-        self.criterion = criterion
-        self.criterion_valid = criterion_valid
-
         self.normalizer = TorchVisionNormalizer(nb_channels=1)
 
         # First convolution block
@@ -169,7 +166,7 @@ class PASA(pl.LightningModule):
         # Forward pass on the network
         outputs = self(images)
 
-        training_loss = self.criterion(outputs, labels.double())
+        training_loss = self.hparams.criterion(outputs, labels.double())
 
         return {"loss": training_loss}
 
@@ -184,7 +181,7 @@ class PASA(pl.LightningModule):
 
         # data forwarding on the existing network
         outputs = self(images)
-        validation_loss = self.criterion_valid(outputs, labels.double())
+        validation_loss = self.hparams.criterion_valid(outputs, labels.double())
 
         return {"validation_loss": validation_loss}
 
diff --git a/src/ptbench/models/signs_to_tb.py b/src/ptbench/models/signs_to_tb.py
index 0169a1b8fa008786829a1f301260efe3d695df7e..9db39f7a967fd1bfa9c04d158120e7b866ed5a88 100644
--- a/src/ptbench/models/signs_to_tb.py
+++ b/src/ptbench/models/signs_to_tb.py
@@ -20,18 +20,15 @@ class SignsToTB(pl.LightningModule):
     ):
         super().__init__()
 
-        self.save_hyperparameters(ignore=["criterion", "criterion_valid"])
+        self.save_hyperparameters()
 
         self.name = "signs_to_tb"
 
-        self.criterion = criterion
-        self.criterion_valid = criterion_valid
-
-        self.input_size = input_size
-        self.hidden_size = hidden_size
-        self.fc1 = torch.nn.Linear(self.input_size, self.hidden_size)
+        self.fc1 = torch.nn.Linear(
+            self.hparams.input_size, self.hparams.hidden_size
+        )
         self.relu = torch.nn.ReLU()
-        self.fc2 = torch.nn.Linear(self.hidden_size, 1)
+        self.fc2 = torch.nn.Linear(self.hparams.hidden_size, 1)
 
     def forward(self, x):
         """
@@ -67,7 +64,7 @@ class SignsToTB(pl.LightningModule):
         # Forward pass on the network
         outputs = self(images)
 
-        training_loss = self.criterion(outputs, labels.double())
+        training_loss = self.hparams.criterion(outputs, labels.double())
 
         return {"loss": training_loss}
 
@@ -82,7 +79,7 @@ class SignsToTB(pl.LightningModule):
 
         # data forwarding on the existing network
         outputs = self(images)
-        validation_loss = self.criterion_valid(outputs, labels.double())
+        validation_loss = self.hparams.criterion_valid(outputs, labels.double())
 
         return {"validation_loss": validation_loss}
 
diff --git a/src/ptbench/scripts/predict.py b/src/ptbench/scripts/predict.py
index 860d95b293f22895e62ebc825a03550d24018806..65336ac138bd58cf30f9879c562c32a5614591e1 100644
--- a/src/ptbench/scripts/predict.py
+++ b/src/ptbench/scripts/predict.py
@@ -73,7 +73,7 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
 @click.option(
     "--weight",
     "-w",
-    help="Path or URL to pretrained model file (.pth extension)",
+    help="Path or URL to pretrained model file (.ckpt extension)",
     required=True,
     cls=ResourceOption,
 )
@@ -122,9 +122,8 @@ def predict(
 
     dataset = dataset if isinstance(dataset, dict) else dict(test=dataset)
 
-    model = model.load_from_checkpoint(
-        weight, criterion=model.criterion, criterion_valid=model.criterion_valid
-    )
+    logger.info(f"Loading checkpoint from {weight}")
+    model = model.load_from_checkpoint(weight, strict=False)
 
     # Logistic regressor weights
     if model.name == "logistic_regression":
diff --git a/src/ptbench/scripts/train.py b/src/ptbench/scripts/train.py
index f5bd3a7afa0b19e8ad8650cb325c7fc5ba79a166..6d117c5fea978c54fd659dac8de36a2fb0841233 100644
--- a/src/ptbench/scripts/train.py
+++ b/src/ptbench/scripts/train.py
@@ -349,7 +349,7 @@ def train(
     # Redefine a weighted criterion if possible
     if isinstance(criterion, torch.nn.BCEWithLogitsLoss):
         positive_weights = get_positive_weights(use_dataset)
-        model.criterion = BCEWithLogitsLoss(pos_weight=positive_weights)
+        model.hparams.criterion = BCEWithLogitsLoss(pos_weight=positive_weights)
     else:
         logger.warning("Weighted criterion not supported")
 
@@ -372,7 +372,7 @@ def train(
             or criterion_valid is None
         ):
             positive_weights = get_positive_weights(validation_dataset)
-            model.criterion_valid = BCEWithLogitsLoss(
+            model.hparams.criterion_valid = BCEWithLogitsLoss(
                 pos_weight=positive_weights
             )
         else: