Commit 663e9d4d authored by Amir MOHAMMADI's avatar Amir MOHAMMADI

* Add a transform function for VideoAsArray

* Don't load VideoLikeContainer checkpoints into memory
* More documentation
parent 30cff2d8
...@@ -7,7 +7,9 @@ import numpy as np ...@@ -7,7 +7,9 @@ import numpy as np
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def select_frames(count, max_number_of_frames=None, selection_style=None, step_size=None): def select_frames(
count, max_number_of_frames=None, selection_style=None, step_size=None
):
"""Returns indices of the frames to be selected given the parameters. """Returns indices of the frames to be selected given the parameters.
Different selection styles are supported: Different selection styles are supported:
...@@ -67,14 +69,35 @@ def select_frames(count, max_number_of_frames=None, selection_style=None, step_s ...@@ -67,14 +69,35 @@ def select_frames(count, max_number_of_frames=None, selection_style=None, step_s
class VideoAsArray: class VideoAsArray:
"""A memory efficient class to load only select video frames
It also supports efficient conversion to dask arrays.
"""
def __init__( def __init__(
self, self,
path, path,
selection_style=None, selection_style=None,
max_number_of_frames=None, max_number_of_frames=None,
step_size=None, step_size=None,
transform=None,
**kwargs, **kwargs,
): ):
"""init
Parameters
----------
path : str
Path to the video file
selection_style : str, optional
See :any:`select_frames`, by default None
max_number_of_frames : int, optional
See :any:`select_frames`, by default None
step_size : int, optional
See :any:`select_frames`, by default None
transform : callable, optional
A function that transforms the loaded video. This function should
not change the video shape or its dtype. For example, you may flip
the frames horizontally using this function, by default None
"""
super().__init__(**kwargs) super().__init__(**kwargs)
self.path = path self.path = path
self.reader = bob.io.video.reader(self.path) self.reader = bob.io.video.reader(self.path)
...@@ -89,6 +112,10 @@ class VideoAsArray: ...@@ -89,6 +112,10 @@ class VideoAsArray:
) )
self.indices = indices self.indices = indices
self.shape = (len(indices),) + shape[1:] self.shape = (len(indices),) + shape[1:]
if transform is None:
def transform(x):
return x
self.transform = transform
def __getstate__(self): def __getstate__(self):
d = self.__dict__.copy() d = self.__dict__.copy()
...@@ -106,7 +133,7 @@ class VideoAsArray: ...@@ -106,7 +133,7 @@ class VideoAsArray:
# logger.debug("Getting frame %s from %s", index, self.path) # logger.debug("Getting frame %s from %s", index, self.path)
if isinstance(index, int): if isinstance(index, int):
idx = self.indices[index] idx = self.indices[index]
return self.reader[idx] return self.transform([self.reader[idx]])[0]
if not (isinstance(index, tuple) and len(index) == self.ndim): if not (isinstance(index, tuple) and len(index) == self.ndim):
raise NotImplementedError(f"Indxing like {index} is not supported yet!") raise NotImplementedError(f"Indxing like {index} is not supported yet!")
...@@ -115,7 +142,7 @@ class VideoAsArray: ...@@ -115,7 +142,7 @@ class VideoAsArray:
return np.array([], dtype=self.dtype) return np.array([], dtype=self.dtype)
if self.selection_style == "all": if self.selection_style == "all":
return np.asarray(self.reader.load())[index] return self.transform(np.asarray(self.reader.load())[index])
idx = self.indices[index[0]] idx = self.indices[index[0]]
video = [] video = []
...@@ -127,7 +154,7 @@ class VideoAsArray: ...@@ -127,7 +154,7 @@ class VideoAsArray:
break break
index = (slice(len(video)),) + index[1:] index = (slice(len(video)),) + index[1:]
return np.asarray(video)[index] return self.transform(np.asarray(video)[index])
def __repr__(self): def __repr__(self):
return f"{self.reader!r} {self.dtype!r} {self.ndim!r} {self.shape!r} {self.indices!r}" return f"{self.reader!r} {self.dtype!r} {self.ndim!r} {self.shape!r} {self.indices!r}"
...@@ -156,8 +183,10 @@ class VideoLikeContainer: ...@@ -156,8 +183,10 @@ class VideoLikeContainer:
@classmethod @classmethod
def load(cls, file): def load(cls, file):
with h5py.File(file, mode="r") as f: # weak closing of the hdf5 file so we don't load all the data into
data = np.array(f["data"]) # memory https://docs.h5py.org/en/stable/high/file.html#closing-files
indices = np.array(f["indices"]) f = h5py.File(file, mode="r")
data = f["data"]
indices = f["indices"]
self = cls(data=data, indices=indices) self = cls(data=data, indices=indices)
return self return self
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment