Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
bob
bob.learn.boosting
Commits
7e011d68
Commit
7e011d68
authored
Sep 16, 2013
by
Manuel Günther
Browse files
Added possibility to get the extracted feature indices; improved IO; fixed small bug.
parent
1dd770ee
Changes
2
Hide whitespace changes
Inline
Side-by-side
xbob/boosting/core/boosting.py
View file @
7e011d68
...
...
@@ -235,10 +235,12 @@ class BoostMachine():
""" The class to perform the classification using the set of weak trainer """
def
__init__
(
self
):
def
__init__
(
self
,
number_of_outputs
=
1
):
""" Initialize the set of weak trainers and the alpha values (scale)"""
self
.
alpha
=
[]
self
.
weak_trainer
=
[]
self
.
number_of_outputs
=
number_of_outputs
self
.
selected_indices
=
set
()
...
...
@@ -252,8 +254,13 @@ class BoostMachine():
"""
self
.
alpha
.
append
(
curr_alpha
)
self
.
weak_trainer
.
append
(
curr_trainer
)
self
.
selected_indices
|=
set
([
curr_trainer
.
selected_indices
[
i
]
for
i
in
range
(
self
.
number_of_outputs
)])
def
feature_indices
(
self
):
"""Returns the indices of the features that are selected by the weak classifiers."""
return
sorted
(
list
(
self
.
selected_indices
))
def
__call__
(
self
,
feature
):
"""Returns the predicted score for the given single feature, assuming only single output.
...
...
@@ -296,9 +303,8 @@ class BoostMachine():
# Initialization
num_trainer
=
len
(
self
.
weak_trainer
)
num_samp
=
test_features
.
shape
[
0
]
num_op
=
test_features
.
shape
[
1
]
pred_labels
=
-
numpy
.
ones
([
num_samp
,
num_op
])
pred_scores
=
numpy
.
zeros
([
num_samp
,
num_op
])
pred_labels
=
-
numpy
.
ones
([
num_samp
,
self
.
number_of_outputs
])
pred_scores
=
numpy
.
zeros
([
num_samp
,
self
.
number_of_outputs
])
# For each round of boosting calculate the weak scores for that round and add to the total
...
...
@@ -308,7 +314,7 @@ class BoostMachine():
pred_scores
=
pred_scores
+
self
.
alpha
[
i
]
*
weak_scores
# predict the labels for test features based on score sign (for binary case) and score value (multivariate case)
if
(
num_op
==
1
):
if
(
self
.
number_of_outputs
==
1
):
pred_labels
[
pred_scores
>=
0
]
=
1
pred_labels
=
numpy
.
squeeze
(
pred_labels
)
else
:
...
...
@@ -321,6 +327,7 @@ class BoostMachine():
# hdf5File.set_attribute("MachineType", self.weak_trainer_type)
hdf5File
.
set_attribute
(
"version"
,
0
)
hdf5File
.
set
(
"Weights"
,
self
.
alpha
)
hdf5File
.
set
(
"Outputs"
,
self
.
number_of_outputs
)
for
i
in
range
(
len
(
self
.
weak_trainer
)):
dir_name
=
"WeakMachine%d"
%
i
hdf5File
.
create_group
(
dir_name
)
...
...
@@ -333,7 +340,9 @@ class BoostMachine():
def
load
(
self
,
hdf5File
):
# self.weak_trainer_type = hdf5File.get_attribute("MachineType")
self
.
alpha
=
hdf5File
.
read
(
"Weights"
)
self
.
number_of_outputs
=
hdf5File
.
read
(
"Outputs"
)
self
.
weak_trainer
=
[]
self
.
selected_indices
=
set
()
for
i
in
range
(
len
(
self
.
alpha
)):
dir_name
=
"WeakMachine%d"
%
i
hdf5File
.
cd
(
dir_name
)
...
...
@@ -344,6 +353,7 @@ class BoostMachine():
}
[
weak_machine_type
]
weak_machine
.
load
(
hdf5File
)
self
.
weak_trainer
.
append
(
weak_machine
)
self
.
selected_indices
|=
set
([
weak_machine
.
selected_indices
[
i
]
for
i
in
range
(
self
.
number_of_outputs
)])
hdf5File
.
cd
(
'..'
)
xbob/boosting/core/trainers.py
View file @
7e011d68
...
...
@@ -73,6 +73,8 @@ class StumpMachine():
self
.
selected_indices
=
hdf5File
.
read
(
"Indices"
)
self
.
threshold
=
hdf5File
.
read
(
"Threshold"
)
self
.
polarity
=
hdf5File
.
read
(
"Polarity"
)
if
isinstance
(
self
.
selected_indices
,
int
):
self
.
selected_indices
=
numpy
.
array
([
self
.
selected_indices
],
dtype
=
numpy
.
int
)
...
...
@@ -199,7 +201,7 @@ class LutMachine():
"""
self
.
luts
=
numpy
.
ones
((
num_entries
,
num_outputs
),
dtype
=
numpy
.
int
)
self
.
selected_indices
=
numpy
.
zeros
(
[
num_outputs
,
1
]
,
'int16'
)
self
.
selected_indices
=
numpy
.
zeros
(
(
num_outputs
,
)
,
'int16'
)
...
...
@@ -245,6 +247,8 @@ class LutMachine():
"""Reads the state of this machine from the given HDF5File."""
self
.
luts
=
hdf5File
.
read
(
"LUT"
)
self
.
selected_indices
=
hdf5File
.
read
(
"Indices"
)
if
isinstance
(
self
.
selected_indices
,
int
):
self
.
selected_indices
=
numpy
.
array
([
self
.
selected_indices
],
dtype
=
numpy
.
int
)
class
LutTrainer
():
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment