Skip to content
Snippets Groups Projects
Commit a0aa556d authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

Merge branch 'db-interface' into 'master'

Improve db interface

Closes #4

See merge request !11
parents 101f35b0 7a4a3961
Branches
Tags
1 merge request!11Improve db interface
Pipeline #30732 passed
...@@ -13,11 +13,57 @@ import numpy ...@@ -13,11 +13,57 @@ import numpy
import bob.io.base import bob.io.base
import bob.io.video import bob.io.video
import bob.core import bob.core
from bob.db.base.annotations import read_annotation_file
from bob.db.base import File as BaseFile
from bob.io.video import reader
logger = bob.core.log.setup('bob.db.replaymobile') logger = bob.core.log.setup('bob.db.replaymobile')
Base = declarative_base() Base = declarative_base()
REPLAYMOBILE_FRAME_SHAPE = (3, 1280, 720)
flip_file_list = ['client008_session02_authenticate_tablet_adverse', 'client008_session02_authenticate_tablet_controlled']
flip_client_list = [8]
def replaymobile_annotations(lowlevelfile, original_directory):
# numpy array containing the face bounding box data for each video
# frame, returned data format described in the f.bbx() method of the
# low level interface
annots = lowlevelfile.bbx(directory=original_directory)
annotations = {} # dictionary to return
for fn, frame_annots in enumerate(annots):
topleft = (frame_annots[1], frame_annots[0])
bottomright = (frame_annots[1] + frame_annots[3],
frame_annots[0] + frame_annots[2])
annotations[str(fn)] = {
'topleft': topleft,
'bottomright': bottomright
}
return annotations
def replaymobile_frames(lowlevelfile, original_directory):
vfilename = lowlevelfile.make_path(
directory=original_directory,
extension='.mov')
should_flip = not lowlevelfile.is_tablet()
if not should_flip:
if lowlevelfile.client.id in flip_client_list:
for mfn in flip_file_list:
if mfn in lowlevelfile.path:
should_flip = True
for frame in reader(vfilename):
frame = numpy.rollaxis(frame, 2, 1)
if should_flip:
frame = frame[:, ::-1, :]
yield frame
class Client(Base): class Client(Base):
"""Database clients, marked by an integer identifier and the set they belong """Database clients, marked by an integer identifier and the set they belong
...@@ -42,7 +88,7 @@ class Client(Base): ...@@ -42,7 +88,7 @@ class Client(Base):
return "Client('%s', '%s')" % (self.id, self.set) return "Client('%s', '%s')" % (self.id, self.set)
class File(Base): class File(Base, BaseFile):
"""Generic file container""" """Generic file container"""
__tablename__ = 'file' __tablename__ = 'file'
...@@ -81,29 +127,6 @@ class File(Base): ...@@ -81,29 +127,6 @@ class File(Base):
def __repr__(self): def __repr__(self):
return "File('%s')" % self.path return "File('%s')" % self.path
def make_path(self, directory=None, extension=None):
"""Wraps the current path so that a complete path is formed
Keyword parameters:
directory
An optional directory name that will be prefixed to the returned result.
extension
An optional extension that will be suffixed to the returned filename. The
extension normally includes the leading ``.`` character as in ``.jpg`` or
``.hdf5``.
Returns a string containing the newly generated file path.
"""
if not directory:
directory = ''
if not extension:
extension = ''
return str(os.path.join(directory, self.path + extension))
def videofile(self, directory=None): def videofile(self, directory=None):
"""Returns the path to the database video file for this object """Returns the path to the database video file for this object
...@@ -159,7 +182,16 @@ class File(Base): ...@@ -159,7 +182,16 @@ class File(Base):
Note that **not** all the frames may contain detected faces. Note that **not** all the frames may contain detected faces.
""" """
return numpy.loadtxt(self.facefile(directory), dtype=int) bbx = numpy.loadtxt(self.facefile(directory), dtype=int)
if self.client.id in flip_client_list:
if self.is_tablet():
logger.debug(self.path)
for mfn in flip_file_list:
if mfn in self.path:
logger.debug('flipping bbx')
for i in range(bbx.shape[0]):
bbx[i][1] = 1280 - (bbx[i][1] + bbx[i][3]) # correct the y-coord. of the top-left corner of bbx in this frame.
return bbx
def is_real(self): def is_real(self):
"""Returns True if this file belongs to a real access, False otherwise""" """Returns True if this file belongs to a real access, False otherwise"""
...@@ -192,8 +224,7 @@ class File(Base): ...@@ -192,8 +224,7 @@ class File(Base):
raise RuntimeError("%s is not an attack" % self) raise RuntimeError("%s is not an attack" % self)
return self.attack[0] return self.attack[0]
# def load(self, directory=None, extension='.hdf5'): def load(self, directory=None, extension='.mov'):
def load(self, directory=None, extension=None):
"""Loads the data at the specified location and using the given extension. """Loads the data at the specified location and using the given extension.
Keyword parameters: Keyword parameters:
...@@ -210,52 +241,58 @@ class File(Base): ...@@ -210,52 +241,58 @@ class File(Base):
output and the codec for saving the input blob. output and the codec for saving the input blob.
""" """
logger.debug('video file extension: {}'.format(extension)) logger.debug('video file extension: {}'.format(extension))
if extension is None:
extension = '.mov' directory = directory or self.original_directory
# if self.get_quality() == 'laptop': extension = extension or self.original_extension
# extension = '.mov'
# else:
# extension = '.mp4'
if extension == '.mov' or extension == '.mp4': if extension == '.mov' or extension == '.mp4':
vfilename = self.make_path(directory, extension) vfilename = self.make_path(directory, extension)
video = bob.io.video.reader(vfilename) video = bob.io.video.reader(vfilename)
vin = video.load() vin = video.load()
else: else:
vin = bob.io.base.load(self.make_path(directory, extension)) vin = bob.io.base.load(self.make_path(directory, extension))
vin = numpy.rollaxis(vin, 3, 2) vin = numpy.rollaxis(vin, 3, 2)
if not self.is_tablet(): if not self.is_tablet():
logger.debug('flipping mobile video') logger.debug('flipping mobile video')
vin = vin[:, :, ::-1, :] vin = vin[:, :, ::-1, :]
else:
# if self.is_rotated(): if self.client.id in flip_client_list:
# vin = vin[:, :, ::-1,:] for mfn in flip_file_list:
if mfn in self.path:
logger.debug('flipping tablet video')
vin = vin[:, :, ::-1, :]
return vin return vin
# return bob.io.base.load(self.make_path(directory, extension))
def save(self, data, directory=None, extension='.hdf5'):
"""Saves the input data at the specified location and using the given
extension.
Keyword parameters:
data
The data blob to be saved (normally a :py:class:`numpy.ndarray`).
directory
[optional] If not empty or None, this directory is prefixed to the final
file destination
extension
[optional] The extension of the filename - this will control the type of
output and the codec for saving the input blob.
"""
path = self.make_path(directory, extension) @property
bob.io.base.create_directories_safe(os.path.dirname(path)) def annotations(self):
bob.io.base.save(data, path) if hasattr(self, 'annotation_directory') and \
self.annotation_directory is not None:
# return the external annotations
annotations = read_annotation_file(
os.path.join(self.annotation_directory,
self.path + self.annotation_extension),
self.annotation_type)
return annotations
# return original annotations
return replaymobile_annotations(self, self.original_directory)
@property
def frames(self):
return replaymobile_frames(self, self.original_directory)
@property
def number_of_frames(self):
vfilename = self.make_path(
directory=self.original_directory,
extension='.mov')
return reader(vfilename).number_of_frames
@property
def frame_shape(self):
return REPLAYMOBILE_FRAME_SHAPE
# Intermediate mapping from RealAccess's to Protocol's # Intermediate mapping from RealAccess's to Protocol's
......
...@@ -6,7 +6,8 @@ replay mobile database in the most obvious ways. ...@@ -6,7 +6,8 @@ replay mobile database in the most obvious ways.
""" """
import os import os
from bob.db.base import utils, Database from bob.db.base import utils, SQLiteDatabase
from bob.extension import rc
from .models import * from .models import *
from .driver import Interface from .driver import Interface
...@@ -15,47 +16,26 @@ INFO = Interface() ...@@ -15,47 +16,26 @@ INFO = Interface()
SQLITE_FILE = INFO.files()[0] SQLITE_FILE = INFO.files()[0]
class Database(Database): class Database(SQLiteDatabase):
"""The dataset class opens and maintains a connection opened to the Database. """The dataset class opens and maintains a connection opened to the Database.
It provides many different ways to probe for the characteristics of the data It provides many different ways to probe for the characteristics of the data
and for the data itself inside the database. and for the data itself inside the database.
""" """
def __init__(self, original_directory=None, original_extension=None): def __init__(self,
original_directory=rc['bob.db.replaymobile.directory'],
original_extension='.mov',
annotation_directory=None,
annotation_extension='.json',
annotation_type='json',
):
# opens a session to the database - keep it open until the end # opens a session to the database - keep it open until the end
self.connect() super(Database, self).__init__(
super(Database, self).__init__(original_directory, original_extension) SQLITE_FILE, File, original_directory, original_extension)
self.annotation_directory = annotation_directory
def __del__(self): self.annotation_extension = annotation_extension
"""Releases the opened file descriptor""" self.annotation_type = annotation_type
if self.session:
try:
# Since the dispose function re-creates a pool
# which might fail in some conditions, e.g., when this destructor is called during the exit of the python interpreter
self.session.close()
self.session.bind.dispose()
except Exception:
pass
def connect(self):
"""Tries connecting or re-connecting to the database"""
if not os.path.exists(SQLITE_FILE):
self.session = None
else:
self.session = utils.session_try_readonly(INFO.type(), SQLITE_FILE)
def is_valid(self):
"""Returns if a valid session has been opened for reading the database"""
return self.session is not None
def assert_validity(self):
"""Raise a RuntimeError if the database backend is not available"""
if not self.is_valid():
raise RuntimeError("Database '%s' cannot be found at expected location '%s'. Create it and then try re-connecting using Database.connect()" % (INFO.name(), SQLITE_FILE))
def objects(self, support=Attack.attack_support_choices, def objects(self, support=Attack.attack_support_choices,
protocol='grandtest', groups=Client.set_choices, cls=('attack', 'real'), protocol='grandtest', groups=Client.set_choices, cls=('attack', 'real'),
...@@ -157,10 +137,9 @@ class Database(Database): ...@@ -157,10 +137,9 @@ class Database(Database):
# now query the database # now query the database
retval = [] retval = []
from sqlalchemy.sql.expression import or_
# real-accesses are simpler to query # real-accesses are simpler to query
if 'enroll' in cls: if 'enroll' in cls:
q = self.session.query(File).join(RealAccess).join(Client) q = self.m_session.query(File).join(RealAccess).join(Client)
if groups: if groups:
q = q.filter(Client.set.in_(groups)) q = q.filter(Client.set.in_(groups))
if clients: if clients:
...@@ -175,7 +154,7 @@ class Database(Database): ...@@ -175,7 +154,7 @@ class Database(Database):
# real-accesses are simpler to query # real-accesses are simpler to query
if 'real' in cls: if 'real' in cls:
q = self.session.query(File).join(RealAccess).join((Protocol, RealAccess.protocols)).join(Client) q = self.m_session.query(File).join(RealAccess).join((Protocol, RealAccess.protocols)).join(Client)
if groups: if groups:
q = q.filter(Client.set.in_(groups)) q = q.filter(Client.set.in_(groups))
if clients: if clients:
...@@ -190,7 +169,7 @@ class Database(Database): ...@@ -190,7 +169,7 @@ class Database(Database):
# attacks will have to be filtered a little bit more # attacks will have to be filtered a little bit more
if 'attack' in cls: if 'attack' in cls:
q = self.session.query(File).join(Attack).join((Protocol, Attack.protocols)).join(Client) q = self.m_session.query(File).join(Attack).join((Protocol, Attack.protocols)).join(Client)
if groups: if groups:
q = q.filter(Client.set.in_(groups)) q = q.filter(Client.set.in_(groups))
if clients: if clients:
...@@ -207,6 +186,12 @@ class Database(Database): ...@@ -207,6 +186,12 @@ class Database(Database):
q = q.order_by(Client.id) q = q.order_by(Client.id)
retval += list(q) retval += list(q)
for f in retval:
f.original_directory = self.original_directory
f.original_extension = self.original_extension
f.annotation_directory = self.annotation_directory
f.annotation_extension = self.annotation_extension
f.annotation_type = self.annotation_type
return retval return retval
def files(self, directory=None, extension=None, **object_query): def files(self, directory=None, extension=None, **object_query):
...@@ -244,33 +229,33 @@ class Database(Database): ...@@ -244,33 +229,33 @@ class Database(Database):
"""Returns an iterable with all known clients""" """Returns an iterable with all known clients"""
self.assert_validity() self.assert_validity()
return list(self.session.query(Client)) return list(self.m_session.query(Client))
def has_client_id(self, id): def has_client_id(self, id):
"""Returns True if we have a client with a certain integer identifier""" """Returns True if we have a client with a certain integer identifier"""
self.assert_validity() self.assert_validity()
return self.session.query(Client).filter(Client.id == id).count() != 0 return self.m_session.query(Client).filter(Client.id == id).count() != 0
def protocols(self): def protocols(self):
"""Returns all protocol objects. """Returns all protocol objects.
""" """
self.assert_validity() self.assert_validity()
return list(self.session.query(Protocol)) return list(self.m_session.query(Protocol))
def has_protocol(self, name): def has_protocol(self, name):
"""Tells if a certain protocol is available""" """Tells if a certain protocol is available"""
self.assert_validity() self.assert_validity()
return self.session.query(Protocol).filter(Protocol.name == name).count() != 0 return self.m_session.query(Protocol).filter(Protocol.name == name).count() != 0
def protocol(self, name): def protocol(self, name):
"""Returns the protocol object in the database given a certain name. Raises """Returns the protocol object in the database given a certain name. Raises
an error if that does not exist.""" an error if that does not exist."""
self.assert_validity() self.assert_validity()
return self.session.query(Protocol).filter(Protocol.name == name).one() return self.m_session.query(Protocol).filter(Protocol.name == name).one()
def groups(self): def groups(self):
"""Returns the names of all registered groups""" """Returns the names of all registered groups"""
...@@ -330,7 +315,7 @@ class Database(Database): ...@@ -330,7 +315,7 @@ class Database(Database):
self.assert_validity() self.assert_validity()
fobj = self.session.query(File).filter(File.id.in_(ids)) fobj = self.m_session.query(File).filter(File.id.in_(ids))
retval = [] retval = []
for p in ids: for p in ids:
retval.extend([k.make_path(prefix, suffix) for k in fobj if k.id == p]) retval.extend([k.make_path(prefix, suffix) for k in fobj if k.id == p])
...@@ -350,7 +335,7 @@ class Database(Database): ...@@ -350,7 +335,7 @@ class Database(Database):
self.assert_validity() self.assert_validity()
fobj = self.session.query(File).filter(File.path.in_(paths)) fobj = self.m_session.query(File).filter(File.path.in_(paths))
for p in paths: for p in paths:
retval.extend([k.id for k in fobj if k.path == p]) retval.extend([k.id for k in fobj if k.path == p])
return retval return retval
...@@ -387,7 +372,7 @@ class Database(Database): ...@@ -387,7 +372,7 @@ class Database(Database):
self.assert_validity() self.assert_validity()
fobj = self.session.query(File).filter_by(id=id).one() fobj = self.m_session.query(File).filter_by(id=id).one()
fullpath = os.path.join(directory, str(fobj.path) + extension) fullpath = os.path.join(directory, str(fobj.path) + extension)
fulldir = os.path.dirname(fullpath) fulldir = os.path.dirname(fullpath)
......
...@@ -39,6 +39,7 @@ class File(BaseFile): ...@@ -39,6 +39,7 @@ class File(BaseFile):
def __init__(self, f, framen=None): def __init__(self, f, framen=None):
self._f = f self._f = f
self.framen = framen self.framen = framen
self.original_path = f.path
self.path = '{}_{:03d}'.format(f.path, framen) self.path = '{}_{:03d}'.format(f.path, framen)
self.client_id = f.client_id self.client_id = f.client_id
self.file_id = '{}_{}'.format(f.id, framen) self.file_id = '{}_{}'.format(f.id, framen)
...@@ -56,6 +57,10 @@ class File(BaseFile): ...@@ -56,6 +57,10 @@ class File(BaseFile):
else: else:
return super(File, self).load(directory, extension) return super(File, self).load(directory, extension)
@property
def annotations(self):
return self._f.annotations[str(self.framen)]
class Database(BaseDatabase): class Database(BaseDatabase):
""" """
...@@ -66,11 +71,25 @@ class Database(BaseDatabase): ...@@ -66,11 +71,25 @@ class Database(BaseDatabase):
""" """
__doc__ = __doc__ __doc__ = __doc__
def __init__(self, max_number_of_frames=None, original_directory=None, original_extension=None): def __init__(self,
super(Database, self).__init__(original_directory, original_extension) max_number_of_frames=None,
original_directory=None,
original_extension=None,
annotation_directory=None,
annotation_extension='.json',
annotation_type='json',
):
# call base class constructors to open a session to the database # call base class constructors to open a session to the database
self._db = LDatabase() self._db = LDatabase(
original_directory=original_directory,
original_extension=original_extension,
annotation_directory=annotation_directory,
annotation_extension=annotation_extension,
annotation_type=annotation_type,
)
super(Database, self).__init__(original_directory, original_extension)
self.max_number_of_frames = max_number_of_frames or 10 self.max_number_of_frames = max_number_of_frames or 10
# 240 is the guaranteed number of frames in replay mobile videos # 240 is the guaranteed number of frames in replay mobile videos
...@@ -78,6 +97,46 @@ class Database(BaseDatabase): ...@@ -78,6 +97,46 @@ class Database(BaseDatabase):
self.low_level_group_names = ('train', 'devel', 'test') self.low_level_group_names = ('train', 'devel', 'test')
self.high_level_group_names = ('world', 'dev', 'eval') self.high_level_group_names = ('world', 'dev', 'eval')
@property
def original_directory(self):
return self._db.original_directory
@original_directory.setter
def original_directory(self, value):
self._db.original_directory = value
@property
def original_extension(self):
return self._db.original_extension
@original_extension.setter
def original_extension(self, value):
self._db.original_extension = value
@property
def annotation_directory(self):
return self._db.annotation_directory
@annotation_directory.setter
def annotation_directory(self, value):
self._db.annotation_directory = value
@property
def annotation_extension(self):
return self._db.annotation_extension
@annotation_extension.setter
def annotation_extension(self, value):
self._db.annotation_extension = value
@property
def annotation_type(self):
return self._db.annotation_type
@annotation_type.setter
def annotation_type(self, value):
self._db.annotation_type = value
def protocol_names(self): def protocol_names(self):
"""Returns all registered protocol names """Returns all registered protocol names
Here I am going to hack and double the number of protocols Here I am going to hack and double the number of protocols
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment