Skip to content
Snippets Groups Projects
Commit 9600fea9 authored by Guillaume HEUSCH's avatar Guillaume HEUSCH
Browse files

[trainers] made sure that the fixed image used for sampling is frontal, save...

[trainers] made sure that the fixed image used for sampling is frontal, save generated samples during training
parent 9a562679
Branches
Tags
No related merge requests found
......@@ -117,9 +117,27 @@ class DRGANTrainer(object):
optimizerD = optim.Adam(self.discriminator.parameters(), lr=learning_rate, betas=(beta1, 0.999))
optimizerG = optim.Adam(generator_params, lr=learning_rate, betas=(beta1, 0.999))
# get fixed images and noise for sampling
fixed_image = dataloader.dataset[0]['image']
# get fixed image, fixed noise and conditional pose for sampling
print "number of images = {}".format(len(dataloader.dataset))
pose = 0
counter = 0
while pose != 6:
pose = dataloader.dataset[counter]['pose']
fixed_index = counter
counter += 1
fixed_image = dataloader.dataset[counter]['image']
vutils.save_image(fixed_image, '%s/fixed_id.png' % (output_dir), normalize=True)
#fixed_id = dataloader.dataset[counter]['id']
#fixed_pose = dataloader.dataset[counter]['pose']
#from matplotlib import pyplot
#pyplot.title("ID -> {}, pose {}".format(fixed_id, fixed_pose))
#pyplot.imshow(numpy.rollaxis(numpy.rollaxis(fixed_image.numpy(), 2),2))
#pyplot.show()
fixed_image = fixed_image.expand(self.conditional_dim, self.image_size[0], self.image_size[1], self.image_size[2])
fixed_image = Variable(fixed_image)
fixed_noise = torch.FloatTensor(self.conditional_dim, self.noise_dim, 1, 1).normal_(0, 1)
fixed_noise = Variable(fixed_noise)
......@@ -129,14 +147,6 @@ class DRGANTrainer(object):
fixed_one_hot[k, k] = 1
fixed_one_hot = Variable(fixed_one_hot)
fixed_id = dataloader.dataset[0]['id']
fixed_pose = dataloader.dataset[0]['pose']
#from matplotlib import pyplot
#pyplot.title("ID -> {}, pose {}".format(fixed_id, fixed_pose))
#pyplot.imshow(numpy.rollaxis(numpy.rollaxis(fixed_image, 2),2))
#pyplot.show()
number_of_ids = self.discriminator.number_of_ids
......@@ -206,6 +216,9 @@ class DRGANTrainer(object):
# encode the identity
encoded_ids = self.encoder(imagev)
fake = self.decoder(noisev, one_hot_vv, encoded_ids)
if (i % 10) == 0:
vutils.save_image(fake.data, '%s/generated_images_epoch_%03d_minibatch_%03d.png' % (output_dir, epoch, i), normalize=True)
#from matplotlib import pyplot
#for k in range(batch_size):
# pyplot.title("ID -> {}, pose {}".format(ids[k], poses[k]))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment