Commit 29d9c92c authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

Use the SQL class from bob.db.base to benefit from improvements

parent b2f19681
Pipeline #26229 passed with stage
in 15 minutes and 22 seconds
......@@ -9,6 +9,7 @@
import os
from sqlalchemy import Table, Column, Integer, String, ForeignKey
from bob.db.base.sqlalchemy_migration import Enum, relationship
import bob.db.base
import bob.db.base.utils
from sqlalchemy.orm import backref
from sqlalchemy.ext.declarative import declarative_base
......@@ -43,7 +44,7 @@ class Client(Base):
return "Client('%s', '%s')" % (, self.set)
class File(Base):
class File(Base, bob.db.base.File):
"""Generic file container"""
__tablename__ = 'file'
......@@ -71,33 +72,11 @@ class File(Base):
self.client = client
self.path = path
self.light = light
def __repr__(self):
return "File('%s')" % self.path
def make_path(self, directory=None, extension=None):
"""Wraps the current path so that a complete path is formed
Keyword parameters:
An optional directory name that will be prefixed to the returned result.
An optional extension that will be suffixed to the returned filename. The
extension normally includes the leading ``.`` character as in ``.jpg`` or
Returns a string containing the newly generated file path.
if not directory:
directory = ''
if not extension:
extension = ''
return str(os.path.join(directory, self.path + extension))
def videofile(self, directory=None):
"""Returns the path to the database video file for this object
......@@ -199,27 +178,6 @@ class File(Base):
return vin, extension))
def save(self, data, directory=None, extension='.hdf5'):
"""Saves the input data at the specified location and using the given
Keyword parameters:
The data blob to be saved (normally a :py:class:`numpy.ndarray`).
[optional] If not empty or None, this directory is prefixed to the final
file destination
[optional] The extension of the filename - this will control the type of
output and the codec for saving the input blob.
path = self.make_path(directory, extension), path)
# Intermediate mapping from RealAccess's to Protocol's
realaccesses_protocols = Table('realaccesses_protocols', Base.metadata,
......@@ -8,7 +8,7 @@ replay attack database in the most obvious ways.
import os
from bob.db.base import utils, Database
from bob.db.base import utils, SQLiteDatabase
from .models import *
from .driver import Interface
......@@ -17,50 +17,17 @@ INFO = Interface()
SQLITE_FILE = INFO.files()[0]
class Database(Database):
class Database(SQLiteDatabase):
"""The dataset class opens and maintains a connection opened to the Database.
It provides many different ways to probe for the characteristics of the data
and for the data itself inside the database.
def __init__(self, original_directory=None, original_extension=None):
# opens a session to the database - keep it open until the end
super(Database, self).__init__(original_directory, original_extension)
def __del__(self):
"""Releases the opened file descriptor"""
if self.session:
# Since the dispose function re-creates a pool
# which might fail in some conditions, e.g., when this destructor is called during the exit of the python interpreter
except TypeError:
# ... I can just ignore the according exception...
except AttributeError:
def connect(self):
"""Tries connecting or re-connecting to the database"""
if not os.path.exists(SQLITE_FILE):
self.session = None
self.session = utils.session_try_readonly(INFO.type(), SQLITE_FILE)
def is_valid(self):
"""Returns if a valid session has been opened for reading the database"""
return self.session is not None
def assert_validity(self):
"""Raise a RuntimeError if the database backend is not available"""
if not self.is_valid():
raise RuntimeError("Database '%s' cannot be found at expected location '%s'. Create it and then try re-connecting using Database.connect()" % (, SQLITE_FILE))
def __init__(self, original_directory=None, original_extension=None,
super(Database, self).__init__(
SQLITE_FILE, File, original_directory, original_extension, **kwargs)
def objects(self, support=Attack.attack_support_choices,
protocol='grandtest', groups=Client.set_choices, cls=('attack', 'real'),
......@@ -113,7 +80,8 @@ class Database(Database):
return check_validity((l,), obj, valid, default)
for k in l:
if k not in valid:
raise RuntimeError('Invalid %s "%s". Valid values are %s, or lists/tuples of those' % (obj, k, valid))
raise RuntimeError(
'Invalid %s "%s". Valid values are %s, or lists/tuples of those' % (obj, k, valid))
return l
# check if groups set are valid
......@@ -132,7 +100,8 @@ class Database(Database):
if not protocol:
protocol = 'grandtest' # default
VALID_PROTOCOLS = [ for k in self.protocols()]
protocol = check_validity(protocol, "protocol", VALID_PROTOCOLS, ('grandtest',))
protocol = check_validity(
protocol, "protocol", VALID_PROTOCOLS, ('grandtest',))
# checks client identity validity
VALID_CLIENTS = [ for k in self.clients()]
......@@ -147,7 +116,7 @@ class Database(Database):
# real-accesses are simpler to query
if 'enroll' in cls:
q = self.session.query(File).join(RealAccess).join(Client)
q = self.m_session.query(File).join(RealAccess).join(Client)
if groups:
q = q.filter(Client.set.in_(groups))
if clients:
......@@ -160,7 +129,8 @@ class Database(Database):
# real-accesses are simpler to query
if 'real' in cls:
q = self.session.query(File).join(RealAccess).join((Protocol, RealAccess.protocols)).join(Client)
q = self.m_session.query(File).join(RealAccess).join(
(Protocol, RealAccess.protocols)).join(Client)
if groups:
q = q.filter(Client.set.in_(groups))
if clients:
......@@ -173,7 +143,8 @@ class Database(Database):
# attacks will have to be filtered a little bit more
if 'attack' in cls:
q = self.session.query(File).join(Attack).join((Protocol, Attack.protocols)).join(Client)
q = self.m_session.query(File).join(Attack).join(
(Protocol, Attack.protocols)).join(Client)
if groups:
q = q.filter(Client.set.in_(groups))
if clients:
......@@ -215,7 +186,8 @@ class Database(Database):
import warnings
warnings.warn("The method Database.files() is deprecated, use Database.objects() for more powerful object retrieval", DeprecationWarning)
"The method Database.files() is deprecated, use Database.objects() for more powerful object retrieval", DeprecationWarning)
return dict([(, k.make_path(directory, extension)) for k in self.objects(**object_query)])
......@@ -223,33 +195,33 @@ class Database(Database):
"""Returns an iterable with all known clients"""
return list(self.session.query(Client))
return list(self.m_session.query(Client))
def has_client_id(self, id):
"""Returns True if we have a client with a certain integer identifier"""
return self.session.query(Client).filter( == id).count() != 0
return self.m_session.query(Client).filter( == id).count() != 0
def protocols(self):
"""Returns all protocol objects.
return list(self.session.query(Protocol))
return list(self.m_session.query(Protocol))
def has_protocol(self, name):
"""Tells if a certain protocol is available"""
return self.session.query(Protocol).filter( == name).count() != 0
return self.m_session.query(Protocol).filter( == name).count() != 0
def protocol(self, name):
"""Returns the protocol object in the database given a certain name. Raises
an error if that does not exist."""
return self.session.query(Protocol).filter( == name).one()
return self.m_session.query(Protocol).filter( == name).one()
def groups(self):
"""Returns the names of all registered groups"""
......@@ -304,7 +276,7 @@ class Database(Database):
fobj = self.session.query(File).filter(
fobj = self.m_session.query(File).filter(
retval = []
for p in ids:
retval.extend([k.make_path(prefix, suffix) for k in fobj if == p])
......@@ -324,7 +296,7 @@ class Database(Database):
fobj = self.session.query(File).filter(File.path.in_(paths))
fobj = self.m_session.query(File).filter(File.path.in_(paths))
for p in paths:
retval.extend([ for k in fobj if k.path == p])
return retval
......@@ -361,7 +333,7 @@ class Database(Database):
fobj = self.session.query(File).filter_by(id=id).one()
fobj = self.m_session.query(File).filter_by(id=id).one()
fullpath = os.path.join(directory, str(fobj.path) + extension)
fulldir = os.path.dirname(fullpath)
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment