Skip to content
Snippets Groups Projects
Commit 4fb98320 authored by Guillaume HEUSCH's avatar Guillaume HEUSCH
Browse files

[architectures] added the original DR-GAN architecture

parent 1814b429
Branches
Tags
No related merge requests found
#!/usr/bin/env python
# encoding: utf-8
import torch
import torch.nn as nn
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
m.weight.data.normal_(0.0, 0.02)
elif classname.find('BatchNorm') != -1:
m.weight.data.normal_(1.0, 0.02)
m.bias.data.fill_(0)
class DRGANOriginal_encoder(nn.Module):
"""
Class defining the decoder in the DR-GAN architecture
**Parameters**
image_size: tuple
The dimension of the image (CxHxW)
latent_dim: int
The dimension of the encoded ID
"""
def __init__(self, image_size, latent_dim):
# conv2d(in_channels, out_channels (i.e. number of feature maps), kernel size, stride, padding)
super(DRGANOriginal_encoder, self).__init__()
self.ngpu = 1
self.main = nn.Sequential(
# input is 3x96x96, output is 32x96x96
nn.Conv2d(image_size[0], 32, 3, 1, 1, bias=False),
nn.BatchNorm2d(32),
nn.ELU(inplace=True),
# input is 32x96x96, output is 64x96x96
nn.Conv2d(32, 64, 3, 1, 1, bias=False),
nn.BatchNorm2d(64),
nn.ELU(inplace=True),
# ------------------------------------------
# input is 64x96x96, output is 64x48x48
nn.Conv2d(64, 64, 3, 2, 1, bias=False),
nn.BatchNorm2d(64),
nn.ELU(inplace=True),
# input is 64x48x48, output is 64x48x48
nn.Conv2d(64, 64, 3, 1, 1, bias=False),
nn.BatchNorm2d(64),
nn.ELU(inplace=True),
# input is 64x48x48 , output is 128x48x48
nn.Conv2d(64, 128, 3, 1, 1, bias=False),
nn.BatchNorm2d(128),
nn.ELU(inplace=True),
# ------------------------------------------
# input is 128x48x48, output is 128x24x24
nn.Conv2d(128, 128, 3, 2, 1, bias=False),
nn.BatchNorm2d(128),
nn.ELU(inplace=True),
# input is 128x24x24, output is 96x24x24
nn.Conv2d(128, 96, 3, 1, 1, bias=False),
nn.BatchNorm2d(96),
nn.ELU(inplace=True),
# input is 96x24x24, output is 192x24x24
nn.Conv2d(96, 192, 3, 1, 1, bias=False),
nn.BatchNorm2d(192),
nn.ELU(inplace=True),
# ------------------------------------------
# input is 192x24x24, output is 192x12x12
nn.Conv2d(192, 192, 3, 2, 1, bias=False),
nn.BatchNorm2d(192),
nn.ELU(inplace=True),
# input is 192x12x12, output is 128x12x12
nn.Conv2d(192, 128, 3, 1, 1, bias=False),
nn.BatchNorm2d(128),
nn.ELU(inplace=True),
# input is 128x12x12, output is 256x12x12
nn.Conv2d(128, 256, 3, 1, 1, bias=False),
nn.ELU(inplace=True),
# ------------------------------------------
# input is 256x12x12, output is 256x6x6
nn.Conv2d(256, 256, 3, 2, 1, bias=False),
nn.BatchNorm2d(256),
nn.ELU(inplace=True),
# input is 256x6x6, output is 160x6x6
nn.Conv2d(256, 160, 3, 1, 1, bias=False),
nn.BatchNorm2d(160),
nn.ELU(inplace=True),
# input is 160x6x6, output is (latent_dim)x6x6
nn.Conv2d(160, latent_dim, 3, 1, 1, bias=False),
nn.BatchNorm2d(latent_dim),
# ------------------------------------------
# average pool
nn.AvgPool2d(6, stride=1)
# dropout ?
)
def forward(self, x):
"""
Forward function for the encoder.
**Parameters**
x: pyTorch Variable
The minibatch of images to encode.
"""
if isinstance(x.data, torch.cuda.FloatTensor) and self.ngpu > 1:
output = nn.parallel.data_parallel(self.main, x, range(self.ngpu))
else:
output = self.main(x)
print "Encoder output: {output}".format(output)
return output
class DRGANOriginal_decoder(nn.Module):
"""
Class defining the decoder in the DR-GAN architecture
**Parameters**
image_size: tuple
The dimension of the image (CxHxW)
noise_dim: int
The dimension of the noise
latent_dim: int
The dimension of the encoded ID
conditional_dim: int
The dimension of the conditioning variable
"""
def __init__(self, image_size, noise_dim, latent_dim, conditional_dim):
super(DRGANOriginal_decoder, self).__init__()
self.ngpu = 1 # usually, we don't have more than one GPU
self.main = nn.Sequential(
# input is Z+ID+C , going into a convolution, output is 320x6x6
nn.ConvTranspose2d((noise_dim + latent_dim + conditional_dim), 320, 6, 1, 0, bias=False),
# dropout ?
nn.BatchNorm2d(320),
nn.ELU(inplace=True),
# input is 320x6x6, output is 160x6x6
nn.ConvTranspose2d(320, 160, 3, 1, 1, bias=False),
nn.BatchNorm2d(160),
nn.ELU(inplace=True),
# input is 160x6x6, output is 256x6x6
nn.ConvTranspose2d(160, 256, 3, 1, 1, bias=False),
nn.BatchNorm2d(256),
nn.ELU(inplace=True),
# ------------------------------------------
# input is 256x6x6, output is 256x12x12
nn.ConvTranspose2d(256, 256, 3, 2, 1, bias=False),
nn.BatchNorm2d(256),
nn.ELU(inplace=True),
# input is 256x12x12, output is 128x12x12
nn.ConvTranspose2d(256, 128, 3, 1, 1, bias=False),
nn.BatchNorm2d(128),
nn.ELU(inplace=True),
# input is 128x12x12, output is 192x12x12
nn.ConvTranspose2d(128, 192, 3, 1, 1, bias=False),
nn.BatchNorm2d(192),
nn.ELU(inplace=True),
# ------------------------------------------
# input is 192x12x12, output is 192x24x24
nn.ConvTranspose2d(192, 192, 3, 2, 1, bias=False),
nn.BatchNorm2d(192),
nn.ELU(inplace=True),
# input is 192x24x24, output is 96x24x24
nn.ConvTranspose2d(192, 96, 3, 1, 1, bias=False),
nn.BatchNorm2d(96),
nn.ELU(inplace=True),
# input is 96x24x24, output is 128x12x12
nn.ConvTranspose2d(96, 128, 3, 1, 1, bias=False),
nn.BatchNorm2d(128),
nn.ELU(inplace=True),
# ------------------------------------------
# input is 128x12x12, output is 128x48x48
nn.ConvTranspose2d(128, 128, 3, 2, 1, bias=False),
nn.BatchNorm2d(128),
nn.ELU(inplace=True),
# input is 128x48x48, output is 64x48x48
nn.ConvTranspose2d(128, 64, 3, 1, 1, bias=False),
nn.BatchNorm2d(64),
nn.ELU(inplace=True),
# input is 64x48x48, output is 64x48x48
nn.ConvTranspose2d(64, 64, 3, 1, 1, bias=False),
nn.BatchNorm2d(64),
nn.ELU(inplace=True),
# ------------------------------------------
# input is 64x48x48, output is 64x96x96
nn.ConvTranspose2d(64, 64, 3, 2, 1, bias=False),
nn.BatchNorm2d(64),
nn.ELU(inplace=True),
# input is 64x96x96, output is 32x96x96
nn.ConvTranspose2d(64, 32, 3, 1, 1, bias=False),
nn.BatchNorm2d(32),
nn.ELU(inplace=True),
# input is 32x96x96, output is 3x96x96
nn.ConvTranspose2d(96, 128, 3, 1, 1, bias=False),
nn.BatchNorm2d(192),
nn.Tanh(),
# ------------------------------------------
)
def forward(self, z, y, f):
"""
Forward function for the decoder.
**Parameters**
z: pyTorch Variable
The minibatch of noise.
y: pyTorch Variable
The conditional one hot encoded vector for the minibatch.
f: pyTorch Variable
The encoded ID for the minibatch
"""
decoder_input = torch.cat((z, y, f), 1)
if isinstance(decoder_input.data, torch.cuda.FloatTensor) and self.ngpu > 1:
output = nn.parallel.data_parallel(self.main, decoder_input, range(self.ngpu))
else:
output = self.main(decoder_input)
return output
class DRGANOriginal_discriminator(nn.Module):
"""
Class defining the discriminator in the DR-GAN architecture
**Parameters**
image_size: tuple
The dimension of the image (CxHxW)
number_of_ids: int
The number of identities in the DB
conditional_dim: int
The dimension of the conditioning variable
"""
def __init__(self, image_size, number_of_ids, conditional_dim):
super(DRGANOriginal_discriminator, self).__init__()
self.number_of_ids = number_of_ids
self.ngpu = 1
self.main = nn.Sequential(
# input is 3x96x96, output is 32x96x96
nn.Conv2d(image_size[0], 32, 3, 1, 1, bias=False),
nn.BatchNorm2d(32),
nn.ELU(inplace=True),
# input is 32x96x96, output is 64x96x96
nn.Conv2d(32, 64, 3, 1, 1, bias=False),
nn.BatchNorm2d(64),
nn.ELU(inplace=True),
# ------------------------------------------
# input is 64x96x96, output is 64x48x48
nn.Conv2d(64, 64, 3, 2, 1, bias=False),
nn.BatchNorm2d(64),
nn.ELU(inplace=True),
# input is 64x48x48, output is 64x48x48
nn.Conv2d(64, 64, 3, 1, 1, bias=False),
nn.BatchNorm2d(64),
nn.ELU(inplace=True),
# input is 64x48x48 , output is 128x48x48
nn.Conv2d(64, 128, 3, 1, 1, bias=False),
nn.BatchNorm2d(128),
nn.ELU(inplace=True),
# ------------------------------------------
# input is 128x48x48, output is 128x24x24
nn.Conv2d(128, 128, 3, 2, 1, bias=False),
nn.BatchNorm2d(128),
nn.ELU(inplace=True),
# input is 128x24x24, output is 96x24x24
nn.Conv2d(128, 96, 3, 1, 1, bias=False),
nn.BatchNorm2d(96),
nn.ELU(inplace=True),
# input is 96x24x24, output is 192x24x24
nn.Conv2d(96, 192, 3, 1, 1, bias=False),
nn.BatchNorm2d(192),
nn.ELU(inplace=True),
# ------------------------------------------
# input is 192x24x24, output is 192x12x12
nn.Conv2d(192, 192, 3, 2, 1, bias=False),
nn.BatchNorm2d(192),
nn.ELU(inplace=True),
# input is 192x12x12, output is 128x12x12
nn.Conv2d(192, 128, 3, 1, 1, bias=False),
nn.BatchNorm2d(128),
nn.ELU(inplace=True),
# input is 128x12x12, output is 256x12x12
nn.Conv2d(128, 256, 3, 1, 1, bias=False),
nn.ELU(inplace=True),
# ------------------------------------------
# input is 256x12x12, output is 256x6x6
nn.Conv2d(256, 256, 3, 2, 1, bias=False),
nn.BatchNorm2d(256),
nn.ELU(inplace=True),
# input is 256x6x6, output is 160x6x6
nn.Conv2d(256, 160, 3, 1, 1, bias=False),
nn.BatchNorm2d(160),
nn.ELU(inplace=True),
# input is 160x6x6, output is (latent_dim)x6x6
nn.Conv2d(160, latent_dim, 3, 1, 1, bias=False),
nn.BatchNorm2d(latent_dim),
# ------------------------------------------
# --- average pool
nn.AvgPool2d(6, stride=1)
# --- fully connected
nn.Linear(320, (number_of_ids + conditional_dim + 1))
)
def forward(self, x):
"""
Forward function for the encoder.
**Parameters**
x: pyTorch Variable
The minibatch of images to process.
"""
if isinstance(x.data, torch.cuda.FloatTensor) and self.ngpu > 1:
output = nn.parallel.data_parallel(self.main, x, range(self.ngpu))
else:
output = self.main(x)
return output.squeeze()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment