Skip to content
Snippets Groups Projects
Commit 72f9922d authored by Guillaume HEUSCH's avatar Guillaume HEUSCH
Browse files

[architecture, trainer] cleaned the code, added docstring

parent 1fc07678
Branches
Tags
No related merge requests found
......@@ -6,17 +6,43 @@ import torch
import torch.nn as nn
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
m.weight.data.normal_(0.0, 0.02)
elif classname.find('BatchNorm') != -1:
m.weight.data.normal_(1.0, 0.02)
m.bias.data.fill_(0)
"""
Weights initialization
**Parameters**
m:
The model
"""
classname = m.__class__.__name__
if classname.find('Conv') != -1:
m.weight.data.normal_(0.0, 0.02)
elif classname.find('BatchNorm') != -1:
m.weight.data.normal_(1.0, 0.02)
m.bias.data.fill_(0)
class ConditionalGAN_generator(nn.Module):
"""
Class defining the Conditional GAN generator.
**Parameters**
noise_dim: int
The dimension of the noise.
conditional_dim: int
The dimension of the conditioning variable.
channels: int
The number of channels in the input image (default: 3).
class ConditionalGAN_generator2(nn.Module):
ngpu: int
The number of GPU (default: 1)
"""
def __init__(self, noise_dim, conditional_dim, channels=3, ngpu=1):
super(ConditionalGAN_generator2, self).__init__()
super(ConditionalGAN_generator, self).__init__()
self.ngpu = ngpu
self.conditional_dim = conditional_dim
......@@ -52,14 +78,13 @@ class ConditionalGAN_generator2(nn.Module):
**Parameters**
z: pyTorch Tensor
z: pyTorch Variable
The minibatch of noise.
y: int
The conditional one hot encoded vector.
y: pyTorch Variable
The conditional one hot encoded vector for the minibatch.
"""
generator_input = torch.cat((z, y), 1)
if isinstance(generator_input.data, torch.cuda.FloatTensor) and self.ngpu > 1:
output = nn.parallel.data_parallel(self.main, generator_input, range(self.ngpu))
else:
......@@ -67,16 +92,24 @@ class ConditionalGAN_generator2(nn.Module):
return output
def one_hot_vector(self, y):
one_hot = torch.FloatTensor(self.minibatch_size, self.conditional_dim, 1, 1).zero_()
for k in range(self.minibatch_size):
one_hot[k, y[k]] = 1
return one_hot
class ConditionalGAN_discriminator(nn.Module):
"""
Class defining the Conditional GAN discriminator.
**Parameters**
conditional_dim: int
The dimension of the conditioning variable.
channels: int
The number of channels in the input image (default: 3).
class ConditionalGAN_discriminator2(nn.Module):
ngpu: int
The number of GPU (default: 1)
"""
def __init__(self, conditional_dim, channels=3, ngpu=1):
super(ConditionalGAN_discriminator2, self).__init__()
super(ConditionalGAN_discriminator, self).__init__()
self.conditional_dim = conditional_dim
self.ngpu = ngpu
......@@ -111,17 +144,15 @@ class ConditionalGAN_discriminator2(nn.Module):
**Parameters**
image: pyTorch Tensor
images: pyTorch Variable
The minibatch of input images.
y: int
The conditional feature maps.
y: pyTorch Variable
The corresponding conditional feature maps.
"""
input_discriminator = torch.cat((images, y), 1)
if isinstance(input_discriminator.data, torch.cuda.FloatTensor) and self.ngpu > 1:
output = nn.parallel.data_parallel(self.main, input_discriminator, range(self.ngpu))
else:
output = self.main(input_discriminator)
return output.view(-1, 1).squeeze(1)
......@@ -2,6 +2,7 @@
# encoding: utf-8
import numpy
import time
import torch
import torch.nn as nn
......@@ -9,15 +10,10 @@ import torch.optim as optim
from torch.autograd import Variable
import torchvision.utils as vutils
import bob.core
logger = bob.core.log.setup("bob.learn.pytorch")
import time
from matplotlib import pyplot
class ConditionalGANTrainer2(object):
class ConditionalGANTrainer(object):
"""
Class to train a Conditional GAN
......@@ -64,21 +60,19 @@ class ConditionalGANTrainer2(object):
self.fixed_one_hot = torch.FloatTensor(self.conditional_dim, self.conditional_dim, 1, 1).zero_()
for k in range(self.conditional_dim):
self.fixed_one_hot[k, k] = 1
# TODO: figuring out the CPU/GPU thing - Guillaume HEUSCH, 17-11-2017
self.fixed_noise = Variable(self.fixed_noise)
self.fixed_one_hot = Variable(self.fixed_one_hot)
# binary cross-entropy loss
self.criterion = nn.BCELoss()
# move stuff to GPU if needed
if self.use_gpu:
self.netD.cuda()
self.netG.cuda()
self.criterion.cuda()
#self_fixed_noise = self.fixed_noise.cuda()
#self_fixed_one_hot = self.fixed_one_hot.cuda()
self.fixed_noise = Variable(self.fixed_noise)
self.fixed_one_hot = Variable(self.fixed_one_hot)
def train(self, dataloader, n_epochs=10, learning_rate=0.0002, beta1=0.5, output_dir='out'):
......@@ -131,7 +125,8 @@ class ConditionalGANTrainer2(object):
for k in range(batch_size):
one_hot_feature_maps[k, poses[k], :, :] = 1
one_hot_vector[k, poses[k]] = 1
# move stuff to GPU if needed
if self.use_gpu:
real_images = real_images.cuda()
label = label.cuda()
......@@ -139,7 +134,6 @@ class ConditionalGANTrainer2(object):
one_hot_feature_maps = one_hot_feature_maps.cuda()
one_hot_vector = one_hot_vector.cuda()
# =============
# DISCRIMINATOR
# =============
......@@ -149,11 +143,6 @@ class ConditionalGANTrainer2(object):
label.resize_(batch_size).fill_(real_label)
imagev = Variable(real_images)
one_hot_fmv = Variable(one_hot_feature_maps)
#from matplotlib import pyplot
#pyplot.title("Pose {}".format(poses[0]))
#pyplot.imshow(numpy.rollaxis(numpy.rollaxis(first_image, 2),2))
#pyplot.show()
labelv = Variable(label)
output_real = self.netD(imagev, one_hot_fmv)
errD_real = self.criterion(output_real, labelv)
......@@ -167,9 +156,9 @@ class ConditionalGANTrainer2(object):
output_fake = self.netD(fake, one_hot_fmv)
errD_fake = self.criterion(output_fake, labelv)
errD_fake.backward(retain_graph=True)
errD = errD_real + errD_fake
# perform optimization (i.e. update discriminator parameters)
errD = errD_real + errD_fake
optimizerD.step()
# =========
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment