Commit cd691561 authored by Amir Mohammadi's avatar Amir Mohammadi Committed by GitHub

Merge pull request #6 from bioidiap/refactoring_2016

Refactoring 2016
parents af4dcda7 6151d0e0
......@@ -19,32 +19,32 @@
"""Table models and functionality for the Mobio database.
"""
import os, numpy
import bob.db.base.utils
from sqlalchemy import Table, Column, Integer, String, ForeignKey, or_, and_, not_
from sqlalchemy import Table, Column, Integer, String, ForeignKey
from bob.db.base.sqlalchemy_migration import Enum, relationship
from sqlalchemy.orm import backref
from sqlalchemy.ext.declarative import declarative_base
import bob.db.verification.utils
import bob.db.base
Base = declarative_base()
subworld_client_association = Table('subworld_client_association', Base.metadata,
Column('subworld_id', Integer, ForeignKey('subworld.id')),
Column('client_id', Integer, ForeignKey('client.id')))
Column('subworld_id', Integer, ForeignKey('subworld.id')),
Column('client_id', Integer, ForeignKey('client.id')))
subworld_file_association = Table('subworld_file_association', Base.metadata,
Column('subworld_id', Integer, ForeignKey('subworld.id')),
Column('file_id', Integer, ForeignKey('file.id')))
Column('subworld_id', Integer, ForeignKey('subworld.id')),
Column('file_id', Integer, ForeignKey('file.id')))
tmodel_file_association = Table('tmodel_file_association', Base.metadata,
Column('tmodel_id', String, ForeignKey('tmodel.id')),
Column('file_id', Integer, ForeignKey('file.id')))
Column('tmodel_id', String, ForeignKey('tmodel.id')),
Column('file_id', Integer, ForeignKey('file.id')))
protocolPurpose_file_association = Table('protocolPurpose_file_association', Base.metadata,
Column('protocolPurpose_id', Integer, ForeignKey('protocolPurpose.id')),
Column('file_id', Integer, ForeignKey('file.id')))
Column('protocolPurpose_id', Integer, ForeignKey('protocolPurpose.id')),
Column('file_id', Integer, ForeignKey('file.id')))
class Client(Base):
"""Database clients, marked by an integer identifier and the group they belong to"""
......@@ -54,11 +54,11 @@ class Client(Base):
# Key identifier for the client
id = Column(Integer, primary_key=True)
# Gender to which the client belongs to
gender_choices = ('female','male')
gender_choices = ('female', 'male')
gender = Column(Enum(*gender_choices))
# Group to which the client belongs to
group_choices = ('dev','eval','world')
sgroup = Column(Enum(*group_choices)) # do NOT use group (SQL keyword)
group_choices = ('dev', 'eval', 'world')
sgroup = Column(Enum(*group_choices)) # do NOT use group (SQL keyword)
# Institute to which the client belongs to
institute_choices = ('idiap', 'manchester', 'surrey', 'oulu', 'brno', 'avignon')
institute = Column(Enum(*institute_choices))
......@@ -72,6 +72,7 @@ class Client(Base):
def __repr__(self):
return "Client('%d', '%s')" % (self.id, self.sgroup)
class Subworld(Base):
"""Database clients belonging to the world group are split in subsets"""
......@@ -93,6 +94,7 @@ class Subworld(Base):
def __repr__(self):
return "Subworld('%s')" % (self.name)
class TModel(Base):
"""T-Norm models"""
......@@ -102,8 +104,8 @@ class TModel(Base):
id = Column(Integer, primary_key=True)
# Model id (only unique for a given protocol)
mid = Column(String(9))
client_id = Column(Integer, ForeignKey('client.id')) # for SQL
protocol_id = Column(Integer, ForeignKey('protocol.id')) # for SQL
client_id = Column(Integer, ForeignKey('client.id')) # for SQL
protocol_id = Column(Integer, ForeignKey('protocol.id')) # for SQL
# for Python: A direct link to the client
client = relationship("Client", backref=backref("tmodels", order_by=id))
......@@ -120,7 +122,8 @@ class TModel(Base):
def __repr__(self):
return "TModel('%s', '%s')" % (self.mid, self.protocol_id)
class File(Base, bob.db.verification.utils.File):
class File(Base, bob.db.base.File):
"""Generic file container"""
__tablename__ = 'file'
......@@ -128,13 +131,13 @@ class File(Base, bob.db.verification.utils.File):
# Key identifier for the file
id = Column(Integer, primary_key=True)
# Key identifier of the client associated with this file
client_id = Column(Integer, ForeignKey('client.id')) # for SQL
client_id = Column(Integer, ForeignKey('client.id')) # for SQL
# Unique path to this file inside the database
path = Column(String(100), unique=True)
# Identifier of the session
session_id = Column(Integer)
# Speech type
speech_type_choices = ('p','l','r','f')
speech_type_choices = ('p', 'l', 'r', 'f')
speech_type = Column(Enum(*speech_type_choices))
# Identifier of the shot
shot_id = Column(Integer)
......@@ -152,7 +155,7 @@ class File(Base, bob.db.verification.utils.File):
def __init__(self, client_id, path, session_id, speech_type, shot_id, environment, device, channel_id):
# call base class constructor
bob.db.verification.utils.File.__init__(self, client_id = client_id, path = path)
bob.db.base.File.__init__(self, client_id=client_id, path=path)
# fill the remaining bits of the file information
self.session_id = session_id
......@@ -162,6 +165,7 @@ class File(Base, bob.db.verification.utils.File):
self.device = device
self.channel_id = channel_id
class Protocol(Base):
"""MOBIO protocols"""
......@@ -171,7 +175,7 @@ class Protocol(Base):
id = Column(Integer, primary_key=True)
# Name of the protocol associated with this object
name = Column(String(20), unique=True)
gender_choices = ('female','male')
gender_choices = ('female', 'male')
gender = Column(Enum(*gender_choices))
def __init__(self, name, gender):
......@@ -181,6 +185,7 @@ class Protocol(Base):
def __repr__(self):
return "Protocol('%s','%s')" % (self.name, self.gender)
class ProtocolPurpose(Base):
"""MOBIO protocol purposes"""
......@@ -189,7 +194,7 @@ class ProtocolPurpose(Base):
# Unique identifier for this protocol purpose object
id = Column(Integer, primary_key=True)
# Id of the protocol associated with this protocol purpose object
protocol_id = Column(Integer, ForeignKey('protocol.id')) # for SQL
protocol_id = Column(Integer, ForeignKey('protocol.id')) # for SQL
# Group associated with this protocol purpose object
group_choices = Client.group_choices
sgroup = Column(Enum(*group_choices))
......
......@@ -20,27 +20,26 @@
MOBIO database in the most obvious ways.
"""
import os
import six
from bob.db.base import utils
from .models import *
from .driver import Interface
from sqlalchemy import and_, not_
import bob.db.verification.utils
import bob.db.base
SQLITE_FILE = Interface().files()[0]
class Database(bob.db.verification.utils.SQLiteDatabase, bob.db.verification.utils.ZTDatabase):
class Database(bob.db.base.SQLiteDatabase):
"""The dataset class opens and maintains a connection opened to the Database.
It provides many different ways to probe for the characteristics of the data
and for the data itself inside the database.
"""
def __init__(self, original_directory = None, original_extension = None, annotation_directory = None, annotation_extension = '.pos'):
def __init__(self, original_directory=None, original_extension=None, annotation_directory=None, annotation_extension='.pos'):
# call base class constructors to open a session to the database
bob.db.verification.utils.SQLiteDatabase.__init__(self, SQLITE_FILE, File)
bob.db.verification.utils.ZTDatabase.__init__(self, original_directory=original_directory, original_extension=original_extension)
super(Database, self).__init__(SQLITE_FILE, File)
self.annotation_directory = annotation_directory
self.annotation_extension = annotation_extension
......@@ -72,24 +71,28 @@ class Database(bob.db.verification.utils.SQLiteDatabase, bob.db.verification.uti
"""Tells if a certain subworld is available"""
self.assert_validity()
return self.query(Subworld).filter(Subworld.name==name).count() != 0
return self.query(Subworld).filter(Subworld.name == name).count() != 0
def _replace_protocol_alias(self, protocol):
if protocol == 'male': return 'mobile0-male'
elif protocol == 'female': return 'mobile0-female'
else: return protocol
if protocol == 'male':
return 'mobile0-male'
elif protocol == 'female':
return 'mobile0-female'
else:
return protocol
def _replace_protocols_alias(self, protocol):
#print(protocol)
# print(protocol)
if protocol:
from six import string_types
if isinstance(protocol, string_types):
#print([self._replace_protocol_alias(protocol)])
# print([self._replace_protocol_alias(protocol)])
return [self._replace_protocol_alias(protocol)]
else:
#print(list(set(self._replace_protocol_alias(k) for k in protocols)))
# print(list(set(self._replace_protocol_alias(k) for k in protocols)))
return list(set(self._replace_protocol_alias(k) for k in protocols))
else: return None
else:
return None
def clients(self, protocol=None, groups=None, subworld=None, gender=None):
"""Returns a list of Clients for the specific query by the user.
......@@ -134,8 +137,10 @@ class Database(bob.db.verification.utils.SQLiteDatabase, bob.db.verification.uti
retval += list(q)
dev_eval = []
if 'dev' in groups: dev_eval.append('dev')
if 'eval' in groups: dev_eval.append('eval')
if 'dev' in groups:
dev_eval.append('dev')
if 'eval' in groups:
dev_eval.append('eval')
if dev_eval:
protocol_gender = None
if protocol:
......@@ -154,13 +159,13 @@ class Database(bob.db.verification.utils.SQLiteDatabase, bob.db.verification.uti
def has_client_id(self, id):
"""Returns True if we have a client with a certain integer identifier"""
return self.query(Client).filter(Client.id==id).count() != 0
return self.query(Client).filter(Client.id == id).count() != 0
def client(self, id):
"""Returns the Client object in the database given a certain id. Raises
an error if that does not exist."""
return self.query(Client).filter(Client.id==id).one()
return self.query(Client).filter(Client.id == id).one()
def tclients(self, protocol=None, groups=None, subworld='onethird', gender=None):
"""Returns a set of T-Norm clients for the specific query by the user.
......@@ -348,7 +353,7 @@ class Database(bob.db.verification.utils.SQLiteDatabase, bob.db.verification.uti
return model_id
def objects(self, protocol=None, purposes=None, model_ids=None,
groups=None, classes=None, subworld=None, gender=None, device=None):
groups=None, classes=None, subworld=None, gender=None, device=None):
"""Returns a set of Files for the specific query by the user.
Keyword Parameters:
......@@ -412,7 +417,7 @@ class Database(bob.db.verification.utils.SQLiteDatabase, bob.db.verification.uti
retval = []
if 'world' in groups and 'train' in purposes:
q = self.query(File).join(Client).filter(Client.sgroup == 'world').join((ProtocolPurpose, File.protocol_purposes)).join(Protocol).\
filter(and_(Protocol.name.in_(protocol), ProtocolPurpose.sgroup == 'world'))
filter(and_(Protocol.name.in_(protocol), ProtocolPurpose.sgroup == 'world'))
if subworld:
q = q.join((Subworld, File.subworld)).filter(Subworld.name.in_(subworld))
if gender:
......@@ -427,7 +432,7 @@ class Database(bob.db.verification.utils.SQLiteDatabase, bob.db.verification.uti
if ('dev' in groups or 'eval' in groups):
if('enroll' in purposes):
q = self.query(File).join(Client).join((ProtocolPurpose, File.protocol_purposes)).join(Protocol).\
filter(and_(Protocol.name.in_(protocol), ProtocolPurpose.sgroup.in_(groups), ProtocolPurpose.purpose == 'enroll'))
filter(and_(Protocol.name.in_(protocol), ProtocolPurpose.sgroup.in_(groups), ProtocolPurpose.purpose == 'enroll'))
if gender:
q = q.filter(Client.gender.in_(gender))
if device:
......@@ -440,7 +445,7 @@ class Database(bob.db.verification.utils.SQLiteDatabase, bob.db.verification.uti
if('probe' in purposes):
if('client' in classes):
q = self.query(File).join(Client).join((ProtocolPurpose, File.protocol_purposes)).join(Protocol).\
filter(and_(Protocol.name.in_(protocol), ProtocolPurpose.sgroup.in_(groups), ProtocolPurpose.purpose == 'probe'))
filter(and_(Protocol.name.in_(protocol), ProtocolPurpose.sgroup.in_(groups), ProtocolPurpose.purpose == 'probe'))
if gender:
q = q.filter(Client.gender.in_(gender))
if device:
......@@ -452,7 +457,7 @@ class Database(bob.db.verification.utils.SQLiteDatabase, bob.db.verification.uti
if('impostor' in classes):
q = self.query(File).join(Client).join((ProtocolPurpose, File.protocol_purposes)).join(Protocol).\
filter(and_(Protocol.name.in_(protocol), ProtocolPurpose.sgroup.in_(groups), ProtocolPurpose.purpose == 'probe'))
filter(and_(Protocol.name.in_(protocol), ProtocolPurpose.sgroup.in_(groups), ProtocolPurpose.purpose == 'probe'))
if gender:
q = q.filter(Client.gender.in_(gender))
if device:
......@@ -462,7 +467,7 @@ class Database(bob.db.verification.utils.SQLiteDatabase, bob.db.verification.uti
q = q.order_by(File.client_id, File.session_id, File.speech_type, File.shot_id, File.device)
retval += list(q)
return list(set(retval)) # To remove duplicates
return list(set(retval)) # To remove duplicates
def tobjects(self, protocol=None, model_ids=None, groups=None, subworld='onethird', gender=None, speech_type=None, device=None):
"""Returns a set of filenames for enrolling T-norm models for score
......@@ -506,14 +511,13 @@ class Database(bob.db.verification.utils.SQLiteDatabase, bob.db.verification.uti
subworld = self.check_parameters_for_validity(subworld, "subworld", self.subworld_names(), [])
gender = self.check_parameters_for_validity(gender, "gender", self.genders(), [])
import collections
if(model_ids is None):
model_ids = ()
elif isinstance(model_ids, six.string_types):
model_ids = (model_ids,)
# Now query the database
q = self.query(File,Protocol).filter(Protocol.name.in_(protocol)).join(Client)
q = self.query(File, Protocol).filter(Protocol.name.in_(protocol)).join(Client)
if subworld:
q = q.join((Subworld, File.subworld)).filter(Subworld.name.in_(subworld))
q = q.join((TModel, File.tmodels)).filter(TModel.protocol_id == Protocol.id)
......@@ -529,7 +533,7 @@ class Database(bob.db.verification.utils.SQLiteDatabase, bob.db.verification.uti
retval = [v[0] for v in q]
return list(retval)
def zobjects(self, protocol=None, model_ids=None, groups=None, subworld='onethird', gender=None, speech_type=['r','f'], device=['mobile']):
def zobjects(self, protocol=None, model_ids=None, groups=None, subworld='onethird', gender=None, speech_type=['r', 'f'], device=['mobile']):
"""Returns a set of Files to perform Z-norm score normalization.
Keyword Parameters:
......@@ -581,7 +585,7 @@ class Database(bob.db.verification.utils.SQLiteDatabase, bob.db.verification.uti
# Now query the database
q = self.query(File).join(Client).filter(Client.sgroup == 'world').join((ProtocolPurpose, File.protocol_purposes)).join(Protocol).\
filter(and_(Protocol.name.in_(protocol), ProtocolPurpose.sgroup == 'world'))
filter(and_(Protocol.name.in_(protocol), ProtocolPurpose.sgroup == 'world'))
if subworld:
q = q.join((Subworld, File.subworld)).filter(Subworld.name.in_(subworld))
if gender:
......@@ -615,7 +619,7 @@ class Database(bob.db.verification.utils.SQLiteDatabase, bob.db.verification.uti
annotation_file = file.make_path(self.annotation_directory, self.annotation_extension)
# return the annotations as read from file
return bob.db.verification.utils.read_annotation_file(annotation_file, 'eyecenter')
return bob.db.base.read_annotation_file(annotation_file, 'eyecenter')
def protocol_names(self):
"""Returns all registered protocol names"""
......@@ -632,13 +636,13 @@ class Database(bob.db.verification.utils.SQLiteDatabase, bob.db.verification.uti
def has_protocol(self, name):
"""Tells if a certain protocol is available"""
return self.query(Protocol).filter(Protocol.name==self._replace_protocol_alias(name)).count() != 0
return self.query(Protocol).filter(Protocol.name == self._replace_protocol_alias(name)).count() != 0
def protocol(self, name):
"""Returns the protocol object in the database given a certain name. Raises
an error if that does not exist."""
return self.query(Protocol).filter(Protocol.name==self._replace_protocol_alias(name)).one()
return self.query(Protocol).filter(Protocol.name == self._replace_protocol_alias(name)).one()
def protocol_purposes(self):
"""Returns all registered protocol purposes"""
......
......@@ -13,7 +13,6 @@ develop = src/bob.extension
src/bob.core
src/bob.io.base
src/bob.db.base
src/bob.db.verification.utils
.
; options for bob.buildout extension
......@@ -27,7 +26,6 @@ bob.blitz = git https://github.com/bioidiap/bob.blitz
bob.core = git https://github.com/bioidiap/bob.core
bob.io.base = git https://github.com/bioidiap/bob.io.base
bob.db.base = git https://github.com/bioidiap/bob.db.base
bob.db.verification.utils = git https://github.com/bioidiap/bob.db.verification.utils
[scripts]
recipe = bob.buildout:scripts
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment