From c7068292787301f9b7b5935b1a6cc10465b8cabd Mon Sep 17 00:00:00 2001 From: "ogueler@idiap.ch" <ogueler@vws110.idiap.ch> Date: Sat, 18 Mar 2023 16:06:26 +0100 Subject: [PATCH] removed strict parameter for checkpointing --- src/ptbench/scripts/predict.py | 2 +- src/ptbench/utils/checkpointer.py | 16 ++++++---------- 2 files changed, 7 insertions(+), 11 deletions(-) diff --git a/src/ptbench/scripts/predict.py b/src/ptbench/scripts/predict.py index 21a3ae66..51275fc4 100644 --- a/src/ptbench/scripts/predict.py +++ b/src/ptbench/scripts/predict.py @@ -131,7 +131,7 @@ def predict( weight_fullpath = os.path.abspath(weight) checkpointer = Checkpointer(model) - checkpointer.load(weight_fullpath, strict=False) + checkpointer.load(weight_fullpath) # Logistic regressor weights if model.name == "logistic_regression": diff --git a/src/ptbench/utils/checkpointer.py b/src/ptbench/utils/checkpointer.py index 36d10e5a..3e839b0e 100644 --- a/src/ptbench/utils/checkpointer.py +++ b/src/ptbench/utils/checkpointer.py @@ -51,7 +51,7 @@ class Checkpointer: with open(self._last_checkpoint_filename, "w") as f: f.write(name) - def load(self, f=None, strict=True): + def load(self, f=None): """Loads model, optimizer and scheduler from file. Parameters @@ -62,9 +62,6 @@ class Checkpointer: contains the checkpoint data to load into the model, and optionally into the optimizer and the scheduler. If not specified, loads data from current path. - - partial : :py:class:`bool`, Optional - If True, loading is not strict and only the model is loaded """ if f is None: f = self.last_checkpoint() @@ -79,13 +76,12 @@ class Checkpointer: checkpoint = torch.load(f, map_location=torch.device("cpu")) # converts model entry to model parameters - self.model.load_state_dict(checkpoint.pop("model"), strict=strict) + self.model.load_state_dict(checkpoint.pop("model")) - if strict: - if self.optimizer is not None: - self.optimizer.load_state_dict(checkpoint.pop("optimizer")) - if self.scheduler is not None: - self.scheduler.load_state_dict(checkpoint.pop("scheduler")) + if self.optimizer is not None: + self.optimizer.load_state_dict(checkpoint.pop("optimizer")) + if self.scheduler is not None: + self.scheduler.load_state_dict(checkpoint.pop("scheduler")) return checkpoint -- GitLab