Skip to content
Snippets Groups Projects
Commit dffdf7c4 authored by Guillaume HEUSCH's avatar Guillaume HEUSCH
Browse files

[trainer] DR-GAN: function to get fixed data and loading models, fixed the...

[trainer] DR-GAN: function to get fixed data and loading models, fixed the images saved at each epoch (embedding the input image)
parent 2a2d9b20
No related branches found
No related tags found
No related merge requests found
......@@ -68,16 +68,6 @@ class DRGANTrainer(object):
self.number_of_ids = self.discriminator.number_of_ids
# fixed conditional noise - used to generate samples (one for each value of the conditional variable)
self.fixed_noise = torch.FloatTensor(self.conditional_dim, noise_dim, 1, 1).normal_(0, 1)
self.fixed_one_hot = torch.FloatTensor(self.conditional_dim, self.conditional_dim, 1, 1).zero_()
for k in range(self.conditional_dim):
self.fixed_one_hot[k, k] = 1
# TODO: figuring out the CPU/GPU thing - Guillaume HEUSCH, 17-11-2017
self.fixed_noise = Variable(self.fixed_noise)
self.fixed_one_hot = Variable(self.fixed_one_hot)
# binary cross-entropy loss
self.criterion_gan = nn.BCELoss()
self.criterion_pose = nn.CrossEntropyLoss() # index is expected as target (and not one-hot)
......@@ -92,6 +82,124 @@ class DRGANTrainer(object):
self.criterion_pose.cuda()
self.criterion_id.cuda()
def get_fixed_data(self, dataset, plot=False):
"""
Function to get fixed data, used to generate samples at every epoch.
**Parameters**
dataset: pyTorch Dataset
The dataset where the image should come from.
**Returns**
fixed_image: pyTorch Variable
The input image to encode/decode
fixed_noise: pyTorch Variable
The noise to inject in the decoder
"""
pose = 0
counter = 0
while pose != 6:
pose = dataset[counter]['pose']
fixed_index = counter
counter += 1
fixed_image = torch.Tensor(torch.Size(self.image_size))
fixed_image.copy_(dataset[counter]['image'])
# plot the image if asked for
if plot:
fixed_id = dataset[counter]['id']
fixed_pose = dataset[counter]['pose']
from matplotlib import pyplot
pyplot.title("ID -> {}, pose {}".format(fixed_id, fixed_pose))
pyplot.imshow(numpy.rollaxis(numpy.rollaxis(((fixed_image.numpy()+1)/2.), 2),2))
pyplot.show()
# expand the fixed image, so that to have a batch for all possible poses
fixed_image = fixed_image.expand(self.conditional_dim, self.image_size[0], self.image_size[1], self.image_size[2])
fixed_image = Variable(fixed_image)
# the noise to sample from
fixed_noise = torch.FloatTensor(self.conditional_dim, self.noise_dim, 1, 1).normal_(0, 1)
fixed_noise = Variable(fixed_noise)
# build the set of one-hot encoded pose to sample from
fixed_one_hot = torch.FloatTensor(self.conditional_dim, self.conditional_dim, 1, 1).zero_()
for k in range(self.conditional_dim):
fixed_one_hot[k, k] = 1
fixed_one_hot = Variable(fixed_one_hot)
return fixed_image, fixed_noise, fixed_one_hot
def load_models(self, models_dir):
"""
Function to load existing models (last saved ones)
**Parameters**
model_dir: path
dir where existing models are stored.
**Returns**
start_epoch: int
"""
import os
def getKey(item):
"""
Return the time of last modification. Used to sort saved models
**Parameters**
item: file
The file from which you woud get the modification time
**Returns**
The last modification time of the file
"""
return os.path.getmtime(item)
# to store model files
discriminator_files = []
encoder_files = []
decoder_files = []
start_epoch = 0
# populate lists of models
for f in os.listdir(models_dir):
if 'discriminator' in f:
discriminator_files.append(os.path.join(models_dir, f))
if 'encoder' in f:
encoder_files.append(os.path.join(models_dir, f))
if 'decoder' in f:
decoder_files.append(os.path.join(models_dir, f))
# if some models already exists ...
if len(discriminator_files) > 0:
# sort according to modification time
discriminator_files.sort(key=getKey)
discriminator_file = discriminator_files[-1]
encoder_file = encoder_files[-1]
decoder_file = decoder_files[-1]
# load models
self.discriminator.load_state_dict(torch.load(discriminator_file, map_location=lambda storage, loc: storage))
self.encoder.load_state_dict(torch.load(encoder_file, map_location=lambda storage, loc: storage))
self.decoder.load_state_dict(torch.load(decoder_file, map_location=lambda storage, loc: storage))
# get the start epoch
start_epoch = int((discriminator_file.split("_")[-1]).split(".")[0]) + 1
logger.info("Models from epoch {} loaded ! Starting at epoch {}".format(start_epoch - 1, start_epoch))
else:
logger.info("No exisiting models were found in {}".format(models_dir))
return start_epoch
def train(self, dataloader, n_epochs=10, learning_rate=0.0002, beta1=0.5, output_dir='out', plot=False, save_sample=10, pose_random=True):
"""
......@@ -131,7 +239,13 @@ class DRGANTrainer(object):
bob.io.base.create_directories_safe(models_dir)
bob.io.base.create_directories_safe(log_dir)
# be sure to save samples at each epoch at least
# check if some models already exists, in which case load the last ones
import os
start_epoch = 0
if os.path.isdir(models_dir):
start_epoch = self.load_models(models_dir)
# be sure to save samples at each epoch (at least)
if save_sample >= len(dataloader):
save_sample = len(dataloader) - 1
......@@ -144,56 +258,22 @@ class DRGANTrainer(object):
optimizerD = optim.Adam(self.discriminator.parameters(), lr=learning_rate, betas=(beta1, 0.999))
optimizerG = optim.Adam(generator_params, lr=learning_rate, betas=(beta1, 0.999))
# ==================================================================================================
# TODO: fix this - Guillaume HEUSCH, 04-12-2017
pose = 0
counter = 0
while pose != 6:
pose = dataloader.dataset[counter]['pose']
fixed_index = counter
counter += 1
fixed_image = torch.Tensor(torch.Size(self.image_size))
fixed_image.copy_(dataloader.dataset[counter]['image'])
vutils.save_image(fixed_image, '%s/fixed_id.png' % (output_dir), normalize=True)
# plot the image if asked for
if plot:
fixed_id = dataloader.dataset[counter]['id']
fixed_pose = dataloader.dataset[counter]['pose']
from matplotlib import pyplot
pyplot.title("ID -> {}, pose {}".format(fixed_id, fixed_pose))
pyplot.imshow(numpy.rollaxis(numpy.rollaxis(((fixed_image.numpy()+1)/2.), 2),2))
pyplot.show()
# expand the fixed image, so that to have a batch for all possible poses
fixed_image = fixed_image.expand(self.conditional_dim, self.image_size[0], self.image_size[1], self.image_size[2])
fixed_image = Variable(fixed_image)
# the noise to sample from
fixed_noise = torch.FloatTensor(self.conditional_dim, self.noise_dim, 1, 1).normal_(0, 1)
fixed_noise = Variable(fixed_noise)
# get fixed data to sample from after each epoch
fixed_image, fixed_noise, fixed_one_hot = self.get_fixed_data(dataloader.dataset, plot)
# build the set of one-hot encoded pose to sample from
fixed_one_hot = torch.FloatTensor(self.conditional_dim, self.conditional_dim, 1, 1).zero_()
for k in range(self.conditional_dim):
fixed_one_hot[k, k] = 1
fixed_one_hot = Variable(fixed_one_hot)
# ==================================================================================================
# statistics
# statistics - losses
discriminator_loss = []
generator_loss = []
# ================
# === LET'S GO ===
# ================
for epoch in range(n_epochs):
for epoch in range(start_epoch, n_epochs):
for i, data in enumerate(dataloader, 0):
start = time.time()
# get the data, pose and id labels
# get the batch data, pose and id labels
real_images = data['image']
poses = data['pose']
ids = data['id']
......@@ -305,19 +385,18 @@ class DRGANTrainer(object):
# perform optimization (i.e. update discriminator parameters)
errG = errG_id + errG_pose + errG_gan
optimizerG.step()
end = time.time()
# === END OF OPTIMIZATION ===
# log stuff
end = time.time()
logger.info("[{}/{}][{}/{}] => Loss D = {} -- Loss G = {} (time spent: {})".format(epoch, n_epochs, i, len(dataloader), errD.data[0], errG.data[0], (end-start)))
discriminator_loss.append(errD.data[0])
generator_loss.append(errG.data[0])
self.check_batch_statistics(output_real, output_fake, output_generated, ids, poses, random_poses, batch_size, log_dir)
# =====================
# SAVE GENERATED IMAGES
# =====================
# ==========
# SAVE STUFF
# ==========
if ((i % save_sample) == 0) and (i > 0):
# get a random example in this minibatch
......@@ -369,7 +448,10 @@ class DRGANTrainer(object):
fixed_encoded_id = self.encoder(fixed_image)
fake_examples = self.decoder(fixed_noise, fixed_one_hot, fixed_encoded_id)
vutils.save_image(fake_examples.data, '%s/fake_samples_epoch_%03d.png' % (output_dir, epoch), normalize=True)
to_save = torch.Tensor(torch.Size((self.conditional_dim + 1, self.image_size[0], self.image_size[1], self.image_size[2])))
to_save[0] = fixed_image.data[0]
to_save[1:] = fake_examples.data
vutils.save_image(to_save, '%s/fake_samples_epoch_%03d.png' % (images_dir, epoch), normalize=True)
if torch.cuda.is_available():
self.encoder = self.encoder.cuda()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment