diff --git a/bob/paper/ijcb2021_synthetic_dataset/generate/db_generator.py b/bob/paper/ijcb2021_synthetic_dataset/generate/db_generator.py index 0ed5bf4e9c47b564185e9b0437b43cd8d0d95d43..1aaa5e2283f6b591780ae37a8b8920abe3c4eabb 100644 --- a/bob/paper/ijcb2021_synthetic_dataset/generate/db_generator.py +++ b/bob/paper/ijcb2021_synthetic_dataset/generate/db_generator.py @@ -66,7 +66,7 @@ class DBGenerator(object): seed : int Seed of the random generator for reproducibility covariates_analysis_path: str - Path to pickled covariaties analysis results (that contains the latent directions) + Path to pickled covariates analysis results (that contains the latent directions) """ self.generator = generator @@ -74,8 +74,10 @@ class DBGenerator(object): self.cropper = cropper self.extractor = extractor self.seed = seed - with open(covariates_analysis_path, "rb") as f: + with open(rc["bob.synface.latent_directions"], "rb") as f: self.covariates_analysis = pkl.load(f) + with open(rc["bob.synface.latent_directions_ethnicity"], "rb") as f: + self.covariates_analysis_ethnicity = pkl.load(f) self.ict = ict self.covariates_scaling = covariates_scaling @@ -190,6 +192,41 @@ class DBGenerator(object): ] return faces + def create_reference_from_w_latent(self, identity, w_latent): + """ Create the reference Face for the required identity by using it's exisisting projected w_latent, + and computes the face image as well as the face embedding. + + Parameters + ---------- + identity : int + Tag of the currently generated identity + w_latent : np.arrays + W-space latent vector corresponding to the projection. Shape [latent_dim] + + Returns + ------- + :synthface.generate.face_wrapper.Face: + + """ + ref_w = editing.latent_neutralisation(w_latent, self.covariates_analysis) + + # Compute the image and embeddings + raw_images = self.generator.run_from_W(ref_w) + ref_img = self.image_postprocessor(raw_images[0]) + ref_embedding = self.get_embedding(ref_img) + + ref_face = Face( + z_latent=None, + w_latent=ref_w, + image=ref_img, + embedding=ref_embedding, + identity=identity, + sample="reference", + ) + + self.save_faces([ref_face], self.references_hdf5_path) + return ref_face + def create_reference(self, identity, compared_embeddings): """ Create the reference Face for the required identity by randomly sampling a new latent vector, ans computing the face image as well as the face embedding. @@ -255,7 +292,7 @@ class DBGenerator(object): ) return ref_face, num_candidates - def augment_identity(self, reference_face): + def augment_identity(self, reference_face, ethnicity_labels = None): """ Generate all semantic augmentations for the provided reference. Parameters @@ -269,9 +306,14 @@ class DBGenerator(object): Faces for each semantic augmentations of the reference """ # Compute semantic augmentations - w_augmented, labels = editing.latent_augmentation( - reference_face.w_latent, self.covariates_analysis, self.covariates_scaling - ) + if ethnicity_labels is not None: + w_augmented, labels = editing.latent_augmentation_ethnicity( + reference_face.w_latent, self.covariates_analysis_ethnicity, ethnicity_labels, ethnicity_scaling = 1.2 + ) + else: + w_augmented, labels = editing.latent_augmentation( + reference_face.w_latent, self.covariates_analysis, self.covariates_scaling + ) # Optionally, also add color variation if self.color_variations == 0: @@ -352,7 +394,7 @@ class DBGenerator(object): else: return [] - def create_database(self, identities): + def create_database(self, identities, ethnicity_labels = None): """ Create a full database (reference + augmentations) for the provided identity tags Parameters @@ -361,7 +403,7 @@ class DBGenerator(object): List of unique tags for each identity that must be created """ self.create_references(identities) - self.augment_identities(identities) + self.augment_identities(identities, ethnicity_labels) def create_references(self, identities): """ Create all references for the provided identity tags @@ -402,7 +444,7 @@ class DBGenerator(object): self.save_faces([ref_face], self.references_hdf5_path) compared_embeddings.append(ref_face.embedding) - def augment_identities(self, identities): + def augment_identities(self, identities, ethnicity_labels = None): """ Create all augmentations for the provided identity tags. This method assumes the references for those identities have already been created. @@ -420,7 +462,7 @@ class DBGenerator(object): # Augment identity start_time = time.time() - faces = self.augment_identity(ref_face) + faces = self.augment_identity(ref_face, ethnicity_labels) runtime = time.time() - start_time # Store stats @@ -430,3 +472,4 @@ class DBGenerator(object): f.write("{} {:.2f}\n".format(identity, runtime)) self.save_faces(faces, self.augmentations_hdf5_path) + diff --git a/bob/paper/ijcb2021_synthetic_dataset/latent/analysis.py b/bob/paper/ijcb2021_synthetic_dataset/latent/analysis.py index 632311cb829263ddfe8f10d099a08788d94c79b0..ebb5dc845e8b61f0bcf172398a6f22998e48c1fc 100644 --- a/bob/paper/ijcb2021_synthetic_dataset/latent/analysis.py +++ b/bob/paper/ijcb2021_synthetic_dataset/latent/analysis.py @@ -6,19 +6,29 @@ from sklearn.svm import LinearSVC from sklearn.preprocessing import LabelEncoder import numpy as np +# additional import +import h5py + MAIN_DIR = rc['bob.synface.multipie_projections'] LAT_DIR = os.path.join(MAIN_DIR, "w_latents") FAILURE_FILE = os.path.join(MAIN_DIR, "failure.dat") MULTIPIE_DIR = rc["bob.db.multipie.directory"] -VALID_COVARIATES = ['illumination','expression','pose'] +ETHNICITY_KEY_DIR = rc['bob.db.keys'] + +PROTOCOLS = { + "expression": "E_lit", + "pose": "P_lit", + "illumination": "U" +} + +VALID_CAMERAS = ["08_0", "13_0", "14_0", "05_1", "05_0", "04_1", "19_0"] POS_TO_CAM = { "frontal": ["05_1"], "left": ["11_0", "12_0", "09_0", "08_0", "13_0", "14_0"], "right": ["05_0", "04_1", "19_0", "20_0", "01_0", "24_0"], } - CAM_TO_POS = {cam: pos for pos, cam_list in POS_TO_CAM.items() for cam in cam_list} POS_TO_FLASH = { @@ -32,19 +42,17 @@ FLASH_TO_POS = { flash: pos for pos, flash_list in POS_TO_FLASH.items() for flash in flash_list } -EXPRESSION_TO_RECORDING = { # Recording = SessionNumber_RecordingNumber - 'neutral': ["01_01", "02_01", "03_01", "04_01", "04_02"], - 'smile': ["01_02", "03_02"], - 'surprise': ["02_02"], - 'squint': ["02_03"], - 'disgust': ["03_03"], - 'scream': ["04_03"] +VPOS_TO_FLASH = { + "no_flash": [0], + "middle": list(range(1, 14)), + "top": list(range(14, 19)), } -RECORDING_TO_EXPRESSION = { - recording : expr for expr, recording_list in EXPRESSION_TO_RECORDING.items() for recording in recording_list +FLASH_TO_VPOS = { + flash: vpos for vpos, flash_list in VPOS_TO_FLASH.items() for flash in flash_list } + def filter_failure_cases(file_list): with open(FAILURE_FILE, "r") as f: failures = [item.rstrip() for item in f.readlines()] @@ -54,39 +62,25 @@ def filter_failure_cases(file_list): def get_covariate(file, covariate): if covariate == "expression": - recording = '_'.join(file.path.split('/')[-1].split('_')[1:3]) - return RECORDING_TO_EXPRESSION[recording] + return file.expression.name elif covariate == "pose": - camera = file.path.split('/')[4] - return CAM_TO_POS[camera] + return CAM_TO_POS[file.file_multiview.camera.name] elif covariate == "illumination": - shot = file.path.split('/')[-1].split('_')[-1] - return FLASH_TO_POS[int(shot)] + return FLASH_TO_POS[file.file_multiview.shot_id] else: raise ValueError( - "Unknown covariate {} not in {}".format(covariate, VALID_COVARIATES) + "Unknown covariate {} not in {}".format(covariate, PROTOCOLS.keys()) ) -def get_file_list(covariate, group='world'): - if covariate == 'illumination': - db = Multipie(original_directory=MULTIPIE_DIR) - raw_files = db.objects(protocol='U', groups=[group]) - elif covariate == 'expression': - from ..config.project.multipie_E import database as db - raw_files = db.objects(groups=[group]) - elif covariate == 'pose': - from ..config.project.multipie_P import database as db - raw_files = db.objects(groups=[group]) - else: - raise ValueError('Unknown covariate {}'.format(covariate)) - - return filter_failure_cases(raw_files) def load_latents(covariate, group="world"): - assert covariate in VALID_COVARIATES, "Unknown `covariate` {} not in {}".format( - covariate, VALID_COVARIATES + assert covariate in PROTOCOLS.keys(), "Unknown `covariate` {} not in {}".format( + covariate, PROTOCOLS.keys() + ) + db = Multipie(original_directory=MULTIPIE_DIR) + files = filter_failure_cases( + db.objects(protocol=PROTOCOLS[covariate], groups=[group], cameras=VALID_CAMERAS) ) - files = get_file_list(covariate, group) df = pd.DataFrame() df["file"] = [f.path for f in files] df["latent"] = [f.load(LAT_DIR, ".h5") for f in files] @@ -95,11 +89,41 @@ def load_latents(covariate, group="world"): return df +def load_latents_ethnicity(labels): + latents = [] + ethnicities = [] -def binary_analysis(train_df, covariate, target_labels, seed=None, **kwargs): + df = pd.DataFrame() + + for label in labels: + filename = ETHNICITY_KEY_DIR + '/' + label + '_keys.txt' + with open(filename) as f: + files = f.read().splitlines() + + for f in files: + file_path = LAT_DIR + f + '.h5' + try: + with h5py.File(file_path, "r") as f: + a_group_key = list(f.keys())[0] + + # Get the data + data = list(f[a_group_key]) + w_latent = np.array(data) + latents.append(w_latent) + ethnicities.append(label) + + except: + print("Could not retrieve latent projection for: {} ".format(file_path)) + + df["latent"] = latents + df["ethnicity"] = ethnicities + + return df + +def binary_analysis(train_df, covariate, target_labels, **kwargs): train_df = train_df[train_df[covariate].isin(target_labels)] - print(seed) - svm = LinearSVC(fit_intercept=False, random_state=seed, **kwargs) + + svm = LinearSVC(fit_intercept=False, **kwargs) train_latents = np.stack(train_df["latent"]) train_labels = np.stack(train_df[covariate]) @@ -144,35 +168,45 @@ def multiclass_analysis(train_df, covariate, neutral_label, other_labels, **kwar for label in other_labels } +def ethnicity_analysis(train_df, ethnicity): + if ethnicity == "MEDS": + return binary_analysis(train_df, covariate="ethnicity", target_labels=["W", "B"]) + else: + return multiclass_analysis( + train_df, + covariate="ethnicity", + neutral_label="W", + other_labels=["B", "A", "H"], + ) -def pose_analysis(train_df, seed=None): - return binary_analysis(train_df, covariate="pose", target_labels=["left", "right"], seed=seed) +def pose_analysis(train_df): + return binary_analysis(train_df, covariate="pose", target_labels=["left", "right"]) -def illumination_analysis(train_df, seed=None): +def illumination_analysis(train_df): return binary_analysis( - train_df, covariate="illumination", target_labels=["left", "right"], seed=seed + train_df, covariate="illumination", target_labels=["left", "right"] ) -def expression_analysis(train_df, seed=None): +def expression_analysis(train_df): return multiclass_analysis( train_df, covariate="expression", neutral_label="neutral", other_labels=["smile", "scream", "disgust", "squint", "surprise"], - seed=seed ) -def covariate_analysis(covariate, train_df, seed=None): +def covariate_analysis(covariate, train_df): if covariate == "expression": - return expression_analysis(train_df, seed) + return expression_analysis(train_df) elif covariate == "pose": - return pose_analysis(train_df, seed) + return pose_analysis(train_df) elif covariate == "illumination": - return illumination_analysis(train_df, seed) + return illumination_analysis(train_df) else: raise ValueError( - "Unknown covariate {} not in {}".format(covariate, VALID_COVARIATES) + "Unknown covariate {} not in {}".format(covariate, PROTOCOLS.keys()) ) + diff --git a/bob/paper/ijcb2021_synthetic_dataset/latent/editing.py b/bob/paper/ijcb2021_synthetic_dataset/latent/editing.py index 9edb81439dc6fd3a55c5df4dc833e27238a7b9c4..80b426df0080cd087ba6d5e44e13df78b3d09432 100644 --- a/bob/paper/ijcb2021_synthetic_dataset/latent/editing.py +++ b/bob/paper/ijcb2021_synthetic_dataset/latent/editing.py @@ -27,6 +27,12 @@ def latent_neutralisation(w, covariates_analysis): return w +def ethnicity_neutralisation(w,covariates_analysis): + ethnicity_normal = covariates_analysis["ethnicity"]["normal"] + w -= w.dot(ethnicity_normal.T) * ethnicity_normal + + return w + def binary_augmentation(w, analysis, num_latents, scaling): neg_mean = analysis["neg_stats"]["mean"] @@ -37,7 +43,9 @@ def binary_augmentation(w, analysis, num_latents, scaling): normal = analysis["normal"] scales = scaling * np.linspace(-extremum, extremum, num_latents)[:, None] return w + scales * normal, scales - + +def ethnicity_augmentation(w, ethnicity_analysis, num_latents, scaling=3 / 4): + return binary_augmentation(w, ethnicity_analysis, num_latents, scaling) def pose_augmentation(w, pose_analysis, num_latents, scaling=3 / 4): return binary_augmentation(w, pose_analysis, num_latents, scaling) @@ -61,6 +69,52 @@ def expression_augmentation(w, expression_analysis, scaling=3 / 4): return np.concatenate(new_latents), np.stack(new_covariates) +def latent_augmentation_ethnicity(w, covariates_analysis, labels, ethnicity_scaling = 1): + if len(labels) == 2: + return latent_augmentation_ethnicity_binary(w, covariates_analysis, ethnicity_scaling) + else: + return expression_augmentation_ethnicity_multi(w, covariates_analysis, labels, ethnicity_scaling) + +def latent_augmentation_ethnicity_binary(w, covariates_analysis, ethnicity_scaling = 1): + + w_ethnicity, ethnicity_labels = ethnicity_augmentation( + w, + covariates_analysis["ethnicity"], + num_latents=8, + scaling=ethnicity_scaling, + ) + ethnicity_labels = ["ethnicity_{}".format(item) for item in range(len(ethnicity_labels))] + + latents = np.concatenate([w, w_ethnicity]) + labels = ["original"] + ethnicity_labels + + return latents, labels + +def expression_augmentation_ethnicity_multi(w, covariates_analysis, labels, ethnicity_scaling=1): + new_latents = [] + new_covariates = [] + + dic = covariates_analysis["ethnicity"].items() + dic_iterator = iter(dic) + direction, analysis = next(dic_iterator) + new_covariates.append(direction[0]) + + normal = analysis["normal"] + w_neutralized = w - w.dot(normal.T) * normal + w_neutralized += ethnicity_scaling * analysis["neg_stats"]["mean"] * normal + new_latents.append(w_neutralized) + + for direction, analysis in covariates_analysis["ethnicity"].items(): + new_covariates.append(direction[1]) + normal = analysis["normal"] + w_augmented = w_neutralized - w_neutralized.dot(normal.T) * normal + w_augmented += ethnicity_scaling * analysis["pos_stats"]["mean"] * normal + new_latents.append(w_augmented) + + latents = np.concatenate([w, np.concatenate(new_latents)]) + labels = np.stack(["original"]+new_covariates) + return latents, labels + def latent_augmentation(w, covariates_analysis, covariates_scaling): # Pose augmentation @@ -93,3 +147,4 @@ def latent_augmentation(w, covariates_analysis, covariates_scaling): labels = ["original"] + pose_labels + illumination_labels + expression_labels return latents, labels + diff --git a/bob/paper/ijcb2021_synthetic_dataset/script/compute_latent_directions.py b/bob/paper/ijcb2021_synthetic_dataset/script/compute_latent_directions.py index b5d4f30ddb178811512498cb470596a107115762..8d4206a171f46aaef44ab7be6c613f809919d397 100644 --- a/bob/paper/ijcb2021_synthetic_dataset/script/compute_latent_directions.py +++ b/bob/paper/ijcb2021_synthetic_dataset/script/compute_latent_directions.py @@ -8,6 +8,11 @@ from bob.extension.scripts.click_helper import ( import pickle from ..latent import analysis +ETHNICITY_LABELS = { + "Morph": ["A", "B", "H", "W"], + "MEDS": ["H", "W"] +} + @click.command( cls=ConfigCommand, help="Compute and save latent directions starting from precomputed latent projections of MultiPIE", @@ -36,27 +41,39 @@ from ..latent import analysis help="`Activate flag to overwrite computed directions if the file already exist", ) @click.option( - "--seed", - "-s", - type=int, + "--ethnicity", + "-e", + type=click.Choice(["Morph", "MEDS"]), + help="Database used for ethnicity analysis", cls=ResourceOption, - help="Seed to control stochasticity during the SVM fitting.", ) + + def compute_latents(projections_dir=rc['bob.synface.multipie_projections'], output_path=rc['bob.synface.latent_directions'], force=False, - seed=None, + ethnicity=None, **kwargs): if (not os.path.exists(output_path)) or force: out = {} - for covariate in ['illumination','expression','pose']: - print('Analyzing {} covariate ...'.format(covariate)) + if ethnicity is not None: + labels = ETHNICITY_LABELS[ethnicity] print(' Loading projected latents ...') - df = analysis.load_latents(covariate) + df = analysis.load_latents_ethnicity(labels) print(' SVM fitting ...') - result = analysis.covariate_analysis(covariate, df, seed) - out[covariate] = result + result = analysis.ethnicity_analysis(df, ethnicity) + out["ethnicity"] = result + + else: + + for covariate in ['illumination','expression','pose']: + print('Analyzing {} covariate ...'.format(covariate)) + print(' Loading projected latents ...') + df = analysis.load_latents(covariate) + print(' SVM fitting ...') + result = analysis.covariate_analysis(covariate, df) + out[covariate] = result with open(output_path, 'wb') as f: pickle.dump(out, f) @@ -65,4 +82,4 @@ def compute_latents(projections_dir=rc['bob.synface.multipie_projections'], print('Computed directions are already found under {}. Use the --force flag to overwrite them.'.format(output_path)) if __name__ == "__main__": - compute_latents() \ No newline at end of file + compute_latents() diff --git a/bob/paper/ijcb2021_synthetic_dataset/script/generate_db.py b/bob/paper/ijcb2021_synthetic_dataset/script/generate_db.py index 9ae0e01ac2869064e120a40dc264ad2041194820..c6b96d8a124f946f95b9e489e95c1cfd4fba2da0 100644 --- a/bob/paper/ijcb2021_synthetic_dataset/script/generate_db.py +++ b/bob/paper/ijcb2021_synthetic_dataset/script/generate_db.py @@ -4,7 +4,9 @@ from bob.bio.face_ongoing.configs.baselines.msceleb.inception_resnet_v2.centerlo from ..generate.db_generator import DBGenerator from ..stylegan2.generator import StyleGAN2Generator from ..stylegan2.dnnlib import tflib +from ..latent import analysis, editing from ..utils import get_task +import numpy as np import pickle as pkl from scipy.spatial.distance import cdist @@ -15,6 +17,10 @@ import time import click from bob.extension.scripts.click_helper import ConfigCommand, ResourceOption +ETHNICITY_LABELS = { + "Morph": ["A", "B", "H", "W"], + "MEDS": ["H", "W"] +} def get_postprocessing_fn(): SG2_REYE_POS = (480, 380) @@ -59,11 +65,10 @@ def initialization( task_id, seed, ): - # Fix randomness tflib.init_tf({"rnd.np_random_seed": seed, "rnd.tf_random_seed": "auto"}) cropper = get_cropper() - generator = StyleGAN2Generator(randomize_noise=False, batch_size=4) + generator = StyleGAN2Generator(randomize_noise=True, batch_size=4) image_postprocessor = get_postprocessing_fn() image_dir = os.path.join(output_dir, "image") @@ -101,7 +106,6 @@ def initialization( @click.command( cls=ConfigCommand, - entry_point_group='generation_config', help="Generate a synthetic database using semantic augmentation in the latent space.", ) @click.option( @@ -153,7 +157,6 @@ def initialization( "-o", type=str, required=True, - default=rc['bob.synface.synthetic_datasets'], help="Root of the output directory tree", cls=ResourceOption, ) @@ -171,6 +174,14 @@ def initialization( help="Subtask to execute", cls=ResourceOption, ) +@click.option( + "--ethnicity", + "-e", + type=click.Choice(["Morph", "MEDS"]), + help="Database used for ethnicity analysis", + cls=ResourceOption, +) + def db_gen( num_identities, ict, @@ -181,6 +192,7 @@ def db_gen( illumination_scaling=1.0, expression_scaling=1.0, color_variations=0, + ethnicity="Morph", **kwargs ): current_task, num_tasks = get_task() @@ -197,13 +209,17 @@ def db_gen( identities = list(range(num_identities))[current_task::num_tasks] + ethnicity_labels = ETHNICITY_LABELS[ethnicity] if ethnicity is not None else None + print("Generating labels for : {}".format(ethnicity_labels)) + if task == "references": db_generator.create_references(identities) elif task == "augmentations": - db_generator.augment_identities(identities) + db_generator.augment_identities(identities, ethnicity_labels) else: - db_generator.create_database(identities) + db_generator.create_database(identities, ethnicity_labels) if __name__ == "__main__": db_gen() + diff --git a/bob/paper/ijcb2021_synthetic_dataset/script/project_db.py b/bob/paper/ijcb2021_synthetic_dataset/script/project_db.py index e550cc64c6fc95e6e4edf1c2251ecbcdc84627ad..97dbadee7d9384aa9d1b1b4adebc60e0d5adad10 100644 --- a/bob/paper/ijcb2021_synthetic_dataset/script/project_db.py +++ b/bob/paper/ijcb2021_synthetic_dataset/script/project_db.py @@ -17,10 +17,20 @@ from bob.extension.scripts.click_helper import ( from ..utils import fix_randomness + +# additional imports +from PIL import Image +import numpy as np + CROP_DIR = "img_aligned" PROJ_DIR = "projected" LAT_DIR = "w_latents" +ETHNICITY_LABELS = { + "Morph": ["A", "B", "H", "W"], + "MEDS": ["H", "W"] +} + def is_processed(f, output_dir): return os.path.exists(f.make_path(os.path.join(output_dir, LAT_DIR), ".h5")) @@ -28,16 +38,14 @@ def is_processed(f, output_dir): @click.command( cls=ConfigCommand, - entry_point_group='projection_config', help="Project a bob.bio.base.database.BioDatabase into the StyleGAN2 latent space.", ) @click.option( "--database", "-d", - entry_point_group='projected_db', - required=True, + entry_point_group="bob.bio.database", cls=ResourceOption, - help="Which bob.bio.database to project (available : 'multipie_U', 'multipie_E', 'multipie_P')" + help="BioDatabase of face images that should be projected", ) @click.option( "--output-dir", @@ -82,14 +90,23 @@ def is_processed(f, output_dir): cls=ResourceOption, help="Seed to control stochasticity during projection.", ) +@click.option( + "--ethnicity", + "-e", + type=click.Choice(['Morph', 'MEDS']), + help="Database used for ethnicity analysis", + cls=ResourceOption, +) + def project( - database, - output_dir, + output_dir= rc['bob.synface.multipie_projections'], group="world", num_steps=1000, + database=None, checkpoint=False, force=False, seed=None, + ethnicity=None, **kwargs ): failure_file_path = os.path.join(output_dir, "failure.dat") @@ -102,57 +119,114 @@ def project( task_id, num_tasks = get_task() - if group == "world": - files = database.training_files() - else: - files = database.test_files(group) - - if not force: - files = [f for f in files if not is_processed(f, output_dir)] - - if os.path.exists(failure_file_path): - with open(failure_file_path, "r") as failure_file: - fail_cases = [item.rstrip() for item in failure_file] + #database_path = "/idiap/resource/database/MEDS/MEDS_II/data" # rc it, chose according to input + if ethnicity is not None: + database_path = rc['bob.db.' + ethnicity] + labels = ETHNICITY_LABELS[ethnicity] + for label in labels: + filename = rc['bob.db.keys'] + '/' + label + '_keys.txt' + output_dir_label = output_dir + "/" + label + with open(filename) as f: + subfiles = f.read().splitlines() + + for i, f in enumerate(subfiles): # treat subfiles, generate one for each ethnciity + print("{} : {}".format(i, f)) + + image = Image.open(database_path + '/' + f + ".JPG") + basewidth = 1024 + wpercent = (basewidth/float(image.size[0])) + hsize = int((float(image.size[1])*float(wpercent))) + image = image.resize((basewidth,hsize), Image.ANTIALIAS) + image = np.array(image) + image = np.moveaxis(image, -1, 0) + + #project_image(image, f, output_dir, checkpoint) + try: + cropped = cropper(image) + print("Crop ok {} !".format(f)) + except: + with open(failure_file_path, "a") as failure_file: + failure_file.write(f + "\n") + print("Failure to crop {} !".format(f)) + return + + print(os.path.join(output_dir, CROP_DIR) + f +".png") + + if checkpoint: + bob.io.base.save( + cropped, + os.path.join(output_dir, CROP_DIR) + f +".png", + create_directories=True + ) + + projected = projector(cropped) + + if checkpoint: + bob.io.base.save( + projected.image, + os.path.join(output_dir, PROJ_DIR)+ f + ".png", + create_directories=True + ) + + bob.io.base.save( + projected.w_latent, + os.path.join(output_dir, LAT_DIR)+ f+ ".h5", + create_directories=True + ) - files = [f for f in files if f.path not in fail_cases] - - subfiles = files[task_id :: num_tasks] - print("{} files remaining. Handling {} of them".format(len(files), len(subfiles))) - - for i, f in enumerate(subfiles): - print("{} : {}".format(i, f)) - image = f.load(database.original_directory, database.original_extension) - - try: - cropped = cropper(image) - except: - with open(failure_file_path, "a") as failure_file: - failure_file.write(f.path + "\n") - print("Failure to crop {} !".format(f)) - continue - - if checkpoint: - bob.io.base.save( - cropped, - f.make_path(os.path.join(output_dir, CROP_DIR), ".png"), - create_directories=True, - ) - - projected = projector(cropped) - - if checkpoint: - bob.io.base.save( - projected.image, - f.make_path(os.path.join(output_dir, PROJ_DIR), ".png"), - create_directories=True, - ) - - bob.io.base.save( - projected.w_latent, - f.make_path(os.path.join(output_dir, LAT_DIR), ".h5"), - create_directories=True, - ) + else: + if group == "world": + files = database.training_files() + else: + files = database.test_files(group) + + if not force: + files = [f for f in files if not is_processed(f, output_dir)] + + if os.path.exists(failure_file_path): + with open(failure_file_path, "r") as failure_file: + fail_cases = [item.rstrip() for item in failure_file] + + files = [f for f in files if f.path not in fail_cases] + + subfiles = files[task_id :: num_tasks] + print("{} files remaining. Handling {} of them".format(len(files), len(subfiles))) + + for i, f in enumerate(subfiles): + print("{} : {}".format(i, f)) + image = f.load(database.original_directory, database.original_extension) + + try: + cropped = cropper(image) + except: + with open(failure_file_path, "a") as failure_file: + failure_file.write(f.path + "\n") + print("Failure to crop {} !".format(f)) + continue + + if checkpoint: + bob.io.base.save( + cropped, + f.make_path(os.path.join(output_dir, CROP_DIR), ".png"), + create_directories=True, + ) + + projected = projector(cropped) + + if checkpoint: + bob.io.base.save( + projected.image, + f.make_path(os.path.join(output_dir, PROJ_DIR), ".png"), + create_directories=True, + ) + + bob.io.base.save( + projected.w_latent, + f.make_path(os.path.join(output_dir, LAT_DIR), ".h5"), + create_directories=True, + ) if __name__ == "__main__": project() +