diff --git a/bob/pad/face/config/batl_docker_configs/databases/gt_config.py b/bob/pad/face/config/batl_docker_configs/databases/gt_config.py index 9b15f1abbebdf039257afdddbd3e3f0d02f113b8..4cebdb192cbf8fe867085b7571aa2ed966dab321 100644 --- a/bob/pad/face/config/batl_docker_configs/databases/gt_config.py +++ b/bob/pad/face/config/batl_docker_configs/databases/gt_config.py @@ -5,7 +5,7 @@ ANNOTATIONS_TEMP_DIR = "/tmp/sub_dir/output/train/annotations/" N_FRAMES = 50 GT_PATH = '/tmp/sub_dir/gt.csv' -IDIAP_DATA_GT_PATH = '/tmp/idiap_extracted_data/gt_idiap.csv' +IDIAP_DATA_GT_PATH = '/tmp/config/gt_idiap.csv' GT_CONFIG = dict(path=0, any_pa=1, face=2) IDIAP_GT_CONFIG = dict(path=0, type_id=1, pai_id=2, low_level_group=4) diff --git a/bob/pad/face/database/batldocker.py b/bob/pad/face/database/batldocker.py index 3ff006a64abfe9f5234b28b013f12ff71341dfaf..82dcc223e3b35807d3793464c46838ad861a9bd2 100644 --- a/bob/pad/face/database/batldocker.py +++ b/bob/pad/face/database/batldocker.py @@ -10,6 +10,7 @@ import tables import pkg_resources import collections +from pandas import read_csv from collections import defaultdict from bob.pad.base.database import PadDatabase, PadFile @@ -83,7 +84,7 @@ class BatlDockerPadFile(PadFile): """ self.f = f - if f['type_id'] == 1: + if f['type_id'] >= 1: attack_type = 'attack' else: attack_type = None @@ -234,50 +235,41 @@ class BatlDockerPadDatabase(PadDatabase): base class constructor. """ - file_id = 0 - self.gt_dict = dict() + def read_ground_truth(file_path, gt_type = None): + gt_csv = None + try: + gt_csv = read_csv(file_path) + except Exception as e: + print(e) + + if gt_type != 'Idiap': + gt_csv.rename(columns={'h5_file': 'path', 'face_15': 'type_id'}, inplace=True) + gt_csv = gt_csv[['path','any_pa','type_id']] + gt_csv['group'] = "train" + gt_csv['pai_id'] = None + + # remove ".h5" in file_path + gt_csv['path'] = gt_csv['path'].apply(lambda x: os.path.splitext(x)[0]) + gt_csv['client_id'] = gt_csv['path'].apply(lambda x: x.strip().split('/')[-1]) + return gt_csv # Place input data from govt ground-truth in train set, given without any pai_id - with open(ground_truth['govt']['path'],'r') as gt: - gt_config = ground_truth['govt']['config'] - for data in gt: - fields = data.strip().split(',') - if fields[gt_config['face']] in ["0","1"]: - file_id = file_id + 1 - - # remove ".h5" in file_path - (path,ext) = os.path.splitext(fields[gt_config['path']]) - fields[gt_config['path']] = path - - self.gt_dict[(fields[gt_config['path']], - int(fields[gt_config['face']]), - None, - "train", - int(file_id)) - ] = dict(path=fields[gt_config['path']]) + gt_dataframe = read_ground_truth(ground_truth['govt']['path']) # If retraining, place data from idiap ground-truth in train/eval/dev set if retrain: - with open(ground_truth['idiap']['path'],'r') as gt: - gt_config = ground_truth['idiap']['config'] - for data in gt: - file_id = file_id + 1 - fields = data.strip().split(',') - self.gt_dict[(fields[gt_config['path']], - int(fields[gt_config['type_id']]), - int(fields[gt_config['pai_id']]), - fields[gt_config['low_level_group']], - int(file_id)) - ] = dict(path=fields[gt_config['path']]) - - def split_by(file_list, field): - splitted_list = defaultdict(list) - for data, paths in file_list.items(): - splitted_list[data[3]]\ - .append(dict(client_id=data[0], type_id=data[1], pai_id=data[2], group=data[3], file_id=data[4], path=paths['path'])) - return splitted_list - - self.gt_list = split_by(self.gt_dict, "low_level_group") + gt_dataframe_idiap = read_ground_truth(ground_truth['idiap']['path'],'Idiap') + gt_dataframe = gt_dataframe.append(gt_dataframe_idiap, ignore_index=True, sort = False) + + # Sort in defaultdict with group as keys and list of dict (client_id, path,...) as values + gt_dataframe.index += 1 + gt_dataframe['file_id'] = gt_dataframe.index + gt_dataframe = gt_dataframe.set_index('group')[['client_id','path','type_id','pai_id','file_id']] + gt_dataframe['group'] = gt_dataframe.index + self.gt_list = gt_dataframe.apply(dict,1)\ + .groupby(level=0)\ + .agg(lambda x: list(x.values))\ + .to_dict(into=defaultdict(list)) self.low_level_group_names = ( 'train', 'validation', @@ -430,7 +422,7 @@ class BatlDockerPadDatabase(PadDatabase): if f['type_id']== 0: files.append(f) elif purposes == 'attack': - if f['type_id'] == 1: + if f['type_id'] >= 1: files.append(f) elif purposes == ['real','attack']: files.append(f)