Commit 1c7a49d2 authored by Guillaume HEUSCH's avatar Guillaume HEUSCH
Browse files

[tests] removed MLP tests, moved Autoencoder architecture test

parent 106dcfc0
......@@ -43,8 +43,6 @@ def test_architectures():
assert emdedding.shape == torch.Size([1, 256])
# LightCNN29
a = numpy.random.rand(1, 1, 128, 128).astype("float32")
t = torch.from_numpy(a)
from ..architectures import LightCNN29
net = LightCNN29()
output, emdedding = net.forward(t)
......@@ -52,8 +50,6 @@ def test_architectures():
assert emdedding.shape == torch.Size([1, 256])
# LightCNN29v2
a = numpy.random.rand(1, 1, 128, 128).astype("float32")
t = torch.from_numpy(a)
from ..architectures import LightCNN29v2
net = LightCNN29v2()
output, emdedding = net.forward(t)
......@@ -69,8 +65,6 @@ def test_architectures():
assert output.shape == torch.Size([1, 1])
# MCCNNv2
a = numpy.random.rand(1, 4, 128, 128).astype("float32")
t = torch.from_numpy(a)
from ..architectures import MCCNNv2
net = MCCNNv2(num_channels=4)
output = net.forward(t)
......@@ -117,7 +111,6 @@ def test_architectures():
assert output.shape == torch.Size([1, 3, 64, 64])
# Conditional GAN
d = numpy.random.rand(1, 3, 64, 64).astype("float32")
t = torch.from_numpy(d)
cfm = numpy.zeros((1, 13, 64, 64), dtype="float32")
cfm[:, 0, :, :] = 1
......@@ -127,7 +120,6 @@ def test_architectures():
output = discriminator.forward(t, cfmt)
assert output.shape == torch.Size([1])
g = numpy.random.rand(1, 100, 1, 1).astype("float32")
t = torch.from_numpy(g)
oh = numpy.zeros((1, 13, 1, 1), dtype="float32")
oh[0] = 1
......@@ -136,6 +128,16 @@ def test_architectures():
discriminator = ConditionalGAN_generator(100, 13)
output = discriminator.forward(t, oht)
assert output.shape == torch.Size([1, 3, 64, 64])
# Convolutional Autoencoder
from bob.learn.pytorch.architectures import ConvAutoencoder
batch = torch.randn(1, 3, 64, 64)
model = ConvAutoencoder()
output = model(batch)
assert batch.shape == output.shape
model_embeddings = ConvAutoencoder(return_latent_embedding = True)
embedding = model_embeddings(batch)
assert list(embedding.shape) == [1, 16, 5, 5]
def test_transforms():
......@@ -186,15 +188,11 @@ def test_map_labels():
assert '0' in new_labels, "new_labels = {}".format(new_labels)
assert '1' in new_labels, "new_labels = {}".format(new_labels)
assert '2' in new_labels, "new_labels = {}".format(new_labels)
#new_labels = sorted(new_labels)
#assert new_labels == ['0', '1', '2']
new_labels = map_labels(labels, start_index = 5)
#new_labels = sorted(new_labels)
assert '5' in new_labels, "new_labels = {}".format(new_labels)
assert '6' in new_labels, "new_labels = {}".format(new_labels)
assert '7' in new_labels, "new_labels = {}".format(new_labels)
#assert new_labels == ['5', '6', '7']
from torch.utils.data import Dataset
......@@ -390,22 +388,6 @@ def test_ConditionalGANTrainer():
os.remove('netG_epoch_0.pth')
def test_conv_autoencoder():
"""
Test the ConvAutoencoder class.
"""
from bob.learn.pytorch.architectures import ConvAutoencoder
batch = torch.randn(1, 3, 64, 64)
model = ConvAutoencoder()
output = model(batch)
assert batch.shape == output.shape
model_embeddings = ConvAutoencoder(return_latent_embedding = True)
embedding = model_embeddings(batch)
assert list(embedding.shape) == [1, 16, 5, 5]
def test_extractors():
# lightCNN9
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment