Commit 8d68a694 authored by Amir Mohammadi's avatar Amir Mohammadi
Browse files

Add a protocol for verification

parent 0d9f19da
Pipeline #4652 canceled with stages
in 3 minutes
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
""" The Replay-Mobile Database for face spoofing interface. It is an
extension of an SQL-based database interface, which directly talks to Replay-
Mobile database, for verification experiments (good to use in bob.bio.base
framework). It also implements a kind of hack so that you can run
vulnerability analysis with it. """
from bob.db.base import File as BaseFile
from .query import Database as LDatabase
def selected_indices(total_number_of_indices, desired_number_of_indices=None):
"""Returns a list of indices that will contain exactly the number of desired indices (or the number of total items in the list, if this is smaller).
These indices are selected such that they are evenly spread over the whole sequence."""
if desired_number_of_indices is None or desired_number_of_indices >= total_number_of_indices or desired_number_of_indices < 0:
return range(total_number_of_indices)
increase = float(total_number_of_indices) / float(desired_number_of_indices)
# generate a regular quasi-random index list
return [int((i + .5) * increase) for i in range(desired_number_of_indices)]
class File(BaseFile):
"""Replay Mobile low-level file used for vulnerability analysis in face recognition"""
def __init__(self, f, framen=None):
self._f = f
self.framen = framen
self.path = '{}_{:03d}'.format(f.path, framen)
self.client_id = f.client_id
self.file_id = '{}_{}'.format(f.id, framen)
super(File, self).__init__(path=self.path, file_id=self.file_id)
def load(self, directory=None, extension=None):
if extension in (None, '.mov'):
video = self._f.load(directory, extension)
# just return the required frame.
return video[self.framen]
else:
return super(File, self).load(directory, extension)
class Database(LDatabase):
"""
Implements verification API for querying Replay Mobile database.
This database loads max_number_of_frames from the video files as
separate samples. This is different from what bob.bio.video does
currently.
"""
__doc__ = __doc__
def __init__(self, max_number_of_frames=None):
# call base class constructors to open a session to the database
super(Database, self).__init__()
self.max_number_of_frames = max_number_of_frames or 10
# 300 is the number of frames in replay mobile videos
self.indices = selected_indices(300, max_number_of_frames)
self.low_level_group_names = ('train', 'devel', 'test')
self.high_level_group_names = ('world', 'dev', 'eval')
def protocol_names(self):
"""Returns all registered protocol names
Here I am going to hack and double the number of protocols
with -licit and -spoof. This is done for running vulnerability
analysis"""
names = [p.name + '-licit' for p in super(Database, self).protocols()]
names += [p.name + '-spoof' for p in super(Database, self).protocols()]
return names
def groups(self):
return self.convert_names_to_highlevel(
super(Database, self).groups(), self.low_level_group_names, self.high_level_group_names)
def annotations(self, myfile):
"""Will return the bounding box annotation of nth frame of the video."""
fn = myfile.framen # 10th frame number
annots = myfile._f.bbx(directory=self.original_directory)
# bob uses the (y, x) format
topleft = (annots[fn][2], annots[fn][1])
bottomright = (annots[fn][2] + annots[fn][4], annots[fn][1] + annots[fn][3])
annotations = {'topleft': topleft, 'bottomright': bottomright}
return annotations
def model_ids_with_protocol(self, groups=None, protocol=None, **kwargs):
# since the low-level API does not support verification straight-forward-ly, we improvise.
files = self.objects(groups=groups, protocol=protocol, purposes='enroll', **kwargs)
return sorted(set(f.client_id for f in files))
def objects(self, groups=None, protocol=None, purposes=None, model_ids=None, **kwargs):
if protocol == '.':
protocol = None
protocol = self.check_parameter_for_validity(protocol, "protocol", self.protocol_names(), 'grandtest-licit')
groups = self.check_parameters_for_validity(groups, "group", self.groups(), self.groups())
purposes = self.check_parameters_for_validity(purposes, "purpose", ('enroll', 'probe'), ('enroll', 'probe'))
purposes = list(purposes)
groups = self.convert_names_to_lowlevel(
groups, self.low_level_group_names, self.high_level_group_names)
# protocol licit is not defined in the low level API
# so do a hack here.
if '-licit' in protocol:
# for licit we return the grandtest protocol
protocol = protocol.replace('-licit', '')
# The low-level API has only "attack", "real", "enroll" and "probe"
# should translate to "real" or "attack" depending on the protocol.
# enroll does not to change.
if 'probe' in purposes:
purposes.remove('probe')
purposes.append('real')
if len(purposes) == 1:
# making the model_ids to None will return all clients which make
# the impostor data also available.
model_ids = None
elif model_ids:
raise NotImplementedError(
'Currently returning both enroll and probe for specific '
'client(s) in the licit protocol is not supported. '
'Please specify one purpose only.')
elif '-spoof' in protocol:
protocol = protocol.replace('-spoof', '')
# you need to replace probe with attack and real for the spoof protocols.
# You can add the real here also to create positives scores also
# but usually you get these scores when you run the licit protocol
if 'probe' in purposes:
purposes.remove('probe')
purposes.append('attack')
# now, query the actual Replay database
objects = super(Database, self).objects(groups=groups, protocol=protocol, cls=purposes, clients=model_ids, **kwargs)
# make sure to return File representation of a file, not the database one
# also make sure you replace client ids with attack
retval = []
for f in objects:
for i in self.indices:
if f.is_real():
retval.append(File(f, i))
else:
temp = File(f, i)
temp.client_id = 'attack'
retval.append(temp)
return retval
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment