diff --git a/src/ptbench/scripts/predict.py b/src/ptbench/scripts/predict.py index 054590e2631b462ec11e78ea73e16a6a7858cb89..91b66d725c83ed9802d7ff9f03c2e796d670e546 100644 --- a/src/ptbench/scripts/predict.py +++ b/src/ptbench/scripts/predict.py @@ -134,7 +134,7 @@ def predict( datamodule.setup(stage="predict") logger.info(f"Loading checkpoint from `{weight}`...") - model = model.load_from_checkpoint(weight, strict=False) + model = type(model).load_from_checkpoint(weight, strict=False) predictions = run(model, datamodule, DeviceManager(device))