Skip to content
Snippets Groups Projects
Commit 53d6dfe0 authored by Guillaume HEUSCH's avatar Guillaume HEUSCH
Browse files

[script] generic script to train DCGAN

parent 4fa5fa92
No related branches found
No related tags found
1 merge request!4Resolve "Add GANs"
......@@ -103,22 +103,10 @@ def main(user_input=None):
logger.error("Please provide a dataset in your configuration file !")
sys.exit()
# ===============
# === NETWORK ===
# ===============
ngpu = 1 # usually we don't have more than one GPU
generator = DCGAN_generator(ngpu)
generator.apply(weights_init)
logger.info("Generator architecture: {}".format(generator))
discriminator = DCGAN_discriminator(ngpu)
discriminator.apply(weights_init)
logger.info("Discriminator architecture: {}".format(discriminator))
# ===============
# === TRAINER ===
# ===============
trainer = DCGANTrainer(generator, discriminator, batch_size=batch_size, noise_dim=noise_dim, use_gpu=use_gpu, verbosity_level=verbosity_level)
trainer.train(dataloader, n_epochs=epochs, output_dir=output_dir)
# train the model
if hasattr(configuration, 'generator') and hasattr(configuration, 'discriminator'):
trainer = DCGANTrainer(generator, discriminator, batch_size=batch_size, noise_dim=noise_dim, use_gpu=use_gpu, verbosity_level=verbosity_level)
trainer.train(dataloader, n_epochs=epochs, output_dir=output_dir)
else:
logger.error("Please provide both a generator and a discriminator in your configuration file !")
sys.exit()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment