Skip to content
Snippets Groups Projects
Commit 1494e623 authored by Guillaume HEUSCH's avatar Guillaume HEUSCH
Browse files

Merge branch 'ae_update' into 'master'

Added an option to return a latent embedding to the ConvAutoencoder class + unit test for this case

See merge request !8
parents 8bd13a75 b9d1a978
No related branches found
No related tags found
1 merge request!8Added an option to return a latent embedding to the ConvAutoencoder class + unit test for this case
Pipeline #26410 passed
......@@ -12,9 +12,22 @@ from torch import nn
# Define the network:
class ConvAutoencoder(nn.Module):
"""
A class defining a simple convolutional autoencoder.
def __init__(self):
Attributes
----------
return_latent_embedding : bool
If set to ``True`` forward() method returns a latent
emebedding (encoder output), otherwise a reconstructed
image is returned. Default: ``False``
"""
def __init__(self, return_latent_embedding = False):
super(ConvAutoencoder, self).__init__()
self.return_latent_embedding = return_latent_embedding
self.encoder = nn.Sequential(nn.Conv2d(3, 16, 5, padding=2),
nn.ReLU(True),
nn.MaxPool2d(2),
......@@ -45,5 +58,9 @@ class ConvAutoencoder(nn.Module):
"""
x = self.encoder(x)
x = self.decoder(x)
if self.return_latent_embedding:
return self.encoder(x)
return x
......@@ -217,3 +217,10 @@ def test_conv_autoencoder():
assert batch.shape == output.shape
model_embeddings = ConvAutoencoder(return_latent_embedding = True)
embedding = model_embeddings(batch)
assert list(embedding.shape) == [1, 16, 5, 5]
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment