From a9b50665d52f0889c361a3477ff8e4702910db5c Mon Sep 17 00:00:00 2001 From: Guillaume HEUSCH <guillaume.heusch@idiap.ch> Date: Wed, 16 Jan 2019 11:25:03 +0100 Subject: [PATCH] fixed casiasurf high-level implementation, put it in __init__ and as an entry point --- bob/pad/face/database/__init__.py | 1 + bob/pad/face/database/casiasurf.py | 44 ++++++++++++++++++++---------- setup.py | 1 + 3 files changed, 31 insertions(+), 15 deletions(-) diff --git a/bob/pad/face/database/__init__.py b/bob/pad/face/database/__init__.py index 21744207..eb70097d 100644 --- a/bob/pad/face/database/__init__.py +++ b/bob/pad/face/database/__init__.py @@ -7,6 +7,7 @@ from .mifs import MIFSPadDatabase from .batl import BatlPadDatabase from .celeb_a import CELEBAPadDatabase from .maskattack import MaskAttackPadDatabase +from .casiasurf import CasiaSurfPadDatabase # gets sphinx autodoc done right - don't remove it def __appropriate__(*args): diff --git a/bob/pad/face/database/casiasurf.py b/bob/pad/face/database/casiasurf.py index b285fc63..9c609fb4 100644 --- a/bob/pad/face/database/casiasurf.py +++ b/bob/pad/face/database/casiasurf.py @@ -24,7 +24,7 @@ class CasiaSurfPadFile(VideoPadFile): """ - def __init__(self, f, stream_type): + def __init__(self, f, stream_type, attack_type): """ Init Parameters @@ -63,7 +63,7 @@ class CasiaSurfPadFile(VideoPadFile): """ # get the dict of numpy array - data = self.f.load(directory, extension, modality=self.modality) + data = self.f.load(directory, extension, modality=self.stream_type) # convert that to dict of FrameContainer data_to_return = {} @@ -93,21 +93,35 @@ class CasiaSurfPadDatabase(PadDatabase): the group names in the high-level interface (train, dev, eval) """ + + def __init__(self, protocol='all', original_directory=None, original_extension='.jpg', **kwargs): + """Init function + + Parameters + ---------- + protocol : :py:class:`str` + The name of the protocol that defines the default experimental setup for this database. + original_directory : :py:class:`str` + The directory where the original data of the database are stored. + original_extension : :py:class:`str` + The file name extension of the original data. - from bob.db.casiasurf import Database as LowLevelDatabase - self.db = LowLevelDatabase() + """ + + from bob.db.casiasurf import Database as LowLevelDatabase + self.db = LowLevelDatabase() - self.low_level_group_names = ('train', 'validation', 'test') - self.high_level_group_names = ('train', 'dev', 'eval') + self.low_level_group_names = ('train', 'validation', 'test') + self.high_level_group_names = ('train', 'dev', 'eval') - super(CasiaSurfPadDatabase, self).__init__( - name='casiasurf', - protocol=protocol, - original_directory=original_directory, - original_extension=original_extension, - **kwargs) + super(CasiaSurfPadDatabase, self).__init__( + name='casiasurf', + protocol=protocol, + original_directory=original_directory, + original_extension=original_extension, + **kwargs) - @property + @property def original_directory(self): return self.db.original_directory @@ -160,8 +174,8 @@ class CasiaSurfPadDatabase(PadDatabase): if ('dev' in groups or 'test' in groups) and purposes == 'attack': lowlevel_purposes.append('unknown') - samples = self.db.objects(sets=groups, purposes=lowlevel_purposes, **kwargs) - samples = [CasiaSurfPadFile(s) for s in samples] + samples = self.db.objects(groups=groups, purposes=lowlevel_purposes, **kwargs) + samples = [CasiaSurfPadFile(s, stream_type=protocol, attack_type=s.attack_type) for s in samples] return samples diff --git a/setup.py b/setup.py index 89ae3f14..0f6ffa62 100644 --- a/setup.py +++ b/setup.py @@ -74,6 +74,7 @@ setup( 'batl-db-thermal = bob.pad.face.config.batl_db_thermal:database', 'celeb-a = bob.pad.face.config.celeb_a:database', 'maskattack = bob.pad.face.config.maskattack:database', + 'casiasurf = bob.pad.face.config.casiasurf:database', ], # registered configurations: -- GitLab