Skip to content
Snippets Groups Projects
Commit 3975182c authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

Merge branch 'add-annotations-to-wrappers' into 'master'

Allow setting specific attributes of sample

See merge request !43
parents 57bd3218 90f0443d
Branches
No related tags found
1 merge request!43Allow setting specific attributes of sample
Pipeline #45000 failed
...@@ -101,6 +101,12 @@ class SampleWrapper(BaseWrapper, TransformerMixin): ...@@ -101,6 +101,12 @@ class SampleWrapper(BaseWrapper, TransformerMixin):
"subject")]`` as the value for this attribute. "subject")]`` as the value for this attribute.
transform_extra_arguments : [tuple] transform_extra_arguments : [tuple]
Similar to ``fit_extra_arguments`` but for the transform and other similar methods. Similar to ``fit_extra_arguments`` but for the transform and other similar methods.
output_attribute: str
The name of a Sample attribute where the output of the estimator will be
saved to. [Default is ``data``]
Example:
if ``output_attribute`` is ``"annotations"``, then
``sample.annotations`` will contain the output of the estimator.
""" """
def __init__( def __init__(
...@@ -108,12 +114,14 @@ class SampleWrapper(BaseWrapper, TransformerMixin): ...@@ -108,12 +114,14 @@ class SampleWrapper(BaseWrapper, TransformerMixin):
estimator, estimator,
transform_extra_arguments=None, transform_extra_arguments=None,
fit_extra_arguments=None, fit_extra_arguments=None,
output_attribute="data",
**kwargs, **kwargs,
): ):
super().__init__(**kwargs) super().__init__(**kwargs)
self.estimator = estimator self.estimator = estimator
self.transform_extra_arguments = transform_extra_arguments or tuple() self.transform_extra_arguments = transform_extra_arguments or tuple()
self.fit_extra_arguments = fit_extra_arguments or tuple() self.fit_extra_arguments = fit_extra_arguments or tuple()
self.output_attribute = output_attribute
def _samples_transform(self, samples, method_name): def _samples_transform(self, samples, method_name):
# Transform either samples or samplesets # Transform either samples or samplesets
...@@ -131,10 +139,16 @@ class SampleWrapper(BaseWrapper, TransformerMixin): ...@@ -131,10 +139,16 @@ class SampleWrapper(BaseWrapper, TransformerMixin):
else: else:
kwargs = _make_kwargs_from_samples(samples, self.transform_extra_arguments) kwargs = _make_kwargs_from_samples(samples, self.transform_extra_arguments)
delayed = DelayedSamplesCall(partial(method, **kwargs), func_name, samples,) delayed = DelayedSamplesCall(partial(method, **kwargs), func_name, samples,)
new_samples = [ if self.output_attribute != "data":
DelayedSample(partial(delayed, index=i), parent=s) # Edit the sample.<output_attribute> instead of data
for i, s in enumerate(samples) for i, s in enumerate(samples):
] setattr(s, self.output_attribute, delayed(i))
new_samples = samples
else:
new_samples = [
DelayedSample(partial(delayed, index=i), parent=s)
for i, s in enumerate(samples)
]
return new_samples return new_samples
def transform(self, samples): def transform(self, samples):
...@@ -176,7 +190,15 @@ class SampleWrapper(BaseWrapper, TransformerMixin): ...@@ -176,7 +190,15 @@ class SampleWrapper(BaseWrapper, TransformerMixin):
class CheckpointWrapper(BaseWrapper, TransformerMixin): class CheckpointWrapper(BaseWrapper, TransformerMixin):
"""Wraps :any:`Sample`-based estimators so the results are saved in """Wraps :any:`Sample`-based estimators so the results are saved in
disk.""" disk.
Parameters
----------
sample_attribute: str
The attribute of the Sample object that needs to be saved to disk.
[Default is ``data``].
"""
def __init__( def __init__(
self, self,
...@@ -186,6 +208,7 @@ class CheckpointWrapper(BaseWrapper, TransformerMixin): ...@@ -186,6 +208,7 @@ class CheckpointWrapper(BaseWrapper, TransformerMixin):
extension=".h5", extension=".h5",
save_func=None, save_func=None,
load_func=None, load_func=None,
sample_attribute="data",
**kwargs, **kwargs,
): ):
super().__init__(**kwargs) super().__init__(**kwargs)
...@@ -195,6 +218,7 @@ class CheckpointWrapper(BaseWrapper, TransformerMixin): ...@@ -195,6 +218,7 @@ class CheckpointWrapper(BaseWrapper, TransformerMixin):
self.extension = extension self.extension = extension
self.save_func = save_func or bob.io.base.save self.save_func = save_func or bob.io.base.save
self.load_func = load_func or bob.io.base.load self.load_func = load_func or bob.io.base.load
self.sample_attribute = sample_attribute
if model_path is None and features_dir is None: if model_path is None and features_dir is None:
logger.warning( logger.warning(
"Both model_path and features_dir are None. " "Both model_path and features_dir are None. "
...@@ -290,14 +314,21 @@ class CheckpointWrapper(BaseWrapper, TransformerMixin): ...@@ -290,14 +314,21 @@ class CheckpointWrapper(BaseWrapper, TransformerMixin):
def save(self, sample): def save(self, sample):
path = self.make_path(sample) path = self.make_path(sample)
os.makedirs(os.path.dirname(path), exist_ok=True) os.makedirs(os.path.dirname(path), exist_ok=True)
return self.save_func(sample.data, path) # Gets sample.data or sample.<sample_attribute> if specified
to_save = getattr(sample, self.sample_attribute)
return self.save_func(to_save, path)
def load(self, sample, path): def load(self, sample, path):
# because we are checkpointing, we return a DelayedSample # because we are checkpointing, we return a DelayedSample
# instead of a normal (preloaded) sample. This allows the next # instead of a normal (preloaded) sample. This allows the next
# phase to avoid loading it would it be unnecessary (e.g. next # phase to avoid loading it would it be unnecessary (e.g. next
# phase is already check-pointed) # phase is already check-pointed)
return DelayedSample(partial(self.load_func, path), parent=sample) if self.sample_attribute == "data":
return DelayedSample(partial(self.load_func, path), parent=sample)
else:
loaded = self.load_func(path)
setattr(sample, self.sample_attribute, loaded)
return sample
def load_model(self): def load_model(self):
if is_estimator_stateless(self.estimator): if is_estimator_stateless(self.estimator):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment