Skip to content
Snippets Groups Projects
Commit 6515dc97 authored by Guillaume HEUSCH's avatar Guillaume HEUSCH
Browse files

[scripts] modified the training scripts (more options)

parent 885e51ed
Branches
Tags
No related merge requests found
......@@ -7,6 +7,7 @@
Usage:
%(prog)s [--latent-dim=<int>] [--noise-dim=<int>] [--conditional-dim=<int>]
[--batch-size=<int>] [--epochs=<int>] [--sample=<int>] [--light]
[--fixed-pose] [--keep-model=<int>] [--dropout]
[--output-dir=<path>] [--use-gpu] [--seed=<int>] [--verbose ...] [--plot]
Options:
......@@ -18,8 +19,11 @@ Options:
-b, --batch-size=<int> The size of your mini-batch [default: 64]
-e, --epochs=<int> The number of training epochs [default: 50]
-s, --sample=<int> Save generated images at every 'sample' batch iteration [default: 200]
-L, --light Use a lighter architecture (similar as DCGAN)
-o, --output-dir=<path> Dir to save the logs, models and images [default: ./drgan-mpie-casia/]
-L, --light Use a lighter architecture (similar as DCGAN)
-d, --dropout Apply dropout
-k, --keep-model=<int> To only keep the saved model every at every X epochs (plus the latest one) [default: 10]
-F, --fixed-pose Try to generate the same pose than the current training example.
-g, --use-gpu Use the GPU
-S, --seed=<int> The random seed [default: 3]
-v, --verbose Increase the verbosity (may appear multiple times).
......@@ -50,17 +54,11 @@ import bob.io.base
# torch
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.utils as vutils
from torch.autograd import Variable
# data and architecture from the package
from bob.learn.pytorch.datasets import MultiPIEDataset
from bob.learn.pytorch.datasets import CasiaDataset
#from torch.utils.data import ConcatDataset
from bob.learn.pytorch.datasets import ConcatDataset
from bob.learn.pytorch.datasets import RollChannels
......@@ -99,8 +97,14 @@ def main(user_input=None):
seed = int(args['--seed'])
use_gpu = bool(args['--use-gpu'])
plot = bool(args['--plot'])
random_pose = not(bool(args['--fixed-pose']))
keep_model = int(args['--keep-model'])
dropout = bool(args['--dropout'])
if bool(args['--light']):
if dropout:
logger.error("The light architecture does not support dropout - so drop this option ;)")
sys.exit()
from bob.learn.pytorch.architectures import DRGAN_encoder as drgan_encoder
from bob.learn.pytorch.architectures import DRGAN_decoder as drgan_decoder
from bob.learn.pytorch.architectures import DRGAN_discriminator as drgan_discriminator
......@@ -161,6 +165,8 @@ def main(user_input=None):
# === NETWORK ===
# ===============
encoder = drgan_encoder(image_size, latent_dim)
if dropout:
encoder = drgan_encoder(image_size, latent_dim, dropout=dropout)
encoder.apply(weights_init)
logger.info("Encoder architecture: {}".format(encoder))
......@@ -181,4 +187,4 @@ def main(user_input=None):
# ===============
trainer = DRGANTrainer(encoder, decoder, discriminator, image_size, batch_size=batch_size,
noise_dim=noise_dim, conditional_dim=conditional_dim, latent_dim=latent_dim, use_gpu=use_gpu, verbosity_level=verbosity_level)
trainer.train(dataloader, n_epochs=epochs, output_dir=output_dir, plot=plot, save_sample=sample)
trainer.train(dataloader, n_epochs=epochs, output_dir=output_dir, plot=plot, save_sample=sample, pose_random=random_pose, keep_model=keep_model)
......@@ -150,9 +150,6 @@ def main(user_input=None):
# ===============
# === NETWORK ===
# ===============
ngpu = 1 # usually we don't have more than one GPU
encoder = drgan_encoder(image_size, latent_dim)
if dropout:
encoder = drgan_encoder(image_size, latent_dim, dropout=dropout)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment