Skip to content
Snippets Groups Projects
Commit a2026c5f authored by Guillaume HEUSCH's avatar Guillaume HEUSCH
Browse files

[script] fixed the DR-GAN training script - added the trainer

parent cdfc7e07
No related branches found
No related tags found
No related merge requests found
...@@ -65,6 +65,8 @@ from bob.learn.pytorch.architectures import DRGAN_encoder ...@@ -65,6 +65,8 @@ from bob.learn.pytorch.architectures import DRGAN_encoder
from bob.learn.pytorch.architectures import DRGAN_decoder from bob.learn.pytorch.architectures import DRGAN_decoder
from bob.learn.pytorch.architectures import DRGAN_discriminator from bob.learn.pytorch.architectures import DRGAN_discriminator
from bob.learn.pytorch.trainers import DRGANTrainer
def main(user_input=None): def main(user_input=None):
...@@ -114,7 +116,7 @@ def main(user_input=None): ...@@ -114,7 +116,7 @@ def main(user_input=None):
# WARNING with the transforms ... act on labels too, at some point, I may have to write my own # WARNING with the transforms ... act on labels too, at some point, I may have to write my own
# Also, in 'ToTensor', there is a reshape performed from: HxWxC to CxHxW # Also, in 'ToTensor', there is a reshape performed from: HxWxC to CxHxW
face_dataset = MultiPIEDataset(root_dir='/idiap/temp/heusch/data/multipie-cropped-64x64', face_dataset = MultiPIEDataset(root_dir='/idiap/temp/heusch/data/multipie-cropped-64x64',
frontal_only=True, frontal_only=False,
transform=transforms.Compose([ transform=transforms.Compose([
RollChannels(), # bob to skimage: RollChannels(), # bob to skimage:
ToTensor(), ToTensor(),
...@@ -152,5 +154,6 @@ def main(user_input=None): ...@@ -152,5 +154,6 @@ def main(user_input=None):
# =============== # ===============
# === TRAINER === # === TRAINER ===
# =============== # ===============
#trainer = DRGANTrainer(encoder, decoder, discriminator, batch_size=batch_size, latent_dim=latent_dim, noise_dim=noise_dim, use_gpu=use_gpu, verbosity_level=verbosity_level) trainer = DRGANTrainer(encoder, decoder, discriminator, image_size, batch_size=batch_size,
#trainer.train(dataloader, n_epochs=epochs, output_dir=output_dir) noise_dim=noise_dim, conditional_dim=conditional_dim, latent_dim=latent_dim, use_gpu=use_gpu, verbosity_level=verbosity_level)
trainer.train(dataloader, n_epochs=epochs, output_dir=output_dir)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment