Commit b9d1a978 authored by Olegs NIKISINS's avatar Olegs NIKISINS

Added an option to return a latent embedding to the ConvAutoencoder class + unit test for this case

parent 8bd13a75
Pipeline #26407 passed with stage
in 9 minutes and 28 seconds
......@@ -12,9 +12,22 @@ from torch import nn
# Define the network:
class ConvAutoencoder(nn.Module):
"""
A class defining a simple convolutional autoencoder.
def __init__(self):
Attributes
----------
return_latent_embedding : bool
If set to ``True`` forward() method returns a latent
emebedding (encoder output), otherwise a reconstructed
image is returned. Default: ``False``
"""
def __init__(self, return_latent_embedding = False):
super(ConvAutoencoder, self).__init__()
self.return_latent_embedding = return_latent_embedding
self.encoder = nn.Sequential(nn.Conv2d(3, 16, 5, padding=2),
nn.ReLU(True),
nn.MaxPool2d(2),
......@@ -45,5 +58,9 @@ class ConvAutoencoder(nn.Module):
"""
x = self.encoder(x)
x = self.decoder(x)
if self.return_latent_embedding:
return self.encoder(x)
return x
......@@ -217,3 +217,10 @@ def test_conv_autoencoder():
assert batch.shape == output.shape
model_embeddings = ConvAutoencoder(return_latent_embedding = True)
embedding = model_embeddings(batch)
assert list(embedding.shape) == [1, 16, 5, 5]
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment