Skip to content
Snippets Groups Projects
ConditionalGAN.py 4.17 KiB
#!/usr/bin/env python
# encoding: utf-8


import torch
import torch.nn as nn

def weights_init(m):
  """
  Weights initialization
  
  **Parameters**
  
  m: 
    The model
  """
  classname = m.__class__.__name__
  if classname.find('Conv') != -1:
    m.weight.data.normal_(0.0, 0.02)
  elif classname.find('BatchNorm') != -1:
    m.weight.data.normal_(1.0, 0.02)
    m.bias.data.fill_(0)


class ConditionalGAN_generator(nn.Module):
  """
  Class defining the Conditional GAN generator.

  **Parameters**

  noise_dim: int
    The dimension of the noise.

  conditional_dim: int
    The dimension of the conditioning variable.

  channels: int
    The number of channels in the input image (default: 3).

  ngpu: int
    The number of GPU (default: 1)
  """ 
  def __init__(self, noise_dim, conditional_dim, channels=3, ngpu=1):
    
    super(ConditionalGAN_generator, self).__init__()
    self.ngpu = ngpu
    self.conditional_dim = conditional_dim
    
    # output dimension
    ngf = 64

    self.main = nn.Sequential(
      # input is Z, going into a convolution
      nn.ConvTranspose2d((noise_dim + conditional_dim), ngf * 8, 4, 1, 0, bias=False),
      nn.BatchNorm2d(ngf * 8),
      nn.ReLU(True),
      # state size. (ngf*8) x 4 x 4
      nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
      nn.BatchNorm2d(ngf * 4),
      nn.ReLU(True),
      # state size. (ngf*4) x 8 x 8
      nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
      nn.BatchNorm2d(ngf * 2),
      nn.ReLU(True),
      # state size. (ngf*2) x 16 x 16
      nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
      nn.BatchNorm2d(ngf),
      nn.ReLU(True),
      # state size. (ngf) x 32 x 32
      nn.ConvTranspose2d(ngf, channels, 4, 2, 1, bias=False),
      nn.Tanh()
      # state size. (nc) x 64 x 64
    )

  def forward(self, z, y):
    """
      Forward function for the generator.

    **Parameters**

    z: pyTorch Variable
      The minibatch of noise.

    y: pyTorch Variable
      The conditional one hot encoded vector for the minibatch.
    """
    generator_input = torch.cat((z, y), 1)
    if isinstance(generator_input.data, torch.cuda.FloatTensor) and self.ngpu > 1:
      output = nn.parallel.data_parallel(self.main, generator_input, range(self.ngpu))
    else:
      output = self.main(generator_input)
    return output


class ConditionalGAN_discriminator(nn.Module):
  """
  Class defining the Conditional GAN discriminator.

  **Parameters**

  conditional_dim: int
    The dimension of the conditioning variable.

  channels: int
    The number of channels in the input image (default: 3).

  ngpu: int
    The number of GPU (default: 1)
  """ 
  def __init__(self, conditional_dim, channels=3, ngpu=1):
    
    super(ConditionalGAN_discriminator, self).__init__()
    self.conditional_dim = conditional_dim
    self.ngpu = ngpu
    
    # input dimension
    ndf = 64
       
    self.main = nn.Sequential(
      # input is (nc) x 64 x 64
      nn.Conv2d((channels + conditional_dim), ndf, 4, 2, 1, bias=False),
      nn.LeakyReLU(0.2, inplace=True),
      # state size. (ndf) x 32 x 32
      nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
      nn.BatchNorm2d(ndf * 2),
      nn.LeakyReLU(0.2, inplace=True),
      # state size. (ndf*2) x 16 x 16
      nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
      nn.BatchNorm2d(ndf * 4),
      nn.LeakyReLU(0.2, inplace=True),
      # state size. (ndf*4) x 8 x 8
      nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
      nn.BatchNorm2d(ndf * 8),
      nn.LeakyReLU(0.2, inplace=True),
      # state size. (ndf*8) x 4 x 4
      nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
      nn.Sigmoid()
    )
  

  def forward(self, images, y):
    """
      Forward function for the discriminator.

    **Parameters**

    images: pyTorch Variable
      The minibatch of input images.

    y: pyTorch Variable
      The corresponding conditional feature maps.
    """
    input_discriminator = torch.cat((images, y), 1)
    if isinstance(input_discriminator.data, torch.cuda.FloatTensor) and self.ngpu > 1:
      output = nn.parallel.data_parallel(self.main, input_discriminator, range(self.ngpu))
    else:
      output = self.main(input_discriminator)
    return output.view(-1, 1).squeeze(1)