Skip to content
Snippets Groups Projects
Commit 0b9ae46c authored by André Anjos's avatar André Anjos :speech_balloon:
Browse files

Merge branch 'checkpointing-cleanup' into 'main'

Checkpointing cleanup

See merge request biosignal/software/ptbench!2
parents 0209ebe1 e360a861
Branches
Tags
1 merge request!2Checkpointing cleanup
Pipeline #71477 passed
...@@ -131,7 +131,7 @@ def predict( ...@@ -131,7 +131,7 @@ def predict(
weight_fullpath = os.path.abspath(weight) weight_fullpath = os.path.abspath(weight)
checkpointer = Checkpointer(model) checkpointer = Checkpointer(model)
checkpointer.load(weight_fullpath, strict=False) checkpointer.load(weight_fullpath)
# Logistic regressor weights # Logistic regressor weights
if model.name == "logistic_regression": if model.name == "logistic_regression":
......
...@@ -51,7 +51,7 @@ class Checkpointer: ...@@ -51,7 +51,7 @@ class Checkpointer:
with open(self._last_checkpoint_filename, "w") as f: with open(self._last_checkpoint_filename, "w") as f:
f.write(name) f.write(name)
def load(self, f=None, strict=True): def load(self, f=None):
"""Loads model, optimizer and scheduler from file. """Loads model, optimizer and scheduler from file.
Parameters Parameters
...@@ -62,9 +62,6 @@ class Checkpointer: ...@@ -62,9 +62,6 @@ class Checkpointer:
contains the checkpoint data to load into the model, and optionally contains the checkpoint data to load into the model, and optionally
into the optimizer and the scheduler. If not specified, loads data into the optimizer and the scheduler. If not specified, loads data
from current path. from current path.
partial : :py:class:`bool`, Optional
If True, loading is not strict and only the model is loaded
""" """
if f is None: if f is None:
f = self.last_checkpoint() f = self.last_checkpoint()
...@@ -79,13 +76,12 @@ class Checkpointer: ...@@ -79,13 +76,12 @@ class Checkpointer:
checkpoint = torch.load(f, map_location=torch.device("cpu")) checkpoint = torch.load(f, map_location=torch.device("cpu"))
# converts model entry to model parameters # converts model entry to model parameters
self.model.load_state_dict(checkpoint.pop("model"), strict=strict) self.model.load_state_dict(checkpoint.pop("model"))
if strict: if self.optimizer is not None:
if self.optimizer is not None: self.optimizer.load_state_dict(checkpoint.pop("optimizer"))
self.optimizer.load_state_dict(checkpoint.pop("optimizer")) if self.scheduler is not None:
if self.scheduler is not None: self.scheduler.load_state_dict(checkpoint.pop("scheduler"))
self.scheduler.load_state_dict(checkpoint.pop("scheduler"))
return checkpoint return checkpoint
......
Subproject commit 64c25ecf20b6f6ac2f250772fcb5338c1196a950 Subproject commit 69185f0d9ea67893722c5a840e2caa59946b3b83
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment