Skip to content
Snippets Groups Projects
Commit 99356e82 authored by Daniel CARRON's avatar Daniel CARRON :b:
Browse files

Fix saving and loading model hyperparameters

parent d711a4fa
No related branches found
No related tags found
1 merge request!4Moved code to lightning
......@@ -28,3 +28,5 @@ criterion_valid = BCEWithLogitsLoss(pos_weight=empty(1))
# model
model = PASA(criterion, criterion_valid, optimizer, optimizer_configs)
model.criterion = criterion
model.criterion_valid = criterion_valid
......@@ -70,7 +70,7 @@ class LoggingCallback(Callback):
assert self.resource_monitor.q.empty()
for metric_name, metric_value in self.resource_monitor.data:
self.log(metric_name, metric_value)
self.log(metric_name, float(metric_value))
self.resource_monitor.data = None
......
......@@ -28,9 +28,6 @@ class Alexnet(pl.LightningModule):
self.save_hyperparameters(ignore=["criterion", "criterion_valid"])
self.criterion = criterion
self.criterion_valid = criterion_valid
self.name = "AlexNet"
# Load pretrained model
......@@ -77,7 +74,7 @@ class Alexnet(pl.LightningModule):
# Forward pass on the network
outputs = self(images)
training_loss = self.criterion(outputs, labels.double())
training_loss = self.hparams.criterion(outputs, labels.double())
return {"loss": training_loss}
......@@ -92,7 +89,7 @@ class Alexnet(pl.LightningModule):
# data forwarding on the existing network
outputs = self(images)
validation_loss = self.criterion_valid(outputs, labels.double())
validation_loss = self.hparams.criterion_valid(outputs, labels.double())
return {"validation_loss": validation_loss}
......
......@@ -31,9 +31,6 @@ class Densenet(pl.LightningModule):
self.name = "Densenet"
self.criterion = criterion
self.criterion_valid = criterion_valid
self.normalizer = TorchVisionNormalizer(nb_channels=nb_channels)
# Load pretrained model
......@@ -78,7 +75,7 @@ class Densenet(pl.LightningModule):
# Forward pass on the network
outputs = self(images)
training_loss = self.criterion(outputs, labels.double())
training_loss = self.hparams.criterion(outputs, labels.double())
return {"loss": training_loss}
......@@ -93,7 +90,7 @@ class Densenet(pl.LightningModule):
# data forwarding on the existing network
outputs = self(images)
validation_loss = self.criterion_valid(outputs, labels.double())
validation_loss = self.hparams.criterion_valid(outputs, labels.double())
return {"validation_loss": validation_loss}
......
......@@ -26,9 +26,6 @@ class DensenetRS(pl.LightningModule):
self.name = "DensenetRS"
self.criterion = criterion
self.criterion_valid = criterion_valid
self.normalizer = TorchVisionNormalizer()
# Load pretrained model
......@@ -72,7 +69,7 @@ class DensenetRS(pl.LightningModule):
# Forward pass on the network
outputs = self(images)
training_loss = self.criterion(outputs, labels.double())
training_loss = self.hparams.criterion(outputs, labels.double())
return {"loss": training_loss}
......@@ -87,7 +84,7 @@ class DensenetRS(pl.LightningModule):
# data forwarding on the existing network
outputs = self(images)
validation_loss = self.criterion_valid(outputs, labels.double())
validation_loss = self.hparams.criterion_valid(outputs, labels.double())
return {"validation_loss": validation_loss}
......
......@@ -22,12 +22,9 @@ class LogisticRegression(pl.LightningModule):
self.save_hyperparameters(ignore=["criterion", "criterion_valid"])
self.criterion = criterion
self.criterion_valid = criterion_valid
self.name = "logistic_regression"
self.linear = nn.Linear(input_size, 1)
self.linear = nn.Linear(self.hparams.input_size, 1)
def forward(self, x):
"""
......@@ -60,7 +57,7 @@ class LogisticRegression(pl.LightningModule):
# Forward pass on the network
outputs = self(images)
training_loss = self.criterion(outputs, labels.double())
training_loss = self.hparams.criterion(outputs, labels.double())
return {"loss": training_loss}
......@@ -75,7 +72,7 @@ class LogisticRegression(pl.LightningModule):
# data forwarding on the existing network
outputs = self(images)
validation_loss = self.criterion_valid(outputs, labels.double())
validation_loss = self.hparams.criterion_valid(outputs, labels.double())
return {"validation_loss": validation_loss}
......
......@@ -38,13 +38,10 @@ class PASA(pl.LightningModule):
):
super().__init__()
self.save_hyperparameters(ignore=["criterion", "criterion_valid"])
self.save_hyperparameters()
self.name = "pasa"
self.criterion = criterion
self.criterion_valid = criterion_valid
self.normalizer = TorchVisionNormalizer(nb_channels=1)
# First convolution block
......@@ -169,7 +166,7 @@ class PASA(pl.LightningModule):
# Forward pass on the network
outputs = self(images)
training_loss = self.criterion(outputs, labels.double())
training_loss = self.hparams.criterion(outputs, labels.double())
return {"loss": training_loss}
......@@ -184,7 +181,7 @@ class PASA(pl.LightningModule):
# data forwarding on the existing network
outputs = self(images)
validation_loss = self.criterion_valid(outputs, labels.double())
validation_loss = self.hparams.criterion_valid(outputs, labels.double())
return {"validation_loss": validation_loss}
......
......@@ -20,18 +20,15 @@ class SignsToTB(pl.LightningModule):
):
super().__init__()
self.save_hyperparameters(ignore=["criterion", "criterion_valid"])
self.save_hyperparameters()
self.name = "signs_to_tb"
self.criterion = criterion
self.criterion_valid = criterion_valid
self.input_size = input_size
self.hidden_size = hidden_size
self.fc1 = torch.nn.Linear(self.input_size, self.hidden_size)
self.fc1 = torch.nn.Linear(
self.hparams.input_size, self.hparams.hidden_size
)
self.relu = torch.nn.ReLU()
self.fc2 = torch.nn.Linear(self.hidden_size, 1)
self.fc2 = torch.nn.Linear(self.hparams.hidden_size, 1)
def forward(self, x):
"""
......@@ -67,7 +64,7 @@ class SignsToTB(pl.LightningModule):
# Forward pass on the network
outputs = self(images)
training_loss = self.criterion(outputs, labels.double())
training_loss = self.hparams.criterion(outputs, labels.double())
return {"loss": training_loss}
......@@ -82,7 +79,7 @@ class SignsToTB(pl.LightningModule):
# data forwarding on the existing network
outputs = self(images)
validation_loss = self.criterion_valid(outputs, labels.double())
validation_loss = self.hparams.criterion_valid(outputs, labels.double())
return {"validation_loss": validation_loss}
......
......@@ -73,7 +73,7 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
@click.option(
"--weight",
"-w",
help="Path or URL to pretrained model file (.pth extension)",
help="Path or URL to pretrained model file (.ckpt extension)",
required=True,
cls=ResourceOption,
)
......@@ -122,9 +122,8 @@ def predict(
dataset = dataset if isinstance(dataset, dict) else dict(test=dataset)
model = model.load_from_checkpoint(
weight, criterion=model.criterion, criterion_valid=model.criterion_valid
)
logger.info(f"Loading checkpoint from {weight}")
model = model.load_from_checkpoint(weight, strict=False)
# Logistic regressor weights
if model.name == "logistic_regression":
......
......@@ -349,7 +349,7 @@ def train(
# Redefine a weighted criterion if possible
if isinstance(criterion, torch.nn.BCEWithLogitsLoss):
positive_weights = get_positive_weights(use_dataset)
model.criterion = BCEWithLogitsLoss(pos_weight=positive_weights)
model.hparams.criterion = BCEWithLogitsLoss(pos_weight=positive_weights)
else:
logger.warning("Weighted criterion not supported")
......@@ -372,7 +372,7 @@ def train(
or criterion_valid is None
):
positive_weights = get_positive_weights(validation_dataset)
model.criterion_valid = BCEWithLogitsLoss(
model.hparams.criterion_valid = BCEWithLogitsLoss(
pos_weight=positive_weights
)
else:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment