Commit 225fb644 authored by Guillaume HEUSCH's avatar Guillaume HEUSCH
Browse files

[trainers] fixed docstrings in ConditionalGANTrainer

parent 68ae48c1
#!/usr/bin/env python
# encoding: utf-8
import numpy
import time
import torch
import torch.nn as nn
......@@ -13,38 +11,58 @@ import torchvision.utils as vutils
import bob.core
logger = bob.core.log.setup("bob.learn.pytorch")
class ConditionalGANTrainer(object):
"""
Class to train a Conditional GAN
import time
**Parameters**
class ConditionalGANTrainer(object):
"""Class to train a Conditional GAN
generator: pytorch nn.Module
Attributes
----------
generator : :py:class:`torch.nn.Module`
The generator network
discriminator: pytorch nn.Module
discriminator : :py:class:`torch.nn.Module`
The discriminator network
image_size: list
image_size: list of int
The size of the images in this format: [channels,height, width]
batch_size: int
The size of your minibatch
noise_dim: int
The dimension of the noise (input to the generator)
conditional_dim: int
The dimension of the conditioning variable
use_gpu: boolean
use_gpu: bool
If you would like to use the gpu
verbosity_level: int
The level of verbosity output to stdout
fixed_noise : :py:class:`torch.Tensor`
The fixed input noise to the generator.
fixed_one_hot : :py:class:`torch.Tensor`
The set of fixed one-hot encoded conditioning variable
criterion : :py:class:`torch.nn.BCELoss`
The binary cross-entropy loss
"""
def __init__(self, netG, netD, image_size, batch_size=64, noise_dim=100, conditional_dim=13, use_gpu=False, verbosity_level=2):
"""Init function
Parameters
----------
netG : :py:class:`torch.nn.Module`
The generator network
netD : :py:class:`torch.nn.Module`
The discriminator network
image_size: list of int
The size of the images in this format: [channels,height, width]
batch_size: int
The size of your minibatch
noise_dim: int
The dimension of the noise (input to the generator)
conditional_dim: int
The dimension of the conditioning variable
use_gpu: bool
If you would like to use the gpu
verbosity_level: int
The level of verbosity output to stdout
"""
bob.core.log.set_verbosity_level(logger, verbosity_level)
self.netG = netG
......@@ -76,25 +94,21 @@ class ConditionalGANTrainer(object):
def train(self, dataloader, n_epochs=10, learning_rate=0.0002, beta1=0.5, output_dir='out'):
"""
Function that performs the training.
**Parameters**
"""trains the Conditional GAN.
dataloader: pytorch DataLoader
Parameters
----------
dataloader: :py:class:`torch.utils.data.DataLoader`
The dataloader for your data
n_epochs: int
The number of epochs you would like to train for
learning_rate: float
The learning rate for Adam optimizer
beta1: float
The beta1 for Adam optimizer
output_dir: path
output_dir: str
The directory where you would like to output images and models
"""
real_label = 1
fake_label = 0
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment