From 03838caeb58fb5c228ec63dd64c6fbb18c40f988 Mon Sep 17 00:00:00 2001
From: Andre Anjos <andre.dos.anjos@gmail.com>
Date: Mon, 18 May 2020 20:23:57 +0200
Subject: [PATCH] [utils.checkpointer] Remove custom serialization

---
 bob/ip/binseg/utils/checkpointer.py        |  4 +-
 bob/ip/binseg/utils/model_serialization.py | 92 ----------------------
 doc/api.rst                                |  1 -
 3 files changed, 1 insertion(+), 96 deletions(-)
 delete mode 100644 bob/ip/binseg/utils/model_serialization.py

diff --git a/bob/ip/binseg/utils/checkpointer.py b/bob/ip/binseg/utils/checkpointer.py
index 19090db7..c0ad2a61 100644
--- a/bob/ip/binseg/utils/checkpointer.py
+++ b/bob/ip/binseg/utils/checkpointer.py
@@ -5,8 +5,6 @@ import os
 
 import torch
 
-from .model_serialization import load_state_dict
-
 import logging
 
 logger = logging.getLogger(__name__)
@@ -84,7 +82,7 @@ class Checkpointer:
         checkpoint = torch.load(f, map_location=torch.device("cpu"))
 
         # converts model entry to model parameters
-        load_state_dict(self.model, checkpoint.pop("model"))
+        self.model.load_state_dict(checkpoint.pop("model"))
 
         if self.optimizer is not None:
             self.optimizer.load_state_dict(checkpoint.pop("optimizer"))
diff --git a/bob/ip/binseg/utils/model_serialization.py b/bob/ip/binseg/utils/model_serialization.py
deleted file mode 100644
index d629eae1..00000000
--- a/bob/ip/binseg/utils/model_serialization.py
+++ /dev/null
@@ -1,92 +0,0 @@
-# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
-# https://github.com/facebookresearch/maskrcnn-benchmark
-
-from collections import OrderedDict
-
-import logging
-
-logger = logging.getLogger(__name__)
-
-import torch
-
-
-def align_and_update_state_dicts(model_state_dict, loaded_state_dict):
-    """
-
-    Strategy: suppose that the models that we will create will have prefixes
-    appended to each of its keys, for example due to an extra level of nesting
-    that the original pre-trained weights from ImageNet won't contain. For
-    example, model.state_dict() might return
-    backbone[0].body.res2.conv1.weight, while the pre-trained model contains
-    res2.conv1.weight. We thus want to match both parameters together.  For
-    that, we look for each model weight, look among all loaded keys if there is
-    one that is a suffix of the current weight name, and use it if that's the
-    case.  If multiple matches exist, take the one with longest size of the
-    corresponding name. For example, for the same model as before, the
-    pretrained weight file can contain both res2.conv1.weight, as well as
-    conv1.weight. In this case, we want to match backbone[0].body.conv1.weight
-    to conv1.weight, and backbone[0].body.res2.conv1.weight to
-    res2.conv1.weight.
-    """
-
-    current_keys = sorted(list(model_state_dict.keys()))
-    loaded_keys = sorted(list(loaded_state_dict.keys()))
-    # get a matrix of string matches, where each (i, j) entry correspond to the size of the
-    # loaded_key string, if it matches
-    match_matrix = [
-        len(j) if i.endswith(j) else 0
-        for i in current_keys
-        for j in loaded_keys
-    ]
-    match_matrix = torch.as_tensor(match_matrix).view(
-        len(current_keys), len(loaded_keys)
-    )
-    max_match_size, idxs = match_matrix.max(1)
-    # remove indices that correspond to no-match
-    idxs[max_match_size == 0] = -1
-
-    # used for logging
-    max_size = max([len(key) for key in current_keys]) if current_keys else 1
-    max_size_loaded = (
-        max([len(key) for key in loaded_keys]) if loaded_keys else 1
-    )
-    log_str_template = "{: <{}} loaded from {: <{}} of shape {}"
-    for idx_new, idx_old in enumerate(idxs.tolist()):
-        if idx_old == -1:
-            continue
-        key = current_keys[idx_new]
-        key_old = loaded_keys[idx_old]
-        model_state_dict[key] = loaded_state_dict[key_old]
-        logger.debug(
-            log_str_template.format(
-                key,
-                max_size,
-                key_old,
-                max_size_loaded,
-                tuple(loaded_state_dict[key_old].shape),
-            )
-        )
-
-
-def strip_prefix_if_present(state_dict, prefix):
-    keys = sorted(state_dict.keys())
-    if not all(key.startswith(prefix) for key in keys):
-        return state_dict
-    stripped_state_dict = OrderedDict()
-    for key, value in state_dict.items():
-        stripped_state_dict[key.replace(prefix, "")] = value
-    return stripped_state_dict
-
-
-def load_state_dict(model, loaded_state_dict):
-    model_state_dict = model.state_dict()
-    # if the state_dict comes from a model that was wrapped in a
-    # DataParallel or DistributedDataParallel during serialization,
-    # remove the "module" prefix before performing the matching
-    loaded_state_dict = strip_prefix_if_present(
-        loaded_state_dict, prefix="module."
-    )
-    align_and_update_state_dicts(model_state_dict, loaded_state_dict)
-
-    # use strict loading
-    model.load_state_dict(model_state_dict)
diff --git a/doc/api.rst b/doc/api.rst
index b6a288f0..e64e7d28 100644
--- a/doc/api.rst
+++ b/doc/api.rst
@@ -86,7 +86,6 @@ Toolbox
    bob.ip.binseg.utils
    bob.ip.binseg.utils.checkpointer
    bob.ip.binseg.utils.measure
-   bob.ip.binseg.utils.model_serialization
    bob.ip.binseg.utils.plot
    bob.ip.binseg.utils.table
    bob.ip.binseg.utils.summary
-- 
GitLab