Skip to content
Snippets Groups Projects
test_checkpointer.py 3.05 KiB
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later

import os
import unittest

from collections import OrderedDict
from tempfile import TemporaryDirectory

import torch

from ptbench.utils.checkpointer import Checkpointer


class TestCheckpointer(unittest.TestCase):
    def create_model(self):
        return torch.nn.Sequential(torch.nn.Linear(2, 3), torch.nn.Linear(3, 1))

    def create_complex_model(self):
        m = torch.nn.Module()
        m.block1 = torch.nn.Module()
        m.block1.layer1 = torch.nn.Linear(2, 3)
        m.layer2 = torch.nn.Linear(3, 2)
        m.res = torch.nn.Module()
        m.res.layer2 = torch.nn.Linear(3, 2)

        state_dict = OrderedDict()
        state_dict["layer1.weight"] = torch.rand(3, 2)
        state_dict["layer1.bias"] = torch.rand(3)
        state_dict["layer2.weight"] = torch.rand(2, 3)
        state_dict["layer2.bias"] = torch.rand(2)
        state_dict["res.layer2.weight"] = torch.rand(2, 3)
        state_dict["res.layer2.bias"] = torch.rand(2)

        return m, state_dict

    def test_from_last_checkpoint_model(self):
        # test that loading works even if they differ by a prefix
        trained_model = self.create_model()
        fresh_model = self.create_model()
        with TemporaryDirectory() as f:
            checkpointer = Checkpointer(trained_model, path=f)
            checkpointer.save("checkpoint_file")

            # in the same folder
            fresh_checkpointer = Checkpointer(fresh_model, path=f)
            assert fresh_checkpointer.has_checkpoint()
            assert fresh_checkpointer.last_checkpoint() == os.path.realpath(
                os.path.join(f, "checkpoint_file.pth")
            )
            _ = fresh_checkpointer.load()

        for trained_p, loaded_p in zip(
            trained_model.parameters(), fresh_model.parameters()
        ):
            # different tensor references
            assert id(trained_p) != id(loaded_p)
            # same content
            assert trained_p.equal(loaded_p)

    def test_from_name_file_model(self):
        # test that loading works even if they differ by a prefix
        trained_model = self.create_model()
        fresh_model = self.create_model()
        with TemporaryDirectory() as f:
            checkpointer = Checkpointer(trained_model, path=f)
            checkpointer.save("checkpoint_file")

            # on different folders
            with TemporaryDirectory() as g:
                fresh_checkpointer = Checkpointer(fresh_model, path=g)
                assert not fresh_checkpointer.has_checkpoint()
                assert fresh_checkpointer.last_checkpoint() is None
                _ = fresh_checkpointer.load(
                    os.path.join(f, "checkpoint_file.pth")
                )

        for trained_p, loaded_p in zip(
            trained_model.parameters(), fresh_model.parameters()
        ):
            # different tensor references
            assert id(trained_p) != id(loaded_p)
            # same content
            assert trained_p.equal(loaded_p)