test.py 8.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
#!/usr/bin/env python
# encoding: utf-8


""" Unit tests

"""

import numpy
import torch

def test_architectures():

  a = numpy.random.rand(1, 3, 128, 128).astype("float32")
  t = torch.from_numpy(a)
16

17 18
  number_of_classes = 20
  output_dimension = number_of_classes
19

20 21 22 23 24 25 26
  # CASIANet
  from ..architectures import CASIANet
  net = CASIANet(number_of_classes)
  embedding_dimension = 320
  output, emdedding = net.forward(t)
  assert output.shape == torch.Size([1, 20])
  assert emdedding.shape == torch.Size([1, 320])
27

28 29 30 31 32 33 34
  # CNN8
  from ..architectures import CNN8
  net = CNN8(number_of_classes)
  embedding_dimension = 512
  output, emdedding = net.forward(t)
  assert output.shape == torch.Size([1, 20])
  assert emdedding.shape == torch.Size([1, 512])
35

36 37 38 39 40 41 42 43
  # LightCNN9
  a = numpy.random.rand(1, 1, 128, 128).astype("float32")
  t = torch.from_numpy(a)
  from ..architectures import LightCNN9
  net = LightCNN9()
  output, emdedding = net.forward(t)
  assert output.shape == torch.Size([1, 79077])
  assert emdedding.shape == torch.Size([1, 256])
44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61
  
  # LightCNN29
  a = numpy.random.rand(1, 1, 128, 128).astype("float32")
  t = torch.from_numpy(a)
  from ..architectures import LightCNN29
  net = LightCNN29()
  output, emdedding = net.forward(t)
  assert output.shape == torch.Size([1, 79077])
  assert emdedding.shape == torch.Size([1, 256])

  # LightCNN29v2
  a = numpy.random.rand(1, 1, 128, 128).astype("float32")
  t = torch.from_numpy(a)
  from ..architectures import LightCNN29v2
  net = LightCNN29v2()
  output, emdedding = net.forward(t)
  assert output.shape == torch.Size([1, 79077])
  assert emdedding.shape == torch.Size([1, 256])
62

63 64 65 66 67 68 69 70 71 72 73 74 75 76
  # DCGAN
  d = numpy.random.rand(1, 3, 64, 64).astype("float32")
  t = torch.from_numpy(d)
  from ..architectures import DCGAN_discriminator
  discriminator = DCGAN_discriminator(1)
  output = discriminator.forward(t)
  assert output.shape == torch.Size([1])

  g = numpy.random.rand(1, 100, 1, 1).astype("float32")
  t = torch.from_numpy(g)
  from ..architectures import DCGAN_generator
  generator = DCGAN_generator(1)
  output = generator.forward(t)
  assert output.shape == torch.Size([1, 3, 64, 64])
77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98

  # Conditional GAN
  d = numpy.random.rand(1, 3, 64, 64).astype("float32")
  t = torch.from_numpy(d)
  cfm = numpy.zeros((1, 13, 64, 64), dtype="float32")
  cfm[:, 0, :, :] = 1
  cfmt = torch.from_numpy(cfm)
  from ..architectures import ConditionalGAN_discriminator
  discriminator = ConditionalGAN_discriminator(13)
  output = discriminator.forward(t, cfmt)
  assert output.shape == torch.Size([1])

  g = numpy.random.rand(1, 100, 1, 1).astype("float32")
  t = torch.from_numpy(g)
  oh = numpy.zeros((1, 13, 1, 1), dtype="float32")
  oh[0] = 1
  oht = torch.from_numpy(oh)
  from ..architectures import ConditionalGAN_generator
  discriminator = ConditionalGAN_generator(100, 13)
  output = discriminator.forward(t, oht)
  assert output.shape == torch.Size([1, 3, 64, 64])

99 100 101

def test_transforms():

102
  image = numpy.random.rand(3, 128, 128).astype("uint8")
103

104
  from ..datasets import RollChannels
105
  sample = {'image': image}
106 107 108 109
  rc = RollChannels()
  rc(sample)
  assert sample['image'].shape == (128, 128, 3)

110
  from ..datasets import ToTensor
111 112 113
  tt = ToTensor()
  tt(sample)
  assert isinstance(sample['image'], torch.Tensor)
114 115 116 117 118
  # grayscale 
  image_gray = numpy.random.rand(128, 128).astype("uint8")
  sample_gray = {'image': image_gray}
  tt(sample_gray)
  assert isinstance(sample['image'], torch.Tensor)
119

120
  from ..datasets import Normalize
121 122 123 124 125 126 127 128 129 130 131 132 133 134
  image_copy = torch.Tensor(sample['image'])
  norm = Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
  norm(sample)
  for c in range(3):
    for h in range(sample['image'].shape[0]):
      for w in range(sample['image'].shape[1]):
        assert (abs(sample['image'][c, h, w]) - abs((image_copy[c, h, w] - 0.5) / 0.5)) < 1e-10


def test_map_labels():

  labels = ['1', '4', '7']
  from ..datasets import map_labels
  new_labels = map_labels(labels)
135 136 137
  assert '0' in new_labels, "new_labels = {}".format(new_labels)
  assert '1' in new_labels, "new_labels = {}".format(new_labels)
  assert '2' in new_labels, "new_labels = {}".format(new_labels)
138 139
  #new_labels = sorted(new_labels)
  #assert new_labels == ['0', '1', '2']
140 141

  new_labels = map_labels(labels, start_index = 5)
142
  #new_labels = sorted(new_labels)
143 144 145
  assert '5' in new_labels, "new_labels = {}".format(new_labels)
  assert '6' in new_labels, "new_labels = {}".format(new_labels)
  assert '7' in new_labels, "new_labels = {}".format(new_labels)
146
  #assert new_labels == ['5', '6', '7']
147

148 149 150 151 152 153 154 155

from torch.utils.data import Dataset
class DummyDataSet(Dataset):
  def __init__(self):
    pass
  def __len__(self):
    return 100
  def __getitem__(self, idx):
156
    data =  numpy.random.rand(1, 128, 128).astype("float32")
157 158
    label = numpy.random.randint(20)
    sample = {'image': torch.from_numpy(data), 'label': label}
159
    return sample
160 161


162
def test_CNNtrainer():
163

164 165
  from ..architectures import LightCNN9
  net = LightCNN9(20)
166 167

  dataloader = torch.utils.data.DataLoader(DummyDataSet(), batch_size=32, shuffle=True)
168

169 170 171 172 173 174 175 176
  from ..trainers import CNNTrainer
  trainer = CNNTrainer(net, verbosity_level=3)
  trainer.train(dataloader, n_epochs=1, output_dir='.')

  import os
  assert os.path.isfile('model_1_0.pth')

  os.remove('model_1_0.pth')
177 178 179 180 181 182 183 184 185 186


class DummyDataSetGAN(Dataset):
  def __init__(self):
    pass
  def __len__(self):
    return 100
  def __getitem__(self, idx):
    data =  numpy.random.rand(3, 64, 64).astype("float32")
    sample = {'image': torch.from_numpy(data)}
187
    return sample
188 189 190 191 192 193 194 195 196

def test_DCGANtrainer():

  from ..architectures import DCGAN_generator
  from ..architectures import DCGAN_discriminator
  g = DCGAN_generator(1)
  d = DCGAN_discriminator(1)

  dataloader = torch.utils.data.DataLoader(DummyDataSetGAN(), batch_size=32, shuffle=True)
197

198 199 200 201 202 203 204 205 206 207 208 209 210
  from ..trainers import DCGANTrainer
  trainer = DCGANTrainer(g, d, batch_size=32, noise_dim=100, use_gpu=False, verbosity_level=2)
  trainer.train(dataloader, n_epochs=1, output_dir='.')

  import os
  assert os.path.isfile('fake_samples_epoch_000.png')
  assert os.path.isfile('netD_epoch_0.pth')
  assert os.path.isfile('netG_epoch_0.pth')

  os.remove('fake_samples_epoch_000.png')
  os.remove('netD_epoch_0.pth')
  os.remove('netG_epoch_0.pth')

211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228
class DummyDataSetConditionalGAN(Dataset):
  def __init__(self):
    pass
  def __len__(self):
    return 100
  def __getitem__(self, idx):
    data =  numpy.random.rand(3, 64, 64).astype("float32")
    sample = {'image': torch.from_numpy(data), 'pose': numpy.random.randint(0, 13)}
    return sample

def test_ConditionalGANTrainer():

  from ..architectures import ConditionalGAN_generator
  from ..architectures import ConditionalGAN_discriminator
  g = ConditionalGAN_generator(100, 13)
  d = ConditionalGAN_discriminator(13)

  dataloader = torch.utils.data.DataLoader(DummyDataSetConditionalGAN(), batch_size=32, shuffle=True)
229

230 231 232
  from ..trainers import ConditionalGANTrainer
  trainer = ConditionalGANTrainer(g, d, [3, 64, 64], batch_size=32, noise_dim=100, conditional_dim=13)
  trainer.train(dataloader, n_epochs=1, output_dir='.')
233

234 235 236 237 238 239 240
  import os
  assert os.path.isfile('fake_samples_epoch_000.png')
  assert os.path.isfile('netD_epoch_0.pth')
  assert os.path.isfile('netG_epoch_0.pth')
  os.remove('fake_samples_epoch_000.png')
  os.remove('netD_epoch_0.pth')
  os.remove('netG_epoch_0.pth')
241 242 243 244 245 246 247


def test_conv_autoencoder():
    """
    Test the ConvAutoencoder class.
    """
    from bob.learn.pytorch.architectures import ConvAutoencoder
248
    
249 250 251 252 253
    batch = torch.randn(1, 3, 64, 64)
    model = ConvAutoencoder()
    output = model(batch)
    assert batch.shape == output.shape

254 255 256 257 258
    model_embeddings = ConvAutoencoder(return_latent_embedding = True)
    embedding = model_embeddings(batch)
    assert list(embedding.shape) == [1, 16, 5, 5]


259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284
def test_extractors():

  # lightCNN9
  from bob.learn.pytorch.extractor.image import LightCNN9Extractor
  extractor = LightCNN9Extractor()
  # this architecture expects 128x128 grayscale images
  data = numpy.random.rand(128, 128).astype("float32")
  output = extractor(data)
  assert output.shape[0] == 256

  # lightCNN29
  from bob.learn.pytorch.extractor.image import LightCNN29Extractor
  extractor = LightCNN29Extractor()
  # this architecture expects 128x128 grayscale images
  data = numpy.random.rand(128, 128).astype("float32")
  output = extractor(data)
  assert output.shape[0] == 256

  # lightCNN29v2
  from bob.learn.pytorch.extractor.image import LightCNN29v2Extractor
  extractor = LightCNN29v2Extractor()
  # this architecture expects 128x128 grayscale images
  data = numpy.random.rand(128, 128).astype("float32")
  output = extractor(data)
  assert output.shape[0] == 256