Skip to content
Snippets Groups Projects
Commit 00bc5356 authored by Guillaume HEUSCH's avatar Guillaume HEUSCH
Browse files

[trainers] fixes stuff in the trainer for WGAN-GP (iterator for data, learning...

[trainers] fixes stuff in the trainer for WGAN-GP (iterator for data, learning rate of the critic, ...)
parent 86e2888b
Branches
Tags
No related merge requests found
......@@ -8,12 +8,16 @@ import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
from torch.autograd import grad
import torchvision.utils as vutils
import bob.core
logger = bob.core.log.setup("bob.learn.pytorch")
class IWCGAN(object):
from matplotlib import pyplot
class IWCGANTrainer(object):
"""
Class to train a Conditional GAN, using the Improved Wasserstein Training method
......@@ -50,7 +54,7 @@ class IWCGAN(object):
The level of verbosity output to stdout
"""
def __init__(self, netG, netD, image_size, batch_size=64, noise_dim=100, conditional_dim=13,
n_critic_update=5, Lambda=10, use_gpu=False, verbosity_level=2):
n_critic_update=3, Lambda=10, use_gpu=False, verbosity_level=2):
bob.core.log.set_verbosity_level(logger, verbosity_level)
......@@ -83,7 +87,7 @@ class IWCGAN(object):
self.netG.cuda()
self.criterion.cuda()
def calc_gradient_penalty(real_data, fake_data, one_hot, batch_size):
def calc_gradient_penalty(self, real_data, fake_data, one_hot, batch_size):
"""
Computes the gradient penalty term.
......@@ -103,22 +107,30 @@ class IWCGAN(object):
"""
alpha = torch.rand(batch_size, 1)
alpha_zero = alpha[0].numpy()
alpha = alpha.expand(batch_size, real_data.nelement()/batch_size).contiguous().view(batch_size, self.image_size[0], self.image_size[1], self.image_size[2])
alpha = alpha.cuda() if self.use_gpu else alpha
interpolates = alpha * real_data + ((1 - alpha) * fake_data)
if use_gpu:
#first_image = interpolates[0].numpy()
#pyplot.title(str(alpha_zero))
#pyplot.imshow(numpy.rollaxis(numpy.rollaxis(first_image, 2),2))
#pyplot.show()
if self.use_gpu:
interpolates = interpolates.cuda()
interpolates = autograd.Variable(interpolates, requires_grad=True)
interpolates = Variable(interpolates, requires_grad=True)
disc_interpolates = self.netD(interpolates, one_hot)
gradients = autograd.grad(outputs=disc_interpolates, inputs=interpolates,
grad_outputs=torch.ones(disc_interpolates.size()).cuda() if use_gpu else torch.ones(disc_interpolates.size()),
create_graph=True, retain_graph=True, only_inputs=True)[0]
gradients = grad(outputs=disc_interpolates, inputs=interpolates,
grad_outputs=torch.ones(disc_interpolates.size()).cuda() if self.use_gpu else torch.ones(disc_interpolates.size()),
create_graph=True, retain_graph=True, only_inputs=True)[0]
# TODO: check the gradient norm along all dimensions - Guillaume HEUSCH, 22-11-2017
gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * self.Lambda
#print "gradient penalty = {}".format(gradient_penalty.data[0])
return gradient_penalty
......@@ -145,14 +157,22 @@ class IWCGAN(object):
"""
# setup optimizer
optimizerD = optim.Adam(self.netD.parameters(), lr=1e-4, betas=(beta1, 0.9))
optimizerD = optim.Adam(self.netD.parameters(), lr=1e-5, betas=(beta1, 0.9))
optimizerG = optim.Adam(self.netG.parameters(), lr=1e-4, betas=(beta1, 0.9))
one = torch.FloatTensor([1])
zero = torch.FloatTensor([0])
mone = one * -1
if use_gpu:
if self.use_gpu:
one = one.cuda()
mone = mone.cuda()
zero = zero.cuda()
data_iterator = iter(dataloader)
#import sys
#sys.exit()
n_consumed_batches = 0
# let's go
for iteration in range(n_iterations):
......@@ -165,12 +185,15 @@ class IWCGAN(object):
for p in self.netD.parameters():
p.requires_grad = True
for k in range(n_critic_update):
for k in range(self.n_critic_update):
# reset gradients
self.netD.zero_grad()
# get the data and pose labels
data = dataloader.next()
data = data_iterator.next()
n_consumed_batches += 1
real_images = data['image']
poses = data['pose']
......@@ -200,7 +223,9 @@ class IWCGAN(object):
output_real = self.netD(imagev, one_hot_fmv)
output_real = output_real.mean()
#print "output real = {}".format(output_real.data[0])
output_real.backward(mone)
#output_real.backward()
# === FAKE DATA ===
......@@ -211,51 +236,64 @@ class IWCGAN(object):
output_fake = self.netD(input_fakev, one_hot_fmv)
output_fake = output_fake.mean()
#print "output fake = {}".format(output_fake.data[0])
output_fake.backward(one)
#output_fake.backward()
gradient_penalty = calc_gradient_penalty(imagev.data, fake.data, one_hot_feature_maps)
gradient_penalty = self.calc_gradient_penalty(imagev.data, fake.data, one_hot_fmv, batch_size)
gradient_penalty.backward()
D_cost = D_fake - D_real + gradient_penalty
D_cost = output_fake - output_real + gradient_penalty
optimizerD.step()
# =========
# GENERATOR
# =========
for p in netD.parameters():
for p in self.netD.parameters():
p.requires_grad = False
self.netG.zero_grad()
noise = torch.FloatTensor(batch_size, self.noise_dim, 1, 1).normal_(0, 1)
data = dataloader.next()
data = data_iterator.next()
n_consumed_batches += 1
poses = data['pose']
one_hot_feature_maps = torch.FloatTensor(batch_size, self.conditional_dim, self.image_size[1], self.image_size[2]).zero_()
one_hot_vector = torch.FloatTensor(batch_size, self.conditional_dim, 1, 1).zero_()
for k in range(batch_size):
one_hot_feature_maps[k, poses[k], :, :] = 1
one_hot_vector[k, poses[k]] = 1
if self.use_gpu:
noise = noise.cuda()
one_hot_feature_maps = one_hot_feature_maps.cuda()
one_hot_vector = one_hot_vector.cuda()
noisev = Variable(noise)
one_hot_fmv = Variable(one_hot_feature_maps)
fake = netG(noisev)
G = netD(fake, one_hot_fmv)
one_hot_vv = Variable(one_hot_vector)
fake = self.netG(noisev, one_hot_vv)
G = self.netD(fake, one_hot_fmv)
G = G.mean()
G.backward(mone)
G_cost = -G
optimizerG.step()
end = time.time()
logger.info("[{}/{}] => Loss D = {} -- Loss G = {} (time spent: {})".format(iteration, n_iterations, D_cost.data, G_cost.data, (end-start)))
logger.info("[{}/{}] => Loss D = {} -- Loss G = {} (time spent: {}, number of consumed batch = {})".format(iteration, n_iterations, D_cost.data[0], G_cost.data[0], (end-start), n_consumed_batches))
# save sample every 100 iterations
if iteration % 100 == 99:
fake_examples = self.netG(self.fixed_noise, self.fixed_one_hot)
vutils.save_image(fake_examples.data, '%s/fake_samples_epoch_%03d.png' % (output_dir, epoch), normalize=True)
vutils.save_image(fake_examples.data, '%s/fake_samples_iteration_%03d.png' % (output_dir, iteration), normalize=True)
# save model every 1000 iterations
if iteration % 1000 == 999:
torch.save(self.netG.state_dict(), '%s/netG_epoch_%d.pth' % (output_dir, epoch))
torch.save(self.netD.state_dict(), '%s/netD_epoch_%d.pth' % (output_dir, epoch))
torch.save(self.netG.state_dict(), '%s/netG_epoch_%d.pth' % (output_dir, iteration))
torch.save(self.netD.state_dict(), '%s/netD_epoch_%d.pth' % (output_dir, iteration))
if n_consumed_batches > (len(dataloader) - (self.n_critic_update + 1)):
logger.info("Batch generator should be restarted ({}/{})".format(n_consumed_batches, len(dataloader)))
n_consumed_batches = 0
data_iterator = iter(dataloader)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment