Commit d8b6a02a authored by Guillaume HEUSCH's avatar Guillaume HEUSCH

[architecture] reorganized the code

parent dad4c652
......@@ -14,84 +14,82 @@ def weights_init(m):
m.bias.data.fill_(0)
class _netG(nn.Module):
def __init__(self, ngpu):
super(_netG, self).__init__()
self.ngpu = ngpu
class DCGAN_generator(nn.Module):
def __init__(self, ngpu):
super(DCGAN_generator, self).__init__()
self.ngpu = ngpu
# just to test - will soon be args
nz = 100
ngf = 64
nc = 3
# just to test - will soon be args
nz = 100
ngf = 64
nc = 3
self.main = nn.Sequential(
# input is Z, going into a convolution
nn.ConvTranspose2d( nz, ngf * 8, 4, 1, 0, bias=False),
nn.BatchNorm2d(ngf * 8),
nn.ReLU(True),
# state size. (ngf*8) x 4 x 4
nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf * 4),
nn.ReLU(True),
# state size. (ngf*4) x 8 x 8
nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf * 2),
nn.ReLU(True),
# state size. (ngf*2) x 16 x 16
nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf),
nn.ReLU(True),
# state size. (ngf) x 32 x 32
nn.ConvTranspose2d( ngf, nc, 4, 2, 1, bias=False),
nn.Tanh()
# state size. (nc) x 64 x 64
)
self.main = nn.Sequential(
# input is Z, going into a convolution
nn.ConvTranspose2d( nz, ngf * 8, 4, 1, 0, bias=False),
nn.BatchNorm2d(ngf * 8),
nn.ReLU(True),
# state size. (ngf*8) x 4 x 4
nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf * 4),
nn.ReLU(True),
# state size. (ngf*4) x 8 x 8
nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf * 2),
nn.ReLU(True),
# state size. (ngf*2) x 16 x 16
nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf),
nn.ReLU(True),
# state size. (ngf) x 32 x 32
nn.ConvTranspose2d( ngf, nc, 4, 2, 1, bias=False),
nn.Tanh()
# state size. (nc) x 64 x 64
)
def forward(self, input):
if isinstance(input.data, torch.cuda.FloatTensor) and self.ngpu > 1:
output = nn.parallel.data_parallel(self.main, input, range(self.ngpu))
else:
output = self.main(input)
return output
def forward(self, input):
if isinstance(input.data, torch.cuda.FloatTensor) and self.ngpu > 1:
output = nn.parallel.data_parallel(self.main, input, range(self.ngpu))
else:
output = self.main(input)
return output
class _netD(nn.Module):
def __init__(self, ngpu):
super(_netD, self).__init__()
self.ngpu = ngpu
class DCGAN_discriminator(nn.Module):
def __init__(self, ngpu):
super(DCGAN_discriminator, self).__init__()
self.ngpu = ngpu
# just to test - will soon be args
ndf = 64
nc = 3
# just to test - will soon be args
ndf = 64
nc = 3
self.main = nn.Sequential(
# input is (nc) x 64 x 64
nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf) x 32 x 32
nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 2),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf*2) x 16 x 16
nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 4),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf*4) x 8 x 8
nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 8),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf*8) x 4 x 4
nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
nn.Sigmoid()
)
def forward(self, input):
if isinstance(input.data, torch.cuda.FloatTensor) and self.ngpu > 1:
output = nn.parallel.data_parallel(self.main, input, range(self.ngpu))
else:
output = self.main(input)
return output.view(-1, 1).squeeze(1)
self.main = nn.Sequential(
# input is (nc) x 64 x 64
nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf) x 32 x 32
nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 2),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf*2) x 16 x 16
nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 4),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf*4) x 8 x 8
nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 8),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf*8) x 4 x 4
nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
nn.Sigmoid()
)
def forward(self, input):
if isinstance(input.data, torch.cuda.FloatTensor) and self.ngpu > 1:
output = nn.parallel.data_parallel(self.main, input, range(self.ngpu))
else:
output = self.main(input)
return output.view(-1, 1).squeeze(1)
from .DCGAN import DCGAN_generator
from .DCGAN import DCGAN_discriminator
from .DCGAN import weights_init
# gets sphinx autodoc done right - don't remove it
__all__ = [_ for _ in dir() if not _.startswith('_')]
......@@ -57,9 +57,9 @@ from torch.autograd import Variable
from bob.learn.pytorch.datasets import MultiPIEDataset
from bob.learn.pytorch.datasets import RollChannels
from bob.learn.pytorch.architectures.DCGAN import _netG
from bob.learn.pytorch.architectures.DCGAN import _netD
from bob.learn.pytorch.architectures.DCGAN import weights_init
from bob.learn.pytorch.architectures import DCGAN_generator
from bob.learn.pytorch.architectures import DCGAN_discriminator
from bob.learn.pytorch.architectures import weights_init
from bob.learn.pytorch.trainers.DCGANTrainer import DCGANTrainer
......@@ -125,17 +125,17 @@ def main(user_input=None):
# ===============
ngpu = 1 # usually we don't have more than one GPU
netG = _netG(ngpu)
netG.apply(weights_init)
logger.info("Generator architecture: {}".format(netG))
generator = DCGAN_generator(ngpu)
generator.apply(weights_init)
logger.info("Generator architecture: {}".format(generator))
netD = _netD(ngpu)
netD.apply(weights_init)
logger.info("Discriminator architecture: {}".format(netD))
discrminator = DCGAN_discriminator(ngpu)
discrminator.apply(weights_init)
logger.info("Discriminator architecture: {}".format(discrminator))
# ===============
# === TRAINER ===
# ===============
trainer = DCGANTrainer(netG, netD, batch_size=batch_size, noise_dim=noise_dim, use_gpu=use_gpu, verbosity_level=verbosity_level)
trainer = DCGANTrainer(generator, discrminator, batch_size=batch_size, noise_dim=noise_dim, use_gpu=use_gpu, verbosity_level=verbosity_level)
trainer.train(dataloader, n_epochs=epochs, output_dir=output_dir)
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment