Commit 02d2e827 authored by Anjith GEORGE's avatar Anjith GEORGE

Added unseen attack protocols

parent 89c0b87c
Pipeline #39218 failed with stage
in 13 minutes and 4 seconds
......@@ -140,6 +140,55 @@ class HQWMCAPadDatabase(PadDatabase):
def original_directory(self, value):
self.db.original_directory = value
def unseen_attack_list_maker(self,files,unknown_attack,train=True):
"""
Selects and returns a list of files for Leave One Out (LOO) protocols.
This utilizes the original partitioning in the `grandtest` protocol and subselects
the file list such that the specified `unknown_attack` is removed from both `train` and `dev` sets.
The `test` set will consist of only the selected `unknown_attack` and `bonafide` files.
**Parameters:**
``files`` : pyclass::list
A list of files, db.objects()
``unknown_attack`` : str
The unknown attack protocol name example:'rigidmask' .
``train`` : bool
Denotes whether files are from training or development partition
**Returns:**
``mod_files`` : pyclass::list
A list of files selected for the protocol
"""
mod_files=[]
for file in files:
if file.is_attack():
attack_category = idiap_type_id_config[str(file.type_id)]
else:
attack_category ='bonafide'
if train:
if attack_category==unknown_attack:
pass
else:
mod_files.append(file) # everything except the attack specified is there
if not train:
if attack_category==unknown_attack or attack_category=='bonafide':
mod_files.append(file) # only the attack mentioned and bonafides in testing
else:
pass
return mod_files
def objects(self,
groups=None,
protocol=None,
......@@ -187,11 +236,58 @@ class HQWMCAPadDatabase(PadDatabase):
protocol=protocol.split('-')[0]
files = self.db.objects(protocol=protocol,
groups=groups,
purposes=purposes,
attacks=attack_types,
**kwargs)
unseen_attack=None
if 'LOO' in protocol:
unseen_attack=protocol.split('_')[-1]
self.use_curated_file_list=True
else:
files = self.db.objects(protocol=protocol,
groups=groups,
purposes=purposes,
attacks=attack_types,
**kwargs)
if unseen_attack is not None:
hqwmca_files=[]
if 'train' in groups:
t_hqwmca_files = self.db.objects(protocol='grand_test',
groups=['train'],
purposes=purposes, **kwargs)
t_hqwmca_files=self.unseen_attack_list_maker(t_hqwmca_files,unseen_attack,train=True)
hqwmca_files=hqwmca_files+t_hqwmca_files
if 'validation' in groups:
t_hqwmca_files = self.db.objects(protocol='grand_test',
groups=['validation'],
purposes=purposes, **kwargs)
t_hqwmca_files=self.unseen_attack_list_maker(t_hqwmca_files,unseen_attack,train=True)
hqwmca_files=hqwmca_files+t_hqwmca_files
if 'test' in groups:
t_hqwmca_files = self.db.objects(protocol='grand_test',
groups=['test'],
purposes=purposes, **kwargs)
t_hqwmca_files=self.unseen_attack_list_maker(t_hqwmca_files,unseen_attack,train=False)
hqwmca_files=hqwmca_files+t_hqwmca_files
files=hqwmca_files
......@@ -348,13 +444,13 @@ class HQWMCAPadDatabase(PadDatabase):
print('rep_annotations.keys', rep_annotations.keys(), rep_annotations)
r_file_path=file_path.replace('raw_annotations','rep_annotations')
# r_file_path=file_path.replace('raw_annotations','rep_annotations')
bob.io.base.create_directories_safe(directory=os.path.split(r_file_path)[0], dryrun=False)
# bob.io.base.create_directories_safe(directory=os.path.split(r_file_path)[0], dryrun=False)
with open(r_file_path, 'w+') as json_file:
# with open(r_file_path, 'w+') as json_file:
json_file.write(json.dumps(rep_annotations))
# json_file.write(json.dumps(rep_annotations))
if len(rep_annotations.keys()) < 2:
print('KEYS LESSSSSSSSSSSSSSSSSSSSSSSSSSSSSS...........')
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment