Skip to content
Snippets Groups Projects
Commit 5a205ee8 authored by Guillaume HEUSCH's avatar Guillaume HEUSCH
Browse files

[trainer] commented the code, fixed a bug with real labels, and another one with the batch size

parent a0b02c1c
No related branches found
No related tags found
No related merge requests found
...@@ -56,11 +56,16 @@ class ConditionalGANTrainer(object): ...@@ -56,11 +56,16 @@ class ConditionalGANTrainer(object):
self.conditional_dim = conditional_dim self.conditional_dim = conditional_dim
self.use_gpu = use_gpu self.use_gpu = use_gpu
self.input = torch.FloatTensor(batch_size, (image_size[0] + conditional_dim), image_size[1], image_size[2]) # real image + concatentation of conditonal feature maps
self.conditional_noise = torch.FloatTensor(batch_size, noise_dim + conditional_dim, 1, 1) #self.input = torch.FloatTensor(batch_size, (image_size[0] + conditional_dim), image_size[1], image_size[2])
self.label = torch.FloatTensor(batch_size)
# noise, + one hot encoding of the conditional variable
#self.conditional_noise = torch.FloatTensor(batch_size, noise_dim + conditional_dim, 1, 1)
# real/fake labels
#self.label = torch.FloatTensor(batch_size)
# to generate samples # fixed conditional noise - used to generate samples (one for each value of the conditional variable)
noise = torch.FloatTensor(self.conditional_dim, noise_dim).normal_(0, 1) noise = torch.FloatTensor(self.conditional_dim, noise_dim).normal_(0, 1)
oh = numpy.zeros((self.conditional_dim, self.conditional_dim)) oh = numpy.zeros((self.conditional_dim, self.conditional_dim))
for pose in range(self.conditional_dim): for pose in range(self.conditional_dim):
...@@ -69,19 +74,18 @@ class ConditionalGANTrainer(object): ...@@ -69,19 +74,18 @@ class ConditionalGANTrainer(object):
input_generator_examples = torch.FloatTensor(self.conditional_dim, (self.noise_dim + self.conditional_dim)) input_generator_examples = torch.FloatTensor(self.conditional_dim, (self.noise_dim + self.conditional_dim))
for k in range(self.conditional_dim): for k in range(self.conditional_dim):
input_generator_examples[k] = torch.cat((noise[k], one_hot[k]), 0) input_generator_examples[k] = torch.cat((noise[k], one_hot[k]), 0)
self.fixed_noise = torch.FloatTensor(conditional_dim, noise_dim + conditional_dim, 1, 1) self.fixed_noise = torch.FloatTensor(conditional_dim, noise_dim + conditional_dim, 1, 1)
self.fixed_noise.resize_(self.conditional_dim, (self.noise_dim + self.conditional_dim), 1, 1).copy_(input_generator_examples) self.fixed_noise.resize_(self.conditional_dim, (self.noise_dim + self.conditional_dim), 1, 1).copy_(input_generator_examples)
self.fixed_noise = Variable(self.fixed_noise) self.fixed_noise = Variable(self.fixed_noise)
# binary cross-entropy loss
self.criterion = nn.BCELoss() self.criterion = nn.BCELoss()
# move stuff to GPU if needed
if self.use_gpu: if self.use_gpu:
self.netD.cuda() self.netD.cuda()
self.netG.cuda() self.netG.cuda()
self.criterion.cuda() self.criterion.cuda()
self.input, self.label = self.input.cuda(), self.label.cuda()
self.conditional_noise, self.fixed_noise = self.conditional_noise.cuda(), self.fixed_noise.cuda()
bob.core.log.set_verbosity_level(logger, verbosity_level) bob.core.log.set_verbosity_level(logger, verbosity_level)
...@@ -113,26 +117,36 @@ class ConditionalGANTrainer(object): ...@@ -113,26 +117,36 @@ class ConditionalGANTrainer(object):
# setup optimizer # setup optimizer
optimizerD = optim.Adam(self.netD.parameters(), lr=learning_rate, betas=(beta1, 0.999)) optimizerD = optim.Adam(self.netD.parameters(), lr=learning_rate, betas=(beta1, 0.999))
optimizerG = optim.Adam(self.netG.parameters(), lr=learning_rate, betas=(beta1, 0.999)) optimizerG = optim.Adam(self.netG.parameters(), lr=learning_rate, betas=(beta1, 0.999))
if self.use_gpu:
label = label.cuda()
self.conditional_noise, self.fixed_noise = self.conditional_noise.cuda(), self.fixed_noise.cuda()
for epoch in range(n_epochs): for epoch in range(n_epochs):
for i, data in enumerate(dataloader, 0): for i, data in enumerate(dataloader, 0):
start = time.time() start = time.time()
# ============= # get the data and pose labels
# DISCRIMINATOR
# =============
# train with real
self.netD.zero_grad()
real_images = data['image'] real_images = data['image']
poses = data['pose'] poses = data['pose']
image_size = real_images[1].size()
# WARNING: the last batch could be smaller than the provided size # WARNING: the last batch could be smaller than the provided size
batch_size = len(real_images) batch_size = len(real_images)
# create the Tensors with the right batch size
discriminator_input = torch.FloatTensor(batch_size, (self.image_size[0] + self.conditional_dim), self.image_size[1], self.image_size[2])
conditional_noise = torch.FloatTensor(batch_size, self.noise_dim + self.conditional_dim, 1, 1)
label = torch.FloatTensor(batch_size)
# =============
# DISCRIMINATOR
# =============
self.netD.zero_grad()
# === REAL DATA ===
# build the additional feature maps corresponding to the conditioning variable # build the additional feature maps corresponding to the conditioning variable
cm = numpy.zeros((batch_size, self.conditional_dim, self.image_size[1], self.image_size[2])) cm = numpy.zeros((batch_size, self.conditional_dim, self.image_size[1], self.image_size[2]))
for k in range(batch_size): for k in range(batch_size):
...@@ -140,75 +154,82 @@ class ConditionalGANTrainer(object): ...@@ -140,75 +154,82 @@ class ConditionalGANTrainer(object):
conditional_maps = torch.FloatTensor(cm) conditional_maps = torch.FloatTensor(cm)
# append the conditional feature maps to the original images # append the conditional feature maps to the original images
input_discriminator = torch.FloatTensor(self.batch_size, (image_size[0] + self.conditional_dim), image_size[1], image_size[2]) input_discriminator = torch.FloatTensor(batch_size, (self.image_size[0] + self.conditional_dim), self.image_size[1], self.image_size[2])
for k in range(batch_size): for k in range(batch_size):
input_discriminator[k] = torch.cat((real_images[k], conditional_maps[k]), 0) input_discriminator[k] = torch.cat((real_images[k], conditional_maps[k]), 0)
if self.use_gpu: if self.use_gpu:
input_discriminator = input_discriminator.cuda() input_discriminator = input_discriminator.cuda()
self.input.resize_as_(input_discriminator).copy_(input_discriminator) # train with real
inputv = Variable(self.input) #self.input.resize_as_(input_discriminator).copy_(input_discriminator)
labelv = Variable(self.label) label.resize_(batch_size).fill_(real_label)
inputv = Variable(discriminator_input)
labelv = Variable(label)
output = self.netD(inputv) output = self.netD(inputv)
errD_real = self.criterion(output, labelv) errD_real = self.criterion(output, labelv)
errD_real.backward() errD_real.backward()
D_x = output.data.mean() D_x = output.data.mean()
# train with fake # === FAKE DATA ===
# get the noise
noise = torch.FloatTensor(batch_size, self.noise_dim) noise = torch.FloatTensor(batch_size, self.noise_dim)
noise.normal_(0, 1) noise.normal_(0, 1)
# generate the one hot pose encoding # generate the one hot pose encoding vector
oh = numpy.zeros((batch_size, self.conditional_dim)) oh = numpy.zeros((batch_size, self.conditional_dim))
for k in range(batch_size): for k in range(batch_size):
oh[k, int(poses[k])] = 1 oh[k, int(poses[k])] = 1
one_hot = torch.FloatTensor(oh) one_hot = torch.FloatTensor(oh)
# concatenate that with the noise # concatenate the one hot vector with the noise
input_generator = torch.FloatTensor(self.batch_size, (self.noise_dim + self.conditional_dim)) input_generator = torch.FloatTensor(batch_size, (self.noise_dim + self.conditional_dim))
for k in range(batch_size): for k in range(batch_size):
input_generator[k] = torch.cat((noise[k], one_hot[k]), 0) input_generator[k] = torch.cat((noise[k], one_hot[k]), 0)
if self.use_gpu: if self.use_gpu:
input_generator = input_generator.cuda() input_generator = input_generator.cuda()
conditional_noise = conditional_noise.cuda()
self.conditional_noise.resize_(self.batch_size, (self.noise_dim + self.conditional_dim), 1, 1).copy_(input_generator)
noisev = Variable(self.conditional_noise) # generate conditioned fake images
conditional_noise.resize_(self.batch_size, (self.noise_dim + self.conditional_dim), 1, 1).copy_(input_generator)
noisev = Variable(conditional_noise)
fake = self.netG(noisev) fake = self.netG(noisev)
# build conditional fakes # build conditional fakes (i.e. generated images + conditional feature maps)
fake_images = fake.data fake_images = fake.data
if self.use_gpu: if self.use_gpu:
fake = fake.cpu() fake = fake.cpu()
fake_images = fake.data fake_images = fake.data
fake = fake.cuda() fake = fake.cuda()
input_discriminator_fake = torch.FloatTensor(self.batch_size, (image_size[0] + self.conditional_dim), image_size[1], image_size[2]) input_discriminator_fake = torch.FloatTensor(self.batch_size, (self.image_size[0] + self.conditional_dim), self.image_size[1], self.image_size[2])
for k in range(batch_size): for k in range(batch_size):
input_discriminator_fake[k] = torch.cat((fake_images[k], conditional_maps[k]), 0) input_discriminator_fake[k] = torch.cat((fake_images[k], conditional_maps[k]), 0)
if self.use_gpu: if self.use_gpu:
input_discriminator_fake = input_discriminator_fake.cuda() input_discriminator_fake = input_discriminator_fake.cuda()
# train with fake
fake_input_v = Variable(input_discriminator_fake) fake_input_v = Variable(input_discriminator_fake)
labelv = Variable(self.label.fill_(fake_label)) labelv = Variable(label.fill_(fake_label))
output = self.netD(fake_input_v) output = self.netD(fake_input_v)
errD_fake = self.criterion(output, labelv) errD_fake = self.criterion(output, labelv)
errD_fake.backward() errD_fake.backward()
D_G_z1 = output.data.mean()
errD = errD_real + errD_fake errD = errD_real + errD_fake
# perform optimization (i.e. update discriminator parameters)
optimizerD.step() optimizerD.step()
# ========================================= # =========
# (2) Update G network: maximize log(D(G(z))) # GENERATOR
# ========================================= # =========
self.netG.zero_grad() self.netG.zero_grad()
labelv = Variable(self.label.fill_(real_label)) # fake labels are real for generator cost labelv = Variable(label.fill_(real_label)) # fake labels are real for generator cost
output = self.netD(fake_input_v) output = self.netD(fake_input_v)
errG = self.criterion(output, labelv) errG = self.criterion(output, labelv)
errG.backward() errG.backward()
D_G_z2 = output.data.mean()
optimizerG.step() optimizerG.step()
end = time.time() end = time.time()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment