Skip to content
Snippets Groups Projects
Commit 6c6e2fb1 authored by Guillaume HEUSCH's avatar Guillaume HEUSCH
Browse files

[architectures, scripts] added architecture for Wasserstein with GP and script to train

parent 8b4ac689
No related branches found
No related tags found
No related merge requests found
#!/usr/bin/env python
# encoding: utf-8
import torch
import torch.nn as nn
def weights_init(m):
"""
Weights initialization
**Parameters**
m:
The model
"""
classname = m.__class__.__name__
if classname.find('Conv') != -1:
m.weight.data.normal_(0.0, 0.02)
elif classname.find('BatchNorm') != -1:
m.weight.data.normal_(1.0, 0.02)
m.bias.data.fill_(0)
class WCGAN_generator(nn.Module):
"""
Class defining the Conditional GAN generator for the Improved Wasserstein training.
**Parameters**
noise_dim: int
The dimension of the noise.
conditional_dim: int
The dimension of the conditioning variable.
channels: int
The number of channels in the input image (default: 3).
ngpu: int
The number of GPU (default: 1)
"""
def __init__(self, noise_dim, conditional_dim, channels=3, ngpu=1):
super(WCGAN_generator, self).__init__()
self.ngpu = ngpu
self.conditional_dim = conditional_dim
# output dimension
ngf = 64
self.main = nn.Sequential(
# input is Z, going into a convolution
nn.ConvTranspose2d((noise_dim + conditional_dim), ngf * 8, 4, 1, 0, bias=False),
nn.BatchNorm2d(ngf * 8),
nn.ReLU(True),
# state size. (ngf*8) x 4 x 4
nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf * 4),
nn.ReLU(True),
# state size. (ngf*4) x 8 x 8
nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf * 2),
nn.ReLU(True),
# state size. (ngf*2) x 16 x 16
nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf),
nn.ReLU(True),
# state size. (ngf) x 32 x 32
nn.ConvTranspose2d(ngf, channels, 4, 2, 1, bias=False),
nn.Tanh()
# state size. (nc) x 64 x 64
)
def forward(self, z, y):
"""
Forward function for the generator.
**Parameters**
z: pyTorch Variable
The minibatch of noise.
y: pyTorch Variable
The conditional one hot encoded vector for the minibatch.
"""
generator_input = torch.cat((z, y), 1)
if isinstance(generator_input.data, torch.cuda.FloatTensor) and self.ngpu > 1:
output = nn.parallel.data_parallel(self.main, generator_input, range(self.ngpu))
else:
output = self.main(generator_input)
return output
class WCGAN_discriminator(nn.Module):
"""
Class defining the Conditional GAN discriminator for the Improved Wasserstein training.
(no batch normalization)
**Parameters**
conditional_dim: int
The dimension of the conditioning variable.
channels: int
The number of channels in the input image (default: 3).
ngpu: int
The number of GPU (default: 1)
"""
def __init__(self, conditional_dim, channels=3, ngpu=1):
super(WCGAN_discriminator, self).__init__()
self.conditional_dim = conditional_dim
self.ngpu = ngpu
# input dimension
ndf = 64
self.main = nn.Sequential(
# input is (nc) x 64 x 64
nn.Conv2d((channels + conditional_dim), ndf, 4, 2, 1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf) x 32 x 32
nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf*2) x 16 x 16
nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf*4) x 8 x 8
nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf*8) x 4 x 4
nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
nn.Sigmoid()
)
def forward(self, images, y):
"""
Forward function for the discriminator.
**Parameters**
images: pyTorch Variable
The minibatch of input images.
y: pyTorch Variable
The corresponding conditional feature maps.
"""
input_discriminator = torch.cat((images, y), 1)
if isinstance(input_discriminator.data, torch.cuda.FloatTensor) and self.ngpu > 1:
output = nn.parallel.data_parallel(self.main, input_discriminator, range(self.ngpu))
else:
output = self.main(input_discriminator)
return output.view(-1, 1).squeeze(1)
...@@ -4,6 +4,9 @@ from .DCGAN import DCGAN_discriminator ...@@ -4,6 +4,9 @@ from .DCGAN import DCGAN_discriminator
from .ConditionalGAN import ConditionalGAN_generator from .ConditionalGAN import ConditionalGAN_generator
from .ConditionalGAN import ConditionalGAN_discriminator from .ConditionalGAN import ConditionalGAN_discriminator
from .WassersteinCGAN import WCGAN_generator
from .WassersteinCGAN import WCGAN_discriminator
from .DCGAN import weights_init from .DCGAN import weights_init
# gets sphinx autodoc done right - don't remove it # gets sphinx autodoc done right - don't remove it
......
#!/usr/bin/env python
# encoding: utf-8
""" Train a Conditional GAN using Improved Wasserstein Training.
Usage:
%(prog)s [--noise-dim=<int>] [--conditional-dim=<int>]
[--batch-size=<int>] [--iterations=<int>]
[--output-dir=<path>] [--use-gpu] [--seed=<int>] [--verbose ...]
Options:
-h, --help Show this screen.
-V, --version Show version.
-n, --noise-dim=<int> The dimension of the noise [default: 100]
-c, --conditional-dim=<int> The dimension of the conditional variable [default: 13]
-b, --batch-size=<int> The size of your mini-batch [default: 64]
-e, --iterations=<int> The number of training iterations [default: 100000]
-o, --output-dir=<path> Dir to save the logs, models and images [default: ./drgan-light-mpie-casia/]
-g, --use-gpu Use the GPU
-S, --seed=<int> The random seed [default: 3]
-v, --verbose Increase the verbosity (may appear multiple times).
Example:
To run the training process
$ %(prog)s --batch-size 64 --epochs 25 --output-dir drgan
See '%(prog)s --help' for more information.
"""
import os, sys
import pkg_resources
import bob.core
logger = bob.core.log.setup("bob.learn.pytorch")
from docopt import docopt
version = pkg_resources.require('bob.learn.pytorch')[0].version
import numpy
import bob.io.base
# torch
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.utils as vutils
from torch.autograd import Variable
# data and architecture from the package
from bob.learn.pytorch.datasets import MultiPIEDataset
from bob.learn.pytorch.datasets import RollChannels
from bob.learn.pytorch.datasets import ToTensor
from bob.learn.pytorch.datasets import Normalize
from bob.learn.pytorch.architectures import weights_init
from bob.learn.pytorch.architectures import WCGAN_generator as cgenerator
from bob.learn.pytorch.architectures import WCGAN_discriminator as cdiscriminator
from bob.learn.pytorch.trainers import IWCGANTrainer as iwtrainer
def main(user_input=None):
# Parse the command-line arguments
if user_input is not None:
arguments = user_input
else:
arguments = sys.argv[1:]
prog = os.path.basename(sys.argv[0])
completions = dict(prog=prog, version=version,)
args = docopt(__doc__ % completions,argv=arguments,version='Train conditional GAN (%s)' % version,)
# verbosity
verbosity_level = args['--verbose']
bob.core.log.set_verbosity_level(logger, verbosity_level)
# get the arguments
noise_dim = int(args['--noise-dim'])
conditional_dim = int(args['--conditional-dim'])
batch_size = int(args['--batch-size'])
iterations = int(args['--iterations'])
output_dir = str(args['--output-dir'])
seed = int(args['--seed'])
use_gpu = bool(args['--use-gpu'])
images_dir = os.path.join(output_dir, 'samples')
log_dir = os.path.join(output_dir, 'logs')
model_dir = os.path.join(output_dir, 'models')
# process on the arguments / options
torch.manual_seed(seed)
if use_gpu:
torch.cuda.manual_seed_all(seed)
if torch.cuda.is_available() and not use_gpu:
logger.warn("You have a CUDA device, so you should probably run with --use-gpu")
bob.io.base.create_directories_safe(images_dir)
bob.io.base.create_directories_safe(log_dir)
bob.io.base.create_directories_safe(images_dir)
# ============
# === DATA ===
# ============
# WARNING with the transforms ... act on labels too, at some point, I may have to write my own
# Also, in 'ToTensor', there is a reshape performed from: HxWxC to CxHxW
face_dataset = MultiPIEDataset(#root_dir='/Users/guillaumeheusch/work/idiap/data/multipie-cropped-64x64',
root_dir='/idiap/temp/heusch/data/multipie-cropped-64x64',
frontal_only=False,
transform=transforms.Compose([
RollChannels(), # bob to skimage:
ToTensor(),
Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
)
dataloader = torch.utils.data.DataLoader(face_dataset, batch_size=batch_size, shuffle=True)
logger.info("There are {} training images".format(len(face_dataset)))
# ===============
# === NETWORK ===
# ===============
ngpu = 1 # usually we don't have more than one GPU
generator = cgenerator(noise_dim, conditional_dim)
generator.apply(weights_init)
logger.info("Generator architecture: {}".format(generator))
discriminator = cdiscriminator(conditional_dim)
discriminator.apply(weights_init)
logger.info("Discriminator architecture: {}".format(discriminator))
# ===============
# === TRAINER ===
# ===============
trainer = iwtrainer(generator, discriminator, [3, 64, 64], batch_size=batch_size, noise_dim=noise_dim, conditional_dim=conditional_dim, use_gpu=use_gpu, verbosity_level=verbosity_level)
trainer.train(dataloader, n_iterations=iterations, output_dir=output_dir)
from .DCGANTrainer import DCGANTrainer from .DCGANTrainer import DCGANTrainer
from .ConditionalGANTrainer import ConditionalGANTrainer from .ConditionalGANTrainer import ConditionalGANTrainer
from .ImprovedWassersteinCGANTrainer import IWCGANTrainer
# gets sphinx autodoc done right - don't remove it # gets sphinx autodoc done right - don't remove it
__all__ = [_ for _ in dir() if not _.startswith('_')] __all__ = [_ for _ in dir() if not _.startswith('_')]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment