Commit 5a205ee8 authored by Guillaume HEUSCH's avatar Guillaume HEUSCH

[trainer] commented the code, fixed a bug with real labels, and another one with the batch size

parent a0b02c1c
......@@ -56,11 +56,16 @@ class ConditionalGANTrainer(object):
self.conditional_dim = conditional_dim
self.use_gpu = use_gpu
self.input = torch.FloatTensor(batch_size, (image_size[0] + conditional_dim), image_size[1], image_size[2])
self.conditional_noise = torch.FloatTensor(batch_size, noise_dim + conditional_dim, 1, 1)
self.label = torch.FloatTensor(batch_size)
# real image + concatentation of conditonal feature maps
#self.input = torch.FloatTensor(batch_size, (image_size[0] + conditional_dim), image_size[1], image_size[2])
# noise, + one hot encoding of the conditional variable
#self.conditional_noise = torch.FloatTensor(batch_size, noise_dim + conditional_dim, 1, 1)
# real/fake labels
#self.label = torch.FloatTensor(batch_size)
# to generate samples
# fixed conditional noise - used to generate samples (one for each value of the conditional variable)
noise = torch.FloatTensor(self.conditional_dim, noise_dim).normal_(0, 1)
oh = numpy.zeros((self.conditional_dim, self.conditional_dim))
for pose in range(self.conditional_dim):
......@@ -69,19 +74,18 @@ class ConditionalGANTrainer(object):
input_generator_examples = torch.FloatTensor(self.conditional_dim, (self.noise_dim + self.conditional_dim))
for k in range(self.conditional_dim):
input_generator_examples[k] = torch.cat((noise[k], one_hot[k]), 0)
self.fixed_noise = torch.FloatTensor(conditional_dim, noise_dim + conditional_dim, 1, 1)
self.fixed_noise.resize_(self.conditional_dim, (self.noise_dim + self.conditional_dim), 1, 1).copy_(input_generator_examples)
self.fixed_noise = Variable(self.fixed_noise)
# binary cross-entropy loss
self.criterion = nn.BCELoss()
# move stuff to GPU if needed
if self.use_gpu:
self.netD.cuda()
self.netG.cuda()
self.criterion.cuda()
self.input, self.label = self.input.cuda(), self.label.cuda()
self.conditional_noise, self.fixed_noise = self.conditional_noise.cuda(), self.fixed_noise.cuda()
bob.core.log.set_verbosity_level(logger, verbosity_level)
......@@ -113,26 +117,36 @@ class ConditionalGANTrainer(object):
# setup optimizer
optimizerD = optim.Adam(self.netD.parameters(), lr=learning_rate, betas=(beta1, 0.999))
optimizerG = optim.Adam(self.netG.parameters(), lr=learning_rate, betas=(beta1, 0.999))
if self.use_gpu:
label = label.cuda()
self.conditional_noise, self.fixed_noise = self.conditional_noise.cuda(), self.fixed_noise.cuda()
for epoch in range(n_epochs):
for i, data in enumerate(dataloader, 0):
start = time.time()
# =============
# DISCRIMINATOR
# =============
# train with real
self.netD.zero_grad()
# get the data and pose labels
real_images = data['image']
poses = data['pose']
image_size = real_images[1].size()
# WARNING: the last batch could be smaller than the provided size
batch_size = len(real_images)
# create the Tensors with the right batch size
discriminator_input = torch.FloatTensor(batch_size, (self.image_size[0] + self.conditional_dim), self.image_size[1], self.image_size[2])
conditional_noise = torch.FloatTensor(batch_size, self.noise_dim + self.conditional_dim, 1, 1)
label = torch.FloatTensor(batch_size)
# =============
# DISCRIMINATOR
# =============
self.netD.zero_grad()
# === REAL DATA ===
# build the additional feature maps corresponding to the conditioning variable
cm = numpy.zeros((batch_size, self.conditional_dim, self.image_size[1], self.image_size[2]))
for k in range(batch_size):
......@@ -140,75 +154,82 @@ class ConditionalGANTrainer(object):
conditional_maps = torch.FloatTensor(cm)
# append the conditional feature maps to the original images
input_discriminator = torch.FloatTensor(self.batch_size, (image_size[0] + self.conditional_dim), image_size[1], image_size[2])
input_discriminator = torch.FloatTensor(batch_size, (self.image_size[0] + self.conditional_dim), self.image_size[1], self.image_size[2])
for k in range(batch_size):
input_discriminator[k] = torch.cat((real_images[k], conditional_maps[k]), 0)
if self.use_gpu:
input_discriminator = input_discriminator.cuda()
self.input.resize_as_(input_discriminator).copy_(input_discriminator)
inputv = Variable(self.input)
labelv = Variable(self.label)
# train with real
#self.input.resize_as_(input_discriminator).copy_(input_discriminator)
label.resize_(batch_size).fill_(real_label)
inputv = Variable(discriminator_input)
labelv = Variable(label)
output = self.netD(inputv)
errD_real = self.criterion(output, labelv)
errD_real.backward()
D_x = output.data.mean()
# train with fake
# === FAKE DATA ===
# get the noise
noise = torch.FloatTensor(batch_size, self.noise_dim)
noise.normal_(0, 1)
# generate the one hot pose encoding
# generate the one hot pose encoding vector
oh = numpy.zeros((batch_size, self.conditional_dim))
for k in range(batch_size):
oh[k, int(poses[k])] = 1
one_hot = torch.FloatTensor(oh)
# concatenate that with the noise
input_generator = torch.FloatTensor(self.batch_size, (self.noise_dim + self.conditional_dim))
# concatenate the one hot vector with the noise
input_generator = torch.FloatTensor(batch_size, (self.noise_dim + self.conditional_dim))
for k in range(batch_size):
input_generator[k] = torch.cat((noise[k], one_hot[k]), 0)
if self.use_gpu:
input_generator = input_generator.cuda()
self.conditional_noise.resize_(self.batch_size, (self.noise_dim + self.conditional_dim), 1, 1).copy_(input_generator)
noisev = Variable(self.conditional_noise)
input_generator = input_generator.cuda()
conditional_noise = conditional_noise.cuda()
# generate conditioned fake images
conditional_noise.resize_(self.batch_size, (self.noise_dim + self.conditional_dim), 1, 1).copy_(input_generator)
noisev = Variable(conditional_noise)
fake = self.netG(noisev)
# build conditional fakes
# build conditional fakes (i.e. generated images + conditional feature maps)
fake_images = fake.data
if self.use_gpu:
fake = fake.cpu()
fake_images = fake.data
fake = fake.cuda()
input_discriminator_fake = torch.FloatTensor(self.batch_size, (image_size[0] + self.conditional_dim), image_size[1], image_size[2])
input_discriminator_fake = torch.FloatTensor(self.batch_size, (self.image_size[0] + self.conditional_dim), self.image_size[1], self.image_size[2])
for k in range(batch_size):
input_discriminator_fake[k] = torch.cat((fake_images[k], conditional_maps[k]), 0)
if self.use_gpu:
input_discriminator_fake = input_discriminator_fake.cuda()
# train with fake
fake_input_v = Variable(input_discriminator_fake)
labelv = Variable(self.label.fill_(fake_label))
labelv = Variable(label.fill_(fake_label))
output = self.netD(fake_input_v)
errD_fake = self.criterion(output, labelv)
errD_fake.backward()
D_G_z1 = output.data.mean()
errD = errD_real + errD_fake
# perform optimization (i.e. update discriminator parameters)
optimizerD.step()
# =========================================
# (2) Update G network: maximize log(D(G(z)))
# =========================================
# =========
# GENERATOR
# =========
self.netG.zero_grad()
labelv = Variable(self.label.fill_(real_label)) # fake labels are real for generator cost
labelv = Variable(label.fill_(real_label)) # fake labels are real for generator cost
output = self.netD(fake_input_v)
errG = self.criterion(output, labelv)
errG.backward()
D_G_z2 = output.data.mean()
optimizerG.step()
end = time.time()
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment