Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
bob.learn.em
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Package registry
Model registry
Operate
Environments
Terraform modules
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
bob
bob.learn.em
Commits
6e804b8a
Commit
6e804b8a
authored
5 years ago
by
Amir MOHAMMADI
Browse files
Options
Downloads
Patches
Plain Diff
Improvements to the em train script
parent
c898cbf3
No related branches found
No related tags found
1 merge request
!36
WIP: Add a bob em train script which works on SGE
Pipeline
#32687
passed
5 years ago
Stage: build
Changes
1
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
bob/learn/em/script/train.py
+39
-11
39 additions, 11 deletions
bob/learn/em/script/train.py
with
39 additions
and
11 deletions
bob/learn/em/script/train.py
+
39
−
11
View file @
6e804b8a
...
@@ -32,7 +32,7 @@ SLEEP = 5
...
@@ -32,7 +32,7 @@ SLEEP = 5
cls
=
ConfigCommand
,
cls
=
ConfigCommand
,
epilog
=
"""
\b
epilog
=
"""
\b
Examples:
Examples:
$ bob em train -vvv config.py -o /tmp/gmm -- --array 64 -q q1d -m 4G ...
$ bob em train -vvv config.py -o /tmp/gmm -- --array 64
--jman-options
'
-q q1d
-i
-m 4G
'
...
Note: samples must be sorted!
Note: samples must be sorted!
"""
,
"""
,
...
@@ -49,8 +49,9 @@ Note: samples must be sorted!
...
@@ -49,8 +49,9 @@ Note: samples must be sorted!
required
=
True
,
required
=
True
,
cls
=
ResourceOption
,
cls
=
ResourceOption
,
help
=
"
A list of samples to be loaded with reader. The samples must be stable.
"
help
=
"
A list of samples to be loaded with reader. The samples must be stable.
"
"
The script will be called several times in separate
"
"
The script will be called several times in separate processes. Each time the
"
"
processes. Each time the samples should be the same! It
'
s best to sort them!
"
,
"
config file is loaded the samples should have the same order and must be exactly
"
"
the same! It
'
s best to sort them!
"
,
)
)
@click.option
(
@click.option
(
"
--output-dir
"
,
"
--output-dir
"
,
...
@@ -79,23 +80,40 @@ Note: samples must be sorted!
...
@@ -79,23 +80,40 @@ Note: samples must be sorted!
type
=
click
.
INT
,
type
=
click
.
INT
,
default
=
50
,
default
=
50
,
cls
=
ResourceOption
,
cls
=
ResourceOption
,
show_default
=
True
,
help
=
"
The maximum number of iterations to train a machine.
"
,
help
=
"
The maximum number of iterations to train a machine.
"
,
)
)
@click.option
(
@click.option
(
"
--convergence-threshold
"
,
"
--convergence-threshold
"
,
type
=
click
.
FLOAT
,
type
=
click
.
FLOAT
,
default
=
4e-5
,
show_default
=
True
,
cls
=
ResourceOption
,
cls
=
ResourceOption
,
help
=
"
The convergence threshold to train a machine. If None, the training
"
help
=
"
The convergence threshold to train a machine. If None, the training
"
"
procedure will stop with the iterations criteria.
"
,
"
procedure will stop with the iterations criteria.
"
,
)
)
@click.option
(
"
--initialization-stride
"
,
type
=
click
.
INT
,
default
=
1
,
show_default
=
True
,
cls
=
ResourceOption
,
help
=
"
The stride to use for selecting a subset of samples to initialize the
"
"
machine. Must be 1 or greater.
"
,
)
@click.option
(
@click.option
(
"
--jman-options
"
,
"
--jman-options
"
,
default
=
"
"
,
default
=
"
"
,
show_default
=
True
,
cls
=
ResourceOption
,
cls
=
ResourceOption
,
help
=
"
Additional options to be given to jman
"
,
help
=
"
Additional options to be given to jman
"
,
)
)
@click.option
(
@click.option
(
"
--jman
"
,
default
=
"
jman
"
,
cls
=
ResourceOption
,
help
=
"
Path to the jman script.
"
"
--jman
"
,
default
=
"
jman
"
,
show_default
=
True
,
cls
=
ResourceOption
,
help
=
"
Path to the jman script.
"
,
)
)
@click.option
(
@click.option
(
"
--step
"
,
"
--step
"
,
...
@@ -114,6 +132,7 @@ def train(
...
@@ -114,6 +132,7 @@ def train(
machine
,
machine
,
max_iterations
,
max_iterations
,
convergence_threshold
,
convergence_threshold
,
initialization_stride
,
jman_options
,
jman_options
,
jman
,
jman
,
step
,
step
,
...
@@ -145,15 +164,20 @@ def train(
...
@@ -145,15 +164,20 @@ def train(
)
)
raise
click
.
Abort
raise
click
.
Abort
# sanity check
assert
len
(
samples
)
//
array
>
machine
.
shape
[
0
],
"
Please reduce array number!
"
n_samples
=
len
(
samples
)
n_samples
=
len
(
samples
)
n_jobs
=
array
# some array jobs may not get any samples
# for example if n_samples is 241 and array is 64,
# each worker gets 4 samples and that means only 61 workers would get samples to
# work with
n_jobs
=
int
(
np
.
ceil
(
n_samples
/
np
.
ceil
(
n_samples
/
array
)))
# initialize
# initialize
if
trainer_type
in
(
"
KMeansTrainer
"
,
"
ML_GMMTrainer
"
):
if
trainer_type
in
(
"
KMeansTrainer
"
,
"
ML_GMMTrainer
"
):
logger
.
info
(
"
Loading %d samples to initialize the machine
"
,
len
(
samples
))
initilization_samples
=
samples
[::
initialization_stride
]
data
=
read_samples
(
reader
,
samples
)
logger
.
info
(
"
Loading %d samples to initialize the machine
"
,
len
(
initilization_samples
)
)
data
=
read_samples
(
reader
,
initilization_samples
)
logger
.
info
(
"
Initializing the trainer (and maybe machine)
"
)
logger
.
info
(
"
Initializing the trainer (and maybe machine)
"
)
trainer
.
initialize
(
machine
,
data
)
trainer
.
initialize
(
machine
,
data
)
...
@@ -349,7 +373,8 @@ def load_statistics(trainer, machine, path):
...
@@ -349,7 +373,8 @@ def load_statistics(trainer, machine, path):
with
HDF5File
(
path
,
"
r
"
)
as
f
:
with
HDF5File
(
path
,
"
r
"
)
as
f
:
if
trainer_type
==
"
KMeansTrainer
"
:
if
trainer_type
==
"
KMeansTrainer
"
:
trainer
.
zeroeth_order_statistics
=
f
[
"
zeroeth_order_statistics
"
]
zeros
=
f
[
"
zeroeth_order_statistics
"
]
trainer
.
zeroeth_order_statistics
=
np
.
array
(
zeros
).
reshape
((
-
1
,))
trainer
.
first_order_statistics
=
f
[
"
first_order_statistics
"
]
trainer
.
first_order_statistics
=
f
[
"
first_order_statistics
"
]
trainer
.
average_min_distance
=
f
[
"
average_min_distance
"
]
trainer
.
average_min_distance
=
f
[
"
average_min_distance
"
]
...
@@ -366,6 +391,9 @@ def load_statistics(trainer, machine, path):
...
@@ -366,6 +391,9 @@ def load_statistics(trainer, machine, path):
def
e_step
(
samples
,
reader
,
output_dir
,
trainer
,
machine
):
def
e_step
(
samples
,
reader
,
output_dir
,
trainer
,
machine
):
if
len
(
samples
)
==
0
:
print
(
"
This worker did not get any samples.
"
)
return
logger
.
info
(
"
Loading %d samples
"
,
len
(
samples
))
logger
.
info
(
"
Loading %d samples
"
,
len
(
samples
))
data
=
read_samples
(
reader
,
samples
)
data
=
read_samples
(
reader
,
samples
)
logger
.
info
(
"
Loaded all samples
"
)
logger
.
info
(
"
Loaded all samples
"
)
...
@@ -394,7 +422,7 @@ def read_samples(reader, samples):
...
@@ -394,7 +422,7 @@ def read_samples(reader, samples):
# read one sample to see if data is numpy arrays
# read one sample to see if data is numpy arrays
data
=
reader
(
samples
[
0
])
data
=
reader
(
samples
[
0
])
if
isinstance
(
data
,
np
.
ndarray
):
if
isinstance
(
data
,
np
.
ndarray
):
samples
=
vstack_features
(
reader
,
samples
,
same_size
=
Tru
e
)
samples
=
vstack_features
(
reader
,
samples
,
same_size
=
Fals
e
)
else
:
else
:
samples
=
[
reader
(
s
)
for
s
in
samples
]
samples
=
[
reader
(
s
)
for
s
in
samples
]
return
samples
return
samples
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment