Skip to content
Snippets Groups Projects
Commit a9b50665 authored by Guillaume HEUSCH's avatar Guillaume HEUSCH
Browse files

fixed casiasurf high-level implementation, put it in __init__ and as an entry point

parent d40132d2
No related branches found
No related tags found
1 merge request!77CASIA-SURF database
Pipeline #26002 passed
......@@ -7,6 +7,7 @@ from .mifs import MIFSPadDatabase
from .batl import BatlPadDatabase
from .celeb_a import CELEBAPadDatabase
from .maskattack import MaskAttackPadDatabase
from .casiasurf import CasiaSurfPadDatabase
# gets sphinx autodoc done right - don't remove it
def __appropriate__(*args):
......
......@@ -24,7 +24,7 @@ class CasiaSurfPadFile(VideoPadFile):
"""
def __init__(self, f, stream_type):
def __init__(self, f, stream_type, attack_type):
""" Init
Parameters
......@@ -63,7 +63,7 @@ class CasiaSurfPadFile(VideoPadFile):
"""
# get the dict of numpy array
data = self.f.load(directory, extension, modality=self.modality)
data = self.f.load(directory, extension, modality=self.stream_type)
# convert that to dict of FrameContainer
data_to_return = {}
......@@ -93,21 +93,35 @@ class CasiaSurfPadDatabase(PadDatabase):
the group names in the high-level interface (train, dev, eval)
"""
def __init__(self, protocol='all', original_directory=None, original_extension='.jpg', **kwargs):
"""Init function
Parameters
----------
protocol : :py:class:`str`
The name of the protocol that defines the default experimental setup for this database.
original_directory : :py:class:`str`
The directory where the original data of the database are stored.
original_extension : :py:class:`str`
The file name extension of the original data.
from bob.db.casiasurf import Database as LowLevelDatabase
self.db = LowLevelDatabase()
"""
from bob.db.casiasurf import Database as LowLevelDatabase
self.db = LowLevelDatabase()
self.low_level_group_names = ('train', 'validation', 'test')
self.high_level_group_names = ('train', 'dev', 'eval')
self.low_level_group_names = ('train', 'validation', 'test')
self.high_level_group_names = ('train', 'dev', 'eval')
super(CasiaSurfPadDatabase, self).__init__(
name='casiasurf',
protocol=protocol,
original_directory=original_directory,
original_extension=original_extension,
**kwargs)
super(CasiaSurfPadDatabase, self).__init__(
name='casiasurf',
protocol=protocol,
original_directory=original_directory,
original_extension=original_extension,
**kwargs)
@property
@property
def original_directory(self):
return self.db.original_directory
......@@ -160,8 +174,8 @@ class CasiaSurfPadDatabase(PadDatabase):
if ('dev' in groups or 'test' in groups) and purposes == 'attack':
lowlevel_purposes.append('unknown')
samples = self.db.objects(sets=groups, purposes=lowlevel_purposes, **kwargs)
samples = [CasiaSurfPadFile(s) for s in samples]
samples = self.db.objects(groups=groups, purposes=lowlevel_purposes, **kwargs)
samples = [CasiaSurfPadFile(s, stream_type=protocol, attack_type=s.attack_type) for s in samples]
return samples
......
......@@ -74,6 +74,7 @@ setup(
'batl-db-thermal = bob.pad.face.config.batl_db_thermal:database',
'celeb-a = bob.pad.face.config.celeb_a:database',
'maskattack = bob.pad.face.config.maskattack:database',
'casiasurf = bob.pad.face.config.casiasurf:database',
],
# registered configurations:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment