diff --git a/src/ptbench/engine/callbacks.py b/src/ptbench/engine/callbacks.py index a01a7e19c50e9400ff5019da9c7deb7e662713f5..0080676a2ba8b5ac63649e6a30f39593e11cfc12 100644 --- a/src/ptbench/engine/callbacks.py +++ b/src/ptbench/engine/callbacks.py @@ -46,7 +46,7 @@ class LoggingCallback(Callback): self.log("total_time", current_time) self.log("eta", eta_seconds) self.log("loss", numpy.average(self.training_loss)) - self.log("learning_rate", pl_module.hparams["optimizer_params"]["lr"]) + self.log("learning_rate", pl_module.hparams["optimizer_configs"]["lr"]) self.log("validation_loss", numpy.average(self.validation_loss)) queue_retries = 0 diff --git a/src/ptbench/models/densenet.py b/src/ptbench/models/densenet.py index 17373b79fe0a37c5ee677cb02c147c4b9c9f3e80..ab7a1f71fbcdd8a0c17384f44ee390d1a5d1ba1a 100644 --- a/src/ptbench/models/densenet.py +++ b/src/ptbench/models/densenet.py @@ -21,7 +21,7 @@ class Densenet(pl.LightningModule): criterion, criterion_valid, optimizer, - optimizer_params, + optimizer_configs, pretrained=False, nb_channels=3, ): diff --git a/src/ptbench/models/pasa.py b/src/ptbench/models/pasa.py index b31fa21d9602a72fd9e39c015a82679dde26efc8..d5657218b0f80ef1d8c6c78a31427f74ce4999f3 100644 --- a/src/ptbench/models/pasa.py +++ b/src/ptbench/models/pasa.py @@ -33,7 +33,9 @@ class PASA(pl.LightningModule): Based on paper by [PASA-2019]_. """ - def __init__(self, criterion, criterion_valid, optimizer, optimizer_params): + def __init__( + self, criterion, criterion_valid, optimizer, optimizer_configs + ): super().__init__() self.save_hyperparameters()