Skip to content
Snippets Groups Projects
Commit c1d31df1 authored by Guillaume HEUSCH's avatar Guillaume HEUSCH
Browse files

[architectures, script] added the architecture of the DRGAN, and the script to...

[architectures, script] added the architecture of the DRGAN, and the script to train it (trainer not done yet though)
parent 1b9684dc
No related branches found
No related tags found
No related merge requests found
#!/usr/bin/env python
# encoding: utf-8
import torch
import torch.nn as nn
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
m.weight.data.normal_(0.0, 0.02)
elif classname.find('BatchNorm') != -1:
m.weight.data.normal_(1.0, 0.02)
m.bias.data.fill_(0)
class DRGAN_encoder(nn.Module):
"""
Class defining the decoder in the DR-GAN architecture
**Parameters**
image_size: tuple
The dimension of the image (CxHxW)
latent_dim: int
The dimension of the encoded ID
"""
def __init__(self, image_size, latent_dim):
# conv2d(in_channels, out_channels (i.e. number of feature maps), kernel size, stride, padding)
super(DRGAN_encoder, self).__init__()
self.ngpu = 1
n_fm_first = 64 # number of feature maps at the first layer
self.main = nn.Sequential(
# input is (nc) x 64 x 64, output is n_mfirst x 64/2 x 64/2
nn.Conv2d(image_size[0], n_fm_first, 4, 2, 1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
# state size. (n_fm_first) x 32 x 32
nn.Conv2d(n_fm_first, n_fm_first * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(n_fm_first * 2),
nn.LeakyReLU(0.2, inplace=True),
# state size. (n_fm_first*2) x 16 x 16
nn.Conv2d(n_fm_first * 2, n_fm_first * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(n_fm_first * 4),
nn.LeakyReLU(0.2, inplace=True),
# state size. (n_fm_first*4) x 8 x 8
nn.Conv2d(n_fm_first * 4, n_fm_first * 8, 4, 2, 1, bias=False),
nn.BatchNorm2d(n_fm_first * 8),
nn.LeakyReLU(0.2, inplace=True),
# state size. (n_fm_first*8) x 4 x 4
nn.Conv2d(n_fm_first * 8, (latent_dim), 4, 1, 0, bias=False),
# average pooling
nn.AvgPool2d(4, stride=1)
)
def forward(self, x):
"""
Forward function for the encoder.
**Parameters**
x: pyTorch Variable
The minibatch of images to encode.
"""
if isinstance(x.data, torch.cuda.FloatTensor) and self.ngpu > 1:
output = nn.parallel.data_parallel(self.main, x, range(self.ngpu))
else:
output = self.main(x)
return output
class DRGAN_decoder(nn.Module):
"""
Class defining the decoder in the DR-GAN architecture
**Parameters**
image_size: tuple
The dimension of the image (CxHxW)
noise_dim: int
The dimension of the noise
latent_dim: int
The dimension of the encoded ID
conditional_dim: int
The dimension of the conditioning variable
"""
def __init__(self, image_size, noise_dim, latent_dim, conditional_dim):
super(DRGAN_decoder, self).__init__()
self.ngpu = 1 # usually, we don't have more than one GPU
n_fm_first = 64 # number of feature maps at the first layer
self.main = nn.Sequential(
# input is Z+ID+C , going into a convolution
nn.ConvTranspose2d((noise_dim + latent_dim + conditional_dim), n_fm_first * 8, 4, 1, 0, bias=False),
nn.BatchNorm2d(n_fm_first * 8),
nn.ReLU(True),
# state size. (n_fm_first*8) x 4 x 4
nn.ConvTranspose2d(n_fm_first * 8, n_fm_first * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(n_fm_first * 4),
nn.ReLU(True),
# state size. (n_fm_first*4) x 8 x 8
nn.ConvTranspose2d(n_fm_first * 4, n_fm_first * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(n_fm_first * 2),
nn.ReLU(True),
# state size. (n_fm_first*2) x 16 x 16
nn.ConvTranspose2d(n_fm_first * 2, n_fm_first, 4, 2, 1, bias=False),
nn.BatchNorm2d(n_fm_first),
nn.ReLU(True),
# state size. (n_fm_first) x 32 x 32
nn.ConvTranspose2d(n_fm_first, image_size[0], 4, 2, 1, bias=False),
nn.Tanh()
# state size. (nc) x 64 x 64
)
def forward(self, z, y, f):
"""
Forward function for the decoder.
**Parameters**
z: pyTorch Variable
The minibatch of noise.
y: pyTorch Variable
The conditional one hot encoded vector for the minibatch.
f: pyTorch Variable
The encoded ID for the minibatch
"""
decoder_input = torch.cat((z, y, f), 1)
if isinstance(decoder_input.data, torch.cuda.FloatTensor) and self.ngpu > 1:
output = nn.parallel.data_parallel(self.main, decoder_input, range(self.ngpu))
else:
output = self.main(decoder_input)
class DRGAN_discriminator(nn.Module):
"""
Class defining the discriminator in the DR-GAN architecture
**Parameters**
image_size: tuple
The dimension of the image (CxHxW)
number_of_ids: int
The number of identities in the DB
conditional_dim: int
The dimension of the conditioning variable
"""
def __init__(self, image_size, number_of_ids, conditional_dim):
super(DRGAN_discriminator, self).__init__()
self.ngpu = 1
n_fm_first = 64
self.main = nn.Sequential(
# input is (nc) x 64 x 64
nn.Conv2d(image_size[0], n_fm_first, 4, 2, 1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
# state size. (n_fm_first) x 32 x 32
nn.Conv2d(n_fm_first, n_fm_first * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(n_fm_first * 2),
nn.LeakyReLU(0.2, inplace=True),
# state size. (n_fm_first*2) x 16 x 16
nn.Conv2d(n_fm_first * 2, n_fm_first * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(n_fm_first * 4),
nn.LeakyReLU(0.2, inplace=True),
# state size. (n_fm_first*4) x 8 x 8
nn.Conv2d(n_fm_first * 4, n_fm_first * 8, 4, 2, 1, bias=False),
nn.BatchNorm2d(n_fm_first * 8),
nn.LeakyReLU(0.2, inplace=True),
# state size. (n_fm_first*8) x 4 x 4
nn.Conv2d(n_fm_first * 8, (number_of_ids + conditional_dim + 1), 4, 1, 0, bias=False),
nn.Sigmoid()
)
def forward(self, x):
"""
Forward function for the encoder.
**Parameters**
x: pyTorch Variable
The minibatch of images to process.
"""
if isinstance(x.data, torch.cuda.FloatTensor) and self.ngpu > 1:
output = nn.parallel.data_parallel(self.main, x, range(self.ngpu))
else:
output = self.main(x)
return output
......@@ -7,6 +7,10 @@ from .ConditionalGAN import ConditionalGAN_discriminator
from .WassersteinCGAN import WCGAN_generator
from .WassersteinCGAN import WCGAN_discriminator
from .DRGAN import DRGAN_encoder
from .DRGAN import DRGAN_decoder
from .DRGAN import DRGAN_discriminator
from .DCGAN import weights_init
# gets sphinx autodoc done right - don't remove it
......
#!/usr/bin/env python
# encoding: utf-8
""" Train a DR-GAN
Usage:
%(prog)s [--latent-dim=<int>] [--noise-dim=<int>] [--conditional-dim=<int>]
[--batch-size=<int>] [--epochs=<int>] [--sample=<int>]
[--output-dir=<path>] [--use-gpu] [--seed=<int>] [--verbose ...]
Options:
-h, --help Show this screen.
-V, --version Show version.
-l, --latent-dim=<int> the dimension of the encoded ID [default: 320]
-n, --noise-dim=<int> the dimension of the noise [default: 100]
-c, --conditional-dim=<int> the dimension of the noise [default: 100]
-b, --batch-size=<int> The size of your mini-batch [default: 64]
-e, --epochs=<int> The number of training epochs [default: 100]
-s, --sample=<int> Save generated images at every 'sample' batch iteration [default: 100000000000]
-o, --output-dir=<path> Dir to save the logs, models and images [default: ./drgan-light-mpie-casia/]
-g, --use-gpu Use the GPU
-S, --seed=<int> The random seed [default: 3]
-v, --verbose Increase the verbosity (may appear multiple times).
Example:
To run the training process
$ %(prog)s --batch-size 64 --epochs 25 --output-dir drgan
See '%(prog)s --help' for more information.
"""
import os, sys
import pkg_resources
import bob.core
logger = bob.core.log.setup("bob.learn.pytorch")
from docopt import docopt
version = pkg_resources.require('bob.learn.pytorch')[0].version
import numpy
import bob.io.base
# torch
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.utils as vutils
from torch.autograd import Variable
# data and architecture from the package
from bob.learn.pytorch.datasets import MultiPIEDataset
from bob.learn.pytorch.datasets import RollChannels
from bob.learn.pytorch.datasets import ToTensor
from bob.learn.pytorch.datasets import Normalize
from bob.learn.pytorch.architectures import weights_init
from bob.learn.pytorch.architectures import DRGAN_encoder
from bob.learn.pytorch.architectures import DRGAN_decoder
from bob.learn.pytorch.architectures import DRGAN_discriminator
def main(user_input=None):
# Parse the command-line arguments
if user_input is not None:
arguments = user_input
else:
arguments = sys.argv[1:]
prog = os.path.basename(sys.argv[0])
completions = dict(prog=prog, version=version,)
args = docopt(__doc__ % completions,argv=arguments,version='Train DR-GAN (%s)' % version,)
# verbosity
verbosity_level = args['--verbose']
bob.core.log.set_verbosity_level(logger, verbosity_level)
# get the arguments
noise_dim = int(args['--noise-dim'])
latent_dim = int(args['--latent-dim'])
conditional_dim = int(args['--conditional-dim'])
batch_size = int(args['--batch-size'])
epochs = int(args['--epochs'])
sample = int(args['--sample'])
output_dir = str(args['--output-dir'])
seed = int(args['--seed'])
use_gpu = bool(args['--use-gpu'])
images_dir = os.path.join(output_dir, 'samples')
log_dir = os.path.join(output_dir, 'logs')
model_dir = os.path.join(output_dir, 'models')
# process on the arguments / options
torch.manual_seed(seed)
if use_gpu:
torch.cuda.manual_seed_all(seed)
if torch.cuda.is_available() and not use_gpu:
logger.warn("You have a CUDA device, so you should probably run with --use-gpu")
bob.io.base.create_directories_safe(images_dir)
bob.io.base.create_directories_safe(log_dir)
bob.io.base.create_directories_safe(images_dir)
# ============
# === DATA ===
# ============
# WARNING with the transforms ... act on labels too, at some point, I may have to write my own
# Also, in 'ToTensor', there is a reshape performed from: HxWxC to CxHxW
face_dataset = MultiPIEDataset(root_dir='/idiap/temp/heusch/data/multipie-cropped-64x64',
frontal_only=True,
transform=transforms.Compose([
RollChannels(), # bob to skimage:
ToTensor(),
Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
)
dataloader = torch.utils.data.DataLoader(face_dataset, batch_size=batch_size, shuffle=True)
logger.info("There are {} training images".format(len(face_dataset)))
# get the image size
image_size = face_dataset[0]['image'].numpy().shape
# get the number of ids
number_of_ids = numpy.max(face_dataset.id_labels)
# ===============
# === NETWORK ===
# ===============
ngpu = 1 # usually we don't have more than one GPU
encoder = DRGAN_encoder(image_size, latent_dim)
encoder.apply(weights_init)
logger.info("Encoder architecture: {}".format(encoder))
decoder = DRGAN_decoder(image_size, noise_dim, latent_dim, conditional_dim)
decoder.apply(weights_init)
logger.info("Generator architecture: {}".format(decoder))
discriminator = DRGAN_discriminator(image_size, number_of_ids, conditional_dim)
discriminator.apply(weights_init)
logger.info("Discriminator architecture: {}".format(discriminator))
# ===============
# === TRAINER ===
# ===============
#trainer = DRGANTrainer(encoder, decoder, discriminator, batch_size=batch_size, latent_dim=latent_dim, noise_dim=noise_dim, use_gpu=use_gpu, verbosity_level=verbosity_level)
#trainer.train(dataloader, n_epochs=epochs, output_dir=output_dir)
......@@ -77,6 +77,7 @@ setup(
'train_conditionalgan_multipie.py = bob.learn.pytorch.scripts.train_conditionalgan_multipie:main',
'train_conditionalgan_casia.py = bob.learn.pytorch.scripts.train_conditionalgan_casia:main',
'train_wcgan_multipie.py = bob.learn.pytorch.scripts.train_wcgan_multipie:main',
'train_drgan_multipie.py = bob.learn.pytorch.scripts.train_drgan_multipie:main',
],
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment