Skip to content
Snippets Groups Projects
Commit 2e79aae0 authored by Gokhan OZBULAK's avatar Gokhan OZBULAK
Browse files

Change flag for batch accumulation. #25

parent cf742317
No related branches found
No related tags found
1 merge request!40Lightning acc
Pipeline #87608 passed
...@@ -26,7 +26,7 @@ def run( ...@@ -26,7 +26,7 @@ def run(
max_epochs: int, max_epochs: int,
output_folder: pathlib.Path, output_folder: pathlib.Path,
monitoring_interval: int | float, monitoring_interval: int | float,
batch_chunk_count: int, accumulate_grad_batches: int,
checkpoint: pathlib.Path | None, checkpoint: pathlib.Path | None,
): ):
"""Fit a CNN model using supervised learning and save it to disk. """Fit a CNN model using supervised learning and save it to disk.
...@@ -60,12 +60,13 @@ def run( ...@@ -60,12 +60,13 @@ def run(
monitoring_interval monitoring_interval
Interval, in seconds (or fractions), through which we should monitor Interval, in seconds (or fractions), through which we should monitor
resources during training. resources during training.
batch_chunk_count accumulate_grad_batches
If this number is different than 1, then each batch will be divided in Number of accumulations for backward propagation to accumulate gradients
this number of chunks. Gradients will be accumulated to perform each over k batches before stepping the optimizer. The default of 1 forces
mini-batch. This is particularly interesting when one has limited RAM the whole batch to be processed at once. Otherwise the batch is multiplied
on the GPU, but would like to keep training with larger batches. One by accumulate-grad-batches pieces, and gradients are accumulated to complete
exchanges for longer processing times in this case. each step. This is especially interesting when one is training on GPUs with
a limited amount of onboard RAM.
checkpoint checkpoint
Path to an optional checkpoint file to load. Path to an optional checkpoint file to load.
""" """
...@@ -118,7 +119,7 @@ def run( ...@@ -118,7 +119,7 @@ def run(
accelerator=accelerator, accelerator=accelerator,
devices=devices, devices=devices,
max_epochs=max_epochs, max_epochs=max_epochs,
accumulate_grad_batches=batch_chunk_count, accumulate_grad_batches=accumulate_grad_batches,
logger=tensorboard_logger, logger=tensorboard_logger,
check_val_every_n_epoch=validation_period, check_val_every_n_epoch=validation_period,
log_every_n_steps=len(datamodule.train_dataloader()), log_every_n_steps=len(datamodule.train_dataloader()),
......
...@@ -40,7 +40,7 @@ def experiment( ...@@ -40,7 +40,7 @@ def experiment(
output_folder, output_folder,
epochs, epochs,
batch_size, batch_size,
batch_chunk_count, accumulate_grad_batches,
drop_incomplete_batch, drop_incomplete_batch,
datamodule, datamodule,
validation_period, validation_period,
...@@ -79,7 +79,7 @@ def experiment( ...@@ -79,7 +79,7 @@ def experiment(
output_folder=train_output_folder, output_folder=train_output_folder,
epochs=epochs, epochs=epochs,
batch_size=batch_size, batch_size=batch_size,
batch_chunk_count=batch_chunk_count, accumulate_grad_batches=accumulate_grad_batches,
drop_incomplete_batch=drop_incomplete_batch, drop_incomplete_batch=drop_incomplete_batch,
datamodule=datamodule, datamodule=datamodule,
validation_period=validation_period, validation_period=validation_period,
......
...@@ -79,18 +79,19 @@ def reusable_options(f): ...@@ -79,18 +79,19 @@ def reusable_options(f):
cls=ResourceOption, cls=ResourceOption,
) )
@click.option( @click.option(
"--batch-chunk-count", "--accumulate-grad-batches",
"-c", "-a",
help="Number of chunks in every batch (this parameter affects " help="Number of accumulations for backward propagation to accumulate "
"memory requirements for the network). The number of samples " "gradients over k batches before stepping the optimizer. This "
"loaded for every iteration will be batch-size*batch-chunk-count. " "parameter, used in conjunction with the batch-size, may be used to "
"This parameter is used to reduce the number of samples loaded in each " "reduce the number of samples loaded in each iteration, to affect memory "
"iteration, in order to reduce the memory usage in exchange for " "usage in exchange for processing time (more iterations). This is "
"processing time (more iterations). This is especially interesting " "especially interesting when one is training on GPUs with a limited amount "
"of onboard RAM. processing time (more iterations). This is especially interesting "
"when one is training on GPUs with limited RAM. The default of 1 forces " "when one is training on GPUs with limited RAM. The default of 1 forces "
"the whole batch to be processed at once. Otherwise the batch is " "the whole batch to be processed at once. Otherwise the batch is "
"multiplied by batch-chunk-count pieces, and gradients are accumulated " "multiplied by accumulate-grad-batches pieces, and gradients are accumulated "
"to complete each batch.", "to complete each step.",
required=True, required=True,
show_default=True, show_default=True,
default=1, default=1,
...@@ -235,7 +236,7 @@ def train( ...@@ -235,7 +236,7 @@ def train(
output_folder, output_folder,
epochs, epochs,
batch_size, batch_size,
batch_chunk_count, accumulate_grad_batches,
drop_incomplete_batch, drop_incomplete_batch,
datamodule, datamodule,
validation_period, validation_period,
...@@ -340,7 +341,7 @@ def train( ...@@ -340,7 +341,7 @@ def train(
split_name=datamodule.split_name, split_name=datamodule.split_name,
epochs=epochs, epochs=epochs,
batch_size=batch_size, batch_size=batch_size,
batch_chunk_count=batch_chunk_count, accumulate_grad_batches=accumulate_grad_batches,
drop_incomplete_batch=drop_incomplete_batch, drop_incomplete_batch=drop_incomplete_batch,
validation_period=validation_period, validation_period=validation_period,
cache_samples=cache_samples, cache_samples=cache_samples,
...@@ -363,6 +364,6 @@ def train( ...@@ -363,6 +364,6 @@ def train(
max_epochs=epochs, max_epochs=epochs,
output_folder=output_folder, output_folder=output_folder,
monitoring_interval=monitoring_interval, monitoring_interval=monitoring_interval,
batch_chunk_count=batch_chunk_count, accumulate_grad_batches=accumulate_grad_batches,
checkpoint=checkpoint_file, checkpoint=checkpoint_file,
) )
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment