Commit 7c489023 authored by Guillaume HEUSCH's avatar Guillaume HEUSCH
Browse files

[architectures] fixed docstrings in ConditionalGAN

parent 74c22c44
......@@ -5,43 +5,37 @@
import torch
import torch.nn as nn
def weights_init(m):
"""
Weights initialization
**Parameters**
m:
The model
"""
classname = m.__class__.__name__
if classname.find('Conv') != -1:
m.weight.data.normal_(0.0, 0.02)
elif classname.find('BatchNorm') != -1:
m.weight.data.normal_(1.0, 0.02)
m.bias.data.fill_(0)
class ConditionalGAN_generator(nn.Module):
"""
Class defining the Conditional GAN generator.
**Parameters**
""" Class implementating the conditional GAN generator
noise_dim: int
The dimension of the noise.
This network is introduced in the following publication:
Mehdi Mirza, Simon Osindero: "Conditional Generative Adversarial Nets"
conditional_dim: int
The dimension of the conditioning variable.
Attributes
----------
ngpu : int
The number of available GPU devices
main : :py:class:`torch.nn.Sequential`
The sequential container
channels: int
The number of channels in the input image (default: 3).
ngpu: int
The number of GPU (default: 1)
"""
"""
def __init__(self, noise_dim, conditional_dim, channels=3, ngpu=1):
"""Init function
Parameters
----------
noise_dim : int
The dimension of the noise
conditional_dim : int
The dimension of the conditioning variable
channels : int
The number of channels in the image
ngpu : int
The number of available GPU devices
"""
super(ConditionalGAN_generator, self).__init__()
self.ngpu = ngpu
self.conditional_dim = conditional_dim
......@@ -73,16 +67,20 @@ class ConditionalGAN_generator(nn.Module):
)
def forward(self, z, y):
"""
Forward function for the generator.
"""Forward function
**Parameters**
z: pyTorch Variable
Parameters
----------
z : :py:class: `torch.autograd.Variable`
The minibatch of noise.
y: pyTorch Variable
y : :py:class: `torch.autograd.Variable`
The conditional one hot encoded vector for the minibatch.
Returns
-------
:py:class:`torch.Tensor`
the output of the generator (i.e. an image)
"""
generator_input = torch.cat((z, y), 1)
if isinstance(generator_input.data, torch.cuda.FloatTensor) and self.ngpu > 1:
......@@ -93,22 +91,33 @@ class ConditionalGAN_generator(nn.Module):
class ConditionalGAN_discriminator(nn.Module):
"""
Class defining the Conditional GAN discriminator.
**Parameters**
""" Class implementating the conditional GAN discriminator
Attributes
----------
conditional_dim: int
The dimension of the conditioning variable.
channels: int
The number of channels in the input image (default: 3).
ngpu : int
The number of available GPU devices
main : :py:class:`torch.nn.Sequential`
The sequential container
ngpu: int
The number of GPU (default: 1)
"""
def __init__(self, conditional_dim, channels=3, ngpu=1):
"""Init function
Parameters
----------
conditional_dim: int
The dimension of the conditioning variable.
channels: int
The number of channels in the input image (default: 3).
ngpu : int
The number of available GPU devices
"""
super(ConditionalGAN_discriminator, self).__init__()
self.conditional_dim = conditional_dim
self.ngpu = ngpu
......@@ -139,16 +148,19 @@ class ConditionalGAN_discriminator(nn.Module):
def forward(self, images, y):
"""
Forward function for the discriminator.
"""Forward function
**Parameters**
images: pyTorch Variable
Parameters
----------
images : :py:class: `torch.autograd.Variable`
The minibatch of input images.
y: pyTorch Variable
y : :py:class: `torch.autograd.Variable`
The corresponding conditional feature maps.
Returns
-------
:py:class:`torch.Tensor`
the output of the discriminator
"""
input_discriminator = torch.cat((images, y), 1)
if isinstance(input_discriminator.data, torch.cuda.FloatTensor) and self.ngpu > 1:
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment