Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
bob.learn.pytorch
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Model registry
Operate
Environments
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
This is an archived project. Repository and other project resources are read-only.
Show more breadcrumbs
bob
bob.learn.pytorch
Commits
d81b0a07
Commit
d81b0a07
authored
7 years ago
by
Guillaume HEUSCH
Browse files
Options
Downloads
Patches
Plain Diff
[architectures] added dropout in the encoder/decoder, like in the original DR-GAN implementation
parent
df5062d4
Branches
Branches containing commit
Tags
Tags containing commit
No related merge requests found
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
bob/learn/pytorch/architectures/DRGANOriginal.py
+54
-15
54 additions, 15 deletions
bob/learn/pytorch/architectures/DRGANOriginal.py
with
54 additions
and
15 deletions
bob/learn/pytorch/architectures/DRGANOriginal.py
+
54
−
15
View file @
d81b0a07
...
...
@@ -27,8 +27,11 @@ class DRGANOriginal_encoder(nn.Module):
latent_dim: int
The dimension of the encoded ID
"""
def
__init__
(
self
,
image_size
,
latent_dim
):
def
__init__
(
self
,
image_size
,
latent_dim
,
is_training
=
True
,
dropout
=
False
):
self
.
is_training
=
is_training
self
.
dropout
=
dropout
# conv2d(in_channels, out_channels (i.e. number of feature maps), kernel size, stride, padding)
super
(
DRGANOriginal_encoder
,
self
).
__init__
()
...
...
@@ -97,9 +100,7 @@ class DRGANOriginal_encoder(nn.Module):
# ------------------------------------------
# average pool
nn
.
AvgPool2d
(
6
,
stride
=
1
)
# dropout ?
nn
.
AvgPool2d
(
6
,
stride
=
1
),
)
...
...
@@ -113,10 +114,17 @@ class DRGANOriginal_encoder(nn.Module):
The minibatch of images to encode.
"""
if
isinstance
(
x
.
data
,
torch
.
cuda
.
FloatTensor
)
and
self
.
ngpu
>
1
:
output
=
nn
.
parallel
.
data_parallel
(
self
.
main
,
x
,
range
(
self
.
ngpu
))
avgpool
=
nn
.
parallel
.
data_parallel
(
self
.
main
,
x
,
range
(
self
.
ngpu
))
else
:
output
=
self
.
main
(
x
)
avgpool
=
self
.
main
(
x
)
# dropout
if
(
self
.
is_training
and
self
.
dropout
):
dropout
=
nn
.
Dropout2d
(
p
=
0.4
)
output
=
dropout
(
avgpool
)
else
:
output
=
avgpool
return
output
...
...
@@ -142,14 +150,14 @@ class DRGANOriginal_decoder(nn.Module):
super
(
DRGANOriginal_decoder
,
self
).
__init__
()
self
.
noise_dim
=
noise_dim
self
.
conditional_dim
=
conditional_dim
self
.
latent_dim
=
latent_dim
self
.
ngpu
=
1
# usually, we don't have more than one GPU
self
.
main
=
nn
.
Sequential
(
# input is Z+ID+C , going into a convolution, output is 320x6x6
nn
.
ConvTranspose2d
((
noise_dim
+
latent_dim
+
conditional_dim
),
320
,
6
,
1
,
0
,
bias
=
False
),
# dropout ?
nn
.
BatchNorm2d
(
320
),
nn
.
ELU
(
inplace
=
True
),
# input is 320x6x6, output is 160x6x6
nn
.
ConvTranspose2d
(
320
,
160
,
3
,
1
,
1
,
bias
=
False
),
nn
.
BatchNorm2d
(
160
),
...
...
@@ -158,7 +166,6 @@ class DRGANOriginal_decoder(nn.Module):
nn
.
ConvTranspose2d
(
160
,
256
,
3
,
1
,
1
,
bias
=
False
),
nn
.
BatchNorm2d
(
256
),
nn
.
ELU
(
inplace
=
True
),
# size OK
# ------------------------------------------
# input is 256x6x6, output is 256x12x12
nn
.
ConvTranspose2d
(
256
,
256
,
3
,
2
,
1
,
output_padding
=
1
,
bias
=
False
),
...
...
@@ -230,10 +237,42 @@ class DRGANOriginal_decoder(nn.Module):
The encoded ID for the minibatch
"""
decoder_input
=
torch
.
cat
((
z
,
y
,
f
),
1
)
# linear transform to build a hypercube as input to deconv layers
#
# input is noise_dim + conditional_dim + latent_dim
# output is (latent_dim x 6 x 6)
#
# Dropout + BatchNorm + ELU are applied to the cube
# squeeze, apply linear layer, and unsqueeze
# specify the ops
lin
=
nn
.
Linear
((
self
.
noise_dim
+
self
.
conditional_dim
+
self
.
latent_dim
),
(
self
.
latent_dim
*
6
*
6
))
dropout
=
nn
.
Dropout
(
p
=
0.4
)
bn
=
nn
.
BatchNorm2d
(
320
)
elu
=
nn
.
ELU
(
inplace
=
True
)
if
torch
.
cuda
.
is_available
():
lin
=
lin
.
cuda
()
dropout
=
dropout
.
cuda
()
bn
=
bn
.
cuda
()
elu
=
elu
.
cuda
()
decoder_input
=
torch
.
squeeze
(
decoder_input
)
projected
=
lin
(
decoder_input
)
projected
=
projected
.
unsqueeze
(
2
)
projected
=
projected
.
unsqueeze
(
3
)
dropped
=
dropout
(
projected
)
reshaped
=
dropped
.
view
(
-
1
,
self
.
latent_dim
,
6
,
6
)
hypercube
=
elu
(
bn
(
reshaped
))
# deconv layers
if
isinstance
(
decoder_input
.
data
,
torch
.
cuda
.
FloatTensor
)
and
self
.
ngpu
>
1
:
output
=
nn
.
parallel
.
data_parallel
(
self
.
main
,
decoder_input
,
range
(
self
.
ngpu
))
output
=
nn
.
parallel
.
data_parallel
(
self
.
main
,
hypercube
,
range
(
self
.
ngpu
))
else
:
output
=
self
.
main
(
decoder_input
)
output
=
self
.
main
(
hypercube
)
return
output
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment