From b7afdc2a66c686b3c7f814def59c4a0ecf196767 Mon Sep 17 00:00:00 2001
From: Vincent <vincent.pollet@idiap.ch>
Date: Fri, 26 Nov 2021 17:54:23 +0100
Subject: [PATCH] Add example of refactoring bob.io.stream and bob.ip.stereo
 stream filters as transformers for bob.pipelines

---
 bob/io/stream/transformers_example.py | 369 ++++++++++++++++++++++++++
 1 file changed, 369 insertions(+)
 create mode 100644 bob/io/stream/transformers_example.py

diff --git a/bob/io/stream/transformers_example.py b/bob/io/stream/transformers_example.py
new file mode 100644
index 0000000..a081cf4
--- /dev/null
+++ b/bob/io/stream/transformers_example.py
@@ -0,0 +1,369 @@
+from pathlib import Path
+from functools import partial
+from typing import Iterable
+from bob.ip.stereo.stereo import reproject_image
+
+import h5py
+import cv2
+import numpy as np
+from scipy.spatial import cKDTree
+import matplotlib.pyplot as plt
+
+from sklearn.base import TransformerMixin, BaseEstimator
+from sklearn.pipeline import make_pipeline
+import bob.pipelines as bpip
+
+from bob.io.image.utils import opencvbgr_to_bob, to_bob, to_matplotlib
+from bob.ip.stereo import CameraPair, stereo_match, load_camera_config, StereoParameters
+
+from utils import get_index_list
+
+
+class NormalizeTransformer(TransformerMixin, BaseEstimator):
+    def __init__(self, tmin=None, tmax=None, dtype="uint8") -> None:
+        super().__init__()
+        self.tmin = tmin
+        self.tmax = tmax
+        self.dtype = dtype
+
+    def _more_tags(self):
+        return {"requires_fit": False, "stateless": True}
+
+    def fit(self, data) -> None:
+        return self
+
+    def transform(self, data) -> np.ndarray:
+        data = np.asarray(data)  # make array with dim 0 indexing samples
+
+        tmin = np.min(data, axis=0, keepdims=True) if self.tmin is None else self.tmin
+        tmax = np.max(data, axis=0, keepdims=True) if self.tmax is None else self.tmax
+        data = (data - tmin).astype("float64")
+        data = data / (tmax.astype(np.float) - tmin.astype(np.float))
+        data = np.clip(data, a_min=0.0, a_max=1.0)
+        if self.dtype == "uint8":
+            data = (data * 255.0).astype("uint8")
+        elif self.dtype == "uint16":
+            data = (data * 65535.0).astype("uint16")
+        return list(data)  # go back to list for sample dimension
+
+
+class ColorMapTransformer(TransformerMixin, BaseEstimator):
+    def __init__(self, colormap="gray") -> None:
+        super().__init__()
+        self.colormap = colormap
+
+    def _more_tags(self):
+        return {"requires_fit": False, "stateless": True}
+
+    def fit(self, data):
+        return self
+
+    def transform(self, data):
+
+        data = np.asarray(data)  # make array with dim 0 indexing samples
+
+        if data.shape[2] != 1:  # need channel dimension to be 1
+            raise ValueError("Can not apply colormap on array with channel dimension " + str(data.shape[1]))
+
+        tmin = np.min(data, axis=0, keepdims=True)
+        tmax = np.max(data, axis=0, keepdims=True)
+
+        data = (data - tmin).astype("float")
+        data = (data * 255.0 / (tmax.astype(np.float) - tmin.astype(np.float))).astype("uint8")
+        if self.colormap == "gray":
+            data = np.concatenate([data, data, data], axis=2).astype("uint8")
+            return data
+        else:
+            maps = {"jet": cv2.COLORMAP_JET, "bone": cv2.COLORMAP_BONE, "hsv": cv2.COLORMAP_HSV}
+
+            return [
+                np.stack(
+                    [opencvbgr_to_bob(cv2.applyColorMap(image.squeeze(0), maps[self.colormap])) for image in sample],
+                    axis=0,
+                )
+                for sample in data
+            ]
+
+
+class StereoMatchTransformer(TransformerMixin, BaseEstimator):
+    def __init__(self, camera_left, camera_right, stereo_parameters) -> None:
+        super().__init__()
+        self.camera_left = camera_left
+        self.camera_right = camera_right
+        self.stereo_parameters = stereo_parameters
+
+    def _more_tags(self):
+        return {"requires_fit": False, "stateless": True}
+
+    def fit(self, data, y=None):
+        return self
+
+    def transform(self, data, left_data, right_data):
+        camera_pair = CameraPair(self.camera_left, self.camera_right)
+
+        return [
+            np.stack(
+                [
+                    stereo_match(left_frame, right_frame, camera_pair, stereo_parameters=self.stereo_parameters)
+                    for left_frame, right_frame in zip(left_sample, right_sample)  # loop over frames
+                ],
+                axis=0,
+            )
+            for left_sample, right_sample in zip(left_data, right_data)  # loop over samples
+        ]
+
+
+class StereoReprojectTransformer(TransformerMixin, BaseEstimator):
+    def __init__(self, stream_camera, camera_left, camera_right) -> None:
+        super().__init__()
+        self.stream_camera = stream_camera
+        self.camera_left = camera_left
+        self.camera_right = camera_right
+
+    def _more_tags(self):
+        return {"requires_fit": False, "stateless": True}
+
+    def fit(self, data, y=None):
+        return self
+
+    def transform(self, data, stream_data, map3D_data):
+        # input arguments are list over sample attributes (when this transformer is wrapped with SampleWrapper)
+
+        camera_pair = CameraPair(self.camera_left, self.camera_right)
+
+        return [
+            np.stack(
+                [
+                    reproject_image(stream_frame, map3D_frame, self.stream_camera, camera_pair)
+                    for stream_frame, map3D_frame in zip(
+                        stream_sample, map3D_sample
+                    )  # Should we parralellize this one ?
+                ],
+                axis=0,
+            )
+            for stream_sample, map3D_sample in zip(
+                stream_data, map3D_data
+            )  # this for loop indexes on a "dask bag size" (dask bag not seen here but handled by wrapper)
+        ]
+
+class AdjustTransformer(TransformerMixin, BaseEstimator):
+    def __init__(self) -> None:
+        super().__init__()
+
+    def _more_tags(self):
+        return {"requires_fit": False, "stateless": True}
+
+    def fit(self, data, y=None):
+        return self
+    
+    def transform(self, data, stream_data, stream_timestamps, adjust_to_timestamps):
+        # original stream indices
+
+        new_indices = []
+        for sample in range(len(stream_data)):
+
+            old_indices = list(range(adjust_to_timestamps[sample].shape[0]))  # list with indices of all frames
+            selected_timestamps = [adjust_to_timestamps[sample][i] for i in old_indices]
+            kdtree = cKDTree(stream_timestamps[sample][:, np.newaxis])
+
+            def get_index(val, kdtree):
+                _, i = kdtree.query(val, k=1)
+                return i
+        
+        new_indices.append([get_index(ts[np.newaxis], kdtree) for ts in selected_timestamps])
+
+        return [frames[new_sample_indices] for frames, new_sample_indices in zip(stream_data, new_indices)]
+
+
+
+def load_from_hdf5(filepath, dataset, attribute=None, indices=None):
+    """Load a dataset or a dataset attribute from a hdf5 file
+
+    Parameters
+    ----------
+    filepath : :obj:`pathlib.Path`
+        Path to the hdf5 file.
+    dataset : str
+        Name of the dataset to load.
+    attribute : str, optional
+        If not None, the function will load this attribute from the dataset, not the dataset's data, by default None
+    indices : slice, optional
+        Indices to load in the dataset or dataset attribute, by default None which loads everything.
+
+    Returns
+    -------
+    :obj:`numpy.ndarray`
+        Dataset's data or dataset's attribute data.
+    """
+
+    if indices is None:  # load everything
+        indices = slice(None)
+
+    with h5py.File(str(filepath), "r") as data_file:
+        if attribute is None:  # loading a dataset
+            data = data_file[dataset][indices]
+        else:  # loading the attribute of a dataset
+            data = data_file[dataset].attrs[attribute][indices]
+
+    return data
+
+
+def candy_file_2_delayed_sample(filepath, streams, indices=None):
+
+    if indices is None:
+        indices = slice(None)  # load everything
+        # TODO: This is ignored for photogram, which only have 1 frame
+
+    # candy_frame_list and delayed_attributes_dict should probably be in candy somewhere (database ?)
+
+    candy_frames_list = {  # after trimming
+        "stereo": [0, 5, 10, 15],
+        "850": [1, 6, 11],
+        "950": [2, 7, 12],
+        "white": [3, 8, 13],
+        "dark": [4, 9, 14],
+        "photogram_0": [16],
+        "photogram_1": [17],
+        "photogram_2": [18],
+        "photogram_3": [19],
+    }
+
+    streams_load_fcts = {
+        "color": partial(load_from_hdf5, filepath, "color", indices=candy_frames_list["white"][indices]),
+        "color_dark": partial(load_from_hdf5, filepath, "color", indices=candy_frames_list["dark"][indices]),
+        "color_stereo": partial(load_from_hdf5, filepath, "color", indices=candy_frames_list["stereo"][indices]),
+        "color_photogram_0": partial(load_from_hdf5, filepath, "color", indices=candy_frames_list["photogram_0"]),
+        "color_photogram_1": partial(load_from_hdf5, filepath, "color", indices=candy_frames_list["photogram_1"]),
+        "color_photogram_2": partial(load_from_hdf5, filepath, "color", indices=candy_frames_list["photogram_2"]),
+        "color_photogram_3": partial(load_from_hdf5, filepath, "color", indices=candy_frames_list["photogram_3"]),
+
+        "left_850": partial(load_from_hdf5, filepath, "left", indices=candy_frames_list["850"][indices]),
+        "left_950": partial(load_from_hdf5, filepath, "left", indices=candy_frames_list["950"][indices]),
+        "left_dark": partial(load_from_hdf5, filepath, "left", indices=candy_frames_list["dark"][indices]),
+        "left_stereo": partial(load_from_hdf5, filepath, "left", indices=candy_frames_list["stereo"][indices]),
+        "left_photogram_0": partial(load_from_hdf5, filepath, "left", indices=candy_frames_list["photogram_0"]),
+        "left_photogram_1": partial(load_from_hdf5, filepath, "left", indices=candy_frames_list["photogram_1"]),
+        "left_photogram_2": partial(load_from_hdf5, filepath, "left", indices=candy_frames_list["photogram_2"]),
+        "left_photogram_3": partial(load_from_hdf5, filepath, "left", indices=candy_frames_list["photogram_3"]),
+
+        "right_850": partial(load_from_hdf5, filepath, "right", indices=candy_frames_list["850"][indices]),
+        "right_950": partial(load_from_hdf5, filepath, "right", indices=candy_frames_list["950"][indices]),
+        "right_dark": partial(load_from_hdf5, filepath, "right", indices=candy_frames_list["dark"][indices]),
+        "right_stereo": partial(load_from_hdf5, filepath, "right", indices=candy_frames_list["stereo"][indices]),
+        "right_photogram_0": partial(load_from_hdf5, filepath, "right", indices=candy_frames_list["photogram_0"]),
+        "right_photogram_1": partial(load_from_hdf5, filepath, "right", indices=candy_frames_list["photogram_1"]),
+        "right_photogram_2": partial(load_from_hdf5, filepath, "right", indices=candy_frames_list["photogram_2"]),
+        "right_photogram_3": partial(load_from_hdf5, filepath, "right", indices=candy_frames_list["photogram_3"]),
+    }
+
+    attributes_load_fcts = {
+        "color_timestamps": partial(load_from_hdf5, filepath, "color", attribute="timestamps", indices=candy_frames_list["white"][indices]),
+        "color_dark_timestamps": partial(load_from_hdf5, filepath, "color", attribute="timestamps", indices=candy_frames_list["dark"][indices]),
+        "color_stereo_timestamps": partial(load_from_hdf5, filepath, "color", attribute="timestamps", indices=candy_frames_list["stereo"][indices]),
+        "color_photogram_0_timestamps": partial(load_from_hdf5, filepath, "color", attribute="timestamps", indices=candy_frames_list["photogram_0"][indices]),
+        "color_photogram_1_timestamps": partial(load_from_hdf5, filepath, "color", attribute="timestamps", indices=candy_frames_list["photogram_1"][indices]),
+        "color_photogram_2_timestamps": partial(load_from_hdf5, filepath, "color", attribute="timestamps", indices=candy_frames_list["photogram_2"][indices]),
+        "color_photogram_3_timestamps": partial(load_from_hdf5, filepath, "color", attribute="timestamps", indices=candy_frames_list["photogram_3"][indices]),
+
+        "left_850_timestamps": partial(load_from_hdf5, filepath, "left", attribute="timestamps", indices=candy_frames_list["850"][indices]),
+        "left_950_timestamps": partial(load_from_hdf5, filepath, "left", attribute="timestamps", indices=candy_frames_list["950"][indices]),
+        "left_dark_timestamps": partial(load_from_hdf5, filepath, "left", attribute="timestamps", indices=candy_frames_list["dark"][indices]),
+        "left_stereo_timestamps": partial(load_from_hdf5, filepath, "left", attribute="timestamps", indices=candy_frames_list["stereo"][indices]),
+        "left_photogram_0_timestamps": partial(load_from_hdf5, filepath, "left", attribute="timestamps", indices=candy_frames_list["photogram_0"]),
+        "left_photogram_1_timestamps": partial(load_from_hdf5, filepath, "left", attribute="timestamps", indices=candy_frames_list["photogram_1"]),
+        "left_photogram_2_timestamps": partial(load_from_hdf5, filepath, "left", attribute="timestamps", indices=candy_frames_list["photogram_2"]),
+        "left_photogram_3_timestamps": partial(load_from_hdf5, filepath, "left", attribute="timestamps", indices=candy_frames_list["photogram_3"]),
+
+        "right_850_timestamps": partial(load_from_hdf5, filepath, "right", attribute="timestamps", indices=candy_frames_list["850"][indices]),
+        "right_950_timestamps": partial(load_from_hdf5, filepath, "right", attribute="timestamps", indices=candy_frames_list["950"][indices]),
+        "right_dark_timestamps": partial(load_from_hdf5, filepath, "right", attribute="timestamps", indices=candy_frames_list["dark"][indices]),
+        "right_stereo_timestamps": partial(load_from_hdf5, filepath, "right", attribute="timestamps", indices=candy_frames_list["stereo"][indices]),
+        "right_photogram_0_timestamps": partial(load_from_hdf5, filepath, "right", attribute="timestamps", indices=candy_frames_list["photogram_0"]),
+        "right_photogram_1_timestamps": partial(load_from_hdf5, filepath, "right", attribute="timestamps", indices=candy_frames_list["photogram_1"]),
+        "right_photogram_2_timestamps": partial(load_from_hdf5, filepath, "right", attribute="timestamps", indices=candy_frames_list["photogram_2"]),
+        "right_photogram_3_timestamps": partial(load_from_hdf5, filepath, "right", attribute="timestamps", indices=candy_frames_list["photogram_3"]),
+    }
+
+    delayed_attributes_dict = {stream: streams_load_fcts[stream] for stream in streams}
+    delayed_attributes_dict.update(
+        {stream + "_timestamps": attributes_load_fcts[stream + "_timestamps"] for stream in streams}
+    )
+
+    return bpip.DelayedSample(
+        lambda: 42,  # .data of sample will not be used, all data is put in delayed attributes, but if .data is set to None the sample will be ignored
+        delayed_attributes=delayed_attributes_dict,
+    )
+
+
+def main():
+
+    data_folder = Path(
+        "/idiap/temp/vpollet/projects/candy/bob.ip.stereo/bob/ip/stereo/calibration_2021_11_8_16x16_checker_size_12_marker_size_9_trim_demosaiced_16_bits_scaled"
+    )
+
+    files = [data_file for data_file in data_folder.iterdir() if data_file.is_file() and data_file.match("*.h5")]
+
+    samples = [
+        candy_file_2_delayed_sample(filepath, ["left_stereo", "right_stereo", "color"], slice(0, 2))
+        for filepath in files[:2]
+    ]
+
+    print(samples)
+
+    left_camera = load_camera_config(
+        "/idiap/temp/vpollet/projects/candy/bob.ip.stereo/bob/ip/stereo/calib_2021_11_8.json", "left"
+    )
+    right_camera = load_camera_config(
+        "/idiap/temp/vpollet/projects/candy/bob.ip.stereo/bob/ip/stereo/calib_2021_11_8.json", "right"
+    )
+    color_camera = load_camera_config(
+        "/idiap/temp/vpollet/projects/candy/bob.ip.stereo/bob/ip/stereo/calib_2021_11_8.json", "color"
+    )
+
+    stereo_parameters = StereoParameters()
+
+    stereo_transformer = StereoMatchTransformer(left_camera, right_camera, stereo_parameters)
+
+    stereo_sample_transformer = bpip.wrap(
+        ["sample"],
+        stereo_transformer,
+        transform_extra_arguments=[("left_data", "left_stereo"), ("right_data", "right_stereo")],
+        output_attribute="map3D",
+    )
+
+    color_adjust_transformer = AdjustTransformer()
+
+    color_adjust_sample_transformer = bpip.wrap(
+        ["sample"],
+        color_adjust_transformer,
+        transform_extra_arguments=[
+            ("stream_data", "color"),
+            ("stream_timestamps", "color_timestamps"),
+            ("adjust_to_timestamps", "left_stereo_timestamps"),
+        ],
+        output_attribute="color",
+    )
+
+    reproject_transformer = StereoReprojectTransformer(color_camera, left_camera, right_camera)
+
+    reproject_sample_transformer = bpip.wrap(
+        ["sample"],
+        reproject_transformer,
+        transform_extra_arguments=[("stream_data", "color"), ("map3D_data", "map3D")],
+        output_attribute="rep_color",
+    )
+
+    stereo_pipeline = make_pipeline(stereo_sample_transformer, color_adjust_sample_transformer, reproject_sample_transformer)
+
+    dask_stereo_pipeline = bpip.wrap(["dask"], stereo_pipeline)
+
+    stereo_results = dask_stereo_pipeline.transform(samples).compute(scheduler="single-threaded")
+
+    fig, axs = plt.subplots(1, 2, figsize=(15, 15))
+    axs[0].imshow(to_matplotlib(stereo_results[0].map3D[0, 2]), cmap="jet")
+    axs[1].imshow(to_matplotlib(stereo_results[0].color[0]))
+    plt.show()
+
+
+if __name__ == "__main__":
+    main()
-- 
GitLab