Skip to content
Snippets Groups Projects
Commit 578563d9 authored by Guillaume CLIVAZ's avatar Guillaume CLIVAZ
Browse files

Parse ground-truth csv file with pandas

parent 4e5a70d0
No related branches found
No related tags found
No related merge requests found
Pipeline #
...@@ -5,7 +5,7 @@ ANNOTATIONS_TEMP_DIR = "/tmp/sub_dir/output/train/annotations/" ...@@ -5,7 +5,7 @@ ANNOTATIONS_TEMP_DIR = "/tmp/sub_dir/output/train/annotations/"
N_FRAMES = 50 N_FRAMES = 50
GT_PATH = '/tmp/sub_dir/gt.csv' GT_PATH = '/tmp/sub_dir/gt.csv'
IDIAP_DATA_GT_PATH = '/tmp/idiap_extracted_data/gt_idiap.csv' IDIAP_DATA_GT_PATH = '/tmp/config/gt_idiap.csv'
GT_CONFIG = dict(path=0, any_pa=1, face=2) GT_CONFIG = dict(path=0, any_pa=1, face=2)
IDIAP_GT_CONFIG = dict(path=0, type_id=1, pai_id=2, low_level_group=4) IDIAP_GT_CONFIG = dict(path=0, type_id=1, pai_id=2, low_level_group=4)
......
...@@ -10,6 +10,7 @@ import tables ...@@ -10,6 +10,7 @@ import tables
import pkg_resources import pkg_resources
import collections import collections
from pandas import read_csv
from collections import defaultdict from collections import defaultdict
from bob.pad.base.database import PadDatabase, PadFile from bob.pad.base.database import PadDatabase, PadFile
...@@ -83,7 +84,7 @@ class BatlDockerPadFile(PadFile): ...@@ -83,7 +84,7 @@ class BatlDockerPadFile(PadFile):
""" """
self.f = f self.f = f
if f['type_id'] == 1: if f['type_id'] >= 1:
attack_type = 'attack' attack_type = 'attack'
else: else:
attack_type = None attack_type = None
...@@ -234,50 +235,41 @@ class BatlDockerPadDatabase(PadDatabase): ...@@ -234,50 +235,41 @@ class BatlDockerPadDatabase(PadDatabase):
base class constructor. base class constructor.
""" """
file_id = 0 def read_ground_truth(file_path, gt_type = None):
self.gt_dict = dict() gt_csv = None
try:
gt_csv = read_csv(file_path)
except Exception as e:
print(e)
if gt_type != 'Idiap':
gt_csv.rename(columns={'h5_file': 'path', 'face_15': 'type_id'}, inplace=True)
gt_csv = gt_csv[['path','any_pa','type_id']]
gt_csv['group'] = "train"
gt_csv['pai_id'] = None
# remove ".h5" in file_path
gt_csv['path'] = gt_csv['path'].apply(lambda x: os.path.splitext(x)[0])
gt_csv['client_id'] = gt_csv['path'].apply(lambda x: x.strip().split('/')[-1])
return gt_csv
# Place input data from govt ground-truth in train set, given without any pai_id # Place input data from govt ground-truth in train set, given without any pai_id
with open(ground_truth['govt']['path'],'r') as gt: gt_dataframe = read_ground_truth(ground_truth['govt']['path'])
gt_config = ground_truth['govt']['config']
for data in gt:
fields = data.strip().split(',')
if fields[gt_config['face']] in ["0","1"]:
file_id = file_id + 1
# remove ".h5" in file_path
(path,ext) = os.path.splitext(fields[gt_config['path']])
fields[gt_config['path']] = path
self.gt_dict[(fields[gt_config['path']],
int(fields[gt_config['face']]),
None,
"train",
int(file_id))
] = dict(path=fields[gt_config['path']])
# If retraining, place data from idiap ground-truth in train/eval/dev set # If retraining, place data from idiap ground-truth in train/eval/dev set
if retrain: if retrain:
with open(ground_truth['idiap']['path'],'r') as gt: gt_dataframe_idiap = read_ground_truth(ground_truth['idiap']['path'],'Idiap')
gt_config = ground_truth['idiap']['config'] gt_dataframe = gt_dataframe.append(gt_dataframe_idiap, ignore_index=True, sort = False)
for data in gt:
file_id = file_id + 1 # Sort in defaultdict with group as keys and list of dict (client_id, path,...) as values
fields = data.strip().split(',') gt_dataframe.index += 1
self.gt_dict[(fields[gt_config['path']], gt_dataframe['file_id'] = gt_dataframe.index
int(fields[gt_config['type_id']]), gt_dataframe = gt_dataframe.set_index('group')[['client_id','path','type_id','pai_id','file_id']]
int(fields[gt_config['pai_id']]), gt_dataframe['group'] = gt_dataframe.index
fields[gt_config['low_level_group']], self.gt_list = gt_dataframe.apply(dict,1)\
int(file_id)) .groupby(level=0)\
] = dict(path=fields[gt_config['path']]) .agg(lambda x: list(x.values))\
.to_dict(into=defaultdict(list))
def split_by(file_list, field):
splitted_list = defaultdict(list)
for data, paths in file_list.items():
splitted_list[data[3]]\
.append(dict(client_id=data[0], type_id=data[1], pai_id=data[2], group=data[3], file_id=data[4], path=paths['path']))
return splitted_list
self.gt_list = split_by(self.gt_dict, "low_level_group")
self.low_level_group_names = ( self.low_level_group_names = (
'train', 'validation', 'train', 'validation',
...@@ -430,7 +422,7 @@ class BatlDockerPadDatabase(PadDatabase): ...@@ -430,7 +422,7 @@ class BatlDockerPadDatabase(PadDatabase):
if f['type_id']== 0: if f['type_id']== 0:
files.append(f) files.append(f)
elif purposes == 'attack': elif purposes == 'attack':
if f['type_id'] == 1: if f['type_id'] >= 1:
files.append(f) files.append(f)
elif purposes == ['real','attack']: elif purposes == ['real','attack']:
files.append(f) files.append(f)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment