Commit 3b4f0a95 authored by Amir MOHAMMADI's avatar Amir MOHAMMADI

Merge branch 'handle-none-data' into 'master'

[transformer][wrapper] handle None data in a way that they are saved in score files in the end

See merge request !44
parents cbbd86a3 824e4d1f
Pipeline #51441 passed with stages
in 9 minutes and 37 seconds
import logging import logging
from sklearn.base import BaseEstimator, TransformerMixin from sklearn.base import BaseEstimator, TransformerMixin
from bob.pipelines.wrappers import _check_n_input_output, _frmt
from . import utils from . import utils
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -43,20 +43,33 @@ class VideoWrapper(TransformerMixin, BaseEstimator): ...@@ -43,20 +43,33 @@ class VideoWrapper(TransformerMixin, BaseEstimator):
for index in video.indices for index in video.indices
] ]
data = self.estimator.transform(video, **kw) # remove None's before calling and add them back in data later
# Isolate invalid samples (when previous transformers returned None)
dl, vl = len(data), len(video) invalid_ids = [i for i, frame in enumerate(video) if frame is None]
if dl != vl: valid_frames = [frame for frame in video if frame is not None]
raise RuntimeError(
f"Length of transformed data ({dl}) using {self.estimator}" # remove invalid kw args as well
f" is different from the length of input video: {vl}" for k, v in kw.items():
kw[k] = [vv for j, vv in enumerate(v) if j not in invalid_ids]
# Process only the valid samples
output = None
if len(valid_frames) > 0:
output = self.estimator.transform(valid_frames, **kw)
valid_frames, output, f"{_frmt(self.estimator)}.transform"
) )
# handle None's if output is None:
indices = [idx for d, idx in zip(data, video.indices) if d is not None] output = [None] * len(valid_frames)
data = [d for d in data if d is not None]
# Rebuild the full batch of samples (include the previously failed)
if len(invalid_ids) > 0:
output = list(output)
for j in invalid_ids:
output.insert(j, None)
data = utils.VideoLikeContainer(data, indices) data = utils.VideoLikeContainer(output, video.indices)
transformed_videos.append(data) transformed_videos.append(data)
return transformed_videos return transformed_videos
...@@ -229,8 +229,8 @@ class VideoLikeContainer: ...@@ -229,8 +229,8 @@ class VideoLikeContainer:
def save(self, file): def save(self, file):
self.save_function(self, file) self.save_function(self, file)
@classmethod @staticmethod
def save_function(cls, other, file): def save_function(other, file):
with h5py.File(file, mode="w") as f: with h5py.File(file, mode="w") as f:
f["data"] = f["data"] =
f["indices"] = other.indices f["indices"] = other.indices
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment