Skip to content
Snippets Groups Projects
Commit 029a57a9 authored by Daniel CARRON's avatar Daniel CARRON :b:
Browse files

Properly save/load criterion

parent e7446664
No related branches found
No related tags found
1 merge request!4Moved code to lightning
......@@ -26,7 +26,7 @@ class Alexnet(pl.LightningModule):
):
super().__init__()
self.save_hyperparameters()
self.save_hyperparameters(ignore=["criterion", "criterion_valid"])
self.criterion = criterion
self.criterion_valid = criterion_valid
......
......@@ -27,7 +27,7 @@ class Densenet(pl.LightningModule):
):
super().__init__()
self.save_hyperparameters()
self.save_hyperparameters(ignore=["criterion", "criterion_valid"])
self.name = "Densenet"
......
......@@ -20,7 +20,7 @@ class LogisticRegression(pl.LightningModule):
):
super().__init__()
self.save_hyperparameters()
self.save_hyperparameters(ignore=["criterion", "criterion_valid"])
self.criterion = criterion
self.criterion_valid = criterion_valid
......
......@@ -38,7 +38,7 @@ class PASA(pl.LightningModule):
):
super().__init__()
self.save_hyperparameters()
self.save_hyperparameters(ignore=["criterion", "criterion_valid"])
self.name = "pasa"
......
......@@ -20,7 +20,7 @@ class SignsToTB(pl.LightningModule):
):
super().__init__()
self.save_hyperparameters()
self.save_hyperparameters(ignore=["criterion", "criterion_valid"])
self.name = "signs_to_tb"
......
......@@ -122,7 +122,9 @@ def predict(
dataset = dataset if isinstance(dataset, dict) else dict(test=dataset)
model = model.load_from_checkpoint(weight)
model = model.load_from_checkpoint(
weight, criterion=model.criterion, criterion_valid=model.criterion_valid
)
# Logistic regressor weights
if model.name == "logistic_regression":
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment