Skip to content
Snippets Groups Projects
Commit 4b219490 authored by Guillaume HEUSCH's avatar Guillaume HEUSCH
Browse files

[test] unit tests for Conditional GAN

parent 78a92c7e
No related branches found
No related tags found
1 merge request!4Resolve "Add GANs"
......@@ -47,7 +47,28 @@ def test_architectures():
generator = DCGAN_generator(1)
output = generator.forward(t)
assert output.shape == torch.Size([1, 3, 64, 64])
# Conditional GAN
d = numpy.random.rand(1, 3, 64, 64).astype("float32")
t = torch.from_numpy(d)
cfm = numpy.zeros((1, 13, 64, 64), dtype="float32")
cfm[:, 0, :, :] = 1
cfmt = torch.from_numpy(cfm)
from ..architectures import ConditionalGAN_discriminator
discriminator = ConditionalGAN_discriminator(13)
output = discriminator.forward(t, cfmt)
assert output.shape == torch.Size([1])
g = numpy.random.rand(1, 100, 1, 1).astype("float32")
t = torch.from_numpy(g)
oh = numpy.zeros((1, 13, 1, 1), dtype="float32")
oh[0] = 1
oht = torch.from_numpy(oh)
from ..architectures import ConditionalGAN_generator
discriminator = ConditionalGAN_generator(100, 13)
output = discriminator.forward(t, oht)
assert output.shape == torch.Size([1, 3, 64, 64])
def test_transforms():
......@@ -149,3 +170,33 @@ def test_DCGANtrainer():
os.remove('netD_epoch_0.pth')
os.remove('netG_epoch_0.pth')
class DummyDataSetConditionalGAN(Dataset):
def __init__(self):
pass
def __len__(self):
return 100
def __getitem__(self, idx):
data = numpy.random.rand(3, 64, 64).astype("float32")
sample = {'image': torch.from_numpy(data), 'pose': numpy.random.randint(0, 13)}
return sample
def test_ConditionalGANTrainer():
from ..architectures import ConditionalGAN_generator
from ..architectures import ConditionalGAN_discriminator
g = ConditionalGAN_generator(100, 13)
d = ConditionalGAN_discriminator(13)
dataloader = torch.utils.data.DataLoader(DummyDataSetConditionalGAN(), batch_size=32, shuffle=True)
from ..trainers import ConditionalGANTrainer
trainer = ConditionalGANTrainer(g, d, [3, 64, 64], batch_size=32, noise_dim=100, conditional_dim=13)
trainer.train(dataloader, n_epochs=1, output_dir='.')
import os
assert os.path.isfile('fake_samples_epoch_000.png')
assert os.path.isfile('netD_epoch_0.pth')
assert os.path.isfile('netG_epoch_0.pth')
os.remove('fake_samples_epoch_000.png')
os.remove('netD_epoch_0.pth')
os.remove('netG_epoch_0.pth')
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment