From 285b51b4d5a3a68857624735bc328e53008aa12e Mon Sep 17 00:00:00 2001
From: ageorge <anjith.george@idiap.ch>
Date: Tue, 14 Jul 2020 13:46:56 +0200
Subject: [PATCH] Updated configs for bob.io.stream

---
 .../SwirDiffPreprocessor.py                   | 50 ++++++++++++++++---
 configuration/preprocessing_mccnn_swirdiff.py |  4 +-
 2 files changed, 47 insertions(+), 7 deletions(-)

diff --git a/bob/paper/pad_mccnns_swirdiff/SwirDiffPreprocessor.py b/bob/paper/pad_mccnns_swirdiff/SwirDiffPreprocessor.py
index 82e7afb..e4c80ba 100644
--- a/bob/paper/pad_mccnns_swirdiff/SwirDiffPreprocessor.py
+++ b/bob/paper/pad_mccnns_swirdiff/SwirDiffPreprocessor.py
@@ -104,7 +104,7 @@ class SwirDiffPreprocessor(Preprocessor):
     return result
 
 
-  def __init__(self, face_size=224, absolute=False, debug=False, loose_crop=False, **kwargs):
+  def __init__(self, face_size=224, absolute=False, debug=True, loose_crop=False, **kwargs):
     """
     Init function
 
@@ -198,6 +198,8 @@ class SwirDiffPreprocessorBRSU(SwirDiffPreprocessor):
     faces = {}
     annotations = {}
 
+    print('images',images.keys())
+
     # first detect and crop the face in the color image
     if self.loose_crop:
 
@@ -273,7 +275,7 @@ class SwirDiffPreprocessorBRSU(SwirDiffPreprocessor):
 
 class SwirDiffPreprocessorHQWMCA(SwirDiffPreprocessor):
 
-  def __call__(self, frames, annotations):
+  def __call__(self, oframes, annotations):
     """
 
     This class processes HQWMCA data. In this case, the data
@@ -296,7 +298,7 @@ class SwirDiffPreprocessorHQWMCA(SwirDiffPreprocessor):
     
     """
     try:
-      n_frames = len(frames['_color'])
+      n_frames = len(oframes['color'])
     except KeyError:
       logger.error("There is no color stream")
       import sys
@@ -304,17 +306,48 @@ class SwirDiffPreprocessorHQWMCA(SwirDiffPreprocessor):
 
     # the result
     fc = bob.bio.video.FrameContainer()
+
+
+
+    # makes frames compatible with older one
+
+    frames={}
+
+    frames['color']=oframes['color']
+
+
+    # print('oframes',oframes['swir'].as_array().shape)
+
+    frame_keys=['color','swir_940nm','swir_1050nm','swir_1200nm','swir_1300nm','swir_1450nm','swir_1550nm','swir_1650nm']
+
+    # for idx,key in enumerate(frame_keys):
+    #   if key!='color':
+    #     for i in range(n_frames):
+    #       frames[key]=oframes['swir']
     
     # loop on all frames
+
+    swir=oframes['swir'].as_array()
     for i in range(n_frames):
      
       # flag to check if the current frame has annotations
       ok = True
       
       faces = {} 
-      for k in frames.keys():
 
-        _, image, _ = frames[k][i]
+      # print('frames.keys()',frames.keys())
+      for idx,k in enumerate(frame_keys):
+
+        if k!='color':
+
+          image=swir[i,idx-1,:,:].squeeze()
+
+        else:
+          _, image, _ = oframes[k][i]
+
+        # print('image in swirdiff',image.shape)
+
+        # print(annotations[str(i)])
         try:
           eyes = {'leye': annotations[str(i)]['leye'], 'reye': annotations[str(i)]['reye']}
         
@@ -349,11 +382,16 @@ class SwirDiffPreprocessorHQWMCA(SwirDiffPreprocessor):
         pyplot.show()
 
       swir_faces = dict((k, faces[k]) for k in faces.keys() if 'swir' in k)
+
+      print('swir_faces',swir_faces)
+
       swir_diffs = self._swir_diff(swir_faces)
+      print('swir_diffs',swir_diffs.shape)
       if ok: 
         final_array = numpy.zeros((1 + swir_diffs.shape[0], self.face_size, self.face_size), dtype='uint8')
-        final_array[0] = faces['_color'].astype('uint8')
+        final_array[0] = faces['color'].astype('uint8')
         final_array[1:] = swir_diffs.astype('uint8')
+        print('final_array',final_array.shape)
         fc.add(str(i), final_array)
    
     return fc
diff --git a/configuration/preprocessing_mccnn_swirdiff.py b/configuration/preprocessing_mccnn_swirdiff.py
index dd210cd..b204791 100644
--- a/configuration/preprocessing_mccnn_swirdiff.py
+++ b/configuration/preprocessing_mccnn_swirdiff.py
@@ -5,6 +5,7 @@
 # === DATABASE ===
 # ================
 
+from bob.io.stream import Stream
 
 
 dark_for_stereo     = False
@@ -64,6 +65,7 @@ swir = swir.warp(color)
 
 thermal = Stream('thermal').adjust(color).warp(color) 
 
+
 streams = { 'color'     : color,
             'swir'      : swir}
 
@@ -74,7 +76,7 @@ from bob.pad.face.database import HQWMCAPadDatabase_warp as HQWMCAPadDatabase
 database = HQWMCAPadDatabase(protocol='grand_test',
                              original_directory=_rc['bob.db.hqwmca.directory'],
                              original_extension='.h5',
-                             annotations_dir = './landmarks-full',
+                             annotations_dir = '/idiap/user/ageorge/WORK/COMMON_ENV_PAD_BATL_DB/src/bob.paper.pad_mccnns_swirdiff/landmarks-full',
                              streams=streams,
                              n_frames=10)
 
-- 
GitLab