Skip to content
Snippets Groups Projects

Small fix on LFW

Merged Tiago de Freitas Pereira requested to merge lfw-fix into master
1 unresolved thread
1 file
+ 50
37
Compare changes
  • Side-by-side
  • Inline
+ 50
37
@@ -157,7 +157,7 @@ class LFWDatabase(Database):
super().__init__(
name="lfw",
protocol=protocol,
allow_scoring_with_all_biometric_references=protocol[0] == 'o',
allow_scoring_with_all_biometric_references=protocol[0] == "o",
annotation_type=annotation_type,
fixed_positions=fixed_positions,
memory_demanding=False,
@@ -165,7 +165,6 @@ class LFWDatabase(Database):
self.load_pairs()
def _extract_funneled(self, annotation_path):
"""Interprets the annotation string as if it came from the funneled images.
Inspired by: https://gitlab.idiap.ch/bob/bob.db.lfw/-/blob/5ac22c5b77aae971de6b73cbe23f26d6a5632072/bob/db/lfw/models.py#L69
@@ -283,14 +282,28 @@ class LFWDatabase(Database):
self._create_probe_reference_dict()
elif self.protocol[0] == 'o':
self.pairs = {"enroll":{}, "training-unknown":[], "probe":{}, "o1":[], "o2":[]}
elif self.protocol[0] == "o":
self.pairs = {
"enroll": {},
"training-unknown": [],
"probe": {},
"o1": [],
"o2": [],
}
# parse directory for open-set protocols
for d in os.listdir(os.path.join(self.original_directory,self.image_relative_path)):
dd = os.path.join(self.original_directory,self.image_relative_path,d)
for d in os.listdir(
os.path.join(self.original_directory, self.image_relative_path)
):
dd = os.path.join(self.original_directory, self.image_relative_path, d)
if os.path.isdir(dd):
# count the number of images
images = sorted([os.path.splitext(i)[0] for i in os.listdir(dd) if os.path.splitext(i)[1] == self.extension])
images = sorted(
[
os.path.splitext(i)[0]
for i in os.listdir(dd)
if os.path.splitext(i)[1] == self.extension
]
)
if len(images) > 3:
# take the first three images for enrollment
@@ -310,7 +323,6 @@ class LFWDatabase(Database):
def protocols():
return ["view2", "o1", "o2", "o3"]
def background_model_samples(self):
"""This function returns the training set for the open-set protocols o1, o2 and o3.
It returns the :py:meth:`references` and the training samples with known unknowns, which get the subject id "unknown".
@@ -339,12 +351,13 @@ class LFWDatabase(Database):
# load annotations
if self.annotation_directory is not None:
annotation_path = os.path.join(
self.annotation_directory, self.make_path_from_filename(image) + self.annotation_extension,
self.annotation_directory,
self.make_path_from_filename(image) + self.annotation_extension,
)
annotations = self._extract(annotation_path)
else:
annotations = None
data[image] = (image_path,annotations)
data[image] = (image_path, annotations)
# generate one sampleset from images of the unknown unknowns
sset = SampleSet(
@@ -358,7 +371,7 @@ class LFWDatabase(Database):
annotations=data[image][1],
)
for image in data
]
],
)
return enrollmentset + [sset]
@@ -368,7 +381,7 @@ class LFWDatabase(Database):
where that probe should be compared with.
"""
if self.protocol[0] == 'o':
if self.protocol[0] == "o":
return
self.probe_reference_keys = {}
@@ -380,7 +393,6 @@ class LFWDatabase(Database):
self.probe_reference_keys[value].append(key)
def probes(self, group="dev"):
if self.protocol not in self.probes_dict:
self.probes_dict[self.protocol] = []
@@ -394,7 +406,8 @@ class LFWDatabase(Database):
)
if self.annotation_directory is not None:
annotation_path = os.path.join(
self.annotation_directory, key + self.annotation_extension,
self.annotation_directory,
key + self.annotation_extension,
)
annotations = self._extract(annotation_path)
else:
@@ -410,6 +423,7 @@ class LFWDatabase(Database):
samples=[
DelayedSample(
key=key,
reference_id=key,
load=partial(bob.io.image.load, image_path),
annotations=annotations,
)
@@ -417,24 +431,18 @@ class LFWDatabase(Database):
)
self.probes_dict[self.protocol].append(sset)
elif self.protocol[0] == 'o':
elif self.protocol[0] == "o":
# add known probes
# collect probe samples:
probes = [
(image,key)
for key in self.pairs["probe"]
for image in self.pairs["probe"][key]
]
(image, key)
for key in self.pairs["probe"]
for image in self.pairs["probe"][key]
]
if self.protocol in ("o1", "o3"):
probes += [
(image,"unknown")
for image in self.pairs["o1"]
]
probes += [(image, "unknown") for image in self.pairs["o1"]]
if self.protocol in ("o2", "o3"):
probes += [
(image,"unknown")
for image in self.pairs["o2"]
]
probes += [(image, "unknown") for image in self.pairs["o2"]]
for image, key in probes:
# get image path
@@ -446,7 +454,9 @@ class LFWDatabase(Database):
# load annotations
if self.annotation_directory is not None:
annotation_path = os.path.join(
self.annotation_directory, self.make_path_from_filename(image) + self.annotation_extension,
self.annotation_directory,
self.make_path_from_filename(image)
+ self.annotation_extension,
)
annotations = self._extract(annotation_path)
else:
@@ -460,6 +470,7 @@ class LFWDatabase(Database):
samples=[
DelayedSample(
key=image,
reference_id=image,
load=partial(bob.io.image.load, image_path),
annotations=annotations,
)
@@ -467,10 +478,8 @@ class LFWDatabase(Database):
)
self.probes_dict[self.protocol].append(sset)
return self.probes_dict[self.protocol]
def references(self, group="dev"):
if self.protocol not in self.references_dict:
@@ -486,7 +495,8 @@ class LFWDatabase(Database):
)
if self.annotation_directory is not None:
annotation_path = os.path.join(
self.annotation_directory, key + self.annotation_extension,
self.annotation_directory,
key + self.annotation_extension,
)
annotations = self._extract(annotation_path)
else:
@@ -499,13 +509,14 @@ class LFWDatabase(Database):
samples=[
DelayedSample(
key=key,
reference_id=key,
load=partial(bob.io.image.load, image_path),
annotations=annotations,
)
],
)
self.references_dict[self.protocol].append(sset)
elif self.protocol[0] == 'o':
elif self.protocol[0] == "o":
for key in self.pairs["enroll"]:
data = {}
for image in self.pairs["enroll"][key]:
@@ -518,12 +529,14 @@ class LFWDatabase(Database):
# load annotations
if self.annotation_directory is not None:
annotation_path = os.path.join(
self.annotation_directory, self.make_path_from_filename(image) + self.annotation_extension,
self.annotation_directory,
self.make_path_from_filename(image)
+ self.annotation_extension,
)
annotations = self._extract(annotation_path)
else:
annotations = None
data[image] = (image_path,annotations)
data[image] = (image_path, annotations)
# generate one sampleset from several (should be 3) images of the same person
sset = SampleSet(
@@ -533,17 +546,17 @@ class LFWDatabase(Database):
samples=[
DelayedSample(
key=image,
reference_id=key,
load=partial(bob.io.image.load, data[image][0]),
annotations=data[image][1],
)
for image in data
]
],
)
self.references_dict[self.protocol].append(sset)
return self.references_dict[self.protocol]
def groups(self):
return ["dev"]
@@ -552,7 +565,7 @@ class LFWDatabase(Database):
if self.protocol == "view2":
return self.references() + self.probes()
elif self.protocol[0] == 'o':
elif self.protocol[0] == "o":
return self.background_model_samples() + self.probes()
def _check_protocol(self, protocol):
Loading