diff --git a/bob/learn/pytorch/datasets/__init__.py b/bob/learn/pytorch/datasets/__init__.py
index b03b5fb0212744b085f0ef6254fce7c76a18479b..5772fc337c6d9e06bd7c12f78afcf4fbf13cf8d8 100644
--- a/bob/learn/pytorch/datasets/__init__.py
+++ b/bob/learn/pytorch/datasets/__init__.py
@@ -1,4 +1,5 @@
 from .multipie import MultiPIEDataset
+from .casia_webface import CasiaDataset
 
 # transforms
 from .utils import RollChannels 
diff --git a/bob/learn/pytorch/datasets/casia_webface.py b/bob/learn/pytorch/datasets/casia_webface.py
new file mode 100644
index 0000000000000000000000000000000000000000..955e0e4e3424c7c009a8c8804a92d56930978b27
--- /dev/null
+++ b/bob/learn/pytorch/datasets/casia_webface.py
@@ -0,0 +1,92 @@
+#!/usr/bin/env python
+# encoding: utf-8
+
+import os
+
+import torch
+import numpy
+
+from torch.utils.data import Dataset, DataLoader
+
+import bob.db.casia_webface
+import bob.io.base
+import bob.io.image
+
+from .utils import map_labels
+
+class CasiaDataset(Dataset):
+  """Casia WebFace dataset.
+  
+  Class representing the CASIA WebFace dataset
+
+  **Parameters**
+
+  root-dir: path
+    The path to the data
+
+  frontal_only: boolean
+    If you want to only use frontal faces 
+    
+  transform: torchvision.transforms
+    The transform(s) to apply to the face images
+  """
+
+  # TODO: Start from original data and annotations - Guillaume HEUSCH, 06-11-2017
+
+  def __init__(self, root_dir, frontal_only=False, transform=None):
+    self.root_dir = root_dir
+    self.transform = transform
+  
+    dir_to_pose_label = {'l90': '0',
+                       'l75': '1',
+                       'l60': '2',
+                       'l45': '3',
+                       'l30': '4',
+                       'l15': '5',
+                       '0'  : '6',
+                       'r15': '7',
+                       'r30': '8',
+                       'r45': '9',
+                       'r60': '10',
+                       'r75': '11',
+                       'r90': '12',
+                      }
+
+    # get all the needed file, the pose labels, and the id labels
+    self.data_files = []
+    self.pose_labels = []
+    id_labels = []
+
+    for root, dirs, files in os.walk(self.root_dir):
+      for name in files:
+        filename = os.path.split(os.path.join(root, name))[-1]
+        path = root.split(os.sep)
+        subject = int(path[-1])
+        cluster = path[-2]
+        self.data_files.append(os.path.join(root, name))
+        self.pose_labels.append(dir_to_pose_label[cluster])
+        id_labels.append(subject)
+ 
+    self.id_labels = map_labels(id_labels)
+      
+    
+  def __len__(self):
+    """
+      return the length of the dataset (i.e. nb of examples)
+    """
+    return len(self.data_files)
+ 
+
+  def __getitem__(self, idx):
+    """
+      return a sample from the dataset
+    """
+    image = bob.io.base.load(self.data_files[idx])
+    identity = self.id_labels[idx]
+    pose = self.pose_labels[idx]
+    sample = {'image': image, 'id': identity, 'pose': pose}
+
+    if self.transform:
+      sample = self.transform(sample)
+
+    return sample
diff --git a/bob/learn/pytorch/scripts/train_conditionalgan_casia.py b/bob/learn/pytorch/scripts/train_conditionalgan_casia.py
new file mode 100644
index 0000000000000000000000000000000000000000..84b28ee1559b79cf2fe758976d2e6d9424e3fb3a
--- /dev/null
+++ b/bob/learn/pytorch/scripts/train_conditionalgan_casia.py
@@ -0,0 +1,145 @@
+#!/usr/bin/env python
+# encoding: utf-8
+
+
+""" Train a Conditional GAN 
+
+Usage:
+  %(prog)s [--noise-dim=<int>] [--conditional-dim=<int>] 
+           [--batch-size=<int>] [--epochs=<int>] [--sample=<int>]
+           [--output-dir=<path>] [--use-gpu] [--seed=<int>] [--verbose ...]
+
+Options:
+  -h, --help                    Show this screen.
+  -V, --version                 Show version.
+  -n, --noise-dim=<int>         The dimension of the noise [default: 100]
+  -c, --conditional-dim=<int>   The dimension of the conditional variable [default: 13]
+  -b, --batch-size=<int>        The size of your mini-batch [default: 64]
+  -e, --epochs=<int>            The number of training epochs [default: 100] 
+  -s, --sample=<int>            Save generated images at every 'sample' batch iteration [default: 100000000000] 
+  -o, --output-dir=<path>       Dir to save the logs, models and images [default: ./cgan-casia] 
+  -g, --use-gpu                 Use the GPU 
+  -S, --seed=<int>              The random seed [default: 3] 
+  -v, --verbose                 Increase the verbosity (may appear multiple times).
+
+Example:
+
+  To run the training process 
+
+    $ %(prog)s --batch-size 64 --epochs 25 --output-dir drgan 
+
+See '%(prog)s --help' for more information.
+
+"""
+
+import os, sys
+import pkg_resources
+
+import bob.core
+logger = bob.core.log.setup("bob.learn.pytorch")
+
+from docopt import docopt
+
+version = pkg_resources.require('bob.learn.pytorch')[0].version
+
+import numpy
+import bob.io.base
+
+# torch
+import torch
+import torch.nn as nn
+import torch.optim as optim
+import torchvision.transforms as transforms
+import torchvision.utils as vutils
+from torch.autograd import Variable
+
+# data and architecture from the package
+from bob.learn.pytorch.datasets import CasiaDataset
+from bob.learn.pytorch.datasets import RollChannels
+from bob.learn.pytorch.datasets import ToTensor
+from bob.learn.pytorch.datasets import Normalize
+
+from bob.learn.pytorch.architectures import ConditionalGAN_generator
+from bob.learn.pytorch.architectures import ConditionalGAN_discriminator
+from bob.learn.pytorch.architectures import weights_init
+
+from bob.learn.pytorch.trainers import ConditionalGANTrainer
+
+
+def main(user_input=None):
+  
+  # Parse the command-line arguments
+  if user_input is not None:
+      arguments = user_input
+  else:
+      arguments = sys.argv[1:]
+
+  prog = os.path.basename(sys.argv[0])
+  completions = dict(prog=prog, version=version,)
+  args = docopt(__doc__ % completions,argv=arguments,version='Train DR-GAN (%s)' % version,)
+
+  # verbosity
+  verbosity_level = args['--verbose']
+  bob.core.log.set_verbosity_level(logger, verbosity_level)
+
+  # get the arguments
+  noise_dim = int(args['--noise-dim'])
+  conditional_dim = int(args['--conditional-dim'])
+  batch_size = int(args['--batch-size'])
+  epochs = int(args['--epochs'])
+  sample = int(args['--sample'])
+  output_dir = str(args['--output-dir'])
+  seed = int(args['--seed'])
+  use_gpu = bool(args['--use-gpu'])
+
+  images_dir = os.path.join(output_dir, 'samples')
+  log_dir = os.path.join(output_dir, 'logs')
+  model_dir = os.path.join(output_dir, 'models')
+
+  # process on the arguments / options
+  torch.manual_seed(seed)
+  if use_gpu:
+    torch.cuda.manual_seed_all(seed)
+  if torch.cuda.is_available() and not use_gpu:
+    logger.warn("You have a CUDA device, so you should probably run with --use-gpu")
+  bob.io.base.create_directories_safe(images_dir)
+  bob.io.base.create_directories_safe(log_dir)
+  bob.io.base.create_directories_safe(images_dir)
+
+  # ============
+  # === DATA ===
+  # ============
+  
+  # WARNING with the transforms ... act on labels too, at some point, I may have to write my own
+  # Also, in 'ToTensor', there is a reshape performed from: HxWxC to CxHxW
+  face_dataset = CasiaDataset(root_dir='/idiap/temp/heusch/data/casia-webface-cropped-64x64-pose-clusters', 
+                                 frontal_only=False, 
+                                 transform=transforms.Compose([
+                                   RollChannels(), # bob to skimage:  
+                                   ToTensor(),
+                                   Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
+                                   ])
+                                 )
+
+  dataloader = torch.utils.data.DataLoader(face_dataset, batch_size=batch_size, shuffle=True)
+  logger.info("There are {} training images".format(len(face_dataset))) 
+  
+  # ===============
+  # === NETWORK ===
+  # ===============
+  ngpu = 1 # usually we don't have more than one GPU
+  
+  generator = ConditionalGAN_generator(noise_dim, conditional_dim)
+  generator.apply(weights_init)
+  logger.info("Generator architecture: {}".format(generator))
+
+  discriminator = ConditionalGAN_discriminator(conditional_dim)
+  discriminator.apply(weights_init)
+  logger.info("Discriminator architecture: {}".format(discriminator))
+
+
+  # ===============
+  # === TRAINER ===
+  # ===============
+  trainer = ConditionalGANTrainer(generator, discriminator, [3, 64, 64], batch_size=batch_size, noise_dim=noise_dim, conditional_dim=conditional_dim, use_gpu=use_gpu, verbosity_level=verbosity_level)
+  trainer.train(dataloader, n_epochs=epochs, output_dir=output_dir)
diff --git a/setup.py b/setup.py
index 616399401205ac70585c9918183f92e6b4589d6f..1430ca3bc9561143e2bcfe6753390afa585b5fce 100644
--- a/setup.py
+++ b/setup.py
@@ -76,6 +76,7 @@ setup(
       'console_scripts': [
         'train_dcgan_multipie.py = bob.learn.pytorch.scripts.train_dcgan_multipie:main', 
         'train_conditionalgan_multipie.py = bob.learn.pytorch.scripts.train_conditionalgan_multipie:main', 
+        'train_conditionalgan_casia.py = bob.learn.pytorch.scripts.train_conditionalgan_casia:main', 
       ],