Skip to content
Snippets Groups Projects
Commit 5bf3fb9f authored by Guillaume HEUSCH's avatar Guillaume HEUSCH
Browse files

[trainers] added the DR-GAN trainer (WIP)

parent a2026c5f
Branches
Tags
No related merge requests found
#!/usr/bin/env python
# encoding: utf-8
import numpy
import time
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
import torchvision.utils as vutils
import bob.core
logger = bob.core.log.setup("bob.learn.pytorch")
class DRGANTrainer(object):
"""
Class to train a DR-GAN
**Parameters**
encoder: pytorch nn.Module
The encoder network
decoder: pytorch nn.Module
The decoder network
discriminator: pytorch nn.Module
The discriminator network
image_size: list
The size of the images in this format: [channels,height, width]
batch_size: int
The size of your minibatch
noise_dim: int
The dimension of the noise (input to the generator)
conditional_dim: int
The dimension of the conditioning variable
latent_dim: int
The dimension of the encoded ID
use_gpu: boolean
If you would like to use the gpu
verbosity_level: int
The level of verbosity output to stdout
"""
def __init__(self, encoder, decoder, discriminator, image_size, batch_size=64, noise_dim=100, conditional_dim=13, latent_dim=320, use_gpu=False, verbosity_level=2):
bob.core.log.set_verbosity_level(logger, verbosity_level)
self.encoder = encoder
self.decoder = decoder
self.discriminator = discriminator
self.image_size = image_size
self.batch_size = batch_size
self.noise_dim = noise_dim
self.conditional_dim = conditional_dim
self.latent_dim = latent_dim
self.use_gpu = use_gpu
# fixed conditional noise - used to generate samples (one for each value of the conditional variable)
self.fixed_noise = torch.FloatTensor(self.conditional_dim, noise_dim, 1, 1).normal_(0, 1)
self.fixed_one_hot = torch.FloatTensor(self.conditional_dim, self.conditional_dim, 1, 1).zero_()
for k in range(self.conditional_dim):
self.fixed_one_hot[k, k] = 1
# TODO: figuring out the CPU/GPU thing - Guillaume HEUSCH, 17-11-2017
self.fixed_noise = Variable(self.fixed_noise)
self.fixed_one_hot = Variable(self.fixed_one_hot)
# binary cross-entropy loss
self.criterion_gan = nn.BCELoss()
self.criterion_pose = nn.CrossEntropyLoss() # index is expected as target (and not one-hot)
self.criterion_id = nn.CrossEntropyLoss()
# move stuff to GPU if needed
if self.use_gpu:
self.discriminator.cuda()
self.netG.cuda()
self.criterion_gan.cuda()
self.criterion_pose.cuda()
self.criterion_id.cuda()
def train(self, dataloader, n_epochs=10, learning_rate=0.0002, beta1=0.5, output_dir='out'):
"""
Function that performs the training.
**Parameters**
dataloader: pytorch DataLoader
The dataloader for your data
n_epochs: int
The number of epochs you would like to train for
learning_rate: float
The learning rate for Adam optimizer
beta1: float
The beta1 for Adam optimizer
output_dir: path
The directory where you would like to output images and models
"""
real_label = 1
fake_label = 0
# setup optimizer
generator_params = list(self.encoder.parameters()) + list(self.decoder.parameters())
optimizerD = optim.Adam(self.discriminator.parameters(), lr=learning_rate, betas=(beta1, 0.999))
optimizerG = optim.Adam(generator_params, lr=learning_rate, betas=(beta1, 0.999))
# get a fixed encoded id for sampling
fixed_image = dataloader.dataset[0]['image'].numpy()
fixed_id = dataloader.dataset[0]['id']
fixed_pose = dataloader.dataset[0]['pose']
from matplotlib import pyplot
pyplot.title("ID -> {}, pose {}".format(fixed_id, fixed_pose))
pyplot.imshow(numpy.rollaxis(numpy.rollaxis(fixed_image, 2),2))
pyplot.show()
number_of_ids = self.discriminator.number_of_ids
for epoch in range(n_epochs):
for i, data in enumerate(dataloader, 0):
start = time.time()
# get the data and pose labels
real_images = data['image']
poses = data['pose']
ids = data['id']
# WARNING: the last batch could be smaller than the provided size
batch_size = len(real_images)
# create the Tensors with the right batch size
noise = torch.FloatTensor(batch_size, self.noise_dim, 1, 1).normal_(0, 1)
label_gan = torch.FloatTensor(batch_size)
# create the one hot conditional vector on pose (decoder)
one_hot_vector = torch.FloatTensor(batch_size, self.conditional_dim, 1, 1).zero_()
for k in range(batch_size):
one_hot_vector[k, poses[k]] = 1
# move stuff to GPU if needed
if self.use_gpu:
real_images = real_images.cuda()
label_gan = label_gan.cuda()
poses = poses.cuda()
ids = ids.cuda()
noise = noise.cuda()
one_hot_vector = one_hot_vector.cuda()
# =============
# DISCRIMINATOR
# =============
self.discriminator.zero_grad()
# === REAL DATA ===
label_gan.resize_(batch_size).fill_(real_label)
imagev = Variable(real_images)
label_gan_v = Variable(label_gan)
label_pose_v = Variable(poses)
label_id_v = Variable(ids)
output_real = self.discriminator(imagev)
errD_id = self.criterion_id(output_real[:, :number_of_ids], label_id_v)
errD_pose = self.criterion_pose(output_real[:, number_of_ids:(number_of_ids + self.conditional_dim)], label_pose_v)
errD_gan = self.criterion_gan(output_real[:, -1], label_gan_v)
print errD_id
print errD_pose
print errD_gan
import sys
sys.exit()
# === FAKE DATA ===
noisev = Variable(noise)
one_hot_vv = Variable(one_hot_vector)
fake = self.netG(noisev, one_hot_vv)
labelv = Variable(label.fill_(fake_label))
output_fake = self.discriminator(fake, one_hot_fmv)
errD_fake = self.criterion(output_fake, labelv)
errD_fake.backward(retain_graph=True)
# perform optimization (i.e. update discriminator parameters)
errD = errD_real + errD_fake
optimizerD.step()
# =========
# GENERATOR
# =========
self.netG.zero_grad()
labelv = Variable(label.fill_(real_label)) # fake labels are real for generator cost
output_generated = self.discriminator(fake, one_hot_fmv)
errG = self.criterion(output_generated, labelv)
errG.backward()
optimizerG.step()
end = time.time()
logger.info("[{}/{}][{}/{}] => Loss D = {} -- Loss G = {} (time spent: {})".format(epoch, n_epochs, i, len(dataloader), errD.data[0], errG.data[0], (end-start)))
# save generated images at every epoch
# TODO: model moved to CPU and back and I don't really know why (expected CPU tensor error)
# To summarize:
# tried to move tensors, variables on the GPU -> does not work
# let the tensors on the CPU -> does not work
# => model has to be brought back to the CPU :/
self.netG = self.netG.cpu()
fake_examples = self.netG(self.fixed_noise, self.fixed_one_hot)
self.netG = self.netG.cuda()
vutils.save_image(fake_examples.data, '%s/fake_samples_epoch_%03d.png' % (output_dir, epoch), normalize=True)
# do checkpointing
torch.save(self.netG.state_dict(), '%s/netG_epoch_%d.pth' % (output_dir, epoch))
torch.save(self.discriminator.state_dict(), '%s/discriminator_epoch_%d.pth' % (output_dir, epoch))
from .DCGANTrainer import DCGANTrainer from .DCGANTrainer import DCGANTrainer
from .ConditionalGANTrainer import ConditionalGANTrainer from .ConditionalGANTrainer import ConditionalGANTrainer
from .ImprovedWassersteinCGANTrainer import IWCGANTrainer from .ImprovedWassersteinCGANTrainer import IWCGANTrainer
from .DRGANTrainer import DRGANTrainer
# gets sphinx autodoc done right - don't remove it # gets sphinx autodoc done right - don't remove it
__all__ = [_ for _ in dir() if not _.startswith('_')] __all__ = [_ for _ in dir() if not _.startswith('_')]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment