Skip to content
Snippets Groups Projects

Cross validation

Merged Anjith GEORGE requested to merge cross_validation into master
1 unresolved thread
2 files
+ 119
27
Compare changes
  • Side-by-side
  • Inline
Files
2
@@ -220,6 +220,7 @@ class DummyDataSetMCCNN(Dataset):
sample = data, label
return sample
def test_MCCNNtrainer():
from ..architectures import MCCNN
@@ -229,13 +230,44 @@ def test_MCCNNtrainer():
dataloader['train'] = torch.utils.data.DataLoader(DummyDataSetMCCNN(), batch_size=32, shuffle=True)
from ..trainers import MCCNNTrainer
trainer = MCCNNTrainer(net, verbosity_level=3)
trainer = MCCNNTrainer(net, verbosity_level=3, do_crossvalidation=False)
trainer.train(dataloader, n_epochs=1, output_dir='.')
import os
assert os.path.isfile('model_1_0.pth')
os.remove('model_1_0.pth')
class DummyDataSetFASNet(Dataset):
def __init__(self):
pass
def __len__(self):
return 100
def __getitem__(self, idx):
data = numpy.random.rand(3, 224,224).astype("float32")
label = numpy.random.randint(2)
sample = data, label
return sample
def test_FASNettrainer():
from ..architectures import FASNet
net = FASNet()
dataloader={}
dataloader['train'] = torch.utils.data.DataLoader(DummyDataSetFASNet(), batch_size=32, shuffle=True)
from ..trainers import FASNetTrainer
trainer = FASNetTrainer(net, verbosity_level=3,do_crossvalidation=False)
trainer.train(dataloader, n_epochs=1, output_dir='.')
import os
assert os.path.isfile('model_1_0.pth')
os.remove('model_1_0.pth')
class DummyDataSetGAN(Dataset):
def __init__(self):
pass
@@ -357,6 +389,14 @@ def test_extractors():
data = numpy.random.rand(4, 128, 128).astype("float32")
output = extractor(data)
assert output.shape[0] == 1
# FASNet
from ..extractor.image import FASNetExtractor
extractor = FASNetExtractor(num_channels_used=4)
# this architecture expects RGB images of size 3x224x224 channel images
data = numpy.random.rand(3, 224, 224).astype("float32")
output = extractor(data)
assert output.shape[0] == 1
def test_two_layer_mlp():
"""
Loading