diff --git a/bob/bio/face/preprocessor/Base.py b/bob/bio/face/preprocessor/Base.py index b743431d269c72d055d390a32f6a16bf4ed039a0..9a939368e8d8ea27c5fe10cd54696b99c60fff7e 100644 --- a/bob/bio/face/preprocessor/Base.py +++ b/bob/bio/face/preprocessor/Base.py @@ -12,7 +12,7 @@ class Base (Preprocessor): dtype : :py:class:`numpy.dtype` or convertible or ``None`` The data type that the resulting image will have. - color_channel : one of ``('gray', 'red', 'gren', 'blue')`` + color_channel : one of ``('gray', 'red', 'gren', 'blue', 'rgb')`` The specific color channel, which should be extracted from the image. """ @@ -35,14 +35,18 @@ class Base (Preprocessor): **Returns:** - channel : 2D :py:class:`numpy.ndarray` + channel : 2D or 3D :py:class:`numpy.ndarray` The extracted color channel. """ if image.ndim == 2: + if self.channel == 'rgb': + return bob.ip.color.gray_to_rgb(image) if self.channel != 'gray': raise ValueError("There is no rule to extract a " + channel + " image from a gray level image!") return image + if self.channel == 'rgb': + return image if self.channel == 'gray': return bob.ip.color.rgb_to_gray(image) if self.channel == 'red': diff --git a/bob/bio/face/test/test_preprocessors.py b/bob/bio/face/test/test_preprocessors.py index 815b74bfffd0487bfcedc0733abd9d29b0b3ce8c..534e7e4c51cf0989a8a9dbfc39a23d9a1c8e1406 100644 --- a/bob/bio/face/test/test_preprocessors.py +++ b/bob/bio/face/test/test_preprocessors.py @@ -64,6 +64,23 @@ def test_base(): assert preprocessed.dtype == numpy.float64 assert numpy.allclose(preprocessed, bob.ip.color.rgb_to_gray(image)) + # color output + base = bob.bio.face.preprocessor.Base(color_channel="rgb", dtype=numpy.uint8) + colored = base(bob.ip.color.rgb_to_gray(image)) + + assert colored.ndim == 3 + assert colored.dtype == numpy.uint8 + assert all(numpy.allclose(colored[c], bob.ip.color.rgb_to_gray(image)) for c in range(3)) + + colored = base(image) + assert colored.ndim == 3 + assert colored.dtype == numpy.uint8 + assert numpy.all(colored == image) + + + + + def test_face_crop(): # read input