Skip to content
Snippets Groups Projects
Commit a84be26c authored by Guillaume HEUSCH's avatar Guillaume HEUSCH
Browse files

[trainers] added some comments and debug code

parent 334b9616
Branches
Tags
No related merge requests found
......@@ -15,6 +15,7 @@ logger = bob.core.log.setup("bob.learn.pytorch")
import time
from matplotlib import pyplot
class ConditionalGANTrainer(object):
"""
......@@ -76,7 +77,9 @@ class ConditionalGANTrainer(object):
input_generator_examples[k] = torch.cat((noise[k], one_hot[k]), 0)
self.fixed_noise = torch.FloatTensor(conditional_dim, noise_dim + conditional_dim, 1, 1)
self.fixed_noise.resize_(self.conditional_dim, (self.noise_dim + self.conditional_dim), 1, 1).copy_(input_generator_examples)
self.fixed_noise_v = Variable(self.fixed_noise)
# binary cross-entropy loss
self.criterion = nn.BCELoss()
......@@ -86,9 +89,6 @@ class ConditionalGANTrainer(object):
self.netG.cuda()
self.criterion.cuda()
self_fixed_noise = self.fixed_noise.cuda()
self.fixed_noise_v = Variable(self.fixed_noise)
if self.use_gpu:
self.fixed_noise_v = self.fixed_noise_v.cuda()
bob.core.log.set_verbosity_level(logger, verbosity_level)
......@@ -122,8 +122,6 @@ class ConditionalGANTrainer(object):
optimizerD = optim.Adam(self.netD.parameters(), lr=learning_rate, betas=(beta1, 0.999))
optimizerG = optim.Adam(self.netG.parameters(), lr=learning_rate, betas=(beta1, 0.999))
for epoch in range(n_epochs):
for i, data in enumerate(dataloader, 0):
......@@ -174,10 +172,9 @@ class ConditionalGANTrainer(object):
#pyplot.show()
labelv = Variable(label)
output = self.netD(inputv)
errD_real = self.criterion(output, labelv)
output_real = self.netD(inputv)
errD_real = self.criterion(output_real, labelv)
errD_real.backward()
D_x = output.data.mean()
# === FAKE DATA ===
......@@ -212,6 +209,11 @@ class ConditionalGANTrainer(object):
fake_images = fake.data
fake = fake.cuda()
#from matplotlib import pyplot
#first_fake = (fake_images[0].numpy() + 1)/2.
#pyplot.imshow(numpy.rollaxis(numpy.rollaxis(first_fake, 2),2))
#pyplot.show()
input_discriminator_fake = torch.FloatTensor(batch_size, (self.image_size[0] + self.conditional_dim), self.image_size[1], self.image_size[2])
for k in range(batch_size):
input_discriminator_fake[k] = torch.cat((fake_images[k], conditional_maps[k]), 0)
......@@ -222,11 +224,10 @@ class ConditionalGANTrainer(object):
# train with fake
fake_input_v = Variable(input_discriminator_fake)
labelv = Variable(label.fill_(fake_label))
output = self.netD(fake_input_v)
errD_fake = self.criterion(output, labelv)
output_fake = self.netD(fake_input_v)
errD_fake = self.criterion(output_fake, labelv)
errD_fake.backward()
errD = errD_real + errD_fake
#print errD_fake.grad_fn.next_functions[0][0]
# perform optimization (i.e. update discriminator parameters)
optimizerD.step()
......
......@@ -124,7 +124,7 @@ class DCGANTrainer(object):
noisev = Variable(self.noise)
fake = self.netG(noisev)
labelv = Variable(self.label.fill_(fake_label))
output = self.netD(fake.detach())
output = self.netD(fake.detach()) # detach() -> done for speed, not correctness (PyTorch github's issue says so ...)
errD_fake = self.criterion(output, labelv)
errD_fake.backward()
D_G_z1 = output.data.mean()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment