Skip to content
Snippets Groups Projects
Commit 2e79aae0 authored by Gokhan OZBULAK's avatar Gokhan OZBULAK
Browse files

Change flag for batch accumulation. #25

parent cf742317
No related branches found
No related tags found
1 merge request!40Lightning acc
Pipeline #87608 passed
......@@ -26,7 +26,7 @@ def run(
max_epochs: int,
output_folder: pathlib.Path,
monitoring_interval: int | float,
batch_chunk_count: int,
accumulate_grad_batches: int,
checkpoint: pathlib.Path | None,
):
"""Fit a CNN model using supervised learning and save it to disk.
......@@ -60,12 +60,13 @@ def run(
monitoring_interval
Interval, in seconds (or fractions), through which we should monitor
resources during training.
batch_chunk_count
If this number is different than 1, then each batch will be divided in
this number of chunks. Gradients will be accumulated to perform each
mini-batch. This is particularly interesting when one has limited RAM
on the GPU, but would like to keep training with larger batches. One
exchanges for longer processing times in this case.
accumulate_grad_batches
Number of accumulations for backward propagation to accumulate gradients
over k batches before stepping the optimizer. The default of 1 forces
the whole batch to be processed at once. Otherwise the batch is multiplied
by accumulate-grad-batches pieces, and gradients are accumulated to complete
each step. This is especially interesting when one is training on GPUs with
a limited amount of onboard RAM.
checkpoint
Path to an optional checkpoint file to load.
"""
......@@ -118,7 +119,7 @@ def run(
accelerator=accelerator,
devices=devices,
max_epochs=max_epochs,
accumulate_grad_batches=batch_chunk_count,
accumulate_grad_batches=accumulate_grad_batches,
logger=tensorboard_logger,
check_val_every_n_epoch=validation_period,
log_every_n_steps=len(datamodule.train_dataloader()),
......
......@@ -40,7 +40,7 @@ def experiment(
output_folder,
epochs,
batch_size,
batch_chunk_count,
accumulate_grad_batches,
drop_incomplete_batch,
datamodule,
validation_period,
......@@ -79,7 +79,7 @@ def experiment(
output_folder=train_output_folder,
epochs=epochs,
batch_size=batch_size,
batch_chunk_count=batch_chunk_count,
accumulate_grad_batches=accumulate_grad_batches,
drop_incomplete_batch=drop_incomplete_batch,
datamodule=datamodule,
validation_period=validation_period,
......
......@@ -79,18 +79,19 @@ def reusable_options(f):
cls=ResourceOption,
)
@click.option(
"--batch-chunk-count",
"-c",
help="Number of chunks in every batch (this parameter affects "
"memory requirements for the network). The number of samples "
"loaded for every iteration will be batch-size*batch-chunk-count. "
"This parameter is used to reduce the number of samples loaded in each "
"iteration, in order to reduce the memory usage in exchange for "
"processing time (more iterations). This is especially interesting "
"--accumulate-grad-batches",
"-a",
help="Number of accumulations for backward propagation to accumulate "
"gradients over k batches before stepping the optimizer. This "
"parameter, used in conjunction with the batch-size, may be used to "
"reduce the number of samples loaded in each iteration, to affect memory "
"usage in exchange for processing time (more iterations). This is "
"especially interesting when one is training on GPUs with a limited amount "
"of onboard RAM. processing time (more iterations). This is especially interesting "
"when one is training on GPUs with limited RAM. The default of 1 forces "
"the whole batch to be processed at once. Otherwise the batch is "
"multiplied by batch-chunk-count pieces, and gradients are accumulated "
"to complete each batch.",
"multiplied by accumulate-grad-batches pieces, and gradients are accumulated "
"to complete each step.",
required=True,
show_default=True,
default=1,
......@@ -235,7 +236,7 @@ def train(
output_folder,
epochs,
batch_size,
batch_chunk_count,
accumulate_grad_batches,
drop_incomplete_batch,
datamodule,
validation_period,
......@@ -340,7 +341,7 @@ def train(
split_name=datamodule.split_name,
epochs=epochs,
batch_size=batch_size,
batch_chunk_count=batch_chunk_count,
accumulate_grad_batches=accumulate_grad_batches,
drop_incomplete_batch=drop_incomplete_batch,
validation_period=validation_period,
cache_samples=cache_samples,
......@@ -363,6 +364,6 @@ def train(
max_epochs=epochs,
output_folder=output_folder,
monitoring_interval=monitoring_interval,
batch_chunk_count=batch_chunk_count,
accumulate_grad_batches=accumulate_grad_batches,
checkpoint=checkpoint_file,
)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment