Skip to content
Snippets Groups Projects
Commit e1cd9520 authored by Guillaume HEUSCH's avatar Guillaume HEUSCH
Browse files

[architecture] fixed dimension issue (output padding, squeeze)

parent 5e866340
Branches
Tags
No related merge requests found
......@@ -117,7 +117,6 @@ class DRGANOriginal_encoder(nn.Module):
else:
output = self.main(x)
print "Encoder output: {output}".format(output)
return output
......@@ -159,9 +158,10 @@ class DRGANOriginal_decoder(nn.Module):
nn.ConvTranspose2d(160, 256, 3, 1, 1, bias=False),
nn.BatchNorm2d(256),
nn.ELU(inplace=True),
# size OK
# ------------------------------------------
# input is 256x6x6, output is 256x12x12
nn.ConvTranspose2d(256, 256, 3, 2, 1, bias=False),
nn.ConvTranspose2d(256, 256, 3, 2, 1, output_padding=1, bias=False),
nn.BatchNorm2d(256),
nn.ELU(inplace=True),
# input is 256x12x12, output is 128x12x12
......@@ -174,7 +174,7 @@ class DRGANOriginal_decoder(nn.Module):
nn.ELU(inplace=True),
# ------------------------------------------
# input is 192x12x12, output is 192x24x24
nn.ConvTranspose2d(192, 192, 3, 2, 1, bias=False),
nn.ConvTranspose2d(192, 192, 3, 2, 1, output_padding=1, bias=False),
nn.BatchNorm2d(192),
nn.ELU(inplace=True),
# input is 192x24x24, output is 96x24x24
......@@ -187,7 +187,7 @@ class DRGANOriginal_decoder(nn.Module):
nn.ELU(inplace=True),
# ------------------------------------------
# input is 128x12x12, output is 128x48x48
nn.ConvTranspose2d(128, 128, 3, 2, 1, bias=False),
nn.ConvTranspose2d(128, 128, 3, 2, 1, output_padding=1, bias=False),
nn.BatchNorm2d(128),
nn.ELU(inplace=True),
# input is 128x48x48, output is 64x48x48
......@@ -200,7 +200,7 @@ class DRGANOriginal_decoder(nn.Module):
nn.ELU(inplace=True),
# ------------------------------------------
# input is 64x48x48, output is 64x96x96
nn.ConvTranspose2d(64, 64, 3, 2, 1, bias=False),
nn.ConvTranspose2d(64, 64, 3, 2, 1, output_padding=1, bias=False),
nn.BatchNorm2d(64),
nn.ELU(inplace=True),
# input is 64x96x96, output is 32x96x96
......@@ -208,8 +208,8 @@ class DRGANOriginal_decoder(nn.Module):
nn.BatchNorm2d(32),
nn.ELU(inplace=True),
# input is 32x96x96, output is 3x96x96
nn.ConvTranspose2d(96, 128, 3, 1, 1, bias=False),
nn.BatchNorm2d(192),
nn.ConvTranspose2d(32, 3, 3, 1, 1, bias=False),
nn.BatchNorm2d(3),
nn.Tanh(),
# ------------------------------------------
)
......@@ -234,7 +234,7 @@ class DRGANOriginal_decoder(nn.Module):
output = nn.parallel.data_parallel(self.main, decoder_input, range(self.ngpu))
else:
output = self.main(decoder_input)
return output
class DRGANOriginal_discriminator(nn.Module):
......@@ -252,11 +252,13 @@ class DRGANOriginal_discriminator(nn.Module):
conditional_dim: int
The dimension of the conditioning variable
"""
def __init__(self, image_size, number_of_ids, conditional_dim):
def __init__(self, image_size, number_of_ids, conditional_dim, latent_dim):
super(DRGANOriginal_discriminator, self).__init__()
self.number_of_ids = number_of_ids
self.conditional_dim = conditional_dim
self.latent_dim = latent_dim
self.ngpu = 1
self.main = nn.Sequential(
......@@ -321,10 +323,11 @@ class DRGANOriginal_discriminator(nn.Module):
# ------------------------------------------
# --- average pool
# input is (latent_dim)x6x6, output is latent_dimx1x1
nn.AvgPool2d(6, stride=1),
# --- fully connected
nn.Linear(320, (number_of_ids + conditional_dim + 1))
#nn.Linear(320, (number_of_ids + conditional_dim + 1))
)
......@@ -338,8 +341,15 @@ class DRGANOriginal_discriminator(nn.Module):
The minibatch of images to process.
"""
if isinstance(x.data, torch.cuda.FloatTensor) and self.ngpu > 1:
output = nn.parallel.data_parallel(self.main, x, range(self.ngpu))
output_avgpool = nn.parallel.data_parallel(self.main, x, range(self.ngpu))
else:
output = self.main(x)
output_avgpool = self.main(x)
return output.squeeze()
input_linear = output_avgpool.squeeze()
classifier = nn.Sequential(
nn.Linear(self.latent_dim, (self.number_of_ids + self.conditional_dim + 1)),
nn.Sigmoid()
)
output = classifier(input_linear)
return output
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment