diff --git a/bob/paper/ijcb2021_synthetic_dataset/generate/db_generator.py b/bob/paper/ijcb2021_synthetic_dataset/generate/db_generator.py
index 0ed5bf4e9c47b564185e9b0437b43cd8d0d95d43..1aaa5e2283f6b591780ae37a8b8920abe3c4eabb 100644
--- a/bob/paper/ijcb2021_synthetic_dataset/generate/db_generator.py
+++ b/bob/paper/ijcb2021_synthetic_dataset/generate/db_generator.py
@@ -66,7 +66,7 @@ class DBGenerator(object):
         seed : int
             Seed of the random generator for reproducibility
         covariates_analysis_path: str
-            Path to pickled covariaties analysis results (that contains the latent directions)
+            Path to pickled covariates analysis results (that contains the latent directions)
         """
 
         self.generator = generator
@@ -74,8 +74,10 @@ class DBGenerator(object):
         self.cropper = cropper
         self.extractor = extractor
         self.seed = seed
-        with open(covariates_analysis_path, "rb") as f:
+        with open(rc["bob.synface.latent_directions"], "rb") as f:
             self.covariates_analysis = pkl.load(f)
+        with open(rc["bob.synface.latent_directions_ethnicity"], "rb") as f:
+            self.covariates_analysis_ethnicity = pkl.load(f)
 
         self.ict = ict
         self.covariates_scaling = covariates_scaling
@@ -190,6 +192,41 @@ class DBGenerator(object):
         ]
         return faces
 
+    def create_reference_from_w_latent(self, identity, w_latent):
+        """ Create the reference Face for the required identity by using it's exisisting projected w_latent,
+        and computes the face image as well as the face embedding.
+
+        Parameters
+        ----------
+        identity : int
+            Tag of the currently generated identity
+        w_latent : np.arrays
+            W-space latent vector corresponding to the projection. Shape [latent_dim]
+
+        Returns
+        -------
+        :synthface.generate.face_wrapper.Face:
+    
+        """
+        ref_w = editing.latent_neutralisation(w_latent, self.covariates_analysis)
+
+        # Compute the image and embeddings
+        raw_images = self.generator.run_from_W(ref_w)
+        ref_img = self.image_postprocessor(raw_images[0])
+        ref_embedding = self.get_embedding(ref_img)
+
+        ref_face = Face(
+            z_latent=None,
+            w_latent=ref_w,
+            image=ref_img,
+            embedding=ref_embedding,
+            identity=identity,
+            sample="reference",
+        )
+
+        self.save_faces([ref_face], self.references_hdf5_path)
+        return ref_face
+
     def create_reference(self, identity, compared_embeddings):
         """ Create the reference Face for the required identity by randomly sampling a new latent vector,
         ans computing the face image as well as the face embedding.
@@ -255,7 +292,7 @@ class DBGenerator(object):
         )
         return ref_face, num_candidates
 
-    def augment_identity(self, reference_face):
+    def augment_identity(self, reference_face, ethnicity_labels = None):
         """ Generate all semantic augmentations for the provided reference.
 
         Parameters
@@ -269,9 +306,14 @@ class DBGenerator(object):
             Faces for each semantic augmentations of the reference
         """
         # Compute semantic augmentations
-        w_augmented, labels = editing.latent_augmentation(
-            reference_face.w_latent, self.covariates_analysis, self.covariates_scaling
-        )
+        if ethnicity_labels is not None:
+            w_augmented, labels = editing.latent_augmentation_ethnicity(
+                reference_face.w_latent, self.covariates_analysis_ethnicity, ethnicity_labels, ethnicity_scaling = 1.2
+            )
+        else:
+            w_augmented, labels = editing.latent_augmentation(
+                reference_face.w_latent, self.covariates_analysis, self.covariates_scaling
+            )
 
         # Optionally, also add color variation
         if self.color_variations == 0:
@@ -352,7 +394,7 @@ class DBGenerator(object):
         else:
             return []
 
-    def create_database(self, identities):
+    def create_database(self, identities, ethnicity_labels = None):
         """ Create a full database (reference + augmentations) for the provided identity tags
 
         Parameters
@@ -361,7 +403,7 @@ class DBGenerator(object):
             List of unique tags for each identity that must be created
         """
         self.create_references(identities)
-        self.augment_identities(identities)
+        self.augment_identities(identities, ethnicity_labels)
 
     def create_references(self, identities):
         """ Create all references for the provided identity tags
@@ -402,7 +444,7 @@ class DBGenerator(object):
             self.save_faces([ref_face], self.references_hdf5_path)
             compared_embeddings.append(ref_face.embedding)
 
-    def augment_identities(self, identities):
+    def augment_identities(self, identities, ethnicity_labels = None):
         """ Create all augmentations for the provided identity tags.
         This method assumes the references for those identities have already been 
         created.
@@ -420,7 +462,7 @@ class DBGenerator(object):
 
             # Augment identity
             start_time = time.time()
-            faces = self.augment_identity(ref_face)
+            faces = self.augment_identity(ref_face, ethnicity_labels)
             runtime = time.time() - start_time
 
             # Store stats
@@ -430,3 +472,4 @@ class DBGenerator(object):
                 f.write("{} {:.2f}\n".format(identity, runtime))
 
             self.save_faces(faces, self.augmentations_hdf5_path)
+
diff --git a/bob/paper/ijcb2021_synthetic_dataset/latent/analysis.py b/bob/paper/ijcb2021_synthetic_dataset/latent/analysis.py
index 632311cb829263ddfe8f10d099a08788d94c79b0..ebb5dc845e8b61f0bcf172398a6f22998e48c1fc 100644
--- a/bob/paper/ijcb2021_synthetic_dataset/latent/analysis.py
+++ b/bob/paper/ijcb2021_synthetic_dataset/latent/analysis.py
@@ -6,19 +6,29 @@ from sklearn.svm import LinearSVC
 from sklearn.preprocessing import LabelEncoder
 import numpy as np
 
+# additional import
+import h5py
+
 MAIN_DIR = rc['bob.synface.multipie_projections']
 LAT_DIR = os.path.join(MAIN_DIR, "w_latents")
 FAILURE_FILE = os.path.join(MAIN_DIR, "failure.dat")
 MULTIPIE_DIR = rc["bob.db.multipie.directory"]
 
-VALID_COVARIATES = ['illumination','expression','pose']
+ETHNICITY_KEY_DIR = rc['bob.db.keys']
+
+PROTOCOLS = {
+    "expression": "E_lit",
+    "pose": "P_lit",
+    "illumination": "U"
+}
+
+VALID_CAMERAS = ["08_0", "13_0", "14_0", "05_1", "05_0", "04_1", "19_0"]
 
 POS_TO_CAM = {
     "frontal": ["05_1"],
     "left": ["11_0", "12_0", "09_0", "08_0", "13_0", "14_0"],
     "right": ["05_0", "04_1", "19_0", "20_0", "01_0", "24_0"],
 }
-
 CAM_TO_POS = {cam: pos for pos, cam_list in POS_TO_CAM.items() for cam in cam_list}
 
 POS_TO_FLASH = {
@@ -32,19 +42,17 @@ FLASH_TO_POS = {
     flash: pos for pos, flash_list in POS_TO_FLASH.items() for flash in flash_list
 }
 
-EXPRESSION_TO_RECORDING = { # Recording = SessionNumber_RecordingNumber
-    'neutral': ["01_01", "02_01", "03_01", "04_01", "04_02"],
-    'smile': ["01_02", "03_02"],
-    'surprise': ["02_02"],
-    'squint': ["02_03"],
-    'disgust': ["03_03"],
-    'scream': ["04_03"]
+VPOS_TO_FLASH = {
+    "no_flash": [0],
+    "middle": list(range(1, 14)),
+    "top": list(range(14, 19)),
 }
 
-RECORDING_TO_EXPRESSION = {
-    recording : expr for expr, recording_list in EXPRESSION_TO_RECORDING.items() for recording in recording_list
+FLASH_TO_VPOS = {
+    flash: vpos for vpos, flash_list in VPOS_TO_FLASH.items() for flash in flash_list
 }
 
+
 def filter_failure_cases(file_list):
     with open(FAILURE_FILE, "r") as f:
         failures = [item.rstrip() for item in f.readlines()]
@@ -54,39 +62,25 @@ def filter_failure_cases(file_list):
 
 def get_covariate(file, covariate):
     if covariate == "expression":
-        recording = '_'.join(file.path.split('/')[-1].split('_')[1:3])
-        return RECORDING_TO_EXPRESSION[recording]
+        return file.expression.name
     elif covariate == "pose":
-        camera = file.path.split('/')[4]
-        return CAM_TO_POS[camera]
+        return CAM_TO_POS[file.file_multiview.camera.name]
     elif covariate == "illumination":
-        shot = file.path.split('/')[-1].split('_')[-1]
-        return FLASH_TO_POS[int(shot)]
+        return FLASH_TO_POS[file.file_multiview.shot_id]
     else:
         raise ValueError(
-            "Unknown covariate {} not in {}".format(covariate, VALID_COVARIATES)
+            "Unknown covariate {} not in {}".format(covariate, PROTOCOLS.keys())
         )
 
-def get_file_list(covariate, group='world'):
-    if covariate == 'illumination':
-        db = Multipie(original_directory=MULTIPIE_DIR)
-        raw_files = db.objects(protocol='U', groups=[group])
-    elif covariate == 'expression':
-        from ..config.project.multipie_E import database as db 
-        raw_files = db.objects(groups=[group])
-    elif covariate == 'pose':
-        from ..config.project.multipie_P import database as db
-        raw_files = db.objects(groups=[group])
-    else:
-        raise ValueError('Unknown covariate {}'.format(covariate))
-
-    return filter_failure_cases(raw_files)
 
 def load_latents(covariate, group="world"):
-    assert covariate in VALID_COVARIATES, "Unknown `covariate` {} not in {}".format(
-        covariate, VALID_COVARIATES
+    assert covariate in PROTOCOLS.keys(), "Unknown `covariate` {} not in {}".format(
+        covariate, PROTOCOLS.keys()
+    )
+    db = Multipie(original_directory=MULTIPIE_DIR)
+    files = filter_failure_cases(
+        db.objects(protocol=PROTOCOLS[covariate], groups=[group], cameras=VALID_CAMERAS)
     )
-    files = get_file_list(covariate, group)
     df = pd.DataFrame()
     df["file"] = [f.path for f in files]
     df["latent"] = [f.load(LAT_DIR, ".h5") for f in files]
@@ -95,11 +89,41 @@ def load_latents(covariate, group="world"):
 
     return df
 
+def load_latents_ethnicity(labels):
+    latents = []
+    ethnicities = []
 
-def binary_analysis(train_df, covariate, target_labels, seed=None, **kwargs):
+    df = pd.DataFrame()
+
+    for label in labels:
+        filename = ETHNICITY_KEY_DIR + '/' + label + '_keys.txt'
+        with open(filename) as f:
+            files = f.read().splitlines()
+        
+        for f in files:
+            file_path = LAT_DIR + f + '.h5'
+            try:
+                with h5py.File(file_path, "r") as f:
+                    a_group_key = list(f.keys())[0]
+
+                    # Get the data
+                    data = list(f[a_group_key])
+                    w_latent = np.array(data)
+                    latents.append(w_latent)
+                    ethnicities.append(label)
+
+            except:
+                print("Could not retrieve latent projection for: {} ".format(file_path))
+    
+    df["latent"] = latents
+    df["ethnicity"] = ethnicities
+
+    return df
+
+def binary_analysis(train_df, covariate, target_labels, **kwargs):
     train_df = train_df[train_df[covariate].isin(target_labels)]
-    print(seed)
-    svm = LinearSVC(fit_intercept=False, random_state=seed, **kwargs)
+
+    svm = LinearSVC(fit_intercept=False, **kwargs)
     train_latents = np.stack(train_df["latent"])
     train_labels = np.stack(train_df[covariate])
 
@@ -144,35 +168,45 @@ def multiclass_analysis(train_df, covariate, neutral_label, other_labels, **kwar
         for label in other_labels
     }
 
+def ethnicity_analysis(train_df, ethnicity):
+    if ethnicity == "MEDS":
+        return binary_analysis(train_df, covariate="ethnicity", target_labels=["W", "B"])
+    else:
+        return multiclass_analysis(
+        train_df,
+        covariate="ethnicity",
+        neutral_label="W",
+        other_labels=["B", "A", "H"],
+    )
 
-def pose_analysis(train_df, seed=None):
-    return binary_analysis(train_df, covariate="pose", target_labels=["left", "right"], seed=seed)
+def pose_analysis(train_df):
+    return binary_analysis(train_df, covariate="pose", target_labels=["left", "right"])
 
 
-def illumination_analysis(train_df, seed=None):
+def illumination_analysis(train_df):
     return binary_analysis(
-        train_df, covariate="illumination", target_labels=["left", "right"], seed=seed
+        train_df, covariate="illumination", target_labels=["left", "right"]
     )
 
 
-def expression_analysis(train_df, seed=None):
+def expression_analysis(train_df):
     return multiclass_analysis(
         train_df,
         covariate="expression",
         neutral_label="neutral",
         other_labels=["smile", "scream", "disgust", "squint", "surprise"],
-        seed=seed
     )
 
 
-def covariate_analysis(covariate, train_df, seed=None):
+def covariate_analysis(covariate, train_df):
     if covariate == "expression":
-        return expression_analysis(train_df, seed)
+        return expression_analysis(train_df)
     elif covariate == "pose":
-        return pose_analysis(train_df, seed)
+        return pose_analysis(train_df)
     elif covariate == "illumination":
-        return illumination_analysis(train_df, seed)
+        return illumination_analysis(train_df)
     else:
         raise ValueError(
-            "Unknown covariate {} not in {}".format(covariate, VALID_COVARIATES)
+            "Unknown covariate {} not in {}".format(covariate, PROTOCOLS.keys())
         )
+
diff --git a/bob/paper/ijcb2021_synthetic_dataset/latent/editing.py b/bob/paper/ijcb2021_synthetic_dataset/latent/editing.py
index 9edb81439dc6fd3a55c5df4dc833e27238a7b9c4..80b426df0080cd087ba6d5e44e13df78b3d09432 100644
--- a/bob/paper/ijcb2021_synthetic_dataset/latent/editing.py
+++ b/bob/paper/ijcb2021_synthetic_dataset/latent/editing.py
@@ -27,6 +27,12 @@ def latent_neutralisation(w, covariates_analysis):
 
     return w
 
+def ethnicity_neutralisation(w,covariates_analysis):
+    ethnicity_normal = covariates_analysis["ethnicity"]["normal"]
+    w -= w.dot(ethnicity_normal.T) * ethnicity_normal
+
+    return w
+
 
 def binary_augmentation(w, analysis, num_latents, scaling):
     neg_mean = analysis["neg_stats"]["mean"]
@@ -37,7 +43,9 @@ def binary_augmentation(w, analysis, num_latents, scaling):
     normal = analysis["normal"]
     scales = scaling * np.linspace(-extremum, extremum, num_latents)[:, None]
     return w + scales * normal, scales
-    
+
+def ethnicity_augmentation(w, ethnicity_analysis, num_latents, scaling=3 / 4):
+    return binary_augmentation(w, ethnicity_analysis, num_latents, scaling)
 
 def pose_augmentation(w, pose_analysis, num_latents, scaling=3 / 4):
     return binary_augmentation(w, pose_analysis, num_latents, scaling)
@@ -61,6 +69,52 @@ def expression_augmentation(w, expression_analysis, scaling=3 / 4):
 
     return np.concatenate(new_latents), np.stack(new_covariates)
 
+def latent_augmentation_ethnicity(w, covariates_analysis, labels, ethnicity_scaling = 1):
+    if len(labels) == 2:
+        return latent_augmentation_ethnicity_binary(w, covariates_analysis, ethnicity_scaling)
+    else:
+        return expression_augmentation_ethnicity_multi(w, covariates_analysis, labels, ethnicity_scaling)
+
+def latent_augmentation_ethnicity_binary(w, covariates_analysis, ethnicity_scaling = 1):
+
+    w_ethnicity, ethnicity_labels = ethnicity_augmentation(
+        w,
+        covariates_analysis["ethnicity"],
+        num_latents=8,
+        scaling=ethnicity_scaling,
+    )
+    ethnicity_labels = ["ethnicity_{}".format(item) for item in range(len(ethnicity_labels))]
+
+    latents = np.concatenate([w, w_ethnicity])
+    labels = ["original"] + ethnicity_labels
+
+    return latents, labels
+
+def expression_augmentation_ethnicity_multi(w, covariates_analysis, labels, ethnicity_scaling=1):
+    new_latents = []
+    new_covariates = []
+
+    dic = covariates_analysis["ethnicity"].items()
+    dic_iterator = iter(dic)
+    direction, analysis = next(dic_iterator)
+    new_covariates.append(direction[0])
+
+    normal = analysis["normal"]
+    w_neutralized = w - w.dot(normal.T) * normal
+    w_neutralized += ethnicity_scaling * analysis["neg_stats"]["mean"] * normal
+    new_latents.append(w_neutralized)
+
+    for direction, analysis in covariates_analysis["ethnicity"].items():
+        new_covariates.append(direction[1])
+        normal = analysis["normal"]
+        w_augmented = w_neutralized - w_neutralized.dot(normal.T) * normal
+        w_augmented += ethnicity_scaling * analysis["pos_stats"]["mean"] * normal
+        new_latents.append(w_augmented)
+
+    latents = np.concatenate([w, np.concatenate(new_latents)])
+    labels = np.stack(["original"]+new_covariates)
+    return latents, labels
+
 
 def latent_augmentation(w, covariates_analysis, covariates_scaling):
     # Pose augmentation
@@ -93,3 +147,4 @@ def latent_augmentation(w, covariates_analysis, covariates_scaling):
     labels = ["original"] + pose_labels + illumination_labels + expression_labels
 
     return latents, labels
+
diff --git a/bob/paper/ijcb2021_synthetic_dataset/script/compute_latent_directions.py b/bob/paper/ijcb2021_synthetic_dataset/script/compute_latent_directions.py
index b5d4f30ddb178811512498cb470596a107115762..8d4206a171f46aaef44ab7be6c613f809919d397 100644
--- a/bob/paper/ijcb2021_synthetic_dataset/script/compute_latent_directions.py
+++ b/bob/paper/ijcb2021_synthetic_dataset/script/compute_latent_directions.py
@@ -8,6 +8,11 @@ from bob.extension.scripts.click_helper import (
 import pickle
 from ..latent import analysis
 
+ETHNICITY_LABELS = {
+    "Morph": ["A", "B", "H", "W"],
+    "MEDS": ["H", "W"]
+}
+
 @click.command(
     cls=ConfigCommand,
     help="Compute and save latent directions starting from precomputed latent projections of MultiPIE",
@@ -36,27 +41,39 @@ from ..latent import analysis
     help="`Activate flag to overwrite computed directions if the file already exist",
 )
 @click.option(
-    "--seed",
-    "-s",
-    type=int,
+    "--ethnicity",
+    "-e",
+    type=click.Choice(["Morph", "MEDS"]),
+    help="Database used for ethnicity analysis",
     cls=ResourceOption,
-    help="Seed to control stochasticity during the SVM fitting.",
 )
+
+
 def compute_latents(projections_dir=rc['bob.synface.multipie_projections'],
                     output_path=rc['bob.synface.latent_directions'],
                     force=False,
-                    seed=None,
+                    ethnicity=None,
                     **kwargs):
 
     if (not os.path.exists(output_path)) or force:
         out = {}
-        for covariate in ['illumination','expression','pose']:
-            print('Analyzing {} covariate ...'.format(covariate))
+        if ethnicity is not None:
+            labels = ETHNICITY_LABELS[ethnicity]
             print('    Loading projected latents ...')
-            df = analysis.load_latents(covariate)
+            df = analysis.load_latents_ethnicity(labels)
             print('    SVM fitting ...')
-            result = analysis.covariate_analysis(covariate, df, seed)
-            out[covariate] = result
+            result = analysis.ethnicity_analysis(df, ethnicity)
+            out["ethnicity"] = result
+
+        else:
+
+            for covariate in ['illumination','expression','pose']:
+                print('Analyzing {} covariate ...'.format(covariate))
+                print('    Loading projected latents ...')
+                df = analysis.load_latents(covariate)
+                print('    SVM fitting ...')
+                result = analysis.covariate_analysis(covariate, df)
+                out[covariate] = result
 
         with open(output_path, 'wb') as f:
             pickle.dump(out, f)
@@ -65,4 +82,4 @@ def compute_latents(projections_dir=rc['bob.synface.multipie_projections'],
         print('Computed directions are already found under {}. Use the --force flag to overwrite them.'.format(output_path))
 
 if __name__ == "__main__":
-    compute_latents()
\ No newline at end of file
+    compute_latents()
diff --git a/bob/paper/ijcb2021_synthetic_dataset/script/generate_db.py b/bob/paper/ijcb2021_synthetic_dataset/script/generate_db.py
index 9ae0e01ac2869064e120a40dc264ad2041194820..c6b96d8a124f946f95b9e489e95c1cfd4fba2da0 100644
--- a/bob/paper/ijcb2021_synthetic_dataset/script/generate_db.py
+++ b/bob/paper/ijcb2021_synthetic_dataset/script/generate_db.py
@@ -4,7 +4,9 @@ from bob.bio.face_ongoing.configs.baselines.msceleb.inception_resnet_v2.centerlo
 from ..generate.db_generator import DBGenerator
 from ..stylegan2.generator import StyleGAN2Generator
 from ..stylegan2.dnnlib import tflib
+from ..latent import analysis, editing
 from ..utils import get_task
+import numpy as np
 
 import pickle as pkl
 from scipy.spatial.distance import cdist
@@ -15,6 +17,10 @@ import time
 import click
 from bob.extension.scripts.click_helper import ConfigCommand, ResourceOption
 
+ETHNICITY_LABELS = {
+    "Morph": ["A", "B", "H", "W"],
+    "MEDS": ["H", "W"]
+}
 
 def get_postprocessing_fn():
     SG2_REYE_POS = (480, 380)
@@ -59,11 +65,10 @@ def initialization(
     task_id,
     seed,
 ):
-    # Fix randomness
     tflib.init_tf({"rnd.np_random_seed": seed, "rnd.tf_random_seed": "auto"})
 
     cropper = get_cropper()
-    generator = StyleGAN2Generator(randomize_noise=False, batch_size=4)
+    generator = StyleGAN2Generator(randomize_noise=True, batch_size=4)
     image_postprocessor = get_postprocessing_fn()
 
     image_dir = os.path.join(output_dir, "image")
@@ -101,7 +106,6 @@ def initialization(
 
 @click.command(
     cls=ConfigCommand,
-    entry_point_group='generation_config',
     help="Generate a synthetic database using semantic augmentation in the latent space.",
 )
 @click.option(
@@ -153,7 +157,6 @@ def initialization(
     "-o",
     type=str,
     required=True,
-    default=rc['bob.synface.synthetic_datasets'],
     help="Root of the output directory tree",
     cls=ResourceOption,
 )
@@ -171,6 +174,14 @@ def initialization(
     help="Subtask to execute",
     cls=ResourceOption,
 )
+@click.option(
+    "--ethnicity",
+    "-e",
+    type=click.Choice(["Morph", "MEDS"]),
+    help="Database used for ethnicity analysis",
+    cls=ResourceOption,
+)
+
 def db_gen(
     num_identities,
     ict,
@@ -181,6 +192,7 @@ def db_gen(
     illumination_scaling=1.0,
     expression_scaling=1.0,
     color_variations=0,
+    ethnicity="Morph",
     **kwargs
 ):
     current_task, num_tasks = get_task()
@@ -197,13 +209,17 @@ def db_gen(
 
     identities = list(range(num_identities))[current_task::num_tasks]
 
+    ethnicity_labels = ETHNICITY_LABELS[ethnicity]  if ethnicity is not None else None
+    print("Generating labels for : {}".format(ethnicity_labels))
+    
     if task == "references":
         db_generator.create_references(identities)
     elif task == "augmentations":
-        db_generator.augment_identities(identities)
+        db_generator.augment_identities(identities, ethnicity_labels)
     else:
-        db_generator.create_database(identities)
+        db_generator.create_database(identities, ethnicity_labels)
 
 
 if __name__ == "__main__":
     db_gen()
+
diff --git a/bob/paper/ijcb2021_synthetic_dataset/script/project_db.py b/bob/paper/ijcb2021_synthetic_dataset/script/project_db.py
index e550cc64c6fc95e6e4edf1c2251ecbcdc84627ad..97dbadee7d9384aa9d1b1b4adebc60e0d5adad10 100644
--- a/bob/paper/ijcb2021_synthetic_dataset/script/project_db.py
+++ b/bob/paper/ijcb2021_synthetic_dataset/script/project_db.py
@@ -17,10 +17,20 @@ from bob.extension.scripts.click_helper import (
 
 from ..utils import fix_randomness
 
+
+# additional imports
+from PIL import Image
+import numpy as np
+
 CROP_DIR = "img_aligned"
 PROJ_DIR = "projected"
 LAT_DIR = "w_latents"
 
+ETHNICITY_LABELS = {
+    "Morph": ["A", "B", "H", "W"],
+    "MEDS": ["H", "W"]
+}
+
 
 def is_processed(f, output_dir):
     return os.path.exists(f.make_path(os.path.join(output_dir, LAT_DIR), ".h5"))
@@ -28,16 +38,14 @@ def is_processed(f, output_dir):
 
 @click.command(
     cls=ConfigCommand,
-    entry_point_group='projection_config',
     help="Project a bob.bio.base.database.BioDatabase into the StyleGAN2 latent space.",
 )
 @click.option(
     "--database",
     "-d",
-    entry_point_group='projected_db',
-    required=True,
+    entry_point_group="bob.bio.database",
     cls=ResourceOption,
-    help="Which bob.bio.database to project (available : 'multipie_U', 'multipie_E', 'multipie_P')"
+    help="BioDatabase of face images that should be projected",
 )
 @click.option(
     "--output-dir",
@@ -82,14 +90,23 @@ def is_processed(f, output_dir):
     cls=ResourceOption,
     help="Seed to control stochasticity during projection.",
 )
+@click.option(
+    "--ethnicity",
+    "-e",
+    type=click.Choice(['Morph', 'MEDS']),
+    help="Database used for ethnicity analysis",
+    cls=ResourceOption,
+)
+
 def project(
-    database,
-    output_dir,
+    output_dir= rc['bob.synface.multipie_projections'],
     group="world",
     num_steps=1000,
+    database=None,
     checkpoint=False,
     force=False,
     seed=None,
+    ethnicity=None,
     **kwargs
 ):
     failure_file_path = os.path.join(output_dir, "failure.dat")
@@ -102,57 +119,114 @@ def project(
 
     task_id, num_tasks = get_task()
 
-    if group == "world":
-        files = database.training_files()
-    else:
-        files = database.test_files(group)
-
-    if not force:
-        files = [f for f in files if not is_processed(f, output_dir)]
-
-        if os.path.exists(failure_file_path):
-            with open(failure_file_path, "r") as failure_file:
-                fail_cases = [item.rstrip() for item in failure_file]
+    #database_path = "/idiap/resource/database/MEDS/MEDS_II/data" # rc it, chose according to input
+    if ethnicity is not None:
+        database_path = rc['bob.db.' + ethnicity]
+        labels = ETHNICITY_LABELS[ethnicity]
+        for label in labels:
+            filename = rc['bob.db.keys'] + '/' + label + '_keys.txt'
+            output_dir_label = output_dir + "/" + label
+            with open(filename) as f:
+                subfiles = f.read().splitlines()
+
+            for i, f in enumerate(subfiles): # treat subfiles, generate one for each ethnciity
+                print("{} : {}".format(i, f))
+
+                image = Image.open(database_path + '/' + f + ".JPG")
+                basewidth = 1024
+                wpercent = (basewidth/float(image.size[0]))
+                hsize = int((float(image.size[1])*float(wpercent)))
+                image = image.resize((basewidth,hsize), Image.ANTIALIAS)
+                image = np.array(image)
+                image = np.moveaxis(image, -1, 0)
+
+                #project_image(image, f, output_dir, checkpoint)
+                try:
+                    cropped = cropper(image)
+                    print("Crop ok {} !".format(f))
+                except:
+                    with open(failure_file_path, "a") as failure_file:
+                        failure_file.write(f + "\n")
+                    print("Failure to crop {} !".format(f))
+                    return
+
+                print(os.path.join(output_dir, CROP_DIR) + f +".png")
+
+                if checkpoint:
+                    bob.io.base.save(
+                        cropped,
+                        os.path.join(output_dir, CROP_DIR) + f +".png",
+                        create_directories=True
+                    )
+
+                projected = projector(cropped)
+
+                if checkpoint:
+                    bob.io.base.save(
+                        projected.image,
+                        os.path.join(output_dir, PROJ_DIR)+ f + ".png",
+                        create_directories=True
+                    )
+
+                bob.io.base.save(
+                    projected.w_latent,
+                    os.path.join(output_dir, LAT_DIR)+ f+ ".h5",
+                    create_directories=True
+                )
 
-            files = [f for f in files if f.path not in fail_cases]
-
-    subfiles = files[task_id :: num_tasks]
-    print("{} files remaining. Handling {} of them".format(len(files), len(subfiles)))
-
-    for i, f in enumerate(subfiles):
-        print("{} : {}".format(i, f))
-        image = f.load(database.original_directory, database.original_extension)
-
-        try:
-            cropped = cropper(image)
-        except:
-            with open(failure_file_path, "a") as failure_file:
-                failure_file.write(f.path + "\n")
-            print("Failure to crop {} !".format(f))
-            continue
-
-        if checkpoint:
-            bob.io.base.save(
-                cropped,
-                f.make_path(os.path.join(output_dir, CROP_DIR), ".png"),
-                create_directories=True,
-            )
-
-        projected = projector(cropped)
-
-        if checkpoint:
-            bob.io.base.save(
-                projected.image,
-                f.make_path(os.path.join(output_dir, PROJ_DIR), ".png"),
-                create_directories=True,
-            )
-
-        bob.io.base.save(
-            projected.w_latent,
-            f.make_path(os.path.join(output_dir, LAT_DIR), ".h5"),
-            create_directories=True,
-        )
+    else:
+        if group == "world":
+            files = database.training_files()
+        else:
+            files = database.test_files(group)
+
+        if not force:
+            files = [f for f in files if not is_processed(f, output_dir)]
+
+            if os.path.exists(failure_file_path):
+                with open(failure_file_path, "r") as failure_file:
+                    fail_cases = [item.rstrip() for item in failure_file]
+
+                files = [f for f in files if f.path not in fail_cases]
+
+        subfiles = files[task_id :: num_tasks]
+        print("{} files remaining. Handling {} of them".format(len(files), len(subfiles)))
+
+            for i, f in enumerate(subfiles):
+                print("{} : {}".format(i, f))
+                image = f.load(database.original_directory, database.original_extension)
+
+                try:
+                    cropped = cropper(image)
+                except:
+                    with open(failure_file_path, "a") as failure_file:
+                        failure_file.write(f.path + "\n")
+                    print("Failure to crop {} !".format(f))
+                    continue
+
+                if checkpoint:
+                    bob.io.base.save(
+                        cropped,
+                        f.make_path(os.path.join(output_dir, CROP_DIR), ".png"),
+                        create_directories=True,
+                    )
+
+                projected = projector(cropped)
+
+                if checkpoint:
+                    bob.io.base.save(
+                        projected.image,
+                        f.make_path(os.path.join(output_dir, PROJ_DIR), ".png"),
+                        create_directories=True,
+                    )
+
+                bob.io.base.save(
+                    projected.w_latent,
+                    f.make_path(os.path.join(output_dir, LAT_DIR), ".h5"),
+                    create_directories=True,
+                )
 
 
 if __name__ == "__main__":
     project()
+