Skip to content
Snippets Groups Projects
Commit 51efda26 authored by Daniel CARRON's avatar Daniel CARRON :b:
Browse files

Removed checkpointer tests as we are using the lightning callback

parent bf25f806
No related branches found
No related tags found
1 merge request!4Moved code to lightning
Pipeline #72400 failed
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
import os
import unittest
from collections import OrderedDict
from tempfile import TemporaryDirectory
import torch
from ptbench.utils.checkpointer import Checkpointer
class TestCheckpointer(unittest.TestCase):
def create_model(self):
return torch.nn.Sequential(torch.nn.Linear(2, 3), torch.nn.Linear(3, 1))
def create_complex_model(self):
m = torch.nn.Module()
m.block1 = torch.nn.Module()
m.block1.layer1 = torch.nn.Linear(2, 3)
m.layer2 = torch.nn.Linear(3, 2)
m.res = torch.nn.Module()
m.res.layer2 = torch.nn.Linear(3, 2)
state_dict = OrderedDict()
state_dict["layer1.weight"] = torch.rand(3, 2)
state_dict["layer1.bias"] = torch.rand(3)
state_dict["layer2.weight"] = torch.rand(2, 3)
state_dict["layer2.bias"] = torch.rand(2)
state_dict["res.layer2.weight"] = torch.rand(2, 3)
state_dict["res.layer2.bias"] = torch.rand(2)
return m, state_dict
def test_from_last_checkpoint_model(self):
# test that loading works even if they differ by a prefix
trained_model = self.create_model()
fresh_model = self.create_model()
with TemporaryDirectory() as f:
checkpointer = Checkpointer(trained_model, path=f)
checkpointer.save("checkpoint_file")
# in the same folder
fresh_checkpointer = Checkpointer(fresh_model, path=f)
assert fresh_checkpointer.has_checkpoint()
assert fresh_checkpointer.last_checkpoint() == os.path.realpath(
os.path.join(f, "checkpoint_file.pth")
)
_ = fresh_checkpointer.load()
for trained_p, loaded_p in zip(
trained_model.parameters(), fresh_model.parameters()
):
# different tensor references
assert id(trained_p) != id(loaded_p)
# same content
assert trained_p.equal(loaded_p)
def test_from_name_file_model(self):
# test that loading works even if they differ by a prefix
trained_model = self.create_model()
fresh_model = self.create_model()
with TemporaryDirectory() as f:
checkpointer = Checkpointer(trained_model, path=f)
checkpointer.save("checkpoint_file")
# on different folders
with TemporaryDirectory() as g:
fresh_checkpointer = Checkpointer(fresh_model, path=g)
assert not fresh_checkpointer.has_checkpoint()
assert fresh_checkpointer.last_checkpoint() is None
_ = fresh_checkpointer.load(
os.path.join(f, "checkpoint_file.pth")
)
for trained_p, loaded_p in zip(
trained_model.parameters(), fresh_model.parameters()
):
# different tensor references
assert id(trained_p) != id(loaded_p)
# same content
assert trained_p.equal(loaded_p)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment