From dafa65b07d2f1fa96ee60a77dd17c8a4fb8be45b Mon Sep 17 00:00:00 2001
From: Andre Anjos <andre.anjos@idiap.ch>
Date: Mon, 4 Sep 2017 14:38:46 +0200
Subject: [PATCH] Load only what is annotated (allow flexibility)

---
 bob/bio/vein/script/markdet.py | 18 +++++++++++++-----
 1 file changed, 13 insertions(+), 5 deletions(-)

diff --git a/bob/bio/vein/script/markdet.py b/bob/bio/vein/script/markdet.py
index a59b48c..c15bb50 100644
--- a/bob/bio/vein/script/markdet.py
+++ b/bob/bio/vein/script/markdet.py
@@ -99,7 +99,7 @@ def validate(args):
     '--batch': schema.Use(int),
     '--iterations': schema.Use(int),
     '<database>': lambda n: n in ('fv3d',),
-    '<protocol>': lambda n: n in ('central',),
+    '<protocol>': lambda n: n in ('central', 'left', 'right'),
     '<group>': lambda n: n in ('dev',),
     str: object, #ignores strings we don't care about
     }, ignore_extra_keys=True)
@@ -142,17 +142,20 @@ def main(user_input=None):
   from ..preprocessor.utils import poly_to_mask
   features = None
   target = None
+  loaded = 0
   for k, sample in enumerate(objects):
 
-    if args['--samples'] is not None and k >= args['--samples']: break
+    if args['--samples'] is not None and loaded >= args['--samples']:
+      break
     path = sample.make_path(directory=db.original_directory,
         extension=db.original_extension)
-    logger.info('Loading sample %d/%d (%s)...', k, len(objects), path)
+    logger.info('Loading sample %d/%d (%s)...', loaded, len(objects), path)
     image = sample.load(directory=db.original_directory,
         extension=db.original_extension)
     if not (hasattr(image, 'metadata') and 'roi' in image.metadata):
       logger.info('Skipping sample (no ROI)')
       continue
+    loaded += 1
 
     # copy() required by skimage.util.shape.view_as_windows()
     image = image.copy().astype('float64') / 255.
@@ -169,12 +172,17 @@ def main(user_input=None):
     mask = mask[1:-1, 1:-1]
     for y in range(windows.shape[0]):
       for x in range(windows.shape[1]):
-        idx = (k*windows.shape[0]*windows.shape[1]) + (y*windows.shape[1]) + x
+        idx = ((loaded-1)*windows.shape[0]*windows.shape[1]) + \
+            (y*windows.shape[1]) + x
         features[idx,:-2] = windows[y,x].flatten()
         features[idx,-2] = y+1
         features[idx,-1] = x+1
         target[idx] = mask[y,x]
 
+  # if number of loaded samples is smaller than expected, clip features array
+  features = features[:loaded*windows.shape[0]*windows.shape[1]]
+  target = target[:loaded*windows.shape[0]*windows.shape[1]]
+
   # normalize w.r.t. dimensions
   features[:,-2] /= image.shape[0]
   features[:,-1] /= image.shape[1]
@@ -214,7 +222,7 @@ def main(user_input=None):
     except KeyboardInterrupt:
       print() #avoids the ^C line
       logger.info('Gracefully stopping training before limit (%d iterations)',
-          args['--batch']
+          args['--batch'])
       break
 
   # describe errors
-- 
GitLab