diff --git a/src/ptbench/models/densenet.py b/src/ptbench/models/densenet.py index ab7a1f71fbcdd8a0c17384f44ee390d1a5d1ba1a..4e5b34c09331ed0140eb49afaf812a4c84e59b56 100644 --- a/src/ptbench/models/densenet.py +++ b/src/ptbench/models/densenet.py @@ -114,7 +114,7 @@ class Densenet(pl.LightningModule): def configure_optimizers(self): # Dynamically instantiates the optimizer given the configs optimizer = getattr(torch.optim, self.hparams.optimizer)( - self.parameters(), **self.hparams.optimizer_params + self.parameters(), **self.hparams.optimizer_configs ) return optimizer diff --git a/src/ptbench/models/pasa.py b/src/ptbench/models/pasa.py index d5657218b0f80ef1d8c6c78a31427f74ce4999f3..8c9705e61c474b0f029dd0418a067e4ded593820 100644 --- a/src/ptbench/models/pasa.py +++ b/src/ptbench/models/pasa.py @@ -205,7 +205,7 @@ class PASA(pl.LightningModule): def configure_optimizers(self): # Dynamically instantiates the optimizer given the configs optimizer = getattr(torch.optim, self.hparams.optimizer)( - self.parameters(), **self.hparams.optimizer_params + self.parameters(), **self.hparams.optimizer_configs ) return optimizer