Commit 9ddc9d0b authored by Amir MOHAMMADI's avatar Amir MOHAMMADI

Fix deprecation warnings

parent 11253a9c
Pipeline #49158 passed with stages
in 6 minutes and 33 seconds
......@@ -4,13 +4,15 @@
from .models import Client, File, DEFAULT_DATADIR
import bob.db.base
from bob.db.base.utils import check_parameters_for_validity
class Database(bob.db.base.Database):
"""Wrapper class for the AT&T (aka ORL) database of faces (http://www.cl.cam.ac.uk/research/dtg/attarchive/facedatabase.html).
This class defines a simple protocol for training, enrollment and probe by splitting the few images of the database in a reasonable manner.
Due to the small size of the database, there is only a 'dev' group, and I did not define an 'eval' group."""
def __init__(self, original_directory=DEFAULT_DATADIR, original_extension='.pgm'):
def __init__(self, original_directory=DEFAULT_DATADIR, original_extension=".pgm"):
"""**Constructor Documentation**
Generates a database.
......@@ -27,9 +29,11 @@ class Database(bob.db.base.Database):
self.original_directory = original_directory
self.original_extension = original_extension
# initialize members
self.m_groups = ('world', 'dev')
self.m_purposes = ('enroll', 'probe')
self.m_training_clients = set([1, 2, 5, 6, 10, 11, 12, 14, 16, 17, 20, 21, 24, 26, 27, 29, 33, 34, 36, 39])
self.m_groups = ("world", "dev")
self.m_purposes = ("enroll", "probe")
self.m_training_clients = set(
[1, 2, 5, 6, 10, 11, 12, 14, 16, 17, 20, 21, 24, 26, 27, 29, 33, 34, 36, 39]
)
self.m_enroll_files = set([2, 4, 5, 7, 9])
def groups(self, protocol=None):
......@@ -53,12 +57,12 @@ class Database(bob.db.base.Database):
protocol
Ignored.
"""
groups = self.check_parameters_for_validity(groups, "group", self.m_groups)
groups = check_parameters_for_validity(groups, "group", self.m_groups)
ids = set()
if 'world' in groups:
if "world" in groups:
ids |= self.m_training_clients
if 'dev' in groups:
if "dev" in groups:
ids |= Client.m_valid_client_ids - self.m_training_clients
return [Client(id) for id in ids]
......@@ -75,12 +79,12 @@ class Database(bob.db.base.Database):
Ignored.
"""
groups = self.check_parameters_for_validity(groups, "group", self.m_groups)
groups = check_parameters_for_validity(groups, "group", self.m_groups)
ids = set()
if 'world' in groups:
if "world" in groups:
ids |= self.m_training_clients
if 'dev' in groups:
if "dev" in groups:
ids |= Client.m_valid_client_ids - self.m_training_clients
return sorted(list(ids))
......@@ -127,8 +131,9 @@ class Database(bob.db.base.Database):
protocol
ignored.
"""
return File._from_file_id(file_id, self.original_directory,
self.original_extension).client_id
return File._from_file_id(
file_id, self.original_directory, self.original_extension
).client_id
def get_client_id_from_model_id(self, model_id, groups=None, protocol=None):
"""Returns the client id from the given model id.
......@@ -172,7 +177,7 @@ class Database(bob.db.base.Database):
"""
# check if groups set are valid
groups = self.check_parameters_for_validity(groups, "group", self.m_groups)
groups = check_parameters_for_validity(groups, "group", self.m_groups)
# collect the ids to retrieve
ids = set(self.client_ids(groups))
......@@ -180,36 +185,51 @@ class Database(bob.db.base.Database):
# check the desired client ids for sanity
if isinstance(model_ids, int):
model_ids = (model_ids,)
model_ids = self.check_parameters_for_validity(model_ids, "model", list(Client.m_valid_client_ids))
model_ids = check_parameters_for_validity(
model_ids, "model", list(Client.m_valid_client_ids)
)
# calculate the intersection between the ids and the desired client ids
ids = ids & set(model_ids)
# check that the purposes are valid
if 'dev' in groups:
purposes = self.check_parameters_for_validity(purposes, "purpose", self.m_purposes)
if "dev" in groups:
purposes = check_parameters_for_validity(
purposes, "purpose", self.m_purposes
)
else:
purposes = self.m_purposes
# go through the dataset and collect all desired files
retval = []
if 'enroll' in purposes:
if "enroll" in purposes:
for client_id in ids:
for file_id in self.m_enroll_files:
retval.append(File(client_id, file_id,
self.original_directory, self.original_extension))
if 'probe' in purposes:
retval.append(
File(
client_id,
file_id,
self.original_directory,
self.original_extension,
)
)
if "probe" in purposes:
file_ids = File.m_valid_file_ids - self.m_enroll_files
# for probe, we use all clients of the given groups
for client_id in self.client_ids(groups):
for file_id in file_ids:
retval.append(File(client_id, file_id,
self.original_directory, self.original_extension))
retval.append(
File(
client_id,
file_id,
self.original_directory,
self.original_extension,
)
)
return retval
def paths(self, file_ids, prefix=None, suffix=None, preserve_order=True):
"""Returns a full file paths considering particular file ids, a given
directory and an extension
......@@ -233,10 +253,12 @@ class Database(bob.db.base.Database):
file ids.
"""
files = [File._from_file_id(id, self.original_directory, self.original_extension) for id in file_ids]
files = [
File._from_file_id(id, self.original_directory, self.original_extension)
for id in file_ids
]
return [f.make_path(prefix, suffix) for f in files]
def reverse(self, paths, preserve_order=True):
"""Reverses the lookup: from certain paths, return a list of
File objects
......@@ -253,4 +275,7 @@ class Database(bob.db.base.Database):
Returns a list (that may be empty).
"""
return [File._from_path(p, self.original_directory, self.original_extension) for p in paths]
return [
File._from_path(p, self.original_directory, self.original_extension)
for p in paths
]
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment