transformer.py 2.81 KB
Newer Older
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
1 2 3
import logging

from sklearn.base import BaseEstimator, TransformerMixin
4
from bob.pipelines.wrappers import _check_n_input_output, _frmt
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
5
from . import utils
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
6 7 8 9

logger = logging.getLogger(__name__)


Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
10
class VideoWrapper(TransformerMixin, BaseEstimator):
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
11 12 13 14
    """Wrapper class to run image preprocessing algorithms on video data.

    **Parameters:**

Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
15 16
    estimator : str or ``sklearn.base.BaseEstimator`` instance
      The transformer to be used to preprocess the frames.
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
17 18 19 20
    """

    def __init__(
        self,
21
        estimator,
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
22 23 24
        **kwargs,
    ):
        super().__init__(**kwargs)
25
        self.estimator = estimator
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45

    def transform(self, videos, **kwargs):
        transformed_videos = []
        for i, video in enumerate(videos):

            if not hasattr(video, "indices"):
                raise ValueError(
                    f"The input video: {video}\n does not have indices.\n "
                    f"Processing failed in {self}"
                )

            kw = {}
            if kwargs:
                kw = {k: v[i] for k, v in kwargs.items()}
            if "annotations" in kw:
                kw["annotations"] = [
                    kw["annotations"].get(index, kw["annotations"].get(str(index)))
                    for index in video.indices
                ]

46 47 48 49 50 51 52 53 54 55 56 57 58 59 60
            # remove None's before calling and add them back in data later
            # Isolate invalid samples (when previous transformers returned None)
            invalid_ids = [i for i, frame in enumerate(video) if frame is None]
            valid_frames = [frame for frame in video if frame is not None]

            # remove invalid kw args as well
            for k, v in kw.items():
                kw[k] = [vv for j, vv in enumerate(v) if j not in invalid_ids]

            # Process only the valid samples
            output = None
            if len(valid_frames) > 0:
                output = self.estimator.transform(valid_frames, **kw)
                _check_n_input_output(
                    valid_frames, output, f"{_frmt(self.estimator)}.transform"
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
61 62
                )

63 64 65 66 67 68 69 70
            if output is None:
                output = [None] * len(valid_frames)

            # Rebuild the full batch of samples (include the previously failed)
            if len(invalid_ids) > 0:
                output = list(output)
                for j in invalid_ids:
                    output.insert(j, None)
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
71

72
            data = utils.VideoLikeContainer(output, video.indices)
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
73 74 75 76
            transformed_videos.append(data)
        return transformed_videos

    def _more_tags(self):
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
77
        tags = self.estimator._get_tags()
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
78
        tags["bob_features_save_fn"] = utils.VideoLikeContainer.save_function
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
79 80
        tags["bob_features_load_fn"] = utils.VideoLikeContainer.load
        return tags
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
81 82 83 84

    def fit(self, X, y=None, **fit_params):
        """Does nothing"""
        return self