diff --git a/bob/db/replaymobile/models.py b/bob/db/replaymobile/models.py index c623b815b71679ba80e9cc6d6eac96a829929b44..c36038018de59533469ad1d49d182f462c3b2437 100644 --- a/bob/db/replaymobile/models.py +++ b/bob/db/replaymobile/models.py @@ -13,11 +13,57 @@ import numpy import bob.io.base import bob.io.video import bob.core +from bob.db.base.annotations import read_annotation_file +from bob.db.base import File as BaseFile +from bob.io.video import reader logger = bob.core.log.setup('bob.db.replaymobile') Base = declarative_base() +REPLAYMOBILE_FRAME_SHAPE = (3, 1280, 720) +flip_file_list = ['client008_session02_authenticate_tablet_adverse', 'client008_session02_authenticate_tablet_controlled'] +flip_client_list = [8] + + +def replaymobile_annotations(lowlevelfile, original_directory): + # numpy array containing the face bounding box data for each video + # frame, returned data format described in the f.bbx() method of the + # low level interface + annots = lowlevelfile.bbx(directory=original_directory) + + annotations = {} # dictionary to return + + for fn, frame_annots in enumerate(annots): + + topleft = (frame_annots[1], frame_annots[0]) + bottomright = (frame_annots[1] + frame_annots[3], + frame_annots[0] + frame_annots[2]) + + annotations[str(fn)] = { + 'topleft': topleft, + 'bottomright': bottomright + } + + return annotations + + +def replaymobile_frames(lowlevelfile, original_directory): + vfilename = lowlevelfile.make_path( + directory=original_directory, + extension='.mov') + should_flip = not lowlevelfile.is_tablet() + if not should_flip: + if lowlevelfile.client.id in flip_client_list: + for mfn in flip_file_list: + if mfn in lowlevelfile.path: + should_flip = True + for frame in reader(vfilename): + frame = numpy.rollaxis(frame, 2, 1) + if should_flip: + frame = frame[:, ::-1, :] + yield frame + class Client(Base): """Database clients, marked by an integer identifier and the set they belong @@ -42,7 +88,7 @@ class Client(Base): return "Client('%s', '%s')" % (self.id, self.set) -class File(Base): +class File(Base, BaseFile): """Generic file container""" __tablename__ = 'file' @@ -81,29 +127,6 @@ class File(Base): def __repr__(self): return "File('%s')" % self.path - def make_path(self, directory=None, extension=None): - """Wraps the current path so that a complete path is formed - - Keyword parameters: - - directory - An optional directory name that will be prefixed to the returned result. - - extension - An optional extension that will be suffixed to the returned filename. The - extension normally includes the leading ``.`` character as in ``.jpg`` or - ``.hdf5``. - - Returns a string containing the newly generated file path. - """ - - if not directory: - directory = '' - if not extension: - extension = '' - - return str(os.path.join(directory, self.path + extension)) - def videofile(self, directory=None): """Returns the path to the database video file for this object @@ -159,7 +182,16 @@ class File(Base): Note that **not** all the frames may contain detected faces. """ - return numpy.loadtxt(self.facefile(directory), dtype=int) + bbx = numpy.loadtxt(self.facefile(directory), dtype=int) + if self.client.id in flip_client_list: + if self.is_tablet(): + logger.debug(self.path) + for mfn in flip_file_list: + if mfn in self.path: + logger.debug('flipping bbx') + for i in range(bbx.shape[0]): + bbx[i][1] = 1280 - (bbx[i][1] + bbx[i][3]) # correct the y-coord. of the top-left corner of bbx in this frame. + return bbx def is_real(self): """Returns True if this file belongs to a real access, False otherwise""" @@ -192,8 +224,7 @@ class File(Base): raise RuntimeError("%s is not an attack" % self) return self.attack[0] - # def load(self, directory=None, extension='.hdf5'): - def load(self, directory=None, extension=None): + def load(self, directory=None, extension='.mov'): """Loads the data at the specified location and using the given extension. Keyword parameters: @@ -210,52 +241,58 @@ class File(Base): output and the codec for saving the input blob. """ logger.debug('video file extension: {}'.format(extension)) - if extension is None: - extension = '.mov' - # if self.get_quality() == 'laptop': - # extension = '.mov' - # else: - # extension = '.mp4' + + directory = directory or self.original_directory + extension = extension or self.original_extension if extension == '.mov' or extension == '.mp4': - vfilename = self.make_path(directory, extension) - video = bob.io.video.reader(vfilename) - vin = video.load() + vfilename = self.make_path(directory, extension) + video = bob.io.video.reader(vfilename) + vin = video.load() else: - vin = bob.io.base.load(self.make_path(directory, extension)) + vin = bob.io.base.load(self.make_path(directory, extension)) vin = numpy.rollaxis(vin, 3, 2) if not self.is_tablet(): - logger.debug('flipping mobile video') - vin = vin[:, :, ::-1, :] - - # if self.is_rotated(): - # vin = vin[:, :, ::-1,:] + logger.debug('flipping mobile video') + vin = vin[:, :, ::-1, :] + else: + if self.client.id in flip_client_list: + for mfn in flip_file_list: + if mfn in self.path: + logger.debug('flipping tablet video') + vin = vin[:, :, ::-1, :] return vin - # return bob.io.base.load(self.make_path(directory, extension)) - - def save(self, data, directory=None, extension='.hdf5'): - """Saves the input data at the specified location and using the given - extension. - - Keyword parameters: - - data - The data blob to be saved (normally a :py:class:`numpy.ndarray`). - - directory - [optional] If not empty or None, this directory is prefixed to the final - file destination - - extension - [optional] The extension of the filename - this will control the type of - output and the codec for saving the input blob. - """ - path = self.make_path(directory, extension) - bob.io.base.create_directories_safe(os.path.dirname(path)) - bob.io.base.save(data, path) + @property + def annotations(self): + if hasattr(self, 'annotation_directory') and \ + self.annotation_directory is not None: + # return the external annotations + annotations = read_annotation_file( + os.path.join(self.annotation_directory, + self.path + self.annotation_extension), + self.annotation_type) + return annotations + + # return original annotations + return replaymobile_annotations(self, self.original_directory) + + @property + def frames(self): + return replaymobile_frames(self, self.original_directory) + + @property + def number_of_frames(self): + vfilename = self.make_path( + directory=self.original_directory, + extension='.mov') + return reader(vfilename).number_of_frames + + @property + def frame_shape(self): + return REPLAYMOBILE_FRAME_SHAPE # Intermediate mapping from RealAccess's to Protocol's diff --git a/bob/db/replaymobile/query.py b/bob/db/replaymobile/query.py index 96c75443ad0ba179a51bfdfa4a43dbab3e7c0d42..2aa0e7baf3008601d4fc503122beb1771ab50738 100644 --- a/bob/db/replaymobile/query.py +++ b/bob/db/replaymobile/query.py @@ -6,7 +6,8 @@ replay mobile database in the most obvious ways. """ import os -from bob.db.base import utils, Database +from bob.db.base import utils, SQLiteDatabase +from bob.extension import rc from .models import * from .driver import Interface @@ -15,47 +16,26 @@ INFO = Interface() SQLITE_FILE = INFO.files()[0] -class Database(Database): +class Database(SQLiteDatabase): """The dataset class opens and maintains a connection opened to the Database. It provides many different ways to probe for the characteristics of the data and for the data itself inside the database. """ - def __init__(self, original_directory=None, original_extension=None): + def __init__(self, + original_directory=rc['bob.db.replaymobile.directory'], + original_extension='.mov', + annotation_directory=None, + annotation_extension='.json', + annotation_type='json', + ): # opens a session to the database - keep it open until the end - self.connect() - super(Database, self).__init__(original_directory, original_extension) - - def __del__(self): - """Releases the opened file descriptor""" - if self.session: - try: - # Since the dispose function re-creates a pool - # which might fail in some conditions, e.g., when this destructor is called during the exit of the python interpreter - self.session.close() - self.session.bind.dispose() - except Exception: - pass - - def connect(self): - """Tries connecting or re-connecting to the database""" - if not os.path.exists(SQLITE_FILE): - self.session = None - - else: - self.session = utils.session_try_readonly(INFO.type(), SQLITE_FILE) - - def is_valid(self): - """Returns if a valid session has been opened for reading the database""" - - return self.session is not None - - def assert_validity(self): - """Raise a RuntimeError if the database backend is not available""" - - if not self.is_valid(): - raise RuntimeError("Database '%s' cannot be found at expected location '%s'. Create it and then try re-connecting using Database.connect()" % (INFO.name(), SQLITE_FILE)) + super(Database, self).__init__( + SQLITE_FILE, File, original_directory, original_extension) + self.annotation_directory = annotation_directory + self.annotation_extension = annotation_extension + self.annotation_type = annotation_type def objects(self, support=Attack.attack_support_choices, protocol='grandtest', groups=Client.set_choices, cls=('attack', 'real'), @@ -157,10 +137,9 @@ class Database(Database): # now query the database retval = [] - from sqlalchemy.sql.expression import or_ # real-accesses are simpler to query if 'enroll' in cls: - q = self.session.query(File).join(RealAccess).join(Client) + q = self.m_session.query(File).join(RealAccess).join(Client) if groups: q = q.filter(Client.set.in_(groups)) if clients: @@ -175,7 +154,7 @@ class Database(Database): # real-accesses are simpler to query if 'real' in cls: - q = self.session.query(File).join(RealAccess).join((Protocol, RealAccess.protocols)).join(Client) + q = self.m_session.query(File).join(RealAccess).join((Protocol, RealAccess.protocols)).join(Client) if groups: q = q.filter(Client.set.in_(groups)) if clients: @@ -190,7 +169,7 @@ class Database(Database): # attacks will have to be filtered a little bit more if 'attack' in cls: - q = self.session.query(File).join(Attack).join((Protocol, Attack.protocols)).join(Client) + q = self.m_session.query(File).join(Attack).join((Protocol, Attack.protocols)).join(Client) if groups: q = q.filter(Client.set.in_(groups)) if clients: @@ -207,6 +186,12 @@ class Database(Database): q = q.order_by(Client.id) retval += list(q) + for f in retval: + f.original_directory = self.original_directory + f.original_extension = self.original_extension + f.annotation_directory = self.annotation_directory + f.annotation_extension = self.annotation_extension + f.annotation_type = self.annotation_type return retval def files(self, directory=None, extension=None, **object_query): @@ -244,33 +229,33 @@ class Database(Database): """Returns an iterable with all known clients""" self.assert_validity() - return list(self.session.query(Client)) + return list(self.m_session.query(Client)) def has_client_id(self, id): """Returns True if we have a client with a certain integer identifier""" self.assert_validity() - return self.session.query(Client).filter(Client.id == id).count() != 0 + return self.m_session.query(Client).filter(Client.id == id).count() != 0 def protocols(self): """Returns all protocol objects. """ self.assert_validity() - return list(self.session.query(Protocol)) + return list(self.m_session.query(Protocol)) def has_protocol(self, name): """Tells if a certain protocol is available""" self.assert_validity() - return self.session.query(Protocol).filter(Protocol.name == name).count() != 0 + return self.m_session.query(Protocol).filter(Protocol.name == name).count() != 0 def protocol(self, name): """Returns the protocol object in the database given a certain name. Raises an error if that does not exist.""" self.assert_validity() - return self.session.query(Protocol).filter(Protocol.name == name).one() + return self.m_session.query(Protocol).filter(Protocol.name == name).one() def groups(self): """Returns the names of all registered groups""" @@ -330,7 +315,7 @@ class Database(Database): self.assert_validity() - fobj = self.session.query(File).filter(File.id.in_(ids)) + fobj = self.m_session.query(File).filter(File.id.in_(ids)) retval = [] for p in ids: retval.extend([k.make_path(prefix, suffix) for k in fobj if k.id == p]) @@ -350,7 +335,7 @@ class Database(Database): self.assert_validity() - fobj = self.session.query(File).filter(File.path.in_(paths)) + fobj = self.m_session.query(File).filter(File.path.in_(paths)) for p in paths: retval.extend([k.id for k in fobj if k.path == p]) return retval @@ -387,7 +372,7 @@ class Database(Database): self.assert_validity() - fobj = self.session.query(File).filter_by(id=id).one() + fobj = self.m_session.query(File).filter_by(id=id).one() fullpath = os.path.join(directory, str(fobj.path) + extension) fulldir = os.path.dirname(fullpath) diff --git a/bob/db/replaymobile/verificationprotocol.py b/bob/db/replaymobile/verificationprotocol.py index 979c8ebbc92f1d300439cb7e58c810d54fce96d8..327c7112ca1feb73a17fb216aafe0d3d2a825418 100644 --- a/bob/db/replaymobile/verificationprotocol.py +++ b/bob/db/replaymobile/verificationprotocol.py @@ -39,6 +39,7 @@ class File(BaseFile): def __init__(self, f, framen=None): self._f = f self.framen = framen + self.original_path = f.path self.path = '{}_{:03d}'.format(f.path, framen) self.client_id = f.client_id self.file_id = '{}_{}'.format(f.id, framen) @@ -56,6 +57,10 @@ class File(BaseFile): else: return super(File, self).load(directory, extension) + @property + def annotations(self): + return self._f.annotations[str(self.framen)] + class Database(BaseDatabase): """ @@ -66,11 +71,25 @@ class Database(BaseDatabase): """ __doc__ = __doc__ - def __init__(self, max_number_of_frames=None, original_directory=None, original_extension=None): - super(Database, self).__init__(original_directory, original_extension) + def __init__(self, + max_number_of_frames=None, + original_directory=None, + original_extension=None, + annotation_directory=None, + annotation_extension='.json', + annotation_type='json', + ): # call base class constructors to open a session to the database - self._db = LDatabase() + self._db = LDatabase( + original_directory=original_directory, + original_extension=original_extension, + annotation_directory=annotation_directory, + annotation_extension=annotation_extension, + annotation_type=annotation_type, + ) + + super(Database, self).__init__(original_directory, original_extension) self.max_number_of_frames = max_number_of_frames or 10 # 240 is the guaranteed number of frames in replay mobile videos @@ -78,6 +97,46 @@ class Database(BaseDatabase): self.low_level_group_names = ('train', 'devel', 'test') self.high_level_group_names = ('world', 'dev', 'eval') + @property + def original_directory(self): + return self._db.original_directory + + @original_directory.setter + def original_directory(self, value): + self._db.original_directory = value + + @property + def original_extension(self): + return self._db.original_extension + + @original_extension.setter + def original_extension(self, value): + self._db.original_extension = value + + @property + def annotation_directory(self): + return self._db.annotation_directory + + @annotation_directory.setter + def annotation_directory(self, value): + self._db.annotation_directory = value + + @property + def annotation_extension(self): + return self._db.annotation_extension + + @annotation_extension.setter + def annotation_extension(self, value): + self._db.annotation_extension = value + + @property + def annotation_type(self): + return self._db.annotation_type + + @annotation_type.setter + def annotation_type(self, value): + self._db.annotation_type = value + def protocol_names(self): """Returns all registered protocol names Here I am going to hack and double the number of protocols