Commit 5e3ef9e4 authored by Amir MOHAMMADI's avatar Amir MOHAMMADI

Improve the database interface with frames and annotations

parent 101f35b0
Pipeline #28967 passed with stage
in 10 minutes and 4 seconds
......@@ -13,11 +13,50 @@ import numpy
import bob.io.base
import bob.io.video
import bob.core
from bob.db.base.annotations import read_annotation_file
from bob.db.base import File as BaseFile
from bob.io.video import reader
logger = bob.core.log.setup('bob.db.replaymobile')
Base = declarative_base()
REPLAYMOBILE_FRAME_SHAPE = (3, 1280, 720)
def replaymobile_annotations(lowlevelfile, original_directory):
# numpy array containing the face bounding box data for each video
# frame, returned data format described in the f.bbx() method of the
# low level interface
annots = lowlevelfile.bbx(directory=original_directory)
annotations = {} # dictionary to return
for fn, frame_annots in enumerate(annots):
topleft = (frame_annots[1], frame_annots[0])
bottomright = (frame_annots[1] + frame_annots[3],
frame_annots[0] + frame_annots[2])
annotations[str(fn)] = {
'topleft': topleft,
'bottomright': bottomright
}
return annotations
def replaymobile_frames(lowlevelfile, original_directory):
vfilename = lowlevelfile.make_path(
directory=original_directory,
extension='.mov')
is_not_tablet = not lowlevelfile.is_tablet()
for frame in reader(vfilename):
frame = numpy.rollaxis(frame, 2, 1)
if is_not_tablet:
frame = frame[:, ::-1, :]
yield frame
class Client(Base):
"""Database clients, marked by an integer identifier and the set they belong
......@@ -42,7 +81,7 @@ class Client(Base):
return "Client('%s', '%s')" % (self.id, self.set)
class File(Base):
class File(Base, BaseFile):
"""Generic file container"""
__tablename__ = 'file'
......@@ -81,29 +120,6 @@ class File(Base):
def __repr__(self):
return "File('%s')" % self.path
def make_path(self, directory=None, extension=None):
"""Wraps the current path so that a complete path is formed
Keyword parameters:
directory
An optional directory name that will be prefixed to the returned result.
extension
An optional extension that will be suffixed to the returned filename. The
extension normally includes the leading ``.`` character as in ``.jpg`` or
``.hdf5``.
Returns a string containing the newly generated file path.
"""
if not directory:
directory = ''
if not extension:
extension = ''
return str(os.path.join(directory, self.path + extension))
def videofile(self, directory=None):
"""Returns the path to the database video file for this object
......@@ -192,8 +208,7 @@ class File(Base):
raise RuntimeError("%s is not an attack" % self)
return self.attack[0]
# def load(self, directory=None, extension='.hdf5'):
def load(self, directory=None, extension=None):
def load(self, directory=None, extension='.mov'):
"""Loads the data at the specified location and using the given extension.
Keyword parameters:
......@@ -210,52 +225,52 @@ class File(Base):
output and the codec for saving the input blob.
"""
logger.debug('video file extension: {}'.format(extension))
if extension is None:
extension = '.mov'
# if self.get_quality() == 'laptop':
# extension = '.mov'
# else:
# extension = '.mp4'
directory = directory or self.original_directory
extension = extension or self.original_extension
if extension == '.mov' or extension == '.mp4':
vfilename = self.make_path(directory, extension)
video = bob.io.video.reader(vfilename)
vin = video.load()
vfilename = self.make_path(directory, extension)
video = bob.io.video.reader(vfilename)
vin = video.load()
else:
vin = bob.io.base.load(self.make_path(directory, extension))
vin = bob.io.base.load(self.make_path(directory, extension))
vin = numpy.rollaxis(vin, 3, 2)
if not self.is_tablet():
logger.debug('flipping mobile video')
vin = vin[:, :, ::-1, :]
# if self.is_rotated():
# vin = vin[:, :, ::-1,:]
logger.debug('flipping mobile video')
vin = vin[:, :, ::-1, :]
return vin
# return bob.io.base.load(self.make_path(directory, extension))
def save(self, data, directory=None, extension='.hdf5'):
"""Saves the input data at the specified location and using the given
extension.
Keyword parameters:
data
The data blob to be saved (normally a :py:class:`numpy.ndarray`).
directory
[optional] If not empty or None, this directory is prefixed to the final
file destination
extension
[optional] The extension of the filename - this will control the type of
output and the codec for saving the input blob.
"""
path = self.make_path(directory, extension)
bob.io.base.create_directories_safe(os.path.dirname(path))
bob.io.base.save(data, path)
@property
def annotations(self):
if hasattr(self, 'annotation_directory') and \
self.annotation_directory is not None:
# return the external annotations
annotations = read_annotation_file(
os.path.join(self.annotation_directory,
self.path + self.annotation_extension),
self.annotation_type)
return annotations
# return original annotations
return replaymobile_annotations(self, self.original_directory)
@property
def frames(self):
return replaymobile_frames(self, self.original_directory)
@property
def number_of_frames(self):
vfilename = self.make_path(
directory=self.original_directory,
extension='.mov')
return reader(vfilename).number_of_frames
@property
def frame_shape(self):
return REPLAYMOBILE_FRAME_SHAPE
# Intermediate mapping from RealAccess's to Protocol's
......
......@@ -6,7 +6,8 @@ replay mobile database in the most obvious ways.
"""
import os
from bob.db.base import utils, Database
from bob.db.base import utils, SQLiteDatabase
from bob.extension import rc
from .models import *
from .driver import Interface
......@@ -15,47 +16,26 @@ INFO = Interface()
SQLITE_FILE = INFO.files()[0]
class Database(Database):
class Database(SQLiteDatabase):
"""The dataset class opens and maintains a connection opened to the Database.
It provides many different ways to probe for the characteristics of the data
and for the data itself inside the database.
"""
def __init__(self, original_directory=None, original_extension=None):
def __init__(self,
original_directory=rc['bob.db.replaymobile.directory'],
original_extension='.mov',
annotation_directory=None,
annotation_extension='.json',
annotation_type='json',
):
# opens a session to the database - keep it open until the end
self.connect()
super(Database, self).__init__(original_directory, original_extension)
def __del__(self):
"""Releases the opened file descriptor"""
if self.session:
try:
# Since the dispose function re-creates a pool
# which might fail in some conditions, e.g., when this destructor is called during the exit of the python interpreter
self.session.close()
self.session.bind.dispose()
except Exception:
pass
def connect(self):
"""Tries connecting or re-connecting to the database"""
if not os.path.exists(SQLITE_FILE):
self.session = None
else:
self.session = utils.session_try_readonly(INFO.type(), SQLITE_FILE)
def is_valid(self):
"""Returns if a valid session has been opened for reading the database"""
return self.session is not None
def assert_validity(self):
"""Raise a RuntimeError if the database backend is not available"""
if not self.is_valid():
raise RuntimeError("Database '%s' cannot be found at expected location '%s'. Create it and then try re-connecting using Database.connect()" % (INFO.name(), SQLITE_FILE))
super(Database, self).__init__(
SQLITE_FILE, File, original_directory, original_extension)
self.annotation_directory = annotation_directory
self.annotation_extension = annotation_extension
self.annotation_type = annotation_type
def objects(self, support=Attack.attack_support_choices,
protocol='grandtest', groups=Client.set_choices, cls=('attack', 'real'),
......@@ -157,10 +137,9 @@ class Database(Database):
# now query the database
retval = []
from sqlalchemy.sql.expression import or_
# real-accesses are simpler to query
if 'enroll' in cls:
q = self.session.query(File).join(RealAccess).join(Client)
q = self.m_session.query(File).join(RealAccess).join(Client)
if groups:
q = q.filter(Client.set.in_(groups))
if clients:
......@@ -175,7 +154,7 @@ class Database(Database):
# real-accesses are simpler to query
if 'real' in cls:
q = self.session.query(File).join(RealAccess).join((Protocol, RealAccess.protocols)).join(Client)
q = self.m_session.query(File).join(RealAccess).join((Protocol, RealAccess.protocols)).join(Client)
if groups:
q = q.filter(Client.set.in_(groups))
if clients:
......@@ -190,7 +169,7 @@ class Database(Database):
# attacks will have to be filtered a little bit more
if 'attack' in cls:
q = self.session.query(File).join(Attack).join((Protocol, Attack.protocols)).join(Client)
q = self.m_session.query(File).join(Attack).join((Protocol, Attack.protocols)).join(Client)
if groups:
q = q.filter(Client.set.in_(groups))
if clients:
......@@ -207,6 +186,12 @@ class Database(Database):
q = q.order_by(Client.id)
retval += list(q)
for f in retval:
f.original_directory = self.original_directory
f.original_extension = self.original_extension
f.annotation_directory = self.annotation_directory
f.annotation_extension = self.annotation_extension
f.annotation_type = self.annotation_type
return retval
def files(self, directory=None, extension=None, **object_query):
......@@ -244,33 +229,33 @@ class Database(Database):
"""Returns an iterable with all known clients"""
self.assert_validity()
return list(self.session.query(Client))
return list(self.m_session.query(Client))
def has_client_id(self, id):
"""Returns True if we have a client with a certain integer identifier"""
self.assert_validity()
return self.session.query(Client).filter(Client.id == id).count() != 0
return self.m_session.query(Client).filter(Client.id == id).count() != 0
def protocols(self):
"""Returns all protocol objects.
"""
self.assert_validity()
return list(self.session.query(Protocol))
return list(self.m_session.query(Protocol))
def has_protocol(self, name):
"""Tells if a certain protocol is available"""
self.assert_validity()
return self.session.query(Protocol).filter(Protocol.name == name).count() != 0
return self.m_session.query(Protocol).filter(Protocol.name == name).count() != 0
def protocol(self, name):
"""Returns the protocol object in the database given a certain name. Raises
an error if that does not exist."""
self.assert_validity()
return self.session.query(Protocol).filter(Protocol.name == name).one()
return self.m_session.query(Protocol).filter(Protocol.name == name).one()
def groups(self):
"""Returns the names of all registered groups"""
......@@ -330,7 +315,7 @@ class Database(Database):
self.assert_validity()
fobj = self.session.query(File).filter(File.id.in_(ids))
fobj = self.m_session.query(File).filter(File.id.in_(ids))
retval = []
for p in ids:
retval.extend([k.make_path(prefix, suffix) for k in fobj if k.id == p])
......@@ -350,7 +335,7 @@ class Database(Database):
self.assert_validity()
fobj = self.session.query(File).filter(File.path.in_(paths))
fobj = self.m_session.query(File).filter(File.path.in_(paths))
for p in paths:
retval.extend([k.id for k in fobj if k.path == p])
return retval
......@@ -387,7 +372,7 @@ class Database(Database):
self.assert_validity()
fobj = self.session.query(File).filter_by(id=id).one()
fobj = self.m_session.query(File).filter_by(id=id).one()
fullpath = os.path.join(directory, str(fobj.path) + extension)
fulldir = os.path.dirname(fullpath)
......
......@@ -39,6 +39,7 @@ class File(BaseFile):
def __init__(self, f, framen=None):
self._f = f
self.framen = framen
self.original_path = f.path
self.path = '{}_{:03d}'.format(f.path, framen)
self.client_id = f.client_id
self.file_id = '{}_{}'.format(f.id, framen)
......@@ -56,6 +57,10 @@ class File(BaseFile):
else:
return super(File, self).load(directory, extension)
@property
def annotations(self):
return self._f.annotations[str(self.framen)]
class Database(BaseDatabase):
"""
......@@ -66,11 +71,25 @@ class Database(BaseDatabase):
"""
__doc__ = __doc__
def __init__(self, max_number_of_frames=None, original_directory=None, original_extension=None):
super(Database, self).__init__(original_directory, original_extension)
def __init__(self,
max_number_of_frames=None,
original_directory=None,
original_extension=None,
annotation_directory=None,
annotation_extension='.json',
annotation_type='json',
):
# call base class constructors to open a session to the database
self._db = LDatabase()
self._db = LDatabase(
original_directory=original_directory,
original_extension=original_extension,
annotation_directory=annotation_directory,
annotation_extension=annotation_extension,
annotation_type=annotation_type,
)
super(Database, self).__init__(original_directory, original_extension)
self.max_number_of_frames = max_number_of_frames or 10
# 240 is the guaranteed number of frames in replay mobile videos
......@@ -78,6 +97,46 @@ class Database(BaseDatabase):
self.low_level_group_names = ('train', 'devel', 'test')
self.high_level_group_names = ('world', 'dev', 'eval')
@property
def original_directory(self):
return self._db.original_directory
@original_directory.setter
def original_directory(self, value):
self._db.original_directory = value
@property
def original_extension(self):
return self._db.original_extension
@original_extension.setter
def original_extension(self, value):
self._db.original_extension = value
@property
def annotation_directory(self):
return self._db.annotation_directory
@annotation_directory.setter
def annotation_directory(self, value):
self._db.annotation_directory = value
@property
def annotation_extension(self):
return self._db.annotation_extension
@annotation_extension.setter
def annotation_extension(self, value):
self._db.annotation_extension = value
@property
def annotation_type(self):
return self._db.annotation_type
@annotation_type.setter
def annotation_type(self, value):
self._db.annotation_type = value
def protocol_names(self):
"""Returns all registered protocol names
Here I am going to hack and double the number of protocols
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment