Skip to content
Snippets Groups Projects
Commit 7e76f93c authored by Guillaume HEUSCH's avatar Guillaume HEUSCH
Browse files

[script] added script to sample from a DR-GAN

parent 01543403
Branches
Tags
No related merge requests found
#!/usr/bin/env python
# encoding: utf-8
""" Sample from a DR-GAN
Usage:
%(prog)s <input_image> <encoder> <decoder> [--target-pose=<int>] [--light]
[--latent-dim=<int>] [--noise-dim=<int>] [--conditional-dim=<int>]
[--output-dir=<path>][--verbose ...] [--plot]
Options:
-h, --help Show this screen.
-V, --version Show version.
-l, --light Use a lighter architecture that the original.
-p, --target-pose=<int> the target pose of the generated image. [default: 6]
-L, --latent-dim=<int> the dimension of the encoded ID [default: 320]
-n, --noise-dim=<int> the dimension of the noise [default: 50]
-c, --conditional-dim=<int> the dimension of the conditioning variable [default: 13]
-o, --output-dir=<path> Dir to save the logs, models and images [default: ./samples/]
-v, --verbose Increase the verbosity (may appear multiple times).
-P, --plot Show the generated image.
Example:
To generate a sample of the provided input image with the target pose
$ %(prog)s <input_image> --target-pose 6 --epochs 25 --output-dir samples
See '%(prog)s --help' for more information.
"""
import os, sys
import pkg_resources
import bob.core
logger = bob.core.log.setup("bob.learn.pytorch")
from docopt import docopt
version = pkg_resources.require('bob.learn.pytorch')[0].version
import numpy
import bob.io.base
import bob.io.image
# torch
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.utils as vutils
from torch.autograd import Variable
# data and architecture from the package
from bob.learn.pytorch.datasets import RollChannels
from bob.learn.pytorch.datasets import ToTensor
from bob.learn.pytorch.datasets import Normalize
from bob.learn.pytorch.architectures import weights_init
def main(user_input=None):
# Parse the command-line arguments
if user_input is not None:
arguments = user_input
else:
arguments = sys.argv[1:]
prog = os.path.basename(sys.argv[0])
completions = dict(prog=prog, version=version,)
args = docopt(__doc__ % completions,argv=arguments,version='Train DR-GAN (%s)' % version,)
# verbosity
verbosity_level = args['--verbose']
bob.core.log.set_verbosity_level(logger, verbosity_level)
# get the arguments
encoder_path = args['<encoder>']
decoder_path = args['<decoder>']
noise_dim = int(args['--noise-dim'])
latent_dim = int(args['--latent-dim'])
conditional_dim = int(args['--conditional-dim'])
output_dir = str(args['--output-dir'])
plot = bool(args['--plot'])
if bool(args['--light']):
from bob.learn.pytorch.architectures import DRGAN_encoder as drgan_encoder
from bob.learn.pytorch.architectures import DRGAN_decoder as drgan_decoder
else:
from bob.learn.pytorch.architectures import DRGANOriginal_encoder as drgan_encoder
from bob.learn.pytorch.architectures import DRGANOriginal_decoder as drgan_decoder
# process on the arguments / options
bob.io.base.create_directories_safe(output_dir)
# ============
# === DATA ===
# ============
input_image = bob.io.base.read(args['<input_image>'])
print input_image.shape
if bool(args['--plot']):
from matplotlib import pyplot
pyplot.title("Input Image")
pyplot.imshow(numpy.rollaxis(numpy.rollaxis(input_image, 2),2))
pyplot.show()
# check if we have the right image size
if bool(args['--light']):
assert (input_image.shape == (3, 64, 64)), "Using the DRGAN light model, image size shoud be [3x64x64] (CxHxW)"
else:
assert input_image.shape == (3, 96, 96), "Using the DRGAN model, image size shoud be [3x96x96] (CxHxW)"
# ===============
# === NETWORK ===
# ===============
encoder = drgan_encoder(input_image.shape, latent_dim)
encoder.load_state_dict(torch.load(encoder_path, map_location=lambda storage, loc: storage))
decoder = drgan_decoder(input_image.shape, noise_dim, latent_dim, conditional_dim)
decoder.load_state_dict(torch.load(decoder_path, map_location=lambda storage, loc: storage))
# ================
# === GENERATE ===
# ================
# encode
input_image = numpy.rollaxis(numpy.rollaxis(input_image, 2),2)
to_tensor = transforms.ToTensor()
norm = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
input_image = to_tensor(input_image)
input_image = norm(input_image)
input_image = input_image.unsqueeze(0)
encoded_id = encoder.forward(Variable(input_image))
# decode
noise = torch.FloatTensor(1, noise_dim, 1, 1).normal_(0, 1)
one_hot_vector = torch.FloatTensor(1, conditional_dim, 1, 1).zero_()
one_hot_vector[0, int(args['--target-pose'])] = 1
generated = decoder(Variable(noise), Variable(one_hot_vector), encoded_id)
generated = generated.squeeze(0)
generated_image = (generated.data + 1)/2.
if bool(args['--plot']):
from matplotlib import pyplot
pyplot.title("Generated Image")
pyplot.imshow(numpy.rollaxis(numpy.rollaxis(generated_image.numpy(), 2),2))
pyplot.show()
......@@ -81,6 +81,7 @@ setup(
'train_drgan_mpie_casia.py = bob.learn.pytorch.scripts.train_drgan_mpie_casia:main',
'show_training_images.py = bob.learn.pytorch.scripts.show_training_images:main',
'show_training_stats.py = bob.learn.pytorch.scripts.show_training_stats:main',
'sample_drgan.py = bob.learn.pytorch.scripts.sample_drgan:main',
],
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment