Skip to content
Snippets Groups Projects
Commit a7abd4e9 authored by Guillaume HEUSCH's avatar Guillaume HEUSCH
Browse files

[script] modified the script to train on Multi-PIE to take into account renaming

parent 72f9922d
Branches
Tags
No related merge requests found
......@@ -8,7 +8,6 @@ Usage:
%(prog)s [--noise-dim=<int>] [--conditional-dim=<int>]
[--batch-size=<int>] [--epochs=<int>] [--sample=<int>]
[--output-dir=<path>] [--use-gpu] [--seed=<int>] [--verbose ...]
[--second]
Options:
-h, --help Show this screen.
......@@ -22,7 +21,6 @@ Options:
-g, --use-gpu Use the GPU
-S, --seed=<int> The random seed [default: 3]
-v, --verbose Increase the verbosity (may appear multiple times).
--second Use the "direct" architecture.
Example:
......@@ -63,7 +61,9 @@ from bob.learn.pytorch.datasets import Normalize
from bob.learn.pytorch.architectures import weights_init
from bob.learn.pytorch.architectures import ConditionalGAN_generator as cgenerator
from bob.learn.pytorch.architectures import ConditionalGAN_discriminator as cdiscriminator
from bob.learn.pytorch.trainers import ConditionalGANTrainer as ctrainer
def main(user_input=None):
......@@ -75,7 +75,7 @@ def main(user_input=None):
prog = os.path.basename(sys.argv[0])
completions = dict(prog=prog, version=version,)
args = docopt(__doc__ % completions,argv=arguments,version='Train DR-GAN (%s)' % version,)
args = docopt(__doc__ % completions,argv=arguments,version='Train conditional GAN (%s)' % version,)
# verbosity
verbosity_level = args['--verbose']
......@@ -105,15 +105,6 @@ def main(user_input=None):
bob.io.base.create_directories_safe(log_dir)
bob.io.base.create_directories_safe(images_dir)
if bool(args['--second']):
from bob.learn.pytorch.architectures import ConditionalGAN_generator2 as cgenerator
from bob.learn.pytorch.architectures import ConditionalGAN_discriminator2 as cdiscriminator
from bob.learn.pytorch.trainers import ConditionalGANTrainer2 as ctrainer
else:
from bob.learn.pytorch.architectures import ConditionalGAN_generator as cgenerator
from bob.learn.pytorch.architectures import ConditionalGAN_discriminator as cdiscriminator
from bob.learn.pytorch.trainers import ConditionalGANTrainer as ctrainer
# ============
# === DATA ===
# ============
......@@ -146,7 +137,6 @@ def main(user_input=None):
discriminator.apply(weights_init)
logger.info("Discriminator architecture: {}".format(discriminator))
# ===============
# === TRAINER ===
# ===============
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment