Skip to content
Snippets Groups Projects
Commit 7c489023 authored by Guillaume HEUSCH's avatar Guillaume HEUSCH
Browse files

[architectures] fixed docstrings in ConditionalGAN

parent 74c22c44
No related branches found
No related tags found
1 merge request!4Resolve "Add GANs"
...@@ -5,43 +5,37 @@ ...@@ -5,43 +5,37 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
def weights_init(m):
"""
Weights initialization
**Parameters**
m:
The model
"""
classname = m.__class__.__name__
if classname.find('Conv') != -1:
m.weight.data.normal_(0.0, 0.02)
elif classname.find('BatchNorm') != -1:
m.weight.data.normal_(1.0, 0.02)
m.bias.data.fill_(0)
class ConditionalGAN_generator(nn.Module): class ConditionalGAN_generator(nn.Module):
""" """ Class implementating the conditional GAN generator
Class defining the Conditional GAN generator.
**Parameters**
noise_dim: int This network is introduced in the following publication:
The dimension of the noise. Mehdi Mirza, Simon Osindero: "Conditional Generative Adversarial Nets"
conditional_dim: int Attributes
The dimension of the conditioning variable. ----------
ngpu : int
The number of available GPU devices
main : :py:class:`torch.nn.Sequential`
The sequential container
channels: int """
The number of channels in the input image (default: 3).
ngpu: int
The number of GPU (default: 1)
"""
def __init__(self, noise_dim, conditional_dim, channels=3, ngpu=1): def __init__(self, noise_dim, conditional_dim, channels=3, ngpu=1):
"""Init function
Parameters
----------
noise_dim : int
The dimension of the noise
conditional_dim : int
The dimension of the conditioning variable
channels : int
The number of channels in the image
ngpu : int
The number of available GPU devices
"""
super(ConditionalGAN_generator, self).__init__() super(ConditionalGAN_generator, self).__init__()
self.ngpu = ngpu self.ngpu = ngpu
self.conditional_dim = conditional_dim self.conditional_dim = conditional_dim
...@@ -73,16 +67,20 @@ class ConditionalGAN_generator(nn.Module): ...@@ -73,16 +67,20 @@ class ConditionalGAN_generator(nn.Module):
) )
def forward(self, z, y): def forward(self, z, y):
""" """Forward function
Forward function for the generator.
**Parameters** Parameters
----------
z: pyTorch Variable z : :py:class: `torch.autograd.Variable`
The minibatch of noise. The minibatch of noise.
y : :py:class: `torch.autograd.Variable`
y: pyTorch Variable
The conditional one hot encoded vector for the minibatch. The conditional one hot encoded vector for the minibatch.
Returns
-------
:py:class:`torch.Tensor`
the output of the generator (i.e. an image)
""" """
generator_input = torch.cat((z, y), 1) generator_input = torch.cat((z, y), 1)
if isinstance(generator_input.data, torch.cuda.FloatTensor) and self.ngpu > 1: if isinstance(generator_input.data, torch.cuda.FloatTensor) and self.ngpu > 1:
...@@ -93,22 +91,33 @@ class ConditionalGAN_generator(nn.Module): ...@@ -93,22 +91,33 @@ class ConditionalGAN_generator(nn.Module):
class ConditionalGAN_discriminator(nn.Module): class ConditionalGAN_discriminator(nn.Module):
""" """ Class implementating the conditional GAN discriminator
Class defining the Conditional GAN discriminator.
**Parameters**
Attributes
----------
conditional_dim: int conditional_dim: int
The dimension of the conditioning variable. The dimension of the conditioning variable.
channels: int channels: int
The number of channels in the input image (default: 3). The number of channels in the input image (default: 3).
ngpu : int
The number of available GPU devices
main : :py:class:`torch.nn.Sequential`
The sequential container
ngpu: int
The number of GPU (default: 1)
""" """
def __init__(self, conditional_dim, channels=3, ngpu=1): def __init__(self, conditional_dim, channels=3, ngpu=1):
"""Init function
Parameters
----------
conditional_dim: int
The dimension of the conditioning variable.
channels: int
The number of channels in the input image (default: 3).
ngpu : int
The number of available GPU devices
"""
super(ConditionalGAN_discriminator, self).__init__() super(ConditionalGAN_discriminator, self).__init__()
self.conditional_dim = conditional_dim self.conditional_dim = conditional_dim
self.ngpu = ngpu self.ngpu = ngpu
...@@ -139,16 +148,19 @@ class ConditionalGAN_discriminator(nn.Module): ...@@ -139,16 +148,19 @@ class ConditionalGAN_discriminator(nn.Module):
def forward(self, images, y): def forward(self, images, y):
""" """Forward function
Forward function for the discriminator.
**Parameters** Parameters
----------
images: pyTorch Variable images : :py:class: `torch.autograd.Variable`
The minibatch of input images. The minibatch of input images.
y : :py:class: `torch.autograd.Variable`
y: pyTorch Variable
The corresponding conditional feature maps. The corresponding conditional feature maps.
Returns
-------
:py:class:`torch.Tensor`
the output of the discriminator
""" """
input_discriminator = torch.cat((images, y), 1) input_discriminator = torch.cat((images, y), 1)
if isinstance(input_discriminator.data, torch.cuda.FloatTensor) and self.ngpu > 1: if isinstance(input_discriminator.data, torch.cuda.FloatTensor) and self.ngpu > 1:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment