Skip to content
Snippets Groups Projects
Commit f02fa509 authored by Guillaume HEUSCH's avatar Guillaume HEUSCH
Browse files

[trainer] add the trainer class

parent 4e2fe20e
No related branches found
No related tags found
No related merge requests found
...@@ -5,18 +5,21 @@ ...@@ -5,18 +5,21 @@
""" Train a DR-GAN """ Train a DR-GAN
Usage: Usage:
%(prog)s [--latent-dim=<int>] %(prog)s [--latent-dim=<int>] [--noise-dim=<int>]
[--batch-size=<int>] [--epochs=<int>] [--sample=<int>] [--batch-size=<int>] [--epochs=<int>] [--sample=<int>]
[--output-dir=<path>] [--verbose ...] [--output-dir=<path>] [--use-gpu] [--seed=<int>] [--verbose ...]
Options: Options:
-h, --help Show this screen. -h, --help Show this screen.
-V, --version Show version. -V, --version Show version.
-l, --latent-dim=<int> the dimension of the encoded ID [default: 320] -l, --latent-dim=<int> the dimension of the encoded ID [default: 320]
-n, --noise-dim=<int> the dimension of the noise [default: 100]
-b, --batch-size=<int> The size of your mini-batch [default: 64] -b, --batch-size=<int> The size of your mini-batch [default: 64]
-e, --epochs=<int> The number of training epochs [default: 100] -e, --epochs=<int> The number of training epochs [default: 100]
-s, --sample=<int> Save generated images at every 'sample' batch iteration [default: 100000000000] -s, --sample=<int> Save generated images at every 'sample' batch iteration [default: 100000000000]
-o, --output-dir=<path> Dir to save the logs, models and images [default: ./drgan-light-mpie-casia/] -o, --output-dir=<path> Dir to save the logs, models and images [default: ./drgan-light-mpie-casia/]
-g, --use-gpu Use the GPU
-S, --seed=<int> The random seed [default: 3]
-v, --verbose Increase the verbosity (may appear multiple times). -v, --verbose Increase the verbosity (may appear multiple times).
Example: Example:
...@@ -40,22 +43,27 @@ from docopt import docopt ...@@ -40,22 +43,27 @@ from docopt import docopt
version = pkg_resources.require('bob.learn.pytorch')[0].version version = pkg_resources.require('bob.learn.pytorch')[0].version
import numpy import numpy
import bob.io.base
# torch
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.optim as optim import torch.optim as optim
import torchvision.transforms as transforms import torchvision.transforms as transforms
import torchvision.utils as vutils import torchvision.utils as vutils
from torch.autograd import Variable from torch.autograd import Variable
# data and architecture from the package
from bob.learn.pytorch.datasets.multipie import MultiPIEDataset from bob.learn.pytorch.datasets.multipie import MultiPIEDataset
from bob.learn.pytorch.datasets.multipie import RollChannels from bob.learn.pytorch.datasets.multipie import RollChannels
from bob.learn.pytorch.architectures.DCGAN import _netG from bob.learn.pytorch.architectures.DCGAN import _netG
from bob.learn.pytorch.architectures.DCGAN import _netD from bob.learn.pytorch.architectures.DCGAN import _netD
from bob.learn.pytorch.architectures.DCGAN import weights_init from bob.learn.pytorch.architectures.DCGAN import weights_init
from bob.learn.pytorch.trainers.DCGANTrainer import DCGANTrainer
def main(user_input=None): def main(user_input=None):
# Parse the command-line arguments # Parse the command-line arguments
...@@ -72,129 +80,62 @@ def main(user_input=None): ...@@ -72,129 +80,62 @@ def main(user_input=None):
verbosity_level = args['--verbose'] verbosity_level = args['--verbose']
bob.core.log.set_verbosity_level(logger, verbosity_level) bob.core.log.set_verbosity_level(logger, verbosity_level)
# get the parameters # get the arguments
noise_dim = int(args['--noise-dim'])
batch_size = int(args['--batch-size']) batch_size = int(args['--batch-size'])
epochs = int(args['--epochs']) epochs = int(args['--epochs'])
sample = int(args['--sample']) sample = int(args['--sample'])
output_dir = str(args['--output-dir']) output_dir = str(args['--output-dir'])
seed = int(args['--seed'])
use_gpu = bool(args['--use-gpu'])
images_dir = os.path.join(output_dir, 'samples') images_dir = os.path.join(output_dir, 'samples')
log_dir = os.path.join(output_dir, 'logs') log_dir = os.path.join(output_dir, 'logs')
model_dir = os.path.join(output_dir, 'models') model_dir = os.path.join(output_dir, 'models')
try: # process on the arguments / options
os.makedirs(output_dir) torch.manual_seed(seed)
except OSError: if use_gpu:
pass torch.cuda.manual_seed_all(seed)
if torch.cuda.is_available() and not use_gpu:
logger.warn("You have a CUDA device, so you should probably run with --use-gpu")
# data bob.io.base.create_directories_safe(images_dir)
bob.io.base.create_directories_safe(log_dir)
bob.io.base.create_directories_safe(images_dir)
# ============
# === DATA ===
# ============
# WARNING with the transforms ... act on labels too, at some point, I may have to write my own # WARNING with the transforms ... act on labels too, at some point, I may have to write my own
# Also, in 'ToTensor', there is a reshape performed from: HxWxC to CxHxW # Also, in 'ToTensor', there is a reshape performed from: HxWxC to CxHxW
face_dataset = MultiPIEDataset(root_dir='/idiap/temp/heusch/data/multipie-cropped-64x64', frontal_only=True, face_dataset = MultiPIEDataset(root_dir='/idiap/temp/heusch/data/multipie-cropped-64x64',
transform=transforms.Compose([RollChannels(), transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])) frontal_only=True,
logger.info("There are {} training images".format(len(face_dataset))) transform=transforms.Compose([
RollChannels(), # bob to skimage:
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
)
dataloader = torch.utils.data.DataLoader(face_dataset, batch_size=batch_size, shuffle=True) dataloader = torch.utils.data.DataLoader(face_dataset, batch_size=batch_size, shuffle=True)
logger.info("There are {} training images".format(len(face_dataset)))
#from matplotlib import pyplot # ===============
#for i in range(len(face_dataset)): # === NETWORK ===
# sample = face_dataset[i] # ===============
# pyplot.title('Sample {}: ID -> {}, pose ->{}'.format(i, sample['id'], sample['pose'])) ngpu = 1 # usually we don't have more than one GPU
# pyplot.imshow(sample['image'])
# #pyplot.imshow(numpy.rollaxis(numpy.rollaxis(sample['image'], 2),2))
# pyplot.show()
# network
ngpu = 1
netG = _netG(ngpu) netG = _netG(ngpu)
netG.apply(weights_init) netG.apply(weights_init)
print(netG) logger.info("Generator architecture: {}".format(netG))
netD = _netD(ngpu) netD = _netD(ngpu)
netD.apply(weights_init) netD.apply(weights_init)
print(netD) logger.info("Discriminator architecture: {}".format(netD))
criterion = nn.BCELoss()
nz = 100
input = torch.FloatTensor(batch_size, 3, 64, 64)
noise = torch.FloatTensor(batch_size, nz, 1, 1)
fixed_noise = torch.FloatTensor(batch_size, nz, 1, 1).normal_(0, 1)
label = torch.FloatTensor(batch_size)
real_label = 1
fake_label = 0
fixed_noise = Variable(fixed_noise)
# setup optimizer
lr = 0.0002
beta1 = 0.5
optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))
niter = 10
for epoch in range(niter):
for i, data in enumerate(dataloader, 0):
#print data
#print len(data)
############################
# (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
###########################
# train with real
netD.zero_grad()
real_cpu = data
#print (type(real_cpu))
#print (real_cpu)
batch_size = real_cpu.size(0)
input.resize_as_(real_cpu).copy_(real_cpu)
label.resize_(batch_size).fill_(real_label)
inputv = Variable(input)
labelv = Variable(label)
output = netD(inputv)
errD_real = criterion(output, labelv)
errD_real.backward()
D_x = output.data.mean()
# train with fake
noise.resize_(batch_size, nz, 1, 1).normal_(0, 1)
noisev = Variable(noise)
fake = netG(noisev)
labelv = Variable(label.fill_(fake_label))
output = netD(fake.detach())
errD_fake = criterion(output, labelv)
errD_fake.backward()
D_G_z1 = output.data.mean()
errD = errD_real + errD_fake
optimizerD.step()
############################
# (2) Update G network: maximize log(D(G(z)))
###########################
netG.zero_grad()
labelv = Variable(label.fill_(real_label)) # fake labels are real for generator cost
output = netD(fake)
errG = criterion(output, labelv)
errG.backward()
D_G_z2 = output.data.mean()
optimizerG.step()
print'[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f / %.4f' % (epoch, niter, i, len(dataloader), errD.data[0], errG.data[0], D_x, D_G_z1, D_G_z2)
if i % 100 == 0:
vutils.save_image(real_cpu,
'%s/real_samples.png' % output_dir,
normalize=True)
fake = netG(fixed_noise)
vutils.save_image(fake.data,
'%s/fake_samples_epoch_%03d.png' % (output_dir, epoch),
normalize=True)
# do checkpointing
torch.save(netG.state_dict(), '%s/netG_epoch_%d.pth' % (output_dir, epoch))
torch.save(netD.state_dict(), '%s/netD_epoch_%d.pth' % (output_dir, epoch))
# ===============
# === TRAINER ===
# ===============
trainer = DCGANTrainer(netG, netD, batch_size=batch_size, noise_dim=noise_dim, use_gpu=use_gpu, verbosity_level=verbosity_level)
trainer.train(dataloader, n_epochs=epochs, output_dir=output_dir)
#!/usr/bin/env python
# encoding: utf-8
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
import torchvision.utils as vutils
import bob.core
logger = bob.core.log.setup("bob.learn.pytorch")
class DCGANTrainer(object):
"""
"""
def __init__(self, netG, netD, batch_size=64, noise_dim=100, use_gpu=False, verbosity_level=2):
self.netG = netG
self.netD = netD
self.batch_size = batch_size
self.noise_dim = noise_dim
self.use_gpu = use_gpu
self.input = torch.FloatTensor(batch_size, 3, 64, 64)
self.noise = torch.FloatTensor(batch_size, noise_dim, 1, 1)
self.fixed_noise = torch.FloatTensor(batch_size, noise_dim, 1, 1).normal_(0, 1)
self.label = torch.FloatTensor(batch_size)
self.fixed_noise = Variable(self.fixed_noise)
self.criterion = nn.BCELoss()
if self.use_gpu:
self.netD.cuda()
self.netG.cuda()
self.criterion.cuda()
self.input, self.label = self.input.cuda(), self.label.cuda()
self.noise, self.fixed_noise = self.noise.cuda(), self.fixed_noise.cuda()
bob.core.log.set_verbosity_level(logger, verbosity_level)
def train(self, dataloader, n_epochs=10, learning_rate=0.0002, beta1=0.5, output_dir='out'):
real_label = 1
fake_label = 0
# setup optimizer
optimizerD = optim.Adam(self.netD.parameters(), lr=learning_rate, betas=(beta1, 0.999))
optimizerG = optim.Adam(self.netG.parameters(), lr=learning_rate, betas=(beta1, 0.999))
for epoch in range(n_epochs):
for i, data in enumerate(dataloader, 0):
# ===========================================================
# (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
# ===========================================================
# train with real
self.netD.zero_grad()
real_cpu = data
batch_size = real_cpu.size(0)
if self.use_gpu:
real_cpu = real_cpu.cuda()
self.input.resize_as_(real_cpu).copy_(real_cpu)
self.label.resize_(batch_size).fill_(real_label)
inputv = Variable(self.input)
labelv = Variable(self.label)
output = self.netD(inputv)
errD_real = self.criterion(output, labelv)
errD_real.backward()
D_x = output.data.mean()
# train with fake
self.noise.resize_(batch_size, self.noise_dim, 1, 1).normal_(0, 1)
noisev = Variable(self.noise)
fake = self.netG(noisev)
labelv = Variable(self.label.fill_(fake_label))
output = self.netD(fake.detach())
errD_fake = self.criterion(output, labelv)
errD_fake.backward()
D_G_z1 = output.data.mean()
errD = errD_real + errD_fake
optimizerD.step()
# =========================================
# (2) Update G network: maximize log(D(G(z)))
# =========================================
self.netG.zero_grad()
labelv = Variable(self.label.fill_(real_label)) # fake labels are real for generator cost
output = self.netD(fake)
errG = self.criterion(output, labelv)
errG.backward()
D_G_z2 = output.data.mean()
optimizerG.step()
logger.info("[{}/{}][{}/{}] => Loss D = {} -- Loss G = {}".format(epoch, n_epochs, i, len(dataloader), errD.data[0], errG.data[0]))
# save generated images at every epoch
fake = self.netG(self.fixed_noise)
vutils.save_image(fake.data, '%s/fake_samples_epoch_%03d.png' % (output_dir, epoch), normalize=True)
# do checkpointing
torch.save(self.netG.state_dict(), '%s/netG_epoch_%d.pth' % (output_dir, epoch))
torch.save(self.netD.state_dict(), '%s/netD_epoch_%d.pth' % (output_dir, epoch))
# gets sphinx autodoc done right - don't remove it
__all__ = [_ for _ in dir() if not _.startswith('_')]
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment