Skip to content
Snippets Groups Projects

Added an option to return a latent embedding to the ConvAutoencoder class + unit test for this case

Merged Olegs NIKISINS requested to merge ae_update into master
2 files
+ 25
1
Compare changes
  • Side-by-side
  • Inline
Files
2
@@ -12,9 +12,22 @@ from torch import nn
# Define the network:
class ConvAutoencoder(nn.Module):
"""
A class defining a simple convolutional autoencoder.
def __init__(self):
Attributes
----------
return_latent_embedding : bool
If set to ``True`` forward() method returns a latent
emebedding (encoder output), otherwise a reconstructed
image is returned. Default: ``False``
"""
def __init__(self, return_latent_embedding = False):
super(ConvAutoencoder, self).__init__()
self.return_latent_embedding = return_latent_embedding
self.encoder = nn.Sequential(nn.Conv2d(3, 16, 5, padding=2),
nn.ReLU(True),
nn.MaxPool2d(2),
@@ -45,5 +58,9 @@ class ConvAutoencoder(nn.Module):
"""
x = self.encoder(x)
x = self.decoder(x)
if self.return_latent_embedding:
return self.encoder(x)
return x
Loading