Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
bob.learn.pytorch
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Model registry
Operate
Environments
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
This is an archived project. Repository and other project resources are read-only.
Show more breadcrumbs
bob
bob.learn.pytorch
Commits
7c489023
Commit
7c489023
authored
6 years ago
by
Guillaume HEUSCH
Browse files
Options
Downloads
Patches
Plain Diff
[architectures] fixed docstrings in ConditionalGAN
parent
74c22c44
No related branches found
No related tags found
1 merge request
!4
Resolve "Add GANs"
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
bob/learn/pytorch/architectures/ConditionalGAN.py
+64
-52
64 additions, 52 deletions
bob/learn/pytorch/architectures/ConditionalGAN.py
with
64 additions
and
52 deletions
bob/learn/pytorch/architectures/ConditionalGAN.py
+
64
−
52
View file @
7c489023
...
@@ -5,43 +5,37 @@
...
@@ -5,43 +5,37 @@
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
def
weights_init
(
m
):
"""
Weights initialization
**Parameters**
m:
The model
"""
classname
=
m
.
__class__
.
__name__
if
classname
.
find
(
'
Conv
'
)
!=
-
1
:
m
.
weight
.
data
.
normal_
(
0.0
,
0.02
)
elif
classname
.
find
(
'
BatchNorm
'
)
!=
-
1
:
m
.
weight
.
data
.
normal_
(
1.0
,
0.02
)
m
.
bias
.
data
.
fill_
(
0
)
class
ConditionalGAN_generator
(
nn
.
Module
):
class
ConditionalGAN_generator
(
nn
.
Module
):
"""
"""
Class implementating the conditional GAN generator
Class defining the Conditional GAN generator.
**Parameters**
noise_dim: int
This network is introduced in the following publication:
The dimension of the noise.
Mehdi Mirza, Simon Osindero:
"
Conditional Generative Adversarial Nets
"
conditional_dim: int
Attributes
The dimension of the conditioning variable.
----------
ngpu : int
The number of available GPU devices
main : :py:class:`torch.nn.Sequential`
The sequential container
channels: int
"""
The number of channels in the input image (default: 3).
ngpu: int
The number of GPU (default: 1)
"""
def
__init__
(
self
,
noise_dim
,
conditional_dim
,
channels
=
3
,
ngpu
=
1
):
def
__init__
(
self
,
noise_dim
,
conditional_dim
,
channels
=
3
,
ngpu
=
1
):
"""
Init function
Parameters
----------
noise_dim : int
The dimension of the noise
conditional_dim : int
The dimension of the conditioning variable
channels : int
The number of channels in the image
ngpu : int
The number of available GPU devices
"""
super
(
ConditionalGAN_generator
,
self
).
__init__
()
super
(
ConditionalGAN_generator
,
self
).
__init__
()
self
.
ngpu
=
ngpu
self
.
ngpu
=
ngpu
self
.
conditional_dim
=
conditional_dim
self
.
conditional_dim
=
conditional_dim
...
@@ -73,16 +67,20 @@ class ConditionalGAN_generator(nn.Module):
...
@@ -73,16 +67,20 @@ class ConditionalGAN_generator(nn.Module):
)
)
def
forward
(
self
,
z
,
y
):
def
forward
(
self
,
z
,
y
):
"""
"""
Forward function
Forward function for the generator.
**
Parameters
**
Parameters
----------
z: py
Torch
Variable
z
:
:
py
:class: `torch.autograd.
Variable
`
The minibatch of noise.
The minibatch of noise.
y : :py:class: `torch.autograd.Variable`
y: pyTorch Variable
The conditional one hot encoded vector for the minibatch.
The conditional one hot encoded vector for the minibatch.
Returns
-------
:py:class:`torch.Tensor`
the output of the generator (i.e. an image)
"""
"""
generator_input
=
torch
.
cat
((
z
,
y
),
1
)
generator_input
=
torch
.
cat
((
z
,
y
),
1
)
if
isinstance
(
generator_input
.
data
,
torch
.
cuda
.
FloatTensor
)
and
self
.
ngpu
>
1
:
if
isinstance
(
generator_input
.
data
,
torch
.
cuda
.
FloatTensor
)
and
self
.
ngpu
>
1
:
...
@@ -93,22 +91,33 @@ class ConditionalGAN_generator(nn.Module):
...
@@ -93,22 +91,33 @@ class ConditionalGAN_generator(nn.Module):
class
ConditionalGAN_discriminator
(
nn
.
Module
):
class
ConditionalGAN_discriminator
(
nn
.
Module
):
"""
"""
Class implementating the conditional GAN discriminator
Class defining the Conditional GAN discriminator.
**Parameters**
Attributes
----------
conditional_dim: int
conditional_dim: int
The dimension of the conditioning variable.
The dimension of the conditioning variable.
channels: int
channels: int
The number of channels in the input image (default: 3).
The number of channels in the input image (default: 3).
ngpu : int
The number of available GPU devices
main : :py:class:`torch.nn.Sequential`
The sequential container
ngpu: int
The number of GPU (default: 1)
"""
"""
def
__init__
(
self
,
conditional_dim
,
channels
=
3
,
ngpu
=
1
):
def
__init__
(
self
,
conditional_dim
,
channels
=
3
,
ngpu
=
1
):
"""
Init function
Parameters
----------
conditional_dim: int
The dimension of the conditioning variable.
channels: int
The number of channels in the input image (default: 3).
ngpu : int
The number of available GPU devices
"""
super
(
ConditionalGAN_discriminator
,
self
).
__init__
()
super
(
ConditionalGAN_discriminator
,
self
).
__init__
()
self
.
conditional_dim
=
conditional_dim
self
.
conditional_dim
=
conditional_dim
self
.
ngpu
=
ngpu
self
.
ngpu
=
ngpu
...
@@ -139,16 +148,19 @@ class ConditionalGAN_discriminator(nn.Module):
...
@@ -139,16 +148,19 @@ class ConditionalGAN_discriminator(nn.Module):
def
forward
(
self
,
images
,
y
):
def
forward
(
self
,
images
,
y
):
"""
"""
Forward function
Forward function for the discriminator.
**
Parameters
**
Parameters
----------
images: py
Torch
Variable
images
:
:
py
:class: `torch.autograd.
Variable
`
The minibatch of input images.
The minibatch of input images.
y : :py:class: `torch.autograd.Variable`
y: pyTorch Variable
The corresponding conditional feature maps.
The corresponding conditional feature maps.
Returns
-------
:py:class:`torch.Tensor`
the output of the discriminator
"""
"""
input_discriminator
=
torch
.
cat
((
images
,
y
),
1
)
input_discriminator
=
torch
.
cat
((
images
,
y
),
1
)
if
isinstance
(
input_discriminator
.
data
,
torch
.
cuda
.
FloatTensor
)
and
self
.
ngpu
>
1
:
if
isinstance
(
input_discriminator
.
data
,
torch
.
cuda
.
FloatTensor
)
and
self
.
ngpu
>
1
:
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment