Commit 258a4918 authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

Ported Youtube face database to the new interface

parent 3c6303a1
from bob.bio.base.pipelines.vanilla_biometrics.legacy import DatabaseConnector
from bob.bio.video.database import YoutubeBioDatabase
from bob.bio.video.database import YoutubeDatabase
database = YoutubeDatabase(protocol="fold0")
database = DatabaseConnector(
YoutubeBioDatabase(
protocol="fold1",
models_depend_on_protocol=True,
training_depends_on_protocol=True,
all_files_options={"subworld": "fivefolds"},
extractor_training_options={"subworld": "fivefolds"},
projector_training_options={"subworld": "fivefolds"},
enroller_training_options={"subworld": "fivefolds"},
)
)
from .youtube import YoutubeDatabase
from .database import VideoBioFile
from .youtube import YoutubeBioDatabase
# gets sphinx autodoc done right - don't remove it
......@@ -18,8 +18,5 @@ def __appropriate__(*args):
obj.__module__ = __name__
__appropriate__(
VideoBioFile,
YoutubeBioDatabase,
)
__all__ = [_ for _ in dir() if not _.startswith('_')]
__appropriate__(YoutubeDatabase, VideoBioFile)
__all__ = [_ for _ in dir() if not _.startswith("_")]
"""
YOUTUBE database implementation of bob.bio.base.database.ZTDatabase interface.
It is an extension of an SQL-based database interface, which directly talks to YOUTUBE database, for
verification experiments (good to use in bob.bio.base framework).
"""
from bob.bio.base.pipelines.vanilla_biometrics.abstract_classes import Database
from bob.pipelines import DelayedSample, SampleSet
from bob.bio.video.utils import VideoLikeContainer, select_frames
from functools import partial
import copy
from bob.extension import rc
from bob.extension.download import get_file
import bob.io.base
import os
import logging
import os
logger = logging.getLogger(__name__)
import bob.io.base
from bob.bio.base.database import ZTBioDatabase
from bob.extension import rc
from ..utils import VideoLikeContainer, select_frames
from .database import VideoBioFile
class YoutubeDatabase(Database):
"""
This package contains the access API and descriptions for the `YouTube Faces` database.
It only contains the Bob accessor methods to use the DB directly from python, with our certified protocols.
The actual raw data for the `YouTube Faces` database should be downloaded from the original URL (though we were not able to contact the corresponding Professor).
.. warning::
To use this dataset protocol, you need to have the original files of the YOUTUBE datasets.
Once you have it downloaded, please run the following command to set the path for Bob
.. code-block:: sh
bob config set bob.bio.face.youtube.directory [YOUTUBE PATH]
In this interface we implement the 10 original protocols of the `YouTube Faces` database ('fold1', 'fold2', 'fold3', 'fold4', 'fold5', 'fold6', 'fold7', 'fold8', 'fold9', 'fold10')
The code below allows you to fetch the galery and probes of the "fold0" protocol.
.. code-block:: python
>>> from bob.bio.video.database import YoutubeDatabase
>>> youtube = YoutubeDatabase(protocol="fold0")
>>>
>>> # Fetching the gallery
>>> references = youtube.references()
>>> # Fetching the probes
>>> probes = youtube.probes()
"""
def __init__(
self,
protocol,
annotation_type="bounding-box",
fixed_positions=None,
original_directory=rc.get("bob.bio.face.youtube.directory"),
extension=".jpg",
selection_style="first",
max_number_of_frames=None,
step_size=None,
annotation_extension=".labeled_faces.txt",
):
self._check_protocol(protocol)
if original_directory is None or not os.path.exists(original_directory):
logger.warning(
"Invalid or non existant `original_directory`: f{original_directory}."
"Please, do `bob config set bob.bio.face.lfw.directory PATH` to set the LFW data directory."
)
urls = YoutubeDatabase.urls()
cache_subdir = os.path.join("datasets", "youtube_protocols")
self.filename = get_file(
"youtube_protocols-6962cd2e.tar.gz",
urls,
file_hash="8a4792872ff30b37eab7f25790b0b10d",
extract=True,
cache_subdir=cache_subdir,
)
self.protocol_path = os.path.dirname(self.filename)
self.references_dict = {}
self.probes_dict = {}
# Dict that holds a `subject_id` as a key and has
# filenames as values
self.subject_id_files = {}
self.reference_id_to_subject_id = None
self.reference_id_to_sample = None
self.load_file_client_id()
self.selection_style = selection_style
self.max_number_of_frames = max_number_of_frames
self.step_size = step_size
self.original_directory = original_directory
self.extension = extension
self.annotation_extension = annotation_extension
super().__init__(
name="youtube",
protocol=protocol,
allow_scoring_with_all_biometric_references=False,
annotation_type=annotation_type,
fixed_positions=None,
memory_demanding=True,
)
def load_file_client_id(self):
self.subject_id_files = {}
# List containing the client ID
# Each element of this file matches a line in Youtube_names.txt
self.reference_id_to_subject_id = bob.io.base.load(
os.path.join(self.protocol_path, "Youtube_labels.mat.hdf5")
)[0].astype("int")
self.reference_id_to_sample = [
x.rstrip("\n")
for x in open(
os.path.join(self.protocol_path, "Youtube_names.txt")
).readlines()
]
class YoutubeBioFile(VideoBioFile):
def __init__(self, f, **kwargs):
super().__init__(client_id=f.client_id, path=f.path, file_id=f.id, **kwargs)
self._f = f
for l, n in zip(self.reference_id_to_subject_id, self.reference_id_to_sample):
key = int(l)
if key not in self.subject_id_files:
self.subject_id_files[key] = []
def files(self):
base_dir = self.make_path(self.original_directory, "")
# collect all files from the data directory
files = [os.path.join(base_dir, f) for f in sorted(os.listdir(base_dir))]
# filter files with the given extension
if self.original_extension is not None:
files = [
f for f in files if os.path.splitext(f)[1] == self.original_extension
]
return files
self.subject_id_files[key].append(n.rstrip("\n"))
def _load_pairs(self):
fold = int(self.protocol[-1])
split = bob.io.base.load(
os.path.join(self.protocol_path, "Youtube_splits.mat.hdf5")
)[:, :, fold].astype(int)
return split[:, 0], split[:, 1]
def _load_video_from_path(self, path):
files = sorted(
[x for x in os.listdir(path) if os.path.splitext(x)[1] == ".jpg"]
)
def load(self, *args, **kwargs):
files = self.files()
files_indices = select_frames(
len(files),
max_number_of_frames=self.max_number_of_frames,
......@@ -43,85 +152,150 @@ class YoutubeBioFile(VideoBioFile):
for i, file_name in enumerate(files):
if i not in files_indices:
continue
file_name = os.path.join(path, file_name)
indices.append(os.path.basename(file_name))
data.append(bob.io.base.load(file_name))
return VideoLikeContainer(data=data, indices=indices)
def _make_sample_set(self, reference_id, subject_id, sample_path, references=None):
class YoutubeBioDatabase(ZTBioDatabase):
"""
YouTube Faces database implementation of :py:class:`bob.bio.base.database.ZTBioDatabase` interface.
It is an extension of an SQL-based database interface, which directly talks to :py:class:`bob.db.youtube.Database` database, for
verification experiments (good to use in ``bob.bio`` framework).
"""
path = os.path.join(self.original_directory, sample_path)
def __init__(
self,
original_directory=rc["bob.db.youtube.directory"],
original_extension=".jpg",
annotation_extension=".labeled_faces.txt",
**kwargs,
):
from bob.db.youtube.query import Database as LowLevelDatabase
kwargs = {} if references is None else {"references": references}
self._db = LowLevelDatabase(
original_directory, original_extension, annotation_extension
)
# call base class constructors to open a session to the database
super(YoutubeBioDatabase, self).__init__(
name="youtube",
original_directory=original_directory,
original_extension=original_extension,
annotation_extension=annotation_extension,
# Delaying the annotation loading
delayed_annotations = partial(self._annotations, path)
delayed_attributes = {"annotations": delayed_annotations}
return SampleSet(
key=str(reference_id),
reference_id=str(reference_id),
subject_id=str(subject_id),
**kwargs,
samples=[
DelayedSample(
key=str(sample_path),
load=partial(self._load_video_from_path, path),
annotations=None,
delayed_attributes={"annotations": delayed_annotations},
)
],
)
@property
def original_directory(self):
return self._db.original_directory
def _annotations(self, path):
"""Returns the annotations for the given file id as a dictionary of dictionaries, e.g. {'1.56.jpg' : {'topleft':(y,x), 'bottomright':(y,x)}, '1.57.jpg' : {'topleft':(y,x), 'bottomright':(y,x)}, ...}.
Here, the key of the dictionary is the full image file name of the original image.
@original_directory.setter
def original_directory(self, value):
self._db.original_directory = value
Parameters
----------
def model_ids_with_protocol(self, groups=None, protocol=None, **kwargs):
return self._db.model_ids(groups=groups, protocol=protocol)
path: str
The path containing the frame sequence of a user
def tmodel_ids_with_protocol(self, protocol=None, groups=None, **kwargs):
return self._db.tmodel_ids(protocol=protocol, groups=groups, **kwargs)
"""
def _populate_files_attrs(self, files):
for f in files:
f.original_directory = self.original_directory
f.original_extension = self.original_extension
f.annotation_extension = self.annotation_extension
return files
if self.original_directory is None:
raise ValueError(
"Please specify the 'original_directory' in the constructor of this class to get the annotations."
)
def objects(
self, groups=None, protocol=None, purposes=None, model_ids=None, **kwargs
):
retval = self._db.objects(
groups=groups,
protocol=protocol,
purposes=purposes,
model_ids=model_ids,
**kwargs,
)
return self._populate_files_attrs([YoutubeBioFile(f) for f in retval])
directory = os.path.dirname(path)
shot_id = os.path.basename(path)
def tobjects(self, groups=None, protocol=None, model_ids=None, **kwargs):
retval = self._db.tobjects(
groups=groups, protocol=protocol, model_ids=model_ids, **kwargs
)
return self._populate_files_attrs([YoutubeBioFile(f) for f in retval])
annotation_file = os.path.join(directory + self.annotation_extension)
annots = {}
with open(annotation_file) as f:
for line in f:
splits = line.rstrip().split(",")
# shot_id = int(splits[0].split("\\")[1])
index = splits[0].split("\\")[2]
# coordinates are: center x, center y, width, height
(center_y, center_x, d_y, d_x) = (
float(splits[3]),
float(splits[2]),
float(splits[5]) / 2.0,
float(splits[4]) / 2.0,
)
# extract the bounding box information
annots[index] = {
"topleft": (center_y - d_y, center_x - d_x),
"bottomright": (center_y + d_y, center_x + d_x),
}
# return the annotations as returned by the call function of the
# Annotation object
return annots
def background_model_samples(self):
return None
def references(self, group="dev"):
self._check_group(group)
if self.protocol not in self.references_dict:
self.references_dict[self.protocol] = []
pairs = self._load_pairs()
for i, (e, _) in enumerate(zip(pairs[0], pairs[1])):
reference_id = e
suject_id = self.reference_id_to_subject_id[reference_id]
sample_path = self.reference_id_to_sample[reference_id]
sampleset = self._make_sample_set(reference_id, suject_id, sample_path)
self.references_dict[self.protocol].append(sampleset)
return self.references_dict[self.protocol]
def zobjects(self, groups=None, protocol=None, **kwargs):
retval = self._db.zobjects(groups=groups, protocol=protocol, **kwargs)
return self._populate_files_attrs([YoutubeBioFile(f) for f in retval])
def probes(self, group="dev"):
self._check_group(group)
if self.protocol not in self.probes_dict:
self.probes_dict[self.protocol] = []
pairs = self._load_pairs()
def annotations(self, myfile):
return self._db.annotations(myfile._f)
# Computing reference list
probe_to_reference_id_dict = dict()
for e, p in zip(pairs[0], pairs[1]):
if p not in probe_to_reference_id_dict:
probe_to_reference_id_dict[p] = []
probe_to_reference_id_dict[p].append(str(e))
def client_id_from_model_id(self, model_id, group="dev"):
return self._db.get_client_id_from_file_id(model_id)
# Now assembling the samplesets
for _, p in zip(pairs[0], pairs[1]):
reference_id = p
suject_id = self.reference_id_to_subject_id[reference_id]
sample_path = self.reference_id_to_sample[reference_id]
references = copy.deepcopy(probe_to_reference_id_dict[p])
sampleset = self._make_sample_set(
reference_id, suject_id, sample_path, references
)
self.probes_dict[self.protocol].append(sampleset)
return self.probes_dict[self.protocol]
def all_samples(self):
return self.references() + self.probes()
def groups(self):
return ["dev"]
@staticmethod
def urls():
return [
"https://www.idiap.ch/software/bob/databases/latest/youtube_protocols-6962cd2e.tar.gz",
"http://www.idiap.ch/software/bob/databases/latest/youtube_protocols-6962cd2e.tar.gz",
]
@staticmethod
def protocols():
return [f"fold{fold}" for fold in range(10)]
def _check_protocol(self, protocol):
assert protocol in self.protocols(), "Unvalid protocol `{}` not in {}".format(
protocol, self.protocols()
)
def _check_group(self, group):
assert group in self.groups(), "Unvalid group `{}` not in {}".format(
group, self.groups()
)
from nose.plugins.skip import SkipTest
import bob.bio.base
from bob.bio.base.test.utils import db_available
from bob.bio.base.test.test_database_implementations import check_database_zt
from bob.bio.face.test.test_databases import _check_annotations
import pkg_resources
@db_available("youtube")
def test_youtube():
database = bob.bio.base.load_resource(
"youtube", "database", preferred_package="bob.bio.video"
)
try:
check_database_zt(database, training_depends=True, models_depend=True)
except IOError as e:
raise SkipTest(
"The database could not be queried; probably the db.sql3 file is missing. Here is the error: '%s'"
% e
)
try:
if database.database.original_directory is None:
raise SkipTest("The annotations cannot be queried as original_directory is None")
_check_annotations(database, limit_files=1000, topleft=True, framed=True)
except IOError as e:
raise SkipTest(
"The annotations could not be queried; probably the annotation files are missing. Here is the error: '%s'"
% e
)
def test_new_youtube():
from bob.bio.video.database import YoutubeDatabase
for protocol in [f"fold{i}" for i in range(10)]:
@db_available("youtube")
def test_youtube_load_method():
database = bob.bio.base.load_resource(
"youtube", "database", preferred_package="bob.bio.video"
)
database.database.original_directory = pkg_resources.resource_filename(
"bob.bio.video", "test/data"
)
youtube_db_sample = [
sample
for sample_set in database.references(group="dev")
for sample in sample_set
if sample.key == "Aaron_Eckhart/0"
][0]
database = YoutubeDatabase("fold0")
references = database.references()
probes = database.probes()
frame_container = youtube_db_sample.data
assert len(references) == 500
assert len(probes) == 500
assert len(frame_container) == 2
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment