Skip to content
Snippets Groups Projects

Added an option to return a latent embedding to the ConvAutoencoder class + unit test for this case

Merged Olegs NIKISINS requested to merge ae_update into master
2 files
+ 25
1
Compare changes
  • Side-by-side
  • Inline
Files
2
@@ -12,9 +12,22 @@ from torch import nn
@@ -12,9 +12,22 @@ from torch import nn
# Define the network:
# Define the network:
class ConvAutoencoder(nn.Module):
class ConvAutoencoder(nn.Module):
 
"""
 
A class defining a simple convolutional autoencoder.
def __init__(self):
Attributes
 
----------
 
return_latent_embedding : bool
 
If set to ``True`` forward() method returns a latent
 
emebedding (encoder output), otherwise a reconstructed
 
image is returned. Default: ``False``
 
"""
 
 
def __init__(self, return_latent_embedding = False):
super(ConvAutoencoder, self).__init__()
super(ConvAutoencoder, self).__init__()
 
 
self.return_latent_embedding = return_latent_embedding
 
self.encoder = nn.Sequential(nn.Conv2d(3, 16, 5, padding=2),
self.encoder = nn.Sequential(nn.Conv2d(3, 16, 5, padding=2),
nn.ReLU(True),
nn.ReLU(True),
nn.MaxPool2d(2),
nn.MaxPool2d(2),
@@ -45,5 +58,9 @@ class ConvAutoencoder(nn.Module):
@@ -45,5 +58,9 @@ class ConvAutoencoder(nn.Module):
"""
"""
x = self.encoder(x)
x = self.encoder(x)
x = self.decoder(x)
x = self.decoder(x)
 
 
if self.return_latent_embedding:
 
return self.encoder(x)
 
return x
return x
Loading