Commit 52799b42 authored by Laurent EL SHAFEY's avatar Laurent EL SHAFEY

First attempt at using classes from models.py directly

parent b6df9c10
......@@ -6,6 +6,6 @@
"""
from .query import Database
from .models import Client, File, Protocol, ProtocolPurpose
__all__ = ['Database']
__all__ = dir()
This diff is collapsed.
......@@ -5,19 +5,35 @@
"""Table models and functionality for the BANCA database.
"""
from sqlalchemy import Column, Integer, String, ForeignKey, or_, and_
import os, numpy
import bob.db.utils
from sqlalchemy import Table, Column, Integer, String, ForeignKey, or_, and_
from bob.db.sqlalchemy_migration import Enum, relationship
from sqlalchemy.orm import backref
from sqlalchemy.ext.declarative import declarative_base
Base = declarative_base()
subworld_client_association = Table('subworld_client_association', Base.metadata,
Column('subworld_id', Integer, ForeignKey('subworld.id')),
Column('client_id', Integer, ForeignKey('client.id')))
protocolPurpose_file_association = Table('protocolPurpose_file_association', Base.metadata,
Column('protocolPurpose_id', Integer, ForeignKey('protocolPurpose.id')),
Column('file_id', Integer, ForeignKey('file.id')))
class Client(Base):
"""Database clients, marked by an integer identifier and the group they belong to"""
__tablename__ = 'client'
# Key identifier for the client
id = Column(Integer, primary_key=True)
# Gender to which the client belongs to
gender = Column(Enum('m','f'))
# Group to which the client belongs to
sgroup = Column(Enum('g1','g2','world')) # do NOT use group (SQL keyword)
# Language spoken by the client
language = Column(Enum('en','fr','sp'))
def __init__(self, id, gender, group, language):
......@@ -27,25 +43,29 @@ class Client(Base):
self.language = language
def __repr__(self):
return "<Client('%d', '%s', '%s', '%s')>" % (self.id, self.gender, self.sgroup, self.language)
return "Client('%d', '%s', '%s', '%s')" % (self.id, self.gender, self.sgroup, self.language)
class Subworld(Base):
"""Database clients belonging to the world group are split in two disjoint subworlds,
onethird and twothirds"""
__tablename__ = 'subworld'
# Key identifier for this Subworld object
id = Column(Integer, primary_key=True)
name = Column(Enum('onethird','twothirds'))
client_id = Column(Integer, ForeignKey('client.id')) # for SQL
# Subworld to which the client belongs to
name = Column(String(20), unique=True)
# for Python
real_client = relationship("Client", backref=backref("client_subworld"))
# for Python: A direct link to the client
clients = relationship("Client", secondary=subworld_client_association, backref=backref("subworld", order_by=id))
def __init__(self, name, client_id):
def __init__(self, name):
self.name = name
self.client_id = client_id
def __repr__(self):
print "<Subworld('%s', '%d')>" % (self.name, self.client_id)
return "Subworld('%s')" % (self.name)
"""
class Session(Base):
__tablename__ = 'session'
......@@ -58,33 +78,128 @@ class Session(Base):
def __repr__(self):
return "<Session('%d', '%s')>" % (self.id, self.scenario)
"""
class File(Base):
"""Generic file container"""
__tablename__ = 'file'
# Key identifier for the file
id = Column(Integer, primary_key=True)
# Key identifier of the client associated with this file
real_id = Column(Integer, ForeignKey('client.id')) # for SQL
# Unique path to this file inside the database
path = Column(String(100), unique=True)
claimed_id = Column(Integer) # not always the id of an existing client model
shot = Column(Integer)
session_id = Column(Integer, ForeignKey('session.id'))
# Identifier of the claimed client associated with this file
claimed_id = Column(Integer) # not always the id of an existing client model -> not a ForeignKey
# Identifier of the shot
shot_id = Column(Integer)
# Identifier of the session
session_id = Column(Integer)
# for Python
session = relationship("Session", backref=backref("session_file"))
real_client = relationship("Client", backref=backref("real_client_file"))
# For Python: A direct link to the client object that this file belongs to
real_client = relationship("Client", backref=backref("files", order_by=id))
def __init__(self, real_id, path, claimed_id, shot, session_id):
def __init__(self, real_id, path, claimed_id, shot_id, session_id):
self.real_id = real_id
self.path = path
self.claimed_id = claimed_id
self.shot = shot
self.shot_id = shot_id
self.session_id = session_id
def __repr__(self):
print "<File('%s')>" % self.path
return "File('%s')" % self.path
def make_path(self, directory=None, extension=None):
"""Wraps the current path so that a complete path is formed
Keyword parameters:
directory
An optional directory name that will be prefixed to the returned result.
extension
An optional extension that will be suffixed to the returned filename. The
extension normally includes the leading ``.`` character as in ``.jpg`` or
``.hdf5``.
Returns a string containing the newly generated file path.
"""
if not directory: directory = ''
if not extension: extension = ''
return os.path.join(directory, self.path + extension)
def save(self, data, directory=None, extension='.hdf5'):
"""Saves the input data at the specified location and using the given
extension.
Keyword parameters:
data
The data blob to be saved (normally a :py:class:`numpy.ndarray`).
directory
If not empty or None, this directory is prefixed to the final file
destination
extension
The extension of the filename - this will control the type of output and
the codec for saving the input blob.
"""
path = self.make_path(directory, extension)
bob.utils.makedirs_safe(os.path.dirname(path))
bob.io.save(data, path)
class Protocol(Base):
"""BANCA protocols"""
__tablename__ = 'protocol'
# Unique identifier for this protocol object
id = Column(Integer, primary_key=True)
# Name of the protocol associated with this object
name = Column(String(20), unique=True)
def __init__(self, name):
self.name = name
def __repr__(self):
return "Protocol('%s')" % (self.name,)
class ProtocolPurpose(Base):
"""BANCA protocol purposes"""
__tablename__ = 'protocolPurpose'
# Unique identifier for this protocol purpose object
id = Column(Integer, primary_key=True)
# Id of the protocol associated with this protocol purpose object
protocol_id = Column(Integer, ForeignKey('protocol.id')) # for SQL
# Group associated with this protocol purpose object
group_choices = ('world', 'dev', 'eval')
sgroup = Column(Enum(*group_choices))
# Purpose associated with this protocol purpose object
purpose_choices = ('train', 'enrol', 'probe')
purpose = Column(Enum(*purpose_choices))
# For Python: A direct link to the Protocol object that this ProtocolPurpose belongs to
protocol = relationship("Protocol", backref=backref("purposes", order_by=id))
# For Python: A direct link to the File objects associated with this ProtcolPurpose
files = relationship("File", secondary=protocolPurpose_file_association, backref=backref("protocolPurposes", order_by=id))
def __init__(self, protocol_id, sgroup, purpose):
self.protocol_id = protocol_id
self.sgroup = sgroup
self.purpose = purpose
def __repr__(self):
return "ProtocolPurpose('%s', '%s', '%s')" % (self.protocol.name, self.sgroup, self.purpose)
"""
class Protocol(Base):
__tablename__ = 'protocol'
......@@ -103,3 +218,4 @@ class Protocol(Base):
def __repr__(self):
return "<Protocol('%d', '%s', '%s')>" % (self.session_id, self.name, self.purpose)
"""
......@@ -39,6 +39,12 @@ class Database(object):
return self.session is not None
def assert_validity(self):
"""Raise a RuntimeError if the database backend is not available"""
if not self.is_valid():
raise RuntimeError, "Database '%s' cannot be found at expected location '%s'. Create it and then try re-connecting using Database.connect()" % (INFO.name(), SQLITE_FILE)
def __group_replace_alias__(self, l):
"""Replace 'dev' by 'g1' and 'eval' by 'g2' in a list of groups, and
returns the new list"""
......@@ -60,6 +66,18 @@ class Database(object):
raise RuntimeError, 'Invalid %s "%s". Valid values are %s, or lists/tuples of those' % (obj, k, valid)
return l
def groups(self):
"""Returns the names of all registered groups"""
return ProtocolPurpose.group_choices
def client_groups(self):
"""Returns the names of the XM2VTS groups. This is specific to this database which
does not have separate training, development and evaluation sets."""
return Client.group_choices
def clients(self, protocol=None, groups=None, gender=None, language=None, subworld=None):
"""Returns a set of clients for the specific query by the user.
......@@ -88,6 +106,8 @@ class Database(object):
properties.
"""
self.assert_validity()
groups = self.__group_replace_alias__(groups)
VALID_GROUPS = ('g1', 'g2', 'world')
VALID_GENDERS = ('m', 'f')
......@@ -108,8 +128,7 @@ class Database(object):
q = q.filter(Client.gender.in_(gender)).\
filter(Client.language.in_(language)).\
order_by(Client.id)
for id in [k.id for k in q]:
retval.append(id)
retval += list(q)
if 'g1' in groups or 'g2' in groups:
q = self.session.query(Client).filter(Client.sgroup != 'world').\
......@@ -117,8 +136,7 @@ class Database(object):
filter(Client.gender.in_(gender)).\
filter(Client.language.in_(language)).\
order_by(Client.id)
for id in [k.id for k in q]:
retval.append(id)
retval += list(q)
return retval
......@@ -224,6 +242,11 @@ class Database(object):
return self.zclients(protocol, groups)
def has_client_id(self, id):
"""Returns True if we have a client with a certain integer identifier"""
self.assert_validity()
return self.session.query(Client).filter(Client.id==id).count() != 0
def get_client_id_from_model_id(self, model_id):
"""Returns the client_id attached to the given model_id
......@@ -249,42 +272,6 @@ class Database(object):
"""
return model_id
def get_client_id_from_file_id(self, file_id):
"""Returns the client_id (real client id) attached to the given file_id
Keyword Parameters:
file_id
The file_id to consider
Returns: The client_id attached to the given file_id
"""
q = self.session.query(File).\
filter(File.id == file_id)
if q.count() !=1:
#throw exception?
return None
else:
return q.first().real_id
def get_internal_path_from_file_id(self, file_id):
"""Returns the unique "internal path" attached to the given file_id
Keyword Parameters:
file_id
The file_id to consider
Returns: The internal path attached to the given file_id
"""
q = self.session.query(File).\
filter(File.id == file_id)
if q.count() !=1:
#throw exception?
return None
else:
return q.first().path
def objects(self, directory=None, extension=None, protocol=None,
purposes=None, model_ids=None, groups=None, classes=None,
languages=None, subworld=None):
......@@ -313,10 +300,9 @@ class Database(object):
the model_ids is performed.
groups
One of the groups ("g1", "g2", "world") or a tuple with several of them.
One of the groups ("dev", "eval", "world") or a tuple with several of them.
If 'None' is given (this is the default), it is considered the same as a
tuple with all possible values.
Note that 'dev' is an alias to 'g1' and 'test' an alias to 'g2'
classes
The classes (types of accesses) to be retrieved ('client', 'impostor')
......@@ -353,10 +339,11 @@ class Database(object):
if directory: return os.path.join(directory, stem + extension)
return stem + extension
groups = self.__group_replace_alias__(groups)
self.assert_validity()
VALID_PROTOCOLS = ('Mc', 'Md', 'Ma', 'Ud', 'Ua', 'P', 'G')
VALID_PURPOSES = ('enrol', 'probe')
VALID_GROUPS = ('g1', 'g2', 'world')
VALID_GROUPS = ('dev', 'eval', 'world')
VALID_LANGUAGES = ('en', 'fr', 'sp')
VALID_CLASSES = ('client', 'impostor')
VALID_SUBWORLDS = ('onethird', 'twothirds')
......@@ -368,63 +355,52 @@ class Database(object):
classes = self.__check_validity__(classes, "class", VALID_CLASSES)
subworld = self.__check_validity__(subworld, "subworld", VALID_SUBWORLDS)
retval = {}
retval = []
if(isinstance(model_ids,str)):
model_ids = (model_ids,)
if 'world' in groups:
q = self.session.query(File).join(Client).join(ProtocolPurpose, File.protocolPurposes).join(Protocol)
if len(subworld) == 1:
q = self.session.query(File).join(Client).join(Subworld).filter(Subworld.name.in_(subworld))
else:
q = self.session.query(File).join(Client)
q = q.join(Subworld).filter(Subworld.name.in_(subworld))
q = q.filter(Client.sgroup == 'world').\
filter(and_(Protocol.name.in_(protocol), ProtocolPurpose.sgroup == 'world')).\
filter(Client.language.in_(languages))
if model_ids:
q = q.filter(File.real_id.in_(model_ids))
q = q.order_by(File.real_id, File.session_id, File.claimed_id, File.shot)
for k in q:
retval[k.id] = (make_path(k.path, directory, extension), k.claimed_id, k.claimed_id, k.real_id, k.path)
q = q.filter(Client.id.in_(model_ids))
q = q.order_by(File.real_id, File.session_id, File.claimed_id, File.shot_id)
retval += list(q)
if ('g1' in groups or 'g2' in groups):
if ('dev' in groups or 'eval' in groups):
if('enrol' in purposes):
q = self.session.query(File).join(Client).join(Session).join(Protocol).\
filter(File.claimed_id == File.real_id).\
filter(Client.sgroup.in_(groups)).\
filter(Client.language.in_(languages)).\
filter(Protocol.name.in_(protocol)).\
filter(Protocol.purpose == 'enrol')
q = self.session.query(File).join(Client).join(ProtocolPurpose, File.protocolPurposes).join(Protocol).\
filter(and_(Protocol.name.in_(protocol), ProtocolPurpose.sgroup.in_(groups), ProtocolPurpose.purpose == 'enrol'))
if model_ids:
q = q.filter(File.claimed_id.in_(model_ids))
q = q.order_by(File.claimed_id, File.session_id, File.real_id, File.shot)
for k in q:
retval[k.id] = (make_path(k.path, directory, extension), k.claimed_id, k.claimed_id, k.real_id, k.path)
q = q.filter(Client.id.in_(model_ids))
q = q.order_by(File.real_id, File.session_id, File.claimed_id, File.shot_id)
retval += list(q)
if('probe' in purposes):
if('client' in classes):
q = self.session.query(File).join(Client).join(Session).join(Protocol).\
filter(File.claimed_id == File.real_id).\
filter(Client.sgroup.in_(groups)).\
filter(Client.language.in_(languages)).\
filter(Protocol.name.in_(protocol)).\
filter(Protocol.purpose == 'probe')
q = self.session.query(File).join(Client).join(ProtocolPurpose, File.protocolPurposes).join(Protocol).\
filter(File.real_id == File.claimed_id).\
filter(and_(Protocol.name.in_(protocol), ProtocolPurpose.sgroup.in_(groups), ProtocolPurpose.purpose == 'probe'))
if model_ids:
q = q.filter(File.claimed_id.in_(model_ids))
q = q.order_by(File.claimed_id, File.session_id, File.real_id, File.shot)
for k in q:
retval[k.id] = (make_path(k.path, directory, extension), k.claimed_id, k.claimed_id, k.real_id, k.path)
q = q.filter(Client.id.in_(model_ids))
q = q.order_by(File.real_id, File.session_id, File.claimed_id, File.shot_id)
retval += list(q)
if('impostor' in classes):
q = self.session.query(File).join(Client).join(Session).join(Protocol).\
filter(File.claimed_id != File.real_id).\
filter(Client.sgroup.in_(groups)).\
filter(Client.language.in_(languages)).\
filter(Protocol.name.in_(protocol)).\
filter(or_(Protocol.purpose == 'probe', Protocol.purpose == 'probeImpostor'))
q = self.session.query(File).join(Client).join(ProtocolPurpose, File.protocolPurposes).join(Protocol).\
filter(File.real_id != File.claimed_id).\
filter(and_(Protocol.name.in_(protocol), ProtocolPurpose.sgroup.in_(groups), ProtocolPurpose.purpose == 'probe'))
if model_ids:
q = q.filter(File.claimed_id.in_(model_ids))
for k in q:
retval[k.id] = (make_path(k.path, directory, extension), k.claimed_id, k.claimed_id, k.real_id, k.path)
return retval
q = q.order_by(File.real_id, File.session_id, File.claimed_id, File.shot_id)
retval += list(q)
return list(set(retval)) # To remove duplicates
def files(self, directory=None, extension=None, protocol=None,
purposes=None, model_ids=None, groups=None, classes=None,
......@@ -587,7 +563,7 @@ class Database(object):
numbers if you wish to save processing results later on.
"""
retval = {}
retval = []
d = self.tobjects(directory, extension, protocol, model_ids, groups, languages)
for k in d: retval[k] = d[k][0]
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment