Commit 3975182c authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

Merge branch 'add-annotations-to-wrappers' into 'master'

Allow setting specific attributes of sample

See merge request !43
parents 57bd3218 90f0443d
Pipeline #45000 failed with stages
in 3 minutes and 32 seconds
......@@ -101,6 +101,12 @@ class SampleWrapper(BaseWrapper, TransformerMixin):
"subject")]`` as the value for this attribute.
transform_extra_arguments : [tuple]
Similar to ``fit_extra_arguments`` but for the transform and other similar methods.
output_attribute: str
The name of a Sample attribute where the output of the estimator will be
saved to. [Default is ``data``]
Example:
if ``output_attribute`` is ``"annotations"``, then
``sample.annotations`` will contain the output of the estimator.
"""
def __init__(
......@@ -108,12 +114,14 @@ class SampleWrapper(BaseWrapper, TransformerMixin):
estimator,
transform_extra_arguments=None,
fit_extra_arguments=None,
output_attribute="data",
**kwargs,
):
super().__init__(**kwargs)
self.estimator = estimator
self.transform_extra_arguments = transform_extra_arguments or tuple()
self.fit_extra_arguments = fit_extra_arguments or tuple()
self.output_attribute = output_attribute
def _samples_transform(self, samples, method_name):
# Transform either samples or samplesets
......@@ -131,10 +139,16 @@ class SampleWrapper(BaseWrapper, TransformerMixin):
else:
kwargs = _make_kwargs_from_samples(samples, self.transform_extra_arguments)
delayed = DelayedSamplesCall(partial(method, **kwargs), func_name, samples,)
new_samples = [
DelayedSample(partial(delayed, index=i), parent=s)
for i, s in enumerate(samples)
]
if self.output_attribute != "data":
# Edit the sample.<output_attribute> instead of data
for i, s in enumerate(samples):
setattr(s, self.output_attribute, delayed(i))
new_samples = samples
else:
new_samples = [
DelayedSample(partial(delayed, index=i), parent=s)
for i, s in enumerate(samples)
]
return new_samples
def transform(self, samples):
......@@ -176,7 +190,15 @@ class SampleWrapper(BaseWrapper, TransformerMixin):
class CheckpointWrapper(BaseWrapper, TransformerMixin):
"""Wraps :any:`Sample`-based estimators so the results are saved in
disk."""
disk.
Parameters
----------
sample_attribute: str
The attribute of the Sample object that needs to be saved to disk.
[Default is ``data``].
"""
def __init__(
self,
......@@ -186,6 +208,7 @@ class CheckpointWrapper(BaseWrapper, TransformerMixin):
extension=".h5",
save_func=None,
load_func=None,
sample_attribute="data",
**kwargs,
):
super().__init__(**kwargs)
......@@ -195,6 +218,7 @@ class CheckpointWrapper(BaseWrapper, TransformerMixin):
self.extension = extension
self.save_func = save_func or bob.io.base.save
self.load_func = load_func or bob.io.base.load
self.sample_attribute = sample_attribute
if model_path is None and features_dir is None:
logger.warning(
"Both model_path and features_dir are None. "
......@@ -290,14 +314,21 @@ class CheckpointWrapper(BaseWrapper, TransformerMixin):
def save(self, sample):
path = self.make_path(sample)
os.makedirs(os.path.dirname(path), exist_ok=True)
return self.save_func(sample.data, path)
# Gets sample.data or sample.<sample_attribute> if specified
to_save = getattr(sample, self.sample_attribute)
return self.save_func(to_save, path)
def load(self, sample, path):
# because we are checkpointing, we return a DelayedSample
# instead of a normal (preloaded) sample. This allows the next
# phase to avoid loading it would it be unnecessary (e.g. next
# phase is already check-pointed)
return DelayedSample(partial(self.load_func, path), parent=sample)
if self.sample_attribute == "data":
return DelayedSample(partial(self.load_func, path), parent=sample)
else:
loaded = self.load_func(path)
setattr(sample, self.sample_attribute, loaded)
return sample
def load_model(self):
if is_estimator_stateless(self.estimator):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment