diff --git a/src/ptbench/configs/models/pasa.py b/src/ptbench/configs/models/pasa.py index 3ee0b92164b5531b65049b94e71b01b07e2ad27e..c02d650dcaa03bd4c9aef47ca2d34b858ecd3e73 100644 --- a/src/ptbench/configs/models/pasa.py +++ b/src/ptbench/configs/models/pasa.py @@ -28,3 +28,5 @@ criterion_valid = BCEWithLogitsLoss(pos_weight=empty(1)) # model model = PASA(criterion, criterion_valid, optimizer, optimizer_configs) +model.criterion = criterion +model.criterion_valid = criterion_valid diff --git a/src/ptbench/engine/callbacks.py b/src/ptbench/engine/callbacks.py index 0080676a2ba8b5ac63649e6a30f39593e11cfc12..b1d86c8f41bedb77389da4343756e99cf96b6d3b 100644 --- a/src/ptbench/engine/callbacks.py +++ b/src/ptbench/engine/callbacks.py @@ -70,7 +70,7 @@ class LoggingCallback(Callback): assert self.resource_monitor.q.empty() for metric_name, metric_value in self.resource_monitor.data: - self.log(metric_name, metric_value) + self.log(metric_name, float(metric_value)) self.resource_monitor.data = None diff --git a/src/ptbench/models/alexnet.py b/src/ptbench/models/alexnet.py index 59acba158565723ab28a483e53b87b1395de742f..8b24227753e704281e5c00e979dfbf2192760675 100644 --- a/src/ptbench/models/alexnet.py +++ b/src/ptbench/models/alexnet.py @@ -28,9 +28,6 @@ class Alexnet(pl.LightningModule): self.save_hyperparameters(ignore=["criterion", "criterion_valid"]) - self.criterion = criterion - self.criterion_valid = criterion_valid - self.name = "AlexNet" # Load pretrained model @@ -77,7 +74,7 @@ class Alexnet(pl.LightningModule): # Forward pass on the network outputs = self(images) - training_loss = self.criterion(outputs, labels.double()) + training_loss = self.hparams.criterion(outputs, labels.double()) return {"loss": training_loss} @@ -92,7 +89,7 @@ class Alexnet(pl.LightningModule): # data forwarding on the existing network outputs = self(images) - validation_loss = self.criterion_valid(outputs, labels.double()) + validation_loss = self.hparams.criterion_valid(outputs, labels.double()) return {"validation_loss": validation_loss} diff --git a/src/ptbench/models/densenet.py b/src/ptbench/models/densenet.py index b44dac93f46447ef9fa8f3fbf0ca8c9f13163812..f5c58ad68affdabe6c005616599979afb994834c 100644 --- a/src/ptbench/models/densenet.py +++ b/src/ptbench/models/densenet.py @@ -31,9 +31,6 @@ class Densenet(pl.LightningModule): self.name = "Densenet" - self.criterion = criterion - self.criterion_valid = criterion_valid - self.normalizer = TorchVisionNormalizer(nb_channels=nb_channels) # Load pretrained model @@ -78,7 +75,7 @@ class Densenet(pl.LightningModule): # Forward pass on the network outputs = self(images) - training_loss = self.criterion(outputs, labels.double()) + training_loss = self.hparams.criterion(outputs, labels.double()) return {"loss": training_loss} @@ -93,7 +90,7 @@ class Densenet(pl.LightningModule): # data forwarding on the existing network outputs = self(images) - validation_loss = self.criterion_valid(outputs, labels.double()) + validation_loss = self.hparams.criterion_valid(outputs, labels.double()) return {"validation_loss": validation_loss} diff --git a/src/ptbench/models/densenet_rs.py b/src/ptbench/models/densenet_rs.py index 997516a02bcdb5f2b7fdbe04e10fd48077d51092..97cb9bdace46772cd39a829b864eab2116aa1429 100644 --- a/src/ptbench/models/densenet_rs.py +++ b/src/ptbench/models/densenet_rs.py @@ -26,9 +26,6 @@ class DensenetRS(pl.LightningModule): self.name = "DensenetRS" - self.criterion = criterion - self.criterion_valid = criterion_valid - self.normalizer = TorchVisionNormalizer() # Load pretrained model @@ -72,7 +69,7 @@ class DensenetRS(pl.LightningModule): # Forward pass on the network outputs = self(images) - training_loss = self.criterion(outputs, labels.double()) + training_loss = self.hparams.criterion(outputs, labels.double()) return {"loss": training_loss} @@ -87,7 +84,7 @@ class DensenetRS(pl.LightningModule): # data forwarding on the existing network outputs = self(images) - validation_loss = self.criterion_valid(outputs, labels.double()) + validation_loss = self.hparams.criterion_valid(outputs, labels.double()) return {"validation_loss": validation_loss} diff --git a/src/ptbench/models/logistic_regression.py b/src/ptbench/models/logistic_regression.py index ad56cb80530b3721e7aae20f6d3ebf03e6c19250..deda25aaac71289aa210cb88c4817d5344f57f38 100644 --- a/src/ptbench/models/logistic_regression.py +++ b/src/ptbench/models/logistic_regression.py @@ -22,12 +22,9 @@ class LogisticRegression(pl.LightningModule): self.save_hyperparameters(ignore=["criterion", "criterion_valid"]) - self.criterion = criterion - self.criterion_valid = criterion_valid - self.name = "logistic_regression" - self.linear = nn.Linear(input_size, 1) + self.linear = nn.Linear(self.hparams.input_size, 1) def forward(self, x): """ @@ -60,7 +57,7 @@ class LogisticRegression(pl.LightningModule): # Forward pass on the network outputs = self(images) - training_loss = self.criterion(outputs, labels.double()) + training_loss = self.hparams.criterion(outputs, labels.double()) return {"loss": training_loss} @@ -75,7 +72,7 @@ class LogisticRegression(pl.LightningModule): # data forwarding on the existing network outputs = self(images) - validation_loss = self.criterion_valid(outputs, labels.double()) + validation_loss = self.hparams.criterion_valid(outputs, labels.double()) return {"validation_loss": validation_loss} diff --git a/src/ptbench/models/pasa.py b/src/ptbench/models/pasa.py index af47d9e3d96afecde176ea193c9e0d449f341ee5..155aa7d89775868710bfbdc3f0d7a7f9f49df699 100644 --- a/src/ptbench/models/pasa.py +++ b/src/ptbench/models/pasa.py @@ -38,13 +38,10 @@ class PASA(pl.LightningModule): ): super().__init__() - self.save_hyperparameters(ignore=["criterion", "criterion_valid"]) + self.save_hyperparameters() self.name = "pasa" - self.criterion = criterion - self.criterion_valid = criterion_valid - self.normalizer = TorchVisionNormalizer(nb_channels=1) # First convolution block @@ -169,7 +166,7 @@ class PASA(pl.LightningModule): # Forward pass on the network outputs = self(images) - training_loss = self.criterion(outputs, labels.double()) + training_loss = self.hparams.criterion(outputs, labels.double()) return {"loss": training_loss} @@ -184,7 +181,7 @@ class PASA(pl.LightningModule): # data forwarding on the existing network outputs = self(images) - validation_loss = self.criterion_valid(outputs, labels.double()) + validation_loss = self.hparams.criterion_valid(outputs, labels.double()) return {"validation_loss": validation_loss} diff --git a/src/ptbench/models/signs_to_tb.py b/src/ptbench/models/signs_to_tb.py index 0169a1b8fa008786829a1f301260efe3d695df7e..9db39f7a967fd1bfa9c04d158120e7b866ed5a88 100644 --- a/src/ptbench/models/signs_to_tb.py +++ b/src/ptbench/models/signs_to_tb.py @@ -20,18 +20,15 @@ class SignsToTB(pl.LightningModule): ): super().__init__() - self.save_hyperparameters(ignore=["criterion", "criterion_valid"]) + self.save_hyperparameters() self.name = "signs_to_tb" - self.criterion = criterion - self.criterion_valid = criterion_valid - - self.input_size = input_size - self.hidden_size = hidden_size - self.fc1 = torch.nn.Linear(self.input_size, self.hidden_size) + self.fc1 = torch.nn.Linear( + self.hparams.input_size, self.hparams.hidden_size + ) self.relu = torch.nn.ReLU() - self.fc2 = torch.nn.Linear(self.hidden_size, 1) + self.fc2 = torch.nn.Linear(self.hparams.hidden_size, 1) def forward(self, x): """ @@ -67,7 +64,7 @@ class SignsToTB(pl.LightningModule): # Forward pass on the network outputs = self(images) - training_loss = self.criterion(outputs, labels.double()) + training_loss = self.hparams.criterion(outputs, labels.double()) return {"loss": training_loss} @@ -82,7 +79,7 @@ class SignsToTB(pl.LightningModule): # data forwarding on the existing network outputs = self(images) - validation_loss = self.criterion_valid(outputs, labels.double()) + validation_loss = self.hparams.criterion_valid(outputs, labels.double()) return {"validation_loss": validation_loss} diff --git a/src/ptbench/scripts/predict.py b/src/ptbench/scripts/predict.py index 860d95b293f22895e62ebc825a03550d24018806..65336ac138bd58cf30f9879c562c32a5614591e1 100644 --- a/src/ptbench/scripts/predict.py +++ b/src/ptbench/scripts/predict.py @@ -73,7 +73,7 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") @click.option( "--weight", "-w", - help="Path or URL to pretrained model file (.pth extension)", + help="Path or URL to pretrained model file (.ckpt extension)", required=True, cls=ResourceOption, ) @@ -122,9 +122,8 @@ def predict( dataset = dataset if isinstance(dataset, dict) else dict(test=dataset) - model = model.load_from_checkpoint( - weight, criterion=model.criterion, criterion_valid=model.criterion_valid - ) + logger.info(f"Loading checkpoint from {weight}") + model = model.load_from_checkpoint(weight, strict=False) # Logistic regressor weights if model.name == "logistic_regression": diff --git a/src/ptbench/scripts/train.py b/src/ptbench/scripts/train.py index f5bd3a7afa0b19e8ad8650cb325c7fc5ba79a166..6d117c5fea978c54fd659dac8de36a2fb0841233 100644 --- a/src/ptbench/scripts/train.py +++ b/src/ptbench/scripts/train.py @@ -349,7 +349,7 @@ def train( # Redefine a weighted criterion if possible if isinstance(criterion, torch.nn.BCEWithLogitsLoss): positive_weights = get_positive_weights(use_dataset) - model.criterion = BCEWithLogitsLoss(pos_weight=positive_weights) + model.hparams.criterion = BCEWithLogitsLoss(pos_weight=positive_weights) else: logger.warning("Weighted criterion not supported") @@ -372,7 +372,7 @@ def train( or criterion_valid is None ): positive_weights = get_positive_weights(validation_dataset) - model.criterion_valid = BCEWithLogitsLoss( + model.hparams.criterion_valid = BCEWithLogitsLoss( pos_weight=positive_weights ) else: