Commit e6a05829 authored by Guillaume HEUSCH's avatar Guillaume HEUSCH

[script] fixed script to train conditional GAN, addedit in setup.py

parent df02a9a3
Pipeline #22457 passed with stage
in 10 minutes and 59 seconds
#!/usr/bin/env python #!/usr/bin/env python
# encoding: utf-8 # encoding: utf-8
#!/usr/bin/env python
# encoding: utf-8
""" Train a Conditional GAN """ Train a Conditional GAN
Usage: Usage:
%(prog)s [--noise-dim=<int>] [--conditional-dim=<int>] %(prog)s <configuration>
[--noise-dim=<int>] [--conditional-dim=<int>]
[--batch-size=<int>] [--epochs=<int>] [--sample=<int>] [--batch-size=<int>] [--epochs=<int>] [--sample=<int>]
[--output-dir=<path>] [--use-gpu] [--seed=<int>] [--verbose ...] [--output-dir=<path>] [--use-gpu] [--seed=<int>] [--verbose ...]
Arguments:
<configuration> A configuration file, defining the dataset and the network
Options: Options:
-h, --help Show this screen. -h, --help Show this screen.
-V, --version Show version. -V, --version Show version.
...@@ -16,17 +21,19 @@ Options: ...@@ -16,17 +21,19 @@ Options:
-c, --conditional-dim=<int> The dimension of the conditional variable [default: 13] -c, --conditional-dim=<int> The dimension of the conditional variable [default: 13]
-b, --batch-size=<int> The size of your mini-batch [default: 64] -b, --batch-size=<int> The size of your mini-batch [default: 64]
-e, --epochs=<int> The number of training epochs [default: 100] -e, --epochs=<int> The number of training epochs [default: 100]
-s, --sample=<int> Save generated images at every 'sample' batch iteration [default: 100000000000] -s, --sample=<int> Save generated images at every 'sample' batch iteration [default: 1e10]
-o, --output-dir=<path> Dir to save the logs, models and images [default: ./cgan-multipie/] -o, --output-dir=<path> Dir to save the logs, models and images [default: ./conditionalgan/]
-g, --use-gpu Use the GPU -g, --use-gpu Use the GPU
-S, --seed=<int> The random seed [default: 3] -S, --seed=<int> The random seed [default: 3]
-v, --verbose Increase the verbosity (may appear multiple times). -v, --verbose Increase the verbosity (may appear multiple times).
Note that arguments provided directly by command-line will override the ones in the configuration file.
Example: Example:
To run the training process To run the training process
$ %(prog)s --batch-size 64 --epochs 25 --output-dir drgan $ %(prog)s config.py
See '%(prog)s --help' for more information. See '%(prog)s --help' for more information.
...@@ -35,35 +42,19 @@ See '%(prog)s --help' for more information. ...@@ -35,35 +42,19 @@ See '%(prog)s --help' for more information.
import os, sys import os, sys
import pkg_resources import pkg_resources
import bob.core import torch
logger = bob.core.log.setup("bob.learn.pytorch") import numpy
from docopt import docopt from docopt import docopt
version = pkg_resources.require('bob.learn.pytorch')[0].version import bob.core
logger = bob.core.log.setup("bob.learn.pytorch")
import numpy
import bob.io.base import bob.io.base
from bob.extension.config import load
from bob.learn.pytorch.trainers import ConditionalGANTrainer
from bob.learn.pytorch.utils import get_parameter
# torch version = pkg_resources.require('bob.learn.pytorch')[0].version
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.utils as vutils
from torch.autograd import Variable
# data and architecture from the package
from bob.learn.pytorch.datasets import MultiPIEDataset
from bob.learn.pytorch.datasets import RollChannels
from bob.learn.pytorch.datasets import ToTensor
from bob.learn.pytorch.datasets import Normalize
from bob.learn.pytorch.architectures import weights_init
from bob.learn.pytorch.architectures import ConditionalGAN_generator as cgenerator
from bob.learn.pytorch.architectures import ConditionalGAN_discriminator as cdiscriminator
from bob.learn.pytorch.trainers import ConditionalGANTrainer as ctrainer
def main(user_input=None): def main(user_input=None):
...@@ -77,68 +68,57 @@ def main(user_input=None): ...@@ -77,68 +68,57 @@ def main(user_input=None):
completions = dict(prog=prog, version=version,) completions = dict(prog=prog, version=version,)
args = docopt(__doc__ % completions,argv=arguments,version='Train conditional GAN (%s)' % version,) args = docopt(__doc__ % completions,argv=arguments,version='Train conditional GAN (%s)' % version,)
# verbosity # load configuration file
verbosity_level = args['--verbose'] configuration = load([os.path.join(args['<configuration>'])])
bob.core.log.set_verbosity_level(logger, verbosity_level)
# get various parameters, either from config file or command-line
# get the arguments noise_dim = get_parameter(args, configuration, 'noise_dim', 100)
noise_dim = int(args['--noise-dim']) conditional_dim = get_parameter(args, configuration, 'conditional_dim', 13)
conditional_dim = int(args['--conditional-dim']) batch_size = get_parameter(args, configuration, 'batch_size', 64)
batch_size = int(args['--batch-size']) epochs = get_parameter(args, configuration, 'epochs', 20)
epochs = int(args['--epochs']) sample = get_parameter(args, configuration, 'sample', 1e10)
sample = int(args['--sample']) seed = get_parameter(args, configuration, 'seed', 3)
output_dir = str(args['--output-dir']) output_dir = get_parameter(args, configuration, 'output_dir', 'training')
seed = int(args['--seed']) use_gpu = get_parameter(args, configuration, 'use_gpu', False)
use_gpu = bool(args['--use-gpu']) verbosity_level = get_parameter(args, configuration, 'verbose', 0)
bob.core.log.set_verbosity_level(logger, verbosity_level)
images_dir = os.path.join(output_dir, 'samples') images_dir = os.path.join(output_dir, 'samples')
log_dir = os.path.join(output_dir, 'logs') log_dir = os.path.join(output_dir, 'logs')
model_dir = os.path.join(output_dir, 'models') model_dir = os.path.join(output_dir, 'models')
bob.io.base.create_directories_safe(images_dir)
bob.io.base.create_directories_safe(log_dir)
bob.io.base.create_directories_safe(images_dir)
# print parameters
logger.debug("Noise dimension = {}".format(noise_dim))
logger.debug("Conditional dimension = {}".format(conditional_dim))
logger.debug("Batch size = {}".format(batch_size))
logger.debug("Epochs = {}".format(epochs))
logger.debug("Sample = {}".format(sample))
logger.debug("Seed = {}".format(seed))
logger.debug("Output directory = {}".format(output_dir))
logger.debug("Use GPU = {}".format(use_gpu))
# process on the arguments / options # process on the arguments / options
torch.manual_seed(seed) torch.manual_seed(seed)
if use_gpu: if use_gpu:
torch.cuda.manual_seed_all(seed) torch.cuda.manual_seed_all(seed)
if torch.cuda.is_available() and not use_gpu: if torch.cuda.is_available() and not use_gpu:
logger.warn("You have a CUDA device, so you should probably run with --use-gpu") logger.warn("You have a CUDA device, so you should probably run with --use-gpu")
bob.io.base.create_directories_safe(images_dir)
bob.io.base.create_directories_safe(log_dir)
bob.io.base.create_directories_safe(images_dir)
# ============ # get data
# === DATA === if hasattr(configuration, 'dataset'):
# ============ dataloader = torch.utils.data.DataLoader(configuration.dataset, batch_size=batch_size, shuffle=True)
logger.info("There are {} training images".format(len(configuration.dataset)))
# WARNING with the transforms ... act on labels too, at some point, I may have to write my own else:
# Also, in 'ToTensor', there is a reshape performed from: HxWxC to CxHxW logger.error("Please provide a dataset in your configuration file !")
face_dataset = MultiPIEDataset(#root_dir='/Users/guillaumeheusch/work/idiap/data/multipie-cropped-64x64', sys.exit()
root_dir='/idiap/temp/heusch/data/multipie-cropped-64x64',
frontal_only=False,
transform=transforms.Compose([
RollChannels(), # bob to skimage:
ToTensor(),
Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
)
dataloader = torch.utils.data.DataLoader(face_dataset, batch_size=batch_size, shuffle=True)
logger.info("There are {} training images".format(len(face_dataset)))
# ===============
# === NETWORK ===
# ===============
ngpu = 1 # usually we don't have more than one GPU
generator = cgenerator(noise_dim, conditional_dim) # train the model
generator.apply(weights_init) if hasattr(configuration, 'generator') and hasattr(configuration, 'discriminator'):
logger.info("Generator architecture: {}".format(generator)) trainer = ConditionalGANTrainer(configuration.generator, configuration.discriminator, [3, 64, 64], batch_size=batch_size, noise_dim=noise_dim, conditional_dim=conditional_dim, use_gpu=use_gpu, verbosity_level=verbosity_level)
trainer.train(dataloader, n_epochs=epochs, output_dir=output_dir)
discriminator = cdiscriminator(conditional_dim) else:
discriminator.apply(weights_init) logger.error("Please provide both a generator and a discriminator in your configuration file !")
logger.info("Discriminator architecture: {}".format(discriminator)) sys.exit()
# ===============
# === TRAINER ===
# ===============
trainer = ctrainer(generator, discriminator, [3, 64, 64], batch_size=batch_size, noise_dim=noise_dim, conditional_dim=conditional_dim, use_gpu=use_gpu, verbosity_level=verbosity_level)
trainer.train(dataloader, n_epochs=epochs, output_dir=output_dir)
...@@ -47,6 +47,7 @@ from docopt import docopt ...@@ -47,6 +47,7 @@ from docopt import docopt
import bob.core import bob.core
logger = bob.core.log.setup("bob.learn.pytorch") logger = bob.core.log.setup("bob.learn.pytorch")
import bob.io.base
from bob.extension.config import load from bob.extension.config import load
from bob.learn.pytorch.trainers import DCGANTrainer from bob.learn.pytorch.trainers import DCGANTrainer
from bob.learn.pytorch.utils import get_parameter from bob.learn.pytorch.utils import get_parameter
...@@ -82,6 +83,7 @@ def main(user_input=None): ...@@ -82,6 +83,7 @@ def main(user_input=None):
bob.io.base.create_directories_safe(output_dir) bob.io.base.create_directories_safe(output_dir)
# print parameters # print parameters
logger.debug("Noise dimension = {}".format(noise_dim))
logger.debug("Batch size = {}".format(batch_size)) logger.debug("Batch size = {}".format(batch_size))
logger.debug("Epochs = {}".format(epochs)) logger.debug("Epochs = {}".format(epochs))
logger.debug("Sample = {}".format(sample)) logger.debug("Sample = {}".format(sample))
......
...@@ -71,6 +71,7 @@ setup( ...@@ -71,6 +71,7 @@ setup(
'console_scripts' : [ 'console_scripts' : [
'train_cnn.py = bob.learn.pytorch.scripts.train_cnn:main', 'train_cnn.py = bob.learn.pytorch.scripts.train_cnn:main',
'train_dcgan.py = bob.learn.pytorch.scripts.train_dcgan:main', 'train_dcgan.py = bob.learn.pytorch.scripts.train_dcgan:main',
'train_conditionalgan.py = bob.learn.pytorch.scripts.train_conditionalgan:main',
], ],
}, },
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment