Commit c8943f5d authored by Anjith GEORGE's avatar Anjith GEORGE

Adds more unit tests for trainers with CV and fuixed a typo

parent 90453ad1
Pipeline #27603 passed with stage
in 29 minutes and 48 seconds
......@@ -8,7 +8,7 @@ class FASNet(nn.Module):
"""PyTorch Reimplementation of Lucena, Oeslle, et al. "Transfer learning using
convolutional neural networks for face anti-spoofing."
International Conference Image Analysis and Recognition. Springer, Cham, 2017.
eferenced from keras implementation: https://github.com/OeslleLucena/FASNet
Referenced from keras implementation: https://github.com/OeslleLucena/FASNet
Attributes:
pretrained: bool
......
......@@ -239,6 +239,27 @@ def test_MCCNNtrainer():
os.remove('model_1_0.pth')
def test_MCCNNtrainer_cv():
from ..architectures import MCCNN
net = MCCNN(num_channels=4)
dataloader={}
dataloader['train'] = torch.utils.data.DataLoader(DummyDataSetMCCNN(), batch_size=32, shuffle=True)
dataloader['val'] = torch.utils.data.DataLoader(DummyDataSetMCCNN(), batch_size=32, shuffle=True)
from ..trainers import MCCNNTrainer
trainer = MCCNNTrainer(net, verbosity_level=3, do_crossvalidation=True)
trainer.train(dataloader, n_epochs=1, output_dir='.')
import os
assert os.path.isfile('model_1_0.pth')
assert os.path.isfile('model_100_0.pth') # the best model
os.remove('model_1_0.pth')
os.remove('model_100_0.pth')
class DummyDataSetFASNet(Dataset):
def __init__(self):
pass
......@@ -268,6 +289,26 @@ def test_FASNettrainer():
os.remove('model_1_0.pth')
def test_FASNettrainer_cv():
from ..architectures import FASNet
net = FASNet()
dataloader={}
dataloader['train'] = torch.utils.data.DataLoader(DummyDataSetFASNet(), batch_size=32, shuffle=True)
dataloader['val'] = torch.utils.data.DataLoader(DummyDataSetFASNet(), batch_size=32, shuffle=True)
from ..trainers import FASNetTrainer
trainer = FASNetTrainer(net, verbosity_level=3,do_crossvalidation=True)
trainer.train(dataloader, n_epochs=1, output_dir='.')
import os
assert os.path.isfile('model_1_0.pth')
assert os.path.isfile('model_100_0.pth')
os.remove('model_1_0.pth')
os.remove('model_100_0.pth')
class DummyDataSetGAN(Dataset):
def __init__(self):
pass
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment