test.py 2.47 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38
#!/usr/bin/env python
# encoding: utf-8


""" Unit tests

"""

import numpy
import torch

def test_architectures():

  a = numpy.random.rand(1, 3, 128, 128).astype("float32")
  t = torch.from_numpy(a)
  
  number_of_classes = 20
  output_dimension = number_of_classes
  
  # CASIANet
  from ..architectures import CASIANet
  net = CASIANet(number_of_classes)
  embedding_dimension = 320
  output, emdedding = net.forward(t)
  assert output.shape == torch.Size([1, 20])
  assert emdedding.shape == torch.Size([1, 320])
  
  # CNN8
  from ..architectures import CNN8
  net = CNN8(number_of_classes)
  embedding_dimension = 512
  output, emdedding = net.forward(t)
  assert output.shape == torch.Size([1, 20])
  assert emdedding.shape == torch.Size([1, 512])
 

def test_transforms():

39
  image = numpy.random.rand(3, 128, 128).astype("uint8")
40 41

  from ..datasets import RollChannels 
42
  sample = {'image': image}
43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102
  rc = RollChannels()
  rc(sample)
  assert sample['image'].shape == (128, 128, 3)

  from ..datasets import ToTensor 
  tt = ToTensor()
  tt(sample)
  assert isinstance(sample['image'], torch.Tensor)

  from ..datasets import Normalize 
  image_copy = torch.Tensor(sample['image'])
  norm = Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
  norm(sample)
  for c in range(3):
    for h in range(sample['image'].shape[0]):
      for w in range(sample['image'].shape[1]):
        assert (abs(sample['image'][c, h, w]) - abs((image_copy[c, h, w] - 0.5) / 0.5)) < 1e-10


def test_map_labels():

  labels = ['1', '4', '7']
  from ..datasets import map_labels
  new_labels = map_labels(labels)
  new_labels = sorted(new_labels)
  assert new_labels == ['0', '1', '2']

  new_labels = map_labels(labels, start_index = 5)
  new_labels = sorted(new_labels)
  assert new_labels == ['5', '6', '7']
 

from torch.utils.data import Dataset
class DummyDataSet(Dataset):
  def __init__(self):
    pass
  def __len__(self):
    return 100
  def __getitem__(self, idx):
    data =  numpy.random.rand(3, 128, 128).astype("float32")
    label = numpy.random.randint(20)
    sample = {'image': torch.from_numpy(data), 'label': label}
    return sample 


def test_trainer():

  from ..architectures import CNN8
  net = CNN8(20)

  dataloader = torch.utils.data.DataLoader(DummyDataSet(), batch_size=32, shuffle=True)
  
  from ..trainers import CNNTrainer
  trainer = CNNTrainer(net, verbosity_level=3)
  trainer.train(dataloader, n_epochs=1, output_dir='.')

  import os
  assert os.path.isfile('model_1_0.pth')

  os.remove('model_1_0.pth')