Commit 824e4d1f authored by Amir MOHAMMADI's avatar Amir MOHAMMADI

[transformer][wrapper] handle None data in a way that they are saved in score files in the end

parent cbbd86a3
Pipeline #51285 passed with stage
in 8 minutes and 28 seconds
import logging
from sklearn.base import BaseEstimator, TransformerMixin
from bob.pipelines.wrappers import _check_n_input_output, _frmt
from . import utils
logger = logging.getLogger(__name__)
......@@ -43,20 +43,33 @@ class VideoWrapper(TransformerMixin, BaseEstimator):
for index in video.indices
]
data = self.estimator.transform(video, **kw)
dl, vl = len(data), len(video)
if dl != vl:
raise RuntimeError(
f"Length of transformed data ({dl}) using {self.estimator}"
f" is different from the length of input video: {vl}"
# remove None's before calling and add them back in data later
# Isolate invalid samples (when previous transformers returned None)
invalid_ids = [i for i, frame in enumerate(video) if frame is None]
valid_frames = [frame for frame in video if frame is not None]
# remove invalid kw args as well
for k, v in kw.items():
kw[k] = [vv for j, vv in enumerate(v) if j not in invalid_ids]
# Process only the valid samples
output = None
if len(valid_frames) > 0:
output = self.estimator.transform(valid_frames, **kw)
_check_n_input_output(
valid_frames, output, f"{_frmt(self.estimator)}.transform"
)
# handle None's
indices = [idx for d, idx in zip(data, video.indices) if d is not None]
data = [d for d in data if d is not None]
if output is None:
output = [None] * len(valid_frames)
# Rebuild the full batch of samples (include the previously failed)
if len(invalid_ids) > 0:
output = list(output)
for j in invalid_ids:
output.insert(j, None)
data = utils.VideoLikeContainer(data, indices)
data = utils.VideoLikeContainer(output, video.indices)
transformed_videos.append(data)
return transformed_videos
......
......@@ -229,8 +229,8 @@ class VideoLikeContainer:
def save(self, file):
self.save_function(self, file)
@classmethod
def save_function(cls, other, file):
@staticmethod
def save_function(other, file):
with h5py.File(file, mode="w") as f:
f["data"] = other.data
f["indices"] = other.indices
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment