Commit a613225a authored by Guillaume HEUSCH's avatar Guillaume HEUSCH
Browse files

remove the parallel distribution of a network input in case it is run on multiple GPU

parent 00aa77d6
Pipeline #26064 failed with stage
in 28 minutes and 36 seconds
......@@ -80,10 +80,13 @@ class ConditionalGAN_generator(nn.Module):
"""
generator_input = torch.cat((z, y), 1)
if isinstance(generator_input.data, torch.cuda.FloatTensor) and self.ngpu > 1:
output = nn.parallel.data_parallel(self.main, generator_input, range(self.ngpu))
else:
output = self.main(generator_input)
#if isinstance(generator_input.data, torch.cuda.FloatTensor) and self.ngpu > 1:
# output = nn.parallel.data_parallel(self.main, generator_input, range(self.ngpu))
#else:
# output = self.main(generator_input)
# let's assume that we will never face the case where more than a GPU is used ...
output = self.main(generator_input)
return output
......@@ -159,8 +162,11 @@ class ConditionalGAN_discriminator(nn.Module):
the output of the discriminator
"""
input_discriminator = torch.cat((images, y), 1)
if isinstance(input_discriminator.data, torch.cuda.FloatTensor) and self.ngpu > 1:
output = nn.parallel.data_parallel(self.main, input_discriminator, range(self.ngpu))
else:
output = self.main(input_discriminator)
#if isinstance(input_discriminator.data, torch.cuda.FloatTensor) and self.ngpu > 1:
# output = nn.parallel.data_parallel(self.main, input_discriminator, range(self.ngpu))
#else:
# output = self.main(input_discriminator)
# let's assume that we will never face the case where more than a GPU is used ...
output = self.main(input_discriminator)
return output.view(-1, 1).squeeze(1)
......@@ -74,10 +74,13 @@ class DCGAN_generator(nn.Module):
the output of the generator (i.e. an image)
"""
if isinstance(input.data, torch.cuda.FloatTensor) and self.ngpu > 1:
output = nn.parallel.data_parallel(self.main, input, range(self.ngpu))
else:
output = self.main(input)
#if isinstance(input.data, torch.cuda.FloatTensor) and self.ngpu > 1:
# output = nn.parallel.data_parallel(self.main, input, range(self.ngpu))
#else:
# output = self.main(input)
# let's assume that we will never face the case where more than a GPU is used ...
output = self.main(input)
return output
......@@ -148,9 +151,12 @@ class DCGAN_discriminator(nn.Module):
the output of the generator (i.e. an image)
"""
if isinstance(input.data, torch.cuda.FloatTensor) and self.ngpu > 1:
output = nn.parallel.data_parallel(self.main, input, range(self.ngpu))
else:
output = self.main(input)
#if isinstance(input.data, torch.cuda.FloatTensor) and self.ngpu > 1:
# output = nn.parallel.data_parallel(self.main, input, range(self.ngpu))
#else:
# output = self.main(input)
# let's assume that we will never face the case where more than a GPU is used ...
output = self.main(input)
return output.view(-1, 1).squeeze(1)
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment