Skip to content
Snippets Groups Projects
Commit 482bfddb authored by Olegs NIKISINS's avatar Olegs NIKISINS
Browse files

Added a unit test for the TwoLayerMLP class

parent 8881e90f
No related branches found
No related tags found
1 merge request!14MLP class and config to train it
...@@ -41,7 +41,7 @@ def test_architectures(): ...@@ -41,7 +41,7 @@ def test_architectures():
output, emdedding = net.forward(t) output, emdedding = net.forward(t)
assert output.shape == torch.Size([1, 79077]) assert output.shape == torch.Size([1, 79077])
assert emdedding.shape == torch.Size([1, 256]) assert emdedding.shape == torch.Size([1, 256])
# LightCNN29 # LightCNN29
a = numpy.random.rand(1, 1, 128, 128).astype("float32") a = numpy.random.rand(1, 1, 128, 128).astype("float32")
t = torch.from_numpy(a) t = torch.from_numpy(a)
...@@ -119,7 +119,7 @@ def test_transforms(): ...@@ -119,7 +119,7 @@ def test_transforms():
tt = ToTensor() tt = ToTensor()
tt(sample) tt(sample)
assert isinstance(sample['image'], torch.Tensor) assert isinstance(sample['image'], torch.Tensor)
# grayscale # grayscale
image_gray = numpy.random.rand(128, 128).astype("uint8") image_gray = numpy.random.rand(128, 128).astype("uint8")
sample_gray = {'image': image_gray} sample_gray = {'image': image_gray}
tt(sample_gray) tt(sample_gray)
...@@ -253,7 +253,7 @@ def test_conv_autoencoder(): ...@@ -253,7 +253,7 @@ def test_conv_autoencoder():
Test the ConvAutoencoder class. Test the ConvAutoencoder class.
""" """
from bob.learn.pytorch.architectures import ConvAutoencoder from bob.learn.pytorch.architectures import ConvAutoencoder
batch = torch.randn(1, 3, 64, 64) batch = torch.randn(1, 3, 64, 64)
model = ConvAutoencoder() model = ConvAutoencoder()
output = model(batch) output = model(batch)
...@@ -290,3 +290,23 @@ def test_extractors(): ...@@ -290,3 +290,23 @@ def test_extractors():
output = extractor(data) output = extractor(data)
assert output.shape[0] == 256 assert output.shape[0] == 256
def test_two_layer_mlp():
"""
Test the TwoLayerMLP class.
"""
from bob.learn.pytorch.architectures import TwoLayerMLP
batch = torch.randn(10, 1, 100)
model = TwoLayerMLP(in_features = 100,
n_hidden_relu = 10,
apply_sigmoid = True)
output = model(batch)
assert list(output.shape) == [10, 1]
model.apply_sigmoid = False
output = model(batch)
assert list(output.shape) == [10, 1]
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment