Skip to content
Snippets Groups Projects
Commit 79b338ca authored by Yannick DAYER's avatar Yannick DAYER
Browse files

SampleWrapper to accept multiple output in sample.

Ensure that output is delayed regardless of if it is "data".
parent c7e82189
Branches multi-output
No related tags found
No related merge requests found
Pipeline #61623 passed
"""Base definition of sample."""
from collections.abc import MutableSequence, Sequence
from typing import Any
from typing import Any, Callable
import numpy as np
......@@ -182,6 +182,25 @@ class DelayedSample(Sample):
super().__setattr__(name, value)
def set_delayed_attribute(self, name: str, value: Callable) -> None:
"""Sets a delayed attribute.
Parameters
----------
name
Name of the attribute to set
value
Callable that returns the attribute when getattribute is called
"""
delayed_attributes = getattr(self, "_delayed_attributes", None)
if delayed_attributes is None:
super().__setattr__("_delayed_attributes", {name: value})
else:
delayed_attributes[name] = value
super().__setattr__(name, None)
@property
def data(self):
"""Loads the data from the disk file."""
......
......@@ -7,6 +7,7 @@ import traceback
from functools import partial
from pathlib import Path
from typing import Callable
import cloudpickle
import dask
......@@ -247,6 +248,13 @@ class DelayedSamplesCall:
return self.output[index]
def _delayed_call_multiple_output(
delayed: Callable, sample_index: int, attr_index: int
):
"""Handles delayed calls returning a tuple of elements for each sample."""
return delayed(sample_index)[attr_index]
class SampleWrapper(BaseWrapper, TransformerMixin):
"""Wraps scikit-learn estimators to work with :any:`Sample`-based
pipelines.
......@@ -323,53 +331,92 @@ class SampleWrapper(BaseWrapper, TransformerMixin):
samples,
sample_attribute=self.input_attribute,
)
if self.output_attribute == "data": # Normal case
if self.output_attribute == "data": # Normal case, output is data
new_samples = [
DelayedSample(partial(delayed, index=i), parent=s)
for i, s in enumerate(samples)
]
elif isinstance(
self.output_attribute, str
): # Single attribute but not data
elif isinstance(self.output_attribute, str):
# Single attribute but output is not data
if not isinstance(samples[0], DelayedSample):
# Convert to a delayed sample
new_samples = [
DelayedSample(
partial(lambda: s.data),
delayed_attributes={
self.output_attribute: partial(delayed, index=i)
},
parent=s,
)
for i, s in enumerate(samples)
]
else:
for i, s in enumerate(samples):
setattr(s, self.output_attribute, None)
samples[i]._delayed_attributes.update(
{
self.output_attribute: partial(
delayed, index=i
),
}
s.set_delayed_attribute(
self.output_attribute, partial(delayed, index=i)
)
new_samples = samples
elif "data" in self.output_attribute: # TODO YD20220525
# Special case where the output is a tuple and contains "data"
elif "data" in self.output_attribute:
# Special case where the output is multiple and contains "data"
data_idx = self.output_attribute.index("data")
new_samples = [
DelayedSample(partial(delayed(i), index=i), parent=s)
DelayedSample(
partial(
_delayed_call_multiple_output,
delayed,
sample_index=i,
attr_index=data_idx,
),
parent=s,
delayed_attributes={
self.output_attribute[attr_idx]: partial(
_delayed_call_multiple_output,
delayed,
sample_index=i,
attr_index=attr_idx,
)
for attr_idx in range(len(self.output_attribute))
if attr_idx != data_idx
},
)
for i, s in enumerate(samples)
]
for i, s in enumerate(new_samples):
if i != data_idx:
else: # Multiple output attributes
if not isinstance(samples[0], DelayedSample):
# Convert to a delayed sample
new_samples = [
DelayedSample(
partial(lambda: s.data),
delayed_attributes={
self.output_attribute[attr_idx]: partial(
_delayed_call_multiple_output,
delayed,
sample_index=i,
attribute_index=attr_idx,
)
for attr_idx in range(
len(self.output_attribute)
)
},
parent=s,
)
for i, s in enumerate(samples)
]
else:
for i, s in enumerate(samples):
for attr_idx, attr_name in enumerate(
self.output_attribute
):
setattr(s, attr_name, delayed(i)[attr_idx])
else: # TODO YD20220525
for i, s in enumerate(samples):
for attr_idx, attr_name in enumerate(self.output_attribute):
setattr(s, attr_name, delayed(i)[attr_idx])
new_samples = samples
s.set_delayed_attribute(
attr_name,
partial(
_delayed_call_multiple_output,
delayed,
sample_index=i,
attribute_index=attr_idx,
),
)
new_samples = samples
return new_samples
def transform(self, samples):
......@@ -452,8 +499,8 @@ class CheckpointWrapper(BaseWrapper, TransformerMixin):
If None, will use the ``bob_feature_load_fn`` tag in the estimator, or default
to ``bob.io.base.load``.
sample_attribute: str
Defines the payload attribute of the sample.
sample_attribute: str or tuple[str]
Defines the payload attribute(s) of the sample.
If None, will use the ``bob_output`` tag in the estimator, or default to
``data``.
......@@ -497,11 +544,22 @@ class CheckpointWrapper(BaseWrapper, TransformerMixin):
if not bob_tags["bob_checkpoint_features"]:
logger.info(
"Checkpointing is disabled for %s beacuse the bob_checkpoint_features tag is False.",
"Checkpointing is disabled for %s because the bob_checkpoint_features tag is False.",
estimator,
)
features_dir = None
if (
not isinstance(self.sample_attribute, str)
and features_dir is not None
):
raise (
NotImplementedError(
"CheckpointWrapper only supports single output attributes. "
f"Please set the bob_checkpoint_features tag to False for {estimator}."
)
)
self.force = force
self.estimator = estimator
self.model_path = model_path
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment