Commit 9e2d1da5 authored by Guillaume HEUSCH's avatar Guillaume HEUSCH

[trainers] fixed stuff in the conditional GAN trainer (last batch size)

parent c1b4c3a7
......@@ -61,7 +61,17 @@ class ConditionalGANTrainer(object):
self.label = torch.FloatTensor(batch_size)
# to generate samples
self.fixed_noise = torch.FloatTensor(12, noise_dim, 1, 1).normal_(0, 1)
noise = torch.FloatTensor(self.conditional_dim, noise_dim).normal_(0, 1)
oh = numpy.zeros((self.conditional_dim, self.conditional_dim))
for pose in range(self.conditional_dim):
oh[pose, pose] = 1
one_hot = torch.FloatTensor(oh)
input_generator_examples = torch.FloatTensor(self.conditional_dim, (self.noise_dim + self.conditional_dim))
for k in range(self.conditional_dim):
input_generator_examples[k] = torch.cat((noise[k], one_hot[k]), 0)
self.fixed_noise = torch.FloatTensor(conditional_dim, noise_dim + conditional_dim, 1, 1)
self.fixed_noise.resize_(self.conditional_dim, (self.noise_dim + self.conditional_dim), 1, 1).copy_(input_generator_examples)
self.fixed_noise = Variable(self.fixed_noise)
self.criterion = nn.BCELoss()
......@@ -71,7 +81,7 @@ class ConditionalGANTrainer(object):
self.netG.cuda()
self.criterion.cuda()
self.input, self.label = self.input.cuda(), self.label.cuda()
self.noise, self.fixed_noise = self.noise.cuda(), self.fixed_noise.cuda()
self.conditional_noise, self.fixed_noise = self.conditional_noise.cuda(), self.fixed_noise.cuda()
bob.core.log.set_verbosity_level(logger, verbosity_level)
......@@ -119,21 +129,22 @@ class ConditionalGANTrainer(object):
poses = data['pose']
image_size = real_images[1].size()
if self.use_gpu:
real_images = real_images.cuda()
# WARNING: the last batch could be smaller than the provided size
batch_size = len(real_images)
# build the additional feature maps corresponding to the conditioning variable
temp = real_images.numpy()
cm = numpy.zeros((self.batch_size, self.conditional_dim, temp.shape[2], temp.shape[3]))
for k in range(self.batch_size):
cm = numpy.zeros((batch_size, self.conditional_dim, self.image_size[1], self.image_size[2]))
for k in range(batch_size):
cm[:, int(poses[k]), :, :] = 1
conditional_maps = torch.FloatTensor(cm)
# append the conditional feature maps to the original images
input_discriminator = torch.FloatTensor(self.batch_size, (image_size[0] + self.conditional_dim), image_size[1], image_size[2])
for k in range(self.batch_size):
for k in range(batch_size):
input_discriminator[k] = torch.cat((real_images[k], conditional_maps[k]), 0)
if self.use_gpu:
input_discriminator = input_discriminator.cuda()
self.input.resize_as_(input_discriminator).copy_(input_discriminator)
inputv = Variable(self.input)
......@@ -144,30 +155,41 @@ class ConditionalGANTrainer(object):
D_x = output.data.mean()
# train with fake
noise = torch.FloatTensor(self.batch_size, self.noise_dim)
noise = torch.FloatTensor(batch_size, self.noise_dim)
noise.normal_(0, 1)
# generate the one hot pose encoding
oh = numpy.zeros((self.batch_size, self.conditional_dim))
for k in range(self.batch_size):
oh = numpy.zeros((batch_size, self.conditional_dim))
for k in range(batch_size):
oh[k, int(poses[k])] = 1
one_hot = torch.FloatTensor(oh)
# concatenate that with the noise
input_generator = torch.FloatTensor(self.batch_size, (self.noise_dim + self.conditional_dim))
for k in range(self.batch_size):
for k in range(batch_size):
input_generator[k] = torch.cat((noise[k], one_hot[k]), 0)
if self.use_gpu:
input_generator = input_generator.cuda()
self.conditional_noise.resize_(self.batch_size, (self.noise_dim + self.conditional_dim), 1, 1).copy_(input_generator)
noisev = Variable(self.conditional_noise)
fake = self.netG(noisev)
# build conditional fakes
fake_images = fake.data
if self.use_gpu:
fake = fake.cpu()
fake_images = fake.data
fake = fake.cuda()
input_discriminator_fake = torch.FloatTensor(self.batch_size, (image_size[0] + self.conditional_dim), image_size[1], image_size[2])
for k in range(self.batch_size):
for k in range(batch_size):
input_discriminator_fake[k] = torch.cat((fake_images[k], conditional_maps[k]), 0)
if self.use_gpu:
input_discriminator_fake = input_discriminator_fake.cuda()
fake_input_v = Variable(input_discriminator_fake)
labelv = Variable(self.label.fill_(fake_label))
output = self.netD(fake_input_v)
......@@ -190,20 +212,11 @@ class ConditionalGANTrainer(object):
end = time.time()
logger.info("[{}/{}][{}/{}] => Loss D = {} -- Loss G = {} (time spent: {})".format(epoch, n_epochs, i, len(dataloader), errD.data[0], errG.data[0], (end-start)))
# save generated images at every epoch
#input_generator_examples =
#poses = range(self.conditional_dim)
#for pose in pose:
# oh = numpy.zeros(self.conditional_dim)
# oh[pose] = 1
# one_hot = torch.FloatTensor(oh)
#fake = self.netG(self.fixed_noise)
#vutils.save_image(fake.data, '%s/fake_samples_epoch_%03d.png' % (output_dir, epoch), normalize=True)
fake_examples = self.netG(self.fixed_noise)
vutils.save_image(fake_examples.data, '%s/fake_samples_epoch_%03d.png' % (output_dir, epoch), normalize=True)
## do checkpointing
#torch.save(self.netG.state_dict(), '%s/netG_epoch_%d.pth' % (output_dir, epoch))
#torch.save(self.netD.state_dict(), '%s/netD_epoch_%d.pth' % (output_dir, epoch))
# do checkpointing
torch.save(self.netG.state_dict(), '%s/netG_epoch_%d.pth' % (output_dir, epoch))
torch.save(self.netD.state_dict(), '%s/netD_epoch_%d.pth' % (output_dir, epoch))
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment