Commit 29d9c92c authored by Amir MOHAMMADI's avatar Amir MOHAMMADI

Use the SQL class from bob.db.base to benefit from improvements

parent b2f19681
Pipeline #26229 passed with stage
in 15 minutes and 22 seconds
......@@ -9,6 +9,7 @@
import os
from sqlalchemy import Table, Column, Integer, String, ForeignKey
from bob.db.base.sqlalchemy_migration import Enum, relationship
import bob.db.base
import bob.db.base.utils
from sqlalchemy.orm import backref
from sqlalchemy.ext.declarative import declarative_base
......@@ -43,7 +44,7 @@ class Client(Base):
return "Client('%s', '%s')" % (self.id, self.set)
class File(Base):
class File(Base, bob.db.base.File):
"""Generic file container"""
__tablename__ = 'file'
......@@ -71,33 +72,11 @@ class File(Base):
self.client = client
self.path = path
self.light = light
bob.db.base.File.__init__(path)
def __repr__(self):
return "File('%s')" % self.path
def make_path(self, directory=None, extension=None):
"""Wraps the current path so that a complete path is formed
Keyword parameters:
directory
An optional directory name that will be prefixed to the returned result.
extension
An optional extension that will be suffixed to the returned filename. The
extension normally includes the leading ``.`` character as in ``.jpg`` or
``.hdf5``.
Returns a string containing the newly generated file path.
"""
if not directory:
directory = ''
if not extension:
extension = ''
return str(os.path.join(directory, self.path + extension))
def videofile(self, directory=None):
"""Returns the path to the database video file for this object
......@@ -199,27 +178,6 @@ class File(Base):
return vin #bob.io.base.load(self.make_path(directory, extension))
def save(self, data, directory=None, extension='.hdf5'):
"""Saves the input data at the specified location and using the given
extension.
Keyword parameters:
data
The data blob to be saved (normally a :py:class:`numpy.ndarray`).
directory
[optional] If not empty or None, this directory is prefixed to the final
file destination
extension
[optional] The extension of the filename - this will control the type of
output and the codec for saving the input blob.
"""
path = self.make_path(directory, extension)
bob.io.base.create_directories_safe(os.path.dirname(path))
bob.io.base.save(data, path)
# Intermediate mapping from RealAccess's to Protocol's
realaccesses_protocols = Table('realaccesses_protocols', Base.metadata,
......
......@@ -8,7 +8,7 @@ replay attack database in the most obvious ways.
"""
import os
from bob.db.base import utils, Database
from bob.db.base import utils, SQLiteDatabase
from .models import *
from .driver import Interface
......@@ -17,50 +17,17 @@ INFO = Interface()
SQLITE_FILE = INFO.files()[0]
class Database(Database):
class Database(SQLiteDatabase):
"""The dataset class opens and maintains a connection opened to the Database.
It provides many different ways to probe for the characteristics of the data
and for the data itself inside the database.
"""
def __init__(self, original_directory=None, original_extension=None):
# opens a session to the database - keep it open until the end
self.connect()
super(Database, self).__init__(original_directory, original_extension)
def __del__(self):
"""Releases the opened file descriptor"""
if self.session:
try:
# Since the dispose function re-creates a pool
# which might fail in some conditions, e.g., when this destructor is called during the exit of the python interpreter
self.session.close()
self.session.bind.dispose()
except TypeError:
# ... I can just ignore the according exception...
pass
except AttributeError:
pass
def connect(self):
"""Tries connecting or re-connecting to the database"""
if not os.path.exists(SQLITE_FILE):
self.session = None
else:
self.session = utils.session_try_readonly(INFO.type(), SQLITE_FILE)
def is_valid(self):
"""Returns if a valid session has been opened for reading the database"""
return self.session is not None
def assert_validity(self):
"""Raise a RuntimeError if the database backend is not available"""
if not self.is_valid():
raise RuntimeError("Database '%s' cannot be found at expected location '%s'. Create it and then try re-connecting using Database.connect()" % (INFO.name(), SQLITE_FILE))
def __init__(self, original_directory=None, original_extension=None,
**kwargs):
super(Database, self).__init__(
SQLITE_FILE, File, original_directory, original_extension, **kwargs)
def objects(self, support=Attack.attack_support_choices,
protocol='grandtest', groups=Client.set_choices, cls=('attack', 'real'),
......@@ -113,7 +80,8 @@ class Database(Database):
return check_validity((l,), obj, valid, default)
for k in l:
if k not in valid:
raise RuntimeError('Invalid %s "%s". Valid values are %s, or lists/tuples of those' % (obj, k, valid))
raise RuntimeError(
'Invalid %s "%s". Valid values are %s, or lists/tuples of those' % (obj, k, valid))
return l
# check if groups set are valid
......@@ -132,7 +100,8 @@ class Database(Database):
if not protocol:
protocol = 'grandtest' # default
VALID_PROTOCOLS = [k.name for k in self.protocols()]
protocol = check_validity(protocol, "protocol", VALID_PROTOCOLS, ('grandtest',))
protocol = check_validity(
protocol, "protocol", VALID_PROTOCOLS, ('grandtest',))
# checks client identity validity
VALID_CLIENTS = [k.id for k in self.clients()]
......@@ -147,7 +116,7 @@ class Database(Database):
# real-accesses are simpler to query
if 'enroll' in cls:
q = self.session.query(File).join(RealAccess).join(Client)
q = self.m_session.query(File).join(RealAccess).join(Client)
if groups:
q = q.filter(Client.set.in_(groups))
if clients:
......@@ -160,7 +129,8 @@ class Database(Database):
# real-accesses are simpler to query
if 'real' in cls:
q = self.session.query(File).join(RealAccess).join((Protocol, RealAccess.protocols)).join(Client)
q = self.m_session.query(File).join(RealAccess).join(
(Protocol, RealAccess.protocols)).join(Client)
if groups:
q = q.filter(Client.set.in_(groups))
if clients:
......@@ -173,7 +143,8 @@ class Database(Database):
# attacks will have to be filtered a little bit more
if 'attack' in cls:
q = self.session.query(File).join(Attack).join((Protocol, Attack.protocols)).join(Client)
q = self.m_session.query(File).join(Attack).join(
(Protocol, Attack.protocols)).join(Client)
if groups:
q = q.filter(Client.set.in_(groups))
if clients:
......@@ -215,7 +186,8 @@ class Database(Database):
"""
import warnings
warnings.warn("The method Database.files() is deprecated, use Database.objects() for more powerful object retrieval", DeprecationWarning)
warnings.warn(
"The method Database.files() is deprecated, use Database.objects() for more powerful object retrieval", DeprecationWarning)
return dict([(k.id, k.make_path(directory, extension)) for k in self.objects(**object_query)])
......@@ -223,33 +195,33 @@ class Database(Database):
"""Returns an iterable with all known clients"""
self.assert_validity()
return list(self.session.query(Client))
return list(self.m_session.query(Client))
def has_client_id(self, id):
"""Returns True if we have a client with a certain integer identifier"""
self.assert_validity()
return self.session.query(Client).filter(Client.id == id).count() != 0
return self.m_session.query(Client).filter(Client.id == id).count() != 0
def protocols(self):
"""Returns all protocol objects.
"""
self.assert_validity()
return list(self.session.query(Protocol))
return list(self.m_session.query(Protocol))
def has_protocol(self, name):
"""Tells if a certain protocol is available"""
self.assert_validity()
return self.session.query(Protocol).filter(Protocol.name == name).count() != 0
return self.m_session.query(Protocol).filter(Protocol.name == name).count() != 0
def protocol(self, name):
"""Returns the protocol object in the database given a certain name. Raises
an error if that does not exist."""
self.assert_validity()
return self.session.query(Protocol).filter(Protocol.name == name).one()
return self.m_session.query(Protocol).filter(Protocol.name == name).one()
def groups(self):
"""Returns the names of all registered groups"""
......@@ -304,7 +276,7 @@ class Database(Database):
self.assert_validity()
fobj = self.session.query(File).filter(File.id.in_(ids))
fobj = self.m_session.query(File).filter(File.id.in_(ids))
retval = []
for p in ids:
retval.extend([k.make_path(prefix, suffix) for k in fobj if k.id == p])
......@@ -324,7 +296,7 @@ class Database(Database):
self.assert_validity()
fobj = self.session.query(File).filter(File.path.in_(paths))
fobj = self.m_session.query(File).filter(File.path.in_(paths))
for p in paths:
retval.extend([k.id for k in fobj if k.path == p])
return retval
......@@ -361,7 +333,7 @@ class Database(Database):
self.assert_validity()
fobj = self.session.query(File).filter_by(id=id).one()
fobj = self.m_session.query(File).filter_by(id=id).one()
fullpath = os.path.join(directory, str(fobj.path) + extension)
fulldir = os.path.dirname(fullpath)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment