Skip to content
Snippets Groups Projects
Commit 096a685a authored by André Anjos's avatar André Anjos :speech_balloon:
Browse files

Mark queries as read-only to avoid lock contention

parent f35c95fb
No related branches found
No related tags found
No related merge requests found
...@@ -136,19 +136,19 @@ class Database(object): ...@@ -136,19 +136,19 @@ class Database(object):
# real-accesses are simpler to query # real-accesses are simpler to query
if 'enroll' in cls: if 'enroll' in cls:
q = self.session.query(RealAccess).join(File).join(Client).filter(Client.set.in_(groups)).filter(RealAccess.purpose=='enroll').filter(File.light.in_(light)).order_by(Client.id) q = self.session.query(RealAccess).with_lockmode('read').join(File).join(Client).filter(Client.set.in_(groups)).filter(RealAccess.purpose=='enroll').filter(File.light.in_(light)).order_by(Client.id)
for key, value in [(k.file.id, k.file.path) for k in q]: for key, value in [(k.file.id, k.file.path) for k in q]:
retval[key] = make_path(str(value), directory, extension) retval[key] = make_path(str(value), directory, extension)
# real-accesses are simpler to query # real-accesses are simpler to query
if 'real' in cls: if 'real' in cls:
q = self.session.query(RealAccess).join(File).join(Client).filter(RealAccess.protocols.contains(protocol)).filter(Client.set.in_(groups)).filter(File.light.in_(light)).order_by(Client.id) q = self.session.query(RealAccess).with_lockmode('read').join(File).join(Client).filter(RealAccess.protocols.contains(protocol)).filter(Client.set.in_(groups)).filter(File.light.in_(light)).order_by(Client.id)
for key, value in [(k.file.id, k.file.path) for k in q]: for key, value in [(k.file.id, k.file.path) for k in q]:
retval[key] = make_path(str(value), directory, extension) retval[key] = make_path(str(value), directory, extension)
# attacks will have to be filtered a little bit more # attacks will have to be filtered a little bit more
if 'attack' in cls: if 'attack' in cls:
q = self.session.query(Attack).join(File).join(Client).filter(Attack.protocols.contains(protocol)).filter(Client.set.in_(groups)).filter(Attack.attack_support.in_(support)).filter(File.light.in_(light)).order_by(Client.id) q = self.session.query(Attack).with_lockmode('read').join(File).join(Client).filter(Attack.protocols.contains(protocol)).filter(Client.set.in_(groups)).filter(Attack.attack_support.in_(support)).filter(File.light.in_(light)).order_by(Client.id)
for key, value in [(k.file.id, k.file.path) for k in q]: for key, value in [(k.file.id, k.file.path) for k in q]:
retval[key] = make_path(str(value), directory, extension) retval[key] = make_path(str(value), directory, extension)
...@@ -161,7 +161,7 @@ class Database(object): ...@@ -161,7 +161,7 @@ class Database(object):
if not self.is_valid(): if not self.is_valid():
raise RuntimeError, "Database '%s' cannot be found at expected location '%s'. Create it and then try re-connecting using Database.connect()" % (INFO.name(), SQLITE_FILE) raise RuntimeError, "Database '%s' cannot be found at expected location '%s'. Create it and then try re-connecting using Database.connect()" % (INFO.name(), SQLITE_FILE)
return tuple([k.name for k in self.session.query(Protocol)]) return tuple([k.name for k in self.session.query(Protocol).with_lockmode('read')])
def has_protocol(self, name): def has_protocol(self, name):
"""Tells if a certain protocol is available""" """Tells if a certain protocol is available"""
...@@ -169,7 +169,7 @@ class Database(object): ...@@ -169,7 +169,7 @@ class Database(object):
if not self.is_valid(): if not self.is_valid():
raise RuntimeError, "Database '%s' cannot be found at expected location '%s'. Create it and then try re-connecting using Database.connect()" % (INFO.name(), SQLITE_FILE) raise RuntimeError, "Database '%s' cannot be found at expected location '%s'. Create it and then try re-connecting using Database.connect()" % (INFO.name(), SQLITE_FILE)
return self.session.query(Protocol).filter(Protocol.name==name).count() != 0 return self.session.query(Protocol).with_lockmode('read').filter(Protocol.name==name).count() != 0
def protocol(self, name): def protocol(self, name):
"""Returns the protocol object in the database given a certain name. Raises """Returns the protocol object in the database given a certain name. Raises
...@@ -178,7 +178,7 @@ class Database(object): ...@@ -178,7 +178,7 @@ class Database(object):
if not self.is_valid(): if not self.is_valid():
raise RuntimeError, "Database '%s' cannot be found at expected location '%s'. Create it and then try re-connecting using Database.connect()" % (INFO.name(), SQLITE_FILE) raise RuntimeError, "Database '%s' cannot be found at expected location '%s'. Create it and then try re-connecting using Database.connect()" % (INFO.name(), SQLITE_FILE)
return self.session.query(Protocol).filter(Protocol.name==name).one() return self.session.query(Protocol).with_lockmode('read').filter(Protocol.name==name).one()
def groups(self): def groups(self):
"""Returns the names of all registered groups""" """Returns the names of all registered groups"""
...@@ -228,7 +228,7 @@ class Database(object): ...@@ -228,7 +228,7 @@ class Database(object):
if not self.is_valid(): if not self.is_valid():
raise RuntimeError, "Database '%s' cannot be found at expected location '%s'. Create it and then try re-connecting using Database.connect()" % (INFO.name(), SQLITE_FILE) raise RuntimeError, "Database '%s' cannot be found at expected location '%s'. Create it and then try re-connecting using Database.connect()" % (INFO.name(), SQLITE_FILE)
fobj = self.session.query(File).filter(File.id.in_(ids)) fobj = self.session.query(File).with_lockmode('read').filter(File.id.in_(ids))
retval = [] retval = []
for p in ids: for p in ids:
retval.extend([os.path.join(prefix, str(k.path) + suffix) retval.extend([os.path.join(prefix, str(k.path) + suffix)
...@@ -250,7 +250,7 @@ class Database(object): ...@@ -250,7 +250,7 @@ class Database(object):
if not self.is_valid(): if not self.is_valid():
raise RuntimeError, "Database '%s' cannot be found at expected location '%s'. Create it and then try re-connecting using Database.connect()" % (INFO.name(), SQLITE_FILE) raise RuntimeError, "Database '%s' cannot be found at expected location '%s'. Create it and then try re-connecting using Database.connect()" % (INFO.name(), SQLITE_FILE)
fobj = self.session.query(File).filter(File.path.in_(paths)) fobj = self.session.query(File).with_lockmode('read').filter(File.path.in_(paths))
retval = [] retval = []
for p in paths: for p in paths:
retval.extend([k.id for k in fobj if k.path == p]) retval.extend([k.id for k in fobj if k.path == p])
...@@ -284,7 +284,7 @@ class Database(object): ...@@ -284,7 +284,7 @@ class Database(object):
from bob.io import save from bob.io import save
fobj = self.session.query(File).filter_by(id=id).one() fobj = self.session.query(File).with_lockmode('read').filter_by(id=id).one()
fullpath = os.path.join(directory, str(fobj.path) + extension) fullpath = os.path.join(directory, str(fobj.path) + extension)
fulldir = os.path.dirname(fullpath) fulldir = os.path.dirname(fullpath)
utils.makedirs_safe(fulldir) utils.makedirs_safe(fulldir)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment