diff --git a/bob/learn/pytorch/scripts/sample_drgan.py b/bob/learn/pytorch/scripts/sample_drgan.py new file mode 100644 index 0000000000000000000000000000000000000000..7953f81a9335d0752cef42a41756c9d564c53669 --- /dev/null +++ b/bob/learn/pytorch/scripts/sample_drgan.py @@ -0,0 +1,154 @@ +#!/usr/bin/env python +# encoding: utf-8 + + +""" Sample from a DR-GAN + +Usage: + %(prog)s <input_image> <encoder> <decoder> [--target-pose=<int>] [--light] + [--latent-dim=<int>] [--noise-dim=<int>] [--conditional-dim=<int>] + [--output-dir=<path>][--verbose ...] [--plot] + +Options: + -h, --help Show this screen. + -V, --version Show version. + -l, --light Use a lighter architecture that the original. + -p, --target-pose=<int> the target pose of the generated image. [default: 6] + -L, --latent-dim=<int> the dimension of the encoded ID [default: 320] + -n, --noise-dim=<int> the dimension of the noise [default: 50] + -c, --conditional-dim=<int> the dimension of the conditioning variable [default: 13] + -o, --output-dir=<path> Dir to save the logs, models and images [default: ./samples/] + -v, --verbose Increase the verbosity (may appear multiple times). + -P, --plot Show the generated image. + +Example: + + To generate a sample of the provided input image with the target pose + + $ %(prog)s <input_image> --target-pose 6 --epochs 25 --output-dir samples + +See '%(prog)s --help' for more information. + +""" + +import os, sys +import pkg_resources + +import bob.core +logger = bob.core.log.setup("bob.learn.pytorch") + +from docopt import docopt + +version = pkg_resources.require('bob.learn.pytorch')[0].version + +import numpy +import bob.io.base +import bob.io.image + +# torch +import torch +import torch.nn as nn +import torchvision.transforms as transforms +import torchvision.utils as vutils +from torch.autograd import Variable + +# data and architecture from the package +from bob.learn.pytorch.datasets import RollChannels +from bob.learn.pytorch.datasets import ToTensor +from bob.learn.pytorch.datasets import Normalize + +from bob.learn.pytorch.architectures import weights_init + + +def main(user_input=None): + + # Parse the command-line arguments + if user_input is not None: + arguments = user_input + else: + arguments = sys.argv[1:] + + prog = os.path.basename(sys.argv[0]) + completions = dict(prog=prog, version=version,) + args = docopt(__doc__ % completions,argv=arguments,version='Train DR-GAN (%s)' % version,) + + # verbosity + verbosity_level = args['--verbose'] + bob.core.log.set_verbosity_level(logger, verbosity_level) + + # get the arguments + encoder_path = args['<encoder>'] + decoder_path = args['<decoder>'] + + noise_dim = int(args['--noise-dim']) + latent_dim = int(args['--latent-dim']) + conditional_dim = int(args['--conditional-dim']) + output_dir = str(args['--output-dir']) + plot = bool(args['--plot']) + + if bool(args['--light']): + from bob.learn.pytorch.architectures import DRGAN_encoder as drgan_encoder + from bob.learn.pytorch.architectures import DRGAN_decoder as drgan_decoder + else: + from bob.learn.pytorch.architectures import DRGANOriginal_encoder as drgan_encoder + from bob.learn.pytorch.architectures import DRGANOriginal_decoder as drgan_decoder + + # process on the arguments / options + bob.io.base.create_directories_safe(output_dir) + + # ============ + # === DATA === + # ============ + input_image = bob.io.base.read(args['<input_image>']) + print input_image.shape + + if bool(args['--plot']): + from matplotlib import pyplot + pyplot.title("Input Image") + pyplot.imshow(numpy.rollaxis(numpy.rollaxis(input_image, 2),2)) + pyplot.show() + + # check if we have the right image size + if bool(args['--light']): + assert (input_image.shape == (3, 64, 64)), "Using the DRGAN light model, image size shoud be [3x64x64] (CxHxW)" + else: + assert input_image.shape == (3, 96, 96), "Using the DRGAN model, image size shoud be [3x96x96] (CxHxW)" + + + # =============== + # === NETWORK === + # =============== + encoder = drgan_encoder(input_image.shape, latent_dim) + encoder.load_state_dict(torch.load(encoder_path, map_location=lambda storage, loc: storage)) + + decoder = drgan_decoder(input_image.shape, noise_dim, latent_dim, conditional_dim) + decoder.load_state_dict(torch.load(decoder_path, map_location=lambda storage, loc: storage)) + + # ================ + # === GENERATE === + # ================ + + # encode + input_image = numpy.rollaxis(numpy.rollaxis(input_image, 2),2) + to_tensor = transforms.ToTensor() + norm = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) + input_image = to_tensor(input_image) + input_image = norm(input_image) + input_image = input_image.unsqueeze(0) + encoded_id = encoder.forward(Variable(input_image)) + + # decode + noise = torch.FloatTensor(1, noise_dim, 1, 1).normal_(0, 1) + one_hot_vector = torch.FloatTensor(1, conditional_dim, 1, 1).zero_() + one_hot_vector[0, int(args['--target-pose'])] = 1 + generated = decoder(Variable(noise), Variable(one_hot_vector), encoded_id) + generated = generated.squeeze(0) + generated_image = (generated.data + 1)/2. + + if bool(args['--plot']): + from matplotlib import pyplot + pyplot.title("Generated Image") + pyplot.imshow(numpy.rollaxis(numpy.rollaxis(generated_image.numpy(), 2),2)) + pyplot.show() + + diff --git a/setup.py b/setup.py index 0086f66ae03614ccf50a1710dbb772ec4383476e..2495cab51fb4ffb5043984f6840af466affd65fe 100644 --- a/setup.py +++ b/setup.py @@ -81,6 +81,7 @@ setup( 'train_drgan_mpie_casia.py = bob.learn.pytorch.scripts.train_drgan_mpie_casia:main', 'show_training_images.py = bob.learn.pytorch.scripts.show_training_images:main', 'show_training_stats.py = bob.learn.pytorch.scripts.show_training_stats:main', + 'sample_drgan.py = bob.learn.pytorch.scripts.sample_drgan:main', ],