Skip to content
Snippets Groups Projects
Commit c0ebe918 authored by Yannick DAYER's avatar Yannick DAYER
Browse files

Revert to use the bob.bio.base CSVDataset

parent de615967
No related branches found
No related tags found
1 merge request!106Vulnerability framework - CSV datasets
#!/usr/bin/env python
# Yannick Dayer <yannick.dayer@idiap.ch>
from bob.pipelines.datasets import FileListDatabase, CSVToSamples
from bob.bio.base.database import CSVDataset, CSVToSampleLoaderBiometrics
from bob.pipelines.datasets.sample_loaders import AnnotationsLoader
from bob.pipelines.sample import DelayedSample
from bob.extension.download import list_dir, search_file
from bob.db.base.utils import check_parameters_for_validity
from bob.db.base.annotations import read_annotation_file
from bob.io.video import reader as video_reader
from bob.bio.base.pipelines.vanilla_biometrics import Database
from bob.extension.download import get_file
from bob.io.video import reader
from bob.extension import rc
import bob.core
from sklearn.base import TransformerMixin, BaseEstimator
from sklearn.pipeline import make_pipeline
import functools
import os.path
import logging
import numpy
logger = logging.getLogger(__name__)
logger = bob.core.log.setup("bob.bio.face")
class VideoReader(TransformerMixin, BaseEstimator):
"""Transformer that loads the video data from a file
"""
def __init__(self, data_path, file_ext=".mov"):
self.data_path = data_path
self.file_ext = file_ext
def transform(self, X, y=None):
all_samples = []
for sample in X:
all_samples.append(
DelayedSample(
load=functools.partial(video_reader, os.path.join(self.data_path, sample.PATH + self.file_ext)),
parent=sample,
)
)
return all_samples
def load_frame_from_file_replaymobile(file_name, frame, capturing_device):
"""Loads a single frame from a video file for replay-mobile.
This function uses bob's video reader utility that does not load the full
video in memory to just access one frame.
Parameters
----------
def fit(self, X, y=None):
return self
file_name: str
The video file to load the frames from
def _more_tags(self):
return {
"stateless": True,
"requires_fit": False,
}
frame: None or list of int
The index of the frame to load.
capturing device: str
'mobile' devices' frames will be flipped vertically.
Other devices' frames will not be flipped.
def get_frame_from_sample(video_sample, frame_id):
"""Returns one frame's data from a replay-mobile video sample.
Returns
-------
Flips the image according to the sample's metadata.
images: 3D numpy array
The frame of the video in bob format (channel, height, width)
"""
frame = video_sample.data[frame_id]
if video_sample.SHOULD_FLIP: # TODO include this field in the csv files
frame = numpy.flip(frame, 2)
logger.debug(f"Extracting frame {frame} from '{file_name}'")
video_reader = reader(file_name)
image = video_reader[frame]
# Image captured by the 'mobile' device are flipped vertically.
# (Images were captured horizontally and bob.io.video does not read the
# metadata correctly, whether it was on the right or left side)
if capturing_device == "mobile":
image = numpy.flip(image, 2)
# Convert to bob format (channel, height, width)
frame = numpy.transpose(frame, (0, 2, 1))
return frame
image = numpy.transpose(image, (0, 2, 1))
return image
class VideoToFrames(TransformerMixin, BaseEstimator):
"""Transformer that creates a list of frame samples from a video sample.
def read_frame_annotation_file_replaymobile(file_name, frame):
"""Returns the bounding-box for one frame of a video file of replay-mobile.
Parameters
----------
Given an annnotation file location and a frame number, returns the bounding
box coordinates corresponding to the frame.
frame_indices: None or Sequence[int]
The list of frames to keep. Will keep all the frames if None or empty.
"""
def __init__(self, frame_indices=None):
super().__init__()
self.frame_indices = frame_indices
def transform(self, X, y=None):
all_samples = []
# Iterate over each video sample
for video_sample in X:
# Extract frames from the file
[
all_samples.append(DelayedSample(
load=functools.partial(get_frame_from_sample, video_sample, frame_id),
parent=video_sample,
frame=frame_id,
key=f"{video_sample.ID}_{frame_id}")
)
for frame_id in range(len(video_sample.data))
if not self.frame_indices or frame_id in self.frame_indices
]
return all_samples
The replay-mobile annotation files are composed of 4 columns and N rows for
N frames of the video:
def fit(self, X, y=None):
return self
120 230 40 40
125 230 40 40
...
<x> <y> <w> <h>
def _more_tags(self):
return {
"stateless": True,
"requires_fit": False,
}
Parameters
----------
file_name: str
The complete annotation file path and name (with extension).
def read_frame_annotations_file(file_name, frame_id):
"""Reads an annotations file and extracts one frame's annotations.
frame: int
The video frame index.
"""
video_annotations = read_annotation_file(file_name, annotation_type="json")
# read_annotation_file returns an ordered dict with string keys
return video_annotations[f"{frame_id}"]
logger.debug(f"Reading annotation file '{file_name}', frame {frame}.")
if not file_name:
return None
class AnnotationsAdder(TransformerMixin, BaseEstimator):
"""Transformer that adds an 'annotations' field to the samples.
if not os.path.exists(file_name):
raise IOError(f"The annotation file '{file_name}' was not found")
This reads a json file containing coordinates for each frame of a video.
"""
def __init__(self, annotation_directory):
self.annotation_directory=annotation_directory
def transform(self, X, y=None):
all_samples = []
for sample in X:
delayed_annotations = functools.partial(
read_frame_annotations_file,
file_name=f"{self.annotation_directory}:{sample.PATH}.json",
frame_id=sample.frame,
)
all_samples.append(
DelayedSample(
load=sample._load,
parent=sample,
delayed_attributes = {"annotations": delayed_annotations},
)
)
return all_samples
with open(file_name, 'r') as f:
# One line is one frame, each line contains a bounding box coordinates
line = f.readlines()[frame]
def fit(self, X, y=None):
return self
positions = line.split(' ')
def _more_tags(self):
return {
"stateless": True,
"requires_fit": False,
}
if len(positions) != 4:
raise ValueError(f"The content of '{file_name}' was not correct for frame {frame} ({positions})")
annotations = {
'topleft': (float(positions[1]), float(positions[0])),
'bottomright':(
float(positions[1])+float(positions[3]),
float(positions[0])+float(positions[2])
)
}
class CSVToBioSamples(CSVToSamples):
"""Iterator that reads a CSV and creates Samples.
"""
def __iter__(self):
for sample in super().__iter__():
# TODO test that fields are present? (attack_type for vuln?)
yield sample
return annotations
class ReplayMobileCSVFrameSampleLoader(CSVToSampleLoaderBiometrics):
"""A loader transformer returning a specific frame of a video file.
class ReplayMobileBioDatabase(FileListDatabase, Database):
"""Database for Replay-mobile-img for vulnerability analysis
This is specifically tailored for replay-mobile. It uses a specific loader
that takes the capturing device as input.
"""
def __init__(
self,
dataset_protocols_path,
protocol,
data_path,
data_extension=".mov",
annotations_path=None,
**kwargs,
dataset_original_directory="",
extension="",
reference_id_equal_subject_id=True,
):
super().__init__(
dataset_protocols_path,
protocol,
reader_cls=CSVToBioSamples,
transformer=make_pipeline(
VideoReader(data_path=data_path, file_ext=data_extension),
VideoToFrames(range(12,251,24)),
AnnotationsAdder(annotations_path),
data_loader=None,
extension=extension,
dataset_original_directory=dataset_original_directory,
)
self.reference_id_equal_subject_id = reference_id_equal_subject_id
def convert_row_to_sample(self, row, header):
"""Creates a set of samples given a row of the CSV protocol definition.
"""
path = row[0]
reference_id = row[1]
id = row[2] # Will be used as 'key'
kwargs = dict([[str(h).lower(), r] for h, r in zip(header[3:], row[3:])])
if self.reference_id_equal_subject_id:
kwargs["subject_id"] = reference_id
else:
if "subject_id" not in kwargs:
raise ValueError(f"`subject_id` not available in {header}")
# One row leads to multiple samples (different frames)
all_samples = [DelayedSample(
functools.partial(
load_frame_from_file_replaymobile,
file_name=os.path.join(self.dataset_original_directory, path + self.extension),
frame=frame,
capturing_device=kwargs["capturing_device"],
),
key=f"{id}_{frame}",
path=path,
reference_id=reference_id,
frame=frame,
**kwargs,
) for frame in range(12,251,24)]
return all_samples
class FrameBoundingBoxAnnotationLoader(AnnotationsLoader):
"""A transformer that adds bounding-box to a sample from annotations files.
Parameters
----------
annotation_directory: str or None
"""
def __init__(self,
annotation_directory=None,
annotation_extension=".face",
**kwargs
):
super().__init__(
annotation_directory=annotation_directory,
annotation_extension=annotation_extension,
**kwargs
)
self.annotations_path = self.dataset_protocols_path if not annotations_path else annotations_path # TODO default to protocol_path?
self.annotation_type = "eyes-center"
self.fixed_positions = None
def groups(self):
names = list_dir(self.dataset_protocols_path, self.protocol, files=False)
names = [os.path.splitext(n)[0] for n in names]
return names
def list_file(self, group, purpose):
if purpose == "enroll":
purpose_name = "for_models"
elif purpose == "probe":
purpose_name = "for_probes"
elif purpose == "train":
purpose_name = "train_world"
else:
raise ValueError(f"Unknown purpose '{purpose}'.")
# Protocol files are in the form <db_name>/{dev,eval,train}/{for_models,for_probes}.csv
list_file = search_file(
self.dataset_protocols_path,
os.path.join(self.protocol, group, purpose_name + ".csv"),
)
return list_file
def transform(self, X):
"""Adds the bounding-box annotations to a series of samples.
"""
if self.annotation_directory is None:
return None
def get_reader(self, group, purpose): # TODO use the standard csv format instead?
key = (self.protocol, group, purpose)
if key not in self.readers:
self.readers[key] = self.reader_cls(
list_file=self.list_file(group, purpose), transformer=self.transformer
annotated_samples = []
for x in X:
# Build the path to the annotation files structure
annotation_file = os.path.join(
self.annotation_directory, x.path + self.annotation_extension
)
reader = self.readers[key]
return reader
annotated_samples.append(
DelayedSample(
x._load,
parent=x,
delayed_attributes=dict(
annotations=functools.partial(
read_frame_annotation_file_replaymobile,
file_name=annotation_file,
frame=int(x.frame),
)
),
)
)
def samples(self, groups, purpose):
groups = check_parameters_for_validity(
groups, "groups", self.groups(), self.groups()
)
all_samples = []
for grp in groups:
return annotated_samples
for sample in self.get_reader(grp, purpose):
all_samples.append(sample)
class ReplayMobileBioDatabase(CSVDataset):
"""Database interface that loads a csv definition for replay-mobile
return all_samples
Looks for the protocol definition files (structure of CSV files). If not
present, downloads them.
Then sets the data and annotation paths from __init__ parameters or from
the configuration (``bob config`` command).
def background_model_samples(self):
return self.samples(groups="train", purpose="train")
Parameters
----------
def references(self, group):
return self.samples(groups=group, purpose="enroll")
protocol_name: str
The protocol to use
def probes(self, group):
return self.samples(groups=group, purpose="probe")
protocol_definition_path: str or None
Specifies a path to download the database definition to.
If None: Downloads and uses the ``bob_data_folder`` config.
(See :py:fct:`bob.extension.download.get_file`)
def all_samples(self, groups):
return super().all_samples(groups=groups)
data_path: str or None
Overrides the config-defined data location.
If None: uses the ``bob.db.replaymobile.directory`` config.
If None and the config does not exist, set as cwd.
annotation_path: str or None
Overrides the config-defined annotation files location.
If None: uses the ``bob.db.replaymobile.annotation_directory`` config.
If None and the config does not exist, set as
``{data_path}/faceloc/rect``.
"""
def __init__(
self,
protocol_name="bio-grandtest",
protocol_definition_path=None,
data_path=None,
annotation_path=None,
**kwargs
):
if protocol_definition_path is None:
# Downloading database description files if it is not specified
urls = [
"https://www.idiap.ch/software/bob/databases/latest/replay-mobile-csv.tar.gz",
"http://www.idiap.ch/software/bob/databases/latest/replay-mobile-csv.tar.gz",
]
protocol_definition_path = get_file("replay-mobile-csv.tar.gz", urls)
if data_path is None:
# Defaults to cwd if config not defined
data_path = rc.get("bob.db.replaymobile.directory", "")
if annotation_path is None:
# Defaults to {data_path}/faceloc/rect if config not defined
annotation_path = rc.get(
"bob.db.replaymobile.annotation_directory",
os.path.join(data_path, "faceloc/rect/")
)
logger.info(f"Database: Loading database definition from '{protocol_definition_path}'.")
logger.info(f"Database: Defining data files path as '{data_path}'.")
logger.info(f"Database: Defining annotation files path as '{annotation_path}'.")
super().__init__(
protocol_definition_path,
protocol_name,
csv_to_sample_loader=make_pipeline(
ReplayMobileCSVFrameSampleLoader(
dataset_original_directory=data_path,
extension=".mov",
),
FrameBoundingBoxAnnotationLoader(
annotation_directory=annotation_path,
annotation_extension=".face",
),
),
**kwargs
)
self.annotation_type = "bounding-box"
self.fixed_positions = None
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment