Skip to content
Snippets Groups Projects

Lightning acc

Merged Gokhan OZBULAK requested to merge lightning-acc into main
1 unresolved thread
+ 15
17
@@ -79,19 +79,18 @@ def reusable_options(f):
cls=ResourceOption,
)
@click.option(
"--batch-chunk-count",
"-c",
help="Number of chunks in every batch (this parameter affects "
"memory requirements for the network). The number of samples "
"loaded for every iteration will be batch-size/batch-chunk-count. "
"batch-size needs to be divisible by batch-chunk-count, otherwise an "
"error will be raised. This parameter is used to reduce the number of "
"samples loaded in each iteration, in order to reduce the memory usage "
"in exchange for processing time (more iterations). This is especially "
"interesting when one is training on GPUs with limited RAM. The "
"default of 1 forces the whole batch to be processed at once. Otherwise "
"the batch is broken into batch-chunk-count pieces, and gradients are "
"accumulated to complete each batch.",
"--accumulate-grad-batches",
"-a",
help="Number of accumulations for backward propagation to accumulate "
"gradients over k batches before stepping the optimizer. This "
"parameter, used in conjunction with the batch-size, may be used to "
"reduce the number of samples loaded in each iteration, to affect memory "
"usage in exchange for processing time (more iterations). This is "
"useful interesting when one is training on GPUs with a limited amount "
"of onboard RAM. The default of 1 forces the whole batch to be "
"processed at once. Otherwise the batch is multiplied by "
"accumulate-grad-batches pieces, and gradients are accumulated "
"to complete each training step.",
required=True,
show_default=True,
default=1,
@@ -236,7 +235,7 @@ def train(
output_folder,
epochs,
batch_size,
batch_chunk_count,
accumulate_grad_batches,
drop_incomplete_batch,
datamodule,
validation_period,
@@ -283,7 +282,6 @@ def train(
seed_everything(seed)
# reset datamodule with user configurable options
datamodule.set_chunk_size(batch_size, batch_chunk_count)
datamodule.drop_incomplete_batch = drop_incomplete_batch
datamodule.cache_samples = cache_samples
datamodule.parallel = parallel
@@ -342,7 +340,7 @@ def train(
split_name=datamodule.split_name,
epochs=epochs,
batch_size=batch_size,
batch_chunk_count=batch_chunk_count,
accumulate_grad_batches=accumulate_grad_batches,
drop_incomplete_batch=drop_incomplete_batch,
validation_period=validation_period,
cache_samples=cache_samples,
@@ -365,6 +363,6 @@ def train(
max_epochs=epochs,
output_folder=output_folder,
monitoring_interval=monitoring_interval,
batch_chunk_count=batch_chunk_count,
accumulate_grad_batches=accumulate_grad_batches,
checkpoint=checkpoint_file,
)
Loading