Skip to content
Snippets Groups Projects
Commit 27e2f462 authored by Ketan Kotwal's avatar Ketan Kotwal
Browse files

code cleanup for databases and their configs

parent f04958a7
No related branches found
No related tags found
No related merge requests found
......@@ -8,17 +8,13 @@ mask-based presentation attacks
from bob.paper.nir_patch_pooling.database import MLFPDatabase
from bob.extension import rc
PROTOCOL = "grandtest"
database = MLFPDatabase(
protocol=PROTOCOL,
original_directory=rc["bob.db.mlfp.directory"],
original_extension=".hdf5",
annotation_directory=rc["bob.db.mlfp.annotation_directory"],
training_depends_on_protocol=True,
)
groups = ["train", "dev"]
protocol = PROTOCOL
#------------------------------------------------------------------------------
......@@ -9,18 +9,18 @@ from bob.paper.nir_patch_pooling.database import WMCAMask
from bob.extension import rc
PROTOCOL = "grandtest-nir-50"
PROTOCOL = "grandtest"
database = WMCAMask(
protocol = PROTOCOL,
original_directory = rc["bob.db.wmca_mask.directory"],
annotation_directory = rc["bob.db.wmca_mask.annotation_directory"],
landmark_detect_method = "mtcnn"
)
groups = ["train", "dev", "eval"]
protocol = PROTOCOL
#---
#------------------------------------------------------------------------------
#!/usr/bin/env python2
# -*- coding: utf-8 -*-
"""
Implementation of dataset interface of MLFP for PAD.
Implementation of database interface of MLFP dataset for Face PAD.
This protocol caters to only NIR subset of MLFP dataset.
@author: Ketan Kotwal
"""
# Imports
from bob.pad.base.database import FileListPadDatabase
from bob.pad.base.database import PadFile
from bob.pad.face.database import VideoPadFile
from bob.bio.video import FrameSelector, FrameContainer
import bob.io.base
......@@ -23,63 +22,41 @@ import logging
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
class RGBPadFile(VideoPadFile):
"""
A high level implementation of the File class for the RGB images database.
"""
#------------------------------------------------------------------------------
class File(VideoPadFile):
def __init__(self, attack_type, client_id, path, file_id=None):
super(RGBPadFile, self).__init__(attack_type, client_id, path, file_id)
#----------------------------------------------------------
def load(self, directory=None, extension=None, frame_selector=FrameSelector(selection_style='all')):
"""
Overridden version of the load method defined in the ``VideoPadFile``.
**Parameters:**
``directory`` : :py:class:`str`
String containing the path to the MIFS database.
Default: None
``extension`` : :py:class:`str`
Extension of the video files in the MIFS database.
Default: None
``frame_selector`` : ``FrameSelector``
The frame selector to use.
**Returns:**
``video_data`` : FrameContainer
Video data stored in the FrameContainer, see ``bob.bio.video.utils.FrameContainer``
for further details.
"""
super(File, self).__init__(attack_type, client_id, path, file_id)
#------------------------------------------------------------------------------
def load(self, directory=None, extension=None,
frame_selector=FrameSelector(selection_style='all')):
path = self.make_path(directory=directory, extension=extension)
hdf_file = h5py.File(path)
fc = FrameContainer()
data = h5py.File(path)
for i, data1 in enumerate(data.keys()):
frame = data[data1]["array"].value
fc.add(i, frame, None)
for idx, frame_data in enumerate(hdf_file.keys()):
frame = hdf_file[frame_data]["array"].value
fc.add(idx, frame, None)
data.close()
#data = np.expand_dims(data, axis=0) # upgrade to 4D (video)
#video_data = frame_selector(data) # video data
hdf_file.close()
return fc
#----------------------------------------------------------
#------------------------------------------------------------------------------
#------------------------------------------------------------------------------
#----------------------------------------------------------
class MLFPDatabase(FileListPadDatabase):
"""
A high level implementation of the Database class for the MLFP
dataset.
A high level implementation of the Database class for NIR content of the
MLFP dataset.
"""
def __init__(
......@@ -87,31 +64,38 @@ class MLFPDatabase(FileListPadDatabase):
name = "MLFP",
original_directory = None,
original_extension = ".hdf5",
protocol = "grandtest",
annotation_directory = None,
pad_file_class = RGBPadFile,
**kwargs):
pad_file_class = File,
**kwargs
):
self.annotation_directory = annotation_directory
filelists_directory = pkg_resources.resource_filename( __name__, "/lists/mlfp/")
self.filelists_directory = filelists_directory
super(MLFPDatabase, self).__init__(
filelists_directory=filelists_directory,
name=name,
protocol=protocol,
original_directory=original_directory,
original_extension=original_extension,
pad_file_class=pad_file_class,
annotation_directory=annotation_directory,
**kwargs)
filelists_directory = filelists_directory,
name = name,
original_directory = original_directory,
original_extension = original_extension,
annotation_directory = annotation_directory,
pad_file_class = pad_file_class,
**kwargs,
)
self.annotation_directory = annotation_directory
logger.info("Dataset: {}".format(self.name))
logger.info("Original directory: {}; Annotation directory: {}".format(self.original_directory, self.annotation_directory))
logger.info("Original directory: {}; Annotation directory: {}"\
.format(self.original_directory, self.annotation_directory))
#------------------------------------------------------------------------------
#----------------------------------------------------------
def annotations(self, f):
if self.annotation_directory is None:
raise ValueError("Annotation Directory is not provided.")
file_path = os.path.join(self.annotation_directory, f.path + ".json")
# if file exists, load the annotations
......@@ -123,12 +107,13 @@ class MLFPDatabase(FileListPadDatabase):
if not annotations: # if dictionary is empty
logger.warning("Empty annotations for %s", f.path)
return None
return annotations
else:
else:
logger.warning("Annotation file for %s does not exist. (Overall path: %s)", f.path, file_path)
return None
#----------------------------------------------------------
#------------------------------------------------------------------------------
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
......@@ -19,13 +18,13 @@ import json
import os
import bob.io.base
import pkg_resources
#---
# Constants
valid_protocols = ["grandtest", "cv"]
stream_batl_convention = {"nir" : "infrared", "nirhq" : "infrared_high_def"}
import logging
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
#------------------------------------------------------------------------------
#--------
class File(VideoFile):
def __init__(self, path, client_id, session_id, presenter_id, type_id, pai_id):
......@@ -36,26 +35,35 @@ class File(VideoFile):
session_id = session_id,
presenter_id = presenter_id,
type_id = type_id,
pai_id = pai_id)
pai_id = pai_id
)
self.id = path
#------------------------------------------------------------------------------
def load(self, directory=None, extension=None,
# auxiliary function to load existing preprocessed files stored as *.hdf5
def load_aux(self, directory=None, extension=None,
frame_selector=FrameSelector()):
path = self.make_path(directory, extension)
print(path)
# if loading a preprocessed data
if path.endswith('hdf5'):
with bob.io.base.HDF5File(path) as f:
return FrameContainer(hdf5=f)
else:
raise NotImplementedError
raise NotImplementedError("No loading method for {} extension"\
.format(extension))
#---------
#------------------------------------------------------------------------------
#------------------------------------------------------------------------------
class WMCAMask(FileListPadDatabase):
class MaskDatabase(FileListPadDatabase):
"""
A high level implementation of the Database class for the WMCA Mask PAD
database.
......@@ -63,15 +71,14 @@ class MaskDatabase(FileListPadDatabase):
def __init__(
self,
name="Mask Database",
original_directory=None,
original_extension=".h5",
protocol="grandtest",
annotation_directory=None,
name = "WMCAMask",
original_directory = None,
original_extension = ".h5",
protocol = "grandtest",
annotation_directory = None,
pad_file_class = BatlPadFile,
low_level_pad_file_class = File,
landmark_detect_method = "mtcnn",
**kwargs
**kwargs,
):
......@@ -79,30 +86,26 @@ class MaskDatabase(FileListPadDatabase):
**Parameters:**
``original_directory`` : str or None
original directory refers to the location of XCSMAD parent directory
original directory refers to the location of WMCA (or WMCA Mask)
parent directory
``original_extension`` : str or None
extension of original data
``groups`` : str or [str]
The groups for which the clients should be returned.
Usually, groups are one or more elements of ['train', 'dev', 'eval'].
Default: ['train', 'dev', 'eval'].
``protocol`` : str
The protocol for which the clients should be retrieved.
Default: 'grandtest'.
Default: 'grandtest-nir-50'.
"""
filelists_directory = pkg_resources.resource_filename( __name__, "lists/")
filelists_directory = pkg_resources.resource_filename( __name__, "lists/wmca_mask/")
self.filelists_directory = filelists_directory
# init the parent class using super.
super(MaskDatabase, self).__init__(
super(WMCAMask, self).__init__(
filelists_directory = filelists_directory,
name = name,
protocol=protocol,
protocol = protocol,
original_directory = original_directory,
original_extension = original_extension,
pad_file_class = low_level_pad_file_class,
......@@ -112,39 +115,22 @@ class MaskDatabase(FileListPadDatabase):
self.low_level_pad_file_class = low_level_pad_file_class
self.pad_file_class = pad_file_class
self.annotation_directory = annotation_directory
self.landmark_detect_method = landmark_detect_method
self.protocol = protocol
#----
def _split_protocol(self, protocol):
logger.info("Dataset: {}".format(self.name))
logger.info("Original directory: {}; Annotation directory: {}"\
.format(self.original_directory, self.annotation_directory))
split_protocol = protocol.split("-") + [None, None]
#------------------------------------------------------------------------------
protocol = split_protocol[0]
stream_type = split_protocol[1]
if(stream_type in stream_batl_convention):
stream_type = stream_batl_convention[stream_type]
else:
stream_type = "color"
num_frames = split_protocol[2]
if(num_frames is not None):
num_frames = int(num_frames)
else:
num_frames = 50
return protocol, stream_type, num_frames
#----
# override the _make_pad function in bob.pad.base since we want the PAD
# files to be of BatlPad class
# override the _make_pad function in bob.pad.base
def _make_pad(self, files):
low_level_files = []
video_pad_files = []
for f in files:
path = f.path
client_id = f.client_id
info = path.split("/")[-1]
......@@ -152,19 +138,24 @@ class MaskDatabase(FileListPadDatabase):
presenter_id = int(info.split("_")[2])
type_id = int(info.split("_")[3])
pai_id = int(info.split("_")[4])
v_file = self.low_level_pad_file_class(path=path,
client_id=client_id,
session_id=session_id,
presenter_id=presenter_id,
type_id=type_id,
pai_id=pai_id)
low_level_files.append(v_file)
video_file = self.low_level_pad_file_class(
path = path,
client_id = client_id,
session_id = session_id,
presenter_id = presenter_id,
type_id = type_id,
pai_id = pai_id
)
video_pad_files.append(video_file)
return low_level_files
#----
return video_pad_files
def objects(self, groups=None, protocol=None, purposes=None, model_ids=None, **kwargs):
#------------------------------------------------------------------------------
def objects(self, groups=None, protocol=None, purposes=None,
model_ids=None, **kwargs):
# default the parameters if the values are not provided
if protocol is None:
......@@ -174,33 +165,32 @@ class MaskDatabase(FileListPadDatabase):
groups = ["train", "dev", "eval"]
# parse the protocol to extract necessary information
protocol, stream_type, num_frames = self._split_protocol(protocol)
protocol = protocol
stream_type = "nir"
num_frames = 50
# obtain the file list using the parent class's functionality
files = super(MaskDatabase, self).objects(groups=groups, protocol=protocol, purposes=purposes, model_ids=model_ids, **kwargs)
files = super(WMCAMask, self).objects(groups=groups, protocol=protocol,
purposes=purposes, model_ids=model_ids, **kwargs)
# create objects for each file where the class is BATLPadFile
files = [self.pad_file_class(f=f, stream_type=stream_type, max_frames=num_frames) for f in files]
files = [self.pad_file_class(f=f, stream_type=stream_type,
max_frames=num_frames) for f in files]
return files
#----
#------------------------------------------------------------------------------
def annotations(self, f):
"""
Computes annotations for a given file object ``f``, which
is an instance of the ``BatlPadFile`` class.
NOTE: you can pre-compute annotation in your first experiment
and then reuse them in other experiments setting
``self.annotations_temp_dir`` path of this class, where
precomputed annotations will be saved.
Returns annotations for a given file object ``f``.
The annotations must be precomputed using the script provided with the
package.
**Parameters:**
``f`` : :py:class:`object`
An instance of ``BatlPadFile`` defined above.
An instance of file object defined above.
**Returns:**
......@@ -214,45 +204,28 @@ class MaskDatabase(FileListPadDatabase):
face bounding box and landmarks in frame N.
"""
file_path = os.path.join(self.annotation_directory, f.f.path + ".json")
# if annotations do not exist, then generate.
if not os.path.isfile(file_path):
if self.annotation_directory is None:
raise ValueError("Annotation Directory is not provided.")
f.stream_type = "color"
f.reference_stream_type = "color"
f.warp_to_reference = False
f.convert_to_rgb = False
f.crop = None
f.video_data_only = True
video = f.load(directory=self.original_directory, extension=self.original_extension)
annotations = {}
for idx, image in enumerate(video.as_array()):
frame_annotations = detect_face_landmarks_in_image(image, method=self.landmark_detect_method)
if frame_annotations:
annotations[str(idx)] = frame_annotations
if self.annotation_directory: # if directory is not an empty string
bob.io.base.create_directories_safe(directory=os.path.split(file_path)[0], dryrun=False)
with open(file_path, 'w+') as json_file:
json_file.write(json.dumps(annotations))
file_path = os.path.join(self.annotation_directory, f.f.path + ".json")
# if annotations exist, then load.
else:
if os.path.isfile(file_path):
with open(file_path, 'r') as json_file:
with open(file_path, "r") as json_file:
annotations = json.load(json_file)
if not annotations: # if dictionary is empty
logger.warning("Empty annotations for %s", f.path)
return None
return annotations
# if dictionary is empty
if not annotations:
else:
logger.warning("Annotation file for %s does not exist. (Overall path: %s)", f.path, file_path)
return None
return annotations
#-----
#------------------------------------------------------------------------------
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment