diff --git a/bob/ip/binseg/utils/checkpointer.py b/bob/ip/binseg/utils/checkpointer.py index 19090db7e9759e875e363b395b2bcb93163b21a7..c0ad2a6199be61858dd0fa67e3b8a96626966085 100644 --- a/bob/ip/binseg/utils/checkpointer.py +++ b/bob/ip/binseg/utils/checkpointer.py @@ -5,8 +5,6 @@ import os import torch -from .model_serialization import load_state_dict - import logging logger = logging.getLogger(__name__) @@ -84,7 +82,7 @@ class Checkpointer: checkpoint = torch.load(f, map_location=torch.device("cpu")) # converts model entry to model parameters - load_state_dict(self.model, checkpoint.pop("model")) + self.model.load_state_dict(checkpoint.pop("model")) if self.optimizer is not None: self.optimizer.load_state_dict(checkpoint.pop("optimizer")) diff --git a/bob/ip/binseg/utils/model_serialization.py b/bob/ip/binseg/utils/model_serialization.py deleted file mode 100644 index d629eae14f884e770f9a7c45cd67f9aa8092e706..0000000000000000000000000000000000000000 --- a/bob/ip/binseg/utils/model_serialization.py +++ /dev/null @@ -1,92 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. -# https://github.com/facebookresearch/maskrcnn-benchmark - -from collections import OrderedDict - -import logging - -logger = logging.getLogger(__name__) - -import torch - - -def align_and_update_state_dicts(model_state_dict, loaded_state_dict): - """ - - Strategy: suppose that the models that we will create will have prefixes - appended to each of its keys, for example due to an extra level of nesting - that the original pre-trained weights from ImageNet won't contain. For - example, model.state_dict() might return - backbone[0].body.res2.conv1.weight, while the pre-trained model contains - res2.conv1.weight. We thus want to match both parameters together. For - that, we look for each model weight, look among all loaded keys if there is - one that is a suffix of the current weight name, and use it if that's the - case. If multiple matches exist, take the one with longest size of the - corresponding name. For example, for the same model as before, the - pretrained weight file can contain both res2.conv1.weight, as well as - conv1.weight. In this case, we want to match backbone[0].body.conv1.weight - to conv1.weight, and backbone[0].body.res2.conv1.weight to - res2.conv1.weight. - """ - - current_keys = sorted(list(model_state_dict.keys())) - loaded_keys = sorted(list(loaded_state_dict.keys())) - # get a matrix of string matches, where each (i, j) entry correspond to the size of the - # loaded_key string, if it matches - match_matrix = [ - len(j) if i.endswith(j) else 0 - for i in current_keys - for j in loaded_keys - ] - match_matrix = torch.as_tensor(match_matrix).view( - len(current_keys), len(loaded_keys) - ) - max_match_size, idxs = match_matrix.max(1) - # remove indices that correspond to no-match - idxs[max_match_size == 0] = -1 - - # used for logging - max_size = max([len(key) for key in current_keys]) if current_keys else 1 - max_size_loaded = ( - max([len(key) for key in loaded_keys]) if loaded_keys else 1 - ) - log_str_template = "{: <{}} loaded from {: <{}} of shape {}" - for idx_new, idx_old in enumerate(idxs.tolist()): - if idx_old == -1: - continue - key = current_keys[idx_new] - key_old = loaded_keys[idx_old] - model_state_dict[key] = loaded_state_dict[key_old] - logger.debug( - log_str_template.format( - key, - max_size, - key_old, - max_size_loaded, - tuple(loaded_state_dict[key_old].shape), - ) - ) - - -def strip_prefix_if_present(state_dict, prefix): - keys = sorted(state_dict.keys()) - if not all(key.startswith(prefix) for key in keys): - return state_dict - stripped_state_dict = OrderedDict() - for key, value in state_dict.items(): - stripped_state_dict[key.replace(prefix, "")] = value - return stripped_state_dict - - -def load_state_dict(model, loaded_state_dict): - model_state_dict = model.state_dict() - # if the state_dict comes from a model that was wrapped in a - # DataParallel or DistributedDataParallel during serialization, - # remove the "module" prefix before performing the matching - loaded_state_dict = strip_prefix_if_present( - loaded_state_dict, prefix="module." - ) - align_and_update_state_dicts(model_state_dict, loaded_state_dict) - - # use strict loading - model.load_state_dict(model_state_dict) diff --git a/doc/api.rst b/doc/api.rst index b6a288f09afdfd7a0ce7b4554cd36c89b621bb0b..e64e7d28b0b6a9fda600d5089c7249b074b9c2bd 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -86,7 +86,6 @@ Toolbox bob.ip.binseg.utils bob.ip.binseg.utils.checkpointer bob.ip.binseg.utils.measure - bob.ip.binseg.utils.model_serialization bob.ip.binseg.utils.plot bob.ip.binseg.utils.table bob.ip.binseg.utils.summary