youtube.py 11 KB
Newer Older
1
2
3
4
5
6
7
8
9
from bob.bio.base.pipelines.vanilla_biometrics.abstract_classes import Database
from bob.pipelines import DelayedSample, SampleSet
from bob.bio.video.utils import VideoLikeContainer, select_frames
from functools import partial
import copy
from bob.extension import rc
from bob.extension.download import get_file
import bob.io.base
import os
10

11
import logging
12

13
logger = logging.getLogger(__name__)
14
15


16
17
18
19
20
21
22
class YoutubeDatabase(Database):
    """
    This package contains the access API and descriptions for the `YouTube Faces` database.
    It only contains the Bob accessor methods to use the DB directly from python, with our certified protocols.
    The actual raw data for the `YouTube Faces` database should be downloaded from the original URL (though we were not able to contact the corresponding Professor).

    .. warning::
23

24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
      To use this dataset protocol, you need to have the original files of the YOUTUBE datasets.
      Once you have it downloaded, please run the following command to set the path for Bob

        .. code-block:: sh

            bob config set bob.bio.face.youtube.directory [YOUTUBE PATH]



    In this interface we implement the 10 original protocols of the `YouTube Faces` database ('fold1', 'fold2', 'fold3', 'fold4', 'fold5', 'fold6', 'fold7', 'fold8', 'fold9', 'fold10')


    The code below allows you to fetch the galery and probes of the "fold0" protocol.

    .. code-block:: python

        >>> from bob.bio.video.database import YoutubeDatabase
        >>> youtube = YoutubeDatabase(protocol="fold0")
        >>>
43
        >>> # Fetching the gallery
44
        >>> references = youtube.references()
45
        >>> # Fetching the probes
46
47
48
        >>> probes = youtube.probes()


49
50
51
52
    Parameters
    ----------

        protocol: str
53
           One of the Youtube above mentioned protocols
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68

        annotation_type: str
           One of the supported annotation types

        original_directory: str
           Original directory

        extension: str
           Default file extension

        annotation_extension: str

        frame_selector:
           Pointer to a function that does frame selection.

69
70
71
72
73
74
75
    """

    def __init__(
        self,
        protocol,
        annotation_type="bounding-box",
        fixed_positions=None,
76
        original_directory=rc.get("bob.bio.video.youtube.directory", ""),
77
78
        extension=".jpg",
        annotation_extension=".labeled_faces.txt",
79
        frame_selector=None,
80
81
82
83
    ):

        self._check_protocol(protocol)

84
85
        original_directory = original_directory or ""
        if not os.path.exists(original_directory):
86
87
            logger.warning(
                "Invalid or non existant `original_directory`: f{original_directory}."
88
                "Please, do `bob config set bob.bio.video.youtube.directory PATH` to set the Youtube data directory."
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
            )

        urls = YoutubeDatabase.urls()
        cache_subdir = os.path.join("datasets", "youtube_protocols")
        self.filename = get_file(
            "youtube_protocols-6962cd2e.tar.gz",
            urls,
            file_hash="8a4792872ff30b37eab7f25790b0b10d",
            extract=True,
            cache_subdir=cache_subdir,
        )
        self.protocol_path = os.path.dirname(self.filename)

        self.references_dict = {}
        self.probes_dict = {}

        # Dict that holds a `subject_id` as a key and has
        # filenames as values
        self.subject_id_files = {}
        self.reference_id_to_subject_id = None
        self.reference_id_to_sample = None
        self.load_file_client_id()
        self.original_directory = original_directory
        self.extension = extension
        self.annotation_extension = annotation_extension
114
        self.frame_selector = frame_selector
115
116
117
118
119
120
121
122
123
124
125
126
127

        super().__init__(
            name="youtube",
            protocol=protocol,
            allow_scoring_with_all_biometric_references=False,
            annotation_type=annotation_type,
            fixed_positions=None,
            memory_demanding=True,
        )

    def load_file_client_id(self):

        self.subject_id_files = {}
128

129
130
131
132
133
134
135
136
137
138
139
        # List containing the client ID
        # Each element of this file matches a line in Youtube_names.txt
        self.reference_id_to_subject_id = bob.io.base.load(
            os.path.join(self.protocol_path, "Youtube_labels.mat.hdf5")
        )[0].astype("int")
        self.reference_id_to_sample = [
            x.rstrip("\n")
            for x in open(
                os.path.join(self.protocol_path, "Youtube_names.txt")
            ).readlines()
        ]
140

141
142
143
144
        for l, n in zip(self.reference_id_to_subject_id, self.reference_id_to_sample):
            key = int(l)
            if key not in self.subject_id_files:
                self.subject_id_files[key] = []
145

146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
            self.subject_id_files[key].append(n.rstrip("\n"))

    def _load_pairs(self):
        fold = int(self.protocol[-1])

        split = bob.io.base.load(
            os.path.join(self.protocol_path, "Youtube_splits.mat.hdf5")
        )[:, :, fold].astype(int)

        return split[:, 0], split[:, 1]

    def _load_video_from_path(self, path):
        files = sorted(
            [x for x in os.listdir(path) if os.path.splitext(x)[1] == ".jpg"]
        )
161

162
163
164
165
166
167
168
169
170
171
        # If there's no frame selector, uses all frames
        files_indices = (
            select_frames(
                len(files),
                max_number_of_frames=None,
                selection_style="all",
                step_size=None,
            )
            if self.frame_selector is None
            else self.frame_selector(len(files))
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
172
        )
173

174
        data, indices = [], []
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
175
176
177
        for i, file_name in enumerate(files):
            if i not in files_indices:
                continue
178
            file_name = os.path.join(path, file_name)
179
180
            indices.append(os.path.basename(file_name))
            data.append(bob.io.base.load(file_name))
181

182
183
        return VideoLikeContainer(data=data, indices=indices)

184
    def _make_sample_set(self, reference_id, subject_id, sample_path, references=None):
185

186
        path = os.path.join(self.original_directory, sample_path)
187

188
        kwargs = {} if references is None else {"references": references}
189

190
191
192
193
194
195
196
        # Delaying the annotation loading
        delayed_annotations = partial(self._annotations, path)
        delayed_attributes = {"annotations": delayed_annotations}
        return SampleSet(
            key=str(reference_id),
            reference_id=str(reference_id),
            subject_id=str(subject_id),
197
            **kwargs,
198
199
200
201
202
203
204
205
            samples=[
                DelayedSample(
                    key=str(sample_path),
                    load=partial(self._load_video_from_path, path),
                    annotations=None,
                    delayed_attributes={"annotations": delayed_annotations},
                )
            ],
206
207
        )

208
209
210
    def _annotations(self, path):
        """Returns the annotations for the given file id as a dictionary of dictionaries, e.g. {'1.56.jpg' : {'topleft':(y,x), 'bottomright':(y,x)}, '1.57.jpg' : {'topleft':(y,x), 'bottomright':(y,x)}, ...}.
        Here, the key of the dictionary is the full image file name of the original image.
211

212
213
        Parameters
        ----------
214

215
216
        path: str
            The path containing the frame sequence of a user
217

218
        """
219

220
221
222
223
        if self.original_directory is None:
            raise ValueError(
                "Please specify the 'original_directory' in the constructor of this class to get the annotations."
            )
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
224

225
226
        directory = os.path.dirname(path)
        shot_id = os.path.basename(path)
227

228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
        annotation_file = os.path.join(directory + self.annotation_extension)

        annots = {}

        with open(annotation_file) as f:
            for line in f:
                splits = line.rstrip().split(",")
                # shot_id = int(splits[0].split("\\")[1])
                index = splits[0].split("\\")[2]

                # coordinates are: center x, center y, width, height
                (center_y, center_x, d_y, d_x) = (
                    float(splits[3]),
                    float(splits[2]),
                    float(splits[5]) / 2.0,
                    float(splits[4]) / 2.0,
                )
                # extract the bounding box information
                annots[index] = {
                    "topleft": (center_y - d_y, center_x - d_x),
                    "bottomright": (center_y + d_y, center_x + d_x),
                }

        # return the annotations as returned by the call function of the
        # Annotation object
        return annots

    def background_model_samples(self):
256
257
        """
        """
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
        return None

    def references(self, group="dev"):
        self._check_group(group)
        if self.protocol not in self.references_dict:
            self.references_dict[self.protocol] = []
            pairs = self._load_pairs()

            for i, (e, _) in enumerate(zip(pairs[0], pairs[1])):
                reference_id = e
                suject_id = self.reference_id_to_subject_id[reference_id]
                sample_path = self.reference_id_to_sample[reference_id]
                sampleset = self._make_sample_set(reference_id, suject_id, sample_path)
                self.references_dict[self.protocol].append(sampleset)

        return self.references_dict[self.protocol]
274

275
276
277
278
279
    def probes(self, group="dev"):
        self._check_group(group)
        if self.protocol not in self.probes_dict:
            self.probes_dict[self.protocol] = []
            pairs = self._load_pairs()
280

281
282
283
284
285
286
            # Computing reference list
            probe_to_reference_id_dict = dict()
            for e, p in zip(pairs[0], pairs[1]):
                if p not in probe_to_reference_id_dict:
                    probe_to_reference_id_dict[p] = []
                probe_to_reference_id_dict[p].append(str(e))
287

288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
            # Now assembling the samplesets
            for _, p in zip(pairs[0], pairs[1]):
                reference_id = p
                suject_id = self.reference_id_to_subject_id[reference_id]
                sample_path = self.reference_id_to_sample[reference_id]
                references = copy.deepcopy(probe_to_reference_id_dict[p])
                sampleset = self._make_sample_set(
                    reference_id, suject_id, sample_path, references
                )
                self.probes_dict[self.protocol].append(sampleset)

        return self.probes_dict[self.protocol]

    def all_samples(self):
        return self.references() + self.probes()

    def groups(self):
        return ["dev"]

    @staticmethod
    def urls():
        return [
            "https://www.idiap.ch/software/bob/databases/latest/youtube_protocols-6962cd2e.tar.gz",
            "http://www.idiap.ch/software/bob/databases/latest/youtube_protocols-6962cd2e.tar.gz",
        ]

    @staticmethod
    def protocols():
        return [f"fold{fold}" for fold in range(10)]

    def _check_protocol(self, protocol):
        assert protocol in self.protocols(), "Unvalid protocol `{}` not in {}".format(
            protocol, self.protocols()
        )

    def _check_group(self, group):
        assert group in self.groups(), "Unvalid group `{}` not in {}".format(
            group, self.groups()
        )