Commit a0aa556d authored by Amir MOHAMMADI's avatar Amir MOHAMMADI

Merge branch 'db-interface' into 'master'

Improve db interface

Closes #4

See merge request !11
parents 101f35b0 7a4a3961
Pipeline #30732 passed with stages
in 29 minutes and 11 seconds
......@@ -13,11 +13,57 @@ import numpy
import bob.core
from bob.db.base.annotations import read_annotation_file
from bob.db.base import File as BaseFile
from import reader
logger = bob.core.log.setup('bob.db.replaymobile')
Base = declarative_base()
flip_file_list = ['client008_session02_authenticate_tablet_adverse', 'client008_session02_authenticate_tablet_controlled']
flip_client_list = [8]
def replaymobile_annotations(lowlevelfile, original_directory):
# numpy array containing the face bounding box data for each video
# frame, returned data format described in the f.bbx() method of the
# low level interface
annots = lowlevelfile.bbx(directory=original_directory)
annotations = {} # dictionary to return
for fn, frame_annots in enumerate(annots):
topleft = (frame_annots[1], frame_annots[0])
bottomright = (frame_annots[1] + frame_annots[3],
frame_annots[0] + frame_annots[2])
annotations[str(fn)] = {
'topleft': topleft,
'bottomright': bottomright
return annotations
def replaymobile_frames(lowlevelfile, original_directory):
vfilename = lowlevelfile.make_path(
should_flip = not lowlevelfile.is_tablet()
if not should_flip:
if in flip_client_list:
for mfn in flip_file_list:
if mfn in lowlevelfile.path:
should_flip = True
for frame in reader(vfilename):
frame = numpy.rollaxis(frame, 2, 1)
if should_flip:
frame = frame[:, ::-1, :]
yield frame
class Client(Base):
"""Database clients, marked by an integer identifier and the set they belong
......@@ -42,7 +88,7 @@ class Client(Base):
return "Client('%s', '%s')" % (, self.set)
class File(Base):
class File(Base, BaseFile):
"""Generic file container"""
__tablename__ = 'file'
......@@ -81,29 +127,6 @@ class File(Base):
def __repr__(self):
return "File('%s')" % self.path
def make_path(self, directory=None, extension=None):
"""Wraps the current path so that a complete path is formed
Keyword parameters:
An optional directory name that will be prefixed to the returned result.
An optional extension that will be suffixed to the returned filename. The
extension normally includes the leading ``.`` character as in ``.jpg`` or
Returns a string containing the newly generated file path.
if not directory:
directory = ''
if not extension:
extension = ''
return str(os.path.join(directory, self.path + extension))
def videofile(self, directory=None):
"""Returns the path to the database video file for this object
......@@ -159,7 +182,16 @@ class File(Base):
Note that **not** all the frames may contain detected faces.
return numpy.loadtxt(self.facefile(directory), dtype=int)
bbx = numpy.loadtxt(self.facefile(directory), dtype=int)
if in flip_client_list:
if self.is_tablet():
for mfn in flip_file_list:
if mfn in self.path:
logger.debug('flipping bbx')
for i in range(bbx.shape[0]):
bbx[i][1] = 1280 - (bbx[i][1] + bbx[i][3]) # correct the y-coord. of the top-left corner of bbx in this frame.
return bbx
def is_real(self):
"""Returns True if this file belongs to a real access, False otherwise"""
......@@ -192,8 +224,7 @@ class File(Base):
raise RuntimeError("%s is not an attack" % self)
return self.attack[0]
# def load(self, directory=None, extension='.hdf5'):
def load(self, directory=None, extension=None):
def load(self, directory=None, extension='.mov'):
"""Loads the data at the specified location and using the given extension.
Keyword parameters:
......@@ -210,52 +241,58 @@ class File(Base):
output and the codec for saving the input blob.
logger.debug('video file extension: {}'.format(extension))
if extension is None:
extension = '.mov'
# if self.get_quality() == 'laptop':
# extension = '.mov'
# else:
# extension = '.mp4'
directory = directory or self.original_directory
extension = extension or self.original_extension
if extension == '.mov' or extension == '.mp4':
vfilename = self.make_path(directory, extension)
video =
vin = video.load()
vfilename = self.make_path(directory, extension)
video =
vin = video.load()
vin =, extension))
vin =, extension))
vin = numpy.rollaxis(vin, 3, 2)
if not self.is_tablet():
logger.debug('flipping mobile video')
vin = vin[:, :, ::-1, :]
# if self.is_rotated():
# vin = vin[:, :, ::-1,:]
logger.debug('flipping mobile video')
vin = vin[:, :, ::-1, :]
if in flip_client_list:
for mfn in flip_file_list:
if mfn in self.path:
logger.debug('flipping tablet video')
vin = vin[:, :, ::-1, :]
return vin
# return, extension))
def save(self, data, directory=None, extension='.hdf5'):
"""Saves the input data at the specified location and using the given
Keyword parameters:
The data blob to be saved (normally a :py:class:`numpy.ndarray`).
[optional] If not empty or None, this directory is prefixed to the final
file destination
[optional] The extension of the filename - this will control the type of
output and the codec for saving the input blob.
path = self.make_path(directory, extension), path)
def annotations(self):
if hasattr(self, 'annotation_directory') and \
self.annotation_directory is not None:
# return the external annotations
annotations = read_annotation_file(
self.path + self.annotation_extension),
return annotations
# return original annotations
return replaymobile_annotations(self, self.original_directory)
def frames(self):
return replaymobile_frames(self, self.original_directory)
def number_of_frames(self):
vfilename = self.make_path(
return reader(vfilename).number_of_frames
def frame_shape(self):
# Intermediate mapping from RealAccess's to Protocol's
......@@ -6,7 +6,8 @@ replay mobile database in the most obvious ways.
import os
from bob.db.base import utils, Database
from bob.db.base import utils, SQLiteDatabase
from bob.extension import rc
from .models import *
from .driver import Interface
......@@ -15,47 +16,26 @@ INFO = Interface()
SQLITE_FILE = INFO.files()[0]
class Database(Database):
class Database(SQLiteDatabase):
"""The dataset class opens and maintains a connection opened to the Database.
It provides many different ways to probe for the characteristics of the data
and for the data itself inside the database.
def __init__(self, original_directory=None, original_extension=None):
def __init__(self,
# opens a session to the database - keep it open until the end
super(Database, self).__init__(original_directory, original_extension)
def __del__(self):
"""Releases the opened file descriptor"""
if self.session:
# Since the dispose function re-creates a pool
# which might fail in some conditions, e.g., when this destructor is called during the exit of the python interpreter
except Exception:
def connect(self):
"""Tries connecting or re-connecting to the database"""
if not os.path.exists(SQLITE_FILE):
self.session = None
self.session = utils.session_try_readonly(INFO.type(), SQLITE_FILE)
def is_valid(self):
"""Returns if a valid session has been opened for reading the database"""
return self.session is not None
def assert_validity(self):
"""Raise a RuntimeError if the database backend is not available"""
if not self.is_valid():
raise RuntimeError("Database '%s' cannot be found at expected location '%s'. Create it and then try re-connecting using Database.connect()" % (, SQLITE_FILE))
super(Database, self).__init__(
SQLITE_FILE, File, original_directory, original_extension)
self.annotation_directory = annotation_directory
self.annotation_extension = annotation_extension
self.annotation_type = annotation_type
def objects(self, support=Attack.attack_support_choices,
protocol='grandtest', groups=Client.set_choices, cls=('attack', 'real'),
......@@ -157,10 +137,9 @@ class Database(Database):
# now query the database
retval = []
from sqlalchemy.sql.expression import or_
# real-accesses are simpler to query
if 'enroll' in cls:
q = self.session.query(File).join(RealAccess).join(Client)
q = self.m_session.query(File).join(RealAccess).join(Client)
if groups:
q = q.filter(Client.set.in_(groups))
if clients:
......@@ -175,7 +154,7 @@ class Database(Database):
# real-accesses are simpler to query
if 'real' in cls:
q = self.session.query(File).join(RealAccess).join((Protocol, RealAccess.protocols)).join(Client)
q = self.m_session.query(File).join(RealAccess).join((Protocol, RealAccess.protocols)).join(Client)
if groups:
q = q.filter(Client.set.in_(groups))
if clients:
......@@ -190,7 +169,7 @@ class Database(Database):
# attacks will have to be filtered a little bit more
if 'attack' in cls:
q = self.session.query(File).join(Attack).join((Protocol, Attack.protocols)).join(Client)
q = self.m_session.query(File).join(Attack).join((Protocol, Attack.protocols)).join(Client)
if groups:
q = q.filter(Client.set.in_(groups))
if clients:
......@@ -207,6 +186,12 @@ class Database(Database):
q = q.order_by(
retval += list(q)
for f in retval:
f.original_directory = self.original_directory
f.original_extension = self.original_extension
f.annotation_directory = self.annotation_directory
f.annotation_extension = self.annotation_extension
f.annotation_type = self.annotation_type
return retval
def files(self, directory=None, extension=None, **object_query):
......@@ -244,33 +229,33 @@ class Database(Database):
"""Returns an iterable with all known clients"""
return list(self.session.query(Client))
return list(self.m_session.query(Client))
def has_client_id(self, id):
"""Returns True if we have a client with a certain integer identifier"""
return self.session.query(Client).filter( == id).count() != 0
return self.m_session.query(Client).filter( == id).count() != 0
def protocols(self):
"""Returns all protocol objects.
return list(self.session.query(Protocol))
return list(self.m_session.query(Protocol))
def has_protocol(self, name):
"""Tells if a certain protocol is available"""
return self.session.query(Protocol).filter( == name).count() != 0
return self.m_session.query(Protocol).filter( == name).count() != 0
def protocol(self, name):
"""Returns the protocol object in the database given a certain name. Raises
an error if that does not exist."""
return self.session.query(Protocol).filter( == name).one()
return self.m_session.query(Protocol).filter( == name).one()
def groups(self):
"""Returns the names of all registered groups"""
......@@ -330,7 +315,7 @@ class Database(Database):
fobj = self.session.query(File).filter(
fobj = self.m_session.query(File).filter(
retval = []
for p in ids:
retval.extend([k.make_path(prefix, suffix) for k in fobj if == p])
......@@ -350,7 +335,7 @@ class Database(Database):
fobj = self.session.query(File).filter(File.path.in_(paths))
fobj = self.m_session.query(File).filter(File.path.in_(paths))
for p in paths:
retval.extend([ for k in fobj if k.path == p])
return retval
......@@ -387,7 +372,7 @@ class Database(Database):
fobj = self.session.query(File).filter_by(id=id).one()
fobj = self.m_session.query(File).filter_by(id=id).one()
fullpath = os.path.join(directory, str(fobj.path) + extension)
fulldir = os.path.dirname(fullpath)
......@@ -39,6 +39,7 @@ class File(BaseFile):
def __init__(self, f, framen=None):
self._f = f
self.framen = framen
self.original_path = f.path
self.path = '{}_{:03d}'.format(f.path, framen)
self.client_id = f.client_id
self.file_id = '{}_{}'.format(, framen)
......@@ -56,6 +57,10 @@ class File(BaseFile):
return super(File, self).load(directory, extension)
def annotations(self):
return self._f.annotations[str(self.framen)]
class Database(BaseDatabase):
......@@ -66,11 +71,25 @@ class Database(BaseDatabase):
__doc__ = __doc__
def __init__(self, max_number_of_frames=None, original_directory=None, original_extension=None):
super(Database, self).__init__(original_directory, original_extension)
def __init__(self,
# call base class constructors to open a session to the database
self._db = LDatabase()
self._db = LDatabase(
super(Database, self).__init__(original_directory, original_extension)
self.max_number_of_frames = max_number_of_frames or 10
# 240 is the guaranteed number of frames in replay mobile videos
......@@ -78,6 +97,46 @@ class Database(BaseDatabase):
self.low_level_group_names = ('train', 'devel', 'test')
self.high_level_group_names = ('world', 'dev', 'eval')
def original_directory(self):
return self._db.original_directory
def original_directory(self, value):
self._db.original_directory = value
def original_extension(self):
return self._db.original_extension
def original_extension(self, value):
self._db.original_extension = value
def annotation_directory(self):
return self._db.annotation_directory
def annotation_directory(self, value):
self._db.annotation_directory = value
def annotation_extension(self):
return self._db.annotation_extension
def annotation_extension(self, value):
self._db.annotation_extension = value
def annotation_type(self):
return self._db.annotation_type
def annotation_type(self, value):
self._db.annotation_type = value
def protocol_names(self):
"""Returns all registered protocol names
Here I am going to hack and double the number of protocols
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment