Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
bob.pipelines
Manage
Activity
Members
Labels
Plan
Issues
2
Issue boards
Milestones
Code
Merge requests
0
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Package Registry
Model registry
Operate
Environments
Terraform modules
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
bob
bob.pipelines
Commits
84983bac
Commit
84983bac
authored
4 years ago
by
Yannick DAYER
Browse files
Options
Downloads
Patches
Plain Diff
[py] Adds annotations-related wrappers
parent
57bd3218
No related branches found
No related tags found
1 merge request
!42
Adding annotations-related wrappers
Pipeline
#44957
failed with stage
in 1 hour, 34 minutes, and 25 seconds
Changes
2
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
bob/pipelines/__init__.py
+2
-0
2 additions, 0 deletions
bob/pipelines/__init__.py
bob/pipelines/wrappers.py
+170
-0
170 additions, 0 deletions
bob/pipelines/wrappers.py
with
172 additions
and
0 deletions
bob/pipelines/__init__.py
+
2
−
0
View file @
84983bac
...
...
@@ -10,9 +10,11 @@ from .sample import hdf5_to_sample # noqa
from
.sample
import
sample_to_hdf5
# noqa
from
.wrappers
import
BaseWrapper
from
.wrappers
import
CheckpointWrapper
from
.wrappers
import
CheckpointAnnotationsWrapper
from
.wrappers
import
DaskWrapper
from
.wrappers
import
DelayedSamplesCall
from
.wrappers
import
SampleWrapper
from
.wrappers
import
AnnotatedSampleWrapper
from
.wrappers
import
ToDaskBag
from
.wrappers
import
dask_tags
# noqa
from
.wrappers
import
wrap
# noqa
...
...
This diff is collapsed.
Click to expand it.
bob/pipelines/wrappers.py
+
170
−
0
View file @
84983bac
"""
Scikit-learn Estimator Wrappers.
"""
import
logging
import
json
import
os
from
functools
import
partial
...
...
@@ -44,6 +45,15 @@ def copy_learned_attributes(from_estimator, to_estimator):
setattr
(
to_estimator
,
k
,
v
)
def
json_dump
(
data
,
path
):
with
open
(
path
,
"
w
"
)
as
f
:
json
.
dump
(
data
,
f
)
def
json_load
(
path
):
with
open
(
path
,
"
r
"
)
as
f
:
return
json
.
load
(
f
)
class
BaseWrapper
(
MetaEstimatorMixin
,
BaseEstimator
):
"""
The base class for all wrappers.
"""
...
...
@@ -174,6 +184,74 @@ class SampleWrapper(BaseWrapper, TransformerMixin):
return
self
class
AnnotatedSampleWrapper
(
SampleWrapper
):
"""
Wraps an annotator Transformer to set the sample.annotations correctly.
An :py:class:`~bob.bio.base.Annotator` transformer simply returns its
results, (or in the :py:attr:`~bob.pipelines.Sample.data` attribute when
wrapped with :py:class:`~bob.pipelines.SampleWrapper`).
Use this wrapper uniquely with Annotators, and INSTEAD of the
:py:class:`~bob.pipelines.SampleWrapper`.
Attributes
----------
fit_extra_arguments : [tuple]
Use this option if you want to pass extra arguments to the fit method of the
mixed instance. The format is a list of two value tuples. The first value in
tuples is the name of the argument that fit accepts, like ``y``, and the second
value is the name of the attribute that samples carry. For example, if you are
passing samples to the fit method and want to pass ``subject`` attributes of
samples as the ``y`` argument to the fit method, you can provide ``[(
"
y
"
,
"
subject
"
)]`` as the value for this attribute.
transform_extra_arguments : [tuple]
Similar to ``fit_extra_arguments`` but for the transform and other similar methods.
"""
def
__init__
(
self
,
annotator
,
transform_extra_arguments
=
None
,
fit_extra_arguments
=
None
,
**
kwargs
,
):
super
().
__init__
(
estimator
=
annotator
,
transform_extra_arguments
=
transform_extra_arguments
,
fit_extra_arguments
=
fit_extra_arguments
,
**
kwargs
,
)
def
_samples_transform
(
self
,
samples
,
method_name
):
"""
Transforms a set of samples by calling the annotator with any method.
Overrides SampleWrapper.sample_transform to insert annotations in their
field (:py:attr:`~bob.pipelines.Sample.annotations`) instead of the
:py:attr:`~bob.pipelines.Sample.data` field.
"""
# Transform either samples or samplesets
method
=
getattr
(
self
.
estimator
,
method_name
)
logger
.
debug
(
f
"
{
_frmt
(
self
)
}
.
{
method_name
}
"
)
func_name
=
f
"
{
self
}
.
{
method_name
}
"
if
isinstance
(
samples
[
0
],
SampleSet
):
return
[
SampleSet
(
self
.
_samples_transform
(
sset
.
samples
,
method_name
),
parent
=
sset
,
)
for
sset
in
samples
]
else
:
kwargs
=
_make_kwargs_from_samples
(
samples
,
self
.
transform_extra_arguments
)
delayed
=
DelayedSamplesCall
(
partial
(
method
,
**
kwargs
),
func_name
,
samples
,)
new_samples
=
[
DelayedSample
(
load
=
s
.
load
,
annotations
=
partial
(
delayed
,
index
=
i
)(),
parent
=
s
)
for
i
,
s
in
enumerate
(
samples
)
]
return
new_samples
class
CheckpointWrapper
(
BaseWrapper
,
TransformerMixin
):
"""
Wraps :any:`Sample`-based estimators so the results are saved in
disk.
"""
...
...
@@ -315,6 +393,96 @@ class CheckpointWrapper(BaseWrapper, TransformerMixin):
cloudpickle
.
dump
(
self
,
f
)
return
self
class
CheckpointAnnotationsWrapper
(
CheckpointWrapper
):
"""
Wraps :any:`Sample`-based estimators so the annotations are saved to
disk.
"""
def
__init__
(
self
,
annotator
,
annotations_dir
=
None
,
extension
=
"
.json
"
,
save_func
=
None
,
load_func
=
None
,
force
=
False
,
**
kwargs
,
):
save_func
=
save_func
or
json_dump
load_func
=
load_func
or
json_load
super
().
__init__
(
estimator
=
annotator
,
features_dir
=
annotations_dir
,
extension
=
extension
,
save_func
=
save_func
,
load_func
=
load_func
,
**
kwargs
,
)
self
.
force
=
force
def
_checkpoint_transform
(
self
,
samples
,
method_name
):
# Transform either samples or samplesets
method
=
getattr
(
self
.
estimator
,
method_name
)
logger
.
debug
(
f
"
{
_frmt
(
self
)
}
.
{
method_name
}
"
)
# if features_dir is None, just transform all samples at once
if
self
.
features_dir
is
None
:
return
method
(
samples
)
def
_transform_samples
(
samples
):
paths
=
[
self
.
make_path
(
s
)
for
s
in
samples
]
should_compute_list
=
[
p
is
None
or
not
os
.
path
.
isfile
(
p
)
or
self
.
force
for
p
in
paths
]
# call method on non-checkpointed samples
non_existing_samples
=
[
s
for
s
,
should_compute
in
zip
(
samples
,
should_compute_list
)
if
should_compute
]
# non_existing_samples could be empty
computed_features
=
[]
if
non_existing_samples
:
computed_features
=
method
(
non_existing_samples
)
_check_n_input_output
(
non_existing_samples
,
computed_features
,
method
)
# return computed features and checkpointed features
features
,
com_feat_index
=
[],
0
for
s
,
p
,
should_compute
in
zip
(
samples
,
paths
,
should_compute_list
):
if
should_compute
:
feat
=
computed_features
[
com_feat_index
]
com_feat_index
+=
1
# save the computed feature
if
p
is
not
None
:
self
.
save
(
feat
)
feat
=
self
.
load
(
s
,
p
)
s
.
annotations
=
feat
else
:
s
.
annotations
=
self
.
load
(
s
,
p
)
return
samples
if
isinstance
(
samples
[
0
],
SampleSet
):
return
[
SampleSet
(
_transform_samples
(
s
.
samples
),
parent
=
s
)
for
s
in
samples
]
else
:
return
_transform_samples
(
samples
)
def
save
(
self
,
sample
):
"""
Saves a sample
'
s annotations to disk using self.save_func.
Overrides CheckpointAnnotations.save
"""
path
=
self
.
make_path
(
sample
)
os
.
makedirs
(
os
.
path
.
dirname
(
path
),
exist_ok
=
True
)
return
self
.
save_func
(
sample
.
annotations
,
path
)
def
load
(
self
,
sample
,
path
):
"""
Loads a sample
'
s annotations from disk using self.load_func.
Overrides CheckpointAnnotations.load
"""
return
self
.
load_func
(
path
)
class
DaskWrapper
(
BaseWrapper
,
TransformerMixin
):
"""
Wraps Scikit estimators to handle Dask Bags as input.
...
...
@@ -459,6 +627,8 @@ def wrap(bases, estimator=None, **kwargs):
"
sample
"
:
SampleWrapper
,
"
checkpoint
"
:
CheckpointWrapper
,
"
dask
"
:
DaskWrapper
,
"
annotated_sample
"
:
AnnotatedSampleWrapper
,
"
checkpoint_annotations
"
:
CheckpointAnnotationsWrapper
,
}[
w
.
lower
()]
def
_wrap
(
estimator
,
**
kwargs
):
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment