From a9b50665d52f0889c361a3477ff8e4702910db5c Mon Sep 17 00:00:00 2001
From: Guillaume HEUSCH <guillaume.heusch@idiap.ch>
Date: Wed, 16 Jan 2019 11:25:03 +0100
Subject: [PATCH] fixed casiasurf high-level implementation, put it in __init__
 and as an entry point

---
 bob/pad/face/database/__init__.py  |  1 +
 bob/pad/face/database/casiasurf.py | 44 ++++++++++++++++++++----------
 setup.py                           |  1 +
 3 files changed, 31 insertions(+), 15 deletions(-)

diff --git a/bob/pad/face/database/__init__.py b/bob/pad/face/database/__init__.py
index 21744207..eb70097d 100644
--- a/bob/pad/face/database/__init__.py
+++ b/bob/pad/face/database/__init__.py
@@ -7,6 +7,7 @@ from .mifs import MIFSPadDatabase
 from .batl import BatlPadDatabase
 from .celeb_a import CELEBAPadDatabase
 from .maskattack import MaskAttackPadDatabase
+from .casiasurf import CasiaSurfPadDatabase
 
 # gets sphinx autodoc done right - don't remove it
 def __appropriate__(*args):
diff --git a/bob/pad/face/database/casiasurf.py b/bob/pad/face/database/casiasurf.py
index b285fc63..9c609fb4 100644
--- a/bob/pad/face/database/casiasurf.py
+++ b/bob/pad/face/database/casiasurf.py
@@ -24,7 +24,7 @@ class CasiaSurfPadFile(VideoPadFile):
     
     """
 
-    def __init__(self, f, stream_type):
+    def __init__(self, f, stream_type, attack_type):
       """ Init
 
       Parameters
@@ -63,7 +63,7 @@ class CasiaSurfPadFile(VideoPadFile):
         """
         
         # get the dict of numpy array
-        data = self.f.load(directory, extension, modality=self.modality)
+        data = self.f.load(directory, extension, modality=self.stream_type)
       
         # convert that to dict of FrameContainer
         data_to_return = {}
@@ -93,21 +93,35 @@ class CasiaSurfPadDatabase(PadDatabase):
       the group names in the high-level interface (train, dev, eval)
 
     """
+       
+    def __init__(self, protocol='all', original_directory=None, original_extension='.jpg', **kwargs):
+      """Init function
+
+        Parameters
+        ----------
+        protocol : :py:class:`str`
+          The name of the protocol that defines the default experimental setup for this database.
+        original_directory : :py:class:`str`
+          The directory where the original data of the database are stored.
+        original_extension : :py:class:`str`
+          The file name extension of the original data.
         
-    from bob.db.casiasurf import Database as LowLevelDatabase
-    self.db = LowLevelDatabase()
+      """
+
+      from bob.db.casiasurf import Database as LowLevelDatabase
+      self.db = LowLevelDatabase()
 
-    self.low_level_group_names = ('train', 'validation', 'test')  
-    self.high_level_group_names = ('train', 'dev', 'eval')
+      self.low_level_group_names = ('train', 'validation', 'test')  
+      self.high_level_group_names = ('train', 'dev', 'eval')
 
-    super(CasiaSurfPadDatabase, self).__init__(
-        name='casiasurf',
-        protocol=protocol,
-        original_directory=original_directory,
-        original_extension=original_extension,
-        **kwargs)
+      super(CasiaSurfPadDatabase, self).__init__(
+          name='casiasurf',
+          protocol=protocol,
+          original_directory=original_directory,
+          original_extension=original_extension,
+          **kwargs)
 
-     @property
+    @property
     def original_directory(self):
         return self.db.original_directory
 
@@ -160,8 +174,8 @@ class CasiaSurfPadDatabase(PadDatabase):
           if ('dev' in groups or 'test' in groups) and purposes == 'attack':
             lowlevel_purposes.append('unknown')
 
-        samples = self.db.objects(sets=groups, purposes=lowlevel_purposes, **kwargs)
-        samples = [CasiaSurfPadFile(s) for s in samples]
+        samples = self.db.objects(groups=groups, purposes=lowlevel_purposes, **kwargs)
+        samples = [CasiaSurfPadFile(s, stream_type=protocol, attack_type=s.attack_type) for s in samples]
 
         return samples
 
diff --git a/setup.py b/setup.py
index 89ae3f14..0f6ffa62 100644
--- a/setup.py
+++ b/setup.py
@@ -74,6 +74,7 @@ setup(
             'batl-db-thermal = bob.pad.face.config.batl_db_thermal:database',
             'celeb-a = bob.pad.face.config.celeb_a:database',
             'maskattack = bob.pad.face.config.maskattack:database',
+            'casiasurf = bob.pad.face.config.casiasurf:database',
         ],
 
         # registered configurations:
-- 
GitLab