Skip to content
Snippets Groups Projects
Commit f6583e54 authored by Guillaume HEUSCH's avatar Guillaume HEUSCH
Browse files

[trainers] add the performance for minibatches in DR-GAN trainer, as well as...

[trainers] add the performance for minibatches in DR-GAN trainer, as well as the ability to save generated samples in HDF5
parent c6fef45a
Branches
Tags v2.1.3
No related merge requests found
...@@ -66,6 +66,8 @@ class DRGANTrainer(object): ...@@ -66,6 +66,8 @@ class DRGANTrainer(object):
self.latent_dim = latent_dim self.latent_dim = latent_dim
self.use_gpu = use_gpu self.use_gpu = use_gpu
self.number_of_ids = self.discriminator.number_of_ids
# fixed conditional noise - used to generate samples (one for each value of the conditional variable) # fixed conditional noise - used to generate samples (one for each value of the conditional variable)
self.fixed_noise = torch.FloatTensor(self.conditional_dim, noise_dim, 1, 1).normal_(0, 1) self.fixed_noise = torch.FloatTensor(self.conditional_dim, noise_dim, 1, 1).normal_(0, 1)
self.fixed_one_hot = torch.FloatTensor(self.conditional_dim, self.conditional_dim, 1, 1).zero_() self.fixed_one_hot = torch.FloatTensor(self.conditional_dim, self.conditional_dim, 1, 1).zero_()
...@@ -91,7 +93,7 @@ class DRGANTrainer(object): ...@@ -91,7 +93,7 @@ class DRGANTrainer(object):
self.criterion_id.cuda() self.criterion_id.cuda()
def train(self, dataloader, n_epochs=10, learning_rate=0.0002, beta1=0.5, output_dir='out', plot=False): def train(self, dataloader, n_epochs=10, learning_rate=0.0002, beta1=0.5, output_dir='out', plot=False, pose_random=True):
""" """
Function that performs the training. Function that performs the training.
...@@ -131,7 +133,8 @@ class DRGANTrainer(object): ...@@ -131,7 +133,8 @@ class DRGANTrainer(object):
pose = dataloader.dataset[counter]['pose'] pose = dataloader.dataset[counter]['pose']
fixed_index = counter fixed_index = counter
counter += 1 counter += 1
fixed_image = dataloader.dataset[counter]['image'] fixed_image = torch.Tensor(torch.Size(self.image_size))
fixed_image.copy_(dataloader.dataset[counter]['image'])
vutils.save_image(fixed_image, '%s/fixed_id.png' % (output_dir), normalize=True) vutils.save_image(fixed_image, '%s/fixed_id.png' % (output_dir), normalize=True)
# plot the image if asked for # plot the image if asked for
...@@ -140,7 +143,7 @@ class DRGANTrainer(object): ...@@ -140,7 +143,7 @@ class DRGANTrainer(object):
fixed_pose = dataloader.dataset[counter]['pose'] fixed_pose = dataloader.dataset[counter]['pose']
from matplotlib import pyplot from matplotlib import pyplot
pyplot.title("ID -> {}, pose {}".format(fixed_id, fixed_pose)) pyplot.title("ID -> {}, pose {}".format(fixed_id, fixed_pose))
pyplot.imshow(numpy.rollaxis(numpy.rollaxis(fixed_image.numpy(), 2),2)) pyplot.imshow(numpy.rollaxis(numpy.rollaxis(((fixed_image.numpy()+1)/2.), 2),2))
pyplot.show() pyplot.show()
# expand the fixed image, so that to have a batch for all possible poses # expand the fixed image, so that to have a batch for all possible poses
...@@ -157,11 +160,8 @@ class DRGANTrainer(object): ...@@ -157,11 +160,8 @@ class DRGANTrainer(object):
fixed_one_hot[k, k] = 1 fixed_one_hot[k, k] = 1
fixed_one_hot = Variable(fixed_one_hot) fixed_one_hot = Variable(fixed_one_hot)
# number of ids in the database
number_of_ids = self.discriminator.number_of_ids
# save minibatch of generated fake images every X iterations # save minibatch of generated fake images every X iterations
save_generated_minibatch = 1 save_generated_minibatch = 100
# ================ # ================
# === LET'S GO === # === LET'S GO ===
...@@ -176,14 +176,19 @@ class DRGANTrainer(object): ...@@ -176,14 +176,19 @@ class DRGANTrainer(object):
poses = data['pose'] poses = data['pose']
ids = data['id'] ids = data['id']
if max(ids) >= number_of_ids: # sanity check
logger.error("Something is wrong here: I have an ID with index {}, and the number of IDs is {}".format(max(ids), number_of_ids)) if max(ids) >= self.number_of_ids:
logger.error("Something is wrong here: I have an ID with index {}, and the number of IDs is {}".format(max(ids), self.number_of_ids))
import sys import sys
sys.exit() sys.exit()
# WARNING: the last batch could be smaller than the provided size # WARNING: the last batch could be smaller than the provided size
# (you could avoid that by setting the drop_last flag to True in dataloader constructor) # (you could avoid that by setting the drop_last flag to True in dataloader constructor)
batch_size = len(real_images) batch_size = len(real_images)
# generate random pose index - to train to generate "unknown" poses !
random_poses = numpy.random.randint(0, self.conditional_dim, batch_size)
random_poses = torch.LongTensor(random_poses)
# get a minibatch of noise # get a minibatch of noise
noise = torch.FloatTensor(batch_size, self.noise_dim, 1, 1).normal_(0, 1) noise = torch.FloatTensor(batch_size, self.noise_dim, 1, 1).normal_(0, 1)
...@@ -191,14 +196,16 @@ class DRGANTrainer(object): ...@@ -191,14 +196,16 @@ class DRGANTrainer(object):
# create the one hot conditional vector on pose (decoder) # create the one hot conditional vector on pose (decoder)
one_hot_vector = torch.FloatTensor(batch_size, self.conditional_dim, 1, 1).zero_() one_hot_vector = torch.FloatTensor(batch_size, self.conditional_dim, 1, 1).zero_()
for k in range(batch_size): for k in range(batch_size):
one_hot_vector[k, poses[k]] = 1 if pose_random:
one_hot_vector[k, int(random_poses[k])] = 1
else:
one_hot_vector[k, poses[k]] = 1
# label for fake/real # label for fake/real
label_gan = torch.FloatTensor(batch_size) label_gan = torch.FloatTensor(batch_size)
# move stuff to GPU if needed # move stuff to GPU if needed
if self.use_gpu: if self.use_gpu:
# inputs # inputs
real_images = real_images.cuda() real_images = real_images.cuda()
noise = noise.cuda() noise = noise.cuda()
...@@ -206,6 +213,7 @@ class DRGANTrainer(object): ...@@ -206,6 +213,7 @@ class DRGANTrainer(object):
#labels #labels
label_gan = label_gan.cuda() label_gan = label_gan.cuda()
poses = poses.cuda() poses = poses.cuda()
random_poses = random_poses.cuda()
ids = ids.cuda() ids = ids.cuda()
# ============= # =============
...@@ -221,13 +229,9 @@ class DRGANTrainer(object): ...@@ -221,13 +229,9 @@ class DRGANTrainer(object):
label_id_v = Variable(ids) label_id_v = Variable(ids)
output_real = self.discriminator(imagev) output_real = self.discriminator(imagev)
errD_real_id = self.criterion_id(output_real[:, :number_of_ids], label_id_v) errD_real_id = self.criterion_id(output_real[:, :self.number_of_ids], label_id_v)
errD_real_pose = self.criterion_pose(output_real[:, number_of_ids:(number_of_ids + self.conditional_dim)], label_pose_v) errD_real_pose = self.criterion_pose(output_real[:, self.number_of_ids:(self.number_of_ids + self.conditional_dim)], label_pose_v)
errD_real_gan = self.criterion_gan(output_real[:, -1], label_gan_v) errD_real_gan = self.criterion_gan(output_real[:, -1], label_gan_v)
logger.debug("[REAL] error on ID = {}".format(errD_real_id.data[0]))
logger.debug("[REAL] error on pose = {}".format(errD_real_pose.data[0]))
logger.debug("[REAL] error on fake/real = {}".format(errD_real_gan.data[0]))
errD_real_id.backward(retain_graph=True) errD_real_id.backward(retain_graph=True)
errD_real_pose.backward(retain_graph=True) errD_real_pose.backward(retain_graph=True)
...@@ -239,15 +243,15 @@ class DRGANTrainer(object): ...@@ -239,15 +243,15 @@ class DRGANTrainer(object):
encoded_ids = self.encoder(imagev) encoded_ids = self.encoder(imagev)
fake = self.decoder(noisev, one_hot_vv, encoded_ids) fake = self.decoder(noisev, one_hot_vv, encoded_ids)
label_gan_v = Variable(label_gan.fill_(fake_label)) label_gan_v = Variable(label_gan.fill_(fake_label))
label_random_poses_v = Variable(random_poses)
output_fake = self.discriminator(fake) output_fake = self.discriminator(fake)
errD_fake_id = self.criterion_id(output_fake[:, :number_of_ids], label_id_v) errD_fake_id = self.criterion_id(output_fake[:, :self.number_of_ids], label_id_v)
errD_fake_pose = self.criterion_pose(output_fake[:, number_of_ids:(number_of_ids + self.conditional_dim)], label_pose_v) if pose_random:
errD_fake_pose = self.criterion_pose(output_fake[:, self.number_of_ids:(self.number_of_ids + self.conditional_dim)], label_random_poses_v)
else:
errD_fake_pose = self.criterion_pose(output_fake[:, self.number_of_ids:(self.number_of_ids + self.conditional_dim)], label_pose_v)
errD_fake_gan = self.criterion_gan(output_fake[:, -1], label_gan_v) errD_fake_gan = self.criterion_gan(output_fake[:, -1], label_gan_v)
logger.debug("[FAKE] error on ID = {}".format(errD_fake_id.data[0]))
logger.debug("[FAKE] error on pose = {}".format(errD_fake_pose.data[0]))
logger.debug("[FAKE] error on fake/fake = {}".format(errD_fake_gan.data[0]))
errD_fake_id.backward(retain_graph=True) errD_fake_id.backward(retain_graph=True)
errD_fake_pose.backward(retain_graph=True) errD_fake_pose.backward(retain_graph=True)
...@@ -271,13 +275,17 @@ class DRGANTrainer(object): ...@@ -271,13 +275,17 @@ class DRGANTrainer(object):
generated_example = (fake[index].data.numpy() + 1)/2. generated_example = (fake[index].data.numpy() + 1)/2.
id_example = ids[index] id_example = ids[index]
pose_example = poses[index] pose_example = poses[index]
if pose_random:
pose_example = random_poses[index]
# create a figure with both the real example and the generated one # create a figure with both the real example and the generated one
if plot: if plot:
from matplotlib import pyplot from matplotlib import pyplot
fig, axarr = pyplot.subplots(1,2) fig, axarr = pyplot.subplots(1,2)
fig.suptitle("ID = {}, pose = {}".format(id_example, pose_example)) fig.suptitle("ID = {}".format(id_example))
axarr[0].set_title("Real pose = {}".format(poses[index]))
axarr[0].imshow(numpy.rollaxis(numpy.rollaxis(real_example, 2),2)) axarr[0].imshow(numpy.rollaxis(numpy.rollaxis(real_example, 2),2))
axarr[1].set_title("Target pose = {}".format(pose_example))
axarr[1].imshow(numpy.rollaxis(numpy.rollaxis(generated_example, 2),2)) axarr[1].imshow(numpy.rollaxis(numpy.rollaxis(generated_example, 2),2))
pyplot.show() pyplot.show()
...@@ -289,7 +297,6 @@ class DRGANTrainer(object): ...@@ -289,7 +297,6 @@ class DRGANTrainer(object):
if self.use_gpu: if self.use_gpu:
fake = fake.cuda() fake = fake.cuda()
#vutils.save_image(fake.data, '%s/generated_images_epoch_%03d_minibatch_%03d.png' % (output_dir, epoch, i), normalize=True) #vutils.save_image(fake.data, '%s/generated_images_epoch_%03d_minibatch_%03d.png' % (output_dir, epoch, i), normalize=True)
# ========= # =========
# GENERATOR # GENERATOR
...@@ -299,14 +306,13 @@ class DRGANTrainer(object): ...@@ -299,14 +306,13 @@ class DRGANTrainer(object):
label_gan_v = Variable(label_gan.fill_(real_label)) # fake labels are real for generator cost label_gan_v = Variable(label_gan.fill_(real_label)) # fake labels are real for generator cost
output_generated = self.discriminator(fake) output_generated = self.discriminator(fake)
errG_id = self.criterion_id(output_generated[:, :number_of_ids], label_id_v) errG_id = self.criterion_id(output_generated[:, :self.number_of_ids], label_id_v)
errG_pose = self.criterion_pose(output_generated[:, number_of_ids:(number_of_ids + self.conditional_dim)], label_pose_v) if pose_random:
errG_pose = self.criterion_pose(output_generated[:, self.number_of_ids:(self.number_of_ids + self.conditional_dim)], label_random_poses_v)
else:
errG_pose = self.criterion_pose(output_generated[:, self.number_of_ids:(self.number_of_ids + self.conditional_dim)], label_pose_v)
errG_gan = self.criterion_gan(output_generated[:, -1], label_gan_v) errG_gan = self.criterion_gan(output_generated[:, -1], label_gan_v)
logger.debug("[GENERATOR] error on ID = {}".format(errG_id.data[0]))
logger.debug("[GENERATOR] error on pose = {}".format(errG_pose.data[0]))
logger.debug("[GENERATOR] error on fake/fake = {}".format(errG_gan.data[0]))
errG_id.backward(retain_graph=True) errG_id.backward(retain_graph=True)
errG_pose.backward(retain_graph=True) errG_pose.backward(retain_graph=True)
errG_gan.backward(retain_graph=True) errG_gan.backward(retain_graph=True)
...@@ -317,7 +323,39 @@ class DRGANTrainer(object): ...@@ -317,7 +323,39 @@ class DRGANTrainer(object):
end = time.time() end = time.time()
logger.info("[{}/{}][{}/{}] => Loss D = {} -- Loss G = {} (time spent: {})".format(epoch, n_epochs, i, len(dataloader), errD.data[0], errG.data[0], (end-start))) logger.info("[{}/{}][{}/{}] => Loss D = {} -- Loss G = {} (time spent: {})".format(epoch, n_epochs, i, len(dataloader), errD.data[0], errG.data[0], (end-start)))
self.check_batch_statistics(output_real, output_fake, output_generated, ids, poses, random_poses, batch_size)
# =====================
# SAVE GENERATED IMAGES
# =====================
if ((i % save_generated_minibatch) == 0) and (i > 0):
# get a random example in this minibatch
index = numpy.random.randint(0, batch_size)
logger.info("Saving example {} in this batch (epoch {} - iteration {})".format(index, epoch, i))
# move stuff back to CPU (needed to use numpy())
real_images = real_images.cpu()
fake = fake.cpu()
real_example = (real_images[index].numpy() + 1)/2.
generated_example = (fake[index].data.numpy() + 1)/2.
id_example = ids[index]
pose_example = poses[index]
if pose_random:
pose_example = random_poses[index]
# save hdf5 data:
filename = output_dir + '/generated_data_{}_{}.hdf5'.format(epoch, i)
f = bob.io.base.HDF5File(filename, 'w')
f.set('id', id_example)
f.set('real_example', real_example)
f.set('generated_example', generated_example)
f.set('real_pose', poses[index])
f.set('target_pose', pose_example)
del f
# save generated images at every epoch # save generated images at every epoch
# TODO: model moved to CPU and back and I don't really know why (expected CPU tensor error) # TODO: model moved to CPU and back and I don't really know why (expected CPU tensor error)
# To summarize: # To summarize:
...@@ -339,3 +377,109 @@ class DRGANTrainer(object): ...@@ -339,3 +377,109 @@ class DRGANTrainer(object):
torch.save(self.encoder.state_dict(), '%s/encoder_epoch_%d.pth' % (output_dir, epoch)) torch.save(self.encoder.state_dict(), '%s/encoder_epoch_%d.pth' % (output_dir, epoch))
torch.save(self.decoder.state_dict(), '%s/decoder_epoch_%d.pth' % (output_dir, epoch)) torch.save(self.decoder.state_dict(), '%s/decoder_epoch_%d.pth' % (output_dir, epoch))
torch.save(self.discriminator.state_dict(), '%s/discriminator_epoch_%d.pth' % (output_dir, epoch)) torch.save(self.discriminator.state_dict(), '%s/discriminator_epoch_%d.pth' % (output_dir, epoch))
def check_batch_statistics(self, output_real, output_fake, output_generator,
label_id, label_pose, label_random_pose, batch_size):
"""
Compute some performance stats on the current mini-batch
**Parameters**
output_real: pyTorch Tensor (batch_size, (#of ids + dim(condition) + 1))
The discriminator output for real images
output_fake: pyTorch Tensor (batch_size, (#of ids + dim(condition) + 1))
The discriminator output for fake examples
output_generator: pyTorch Tensor (batch_size, (#of ids + dim(condition) + 1))
The discriminator output used to train the generator
label_id: pyTorch Tensor (batch_size)
The label for the identity
label_pose: pyTorch Tensor (batch_size)
The (real) label for the pose
label_random_pose: pyTorch Tensor (batch_size)
The random label for the pose
batch_size: int
The size of the current batch
"""
# --- REAL ---
output_real_id = output_real[:, :self.number_of_ids]
output_real_pose = output_real[:, self.number_of_ids:(self.number_of_ids + self.conditional_dim)]
output_real_gan = output_real[:, -1]
# --- FAKE ---
output_fake_id = output_fake[:, :self.number_of_ids]
output_fake_pose = output_fake[:, self.number_of_ids:(self.number_of_ids + self.conditional_dim)]
output_fake_gan = output_fake[:, -1]
# --- GENERATOR ---
output_generator_id = output_generator[:, :self.number_of_ids]
output_generator_pose = output_generator[:, self.number_of_ids:(self.number_of_ids + self.conditional_dim)]
output_generator_gan = output_generator[:, -1]
real_id_correct = 0
real_pose_correct = 0
real_gan_correct = 0
fake_id_correct = 0
fake_pose_correct = 0
fake_gan_correct = 0
generator_id_correct = 0
generator_pose_correct = 0
generator_gan_correct = 0
for i in range(batch_size):
prob, inferred_real_id = torch.max(output_real_id[i], 0)
prob, inferred_real_pose = torch.max(output_real_pose[i], 0)
prob_real = output_real_gan[i].data
if inferred_real_id.data[0] == label_id[i]:
real_id_correct += 1
if inferred_real_pose.data[0] == label_pose[i]:
real_pose_correct += 1
if prob_real[0] > 0.5:
real_gan_correct +=1
prob, inferred_fake_id = torch.max(output_fake_id[i], 0)
prob, inferred_fake_pose = torch.max(output_fake_pose[i], 0)
prob_real = output_fake_gan[i].data
if inferred_fake_id.data[0] == label_id[i]:
fake_id_correct += 1
if inferred_fake_pose.data[0] == label_pose[i]:
fake_pose_correct += 1
if prob_real[0] < 0.5:
fake_gan_correct +=1
prob, inferred_generator_id = torch.max(output_generator_id[i], 0)
prob, inferred_generator_pose = torch.max(output_generator_pose[i], 0)
prob_real = output_generator_gan[i].data
if inferred_generator_id.data[0] == label_id[i]:
generator_id_correct += 1
if inferred_generator_pose.data[0] == label_pose[i]:
generator_pose_correct += 1
if prob_real[0] < 0.5:
generator_gan_correct +=1
logger.debug("[REAL] accuracy on ID = {} ({}/{})".format(real_id_correct/float(batch_size), real_id_correct, batch_size))
logger.debug("[REAL] accuracy on pose = {} ({}/{})".format(real_pose_correct/float(batch_size), real_pose_correct, batch_size))
logger.debug("[REAL] accuracy on real/fake = {} ({}/{})".format(real_gan_correct/float(batch_size), real_gan_correct, batch_size))
logger.debug("[FAKE] accuracy on ID = {} ({}/{})".format(fake_id_correct/float(batch_size), fake_id_correct, batch_size))
logger.debug("[FAKE] accuracy on pose = {} ({}/{})".format(fake_pose_correct/float(batch_size), fake_pose_correct, batch_size))
logger.debug("[FAKE] accuracy on real/fake = {} ({}/{})".format(fake_gan_correct/float(batch_size), fake_gan_correct, batch_size))
logger.debug("[GENERATOR] accuracy on ID = {} ({}/{})".format(generator_id_correct/float(batch_size), generator_id_correct, batch_size))
logger.debug("[GENERATOR] accuracy on pose = {} ({}/{})".format(generator_pose_correct/float(batch_size), generator_pose_correct, batch_size))
logger.debug("[GENERATOR] accuracy on real/fake = {} ({}/{})".format(generator_gan_correct/float(batch_size), generator_gan_correct, batch_size))
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment