Commit fe20611e authored by Guillaume HEUSCH's avatar Guillaume HEUSCH

[trainers] fixed CPU/GPU issues when moving variables, fixed bug in conditional maps construction

parent 5a205ee8
......@@ -76,7 +76,6 @@ class ConditionalGANTrainer(object):
input_generator_examples[k] = torch.cat((noise[k], one_hot[k]), 0)
self.fixed_noise = torch.FloatTensor(conditional_dim, noise_dim + conditional_dim, 1, 1)
self.fixed_noise.resize_(self.conditional_dim, (self.noise_dim + self.conditional_dim), 1, 1).copy_(input_generator_examples)
self.fixed_noise = Variable(self.fixed_noise)
# binary cross-entropy loss
self.criterion = nn.BCELoss()
......@@ -86,6 +85,11 @@ class ConditionalGANTrainer(object):
self.netD.cuda()
self.netG.cuda()
self.criterion.cuda()
self_fixed_noise = self.fixed_noise.cuda()
self.fixed_noise_v = Variable(self.fixed_noise)
if self.use_gpu:
self.fixed_noise_v = self.fixed_noise_v.cuda()
bob.core.log.set_verbosity_level(logger, verbosity_level)
......@@ -118,9 +122,6 @@ class ConditionalGANTrainer(object):
optimizerD = optim.Adam(self.netD.parameters(), lr=learning_rate, betas=(beta1, 0.999))
optimizerG = optim.Adam(self.netG.parameters(), lr=learning_rate, betas=(beta1, 0.999))
if self.use_gpu:
label = label.cuda()
self.conditional_noise, self.fixed_noise = self.conditional_noise.cuda(), self.fixed_noise.cuda()
for epoch in range(n_epochs):
......@@ -136,9 +137,12 @@ class ConditionalGANTrainer(object):
batch_size = len(real_images)
# create the Tensors with the right batch size
discriminator_input = torch.FloatTensor(batch_size, (self.image_size[0] + self.conditional_dim), self.image_size[1], self.image_size[2])
input_discriminator = torch.FloatTensor(batch_size, (self.image_size[0] + self.conditional_dim), self.image_size[1], self.image_size[2])
conditional_noise = torch.FloatTensor(batch_size, self.noise_dim + self.conditional_dim, 1, 1)
label = torch.FloatTensor(batch_size)
if self.use_gpu:
label = label.cuda()
conditional_noise = conditional_noise.cuda()
# =============
# DISCRIMINATOR
......@@ -150,11 +154,10 @@ class ConditionalGANTrainer(object):
# build the additional feature maps corresponding to the conditioning variable
cm = numpy.zeros((batch_size, self.conditional_dim, self.image_size[1], self.image_size[2]))
for k in range(batch_size):
cm[:, int(poses[k]), :, :] = 1
cm[k, int(poses[k]), :, :] = 1
conditional_maps = torch.FloatTensor(cm)
# append the conditional feature maps to the original images
input_discriminator = torch.FloatTensor(batch_size, (self.image_size[0] + self.conditional_dim), self.image_size[1], self.image_size[2])
for k in range(batch_size):
input_discriminator[k] = torch.cat((real_images[k], conditional_maps[k]), 0)
......@@ -162,9 +165,14 @@ class ConditionalGANTrainer(object):
input_discriminator = input_discriminator.cuda()
# train with real
#self.input.resize_as_(input_discriminator).copy_(input_discriminator)
label.resize_(batch_size).fill_(real_label)
inputv = Variable(discriminator_input)
inputv = Variable(input_discriminator)
#from matplotlib import pyplot
#pyplot.title("Pose {}".format(poses[0]))
#pyplot.imshow(numpy.rollaxis(numpy.rollaxis(first_image, 2),2))
#pyplot.show()
labelv = Variable(label)
output = self.netD(inputv)
errD_real = self.criterion(output, labelv)
......@@ -193,7 +201,7 @@ class ConditionalGANTrainer(object):
conditional_noise = conditional_noise.cuda()
# generate conditioned fake images
conditional_noise.resize_(self.batch_size, (self.noise_dim + self.conditional_dim), 1, 1).copy_(input_generator)
conditional_noise.resize_(batch_size, (self.noise_dim + self.conditional_dim), 1, 1).copy_(input_generator)
noisev = Variable(conditional_noise)
fake = self.netG(noisev)
......@@ -204,7 +212,7 @@ class ConditionalGANTrainer(object):
fake_images = fake.data
fake = fake.cuda()
input_discriminator_fake = torch.FloatTensor(self.batch_size, (self.image_size[0] + self.conditional_dim), self.image_size[1], self.image_size[2])
input_discriminator_fake = torch.FloatTensor(batch_size, (self.image_size[0] + self.conditional_dim), self.image_size[1], self.image_size[2])
for k in range(batch_size):
input_discriminator_fake[k] = torch.cat((fake_images[k], conditional_maps[k]), 0)
......@@ -218,6 +226,7 @@ class ConditionalGANTrainer(object):
errD_fake = self.criterion(output, labelv)
errD_fake.backward()
errD = errD_real + errD_fake
#print errD_fake.grad_fn.next_functions[0][0]
# perform optimization (i.e. update discriminator parameters)
optimizerD.step()
......@@ -236,7 +245,7 @@ class ConditionalGANTrainer(object):
logger.info("[{}/{}][{}/{}] => Loss D = {} -- Loss G = {} (time spent: {})".format(epoch, n_epochs, i, len(dataloader), errD.data[0], errG.data[0], (end-start)))
# save generated images at every epoch
fake_examples = self.netG(self.fixed_noise)
fake_examples = self.netG(self.fixed_noise_v)
vutils.save_image(fake_examples.data, '%s/fake_samples_epoch_%03d.png' % (output_dir, epoch), normalize=True)
# do checkpointing
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment