Skip to content
Snippets Groups Projects
Commit 181b2fff authored by Olegs NIKISINS's avatar Olegs NIKISINS
Browse files

Added the unit test for ConvAutoencoder model

parent eca9f56b
No related branches found
No related tags found
1 merge request!6autoencoders pretraining using RGB faces
Pipeline #26248 passed
......@@ -13,10 +13,10 @@ def test_architectures():
a = numpy.random.rand(1, 3, 128, 128).astype("float32")
t = torch.from_numpy(a)
number_of_classes = 20
output_dimension = number_of_classes
# CASIANet
from ..architectures import CASIANet
net = CASIANet(number_of_classes)
......@@ -24,7 +24,7 @@ def test_architectures():
output, emdedding = net.forward(t)
assert output.shape == torch.Size([1, 20])
assert emdedding.shape == torch.Size([1, 320])
# CNN8
from ..architectures import CNN8
net = CNN8(number_of_classes)
......@@ -74,18 +74,18 @@ def test_transforms():
image = numpy.random.rand(3, 128, 128).astype("uint8")
from ..datasets import RollChannels
from ..datasets import RollChannels
sample = {'image': image}
rc = RollChannels()
rc(sample)
assert sample['image'].shape == (128, 128, 3)
from ..datasets import ToTensor
from ..datasets import ToTensor
tt = ToTensor()
tt(sample)
assert isinstance(sample['image'], torch.Tensor)
from ..datasets import Normalize
from ..datasets import Normalize
image_copy = torch.Tensor(sample['image'])
norm = Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
norm(sample)
......@@ -106,7 +106,7 @@ def test_map_labels():
new_labels = map_labels(labels, start_index = 5)
new_labels = sorted(new_labels)
assert new_labels == ['5', '6', '7']
from torch.utils.data import Dataset
class DummyDataSet(Dataset):
......@@ -118,7 +118,7 @@ class DummyDataSet(Dataset):
data = numpy.random.rand(3, 128, 128).astype("float32")
label = numpy.random.randint(20)
sample = {'image': torch.from_numpy(data), 'label': label}
return sample
return sample
def test_CNNtrainer():
......@@ -127,7 +127,7 @@ def test_CNNtrainer():
net = CNN8(20)
dataloader = torch.utils.data.DataLoader(DummyDataSet(), batch_size=32, shuffle=True)
from ..trainers import CNNTrainer
trainer = CNNTrainer(net, verbosity_level=3)
trainer.train(dataloader, n_epochs=1, output_dir='.')
......@@ -146,7 +146,7 @@ class DummyDataSetGAN(Dataset):
def __getitem__(self, idx):
data = numpy.random.rand(3, 64, 64).astype("float32")
sample = {'image': torch.from_numpy(data)}
return sample
return sample
def test_DCGANtrainer():
......@@ -156,7 +156,7 @@ def test_DCGANtrainer():
d = DCGAN_discriminator(1)
dataloader = torch.utils.data.DataLoader(DummyDataSetGAN(), batch_size=32, shuffle=True)
from ..trainers import DCGANTrainer
trainer = DCGANTrainer(g, d, batch_size=32, noise_dim=100, use_gpu=False, verbosity_level=2)
trainer.train(dataloader, n_epochs=1, output_dir='.')
......@@ -188,11 +188,11 @@ def test_ConditionalGANTrainer():
d = ConditionalGAN_discriminator(13)
dataloader = torch.utils.data.DataLoader(DummyDataSetConditionalGAN(), batch_size=32, shuffle=True)
from ..trainers import ConditionalGANTrainer
trainer = ConditionalGANTrainer(g, d, [3, 64, 64], batch_size=32, noise_dim=100, conditional_dim=13)
trainer.train(dataloader, n_epochs=1, output_dir='.')
import os
assert os.path.isfile('fake_samples_epoch_000.png')
assert os.path.isfile('netD_epoch_0.pth')
......@@ -200,3 +200,20 @@ def test_ConditionalGANTrainer():
os.remove('fake_samples_epoch_000.png')
os.remove('netD_epoch_0.pth')
os.remove('netG_epoch_0.pth')
def test_conv_autoencoder():
"""
Test the ConvAutoencoder class.
"""
from bob.learn.pytorch.architectures import ConvAutoencoder
batch = torch.randn(1, 3, 64, 64)
model = ConvAutoencoder()
output = model(batch)
assert batch.shape == output.shape
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment