Commit 181b2fff authored by Olegs NIKISINS's avatar Olegs NIKISINS

Added the unit test for ConvAutoencoder model

parent eca9f56b
Pipeline #26248 passed with stage
in 7 minutes and 39 seconds
......@@ -13,10 +13,10 @@ def test_architectures():
a = numpy.random.rand(1, 3, 128, 128).astype("float32")
t = torch.from_numpy(a)
number_of_classes = 20
output_dimension = number_of_classes
# CASIANet
from ..architectures import CASIANet
net = CASIANet(number_of_classes)
......@@ -24,7 +24,7 @@ def test_architectures():
output, emdedding = net.forward(t)
assert output.shape == torch.Size([1, 20])
assert emdedding.shape == torch.Size([1, 320])
# CNN8
from ..architectures import CNN8
net = CNN8(number_of_classes)
......@@ -74,18 +74,18 @@ def test_transforms():
image = numpy.random.rand(3, 128, 128).astype("uint8")
from ..datasets import RollChannels
from ..datasets import RollChannels
sample = {'image': image}
rc = RollChannels()
rc(sample)
assert sample['image'].shape == (128, 128, 3)
from ..datasets import ToTensor
from ..datasets import ToTensor
tt = ToTensor()
tt(sample)
assert isinstance(sample['image'], torch.Tensor)
from ..datasets import Normalize
from ..datasets import Normalize
image_copy = torch.Tensor(sample['image'])
norm = Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
norm(sample)
......@@ -106,7 +106,7 @@ def test_map_labels():
new_labels = map_labels(labels, start_index = 5)
new_labels = sorted(new_labels)
assert new_labels == ['5', '6', '7']
from torch.utils.data import Dataset
class DummyDataSet(Dataset):
......@@ -118,7 +118,7 @@ class DummyDataSet(Dataset):
data = numpy.random.rand(3, 128, 128).astype("float32")
label = numpy.random.randint(20)
sample = {'image': torch.from_numpy(data), 'label': label}
return sample
return sample
def test_CNNtrainer():
......@@ -127,7 +127,7 @@ def test_CNNtrainer():
net = CNN8(20)
dataloader = torch.utils.data.DataLoader(DummyDataSet(), batch_size=32, shuffle=True)
from ..trainers import CNNTrainer
trainer = CNNTrainer(net, verbosity_level=3)
trainer.train(dataloader, n_epochs=1, output_dir='.')
......@@ -146,7 +146,7 @@ class DummyDataSetGAN(Dataset):
def __getitem__(self, idx):
data = numpy.random.rand(3, 64, 64).astype("float32")
sample = {'image': torch.from_numpy(data)}
return sample
return sample
def test_DCGANtrainer():
......@@ -156,7 +156,7 @@ def test_DCGANtrainer():
d = DCGAN_discriminator(1)
dataloader = torch.utils.data.DataLoader(DummyDataSetGAN(), batch_size=32, shuffle=True)
from ..trainers import DCGANTrainer
trainer = DCGANTrainer(g, d, batch_size=32, noise_dim=100, use_gpu=False, verbosity_level=2)
trainer.train(dataloader, n_epochs=1, output_dir='.')
......@@ -188,11 +188,11 @@ def test_ConditionalGANTrainer():
d = ConditionalGAN_discriminator(13)
dataloader = torch.utils.data.DataLoader(DummyDataSetConditionalGAN(), batch_size=32, shuffle=True)
from ..trainers import ConditionalGANTrainer
trainer = ConditionalGANTrainer(g, d, [3, 64, 64], batch_size=32, noise_dim=100, conditional_dim=13)
trainer.train(dataloader, n_epochs=1, output_dir='.')
import os
assert os.path.isfile('fake_samples_epoch_000.png')
assert os.path.isfile('netD_epoch_0.pth')
......@@ -200,3 +200,20 @@ def test_ConditionalGANTrainer():
os.remove('fake_samples_epoch_000.png')
os.remove('netD_epoch_0.pth')
os.remove('netG_epoch_0.pth')
def test_conv_autoencoder():
"""
Test the ConvAutoencoder class.
"""
from bob.learn.pytorch.architectures import ConvAutoencoder
batch = torch.randn(1, 3, 64, 64)
model = ConvAutoencoder()
output = model(batch)
assert batch.shape == output.shape
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment