Skip to content
Snippets Groups Projects

Resolve "Add GANs"

Merged Guillaume HEUSCH requested to merge 4-add-gans into master
3 files
+ 492
0
Compare changes
  • Side-by-side
  • Inline
Files
3
 
#!/usr/bin/env python
 
# encoding: utf-8
 
 
 
import torch
 
import torch.nn as nn
 
 
def weights_init(m):
 
"""
 
Weights initialization
 
 
**Parameters**
 
 
m:
 
The model
 
"""
 
classname = m.__class__.__name__
 
if classname.find('Conv') != -1:
 
m.weight.data.normal_(0.0, 0.02)
 
elif classname.find('BatchNorm') != -1:
 
m.weight.data.normal_(1.0, 0.02)
 
m.bias.data.fill_(0)
 
 
 
class ConditionalGAN_generator(nn.Module):
 
"""
 
Class defining the Conditional GAN generator.
 
 
**Parameters**
 
 
noise_dim: int
 
The dimension of the noise.
 
 
conditional_dim: int
 
The dimension of the conditioning variable.
 
 
channels: int
 
The number of channels in the input image (default: 3).
 
 
ngpu: int
 
The number of GPU (default: 1)
 
"""
 
def __init__(self, noise_dim, conditional_dim, channels=3, ngpu=1):
 
 
super(ConditionalGAN_generator, self).__init__()
 
self.ngpu = ngpu
 
self.conditional_dim = conditional_dim
 
 
# output dimension
 
ngf = 64
 
 
self.main = nn.Sequential(
 
# input is Z, going into a convolution
 
nn.ConvTranspose2d((noise_dim + conditional_dim), ngf * 8, 4, 1, 0, bias=False),
 
nn.BatchNorm2d(ngf * 8),
 
nn.ReLU(True),
 
# state size. (ngf*8) x 4 x 4
 
nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
 
nn.BatchNorm2d(ngf * 4),
 
nn.ReLU(True),
 
# state size. (ngf*4) x 8 x 8
 
nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
 
nn.BatchNorm2d(ngf * 2),
 
nn.ReLU(True),
 
# state size. (ngf*2) x 16 x 16
 
nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
 
nn.BatchNorm2d(ngf),
 
nn.ReLU(True),
 
# state size. (ngf) x 32 x 32
 
nn.ConvTranspose2d(ngf, channels, 4, 2, 1, bias=False),
 
nn.Tanh()
 
# state size. (nc) x 64 x 64
 
)
 
 
def forward(self, z, y):
 
"""
 
Forward function for the generator.
 
 
**Parameters**
 
 
z: pyTorch Variable
 
The minibatch of noise.
 
 
y: pyTorch Variable
 
The conditional one hot encoded vector for the minibatch.
 
"""
 
generator_input = torch.cat((z, y), 1)
 
if isinstance(generator_input.data, torch.cuda.FloatTensor) and self.ngpu > 1:
 
output = nn.parallel.data_parallel(self.main, generator_input, range(self.ngpu))
 
else:
 
output = self.main(generator_input)
 
return output
 
 
 
class ConditionalGAN_discriminator(nn.Module):
 
"""
 
Class defining the Conditional GAN discriminator.
 
 
**Parameters**
 
 
conditional_dim: int
 
The dimension of the conditioning variable.
 
 
channels: int
 
The number of channels in the input image (default: 3).
 
 
ngpu: int
 
The number of GPU (default: 1)
 
"""
 
def __init__(self, conditional_dim, channels=3, ngpu=1):
 
 
super(ConditionalGAN_discriminator, self).__init__()
 
self.conditional_dim = conditional_dim
 
self.ngpu = ngpu
 
 
# input dimension
 
ndf = 64
 
 
self.main = nn.Sequential(
 
# input is (nc) x 64 x 64
 
nn.Conv2d((channels + conditional_dim), ndf, 4, 2, 1, bias=False),
 
nn.LeakyReLU(0.2, inplace=True),
 
# state size. (ndf) x 32 x 32
 
nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
 
nn.BatchNorm2d(ndf * 2),
 
nn.LeakyReLU(0.2, inplace=True),
 
# state size. (ndf*2) x 16 x 16
 
nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
 
nn.BatchNorm2d(ndf * 4),
 
nn.LeakyReLU(0.2, inplace=True),
 
# state size. (ndf*4) x 8 x 8
 
nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
 
nn.BatchNorm2d(ndf * 8),
 
nn.LeakyReLU(0.2, inplace=True),
 
# state size. (ndf*8) x 4 x 4
 
nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
 
nn.Sigmoid()
 
)
 
 
 
def forward(self, images, y):
 
"""
 
Forward function for the discriminator.
 
 
**Parameters**
 
 
images: pyTorch Variable
 
The minibatch of input images.
 
 
y: pyTorch Variable
 
The corresponding conditional feature maps.
 
"""
 
input_discriminator = torch.cat((images, y), 1)
 
if isinstance(input_discriminator.data, torch.cuda.FloatTensor) and self.ngpu > 1:
 
output = nn.parallel.data_parallel(self.main, input_discriminator, range(self.ngpu))
 
else:
 
output = self.main(input_discriminator)
 
return output.view(-1, 1).squeeze(1)
Loading