Skip to content
Snippets Groups Projects
Commit b9d1a978 authored by Olegs NIKISINS's avatar Olegs NIKISINS
Browse files

Added an option to return a latent embedding to the ConvAutoencoder class + unit test for this case

parent 8bd13a75
No related branches found
No related tags found
1 merge request!8Added an option to return a latent embedding to the ConvAutoencoder class + unit test for this case
Pipeline #26407 passed
...@@ -12,9 +12,22 @@ from torch import nn ...@@ -12,9 +12,22 @@ from torch import nn
# Define the network: # Define the network:
class ConvAutoencoder(nn.Module): class ConvAutoencoder(nn.Module):
"""
A class defining a simple convolutional autoencoder.
def __init__(self): Attributes
----------
return_latent_embedding : bool
If set to ``True`` forward() method returns a latent
emebedding (encoder output), otherwise a reconstructed
image is returned. Default: ``False``
"""
def __init__(self, return_latent_embedding = False):
super(ConvAutoencoder, self).__init__() super(ConvAutoencoder, self).__init__()
self.return_latent_embedding = return_latent_embedding
self.encoder = nn.Sequential(nn.Conv2d(3, 16, 5, padding=2), self.encoder = nn.Sequential(nn.Conv2d(3, 16, 5, padding=2),
nn.ReLU(True), nn.ReLU(True),
nn.MaxPool2d(2), nn.MaxPool2d(2),
...@@ -45,5 +58,9 @@ class ConvAutoencoder(nn.Module): ...@@ -45,5 +58,9 @@ class ConvAutoencoder(nn.Module):
""" """
x = self.encoder(x) x = self.encoder(x)
x = self.decoder(x) x = self.decoder(x)
if self.return_latent_embedding:
return self.encoder(x)
return x return x
...@@ -217,3 +217,10 @@ def test_conv_autoencoder(): ...@@ -217,3 +217,10 @@ def test_conv_autoencoder():
assert batch.shape == output.shape assert batch.shape == output.shape
model_embeddings = ConvAutoencoder(return_latent_embedding = True)
embedding = model_embeddings(batch)
assert list(embedding.shape) == [1, 16, 5, 5]
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment