Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
bob.pipelines
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Package registry
Model registry
Operate
Environments
Terraform modules
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
bob
bob.pipelines
Commits
79b338ca
Commit
79b338ca
authored
3 years ago
by
Yannick DAYER
Browse files
Options
Downloads
Patches
Plain Diff
SampleWrapper to accept multiple output in sample.
Ensure that output is delayed regardless of if it is "data".
parent
c7e82189
Branches
multi-output
No related tags found
No related merge requests found
Pipeline
#61623
passed
3 years ago
Stage: build
Changes
2
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
bob/pipelines/sample.py
+20
-1
20 additions, 1 deletion
bob/pipelines/sample.py
bob/pipelines/wrappers.py
+83
-25
83 additions, 25 deletions
bob/pipelines/wrappers.py
with
103 additions
and
26 deletions
bob/pipelines/sample.py
+
20
−
1
View file @
79b338ca
"""
Base definition of sample.
"""
from
collections.abc
import
MutableSequence
,
Sequence
from
typing
import
Any
from
typing
import
Any
,
Callable
import
numpy
as
np
...
...
@@ -182,6 +182,25 @@ class DelayedSample(Sample):
super
().
__setattr__
(
name
,
value
)
def
set_delayed_attribute
(
self
,
name
:
str
,
value
:
Callable
)
->
None
:
"""
Sets a delayed attribute.
Parameters
----------
name
Name of the attribute to set
value
Callable that returns the attribute when getattribute is called
"""
delayed_attributes
=
getattr
(
self
,
"
_delayed_attributes
"
,
None
)
if
delayed_attributes
is
None
:
super
().
__setattr__
(
"
_delayed_attributes
"
,
{
name
:
value
})
else
:
delayed_attributes
[
name
]
=
value
super
().
__setattr__
(
name
,
None
)
@property
def
data
(
self
):
"""
Loads the data from the disk file.
"""
...
...
This diff is collapsed.
Click to expand it.
bob/pipelines/wrappers.py
+
83
−
25
View file @
79b338ca
...
...
@@ -7,6 +7,7 @@ import traceback
from
functools
import
partial
from
pathlib
import
Path
from
typing
import
Callable
import
cloudpickle
import
dask
...
...
@@ -247,6 +248,13 @@ class DelayedSamplesCall:
return
self
.
output
[
index
]
def
_delayed_call_multiple_output
(
delayed
:
Callable
,
sample_index
:
int
,
attr_index
:
int
):
"""
Handles delayed calls returning a tuple of elements for each sample.
"""
return
delayed
(
sample_index
)[
attr_index
]
class
SampleWrapper
(
BaseWrapper
,
TransformerMixin
):
"""
Wraps scikit-learn estimators to work with :any:`Sample`-based
pipelines.
...
...
@@ -323,53 +331,92 @@ class SampleWrapper(BaseWrapper, TransformerMixin):
samples
,
sample_attribute
=
self
.
input_attribute
,
)
if
self
.
output_attribute
==
"
data
"
:
# Normal case
if
self
.
output_attribute
==
"
data
"
:
# Normal case
, output is data
new_samples
=
[
DelayedSample
(
partial
(
delayed
,
index
=
i
),
parent
=
s
)
for
i
,
s
in
enumerate
(
samples
)
]
elif
isinstance
(
self
.
output_attribute
,
str
):
# Single attribute but not data
elif
isinstance
(
self
.
output_attribute
,
str
):
# Single attribute but output is not data
if
not
isinstance
(
samples
[
0
],
DelayedSample
):
# Convert to a delayed sample
new_samples
=
[
DelayedSample
(
partial
(
lambda
:
s
.
data
),
delayed_attributes
=
{
self
.
output_attribute
:
partial
(
delayed
,
index
=
i
)
},
parent
=
s
,
)
for
i
,
s
in
enumerate
(
samples
)
]
else
:
for
i
,
s
in
enumerate
(
samples
):
setattr
(
s
,
self
.
output_attribute
,
None
)
samples
[
i
].
_delayed_attributes
.
update
(
{
self
.
output_attribute
:
partial
(
delayed
,
index
=
i
),
}
s
.
set_delayed_attribute
(
self
.
output_attribute
,
partial
(
delayed
,
index
=
i
)
)
new_samples
=
samples
elif
"
data
"
in
self
.
output_attribute
:
# TODO YD20220525
# Special case where the output is
a tu
ple and contains "data"
elif
"
data
"
in
self
.
output_attribute
:
# Special case where the output is
multi
ple and contains "data"
data_idx
=
self
.
output_attribute
.
index
(
"
data
"
)
new_samples
=
[
DelayedSample
(
partial
(
delayed
(
i
),
index
=
i
),
parent
=
s
)
DelayedSample
(
partial
(
_delayed_call_multiple_output
,
delayed
,
sample_index
=
i
,
attr_index
=
data_idx
,
),
parent
=
s
,
delayed_attributes
=
{
self
.
output_attribute
[
attr_idx
]:
partial
(
_delayed_call_multiple_output
,
delayed
,
sample_index
=
i
,
attr_index
=
attr_idx
,
)
for
attr_idx
in
range
(
len
(
self
.
output_attribute
))
if
attr_idx
!=
data_idx
},
)
for
i
,
s
in
enumerate
(
samples
)
]
for
i
,
s
in
enumerate
(
new_samples
):
if
i
!=
data_idx
:
else
:
# Multiple output attributes
if
not
isinstance
(
samples
[
0
],
DelayedSample
):
# Convert to a delayed sample
new_samples
=
[
DelayedSample
(
partial
(
lambda
:
s
.
data
),
delayed_attributes
=
{
self
.
output_attribute
[
attr_idx
]:
partial
(
_delayed_call_multiple_output
,
delayed
,
sample_index
=
i
,
attribute_index
=
attr_idx
,
)
for
attr_idx
in
range
(
len
(
self
.
output_attribute
)
)
},
parent
=
s
,
)
for
i
,
s
in
enumerate
(
samples
)
]
else
:
for
i
,
s
in
enumerate
(
samples
):
for
attr_idx
,
attr_name
in
enumerate
(
self
.
output_attribute
):
setattr
(
s
,
attr_name
,
delayed
(
i
)[
attr_idx
])
else
:
# TODO YD20220525
for
i
,
s
in
enumerate
(
samples
):
for
attr_idx
,
attr_name
in
enumerate
(
self
.
output_attribute
):
setattr
(
s
,
attr_name
,
delayed
(
i
)[
attr_idx
])
new_samples
=
samples
s
.
set_delayed_attribute
(
attr_name
,
partial
(
_delayed_call_multiple_output
,
delayed
,
sample_index
=
i
,
attribute_index
=
attr_idx
,
),
)
new_samples
=
samples
return
new_samples
def
transform
(
self
,
samples
):
...
...
@@ -452,8 +499,8 @@ class CheckpointWrapper(BaseWrapper, TransformerMixin):
If None, will use the ``bob_feature_load_fn`` tag in the estimator, or default
to ``bob.io.base.load``.
sample_attribute: str
Defines the payload attribute of the sample.
sample_attribute: str
or tuple[str]
Defines the payload attribute
(s)
of the sample.
If None, will use the ``bob_output`` tag in the estimator, or default to
``data``.
...
...
@@ -497,11 +544,22 @@ class CheckpointWrapper(BaseWrapper, TransformerMixin):
if
not
bob_tags
[
"
bob_checkpoint_features
"
]:
logger
.
info
(
"
Checkpointing is disabled for %s be
a
cuse the bob_checkpoint_features tag is False.
"
,
"
Checkpointing is disabled for %s bec
a
use the bob_checkpoint_features tag is False.
"
,
estimator
,
)
features_dir
=
None
if
(
not
isinstance
(
self
.
sample_attribute
,
str
)
and
features_dir
is
not
None
):
raise
(
NotImplementedError
(
"
CheckpointWrapper only supports single output attributes.
"
f
"
Please set the bob_checkpoint_features tag to False for
{
estimator
}
.
"
)
)
self
.
force
=
force
self
.
estimator
=
estimator
self
.
model_path
=
model_path
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment