Skip to content
Snippets Groups Projects
Commit ce6dd84c authored by Guillaume HEUSCH's avatar Guillaume HEUSCH
Browse files

[architectures] modified the original DR-GAN to be as close as possible from...

[architectures] modified the original DR-GAN to be as close as possible from MSU (projection at the beginning of decoder)
parent d7610033
No related branches found
No related tags found
No related merge requests found
...@@ -157,6 +157,11 @@ class DRGANOriginal_decoder(nn.Module): ...@@ -157,6 +157,11 @@ class DRGANOriginal_decoder(nn.Module):
self.ngpu = 1 # usually, we don't have more than one GPU self.ngpu = 1 # usually, we don't have more than one GPU
self.project = nn.Sequential(
nn.Linear((self.noise_dim + self.conditional_dim + self.latent_dim), (self.latent_dim * 6 * 6)),
nn.Dropout(p=0.4),
)
self.main = nn.Sequential( self.main = nn.Sequential(
# input is 320x6x6, output is 160x6x6 # input is 320x6x6, output is 160x6x6
nn.ConvTranspose2d(320, 160, 3, 1, 1, bias=False), nn.ConvTranspose2d(320, 160, 3, 1, 1, bias=False),
...@@ -237,41 +242,25 @@ class DRGANOriginal_decoder(nn.Module): ...@@ -237,41 +242,25 @@ class DRGANOriginal_decoder(nn.Module):
The encoded ID for the minibatch The encoded ID for the minibatch
""" """
decoder_input = torch.cat((z, y, f), 1) decoder_input = torch.cat((z, y, f), 1)
decoder_input = decoder_input.squeeze()
# linear transform to build a hypercube as input to deconv layers # used in the "projection layer"
#
# input is noise_dim + conditional_dim + latent_dim
# output is (latent_dim x 6 x 6)
#
# Dropout + BatchNorm + ELU are applied to the cube
# squeeze, apply linear layer, and unsqueeze
# specify the ops
lin = nn.Linear((self.noise_dim + self.conditional_dim + self.latent_dim), (self.latent_dim * 6 * 6))
dropout = nn.Dropout(p=0.4)
bn = nn.BatchNorm2d(320) bn = nn.BatchNorm2d(320)
elu = nn.ELU(inplace=True) elu = nn.ELU(inplace=True)
if torch.cuda.is_available(): if torch.cuda.is_available():
lin = lin.cuda()
dropout = dropout.cuda()
bn = bn.cuda() bn = bn.cuda()
elu = elu.cuda() elu = elu.cuda()
decoder_input = torch.squeeze(decoder_input) # (projection + BN + ELU) + deconvolution
projected = lin(decoder_input)
projected = projected.unsqueeze(2)
projected = projected.unsqueeze(3)
dropped = dropout(projected)
reshaped = dropped.view(-1, self.latent_dim, 6, 6)
hypercube = elu(bn(reshaped))
# deconv layers
if isinstance(decoder_input.data, torch.cuda.FloatTensor) and self.ngpu > 1: if isinstance(decoder_input.data, torch.cuda.FloatTensor) and self.ngpu > 1:
projected = nn.parallel.data_parallel(self.project, decoder_input, range(self.ngpu))
hypercube = projected.view(-1, self.latent_dim, 6, 6)
hypercube = elu(bn(hypercube))
output = nn.parallel.data_parallel(self.main, hypercube, range(self.ngpu)) output = nn.parallel.data_parallel(self.main, hypercube, range(self.ngpu))
else: else:
projected = self.project(decoder_input)
hypercube = projected.view(-1, self.latent_dim, 6, 6)
hypercube = elu(bn(hypercube))
output = self.main(hypercube) output = self.main(hypercube)
return output return output
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment