Skip to content
Snippets Groups Projects
Commit 9982b093 authored by Guillaume HEUSCH's avatar Guillaume HEUSCH
Browse files

[trainers] added trainer for DCGAN

parent 5fc29a09
Branches
Tags v2.1.3
1 merge request!4Resolve "Add GANs"
...@@ -8,39 +8,60 @@ import torch.optim as optim ...@@ -8,39 +8,60 @@ import torch.optim as optim
from torch.autograd import Variable from torch.autograd import Variable
import torchvision.utils as vutils import torchvision.utils as vutils
import bob.core import bob.core
logger = bob.core.log.setup("bob.learn.pytorch") logger = bob.core.log.setup("bob.learn.pytorch")
import time import time
class DCGANTrainer(object): class DCGANTrainer(object):
""" """Class to train a DCGAN
Class to train a DCGAN
**Parameters** Attributes
----------
generator: pytorch nn.Module netG : :py:class:`torch.nn.Module`
The generator network The generator network
netD : :py:class:`torch.nn.Module`
discriminator: pytorch nn.Module
The discriminator network The discriminator network
batch_size: int batch_size: int
The size of your minibatch The size of your minibatch
noise_dim: int noise_dim: int
The dimension of the noise (input to the generator) The dimension of the noise (input to the generator)
use_gpu: bool
use_gpu: boolean
If you would like to use the gpu If you would like to use the gpu
input : :py:class:`torch.Tensor`
verbosity_level: int The input image
The level of verbosity output to stdout noise : :py:class:`torch.Tensor`
The input noise to the generator
fixed_noise : :py:class:`torch.Tensor`
The fixed input noise to the generator.
Used for generating images to save.
label : :py:class:`torch.Tensor`
label for real/fake images.
criterion : :py:class:`torch.nn.BCELoss`
The binary cross-entropy loss
""" """
def __init__(self, netG, netD, batch_size=64, noise_dim=100, use_gpu=False, verbosity_level=2): def __init__(self, netG, netD, batch_size=64, noise_dim=100, use_gpu=False, verbosity_level=2):
"""Init function
Parameters
----------
generator : :py:class:`torch.nn.Module`
The generator network
discriminator : :py:class:`torch.nn.Module`
The discriminator network
batch_size: int
The size of your minibatch
noise_dim: int
The dimension of the noise (input to the generator)
use_gpu: bool
If you would like to use the gpu
verbosity_level: int
The level of verbosity output to stdout
"""
bob.core.log.set_verbosity_level(logger, verbosity_level)
self.netG = netG self.netG = netG
self.netD = netD self.netD = netD
self.batch_size = batch_size self.batch_size = batch_size
...@@ -63,29 +84,23 @@ class DCGANTrainer(object): ...@@ -63,29 +84,23 @@ class DCGANTrainer(object):
self.input, self.label = self.input.cuda(), self.label.cuda() self.input, self.label = self.input.cuda(), self.label.cuda()
self.noise, self.fixed_noise = self.noise.cuda(), self.fixed_noise.cuda() self.noise, self.fixed_noise = self.noise.cuda(), self.fixed_noise.cuda()
bob.core.log.set_verbosity_level(logger, verbosity_level)
def train(self, dataloader, n_epochs=10, learning_rate=0.0002, beta1=0.5, output_dir='out'): def train(self, dataloader, n_epochs=10, learning_rate=0.0002, beta1=0.5, output_dir='out'):
""" """trains the DCGAN.
Function that performs the training.
**Parameters**
dataloader: pytorch DataLoader Parameters
----------
dataloader: :py:class:`torch.utils.data.DataLoader`
The dataloader for your data The dataloader for your data
n_epochs: int n_epochs: int
The number of epochs you would like to train for The number of epochs you would like to train for
learning_rate: float learning_rate: float
The learning rate for Adam optimizer The learning rate for Adam optimizer
beta1: float beta1: float
The beta1 for Adam optimizer The beta1 for Adam optimizer
output_dir: str
output_dir: path
The directory where you would like to output images and models The directory where you would like to output images and models
""" """
real_label = 1 real_label = 1
fake_label = 0 fake_label = 0
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment