Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
bob
bob.learn.pytorch
Commits
225fb644
Commit
225fb644
authored
Jul 24, 2018
by
Guillaume HEUSCH
Browse files
[trainers] fixed docstrings in ConditionalGANTrainer
parent
68ae48c1
Changes
1
Hide whitespace changes
Inline
Side-by-side
bob/learn/pytorch/trainers/ConditionalGANTrainer.py
View file @
225fb644
#!/usr/bin/env python
# encoding: utf-8
import
numpy
import
time
import
torch
import
torch.nn
as
nn
...
...
@@ -13,38 +11,58 @@ import torchvision.utils as vutils
import
bob.core
logger
=
bob
.
core
.
log
.
setup
(
"bob.learn.pytorch"
)
class
ConditionalGANTrainer
(
object
):
"""
Class to train a Conditional GAN
import
time
**Parameters**
class
ConditionalGANTrainer
(
object
):
"""Class to train a Conditional GAN
generator: pytorch nn.Module
Attributes
----------
generator : :py:class:`torch.nn.Module`
The generator network
discriminator: pytorch nn.Module
discriminator : :py:class:`torch.nn.Module`
The discriminator network
image_size: list
image_size: list of int
The size of the images in this format: [channels,height, width]
batch_size: int
The size of your minibatch
noise_dim: int
The dimension of the noise (input to the generator)
conditional_dim: int
The dimension of the conditioning variable
use_gpu: boolean
use_gpu: bool
If you would like to use the gpu
verbosity_level: int
The level of verbosity output to stdout
fixed_noise : :py:class:`torch.Tensor`
The fixed input noise to the generator.
fixed_one_hot : :py:class:`torch.Tensor`
The set of fixed one-hot encoded conditioning variable
criterion : :py:class:`torch.nn.BCELoss`
The binary cross-entropy loss
"""
def
__init__
(
self
,
netG
,
netD
,
image_size
,
batch_size
=
64
,
noise_dim
=
100
,
conditional_dim
=
13
,
use_gpu
=
False
,
verbosity_level
=
2
):
"""Init function
Parameters
----------
netG : :py:class:`torch.nn.Module`
The generator network
netD : :py:class:`torch.nn.Module`
The discriminator network
image_size: list of int
The size of the images in this format: [channels,height, width]
batch_size: int
The size of your minibatch
noise_dim: int
The dimension of the noise (input to the generator)
conditional_dim: int
The dimension of the conditioning variable
use_gpu: bool
If you would like to use the gpu
verbosity_level: int
The level of verbosity output to stdout
"""
bob
.
core
.
log
.
set_verbosity_level
(
logger
,
verbosity_level
)
self
.
netG
=
netG
...
...
@@ -76,25 +94,21 @@ class ConditionalGANTrainer(object):
def
train
(
self
,
dataloader
,
n_epochs
=
10
,
learning_rate
=
0.0002
,
beta1
=
0.5
,
output_dir
=
'out'
):
"""
Function that performs the training.
**Parameters**
"""trains the Conditional GAN.
dataloader: pytorch DataLoader
Parameters
----------
dataloader: :py:class:`torch.utils.data.DataLoader`
The dataloader for your data
n_epochs: int
The number of epochs you would like to train for
learning_rate: float
The learning rate for Adam optimizer
beta1: float
The beta1 for Adam optimizer
output_dir: path
output_dir: str
The directory where you would like to output images and models
"""
real_label
=
1
fake_label
=
0
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment