From 51efda2652dfd6ec42ed47fa9b85c8f4c7d62958 Mon Sep 17 00:00:00 2001 From: dcarron <daniel.carron@idiap.ch> Date: Wed, 12 Apr 2023 16:49:10 +0200 Subject: [PATCH] Removed checkpointer tests as we are using the lightning callback --- tests/test_checkpointer.py | 85 -------------------------------------- 1 file changed, 85 deletions(-) delete mode 100644 tests/test_checkpointer.py diff --git a/tests/test_checkpointer.py b/tests/test_checkpointer.py deleted file mode 100644 index aca95248..00000000 --- a/tests/test_checkpointer.py +++ /dev/null @@ -1,85 +0,0 @@ -# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> -# -# SPDX-License-Identifier: GPL-3.0-or-later - -import os -import unittest - -from collections import OrderedDict -from tempfile import TemporaryDirectory - -import torch - -from ptbench.utils.checkpointer import Checkpointer - - -class TestCheckpointer(unittest.TestCase): - def create_model(self): - return torch.nn.Sequential(torch.nn.Linear(2, 3), torch.nn.Linear(3, 1)) - - def create_complex_model(self): - m = torch.nn.Module() - m.block1 = torch.nn.Module() - m.block1.layer1 = torch.nn.Linear(2, 3) - m.layer2 = torch.nn.Linear(3, 2) - m.res = torch.nn.Module() - m.res.layer2 = torch.nn.Linear(3, 2) - - state_dict = OrderedDict() - state_dict["layer1.weight"] = torch.rand(3, 2) - state_dict["layer1.bias"] = torch.rand(3) - state_dict["layer2.weight"] = torch.rand(2, 3) - state_dict["layer2.bias"] = torch.rand(2) - state_dict["res.layer2.weight"] = torch.rand(2, 3) - state_dict["res.layer2.bias"] = torch.rand(2) - - return m, state_dict - - def test_from_last_checkpoint_model(self): - # test that loading works even if they differ by a prefix - trained_model = self.create_model() - fresh_model = self.create_model() - with TemporaryDirectory() as f: - checkpointer = Checkpointer(trained_model, path=f) - checkpointer.save("checkpoint_file") - - # in the same folder - fresh_checkpointer = Checkpointer(fresh_model, path=f) - assert fresh_checkpointer.has_checkpoint() - assert fresh_checkpointer.last_checkpoint() == os.path.realpath( - os.path.join(f, "checkpoint_file.pth") - ) - _ = fresh_checkpointer.load() - - for trained_p, loaded_p in zip( - trained_model.parameters(), fresh_model.parameters() - ): - # different tensor references - assert id(trained_p) != id(loaded_p) - # same content - assert trained_p.equal(loaded_p) - - def test_from_name_file_model(self): - # test that loading works even if they differ by a prefix - trained_model = self.create_model() - fresh_model = self.create_model() - with TemporaryDirectory() as f: - checkpointer = Checkpointer(trained_model, path=f) - checkpointer.save("checkpoint_file") - - # on different folders - with TemporaryDirectory() as g: - fresh_checkpointer = Checkpointer(fresh_model, path=g) - assert not fresh_checkpointer.has_checkpoint() - assert fresh_checkpointer.last_checkpoint() is None - _ = fresh_checkpointer.load( - os.path.join(f, "checkpoint_file.pth") - ) - - for trained_p, loaded_p in zip( - trained_model.parameters(), fresh_model.parameters() - ): - # different tensor references - assert id(trained_p) != id(loaded_p) - # same content - assert trained_p.equal(loaded_p) -- GitLab