Commit 52799b42 authored by Laurent EL SHAFEY's avatar Laurent EL SHAFEY
Browse files

First attempt at using classes from models.py directly

parent b6df9c10
......@@ -6,6 +6,6 @@
"""
from .query import Database
from .models import Client, File, Protocol, ProtocolPurpose
__all__ = ['Database']
__all__ = dir()
......@@ -10,10 +10,10 @@ import os
from .models import *
def add_files(session, imagedir):
def add_files(session, imagedir, verbose):
"""Add files (and clients) to the BANCA database."""
def add_file(session, filename, client_dict):
def add_file(session, filename, client_dict, verbose):
"""Parse a single filename and add it to the list.
Also add a client entry if not already in the database."""
......@@ -24,51 +24,36 @@ def add_files(session, imagedir):
session.add(Client(int(v[0]), v[1], v[2], v[5]))
client_dict[v[0]] = True
session_id = int(v[3].split('s')[1])
if verbose: print "Adding file '%s'..." %(os.path.basename(filename).split('.')[0], )
session.add(File(int(v[0]), os.path.basename(filename).split('.')[0], v[4], v[6], session_id))
file_list = os.listdir(imagedir)
client_dict = {}
for filename in file_list:
add_file(session, os.path.join(imagedir, filename), client_dict)
add_file(session, os.path.join(imagedir, filename), client_dict, verbose)
def add_subworlds(session):
def add_subworlds(session, verbose):
"""Adds splits in the world set, based on the client ids"""
# one third
session.add(Subworld("onethird", 9003))
session.add(Subworld("onethird", 9005))
session.add(Subworld("onethird", 9027))
session.add(Subworld("onethird", 9033))
session.add(Subworld("onethird", 9035))
session.add(Subworld("onethird", 9043))
session.add(Subworld("onethird", 9049))
session.add(Subworld("onethird", 9053))
session.add(Subworld("onethird", 9055))
session.add(Subworld("onethird", 9057))
# two thirds
session.add(Subworld("twothirds", 9001))
session.add(Subworld("twothirds", 9007))
session.add(Subworld("twothirds", 9009))
session.add(Subworld("twothirds", 9011))
session.add(Subworld("twothirds", 9013))
session.add(Subworld("twothirds", 9015))
session.add(Subworld("twothirds", 9017))
session.add(Subworld("twothirds", 9019))
session.add(Subworld("twothirds", 9021))
session.add(Subworld("twothirds", 9023))
session.add(Subworld("twothirds", 9025))
session.add(Subworld("twothirds", 9029))
session.add(Subworld("twothirds", 9031))
session.add(Subworld("twothirds", 9037))
session.add(Subworld("twothirds", 9039))
session.add(Subworld("twothirds", 9041))
session.add(Subworld("twothirds", 9045))
session.add(Subworld("twothirds", 9047))
session.add(Subworld("twothirds", 9051))
session.add(Subworld("twothirds", 9059))
def add_sessions(session):
"""Adds relations between sessions and scenarios"""
# one third and two thirds
snames = ["onethird", "twothirds"]
slist = [ [9003, 9005, 9027, 9033, 9035, 9043, 9049, 9053, 9055, 9057],
[9001, 9007, 9009, 9011, 9013, 9015, 9017, 9019, 9021, 9023,
9025, 9029, 9031, 9037, 9039, 9041, 9045, 9047, 9051, 9059] ]
for k in range(len(snames)):
if verbose: print "Adding subworld '%s'" %(snames[k], )
su = Subworld(snames[k])
session.add(su)
session.flush()
session.refresh(su)
l = slist[k]
for c_id in l:
if verbose: print "Adding client '%d' to subworld '%s'..." %(c_id, snames[k])
su.clients.append(session.query(Client).filter(Client.id == c_id).first())
"""
def add_sessions(session, verbose):
""Adds relations between sessions and scenarios""
for i in range(1,5):
session.add(Session(i,'controlled'))
......@@ -76,75 +61,117 @@ def add_sessions(session):
session.add(Session(i,'degraded'))
for i in range(9,13):
session.add(Session(i,'adverse'))
"""
def add_protocols(session):
def add_protocols(session, verbose):
"""Adds protocols"""
# 1. DEFINITIONS
# Numbers in the lists correspond to session identifiers
protocol_definitions = {}
# Protocol Mc
session.add(Protocol(1, 'Mc', 'enrol'))
session.add(Protocol(1, 'Mc', 'probeImpostor'))
session.add(Protocol(2, 'Mc', 'probe'))
session.add(Protocol(3, 'Mc', 'probe'))
session.add(Protocol(4, 'Mc', 'probe'))
enrol = [1]
probe_c = [2, 3, 4]
probe_i = [1, 2, 3, 4]
protocol_definitions['Mc'] = [enrol, probe_c, probe_i]
# Protocol Md
session.add(Protocol(5, 'Md', 'enrol'))
session.add(Protocol(5, 'Md', 'probeImpostor'))
session.add(Protocol(6, 'Md', 'probe'))
session.add(Protocol(7, 'Md', 'probe'))
session.add(Protocol(8, 'Md', 'probe'))
enrol = [5]
probe_c = [6, 7, 8]
probe_i = [5, 6, 7, 8]
protocol_definitions['Md'] = [enrol, probe_c, probe_i]
# Protocol Ma
session.add(Protocol(9, 'Ma', 'enrol'))
session.add(Protocol(9, 'Ma', 'probeImpostor'))
session.add(Protocol(10, 'Ma', 'probe'))
session.add(Protocol(11, 'Ma', 'probe'))
session.add(Protocol(12, 'Ma', 'probe'))
enrol = [9]
probe_c = [10, 11, 12]
probe_i = [9, 10, 11, 12]
protocol_definitions['Ma'] = [enrol, probe_c, probe_i]
# Protocol Ud
session.add(Protocol(1, 'Ud', 'enrol'))
session.add(Protocol(5, 'Ud', 'probeImpostor'))
session.add(Protocol(6, 'Ud', 'probe'))
session.add(Protocol(7, 'Ud', 'probe'))
session.add(Protocol(8, 'Ud', 'probe'))
enrol = [1]
probe_c = [6, 7, 8]
probe_i = [5, 6, 7, 8]
protocol_definitions['Ud'] = [enrol, probe_c, probe_i]
# Protocol Ma
session.add(Protocol(1, 'Ua', 'enrol'))
session.add(Protocol(9, 'Ua', 'probeImpostor'))
session.add(Protocol(10, 'Ua', 'probe'))
session.add(Protocol(11, 'Ua', 'probe'))
session.add(Protocol(12, 'Ua', 'probe'))
# Protocol Ua
enrol = [1]
probe_c = [10, 11, 12]
probe_i = [9, 10, 11, 12]
protocol_definitions['Ua'] = [enrol, probe_c, probe_i]
# Protocol P
session.add(Protocol(1, 'P', 'enrol'))
session.add(Protocol(1, 'P', 'probeImpostor'))
session.add(Protocol(2, 'P', 'probe'))
session.add(Protocol(3, 'P', 'probe'))
session.add(Protocol(4, 'P', 'probe'))
session.add(Protocol(5, 'P', 'probeImpostor'))
session.add(Protocol(6, 'P', 'probe'))
session.add(Protocol(7, 'P', 'probe'))
session.add(Protocol(8, 'P', 'probe'))
session.add(Protocol(9, 'P', 'probeImpostor'))
session.add(Protocol(10, 'P', 'probe'))
session.add(Protocol(11, 'P', 'probe'))
session.add(Protocol(12, 'P', 'probe'))
enrol = [1]
probe_c = [2, 3, 4, 6, 7, 8, 10, 11, 12]
probe_i = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]
protocol_definitions['P'] = [enrol, probe_c, probe_i]
# Protocol G
session.add(Protocol(1, 'G', 'enrol'))
session.add(Protocol(5, 'G', 'enrol'))
session.add(Protocol(9, 'G', 'enrol'))
session.add(Protocol(1, 'G', 'probeImpostor'))
session.add(Protocol(2, 'G', 'probe'))
session.add(Protocol(3, 'G', 'probe'))
session.add(Protocol(4, 'G', 'probe'))
session.add(Protocol(5, 'G', 'probeImpostor'))
session.add(Protocol(6, 'G', 'probe'))
session.add(Protocol(7, 'G', 'probe'))
session.add(Protocol(8, 'G', 'probe'))
session.add(Protocol(9, 'G', 'probeImpostor'))
session.add(Protocol(10, 'G', 'probe'))
session.add(Protocol(11, 'G', 'probe'))
session.add(Protocol(12, 'G', 'probe'))
enrol = [1, 5, 9]
probe_c = [2, 3, 4, 6, 7, 8, 10, 11, 12]
probe_i = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]
protocol_definitions['G'] = [enrol, probe_c, probe_i]
# 2. ADDITIONS TO THE SQL DATABASE
protocolPurpose_list = [('world', 'train'), ('dev', 'enrol'), ('dev', 'probe'), ('eval', 'enrol'), ('eval', 'probe')]
for proto in protocol_definitions:
p = Protocol(proto)
# Add protocol
if verbose: print "Adding protocol %s..." % (proto)
session.add(p)
session.flush()
session.refresh(p)
# Add protocol purposes
for key in range(len(protocolPurpose_list)):
purpose = protocolPurpose_list[key]
pu = ProtocolPurpose(p.id, purpose[0], purpose[1])
if verbose: print " Adding protocol purpose ('%s','%s')..." % (purpose[0], purpose[1])
session.add(pu)
session.flush()
session.refresh(pu)
# Add files attached with this protocol purpose
client_group = ""
if(key == 0): client_group = "world"
elif(key == 1 or key == 2): client_group = "g1"
elif(key == 3 or key == 4): client_group = "g2"
session_list = []
session_list_i = []
if(key == 1 or key == 3):
session_list = protocol_definitions[proto][0]
elif(key == 2):
session_list = protocol_definitions[proto][1]
session_list_i = protocol_definitions[proto][2]
elif(key == 4):
session_list = protocol_definitions[proto][1]
session_list_i = protocol_definitions[proto][2]
# Adds 'regular' files (i.e. 'world', 'enrol', 'probe client')
if not session_list:
q = session.query(File).join(Client).filter(Client.sgroup == client_group).\
order_by(File.id)
for k in q:
if verbose: print " Adding protocol file '%s'..." % (k.path)
pu.files.append(k)
else:
for sid in session_list:
q = session.query(File).join(Client).filter(Client.sgroup == client_group).\
filter(and_(File.session_id == sid, File.real_id == File.claimed_id)).\
order_by(File.id)
for k in q:
if verbose: print " Adding protocol file '%s'..." % (k.path)
pu.files.append(k)
# Adds impostors if required
if session_list_i:
for sid in session_list_i:
q = session.query(File).join(Client).filter(Client.sgroup == client_group).\
filter(and_(File.session_id == sid, File.real_id != File.claimed_id)).\
order_by(File.id)
for k in q:
if verbose: print " Adding protocol file '%s'..." % (k.path)
pu.files.append(k)
def create_tables(args):
......@@ -153,11 +180,7 @@ def create_tables(args):
from bob.db.utils import create_engine_try_nolock
engine = create_engine_try_nolock(args.type, args.files[0], echo=(args.verbose >= 2))
File.metadata.create_all(engine)
Client.metadata.create_all(engine)
Subworld.metadata.create_all(engine)
Session.metadata.create_all(engine)
Protocol.metadata.create_all(engine)
Base.metadata.create_all(engine)
# Driver API
# ==========
......@@ -180,10 +203,10 @@ def create(args):
# the real work...
create_tables(args)
s = session_try_nolock(args.type, args.files[0], echo=(args.verbose >= 2))
add_files(s, args.imagedir)
add_subworlds(s)
add_sessions(s)
add_protocols(s)
add_files(s, args.imagedir, args.verbose)
add_subworlds(s, args.verbose)
#add_sessions(s, args.verbose)
add_protocols(s, args.verbose)
s.commit()
s.close()
......
......@@ -5,19 +5,35 @@
"""Table models and functionality for the BANCA database.
"""
from sqlalchemy import Column, Integer, String, ForeignKey, or_, and_
import os, numpy
import bob.db.utils
from sqlalchemy import Table, Column, Integer, String, ForeignKey, or_, and_
from bob.db.sqlalchemy_migration import Enum, relationship
from sqlalchemy.orm import backref
from sqlalchemy.ext.declarative import declarative_base
Base = declarative_base()
subworld_client_association = Table('subworld_client_association', Base.metadata,
Column('subworld_id', Integer, ForeignKey('subworld.id')),
Column('client_id', Integer, ForeignKey('client.id')))
protocolPurpose_file_association = Table('protocolPurpose_file_association', Base.metadata,
Column('protocolPurpose_id', Integer, ForeignKey('protocolPurpose.id')),
Column('file_id', Integer, ForeignKey('file.id')))
class Client(Base):
"""Database clients, marked by an integer identifier and the group they belong to"""
__tablename__ = 'client'
# Key identifier for the client
id = Column(Integer, primary_key=True)
# Gender to which the client belongs to
gender = Column(Enum('m','f'))
# Group to which the client belongs to
sgroup = Column(Enum('g1','g2','world')) # do NOT use group (SQL keyword)
# Language spoken by the client
language = Column(Enum('en','fr','sp'))
def __init__(self, id, gender, group, language):
......@@ -27,25 +43,29 @@ class Client(Base):
self.language = language
def __repr__(self):
return "<Client('%d', '%s', '%s', '%s')>" % (self.id, self.gender, self.sgroup, self.language)
return "Client('%d', '%s', '%s', '%s')" % (self.id, self.gender, self.sgroup, self.language)
class Subworld(Base):
"""Database clients belonging to the world group are split in two disjoint subworlds,
onethird and twothirds"""
__tablename__ = 'subworld'
# Key identifier for this Subworld object
id = Column(Integer, primary_key=True)
name = Column(Enum('onethird','twothirds'))
client_id = Column(Integer, ForeignKey('client.id')) # for SQL
# Subworld to which the client belongs to
name = Column(String(20), unique=True)
# for Python
real_client = relationship("Client", backref=backref("client_subworld"))
# for Python: A direct link to the client
clients = relationship("Client", secondary=subworld_client_association, backref=backref("subworld", order_by=id))
def __init__(self, name, client_id):
def __init__(self, name):
self.name = name
self.client_id = client_id
def __repr__(self):
print "<Subworld('%s', '%d')>" % (self.name, self.client_id)
return "Subworld('%s')" % (self.name)
"""
class Session(Base):
__tablename__ = 'session'
......@@ -58,33 +78,128 @@ class Session(Base):
def __repr__(self):
return "<Session('%d', '%s')>" % (self.id, self.scenario)
"""
class File(Base):
"""Generic file container"""
__tablename__ = 'file'
# Key identifier for the file
id = Column(Integer, primary_key=True)
# Key identifier of the client associated with this file
real_id = Column(Integer, ForeignKey('client.id')) # for SQL
# Unique path to this file inside the database
path = Column(String(100), unique=True)
claimed_id = Column(Integer) # not always the id of an existing client model
shot = Column(Integer)
session_id = Column(Integer, ForeignKey('session.id'))
# Identifier of the claimed client associated with this file
claimed_id = Column(Integer) # not always the id of an existing client model -> not a ForeignKey
# Identifier of the shot
shot_id = Column(Integer)
# Identifier of the session
session_id = Column(Integer)
# for Python
session = relationship("Session", backref=backref("session_file"))
real_client = relationship("Client", backref=backref("real_client_file"))
# For Python: A direct link to the client object that this file belongs to
real_client = relationship("Client", backref=backref("files", order_by=id))
def __init__(self, real_id, path, claimed_id, shot, session_id):
def __init__(self, real_id, path, claimed_id, shot_id, session_id):
self.real_id = real_id
self.path = path
self.claimed_id = claimed_id
self.shot = shot
self.shot_id = shot_id
self.session_id = session_id
def __repr__(self):
print "<File('%s')>" % self.path
return "File('%s')" % self.path
def make_path(self, directory=None, extension=None):
"""Wraps the current path so that a complete path is formed
Keyword parameters:
directory
An optional directory name that will be prefixed to the returned result.
extension
An optional extension that will be suffixed to the returned filename. The
extension normally includes the leading ``.`` character as in ``.jpg`` or
``.hdf5``.
Returns a string containing the newly generated file path.
"""
if not directory: directory = ''
if not extension: extension = ''
return os.path.join(directory, self.path + extension)
def save(self, data, directory=None, extension='.hdf5'):
"""Saves the input data at the specified location and using the given
extension.
Keyword parameters:
data
The data blob to be saved (normally a :py:class:`numpy.ndarray`).
directory
If not empty or None, this directory is prefixed to the final file
destination
extension
The extension of the filename - this will control the type of output and
the codec for saving the input blob.
"""
path = self.make_path(directory, extension)
bob.utils.makedirs_safe(os.path.dirname(path))
bob.io.save(data, path)
class Protocol(Base):
"""BANCA protocols"""
__tablename__ = 'protocol'
# Unique identifier for this protocol object
id = Column(Integer, primary_key=True)
# Name of the protocol associated with this object
name = Column(String(20), unique=True)
def __init__(self, name):
self.name = name
def __repr__(self):
return "Protocol('%s')" % (self.name,)
class ProtocolPurpose(Base):
"""BANCA protocol purposes"""
__tablename__ = 'protocolPurpose'
# Unique identifier for this protocol purpose object
id = Column(Integer, primary_key=True)
# Id of the protocol associated with this protocol purpose object
protocol_id = Column(Integer, ForeignKey('protocol.id')) # for SQL
# Group associated with this protocol purpose object
group_choices = ('world', 'dev', 'eval')
sgroup = Column(Enum(*group_choices))
# Purpose associated with this protocol purpose object
purpose_choices = ('train', 'enrol', 'probe')
purpose = Column(Enum(*purpose_choices))
# For Python: A direct link to the Protocol object that this ProtocolPurpose belongs to
protocol = relationship("Protocol", backref=backref("purposes", order_by=id))
# For Python: A direct link to the File objects associated with this ProtcolPurpose
files = relationship("File", secondary=protocolPurpose_file_association, backref=backref("protocolPurposes", order_by=id))
def __init__(self, protocol_id, sgroup, purpose):
self.protocol_id = protocol_id
self.sgroup = sgroup
self.purpose = purpose
def __repr__(self):
return "ProtocolPurpose('%s', '%s', '%s')" % (self.protocol.name, self.sgroup, self.purpose)
"""
class Protocol(Base):
__tablename__ = 'protocol'
......@@ -103,3 +218,4 @@ class Protocol(Base):
def __repr__(self):
return "<Protocol('%d', '%s', '%s')>" % (self.session_id, self.name, self.purpose)
"""
......@@ -39,6 +39,12 @@ class Database(object):
return self.session is not None
def assert_validity(self):
"""Raise a RuntimeError if the database backend is not available"""
if not self.is_valid():
raise RuntimeError, "Database '%s' cannot be found at expected location '%s'. Create it and then try re-connecting using Database.connect()" % (INFO.name(), SQLITE_FILE)
def __group_replace_alias__(self, l):
"""Replace 'dev' by 'g1' and 'eval' by 'g2' in a list of groups, and
returns the new list"""
......@@ -60,6 +66,18 @@ class Database(object):
raise RuntimeError, 'Invalid %s "%s". Valid values are %s, or lists/tuples of those' % (obj, k, valid)
return l
def groups(self):
"""Returns the names of all registered groups"""
return ProtocolPurpose.group_choices
def client_groups(self):
"""Returns the names of the XM2VTS groups. This is specific to this database which
does not have separate training, development and evaluation sets."""
return Client.group_choices
def clients(self, protocol=None, groups=None, gender=None, language=None, subworld=None):
"""Returns a set of clients for the specific query by the user.
......@@ -88,6 +106,8 @@ class Database(object):
properties.
"""
self.assert_validity()
groups = self.__group_replace_alias__(groups)
VALID_GROUPS = ('g1', 'g2', 'world')
VALID_GENDERS = ('m', 'f')
......@@ -108,8 +128,7 @@ class Database(object):
q = q.filter(Client.gender.in_(gender)).\
filter(Client.language.in_(language)).\
order_by(Client.id)
for id in [k.id for k in q]:
retval.append(id)
retval += list(q)
if 'g1' in groups or 'g2' in groups:
q = self.session.query(Client).filter(Client.sgroup != 'world').\
......@@ -117,8 +136,7 @@ class Database(object):
filter(Client.gender.in_(gender)).\
filter(Client.language.in_(language)).\
order_by(Client.id)
for id in [k.id for k in q]:
retval.append(id)
retval += list(q)
return retval
......@@ -224,6 +242,11 @@ class Database(object):
return self.zclients(protocol, groups)
def has_client_id(self, id):
"""Returns True if we have a client with a certain integer identifier"""
self.assert_validity()
return self.session.query(Client).filter(Client.id==id).count() != 0
def get_client_id_from_model_id(self, model_id):
"""Returns the client_id attached to the given model_id
......@@ -249,42 +272,6 @@ class Database(object):
"""
return model_id
def get_client_id_from_file_id(self, file_id):
"""Returns the client_id (real client id) attached to the given file_id
Keyword Parameters:
file_id
The file_id to consider
Returns: The client_id attached to the given file_id
"""
q = self.session.query(File).\
filter(File.id == file_id)
if q.count() !=1:
#throw exception?
return None