Commit 75aa314c authored by Vincent POLLET's avatar Vincent POLLET
Browse files

Solve circular dependencies between Stream, StreamFile and StreamFilter...

Solve circular dependencies between Stream, StreamFile and StreamFilter classes and refactor them in separated files
parent d5c77162
Pipeline #42316 failed with stage
in 13 minutes and 44 seconds
from .stream import Stream, StreamFile
from .streamfile import StreamFile
from .stream import Stream
from .streamfilters import StreamFilter
def get_config():
"""Returns a string containing the configuration information.
"""Returns a string containing the configuration information.
"""
import bob.extension
return bob.extension.get_config(__name__)
import bob.extension
return bob.extension.get_config(__name__)
# gets sphinx autodoc done right - don't remove it
__all__ = [_ for _ in dir() if not _.startswith('_')]
__all__ = [_ for _ in dir() if not _.startswith("_")]
This diff is collapsed.
import json
import numpy as np
from bob.io.base import HDF5File
from bob.ip.stereo import load_camera_config
from .config import load_data_config
class StreamFile:
def __init__(self, hdf5_file_path=None, data_format_config_file_path=None, camera_config_file_path=None, mode="r"):
if hdf5_file_path is not None:
self.hdf5_file = HDF5File(hdf5_file_path, mode)
else:
self.hdf5_file = None
if data_format_config_file_path is not None:
self.data_format_config = load_data_config(data_format_config_file_path)
else:
self.data_format_config = None
if camera_config_file_path is not None:
self.camera_config = load_camera_config(camera_config_file_path)
else:
self.camera_config = None
def __enter__(self):
pass
def __exit__(self, exc_type, exc_val, exc_tb):
if self.hdf5_file is not None:
self.hdf5_file.close()
def set_source(
self, hdf5_file_path=None, data_format_config_file_path=None, camera_config_file_path=None, mode="r"
):
self.__init__(
hdf5_file_path=hdf5_file_path,
data_format_config_file_path=data_format_config_file_path,
camera_config_file_path=camera_config_file_path,
mode=mode,
)
print("set set_source", hdf5_file_path, self.hdf5_file)
def get_available_streams(self):
return list(self.data_format_config.keys())
def get_stream_config(self, stream_name):
data_config = self.data_format_config[stream_name]
return data_config
def get_stream_shape(self, stream_name):
data_config = self.get_stream_config(stream_name)
data_path = data_config["path"]
descriptor = self.hdf5_file.describe(data_path)
# @TODO check fo other arrays types..
shape = descriptor[1][0][1]
return shape
def get_stream_timestamps(self, stream_name):
data_config = self.get_stream_config(stream_name)
data_path = data_config["path"]
if not self.hdf5_file.has_attribute("timestamps", data_path):
return None
timestamps = self.hdf5_file.get_attribute("timestamps", data_path)
if isinstance(timestamps, np.ndarray) and len(timestamps) == 1 and isinstance(timestamps[0], np.bytes_):
timestamps = timestamps[0]
if isinstance(timestamps, bytes):
timestamps = timestamps.decode("utf-8")
if isinstance(timestamps, str):
return np.array(json.loads("[" + timestamps.strip().strip("[").strip("]") + "]"))
else:
return timestamps
def get_stream_camera(self, stream_name):
# TODO cache camera objects
data_config = self.get_stream_config(stream_name)
if "use_config_from" in data_config:
data_config = self.get_stream_config(data_config["use_config_from"])
camera_name = data_config["camera"]
return self.camera_config[camera_name]
def load_stream_data(self, stream_name, index):
data_config = self.get_stream_config(stream_name)
data_path = data_config["path"]
if "use_config_from" in data_config:
data_config = self.get_stream_config(data_config["use_config_from"])
array_format = data_config["array_format"]
if "flip" in array_format:
array_flip = array_format["flip"]
else:
array_flip = None
def flip_axes(data, axes):
if axes is not None:
for axis_name in axes:
data = np.flip(data, axis=int(array_format[axis_name]))
return data
# TODO load only relevant data if cropped
data = None
if isinstance(index, tuple):
index = index[0]
print("WARNING: cropping not yet implemented")
if isinstance(index, int):
data = np.stack([self.hdf5_file.lread(data_path, index)])
elif isinstance(index, slice):
if index.step == None:
indices = list(range(index.start, index.stop))
else:
indices = list(range(index.start, index.stop, index.step))
data = np.stack([self.hdf5_file.lread(data_path, i) for i in indices])
elif isinstance(index, list):
data = np.stack([self.hdf5_file.lread(data_path, i) for i in index])
else:
raise Exception("index can only be int, slice, tuple or list")
data = flip_axes(data, array_flip)
# TODO rotate
return data
import numpy as np
from scipy.spatial import cKDTree
from skimage import transform
import cv2 as cv
from .utils import convert_cv_to_bob, StreamArray
from bob.ip.stereo import StereoParameters
from bob.ip.stereo import stereo_match, reproject_image, CameraPair
from .stream import stream_filter, Stream
################################################################################
# Stream Filters
### default ###
@stream_filter("nop")
class StreamFilter(Stream):
def __init__(self, name, parent):
super().__init__(name=name, parent=parent)
pass
def get_indices(self, index):
return super().get_indices(index)
def process(self, data, indices):
assert isinstance(indices, list)
return np.stack([self.process_frame(data[i], i, indices[i]) for i in range(data.shape[0])])
def process_frame(self, data, data_index, stream_index):
return data
# load one or several frames
def load(self, index):
indices = self.get_indices(index)
# return buffered data OR load from file OR process data
if self._Stream__loaded == indices and self._Stream__data is not None:
# print('loaded', self.name)
pass
else:
self.__data = self.process(self.parent.load(indices), indices)
# buffer and return data
self.__loaded = indices
return self.__data
### channel ###
@stream_filter("select")
class Select(StreamFilter):
def __init__(self, name, parent, channel=None):
super().__init__(name=name, parent=parent)
if channel is not None:
self.channel = channel
else:
raise Exception("channel parameter not set")
@property
def shape(self):
return (self.parent.shape[0], 1, self.parent.shape[2], self.parent.shape[3])
def process(self, data, indices):
return np.expand_dims(data[:, self.channel, :, :], axis=1)
### astype ###
@stream_filter("astype")
class Select(StreamFilter):
def __init__(self, name, parent, dtype=None):
super().__init__(name=name, parent=parent)
if dtype is not None:
self.dtype = dtype
else:
raise Exception("dtype parameter not set")
def process(self, data, indices):
return data.astype(self.dtype)
### to_rgb ###
@stream_filter("colormap")
class StreamColorMap(StreamFilter):
def __init__(self, name, parent, colormap="gray"):
super().__init__(name=name, parent=parent)
self.colormap = colormap
@property
def shape(self):
return (self.parent.shape[0], 3, self.parent.shape[2], self.parent.shape[3])
def process_frame(self, data, data_index, stream_index):
if data.shape[0] == 1:
# normalise
tmin = np.amin(data)
tmax = np.amax(data)
data = data[0, :, :]
data = (data - tmin).astype("float")
data = (data * 255.0 / float(tmax - tmin)).astype("uint8")
if self.colormap == "gray":
data = (np.stack([data, data, data])).astype("uint8")
return data
else:
# todo add all colormaps
maps = {
"jet": cv.COLORMAP_JET,
"bone": cv.COLORMAP_BONE,
"hsv": cv.COLORMAP_HSV,
}
data = cv.applyColorMap(data, maps[self.colormap])
data = convert_cv_to_bob(data)
return data
else:
raise Exception("cannot convert multichannel stream")
### normalize ###
@stream_filter("normalize")
class StreamNormalize(StreamFilter):
def __init__(self, name, parent, tmin=None, tmax=None, dtype="uint8"):
self.tmin = tmin
self.tmax = tmax
self.dtype = dtype
super().__init__(name=name, parent=parent)
def process(self, data, indices):
tmin = np.amin(data) if self.tmin is None else self.tmin
tmax = np.amax(data) if self.tmax is None else self.tmax
data = (data - tmin).astype("float64")
data = data / float(tmax - tmin)
data = np.clip(data, a_min=0.0, a_max=1.0)
if self.dtype == "uint8":
data = (data * 255.0).astype("uint8")
elif self.dtype == "uint16":
data = (data * 65535.0).astype("uint16")
return data
### clean_dead ###
@stream_filter("clean")
class StreamClean(StreamFilter):
def __init__(self, name, parent, min_value=None, max_value=None, median_filter=None):
super().__init__(name=name, parent=parent)
def process_frame(self, data, data_index, stream_index):
data = data[0]
dtype = data.dtype
data = data.astype(np.float32)
tmin = np.amin(data)
tmax = np.amax(data)
trange = tmax - tmin
threshold = tmax - trange * 0.10
mask = np.where(data == 0, 1, 0).astype(np.uint8)
# print('fix dead', cv.countNonZero(mask))
data = cv.inpaint(data, mask, 3, cv.INPAINT_NS)
data = cv.medianBlur(data, 3)
data = np.stack([data]).astype(dtype)
return data
### stack_with ###
@stream_filter("stack")
class StreamStacked(StreamFilter):
def __init__(self, stack_stream, name, parent):
super().__init__(name=name, parent=parent)
self.stack_stream = stack_stream
def set_source(self, src):
super().set_source(src)
self.stack_stream.set_source(src)
@property
def shape(self):
return (
self.parent.shape[0],
self.parent.shape[1] + self.stack_stream.shape[1],
self.parent.shape[2],
self.parent.shape[3],
)
def process(self, data, indices):
self.data2 = self.stack_stream.load(indices)
return super().process(data, indices)
def process_frame(self, data, data_index, stream_index):
return np.concatenate((data, self.data2[data_index]), axis=0)
### ajust ###
@stream_filter("adjust")
class StreamAdjust(StreamFilter):
def __init__(self, adjust_to, name, parent):
super().__init__(name=name, parent=parent)
self.adjust_to = adjust_to
def set_source(self, src):
super().set_source(src)
self.adjust_to.set_source(src)
@property
def shape(self):
return (self.adjust_to.shape[0], self.parent.shape[1], self.parent.shape[2], self.parent.shape[3])
@property
def timestamps(self):
return self.adjust_to.timestamps
def get_indices(self, index):
return super().get_indices(index)
def load(self, index):
# TODO load only relevant data if cropped
if isinstance(index, tuple):
index = index[0]
print("WARNING: cropping not yet implemented")
# original stream indices
old_indices = self.get_indices(index)
selected_timestamps = [self.adjust_to.timestamps[i] for i in old_indices]
kdtree = cKDTree(self.parent.timestamps[:, np.newaxis])
def get_index(val, kdtree):
_, i = kdtree.query(val, k=1)
return i
new_indices = [get_index(ts[np.newaxis], kdtree) for ts in selected_timestamps]
if False:
print("DEBUG: indices alignement:")
print(old_indices)
print(new_indices)
print(self.parent.timestamps)
print(self.adjust_to.timestamps)
return super().load(new_indices)
### warp_with ###
@stream_filter("warp")
class StreamWarp(StreamFilter):
def __init__(self, warp_to, name, parent):
super().__init__(name=name, parent=parent)
self.warp_to = warp_to
def set_source(self, src):
super().set_source(src)
self.warp_to.set_source(src)
@property
def shape(self):
return (self.parent.shape[0], self.parent.shape[1], self.warp_to.shape[2], self.warp_to.shape[3])
def process(self, data, indices):
self.markers = (self.warp_to.camera.markers, self.camera.markers)
self.warp_transform = transform.ProjectiveTransform()
self.warp_transform.estimate(*self.markers)
self.output_shape = (self.warp_to.shape[2], self.warp_to.shape[3])
return super().process(data, indices)
def process_frame(self, data, data_index, stream_index):
output = []
num_chan = data.shape[0]
for c in range(num_chan):
output.append(
transform.warp(
data[c], self.warp_transform, output_shape=self.output_shape, preserve_range=True
).astype(data.dtype)
)
output = np.stack(output)
return output
### stereo ###
@stream_filter("stereo")
class StreamStereo(StreamFilter):
def __init__(self, match_with_stream, name, parent, stereo_parameters=StereoParameters()):
super().__init__(name=name, parent=parent)
self.match_with_stream = match_with_stream
self.stereo_parameters = stereo_parameters
def set_source(self, src):
super().set_source(src)
self.match_with_stream.set_source(src)
def process(self, data, indices):
self.camera_pair = CameraPair(self.camera, self.match_with_stream.camera)
self.right_data = self.match_with_stream.load(indices)
return super().process(data, indices)
def process_frame(self, data, data_index, stream_index):
return stereo_match(
data, self.right_data[data_index], self.camera_pair, stereo_parameters=self.stereo_parameters
)
### reproject ###
@stream_filter("reproject")
class StreamReproject(StreamFilter):
def __init__(self, left_stream, right_stream, map_3d, name, parent):
super().__init__(name=name, parent=parent)
self.left_stream = left_stream
self.right_stream = right_stream
self.map_3d = map_3d
self.__bounding_box = StreamArray(self)
self.__image_points = StreamArray(self)
def set_source(self, src):
super().set_source(src)
self.left_stream.set_source(src)
self.right_stream.set_source(src)
self.map_3d.set_source(src)
@property
def shape(self):
return (self.parent.shape[0], self.parent.shape[1], self.map_3d.shape[2], self.map_3d.shape[3])
@property
def bounding_box(self):
return self.__bounding_box
@property
def image_points(self):
return self.__image_points
def process(self, data, indices):
self.map_3d_data = self.map_3d.load(indices)
self.camera_pair = CameraPair(self.left_stream.camera, self.right_stream.camera)
return super().process(data, indices)
def process_frame(self, data, data_index, stream_index):
# copy parent's bounding box
bounding_box = self.parent.bounding_box[stream_index]
if bounding_box is not None:
# TODO do type checking in StreamArray
assert isinstance(bounding_box, np.ndarray)
assert bounding_box.shape[0] == 2
assert bounding_box.shape[1] == 2
bounding_box = np.copy(bounding_box)
self.__bounding_box[stream_index] = bounding_box
# copy parent's image points
image_points = self.parent.image_points[stream_index]
if image_points is not None:
assert isinstance(image_points, np.ndarray)
assert image_points.shape[1] == 2
image_points = np.copy(image_points)
self.__image_points[stream_index] = image_points
# reproject
return reproject_image(
data,
self.map_3d_data[data_index],
self.camera,
self.camera_pair,
bounding_box=bounding_box,
image_points=image_points,
)
### subtract dark frame ###
@stream_filter("subtract")
class StreamSubtract(StreamFilter):
def __init__(self, dark, name, parent):
self.dark = dark
super().__init__(name=name, parent=parent)
def set_source(self, src):
super().set_source(src)
self.dark.set_source(src)
def process(self, data, indices):
dark_data = self.dark.load(indices)
assert data.shape == dark_data.shape
# if data > dark_data: return data - dark-data, else return 0 (substraction of uint can be problematic)
return np.where(data > dark_data, data - dark_data, 0)
......@@ -53,17 +53,17 @@ def test_stream():
)
# stream for stereo and projection tests
color = f.stream("color")
color = Stream("color", f)
nir_left = Stream("nir_left_stereo").adjust(color)
nir_right = f.stream("nir_right_stereo").adjust(color)
nir_right = Stream("nir_right_stereo", f).adjust(color)
# streams for subtract tests
swir_dark = f.stream("swir")
swir_940 = f.stream("swir_940nm")
swir_dark = Stream("swir", f)
swir_940 = Stream("swir_940nm", f)
# streams for stack, normalize and warp tests
swir_1050 = f.stream("swir_1050nm")
swir_1300 = f.stream("swir_1300nm")
swir_1550 = f.stream("swir_1550nm")
thermal = f.stream("thermal")
swir_1050 = Stream("swir_1050nm", f)
swir_1300 = Stream("swir_1300nm", f)
swir_1550 = Stream("swir_1550nm", f)
thermal = Stream("thermal", f)
# reproject operations
map_3d = nir_left.stereo(nir_right)
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment