Skip to content
Snippets Groups Projects
image_utils.py 2.89 KiB
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later


"""Data loading code."""

import pathlib

import numpy
import PIL.Image


class SingleAutoLevel16to8:
    """Converts a 16-bit image to 8-bit representation using "auto-level".

    This transform assumes that the input image is gray-scaled.

    To auto-level, we calculate the maximum and the minimum of the image, and
    consider such a range should be mapped to the [0,255] range of the
    destination image.
    """

    def __call__(self, img):
        imin, imax = img.getextrema()
        irange = imax - imin
        return PIL.Image.fromarray(
            numpy.round(
                255.0 * (numpy.array(img).astype(float) - imin) / irange
            ).astype("uint8"),
        ).convert("L")


def remove_black_borders(
    img: PIL.Image.Image, threshold: int = 0
) -> PIL.Image.Image:
    """Remove black borders of CXR.

    Parameters
    ----------
        img
            A PIL image
        threshold
            Threshold value from which borders are considered black.
            Defaults to 0.

    Returns
    -------
        A PIL image with black borders removed
    """

    img = numpy.asarray(img)

    if len(img.shape) == 2:  # single channel
        mask = numpy.asarray(img) > threshold
        return PIL.Image.fromarray(img[numpy.ix_(mask.any(1), mask.any(0))])

    elif len(img.shape) == 3 and img.shape[2] == 3:
        r_mask = img[:, :, 0] > threshold
        g_mask = img[:, :, 1] > threshold
        b_mask = img[:, :, 2] > threshold

        mask = r_mask | g_mask | b_mask
        return PIL.Image.fromarray(img[numpy.ix_(mask.any(1), mask.any(0))])

    else:
        raise NotImplementedError


class RemoveBlackBorders:
    """Remove black borders of CXR."""

    def __init__(self, threshold=0):
        self.threshold = threshold

    def __call__(self, img):
        return remove_black_borders(img, self.threshold)


def load_pil(path: str | pathlib.Path) -> PIL.Image.Image:
    """Loads a sample data.

    Parameters
    ----------

    path
        The full path leading to the image to be loaded


    Returns
    -------

    image
        A PIL image
    """
    return PIL.Image.open(path)


def load_pil_grayscale(path: str | pathlib.Path) -> PIL.Image.Image:
    """Loads a sample data in grayscale mode ("L").

    Parameters
    ----------

    path
        The full path leading to the image to be loaded


    Returns
    -------

    image
        A PIL image in grayscale mode
    """
    return load_pil(path).convert("L")


def load_pil_rgb(path: str | pathlib.Path) -> PIL.Image.Image:
    """Loads a sample data in RGB mode ("RGB").

    Parameters
    ----------

    path
        The full path leading to the image to be loaded


    Returns
    -------

    image
        A PIL image in RGB mode
    """
    return load_pil(path).convert("RGB")