Skip to content
Snippets Groups Projects
Commit 8b4ac689 authored by Guillaume HEUSCH's avatar Guillaume HEUSCH
Browse files

[trainer] added the initial implementation of the improved Wasserstein trainer

parent 1ac5e8fc
Branches
Tags
No related merge requests found
#!/usr/bin/env python
# encoding: utf-8
import numpy
import time
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
import torchvision.utils as vutils
import bob.core
logger = bob.core.log.setup("bob.learn.pytorch")
class IWCGAN(object):
"""
Class to train a Conditional GAN, using the Improved Wasserstein Training method
**Parameters**
generator: pytorch nn.Module
The generator network
discriminator: pytorch nn.Module
The discriminator network
image_size: list
The size of the images in this format: [channels,height, width]
batch_size: int
The size of your minibatch
noise_dim: int
The dimension of the noise (input to the generator)
conditional_dim: int
The dimension of the conditioning variable
n_critic_update: int
The number of critic (discriminator) iterations per generator iterations.
Lambda: int
The regularization weight (gradient penalty).
use_gpu: boolean
If you would like to use the gpu
verbosity_level: int
The level of verbosity output to stdout
"""
def __init__(self, netG, netD, image_size, batch_size=64, noise_dim=100, conditional_dim=13,
n_critic_update=5, Lambda=10, use_gpu=False, verbosity_level=2):
bob.core.log.set_verbosity_level(logger, verbosity_level)
self.netG = netG
self.netD = netD
self.image_size = image_size
self.batch_size = batch_size
self.noise_dim = noise_dim
self.conditional_dim = conditional_dim
self.n_critic_update = n_critic_update
self.Lambda = Lambda
self.use_gpu = use_gpu
# fixed conditional noise - used to generate samples (one for each value of the conditional variable)
self.fixed_noise = torch.FloatTensor(self.conditional_dim, noise_dim, 1, 1).normal_(0, 1)
self.fixed_one_hot = torch.FloatTensor(self.conditional_dim, self.conditional_dim, 1, 1).zero_()
for k in range(self.conditional_dim):
self.fixed_one_hot[k, k] = 1
# TODO: figuring out the CPU/GPU thing - Guillaume HEUSCH, 17-11-2017
self.fixed_noise = Variable(self.fixed_noise)
self.fixed_one_hot = Variable(self.fixed_one_hot)
# binary cross-entropy loss
self.criterion = nn.BCELoss()
# move stuff to GPU if needed
if self.use_gpu:
self.netD.cuda()
self.netG.cuda()
self.criterion.cuda()
def calc_gradient_penalty(real_data, fake_data, one_hot, batch_size):
"""
Computes the gradient penalty term.
**Parameters**
real_data:
The batch of real images.
fake_data:
The batch of generated (fake) images.
one_hot:
The batch of feature maps to append to the discriminator input.
batch_size: int
The size of the minibatch.
"""
alpha = torch.rand(batch_size, 1)
alpha = alpha.expand(batch_size, real_data.nelement()/batch_size).contiguous().view(batch_size, self.image_size[0], self.image_size[1], self.image_size[2])
alpha = alpha.cuda() if self.use_gpu else alpha
interpolates = alpha * real_data + ((1 - alpha) * fake_data)
if use_gpu:
interpolates = interpolates.cuda()
interpolates = autograd.Variable(interpolates, requires_grad=True)
disc_interpolates = self.netD(interpolates, one_hot)
gradients = autograd.grad(outputs=disc_interpolates, inputs=interpolates,
grad_outputs=torch.ones(disc_interpolates.size()).cuda() if use_gpu else torch.ones(disc_interpolates.size()),
create_graph=True, retain_graph=True, only_inputs=True)[0]
gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * self.Lambda
return gradient_penalty
def train(self, dataloader, n_iterations=100000, learning_rate=0.0001, beta1=0.5, output_dir='out'):
"""
Function that performs the training.
**Parameters**
dataloader: pytorch DataLoader
The dataloader for your data
n_iterations: int
The number of iterations you would like to train for
learning_rate: float
The learning rate for Adam optimizer
beta1: float
The beta1 for Adam optimizer
output_dir: path
The directory where you would like to output images and models
"""
# setup optimizer
optimizerD = optim.Adam(self.netD.parameters(), lr=1e-4, betas=(beta1, 0.9))
optimizerG = optim.Adam(self.netG.parameters(), lr=1e-4, betas=(beta1, 0.9))
one = torch.FloatTensor([1])
mone = one * -1
if use_gpu:
one = one.cuda()
mone = mone.cuda()
# let's go
for iteration in range(n_iterations):
start = time.time()
# =============
# DISCRIMINATOR
# =============
for p in self.netD.parameters():
p.requires_grad = True
for k in range(n_critic_update):
self.netD.zero_grad()
# get the data and pose labels
data = dataloader.next()
real_images = data['image']
poses = data['pose']
# WARNING: the last batch could be smaller than the provided size
batch_size = len(real_images)
# create the Tensors with the right batch size
noise = torch.FloatTensor(batch_size, self.noise_dim, 1, 1).normal_(0, 1)
# create the one hot conditional vector (generator) and feature maps (discriminator)
one_hot_feature_maps = torch.FloatTensor(batch_size, self.conditional_dim, self.image_size[1], self.image_size[2]).zero_()
one_hot_vector = torch.FloatTensor(batch_size, self.conditional_dim, 1, 1).zero_()
for k in range(batch_size):
one_hot_feature_maps[k, poses[k], :, :] = 1
one_hot_vector[k, poses[k]] = 1
# move stuff to GPU if needed
if self.use_gpu:
real_images = real_images.cuda()
noise = noise.cuda()
one_hot_feature_maps = one_hot_feature_maps.cuda()
one_hot_vector = one_hot_vector.cuda()
# === REAL DATA ===
imagev = Variable(real_images)
one_hot_fmv = Variable(one_hot_feature_maps)
output_real = self.netD(imagev, one_hot_fmv)
output_real = output_real.mean()
output_real.backward(mone)
# === FAKE DATA ===
noisev = Variable(noise, volatile=True)
one_hot_vv = Variable(one_hot_vector)
fake = Variable(self.netG(noisev, one_hot_vv).data)
input_fakev = fake
output_fake = self.netD(input_fakev, one_hot_fmv)
output_fake = output_fake.mean()
output_fake.backward(one)
gradient_penalty = calc_gradient_penalty(imagev.data, fake.data, one_hot_feature_maps)
gradient_penalty.backward()
D_cost = D_fake - D_real + gradient_penalty
optimizerD.step()
# =========
# GENERATOR
# =========
for p in netD.parameters():
p.requires_grad = False
self.netG.zero_grad()
noise = torch.FloatTensor(batch_size, self.noise_dim, 1, 1).normal_(0, 1)
data = dataloader.next()
poses = data['pose']
one_hot_feature_maps = torch.FloatTensor(batch_size, self.conditional_dim, self.image_size[1], self.image_size[2]).zero_()
for k in range(batch_size):
one_hot_feature_maps[k, poses[k], :, :] = 1
if self.use_gpu:
noise = noise.cuda()
one_hot_feature_maps = one_hot_feature_maps.cuda()
noisev = Variable(noise)
one_hot_fmv = Variable(one_hot_feature_maps)
fake = netG(noisev)
G = netD(fake, one_hot_fmv)
G = G.mean()
G.backward(mone)
G_cost = -G
optimizerG.step()
end = time.time()
logger.info("[{}/{}] => Loss D = {} -- Loss G = {} (time spent: {})".format(iteration, n_iterations, D_cost.data, G_cost.data, (end-start)))
# save sample every 100 iterations
if iteration % 100 == 99:
fake_examples = self.netG(self.fixed_noise, self.fixed_one_hot)
vutils.save_image(fake_examples.data, '%s/fake_samples_epoch_%03d.png' % (output_dir, epoch), normalize=True)
# save model every 1000 iterations
if iteration % 1000 == 999:
torch.save(self.netG.state_dict(), '%s/netG_epoch_%d.pth' % (output_dir, epoch))
torch.save(self.netD.state_dict(), '%s/netD_epoch_%d.pth' % (output_dir, epoch))
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment