Commit 992e72c0 authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

[preprocessor][scale] Rewrite the Scale transformer using simpler and more efficient code

parent 39eb03b8
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
from sklearn.base import TransformerMixin, BaseEstimator
from sklearn.preprocessing import FunctionTransformer
from skimage.transform import resize
import numpy as np
from sklearn.utils import check_array
from bob.io.image import to_matplotlib, to_bob
class Scale(TransformerMixin, BaseEstimator):
"""
Simple scales an images
def scale(images, target_img_size):
"""Scales a list of images to the target size
Parameters
-----------
target_img_size: tuple
Target image size
----------
images : array_like
A list of images (in Bob format) to be scaled to the target size
target_img_size : tuple
A tuple of size 2 as (H, W)
Returns
-------
numpy.ndarray
Scaled images
"""
images = check_array(images, allow_nd=True)
images = to_matplotlib(images)
def __init__(self, target_img_size, **kwargs):
self.target_img_size = target_img_size
def _more_tags(self):
return {"stateless": True, "requires_fit": False}
# images are always batched
output_shape = tuple(target_img_size)
output_shape = tuple(images.shape[0:1]) + output_shape
images = resize(images, output_shape=output_shape)
def fit(self, X, y=None):
return self
return to_bob(images)
def transform(self, X, annotations=None):
"""
Resize an image given a shape
Parameters
----------
img:
Input image
target_img_size: tuple
Target image size
"""
def _resize(x):
return resize(x, self.target_img_size, anti_aliasing=True)
X = check_array(X, allow_nd=True)
def Scale(target_img_size):
"""
A transformer that scales images.
It accepts a list of inputs
if X.ndim < 2 or X.ndim > 4:
raise ValueError(f"Invalid image shape {X.shape}")
Parameters
-----------
if X.ndim == 2:
# Checking if it's bob format CxHxW
return _resize(X)
target_img_size: tuple
Target image size, specified as a tuple of (H, W)
if X.ndim == 3:
# Checking if it's bob format CxHxW
if X.shape[0] == 3:
X = np.moveaxis(X, -1, 0)
return _resize(X)
# Batch of images
if X.ndim == 4:
# Checking if it's bob format NxCxHxW
if X.shape[1] == 3:
X = np.moveaxis(X, 1, -1)
return [_resize(x) for x in X]
"""
return FunctionTransformer(
func=scale, validate=False, kw_args=dict(target_img_size=target_img_size)
)
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment