Skip to content
Snippets Groups Projects
Commit f58f449f authored by Guillaume HEUSCH's avatar Guillaume HEUSCH
Browse files

[trainers] added some comments, and plotting/saving generated images during training

parent 626a68a9
No related branches found
No related tags found
No related merge requests found
......@@ -13,6 +13,9 @@ import torchvision.utils as vutils
import bob.core
logger = bob.core.log.setup("bob.learn.pytorch")
import bob.io.base
import bob.io.image
class DRGANTrainer(object):
"""
Class to train a DR-GAN
......@@ -88,7 +91,7 @@ class DRGANTrainer(object):
self.criterion_id.cuda()
def train(self, dataloader, n_epochs=10, learning_rate=0.0002, beta1=0.5, output_dir='out'):
def train(self, dataloader, n_epochs=10, learning_rate=0.0002, beta1=0.5, output_dir='out', plot=False):
"""
Function that performs the training.
......@@ -108,7 +111,11 @@ class DRGANTrainer(object):
output_dir: path
The directory where you would like to output images and models
plot: boolean
If you want to plot some images during the training process (debug)
"""
# labels for real/fake
real_label = 1
fake_label = 0
......@@ -117,72 +124,89 @@ class DRGANTrainer(object):
optimizerD = optim.Adam(self.discriminator.parameters(), lr=learning_rate, betas=(beta1, 0.999))
optimizerG = optim.Adam(generator_params, lr=learning_rate, betas=(beta1, 0.999))
# get fixed image, fixed noise and conditional pose for sampling
print "number of images = {}".format(len(dataloader.dataset))
# be sure to have a fixed frontal image to sample from
pose = 0
counter = 0
while pose != 6:
pose = dataloader.dataset[counter]['pose']
fixed_index = counter
counter += 1
fixed_image = dataloader.dataset[counter]['image']
vutils.save_image(fixed_image, '%s/fixed_id.png' % (output_dir), normalize=True)
#fixed_id = dataloader.dataset[counter]['id']
#fixed_pose = dataloader.dataset[counter]['pose']
#from matplotlib import pyplot
#pyplot.title("ID -> {}, pose {}".format(fixed_id, fixed_pose))
#pyplot.imshow(numpy.rollaxis(numpy.rollaxis(fixed_image.numpy(), 2),2))
#pyplot.show()
# plot the image if asked for
if plot:
fixed_id = dataloader.dataset[counter]['id']
fixed_pose = dataloader.dataset[counter]['pose']
from matplotlib import pyplot
pyplot.title("ID -> {}, pose {}".format(fixed_id, fixed_pose))
pyplot.imshow(numpy.rollaxis(numpy.rollaxis(fixed_image.numpy(), 2),2))
pyplot.show()
# expand the fixed image, so that to have a batch for all possible poses
fixed_image = fixed_image.expand(self.conditional_dim, self.image_size[0], self.image_size[1], self.image_size[2])
fixed_image = Variable(fixed_image)
# the noise to sample from
fixed_noise = torch.FloatTensor(self.conditional_dim, self.noise_dim, 1, 1).normal_(0, 1)
fixed_noise = Variable(fixed_noise)
# build the set of one-hot encoded pose to sample from
fixed_one_hot = torch.FloatTensor(self.conditional_dim, self.conditional_dim, 1, 1).zero_()
for k in range(self.conditional_dim):
fixed_one_hot[k, k] = 1
fixed_one_hot = Variable(fixed_one_hot)
# number of ids in the database
number_of_ids = self.discriminator.number_of_ids
# save minibatch of generated fake images every X iterations
save_generated_minibatch = 1
# ================
# === LET'S GO ===
# ================
for epoch in range(n_epochs):
for i, data in enumerate(dataloader, 0):
start = time.time()
# get the data and pose labels
# get the data, pose and id labels
real_images = data['image']
poses = data['pose']
ids = data['id']
if max(ids) >= number_of_ids:
logger.error("Something is wrong here: I have an ID with index {}, and the number of IDs is {}".format(max(ids), number_of_ids))
import sys
sys.exit()
# WARNING: the last batch could be smaller than the provided size
# WARNING: the last batch could be smaller than the provided size
# (you could avoid that by setting the drop_last flag to True in dataloader constructor)
batch_size = len(real_images)
# create the Tensors with the right batch size
# get a minibatch of noise
noise = torch.FloatTensor(batch_size, self.noise_dim, 1, 1).normal_(0, 1)
label_gan = torch.FloatTensor(batch_size)
# create the one hot conditional vector on pose (decoder)
one_hot_vector = torch.FloatTensor(batch_size, self.conditional_dim, 1, 1).zero_()
for k in range(batch_size):
one_hot_vector[k, poses[k]] = 1
# label for fake/real
label_gan = torch.FloatTensor(batch_size)
# move stuff to GPU if needed
if self.use_gpu:
# inputs
real_images = real_images.cuda()
noise = noise.cuda()
one_hot_vector = one_hot_vector.cuda()
#labels
label_gan = label_gan.cuda()
poses = poses.cuda()
ids = ids.cuda()
noise = noise.cuda()
one_hot_vector = one_hot_vector.cuda()
# =============
# DISCRIMINATOR
......@@ -212,19 +236,8 @@ class DRGANTrainer(object):
# === FAKE DATA ===
noisev = Variable(noise)
one_hot_vv = Variable(one_hot_vector)
# encode the identity
encoded_ids = self.encoder(imagev)
fake = self.decoder(noisev, one_hot_vv, encoded_ids)
if (i % 10) == 0:
vutils.save_image(fake.data, '%s/generated_images_epoch_%03d_minibatch_%03d.png' % (output_dir, epoch, i), normalize=True)
#from matplotlib import pyplot
#for k in range(batch_size):
# pyplot.title("ID -> {}, pose {}".format(ids[k], poses[k]))
# pyplot.imshow(numpy.rollaxis(numpy.rollaxis(fake[k].data.numpy(), 2),2))
# pyplot.show()
label_gan_v = Variable(label_gan.fill_(fake_label))
output_fake = self.discriminator(fake)
......@@ -243,6 +256,40 @@ class DRGANTrainer(object):
# perform optimization (i.e. update discriminator parameters)
errD = errD_real_id + errD_real_pose + (errD_real_gan + errD_fake_gan)
optimizerD.step()
# +++++ Save generated images during training +++++
if ((i % save_generated_minibatch) == 0) and (i > 0):
# get a random example in this minibatch
index = numpy.random.randint(0, batch_size)
logger.info("Saving example {} in this batch (epoch {} - iteration {})".format(index, epoch, i))
# move stuff back to CPU (needed to use numpy())
real_images = real_images.cpu()
fake = fake.cpu()
real_example = (real_images[index].numpy() + 1)/2.
generated_example = (fake[index].data.numpy() + 1)/2.
id_example = ids[index]
pose_example = poses[index]
# create a figure with both the real example and the generated one
if plot:
from matplotlib import pyplot
fig, axarr = pyplot.subplots(1,2)
fig.suptitle("ID = {}, pose = {}".format(id_example, pose_example))
axarr[0].imshow(numpy.rollaxis(numpy.rollaxis(real_example, 2),2))
axarr[1].imshow(numpy.rollaxis(numpy.rollaxis(generated_example, 2),2))
pyplot.show()
image_to_be_saved = numpy.ones((self.image_size[0], self.image_size[1], self.image_size[2]*2))
image_to_be_saved[:, :self.image_size[1], :self.image_size[2]] = real_example
image_to_be_saved[:, :self.image_size[1], self.image_size[2]:self.image_size[2]*2] = generated_example
bob.io.base.save((image_to_be_saved*255.).astype('uint8'), output_dir + '/generated_sample_{}_{}.png'.format(epoch, i))
if self.use_gpu:
fake = fake.cuda()
#vutils.save_image(fake.data, '%s/generated_images_epoch_%03d_minibatch_%03d.png' % (output_dir, epoch, i), normalize=True)
# =========
# GENERATOR
......@@ -280,8 +327,7 @@ class DRGANTrainer(object):
self.encoder = self.encoder.cpu()
self.decoder = self.decoder.cpu()
fixed_imagev = Variable(fixed_image)
fixed_encoded_id = self.encoder(fixed_imagev)
fixed_encoded_id = self.encoder(fixed_image)
fake_examples = self.decoder(fixed_noise, fixed_one_hot, fixed_encoded_id)
vutils.save_image(fake_examples.data, '%s/fake_samples_epoch_%03d.png' % (output_dir, epoch), normalize=True)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment