Skip to content
Snippets Groups Projects
Commit ab58cf14 authored by Guillaume HEUSCH's avatar Guillaume HEUSCH
Browse files

[architectures] add some docstring in the DCGAN trainer

parent 9b70f326
Branches
Tags
No related merge requests found
......@@ -12,14 +12,33 @@ import torchvision.utils as vutils
import bob.core
logger = bob.core.log.setup("bob.learn.pytorch")
import time
class DCGANTrainer(object):
"""
Class to train a DCGAN
"""
**Parameters**
generator: pytorch nn.Module
The generator network
discriminator: pytorch nn.Module
The discriminator network
batch_size: int
The size of your minibatch
noise_dim: int
The dimension of the noise (input to the generator)
use_gpu: boolean
If you would like to use the gpu
verbosity_level: int
The level of verbosity output to stdout
"""
def __init__(self, netG, netD, batch_size=64, noise_dim=100, use_gpu=False, verbosity_level=2):
self.netG = netG
......@@ -32,7 +51,6 @@ class DCGANTrainer(object):
self.noise = torch.FloatTensor(batch_size, noise_dim, 1, 1)
self.fixed_noise = torch.FloatTensor(batch_size, noise_dim, 1, 1).normal_(0, 1)
self.label = torch.FloatTensor(batch_size)
self.fixed_noise = Variable(self.fixed_noise)
......@@ -49,7 +67,26 @@ class DCGANTrainer(object):
def train(self, dataloader, n_epochs=10, learning_rate=0.0002, beta1=0.5, output_dir='out'):
"""
Function that performs the training.
**Parameters**
dataloader: pytorch DataLoader
The dataloader for your data
n_epochs: int
The number of epochs you would like to train for
learning_rate: float
The learning rate for Adam optimizer
beta1: float
The beta1 for Adam optimizer
output_dir: path
The directory where you would like to output images and models
"""
real_label = 1
fake_label = 0
......@@ -60,6 +97,8 @@ class DCGANTrainer(object):
for epoch in range(n_epochs):
for i, data in enumerate(dataloader, 0):
start = time.time()
# ===========================================================
# (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
# ===========================================================
......@@ -103,7 +142,8 @@ class DCGANTrainer(object):
D_G_z2 = output.data.mean()
optimizerG.step()
logger.info("[{}/{}][{}/{}] => Loss D = {} -- Loss G = {}".format(epoch, n_epochs, i, len(dataloader), errD.data[0], errG.data[0]))
end = time.time()
logger.info("[{}/{}][{}/{}] => Loss D = {} -- Loss G = {} (time spent: {})".format(epoch, n_epochs, i, len(dataloader), errD.data[0], errG.data[0], (end-start)))
# save generated images at every epoch
fake = self.netG(self.fixed_noise)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment