Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
bob.learn.pytorch
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Model registry
Operate
Environments
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
This is an archived project. Repository and other project resources are read-only.
Show more breadcrumbs
bob
bob.learn.pytorch
Merge requests
!4
Resolve "Add GANs"
Code
Review changes
Check out branch
Download
Patches
Plain diff
Merged
Resolve "Add GANs"
4-add-gans
into
master
Overview
0
Commits
26
Pipelines
8
Changes
1
Merged
Guillaume HEUSCH
requested to merge
4-add-gans
into
master
6 years ago
Overview
0
Commits
26
Pipelines
8
Changes
1
Expand
Closes
#4 (closed)
Edited
6 years ago
by
Guillaume HEUSCH
0
0
Merge request reports
Viewing commit
225fb644
Prev
Next
Show latest version
1 file
+
44
−
30
Inline
Compare changes
Side-by-side
Inline
Show whitespace changes
Show one file at a time
225fb644
[trainers] fixed docstrings in ConditionalGANTrainer
· 225fb644
Guillaume HEUSCH
authored
6 years ago
bob/learn/pytorch/trainers/ConditionalGANTrainer.py
+
44
−
30
Options
#!/usr/bin/env python
# encoding: utf-8
import
numpy
import
time
import
torch
import
torch.nn
as
nn
@@ -13,38 +11,58 @@ import torchvision.utils as vutils
import
bob.core
logger
=
bob
.
core
.
log
.
setup
(
"
bob.learn.pytorch
"
)
class
ConditionalGANTrainer
(
object
):
"""
Class to train a Conditional GAN
import
time
**Parameters**
class
ConditionalGANTrainer
(
object
):
"""
Class to train a Conditional GAN
generator: pytorch nn.Module
Attributes
----------
generator : :py:class:`torch.nn.Module`
The generator network
discriminator: pytorch nn.Module
discriminator : :py:class:`torch.nn.Module`
The discriminator network
image_size: list
image_size: list of int
The size of the images in this format: [channels,height, width]
batch_size: int
The size of your minibatch
noise_dim: int
The dimension of the noise (input to the generator)
conditional_dim: int
The dimension of the conditioning variable
use_gpu: boolean
use_gpu: bool
If you would like to use the gpu
verbosity_level: int
The level of verbosity output to stdout
fixed_noise : :py:class:`torch.Tensor`
The fixed input noise to the generator.
fixed_one_hot : :py:class:`torch.Tensor`
The set of fixed one-hot encoded conditioning variable
criterion : :py:class:`torch.nn.BCELoss`
The binary cross-entropy loss
"""
def
__init__
(
self
,
netG
,
netD
,
image_size
,
batch_size
=
64
,
noise_dim
=
100
,
conditional_dim
=
13
,
use_gpu
=
False
,
verbosity_level
=
2
):
"""
Init function
Parameters
----------
netG : :py:class:`torch.nn.Module`
The generator network
netD : :py:class:`torch.nn.Module`
The discriminator network
image_size: list of int
The size of the images in this format: [channels,height, width]
batch_size: int
The size of your minibatch
noise_dim: int
The dimension of the noise (input to the generator)
conditional_dim: int
The dimension of the conditioning variable
use_gpu: bool
If you would like to use the gpu
verbosity_level: int
The level of verbosity output to stdout
"""
bob
.
core
.
log
.
set_verbosity_level
(
logger
,
verbosity_level
)
self
.
netG
=
netG
@@ -76,25 +94,21 @@ class ConditionalGANTrainer(object):
def
train
(
self
,
dataloader
,
n_epochs
=
10
,
learning_rate
=
0.0002
,
beta1
=
0.5
,
output_dir
=
'
out
'
):
"""
Function that performs the training.
**Parameters**
"""
trains the Conditional GAN.
dataloader: pytorch DataLoader
Parameters
----------
dataloader: :py:class:`torch.utils.data.DataLoader`
The dataloader for your data
n_epochs: int
The number of epochs you would like to train for
learning_rate: float
The learning rate for Adam optimizer
beta1: float
The beta1 for Adam optimizer
output_dir: path
output_dir: str
The directory where you would like to output images and models
"""
real_label
=
1
fake_label
=
0
Loading