diff --git a/bob/pad/face/config/casiasurf.py b/bob/pad/face/config/casiasurf.py new file mode 100644 index 0000000000000000000000000000000000000000..6a8489402b80959eb5a2307d896c5888f34fb451 --- /dev/null +++ b/bob/pad/face/config/casiasurf.py @@ -0,0 +1,11 @@ +#!/usr/bin/env python +# encoding: utf-8 + +from bob.pad.face.database import CasiaSurfPadDatabase +from bob.extension import rc + +database = CasiaSurfPadDatabase( + protocol='all', + original_directory=rc['bob.db.casiasurf.directory'], + original_extension=".jpg", +) diff --git a/bob/pad/face/config/casiasurf_color.py b/bob/pad/face/config/casiasurf_color.py new file mode 100644 index 0000000000000000000000000000000000000000..4d7cde9212006855cff26726c3fa3e70accc3de7 --- /dev/null +++ b/bob/pad/face/config/casiasurf_color.py @@ -0,0 +1,11 @@ +#!/usr/bin/env python +# encoding: utf-8 + +from bob.pad.face.database import CasiaSurfPadDatabase +from bob.extension import rc + +database = CasiaSurfPadDatabase( + protocol='color', + original_directory=rc['bob.db.casiasurf.directory'], + original_extension=".jpg", +) diff --git a/bob/pad/face/database/__init__.py b/bob/pad/face/database/__init__.py index 217442075022a4dbcaed66cdd292c9dfff6d223d..753f5ef010ae763e4886a7213d05a2a327aeeabd 100644 --- a/bob/pad/face/database/__init__.py +++ b/bob/pad/face/database/__init__.py @@ -7,6 +7,7 @@ from .mifs import MIFSPadDatabase from .batl import BatlPadDatabase from .celeb_a import CELEBAPadDatabase from .maskattack import MaskAttackPadDatabase +from .casiasurf import CasiaSurfPadDatabase # gets sphinx autodoc done right - don't remove it def __appropriate__(*args): @@ -33,7 +34,8 @@ __appropriate__( MIFSPadDatabase, BatlPadDatabase, CELEBAPadDatabase, - MaskAttackPadDatabase + MaskAttackPadDatabase, + CasiaSurfPadDatabase ) __all__ = [_ for _ in dir() if not _.startswith('_')] diff --git a/bob/pad/face/database/casiasurf.py b/bob/pad/face/database/casiasurf.py new file mode 100644 index 0000000000000000000000000000000000000000..6590c3af4eaa3c98b5a2e3dc12aa78e4a09214ea --- /dev/null +++ b/bob/pad/face/database/casiasurf.py @@ -0,0 +1,180 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + + +import os +import numpy as np +import bob.io.video +from bob.bio.video import FrameSelector, FrameContainer +from bob.pad.face.database import VideoPadFile +from bob.pad.base.database import PadDatabase + +from bob.extension import rc + +class CasiaSurfPadFile(VideoPadFile): + """ + A high level implementation of the File class for the CASIA-SURF database. + + Note that this does not represent a file per se, but rather a sample + that may contain more than one file. + + Attributes + ---------- + f : :py:class:`object` + An instance of the Sample class defined in the low level db interface + of the CASIA-SURF database, in the bob.db.casiasurf.models.py file. + + """ + + def __init__(self, s, stream_type): + """ Init + + Parameters + ---------- + s : :py:class:`object` + An instance of the Sample class defined in the low level db interface + of the CASIA-SURF database, in the bob.db.casiasurf.models.py file. + stream_type: str of list of str + The streams to be loaded. + """ + self.s = s + self.stream_type = stream_type + if not isinstance(s.attack_type, str): + attack_type = str(s.attack_type) + else: + attack_type = s.attack_type + + super(CasiaSurfPadFile, self).__init__( + client_id=s.id, + file_id=s.id, + attack_type=attack_type, + path=s.id) + + + def load(self, directory=rc['bob.db.casiasurf.directory'], extension='.jpg', frame_selector=FrameSelector(selection_style='all')): + """Overloaded version of the load method defined in ``VideoPadFile``. + + Parameters + ---------- + directory : :py:class:`str` + String containing the path to the CASIA-SURF database + extension : :py:class:`str` + Extension of the image files + frame_selector : :py:class:`bob.bio.video.FrameSelector` + The frame selector to use. + + Returns + ------- + dict: + image data for multiple streams stored in the dictionary. + The structure of the dictionary: ``data={"stream1_name" : numpy array, "stream2_name" : numpy array}`` + Names of the streams are defined in ``self.stream_type``. + """ + return self.s.load(directory, extension, modality=self.stream_type) + + +class CasiaSurfPadDatabase(PadDatabase): + """High level implementation of the Database class for the 3DMAD database. + + Note that at the moment, this database only contains a training and validation set. + + The protocol specifies the modality(ies) to load. + + Attributes + ---------- + db : :py:class:`bob.db.casiasurf.Database` + the low-level database interface + low_level_group_names : list of :py:obj:`str` + the group names in the low-level interface (world, dev, test) + high_level_group_names : list of :py:obj:`str` + the group names in the high-level interface (train, dev, eval) + + """ + + def __init__(self, protocol='all', original_directory=rc['bob.db.casiasurf.directory'], original_extension='.jpg', **kwargs): + """Init function + + Parameters + ---------- + protocol : :py:class:`str` + The name of the protocol that defines the default experimental setup for this database. + original_directory : :py:class:`str` + The directory where the original data of the database are stored. + original_extension : :py:class:`str` + The file name extension of the original data. + + """ + + from bob.db.casiasurf import Database as LowLevelDatabase + self.db = LowLevelDatabase() + + self.low_level_group_names = ('train', 'validation', 'test') + self.high_level_group_names = ('train', 'dev', 'eval') + + super(CasiaSurfPadDatabase, self).__init__( + name='casiasurf', + protocol=protocol, + original_directory=original_directory, + original_extension=original_extension, + **kwargs) + + @property + def original_directory(self): + return self.db.original_directory + + + @original_directory.setter + def original_directory(self, value): + self.db.original_directory = value + + def objects(self, + groups=None, + protocol='all', + purposes=None, + model_ids=None, + **kwargs): + """Returns a list of CasiaSurfPadFile objects, which fulfill the given restrictions. + + Parameters + ---------- + groups : list of :py:class:`str` + The groups of which the clients should be returned. + Usually, groups are one or more elements of ('train', 'dev', 'eval') + protocol : :py:class:`str` + The protocol for which the samples should be retrieved. + purposes : :py:class:`str` + The purposes for which Sample objects should be retrieved. + Usually it is either 'real' or 'attack', but could be 'unknown' as well + model_ids + This parameter is not supported in PAD databases yet. + + Returns + ------- + samples : :py:class:`CasiaSurfPadFilePadFile` + A list of CasiaSurfPadFile objects. + """ + + groups = self.convert_names_to_lowlevel(groups, self.low_level_group_names, self.high_level_group_names) + + if groups is not None: + + # for training + lowlevel_purposes = [] + if 'train' in groups and purposes == 'real': + lowlevel_purposes.append('real') + if 'train' in groups and purposes == 'attack': + lowlevel_purposes.append('attack') + + # for dev and eval + if ('validation' in groups or 'test' in groups) and 'attack' in purposes: + lowlevel_purposes.append('unknown') + + samples = self.db.objects(groups=groups, purposes=lowlevel_purposes, **kwargs) + samples = [CasiaSurfPadFile(s, stream_type=protocol) for s in samples] + return samples + + + def annotations(self, file): + """No annotations are provided with this DB + """ + return None diff --git a/bob/pad/face/test/test_databases.py b/bob/pad/face/test/test_databases.py index 38fe4510cdfed0372d14f62a612493ca2c16bf96..3ed9051075549159ce82de3f51c8ffe3ab4f2289 100644 --- a/bob/pad/face/test/test_databases.py +++ b/bob/pad/face/test/test_databases.py @@ -200,3 +200,22 @@ def test_aggregated_db(): "The database could not be queried; probably the db.sql3 file is missing. Here is the error: '%s'" % e) + +# Test the casiasurf database +@db_available('casiasurf') +def test_casiasurf(): + casiasurf = bob.bio.base.load_resource( + 'casiasurf', + 'database', + preferred_package='bob.pad.face', + package_prefix='bob.pad.') + try: + assert len(casiasurf.objects(groups=['train', 'dev'], purposes='real')) == 8942 + assert len(casiasurf.objects(groups=['train'], purposes='attack')) == 20324 + assert len(casiasurf.objects(groups=['dev'], purposes='real')) == 0 + assert len(casiasurf.objects(groups=['dev'], purposes='attack')) == 9608 + + except IOError as e: + raise SkipTest( + "The database could not be queried; probably the db.sql3 file is missing. Here is the error: '%s'" + % e) diff --git a/conda/meta.yaml b/conda/meta.yaml index 15ea77a3ace3461765597ad4f14b5a686e17be38..3283e022c579f3688d430d5a5a5fb8f1c57a811c 100644 --- a/conda/meta.yaml +++ b/conda/meta.yaml @@ -65,6 +65,8 @@ test: - bob.db.replaymobile - bob.db.msu_mfsd_mod - bob.db.mobio + - bob.db.maskattack + - bob.db.casiasurf - bob.rppg.base about: diff --git a/setup.py b/setup.py index bdd290b8954e7a34de517cdd0ff11462021d1ff4..2b2339db588cfdf150d9712353b824b0bc80eb12 100644 --- a/setup.py +++ b/setup.py @@ -75,6 +75,8 @@ setup( 'batl-db-rgb-ir-d-grandtest = bob.pad.face.config.database.batl.batl_db_rgb_ir_d_grandtest:database', 'celeb-a = bob.pad.face.config.celeb_a:database', 'maskattack = bob.pad.face.config.maskattack:database', + 'casiasurf-color = bob.pad.face.config.casiasurf_color:database', + 'casiasurf = bob.pad.face.config.casiasurf:database', ], # registered configurations: diff --git a/test-requirements.txt b/test-requirements.txt index 306a471cc175f084f79b51a77535a02ff7b7f44c..a2ec5b9057eaa8e301cd821942792508b416f054 100644 --- a/test-requirements.txt +++ b/test-requirements.txt @@ -3,3 +3,5 @@ bob.db.replay bob.db.replaymobile bob.db.msu_mfsd_mod bob.db.mobio +bob.db.maskattack +bob.db.casiasurf