Skip to content
Snippets Groups Projects
Commit 3352906d authored by Guillaume HEUSCH's avatar Guillaume HEUSCH
Browse files

Merge branch '3-add-unit-tests' into 'master'

Resolve "Add unit tests"

Closes #3

See merge request !3
parents 6294000c b67e80f5
No related branches found
No related tags found
1 merge request!3Resolve "Add unit tests"
Pipeline #
...@@ -10,7 +10,6 @@ parts ...@@ -10,7 +10,6 @@ parts
src src
develop-eggs develop-eggs
sphinx sphinx
test*
submit* submit*
log* log*
results* results*
......
# gets sphinx autodoc done right - don't remove it
__all__ = [_ for _ in dir() if not _.startswith('_')]
#!/usr/bin/env python
# encoding: utf-8
""" Unit tests
"""
import numpy
import torch
def test_architectures():
a = numpy.random.rand(1, 3, 128, 128).astype("float32")
t = torch.from_numpy(a)
number_of_classes = 20
output_dimension = number_of_classes
# CASIANet
from ..architectures import CASIANet
net = CASIANet(number_of_classes)
embedding_dimension = 320
output, emdedding = net.forward(t)
assert output.shape == torch.Size([1, 20])
assert emdedding.shape == torch.Size([1, 320])
# CNN8
from ..architectures import CNN8
net = CNN8(number_of_classes)
embedding_dimension = 512
output, emdedding = net.forward(t)
assert output.shape == torch.Size([1, 20])
assert emdedding.shape == torch.Size([1, 512])
def test_transforms():
image = numpy.random.rand(3, 128, 128).astype("uint8")
from ..datasets import RollChannels
sample = {'image': image}
rc = RollChannels()
rc(sample)
assert sample['image'].shape == (128, 128, 3)
from ..datasets import ToTensor
tt = ToTensor()
tt(sample)
assert isinstance(sample['image'], torch.Tensor)
from ..datasets import Normalize
image_copy = torch.Tensor(sample['image'])
norm = Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
norm(sample)
for c in range(3):
for h in range(sample['image'].shape[0]):
for w in range(sample['image'].shape[1]):
assert (abs(sample['image'][c, h, w]) - abs((image_copy[c, h, w] - 0.5) / 0.5)) < 1e-10
def test_map_labels():
labels = ['1', '4', '7']
from ..datasets import map_labels
new_labels = map_labels(labels)
new_labels = sorted(new_labels)
assert new_labels == ['0', '1', '2']
new_labels = map_labels(labels, start_index = 5)
new_labels = sorted(new_labels)
assert new_labels == ['5', '6', '7']
from torch.utils.data import Dataset
class DummyDataSet(Dataset):
def __init__(self):
pass
def __len__(self):
return 100
def __getitem__(self, idx):
data = numpy.random.rand(3, 128, 128).astype("float32")
label = numpy.random.randint(20)
sample = {'image': torch.from_numpy(data), 'label': label}
return sample
def test_trainer():
from ..architectures import CNN8
net = CNN8(20)
dataloader = torch.utils.data.DataLoader(DummyDataSet(), batch_size=32, shuffle=True)
from ..trainers import CNNTrainer
trainer = CNNTrainer(net, verbosity_level=3)
trainer.train(dataloader, n_epochs=1, output_dir='.')
import os
assert os.path.isfile('model_1_0.pth')
os.remove('model_1_0.pth')
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment