Skip to content
Snippets Groups Projects
Commit 9a3839cb authored by Guillaume HEUSCH's avatar Guillaume HEUSCH
Browse files

[script] sample: remove unecessary imports, added some verbosity and the image saving

parent 85be403f
No related branches found
No related tags found
No related merge requests found
...@@ -7,17 +7,17 @@ ...@@ -7,17 +7,17 @@
Usage: Usage:
%(prog)s <input_image> <encoder> <decoder> [--target-pose=<int>] [--light] %(prog)s <input_image> <encoder> <decoder> [--target-pose=<int>] [--light]
[--latent-dim=<int>] [--noise-dim=<int>] [--conditional-dim=<int>] [--latent-dim=<int>] [--noise-dim=<int>] [--conditional-dim=<int>]
[--output-dir=<path>][--verbose ...] [--plot] [--output-file=<path>][--verbose ...] [--plot]
Options: Options:
-h, --help Show this screen. -h, --help Show this screen.
-V, --version Show version. -V, --version Show version.
-l, --light Use a lighter architecture that the original.
-p, --target-pose=<int> the target pose of the generated image. [default: 6] -p, --target-pose=<int> the target pose of the generated image. [default: 6]
-l, --light Use a lighter architecture that the original.
-L, --latent-dim=<int> the dimension of the encoded ID [default: 320] -L, --latent-dim=<int> the dimension of the encoded ID [default: 320]
-n, --noise-dim=<int> the dimension of the noise [default: 50] -n, --noise-dim=<int> the dimension of the noise [default: 100]
-c, --conditional-dim=<int> the dimension of the conditioning variable [default: 13] -c, --conditional-dim=<int> the dimension of the conditioning variable [default: 13]
-o, --output-dir=<path> Dir to save the logs, models and images [default: ./samples/] -o, --output-file=<path> Filename of the sampled image [default: ./sampled.png]
-v, --verbose Increase the verbosity (may appear multiple times). -v, --verbose Increase the verbosity (may appear multiple times).
-P, --plot Show the generated image. -P, --plot Show the generated image.
...@@ -25,7 +25,11 @@ Example: ...@@ -25,7 +25,11 @@ Example:
To generate a sample of the provided input image with the target pose To generate a sample of the provided input image with the target pose
$ %(prog)s <input_image> --target-pose 6 --epochs 25 --output-dir samples $ %(prog)s <input_image> path/to/encoder/ path/to/decoder --target-pose 6 --output-file sampled.png
Note that the encoder and decoder must be pyTorch models
See '%(prog)s --help' for more information. See '%(prog)s --help' for more information.
...@@ -49,16 +53,8 @@ import bob.io.image ...@@ -49,16 +53,8 @@ import bob.io.image
import torch import torch
import torch.nn as nn import torch.nn as nn
import torchvision.transforms as transforms import torchvision.transforms as transforms
import torchvision.utils as vutils
from torch.autograd import Variable from torch.autograd import Variable
# data and architecture from the package
from bob.learn.pytorch.datasets import RollChannels
from bob.learn.pytorch.datasets import ToTensor
from bob.learn.pytorch.datasets import Normalize
from bob.learn.pytorch.architectures import weights_init
def main(user_input=None): def main(user_input=None):
...@@ -70,7 +66,7 @@ def main(user_input=None): ...@@ -70,7 +66,7 @@ def main(user_input=None):
prog = os.path.basename(sys.argv[0]) prog = os.path.basename(sys.argv[0])
completions = dict(prog=prog, version=version,) completions = dict(prog=prog, version=version,)
args = docopt(__doc__ % completions,argv=arguments,version='Train DR-GAN (%s)' % version,) args = docopt(__doc__ % completions,argv=arguments,version='Sample from a DR-GAN (%s)' % version,)
# verbosity # verbosity
verbosity_level = args['--verbose'] verbosity_level = args['--verbose']
...@@ -79,11 +75,10 @@ def main(user_input=None): ...@@ -79,11 +75,10 @@ def main(user_input=None):
# get the arguments # get the arguments
encoder_path = args['<encoder>'] encoder_path = args['<encoder>']
decoder_path = args['<decoder>'] decoder_path = args['<decoder>']
noise_dim = int(args['--noise-dim']) noise_dim = int(args['--noise-dim'])
latent_dim = int(args['--latent-dim']) latent_dim = int(args['--latent-dim'])
conditional_dim = int(args['--conditional-dim']) conditional_dim = int(args['--conditional-dim'])
output_dir = str(args['--output-dir']) output_file = str(args['--output-file'])
plot = bool(args['--plot']) plot = bool(args['--plot'])
if bool(args['--light']): if bool(args['--light']):
...@@ -93,14 +88,14 @@ def main(user_input=None): ...@@ -93,14 +88,14 @@ def main(user_input=None):
from bob.learn.pytorch.architectures import DRGANOriginal_encoder as drgan_encoder from bob.learn.pytorch.architectures import DRGANOriginal_encoder as drgan_encoder
from bob.learn.pytorch.architectures import DRGANOriginal_decoder as drgan_decoder from bob.learn.pytorch.architectures import DRGANOriginal_decoder as drgan_decoder
# process on the arguments / options dirname = os.path.dirname(output_file)
bob.io.base.create_directories_safe(output_dir) bob.io.base.create_directories_safe(dirname)
# ============ # ============
# === DATA === # === DATA ===
# ============ # ============
input_image = bob.io.base.read(args['<input_image>']) input_image = bob.io.base.read(args['<input_image>'])
print input_image.shape logger.info("Processing image: {}".format(args['<input_image>']))
if bool(args['--plot']): if bool(args['--plot']):
from matplotlib import pyplot from matplotlib import pyplot
...@@ -114,20 +109,22 @@ def main(user_input=None): ...@@ -114,20 +109,22 @@ def main(user_input=None):
else: else:
assert input_image.shape == (3, 96, 96), "Using the DRGAN model, image size shoud be [3x96x96] (CxHxW)" assert input_image.shape == (3, 96, 96), "Using the DRGAN model, image size shoud be [3x96x96] (CxHxW)"
# =============== # ===============
# === NETWORK === # === NETWORK ===
# =============== # ===============
encoder = drgan_encoder(input_image.shape, latent_dim) encoder = drgan_encoder(input_image.shape, latent_dim)
encoder.load_state_dict(torch.load(encoder_path, map_location=lambda storage, loc: storage)) encoder.load_state_dict(torch.load(encoder_path, map_location=lambda storage, loc: storage))
logger.info("encoder: {}".format(encoder_path))
decoder = drgan_decoder(input_image.shape, noise_dim, latent_dim, conditional_dim) decoder = drgan_decoder(input_image.shape, noise_dim, latent_dim, conditional_dim)
decoder.load_state_dict(torch.load(decoder_path, map_location=lambda storage, loc: storage)) decoder.load_state_dict(torch.load(decoder_path, map_location=lambda storage, loc: storage))
logger.info("decoder: {}".format(decoder_path))
# ================ # ================
# === GENERATE === # === GENERATE ===
# ================ # ================
logger.info("Generating image with target pose {}".format(args['--target-pose']))
# encode # encode
input_image = numpy.rollaxis(numpy.rollaxis(input_image, 2),2) input_image = numpy.rollaxis(numpy.rollaxis(input_image, 2),2)
to_tensor = transforms.ToTensor() to_tensor = transforms.ToTensor()
...@@ -152,3 +149,7 @@ def main(user_input=None): ...@@ -152,3 +149,7 @@ def main(user_input=None):
pyplot.show() pyplot.show()
# save sampled image
logger.info("Saving image as {}".format(output_file))
bob.io.base.save((generated_image.numpy()*255).astype('uint8'), output_file)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment