Skip to content
Snippets Groups Projects
Commit 6515dc97 authored by Guillaume HEUSCH's avatar Guillaume HEUSCH
Browse files

[scripts] modified the training scripts (more options)

parent 885e51ed
Branches
Tags
No related merge requests found
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
Usage: Usage:
%(prog)s [--latent-dim=<int>] [--noise-dim=<int>] [--conditional-dim=<int>] %(prog)s [--latent-dim=<int>] [--noise-dim=<int>] [--conditional-dim=<int>]
[--batch-size=<int>] [--epochs=<int>] [--sample=<int>] [--light] [--batch-size=<int>] [--epochs=<int>] [--sample=<int>] [--light]
[--fixed-pose] [--keep-model=<int>] [--dropout]
[--output-dir=<path>] [--use-gpu] [--seed=<int>] [--verbose ...] [--plot] [--output-dir=<path>] [--use-gpu] [--seed=<int>] [--verbose ...] [--plot]
Options: Options:
...@@ -18,8 +19,11 @@ Options: ...@@ -18,8 +19,11 @@ Options:
-b, --batch-size=<int> The size of your mini-batch [default: 64] -b, --batch-size=<int> The size of your mini-batch [default: 64]
-e, --epochs=<int> The number of training epochs [default: 50] -e, --epochs=<int> The number of training epochs [default: 50]
-s, --sample=<int> Save generated images at every 'sample' batch iteration [default: 200] -s, --sample=<int> Save generated images at every 'sample' batch iteration [default: 200]
-L, --light Use a lighter architecture (similar as DCGAN)
-o, --output-dir=<path> Dir to save the logs, models and images [default: ./drgan-mpie-casia/] -o, --output-dir=<path> Dir to save the logs, models and images [default: ./drgan-mpie-casia/]
-L, --light Use a lighter architecture (similar as DCGAN)
-d, --dropout Apply dropout
-k, --keep-model=<int> To only keep the saved model every at every X epochs (plus the latest one) [default: 10]
-F, --fixed-pose Try to generate the same pose than the current training example.
-g, --use-gpu Use the GPU -g, --use-gpu Use the GPU
-S, --seed=<int> The random seed [default: 3] -S, --seed=<int> The random seed [default: 3]
-v, --verbose Increase the verbosity (may appear multiple times). -v, --verbose Increase the verbosity (may appear multiple times).
...@@ -50,17 +54,11 @@ import bob.io.base ...@@ -50,17 +54,11 @@ import bob.io.base
# torch # torch
import torch import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms import torchvision.transforms as transforms
import torchvision.utils as vutils
from torch.autograd import Variable
# data and architecture from the package # data and architecture from the package
from bob.learn.pytorch.datasets import MultiPIEDataset from bob.learn.pytorch.datasets import MultiPIEDataset
from bob.learn.pytorch.datasets import CasiaDataset from bob.learn.pytorch.datasets import CasiaDataset
#from torch.utils.data import ConcatDataset
from bob.learn.pytorch.datasets import ConcatDataset from bob.learn.pytorch.datasets import ConcatDataset
from bob.learn.pytorch.datasets import RollChannels from bob.learn.pytorch.datasets import RollChannels
...@@ -99,8 +97,14 @@ def main(user_input=None): ...@@ -99,8 +97,14 @@ def main(user_input=None):
seed = int(args['--seed']) seed = int(args['--seed'])
use_gpu = bool(args['--use-gpu']) use_gpu = bool(args['--use-gpu'])
plot = bool(args['--plot']) plot = bool(args['--plot'])
random_pose = not(bool(args['--fixed-pose']))
keep_model = int(args['--keep-model'])
dropout = bool(args['--dropout'])
if bool(args['--light']): if bool(args['--light']):
if dropout:
logger.error("The light architecture does not support dropout - so drop this option ;)")
sys.exit()
from bob.learn.pytorch.architectures import DRGAN_encoder as drgan_encoder from bob.learn.pytorch.architectures import DRGAN_encoder as drgan_encoder
from bob.learn.pytorch.architectures import DRGAN_decoder as drgan_decoder from bob.learn.pytorch.architectures import DRGAN_decoder as drgan_decoder
from bob.learn.pytorch.architectures import DRGAN_discriminator as drgan_discriminator from bob.learn.pytorch.architectures import DRGAN_discriminator as drgan_discriminator
...@@ -161,6 +165,8 @@ def main(user_input=None): ...@@ -161,6 +165,8 @@ def main(user_input=None):
# === NETWORK === # === NETWORK ===
# =============== # ===============
encoder = drgan_encoder(image_size, latent_dim) encoder = drgan_encoder(image_size, latent_dim)
if dropout:
encoder = drgan_encoder(image_size, latent_dim, dropout=dropout)
encoder.apply(weights_init) encoder.apply(weights_init)
logger.info("Encoder architecture: {}".format(encoder)) logger.info("Encoder architecture: {}".format(encoder))
...@@ -181,4 +187,4 @@ def main(user_input=None): ...@@ -181,4 +187,4 @@ def main(user_input=None):
# =============== # ===============
trainer = DRGANTrainer(encoder, decoder, discriminator, image_size, batch_size=batch_size, trainer = DRGANTrainer(encoder, decoder, discriminator, image_size, batch_size=batch_size,
noise_dim=noise_dim, conditional_dim=conditional_dim, latent_dim=latent_dim, use_gpu=use_gpu, verbosity_level=verbosity_level) noise_dim=noise_dim, conditional_dim=conditional_dim, latent_dim=latent_dim, use_gpu=use_gpu, verbosity_level=verbosity_level)
trainer.train(dataloader, n_epochs=epochs, output_dir=output_dir, plot=plot, save_sample=sample) trainer.train(dataloader, n_epochs=epochs, output_dir=output_dir, plot=plot, save_sample=sample, pose_random=random_pose, keep_model=keep_model)
...@@ -150,9 +150,6 @@ def main(user_input=None): ...@@ -150,9 +150,6 @@ def main(user_input=None):
# =============== # ===============
# === NETWORK === # === NETWORK ===
# =============== # ===============
ngpu = 1 # usually we don't have more than one GPU
encoder = drgan_encoder(image_size, latent_dim) encoder = drgan_encoder(image_size, latent_dim)
if dropout: if dropout:
encoder = drgan_encoder(image_size, latent_dim, dropout=dropout) encoder = drgan_encoder(image_size, latent_dim, dropout=dropout)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment