diff --git a/bob/ip/binseg/data/stare/__init__.py b/bob/ip/binseg/data/stare/__init__.py index 753c98feb92884f4f3ad11de3cc781a99d1bdb9e..3ef157056b9f3244c83f9579ef982c97d8b20cd3 100644 --- a/bob/ip/binseg/data/stare/__init__.py +++ b/bob/ip/binseg/data/stare/__init__.py @@ -44,21 +44,22 @@ _root_path = bob.extension.rc.get( ) -def _make_loader(root_path): +class _make_loader: #hack to get testing on the CI working fine for this dataset - def _raw_data_loader(sample): + def __init__(self, root_path): + self.root_path = root_path + + def __raw_data_loader__(self, sample): return dict( - data=load_pil_rgb(os.path.join(root_path, sample["data"])), - label=load_pil_1(os.path.join(root_path, sample["label"])), + data=load_pil_rgb(os.path.join(self.root_path, sample["data"])), + label=load_pil_1(os.path.join(self.root_path, sample["label"])), ) - def _loader(context, sample): + def __call__(self, context, sample): # "context" is ignored in this case - database is homogeneous # we returned delayed samples to avoid loading all images at once - return make_delayed(sample, _raw_data_loader) - - return _loader + return make_delayed(sample, self.__raw_data_loader__) def _make_dataset(root_path):