From 320f8a171f3f7ac705b95cca3b923e8af3871e63 Mon Sep 17 00:00:00 2001
From: Manuel Gunther <siebenkopf@googlemail.com>
Date: Fri, 25 Sep 2015 11:35:07 -0600
Subject: [PATCH] Added support for RGB color space for the base preprocessor.

---
 bob/bio/face/preprocessor/Base.py       |  8 ++++++--
 bob/bio/face/test/test_preprocessors.py | 17 +++++++++++++++++
 2 files changed, 23 insertions(+), 2 deletions(-)

diff --git a/bob/bio/face/preprocessor/Base.py b/bob/bio/face/preprocessor/Base.py
index b743431d..9a939368 100644
--- a/bob/bio/face/preprocessor/Base.py
+++ b/bob/bio/face/preprocessor/Base.py
@@ -12,7 +12,7 @@ class Base (Preprocessor):
   dtype : :py:class:`numpy.dtype` or convertible or ``None``
     The data type that the resulting image will have.
 
-  color_channel : one of ``('gray', 'red', 'gren', 'blue')``
+  color_channel : one of ``('gray', 'red', 'gren', 'blue', 'rgb')``
     The specific color channel, which should be extracted from the image.
   """
 
@@ -35,14 +35,18 @@ class Base (Preprocessor):
 
     **Returns:**
 
-    channel : 2D :py:class:`numpy.ndarray`
+    channel : 2D or 3D :py:class:`numpy.ndarray`
       The extracted color channel.
     """
     if image.ndim == 2:
+      if self.channel == 'rgb':
+        return bob.ip.color.gray_to_rgb(image)
       if self.channel != 'gray':
         raise ValueError("There is no rule to extract a " + channel + " image from a gray level image!")
       return image
 
+    if self.channel == 'rgb':
+      return image
     if self.channel == 'gray':
       return bob.ip.color.rgb_to_gray(image)
     if self.channel == 'red':
diff --git a/bob/bio/face/test/test_preprocessors.py b/bob/bio/face/test/test_preprocessors.py
index 815b74bf..534e7e4c 100644
--- a/bob/bio/face/test/test_preprocessors.py
+++ b/bob/bio/face/test/test_preprocessors.py
@@ -64,6 +64,23 @@ def test_base():
   assert preprocessed.dtype == numpy.float64
   assert numpy.allclose(preprocessed, bob.ip.color.rgb_to_gray(image))
 
+  # color output
+  base = bob.bio.face.preprocessor.Base(color_channel="rgb", dtype=numpy.uint8)
+  colored = base(bob.ip.color.rgb_to_gray(image))
+
+  assert colored.ndim == 3
+  assert colored.dtype == numpy.uint8
+  assert all(numpy.allclose(colored[c], bob.ip.color.rgb_to_gray(image)) for c in range(3))
+
+  colored = base(image)
+  assert colored.ndim == 3
+  assert colored.dtype == numpy.uint8
+  assert numpy.all(colored == image)
+
+
+
+
+
 
 def test_face_crop():
   # read input
-- 
GitLab