Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
bob.bio.base
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Package registry
Model registry
Operate
Environments
Terraform modules
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
bob
bob.bio.base
Commits
6fd872d2
There was a problem fetching the pipeline summary.
Commit
6fd872d2
authored
7 years ago
by
Amir MOHAMMADI
Browse files
Options
Downloads
Patches
Plain Diff
Make the code more WET
parent
31bdc2e9
No related branches found
No related tags found
No related merge requests found
Pipeline
#
Changes
1
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
bob/bio/base/extractor/stacks.py
+41
-28
41 additions, 28 deletions
bob/bio/base/extractor/stacks.py
with
41 additions
and
28 deletions
bob/bio/base/extractor/stacks.py
+
41
−
28
View file @
6fd872d2
...
@@ -7,14 +7,15 @@ class MultipleExtractor(Extractor):
...
@@ -7,14 +7,15 @@ class MultipleExtractor(Extractor):
"""
Base class for SequentialExtractor and ParallelExtractor. This class is
"""
Base class for SequentialExtractor and ParallelExtractor. This class is
not meant to be used directly.
"""
not meant to be used directly.
"""
def
get_attributes
(
self
,
processors
):
@staticmethod
def
get_attributes
(
processors
):
requires_training
=
any
(
p
.
requires_training
for
p
in
processors
)
requires_training
=
any
(
p
.
requires_training
for
p
in
processors
)
split_training_data_by_client
=
any
(
p
.
split_training_data_by_client
for
split_training_data_by_client
=
any
(
p
.
split_training_data_by_client
for
p
in
processors
)
p
in
processors
)
min_extractor_file_size
=
min
(
p
.
min_extractor_file_size
for
p
in
min_extractor_file_size
=
min
(
p
.
min_extractor_file_size
for
p
in
processors
)
processors
)
min_feature_file_size
=
min
(
min_feature_file_size
=
min
(
p
.
min_feature_file_size
for
p
in
p
.
min_feature_file_size
for
p
in
processors
)
processors
)
return
(
requires_training
,
split_training_data_by_client
,
return
(
requires_training
,
split_training_data_by_client
,
min_extractor_file_size
,
min_feature_file_size
)
min_extractor_file_size
,
min_feature_file_size
)
...
@@ -23,38 +24,54 @@ class MultipleExtractor(Extractor):
...
@@ -23,38 +24,54 @@ class MultipleExtractor(Extractor):
return
groups
return
groups
def
train_one
(
self
,
e
,
training_data
,
extractor_file
,
apply
=
False
):
def
train_one
(
self
,
e
,
training_data
,
extractor_file
,
apply
=
False
):
"""
Trains one extractor and optionally applies the extractor on the
training data after training.
Parameters
----------
e : :any:`Extractor`
The extractor to train. The extractor should be able to save itself
in an opened hdf5 file.
training_data : [object] or [[object]]
The data to be used for training.
extractor_file : :any:`bob.io.base.HDF5File`
The opened hdf5 file to save the trained extractor inside.
apply : :obj:`bool`, optional
If ``True``, the extractor is applied to the training data after it
is trained and the data is returned.
Returns
-------
None or [object] or [[object]]
Returns ``None`` if ``apply`` is ``False``. Otherwise, returns the
transformed ``training_data``.
"""
if
not
e
.
requires_training
:
if
not
e
.
requires_training
:
if
not
apply
:
# do nothing since e does not require training!
return
pass
if
self
.
split_training_data_by_client
:
training_data
=
[[
e
(
d
)
for
d
in
datalist
]
for
datalist
in
training_data
]
else
:
training_data
=
[
e
(
d
)
for
d
in
training_data
]
# if any of the extractors require splitting the data, the
# if any of the extractors require splitting the data, the
# split_training_data_by_client is True.
# split_training_data_by_client is True.
elif
e
.
split_training_data_by_client
:
elif
e
.
split_training_data_by_client
:
e
.
train
(
training_data
,
extractor_file
)
e
.
train
(
training_data
,
extractor_file
)
if
not
apply
:
return
training_data
=
[[
e
(
d
)
for
d
in
datalist
]
for
datalist
in
training_data
]
# when no extractor needs splitting
# when no extractor needs splitting
elif
not
self
.
split_training_data_by_client
:
elif
not
self
.
split_training_data_by_client
:
e
.
train
(
training_data
,
extractor_file
)
e
.
train
(
training_data
,
extractor_file
)
if
not
apply
:
return
training_data
=
[
e
(
d
)
for
d
in
training_data
]
# when e here wants it flat but the data is split
# when e here wants it flat but the data is split
else
:
else
:
# make training_data flat
# make training_data flat
aligned_training_data
=
[
d
for
datalist
in
training_data
for
d
in
flat_training_data
=
[
d
for
datalist
in
training_data
for
d
in
datalist
]
datalist
]
e
.
train
(
aligned_training_data
,
extractor_file
)
e
.
train
(
flat_training_data
,
extractor_file
)
if
not
apply
:
return
if
not
apply
:
return
# prepare the training data for the next extractor
if
self
.
split_training_data_by_client
:
training_data
=
[[
e
(
d
)
for
d
in
datalist
]
training_data
=
[[
e
(
d
)
for
d
in
datalist
]
for
datalist
in
training_data
]
for
datalist
in
training_data
]
else
:
training_data
=
[
e
(
d
)
for
d
in
training_data
]
return
training_data
return
training_data
def
load
(
self
,
extractor_file
):
def
load
(
self
,
extractor_file
):
...
@@ -62,8 +79,7 @@ class MultipleExtractor(Extractor):
...
@@ -62,8 +79,7 @@ class MultipleExtractor(Extractor):
groups
=
self
.
get_extractor_groups
()
groups
=
self
.
get_extractor_groups
()
for
e
,
group
in
zip
(
self
.
processors
,
groups
):
for
e
,
group
in
zip
(
self
.
processors
,
groups
):
f
.
cd
(
group
)
f
.
cd
(
group
)
if
e
.
requires_training
:
e
.
load
(
f
)
e
.
load
(
f
)
f
.
cd
(
'
..
'
)
f
.
cd
(
'
..
'
)
...
@@ -112,10 +128,7 @@ class SequentialExtractor(SequentialProcessor, MultipleExtractor):
...
@@ -112,10 +128,7 @@ class SequentialExtractor(SequentialProcessor, MultipleExtractor):
with
HDF5File
(
extractor_file
,
'
w
'
)
as
f
:
with
HDF5File
(
extractor_file
,
'
w
'
)
as
f
:
groups
=
self
.
get_extractor_groups
()
groups
=
self
.
get_extractor_groups
()
for
i
,
(
e
,
group
)
in
enumerate
(
zip
(
self
.
processors
,
groups
)):
for
i
,
(
e
,
group
)
in
enumerate
(
zip
(
self
.
processors
,
groups
)):
if
i
==
len
(
self
.
processors
)
-
1
:
apply
=
i
!=
len
(
self
.
processors
)
-
1
apply
=
False
else
:
apply
=
True
f
.
create_group
(
group
)
f
.
create_group
(
group
)
f
.
cd
(
group
)
f
.
cd
(
group
)
training_data
=
self
.
train_one
(
e
,
training_data
,
f
,
training_data
=
self
.
train_one
(
e
,
training_data
,
f
,
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment