Skip to content
Snippets Groups Projects
Commit 28d836f2 authored by Manuel Günther's avatar Manuel Günther
Browse files

Added bob.io.image.load function that is able to automatically detect the...

Added bob.io.image.load function that is able to automatically detect the image type (using imghdr.what)
parent 10a67d72
No related branches found
No related tags found
No related merge requests found
...@@ -18,6 +18,45 @@ def get_config(): ...@@ -18,6 +18,45 @@ def get_config():
return bob.extension.get_config(__name__, version.externals) return bob.extension.get_config(__name__, version.externals)
def load(filename, extension=None):
"""load(filename) -> image
This function loads and image from the file with the specified ``filename``.
The type of the image will be determined based on the ``extension`` parameter, which can have the following values:
- ``None``: The file name extension of the ``filename`` is used to determine the image type.
- ``'auto'``: The type of the image will be detected automatically, using the :py:mod:`imghdr` module.
- ``'.xxx`'': The image type is determined by the given extension.
For a list of possible extensions, see :py:func:`bob.io.base.extensions` (only the image extensions are valid here).
**Parameters:**
``filename`` : str
The name of the image file to load.
``extension`` : str
[Default: ``None``] If given, the given extension will determine the type of the image.
Use ``'auto'`` to automatically determine the extension (this might take slightly more time).
**Returns**
``image`` : 2D or 3D :py:class:`numpy.ndarray` of type ``uint8``
The image read from the specified file.
"""
# check the extension
if extension is None:
f = bob.io.base.File(filename, 'r')
else:
if extension == 'auto':
import imghdr
extension = "." + imghdr.what(filename)
f = bob.io.base.File(filename, 'r', extension)
return f.read()
# use the same alias as for bob.io.base.load
read = load
def get_include_directories(): def get_include_directories():
"""get_include_directories() -> includes """get_include_directories() -> includes
......
...@@ -11,6 +11,8 @@ ...@@ -11,6 +11,8 @@
import os import os
import numpy import numpy
from bob.io.base import load, write, test_utils from bob.io.base import load, write, test_utils
import bob.io.image
import nose
# These are some global parameters for the test. # These are some global parameters for the test.
PNG_INDEXED_COLOR = test_utils.datafile('img_indexed_color.png', __name__) PNG_INDEXED_COLOR = test_utils.datafile('img_indexed_color.png', __name__)
...@@ -67,6 +69,24 @@ def test_netpbm(): ...@@ -67,6 +69,24 @@ def test_netpbm():
# because of re-compression # because of re-compression
def test_image_load():
# test that the generic bob.io.image.load function works as expected
for filename in ('test.jpg', 'test.pbm', 'test.pgm', 'test.ppm', 'img_rgba_color.png'):
full_file = test_utils.datafile(filename, __name__)
# load with just image name
i1 = bob.io.image.load(full_file)
# load with image name and extension
i2 = bob.io.image.load(full_file, os.path.splitext(full_file)[1])
assert numpy.array_equal(i1,i2)
# load with image name and automatically estimated extension
i3 = bob.io.image.load(full_file, 'auto')
assert numpy.array_equal(i1,i3)
# assert that unknown extensions raise exceptions
nose.tools.assert_raises(RuntimeError, lambda x: bob.io.image.load(x, ".unknown"), full_file)
def test_cpp_interface(): def test_cpp_interface():
from ._library import _test_io from ._library import _test_io
import tempfile import tempfile
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment