Skip to content
Snippets Groups Projects
Commit 9600fea9 authored by Guillaume HEUSCH's avatar Guillaume HEUSCH
Browse files

[trainers] made sure that the fixed image used for sampling is frontal, save...

[trainers] made sure that the fixed image used for sampling is frontal, save generated samples during training
parent 9a562679
No related branches found
No related tags found
No related merge requests found
......@@ -117,9 +117,27 @@ class DRGANTrainer(object):
optimizerD = optim.Adam(self.discriminator.parameters(), lr=learning_rate, betas=(beta1, 0.999))
optimizerG = optim.Adam(generator_params, lr=learning_rate, betas=(beta1, 0.999))
# get fixed images and noise for sampling
fixed_image = dataloader.dataset[0]['image']
# get fixed image, fixed noise and conditional pose for sampling
print "number of images = {}".format(len(dataloader.dataset))
pose = 0
counter = 0
while pose != 6:
pose = dataloader.dataset[counter]['pose']
fixed_index = counter
counter += 1
fixed_image = dataloader.dataset[counter]['image']
vutils.save_image(fixed_image, '%s/fixed_id.png' % (output_dir), normalize=True)
#fixed_id = dataloader.dataset[counter]['id']
#fixed_pose = dataloader.dataset[counter]['pose']
#from matplotlib import pyplot
#pyplot.title("ID -> {}, pose {}".format(fixed_id, fixed_pose))
#pyplot.imshow(numpy.rollaxis(numpy.rollaxis(fixed_image.numpy(), 2),2))
#pyplot.show()
fixed_image = fixed_image.expand(self.conditional_dim, self.image_size[0], self.image_size[1], self.image_size[2])
fixed_image = Variable(fixed_image)
fixed_noise = torch.FloatTensor(self.conditional_dim, self.noise_dim, 1, 1).normal_(0, 1)
fixed_noise = Variable(fixed_noise)
......@@ -129,14 +147,6 @@ class DRGANTrainer(object):
fixed_one_hot[k, k] = 1
fixed_one_hot = Variable(fixed_one_hot)
fixed_id = dataloader.dataset[0]['id']
fixed_pose = dataloader.dataset[0]['pose']
#from matplotlib import pyplot
#pyplot.title("ID -> {}, pose {}".format(fixed_id, fixed_pose))
#pyplot.imshow(numpy.rollaxis(numpy.rollaxis(fixed_image, 2),2))
#pyplot.show()
number_of_ids = self.discriminator.number_of_ids
......@@ -206,6 +216,9 @@ class DRGANTrainer(object):
# encode the identity
encoded_ids = self.encoder(imagev)
fake = self.decoder(noisev, one_hot_vv, encoded_ids)
if (i % 10) == 0:
vutils.save_image(fake.data, '%s/generated_images_epoch_%03d_minibatch_%03d.png' % (output_dir, epoch, i), normalize=True)
#from matplotlib import pyplot
#for k in range(batch_size):
# pyplot.title("ID -> {}, pose {}".format(ids[k], poses[k]))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment