Skip to content
Snippets Groups Projects
Commit c8943f5d authored by Anjith GEORGE's avatar Anjith GEORGE
Browse files

Adds more unit tests for trainers with CV and fuixed a typo

parent 90453ad1
No related branches found
No related tags found
1 merge request!22Cross validation
Pipeline #27603 passed
...@@ -8,7 +8,7 @@ class FASNet(nn.Module): ...@@ -8,7 +8,7 @@ class FASNet(nn.Module):
"""PyTorch Reimplementation of Lucena, Oeslle, et al. "Transfer learning using """PyTorch Reimplementation of Lucena, Oeslle, et al. "Transfer learning using
convolutional neural networks for face anti-spoofing." convolutional neural networks for face anti-spoofing."
International Conference Image Analysis and Recognition. Springer, Cham, 2017. International Conference Image Analysis and Recognition. Springer, Cham, 2017.
eferenced from keras implementation: https://github.com/OeslleLucena/FASNet Referenced from keras implementation: https://github.com/OeslleLucena/FASNet
Attributes: Attributes:
pretrained: bool pretrained: bool
......
...@@ -239,6 +239,27 @@ def test_MCCNNtrainer(): ...@@ -239,6 +239,27 @@ def test_MCCNNtrainer():
os.remove('model_1_0.pth') os.remove('model_1_0.pth')
def test_MCCNNtrainer_cv():
from ..architectures import MCCNN
net = MCCNN(num_channels=4)
dataloader={}
dataloader['train'] = torch.utils.data.DataLoader(DummyDataSetMCCNN(), batch_size=32, shuffle=True)
dataloader['val'] = torch.utils.data.DataLoader(DummyDataSetMCCNN(), batch_size=32, shuffle=True)
from ..trainers import MCCNNTrainer
trainer = MCCNNTrainer(net, verbosity_level=3, do_crossvalidation=True)
trainer.train(dataloader, n_epochs=1, output_dir='.')
import os
assert os.path.isfile('model_1_0.pth')
assert os.path.isfile('model_100_0.pth') # the best model
os.remove('model_1_0.pth')
os.remove('model_100_0.pth')
class DummyDataSetFASNet(Dataset): class DummyDataSetFASNet(Dataset):
def __init__(self): def __init__(self):
pass pass
...@@ -268,6 +289,26 @@ def test_FASNettrainer(): ...@@ -268,6 +289,26 @@ def test_FASNettrainer():
os.remove('model_1_0.pth') os.remove('model_1_0.pth')
def test_FASNettrainer_cv():
from ..architectures import FASNet
net = FASNet()
dataloader={}
dataloader['train'] = torch.utils.data.DataLoader(DummyDataSetFASNet(), batch_size=32, shuffle=True)
dataloader['val'] = torch.utils.data.DataLoader(DummyDataSetFASNet(), batch_size=32, shuffle=True)
from ..trainers import FASNetTrainer
trainer = FASNetTrainer(net, verbosity_level=3,do_crossvalidation=True)
trainer.train(dataloader, n_epochs=1, output_dir='.')
import os
assert os.path.isfile('model_1_0.pth')
assert os.path.isfile('model_100_0.pth')
os.remove('model_1_0.pth')
os.remove('model_100_0.pth')
class DummyDataSetGAN(Dataset): class DummyDataSetGAN(Dataset):
def __init__(self): def __init__(self):
pass pass
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment