From b6353d5e221d11d281d6c9cdd0a049919622559d Mon Sep 17 00:00:00 2001
From: Yannick DAYER <yannick.dayer@idiap.ch>
Date: Wed, 12 May 2021 00:15:34 +0200
Subject: [PATCH] Set references in probes to prevent negative spoof

---
 bob/bio/face/database/replaymobile.py | 105 ++++++++++++++------------
 1 file changed, 57 insertions(+), 48 deletions(-)

diff --git a/bob/bio/face/database/replaymobile.py b/bob/bio/face/database/replaymobile.py
index 392609db..f48cac93 100644
--- a/bob/bio/face/database/replaymobile.py
+++ b/bob/bio/face/database/replaymobile.py
@@ -54,35 +54,6 @@ def load_frame_from_file_replaymobile(file_name, frame, should_flip):
     image = numpy.transpose(image, (0, 2, 1))
     return image
 
-def read_frame_annotation_file_replaymobile(file_name, frame, annotations_type="json"):
-    """Returns the bounding-box for one frame of a video file of replay-mobile.
-
-    Given an annnotation file location and a frame number, returns the bounding
-    box coordinates corresponding to the frame.
-
-    The replay-mobile annotation files are composed of 4 columns and N rows for
-    N frames of the video:
-
-    120 230 40 40
-    125 230 40 40
-    ...
-    <x> <y> <w> <h>
-
-    Parameters
-    ----------
-
-    file_name: str
-        The annotation file name (relative to annotations_path).
-
-    frame: int
-        The video frame index.
-    """
-    logger.debug(f"Reading annotation file '{file_name}', frame {frame}.")
-
-    video_annotations = read_annotation_file(file_name, annotation_type=annotations_type)
-    # read_annotation_file returns an ordered dict with string keys
-    return video_annotations[f"{frame}"]
-
 class ReplayMobileCSVFrameSampleLoader(CSVToSampleLoaderBiometrics):
     """A loader transformer returning a specific frame of a video file.
 
@@ -101,39 +72,79 @@ class ReplayMobileCSVFrameSampleLoader(CSVToSampleLoaderBiometrics):
             dataset_original_directory=dataset_original_directory,
         )
         self.reference_id_equal_subject_id = reference_id_equal_subject_id
+        self.references_list = []
 
     def convert_row_to_sample(self, row, header):
         """Creates a set of samples given a row of the CSV protocol definition.
         """
-        path = row[0]
-        reference_id = row[1]
-        id = row[2] # Will be used as 'key'
+        fields = dict([[str(h).lower(), r] for h, r in zip(header, row)])
 
-        kwargs = dict([[str(h).lower(), r] for h, r in zip(header[3:], row[3:])])
         if self.reference_id_equal_subject_id:
-            kwargs["subject_id"] = reference_id
+            fields["subject_id"] = fields["reference_id"]
         else:
-            if "subject_id" not in kwargs:
+            if "subject_id" not in fields:
                 raise ValueError(f"`subject_id` not available in {header}")
-        if "should_flip" not in kwargs:
+        if "should_flip" not in fields:
             raise ValueError(f"`should_flip` not available in {header}")
+        if "purpose" not in fields:
+            raise ValueError(f"`purpose` not available in {header}")
+
+        kwargs = {k: fields[k] for k in fields.keys() - {"id",}}
+
+        # Retrieve the references list
+        if fields["purpose"].lower() == "enroll" and fields["reference_id"] not in self.references_list:
+            self.references_list.append(fields["reference_id"])
+        # Set the references list in the probes for vanilla-biometrics
+        if fields["purpose"].lower() != "enroll":
+            if fields["attack_type"]:
+                # Attacks only compare to the target (no `spoof_neg`)
+                kwargs["references"] = fields["reference_id"]
+            else:
+                kwargs["references"] = self.references_list
         # One row leads to multiple samples (different frames)
         all_samples = [DelayedSample(
             functools.partial(
                 load_frame_from_file_replaymobile,
-                file_name=os.path.join(self.dataset_original_directory, path + self.extension),
+                file_name=os.path.join(self.dataset_original_directory, fields["path"] + self.extension),
                 frame=frame,
                 should_flip=kwargs["should_flip"]=="TRUE",
             ),
-            key=f"{id}_{frame}",
-            path=path,
-            reference_id=reference_id,
+            key=f"{fields['id']}_{frame}",
             frame=frame,
             **kwargs,
         ) for frame in range(12,251,24)]
         return all_samples
 
 
+def read_frame_annotation_file_replaymobile(file_name, frame, annotations_type="json"):
+    """Returns the bounding-box for one frame of a video file of replay-mobile.
+
+    Given an annnotation file location and a frame number, returns the bounding
+    box coordinates corresponding to the frame.
+
+    The replay-mobile annotation files are composed of 4 columns and N rows for
+    N frames of the video:
+
+    120 230 40 40
+    125 230 40 40
+    ...
+    <x> <y> <w> <h>
+
+    Parameters
+    ----------
+
+    file_name: str
+        The annotation file name (relative to annotations_path).
+
+    frame: int
+        The video frame index.
+    """
+    logger.debug(f"Reading annotation file '{file_name}', frame {frame}.")
+
+    video_annotations = read_annotation_file(file_name, annotation_type=annotations_type)
+    # read_annotation_file returns an ordered dict with string keys
+    return video_annotations[f"{frame}"]
+
 class FrameBoundingBoxAnnotationLoader(AnnotationsLoader):
     """A transformer that adds bounding-box to a sample from annotations files.
 
@@ -144,7 +155,7 @@ class FrameBoundingBoxAnnotationLoader(AnnotationsLoader):
     """
     def __init__(self,
         annotation_directory=None,
-        annotation_extension=".face",
+        annotation_extension=".json",
         **kwargs
     ):
         super().__init__(
@@ -161,12 +172,7 @@ class FrameBoundingBoxAnnotationLoader(AnnotationsLoader):
 
         annotated_samples = []
         for x in X:
-
-            # Build the path to the annotation files structure
-            annotation_file = os.path.join(
-                self.annotation_directory, x.path + self.annotation_extension
-            )
-
+            # Adds the annotations as delayed_attributes, loading them when needed
             annotated_samples.append(
                 DelayedSample(
                     x._load,
@@ -176,6 +182,7 @@ class FrameBoundingBoxAnnotationLoader(AnnotationsLoader):
                             read_frame_annotation_file_replaymobile,
                             file_name=f"{self.annotation_directory}:{x.path}{self.annotation_extension}",
                             frame=int(x.frame),
+                            annotations_type=self.annotation_type,
                         )
                     ),
                 )
@@ -183,6 +190,7 @@ class FrameBoundingBoxAnnotationLoader(AnnotationsLoader):
 
         return annotated_samples
 
+
 class ReplayMobileBioDatabase(CSVDataset):
     """Database interface that loads a csv definition for replay-mobile
 
@@ -258,7 +266,8 @@ class ReplayMobileBioDatabase(CSVDataset):
                     annotation_extension=annotations_extension,
                 ),
             ),
+            fetch_probes=False,
             **kwargs
         )
-        self.annotation_type = "bounding-box"
+        self.annotation_type = "eyes-center"
         self.fixed_positions = None
-- 
GitLab