Commit 2221e9af authored by Guillaume HEUSCH's avatar Guillaume HEUSCH
Browse files

[script, trainer] various fixes to DCGAN

parent 039cd6e1
Pipeline #22429 passed with stage
in 21 minutes and 59 seconds
......@@ -69,6 +69,7 @@ def main(user_input=None):
configuration = load([os.path.join(args['<configuration>'])])
# get various parameters, either from config file or command-line
noise_dim = get_parameter(args, configuration, 'noise_dim', 100)
batch_size = get_parameter(args, configuration, 'batch_size', 64)
epochs = get_parameter(args, configuration, 'epochs', 20)
sample = get_parameter(args, configuration, 'sample', 1e10)
......@@ -83,7 +84,7 @@ def main(user_input=None):
# print parameters
logger.debug("Batch size = {}".format(batch_size))
logger.debug("Epochs = {}".format(epochs))
logger.debug("Sample = {}".format(learning_rate))
logger.debug("Sample = {}".format(sample))
logger.debug("Seed = {}".format(seed))
logger.debug("Output directory = {}".format(output_dir))
logger.debug("Use GPU = {}".format(use_gpu))
......@@ -98,14 +99,14 @@ def main(user_input=None):
# get data
if hasattr(configuration, 'dataset'):
dataloader = torch.utils.data.DataLoader(configuration.dataset, batch_size=batch_size, shuffle=True)
logger.info("There are {} training images from {} categories".format(len(configuration.dataset)))
logger.info("There are {} training images".format(len(configuration.dataset)))
else:
logger.error("Please provide a dataset in your configuration file !")
sys.exit()
# train the model
if hasattr(configuration, 'generator') and hasattr(configuration, 'discriminator'):
trainer = DCGANTrainer(generator, discriminator, batch_size=batch_size, noise_dim=noise_dim, use_gpu=use_gpu, verbosity_level=verbosity_level)
trainer = DCGANTrainer(configuration.generator, configuration.discriminator, batch_size=batch_size, noise_dim=noise_dim, use_gpu=use_gpu, verbosity_level=verbosity_level)
trainer.train(dataloader, n_epochs=epochs, output_dir=output_dir)
else:
logger.error("Please provide both a generator and a discriminator in your configuration file !")
......
......@@ -158,7 +158,8 @@ class DCGANTrainer(object):
optimizerG.step()
end = time.time()
logger.info("[{}/{}][{}/{}] => Loss D = {} -- Loss G = {} (time spent: {})".format(epoch, n_epochs, i, len(dataloader), errD.data[0], errG.data[0], (end-start)))
#logger.info("[{}/{}][{}/{}] => Loss D = {} -- Loss G = {} (time spent: {})".format(epoch, n_epochs, i, len(dataloader), errD.data[0], errG.data[0], (end-start)))
logger.info("[{}/{}][{}/{}] => Loss D = {} -- Loss G = {} (time spent: {})".format(epoch, n_epochs, i, len(dataloader), errD.item(), errG.item(), (end-start)))
# save generated images at every epoch
fake = self.netG(self.fixed_noise)
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment