Skip to content
Snippets Groups Projects
Commit 6e804b8a authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

Improvements to the em train script

parent c898cbf3
No related branches found
No related tags found
1 merge request!36WIP: Add a bob em train script which works on SGE
Pipeline #32687 passed
...@@ -32,7 +32,7 @@ SLEEP = 5 ...@@ -32,7 +32,7 @@ SLEEP = 5
cls=ConfigCommand, cls=ConfigCommand,
epilog="""\b epilog="""\b
Examples: Examples:
$ bob em train -vvv config.py -o /tmp/gmm -- --array 64 -q q1d -m 4G ... $ bob em train -vvv config.py -o /tmp/gmm -- --array 64 --jman-options '-q q1d -i -m 4G' ...
Note: samples must be sorted! Note: samples must be sorted!
""", """,
...@@ -49,8 +49,9 @@ Note: samples must be sorted! ...@@ -49,8 +49,9 @@ Note: samples must be sorted!
required=True, required=True,
cls=ResourceOption, cls=ResourceOption,
help="A list of samples to be loaded with reader. The samples must be stable. " help="A list of samples to be loaded with reader. The samples must be stable. "
"The script will be called several times in separate " "The script will be called several times in separate processes. Each time the "
"processes. Each time the samples should be the same! It's best to sort them!", "config file is loaded the samples should have the same order and must be exactly "
"the same! It's best to sort them!",
) )
@click.option( @click.option(
"--output-dir", "--output-dir",
...@@ -79,23 +80,40 @@ Note: samples must be sorted! ...@@ -79,23 +80,40 @@ Note: samples must be sorted!
type=click.INT, type=click.INT,
default=50, default=50,
cls=ResourceOption, cls=ResourceOption,
show_default=True,
help="The maximum number of iterations to train a machine.", help="The maximum number of iterations to train a machine.",
) )
@click.option( @click.option(
"--convergence-threshold", "--convergence-threshold",
type=click.FLOAT, type=click.FLOAT,
default=4e-5,
show_default=True,
cls=ResourceOption, cls=ResourceOption,
help="The convergence threshold to train a machine. If None, the training " help="The convergence threshold to train a machine. If None, the training "
"procedure will stop with the iterations criteria.", "procedure will stop with the iterations criteria.",
) )
@click.option(
"--initialization-stride",
type=click.INT,
default=1,
show_default=True,
cls=ResourceOption,
help="The stride to use for selecting a subset of samples to initialize the "
"machine. Must be 1 or greater.",
)
@click.option( @click.option(
"--jman-options", "--jman-options",
default=" ", default=" ",
show_default=True,
cls=ResourceOption, cls=ResourceOption,
help="Additional options to be given to jman", help="Additional options to be given to jman",
) )
@click.option( @click.option(
"--jman", default="jman", cls=ResourceOption, help="Path to the jman script." "--jman",
default="jman",
show_default=True,
cls=ResourceOption,
help="Path to the jman script.",
) )
@click.option( @click.option(
"--step", "--step",
...@@ -114,6 +132,7 @@ def train( ...@@ -114,6 +132,7 @@ def train(
machine, machine,
max_iterations, max_iterations,
convergence_threshold, convergence_threshold,
initialization_stride,
jman_options, jman_options,
jman, jman,
step, step,
...@@ -145,15 +164,20 @@ def train( ...@@ -145,15 +164,20 @@ def train(
) )
raise click.Abort raise click.Abort
# sanity check
assert len(samples) // array > machine.shape[0], "Please reduce array number!"
n_samples = len(samples) n_samples = len(samples)
n_jobs = array # some array jobs may not get any samples
# for example if n_samples is 241 and array is 64,
# each worker gets 4 samples and that means only 61 workers would get samples to
# work with
n_jobs = int(np.ceil(n_samples / np.ceil(n_samples / array)))
# initialize # initialize
if trainer_type in ("KMeansTrainer", "ML_GMMTrainer"): if trainer_type in ("KMeansTrainer", "ML_GMMTrainer"):
logger.info("Loading %d samples to initialize the machine", len(samples)) initilization_samples = samples[::initialization_stride]
data = read_samples(reader, samples) logger.info(
"Loading %d samples to initialize the machine", len(initilization_samples)
)
data = read_samples(reader, initilization_samples)
logger.info("Initializing the trainer (and maybe machine)") logger.info("Initializing the trainer (and maybe machine)")
trainer.initialize(machine, data) trainer.initialize(machine, data)
...@@ -349,7 +373,8 @@ def load_statistics(trainer, machine, path): ...@@ -349,7 +373,8 @@ def load_statistics(trainer, machine, path):
with HDF5File(path, "r") as f: with HDF5File(path, "r") as f:
if trainer_type == "KMeansTrainer": if trainer_type == "KMeansTrainer":
trainer.zeroeth_order_statistics = f["zeroeth_order_statistics"] zeros = f["zeroeth_order_statistics"]
trainer.zeroeth_order_statistics = np.array(zeros).reshape((-1,))
trainer.first_order_statistics = f["first_order_statistics"] trainer.first_order_statistics = f["first_order_statistics"]
trainer.average_min_distance = f["average_min_distance"] trainer.average_min_distance = f["average_min_distance"]
...@@ -366,6 +391,9 @@ def load_statistics(trainer, machine, path): ...@@ -366,6 +391,9 @@ def load_statistics(trainer, machine, path):
def e_step(samples, reader, output_dir, trainer, machine): def e_step(samples, reader, output_dir, trainer, machine):
if len(samples) == 0:
print("This worker did not get any samples.")
return
logger.info("Loading %d samples", len(samples)) logger.info("Loading %d samples", len(samples))
data = read_samples(reader, samples) data = read_samples(reader, samples)
logger.info("Loaded all samples") logger.info("Loaded all samples")
...@@ -394,7 +422,7 @@ def read_samples(reader, samples): ...@@ -394,7 +422,7 @@ def read_samples(reader, samples):
# read one sample to see if data is numpy arrays # read one sample to see if data is numpy arrays
data = reader(samples[0]) data = reader(samples[0])
if isinstance(data, np.ndarray): if isinstance(data, np.ndarray):
samples = vstack_features(reader, samples, same_size=True) samples = vstack_features(reader, samples, same_size=False)
else: else:
samples = [reader(s) for s in samples] samples = [reader(s) for s in samples]
return samples return samples
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment