Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
bob.learn.pytorch
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Model registry
Operate
Environments
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
This is an archived project. Repository and other project resources are read-only.
Show more breadcrumbs
bob
bob.learn.pytorch
Commits
8b4ac689
Commit
8b4ac689
authored
7 years ago
by
Guillaume HEUSCH
Browse files
Options
Downloads
Patches
Plain Diff
[trainer] added the initial implementation of the improved Wasserstein trainer
parent
1ac5e8fc
Branches
Branches containing commit
Tags
Tags containing commit
No related merge requests found
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
bob/learn/pytorch/trainers/ImprovedWasserteinCGANTrainer.py
+261
-0
261 additions, 0 deletions
bob/learn/pytorch/trainers/ImprovedWasserteinCGANTrainer.py
with
261 additions
and
0 deletions
bob/learn/pytorch/trainers/ImprovedWasserteinCGANTrainer.py
0 → 100644
+
261
−
0
View file @
8b4ac689
#!/usr/bin/env python
# encoding: utf-8
import
numpy
import
time
import
torch
import
torch.nn
as
nn
import
torch.optim
as
optim
from
torch.autograd
import
Variable
import
torchvision.utils
as
vutils
import
bob.core
logger
=
bob
.
core
.
log
.
setup
(
"
bob.learn.pytorch
"
)
class
IWCGAN
(
object
):
"""
Class to train a Conditional GAN, using the Improved Wasserstein Training method
**Parameters**
generator: pytorch nn.Module
The generator network
discriminator: pytorch nn.Module
The discriminator network
image_size: list
The size of the images in this format: [channels,height, width]
batch_size: int
The size of your minibatch
noise_dim: int
The dimension of the noise (input to the generator)
conditional_dim: int
The dimension of the conditioning variable
n_critic_update: int
The number of critic (discriminator) iterations per generator iterations.
Lambda: int
The regularization weight (gradient penalty).
use_gpu: boolean
If you would like to use the gpu
verbosity_level: int
The level of verbosity output to stdout
"""
def
__init__
(
self
,
netG
,
netD
,
image_size
,
batch_size
=
64
,
noise_dim
=
100
,
conditional_dim
=
13
,
n_critic_update
=
5
,
Lambda
=
10
,
use_gpu
=
False
,
verbosity_level
=
2
):
bob
.
core
.
log
.
set_verbosity_level
(
logger
,
verbosity_level
)
self
.
netG
=
netG
self
.
netD
=
netD
self
.
image_size
=
image_size
self
.
batch_size
=
batch_size
self
.
noise_dim
=
noise_dim
self
.
conditional_dim
=
conditional_dim
self
.
n_critic_update
=
n_critic_update
self
.
Lambda
=
Lambda
self
.
use_gpu
=
use_gpu
# fixed conditional noise - used to generate samples (one for each value of the conditional variable)
self
.
fixed_noise
=
torch
.
FloatTensor
(
self
.
conditional_dim
,
noise_dim
,
1
,
1
).
normal_
(
0
,
1
)
self
.
fixed_one_hot
=
torch
.
FloatTensor
(
self
.
conditional_dim
,
self
.
conditional_dim
,
1
,
1
).
zero_
()
for
k
in
range
(
self
.
conditional_dim
):
self
.
fixed_one_hot
[
k
,
k
]
=
1
# TODO: figuring out the CPU/GPU thing - Guillaume HEUSCH, 17-11-2017
self
.
fixed_noise
=
Variable
(
self
.
fixed_noise
)
self
.
fixed_one_hot
=
Variable
(
self
.
fixed_one_hot
)
# binary cross-entropy loss
self
.
criterion
=
nn
.
BCELoss
()
# move stuff to GPU if needed
if
self
.
use_gpu
:
self
.
netD
.
cuda
()
self
.
netG
.
cuda
()
self
.
criterion
.
cuda
()
def
calc_gradient_penalty
(
real_data
,
fake_data
,
one_hot
,
batch_size
):
"""
Computes the gradient penalty term.
**Parameters**
real_data:
The batch of real images.
fake_data:
The batch of generated (fake) images.
one_hot:
The batch of feature maps to append to the discriminator input.
batch_size: int
The size of the minibatch.
"""
alpha
=
torch
.
rand
(
batch_size
,
1
)
alpha
=
alpha
.
expand
(
batch_size
,
real_data
.
nelement
()
/
batch_size
).
contiguous
().
view
(
batch_size
,
self
.
image_size
[
0
],
self
.
image_size
[
1
],
self
.
image_size
[
2
])
alpha
=
alpha
.
cuda
()
if
self
.
use_gpu
else
alpha
interpolates
=
alpha
*
real_data
+
((
1
-
alpha
)
*
fake_data
)
if
use_gpu
:
interpolates
=
interpolates
.
cuda
()
interpolates
=
autograd
.
Variable
(
interpolates
,
requires_grad
=
True
)
disc_interpolates
=
self
.
netD
(
interpolates
,
one_hot
)
gradients
=
autograd
.
grad
(
outputs
=
disc_interpolates
,
inputs
=
interpolates
,
grad_outputs
=
torch
.
ones
(
disc_interpolates
.
size
()).
cuda
()
if
use_gpu
else
torch
.
ones
(
disc_interpolates
.
size
()),
create_graph
=
True
,
retain_graph
=
True
,
only_inputs
=
True
)[
0
]
gradient_penalty
=
((
gradients
.
norm
(
2
,
dim
=
1
)
-
1
)
**
2
).
mean
()
*
self
.
Lambda
return
gradient_penalty
def
train
(
self
,
dataloader
,
n_iterations
=
100000
,
learning_rate
=
0.0001
,
beta1
=
0.5
,
output_dir
=
'
out
'
):
"""
Function that performs the training.
**Parameters**
dataloader: pytorch DataLoader
The dataloader for your data
n_iterations: int
The number of iterations you would like to train for
learning_rate: float
The learning rate for Adam optimizer
beta1: float
The beta1 for Adam optimizer
output_dir: path
The directory where you would like to output images and models
"""
# setup optimizer
optimizerD
=
optim
.
Adam
(
self
.
netD
.
parameters
(),
lr
=
1e-4
,
betas
=
(
beta1
,
0.9
))
optimizerG
=
optim
.
Adam
(
self
.
netG
.
parameters
(),
lr
=
1e-4
,
betas
=
(
beta1
,
0.9
))
one
=
torch
.
FloatTensor
([
1
])
mone
=
one
*
-
1
if
use_gpu
:
one
=
one
.
cuda
()
mone
=
mone
.
cuda
()
# let's go
for
iteration
in
range
(
n_iterations
):
start
=
time
.
time
()
# =============
# DISCRIMINATOR
# =============
for
p
in
self
.
netD
.
parameters
():
p
.
requires_grad
=
True
for
k
in
range
(
n_critic_update
):
self
.
netD
.
zero_grad
()
# get the data and pose labels
data
=
dataloader
.
next
()
real_images
=
data
[
'
image
'
]
poses
=
data
[
'
pose
'
]
# WARNING: the last batch could be smaller than the provided size
batch_size
=
len
(
real_images
)
# create the Tensors with the right batch size
noise
=
torch
.
FloatTensor
(
batch_size
,
self
.
noise_dim
,
1
,
1
).
normal_
(
0
,
1
)
# create the one hot conditional vector (generator) and feature maps (discriminator)
one_hot_feature_maps
=
torch
.
FloatTensor
(
batch_size
,
self
.
conditional_dim
,
self
.
image_size
[
1
],
self
.
image_size
[
2
]).
zero_
()
one_hot_vector
=
torch
.
FloatTensor
(
batch_size
,
self
.
conditional_dim
,
1
,
1
).
zero_
()
for
k
in
range
(
batch_size
):
one_hot_feature_maps
[
k
,
poses
[
k
],
:,
:]
=
1
one_hot_vector
[
k
,
poses
[
k
]]
=
1
# move stuff to GPU if needed
if
self
.
use_gpu
:
real_images
=
real_images
.
cuda
()
noise
=
noise
.
cuda
()
one_hot_feature_maps
=
one_hot_feature_maps
.
cuda
()
one_hot_vector
=
one_hot_vector
.
cuda
()
# === REAL DATA ===
imagev
=
Variable
(
real_images
)
one_hot_fmv
=
Variable
(
one_hot_feature_maps
)
output_real
=
self
.
netD
(
imagev
,
one_hot_fmv
)
output_real
=
output_real
.
mean
()
output_real
.
backward
(
mone
)
# === FAKE DATA ===
noisev
=
Variable
(
noise
,
volatile
=
True
)
one_hot_vv
=
Variable
(
one_hot_vector
)
fake
=
Variable
(
self
.
netG
(
noisev
,
one_hot_vv
).
data
)
input_fakev
=
fake
output_fake
=
self
.
netD
(
input_fakev
,
one_hot_fmv
)
output_fake
=
output_fake
.
mean
()
output_fake
.
backward
(
one
)
gradient_penalty
=
calc_gradient_penalty
(
imagev
.
data
,
fake
.
data
,
one_hot_feature_maps
)
gradient_penalty
.
backward
()
D_cost
=
D_fake
-
D_real
+
gradient_penalty
optimizerD
.
step
()
# =========
# GENERATOR
# =========
for
p
in
netD
.
parameters
():
p
.
requires_grad
=
False
self
.
netG
.
zero_grad
()
noise
=
torch
.
FloatTensor
(
batch_size
,
self
.
noise_dim
,
1
,
1
).
normal_
(
0
,
1
)
data
=
dataloader
.
next
()
poses
=
data
[
'
pose
'
]
one_hot_feature_maps
=
torch
.
FloatTensor
(
batch_size
,
self
.
conditional_dim
,
self
.
image_size
[
1
],
self
.
image_size
[
2
]).
zero_
()
for
k
in
range
(
batch_size
):
one_hot_feature_maps
[
k
,
poses
[
k
],
:,
:]
=
1
if
self
.
use_gpu
:
noise
=
noise
.
cuda
()
one_hot_feature_maps
=
one_hot_feature_maps
.
cuda
()
noisev
=
Variable
(
noise
)
one_hot_fmv
=
Variable
(
one_hot_feature_maps
)
fake
=
netG
(
noisev
)
G
=
netD
(
fake
,
one_hot_fmv
)
G
=
G
.
mean
()
G
.
backward
(
mone
)
G_cost
=
-
G
optimizerG
.
step
()
end
=
time
.
time
()
logger
.
info
(
"
[{}/{}] => Loss D = {} -- Loss G = {} (time spent: {})
"
.
format
(
iteration
,
n_iterations
,
D_cost
.
data
,
G_cost
.
data
,
(
end
-
start
)))
# save sample every 100 iterations
if
iteration
%
100
==
99
:
fake_examples
=
self
.
netG
(
self
.
fixed_noise
,
self
.
fixed_one_hot
)
vutils
.
save_image
(
fake_examples
.
data
,
'
%s/fake_samples_epoch_%03d.png
'
%
(
output_dir
,
epoch
),
normalize
=
True
)
# save model every 1000 iterations
if
iteration
%
1000
==
999
:
torch
.
save
(
self
.
netG
.
state_dict
(),
'
%s/netG_epoch_%d.pth
'
%
(
output_dir
,
epoch
))
torch
.
save
(
self
.
netD
.
state_dict
(),
'
%s/netD_epoch_%d.pth
'
%
(
output_dir
,
epoch
))
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment