From 7e76f93c9bde7e61916cec4df52e47b19f5d75d0 Mon Sep 17 00:00:00 2001
From: Guillaume HEUSCH <guillaume.heusch@idiap.ch>
Date: Wed, 6 Dec 2017 16:30:38 +0100
Subject: [PATCH] [script] added script to sample from a DR-GAN

---
 bob/learn/pytorch/scripts/sample_drgan.py | 154 ++++++++++++++++++++++
 setup.py                                  |   1 +
 2 files changed, 155 insertions(+)
 create mode 100644 bob/learn/pytorch/scripts/sample_drgan.py

diff --git a/bob/learn/pytorch/scripts/sample_drgan.py b/bob/learn/pytorch/scripts/sample_drgan.py
new file mode 100644
index 0000000..7953f81
--- /dev/null
+++ b/bob/learn/pytorch/scripts/sample_drgan.py
@@ -0,0 +1,154 @@
+#!/usr/bin/env python
+# encoding: utf-8
+
+
+""" Sample from a DR-GAN 
+
+Usage:
+  %(prog)s <input_image> <encoder> <decoder> [--target-pose=<int>] [--light] 
+           [--latent-dim=<int>] [--noise-dim=<int>] [--conditional-dim=<int>] 
+           [--output-dir=<path>][--verbose ...] [--plot]
+
+Options:
+  -h, --help                    Show this screen.
+  -V, --version                 Show version.
+  -l, --light                   Use a lighter architecture that the original. 
+  -p, --target-pose=<int>       the target pose of the generated image. [default: 6]
+  -L, --latent-dim=<int>        the dimension of the encoded ID [default: 320]
+  -n, --noise-dim=<int>         the dimension of the noise [default: 50]
+  -c, --conditional-dim=<int>   the dimension of the conditioning variable [default: 13]
+  -o, --output-dir=<path>       Dir to save the logs, models and images [default: ./samples/] 
+  -v, --verbose                 Increase the verbosity (may appear multiple times).
+  -P, --plot                    Show the generated image. 
+
+Example:
+
+  To generate a sample of the provided input image with the target pose 
+
+    $ %(prog)s <input_image> --target-pose 6 --epochs 25 --output-dir samples 
+
+See '%(prog)s --help' for more information.
+
+"""
+
+import os, sys
+import pkg_resources
+
+import bob.core
+logger = bob.core.log.setup("bob.learn.pytorch")
+
+from docopt import docopt
+
+version = pkg_resources.require('bob.learn.pytorch')[0].version
+
+import numpy
+import bob.io.base
+import bob.io.image
+
+# torch
+import torch
+import torch.nn as nn
+import torchvision.transforms as transforms
+import torchvision.utils as vutils
+from torch.autograd import Variable
+
+# data and architecture from the package
+from bob.learn.pytorch.datasets import RollChannels
+from bob.learn.pytorch.datasets import ToTensor
+from bob.learn.pytorch.datasets import Normalize
+
+from bob.learn.pytorch.architectures import weights_init
+
+
+def main(user_input=None):
+  
+  # Parse the command-line arguments
+  if user_input is not None:
+      arguments = user_input
+  else:
+      arguments = sys.argv[1:]
+
+  prog = os.path.basename(sys.argv[0])
+  completions = dict(prog=prog, version=version,)
+  args = docopt(__doc__ % completions,argv=arguments,version='Train DR-GAN (%s)' % version,)
+
+  # verbosity
+  verbosity_level = args['--verbose']
+  bob.core.log.set_verbosity_level(logger, verbosity_level)
+
+  # get the arguments
+  encoder_path = args['<encoder>']
+  decoder_path = args['<decoder>']
+  
+  noise_dim = int(args['--noise-dim'])
+  latent_dim = int(args['--latent-dim'])
+  conditional_dim = int(args['--conditional-dim'])
+  output_dir = str(args['--output-dir'])
+  plot = bool(args['--plot'])
+
+  if bool(args['--light']):
+    from bob.learn.pytorch.architectures import DRGAN_encoder as drgan_encoder
+    from bob.learn.pytorch.architectures import DRGAN_decoder as drgan_decoder
+  else:
+    from bob.learn.pytorch.architectures import DRGANOriginal_encoder as drgan_encoder
+    from bob.learn.pytorch.architectures import DRGANOriginal_decoder as drgan_decoder
+
+  # process on the arguments / options
+  bob.io.base.create_directories_safe(output_dir)
+
+  # ============
+  # === DATA ===
+  # ============
+  input_image = bob.io.base.read(args['<input_image>'])
+  print input_image.shape
+
+  if bool(args['--plot']):
+    from matplotlib import pyplot
+    pyplot.title("Input Image")
+    pyplot.imshow(numpy.rollaxis(numpy.rollaxis(input_image, 2),2))
+    pyplot.show() 
+
+  # check if we have the right image size
+  if bool(args['--light']):
+    assert (input_image.shape == (3, 64, 64)), "Using the DRGAN light model, image size shoud be [3x64x64] (CxHxW)"
+  else:
+    assert input_image.shape == (3, 96, 96), "Using the DRGAN model, image size shoud be [3x96x96] (CxHxW)"
+
+
+  # ===============
+  # === NETWORK ===
+  # =============== 
+  encoder = drgan_encoder(input_image.shape, latent_dim)
+  encoder.load_state_dict(torch.load(encoder_path, map_location=lambda storage, loc: storage)) 
+
+  decoder = drgan_decoder(input_image.shape, noise_dim, latent_dim, conditional_dim)
+  decoder.load_state_dict(torch.load(decoder_path, map_location=lambda storage, loc: storage)) 
+
+  # ================
+  # === GENERATE ===
+  # ================
+  
+  # encode
+  input_image = numpy.rollaxis(numpy.rollaxis(input_image, 2),2)
+  to_tensor = transforms.ToTensor()
+  norm = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
+  input_image = to_tensor(input_image)
+  input_image = norm(input_image)
+  input_image = input_image.unsqueeze(0)
+  encoded_id = encoder.forward(Variable(input_image))
+
+  # decode
+  noise = torch.FloatTensor(1, noise_dim, 1, 1).normal_(0, 1)
+  one_hot_vector = torch.FloatTensor(1, conditional_dim, 1, 1).zero_()
+  one_hot_vector[0, int(args['--target-pose'])] = 1
+  generated = decoder(Variable(noise), Variable(one_hot_vector), encoded_id)
+  generated = generated.squeeze(0)
+  generated_image = (generated.data + 1)/2.
+
+  if bool(args['--plot']):
+    from matplotlib import pyplot
+    pyplot.title("Generated Image")
+    pyplot.imshow(numpy.rollaxis(numpy.rollaxis(generated_image.numpy(), 2),2))
+    pyplot.show()
+
+
diff --git a/setup.py b/setup.py
index 0086f66..2495cab 100644
--- a/setup.py
+++ b/setup.py
@@ -81,6 +81,7 @@ setup(
         'train_drgan_mpie_casia.py = bob.learn.pytorch.scripts.train_drgan_mpie_casia:main', 
         'show_training_images.py = bob.learn.pytorch.scripts.show_training_images:main', 
         'show_training_stats.py = bob.learn.pytorch.scripts.show_training_stats:main', 
+        'sample_drgan.py = bob.learn.pytorch.scripts.sample_drgan:main', 
       ],
 
 
-- 
GitLab