Skip to content
Snippets Groups Projects
Commit 9a562679 authored by Guillaume HEUSCH's avatar Guillaume HEUSCH
Browse files

[trainers] finished the DRGAN trainer

parent f627be08
No related branches found
No related tags found
No related merge requests found
...@@ -81,7 +81,8 @@ class DRGANTrainer(object): ...@@ -81,7 +81,8 @@ class DRGANTrainer(object):
# move stuff to GPU if needed # move stuff to GPU if needed
if self.use_gpu: if self.use_gpu:
self.discriminator.cuda() self.discriminator.cuda()
self.netG.cuda() self.encoder.cuda()
self.decoder.cuda()
self.criterion_gan.cuda() self.criterion_gan.cuda()
self.criterion_pose.cuda() self.criterion_pose.cuda()
self.criterion_id.cuda() self.criterion_id.cuda()
...@@ -116,15 +117,26 @@ class DRGANTrainer(object): ...@@ -116,15 +117,26 @@ class DRGANTrainer(object):
optimizerD = optim.Adam(self.discriminator.parameters(), lr=learning_rate, betas=(beta1, 0.999)) optimizerD = optim.Adam(self.discriminator.parameters(), lr=learning_rate, betas=(beta1, 0.999))
optimizerG = optim.Adam(generator_params, lr=learning_rate, betas=(beta1, 0.999)) optimizerG = optim.Adam(generator_params, lr=learning_rate, betas=(beta1, 0.999))
# get a fixed encoded id for sampling # get fixed images and noise for sampling
fixed_image = dataloader.dataset[0]['image'].numpy() fixed_image = dataloader.dataset[0]['image']
fixed_image = fixed_image.expand(self.conditional_dim, self.image_size[0], self.image_size[1], self.image_size[2])
fixed_noise = torch.FloatTensor(self.conditional_dim, self.noise_dim, 1, 1).normal_(0, 1)
fixed_noise = Variable(fixed_noise)
fixed_one_hot = torch.FloatTensor(self.conditional_dim, self.conditional_dim, 1, 1).zero_()
for k in range(self.conditional_dim):
fixed_one_hot[k, k] = 1
fixed_one_hot = Variable(fixed_one_hot)
fixed_id = dataloader.dataset[0]['id'] fixed_id = dataloader.dataset[0]['id']
fixed_pose = dataloader.dataset[0]['pose'] fixed_pose = dataloader.dataset[0]['pose']
from matplotlib import pyplot #from matplotlib import pyplot
pyplot.title("ID -> {}, pose {}".format(fixed_id, fixed_pose)) #pyplot.title("ID -> {}, pose {}".format(fixed_id, fixed_pose))
pyplot.imshow(numpy.rollaxis(numpy.rollaxis(fixed_image, 2),2)) #pyplot.imshow(numpy.rollaxis(numpy.rollaxis(fixed_image, 2),2))
pyplot.show() #pyplot.show()
number_of_ids = self.discriminator.number_of_ids number_of_ids = self.discriminator.number_of_ids
...@@ -138,6 +150,8 @@ class DRGANTrainer(object): ...@@ -138,6 +150,8 @@ class DRGANTrainer(object):
poses = data['pose'] poses = data['pose']
ids = data['id'] ids = data['id']
if max(ids) >= number_of_ids:
logger.error("Something is wrong here: I have an ID with index {}, and the number of IDs is {}".format(max(ids), number_of_ids))
# WARNING: the last batch could be smaller than the provided size # WARNING: the last batch could be smaller than the provided size
batch_size = len(real_images) batch_size = len(real_images)
...@@ -166,45 +180,79 @@ class DRGANTrainer(object): ...@@ -166,45 +180,79 @@ class DRGANTrainer(object):
self.discriminator.zero_grad() self.discriminator.zero_grad()
# === REAL DATA === # === REAL DATA ===
label_gan.resize_(batch_size).fill_(real_label)
imagev = Variable(real_images) imagev = Variable(real_images)
label_gan.resize_(batch_size).fill_(real_label)
label_gan_v = Variable(label_gan) label_gan_v = Variable(label_gan)
label_pose_v = Variable(poses) label_pose_v = Variable(poses)
label_id_v = Variable(ids) label_id_v = Variable(ids)
output_real = self.discriminator(imagev) output_real = self.discriminator(imagev)
errD_id = self.criterion_id(output_real[:, :number_of_ids], label_id_v) errD_real_id = self.criterion_id(output_real[:, :number_of_ids], label_id_v)
errD_pose = self.criterion_pose(output_real[:, number_of_ids:(number_of_ids + self.conditional_dim)], label_pose_v) errD_real_pose = self.criterion_pose(output_real[:, number_of_ids:(number_of_ids + self.conditional_dim)], label_pose_v)
errD_gan = self.criterion_gan(output_real[:, -1], label_gan_v) errD_real_gan = self.criterion_gan(output_real[:, -1], label_gan_v)
print errD_id logger.debug("[REAL] error on ID = {}".format(errD_real_id.data[0]))
print errD_pose logger.debug("[REAL] error on pose = {}".format(errD_real_pose.data[0]))
print errD_gan logger.debug("[REAL] error on fake/real = {}".format(errD_real_gan.data[0]))
import sys
sys.exit()
errD_real_id.backward(retain_graph=True)
errD_real_pose.backward(retain_graph=True)
errD_real_gan.backward(retain_graph=True)
# === FAKE DATA === # === FAKE DATA ===
noisev = Variable(noise) noisev = Variable(noise)
one_hot_vv = Variable(one_hot_vector) one_hot_vv = Variable(one_hot_vector)
fake = self.netG(noisev, one_hot_vv)
labelv = Variable(label.fill_(fake_label)) # encode the identity
output_fake = self.discriminator(fake, one_hot_fmv) encoded_ids = self.encoder(imagev)
errD_fake = self.criterion(output_fake, labelv) fake = self.decoder(noisev, one_hot_vv, encoded_ids)
errD_fake.backward(retain_graph=True) #from matplotlib import pyplot
#for k in range(batch_size):
# pyplot.title("ID -> {}, pose {}".format(ids[k], poses[k]))
# pyplot.imshow(numpy.rollaxis(numpy.rollaxis(fake[k].data.numpy(), 2),2))
# pyplot.show()
label_gan_v = Variable(label_gan.fill_(fake_label))
output_fake = self.discriminator(fake)
errD_fake_id = self.criterion_id(output_fake[:, :number_of_ids], label_id_v)
errD_fake_pose = self.criterion_pose(output_fake[:, number_of_ids:(number_of_ids + self.conditional_dim)], label_pose_v)
errD_fake_gan = self.criterion_gan(output_fake[:, -1], label_gan_v)
logger.debug("[FAKE] error on ID = {}".format(errD_fake_id.data[0]))
logger.debug("[FAKE] error on pose = {}".format(errD_fake_pose.data[0]))
logger.debug("[FAKE] error on fake/fake = {}".format(errD_fake_gan.data[0]))
errD_fake_id.backward(retain_graph=True)
errD_fake_pose.backward(retain_graph=True)
errD_fake_gan.backward(retain_graph=True)
# perform optimization (i.e. update discriminator parameters) # perform optimization (i.e. update discriminator parameters)
errD = errD_real + errD_fake errD = errD_real_id + errD_real_pose + (errD_real_gan + errD_fake_gan)
optimizerD.step() optimizerD.step()
# ========= # =========
# GENERATOR # GENERATOR
# ========= # =========
self.netG.zero_grad() self.encoder.zero_grad()
labelv = Variable(label.fill_(real_label)) # fake labels are real for generator cost self.decoder.zero_grad()
output_generated = self.discriminator(fake, one_hot_fmv) label_gan_v = Variable(label_gan.fill_(real_label)) # fake labels are real for generator cost
errG = self.criterion(output_generated, labelv) output_generated = self.discriminator(fake)
errG.backward()
errG_id = self.criterion_id(output_generated[:, :number_of_ids], label_id_v)
errG_pose = self.criterion_pose(output_generated[:, number_of_ids:(number_of_ids + self.conditional_dim)], label_pose_v)
errG_gan = self.criterion_gan(output_generated[:, -1], label_gan_v)
logger.debug("[GENERATOR] error on ID = {}".format(errG_id.data[0]))
logger.debug("[GENERATOR] error on pose = {}".format(errG_pose.data[0]))
logger.debug("[GENERATOR] error on fake/fake = {}".format(errG_gan.data[0]))
errG_id.backward(retain_graph=True)
errG_pose.backward(retain_graph=True)
errG_gan.backward(retain_graph=True)
# perform optimization (i.e. update discriminator parameters)
errG = errG_id + errG_pose + errG_gan
optimizerG.step() optimizerG.step()
end = time.time() end = time.time()
...@@ -216,11 +264,19 @@ class DRGANTrainer(object): ...@@ -216,11 +264,19 @@ class DRGANTrainer(object):
# tried to move tensors, variables on the GPU -> does not work # tried to move tensors, variables on the GPU -> does not work
# let the tensors on the CPU -> does not work # let the tensors on the CPU -> does not work
# => model has to be brought back to the CPU :/ # => model has to be brought back to the CPU :/
self.netG = self.netG.cpu() self.encoder = self.encoder.cpu()
fake_examples = self.netG(self.fixed_noise, self.fixed_one_hot) self.decoder = self.decoder.cpu()
self.netG = self.netG.cuda()
fixed_imagev = Variable(fixed_image)
fixed_encoded_id = self.encoder(fixed_imagev)
fake_examples = self.decoder(fixed_noise, fixed_one_hot, fixed_encoded_id)
vutils.save_image(fake_examples.data, '%s/fake_samples_epoch_%03d.png' % (output_dir, epoch), normalize=True) vutils.save_image(fake_examples.data, '%s/fake_samples_epoch_%03d.png' % (output_dir, epoch), normalize=True)
if torch.cuda.is_available():
self.encoder = self.encoder.cuda()
self.decoder = self.decoder.cuda()
# do checkpointing # do checkpointing
torch.save(self.netG.state_dict(), '%s/netG_epoch_%d.pth' % (output_dir, epoch)) torch.save(self.encoder.state_dict(), '%s/encoder_epoch_%d.pth' % (output_dir, epoch))
torch.save(self.decoder.state_dict(), '%s/decoder_epoch_%d.pth' % (output_dir, epoch))
torch.save(self.discriminator.state_dict(), '%s/discriminator_epoch_%d.pth' % (output_dir, epoch)) torch.save(self.discriminator.state_dict(), '%s/discriminator_epoch_%d.pth' % (output_dir, epoch))
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment