Skip to content
Snippets Groups Projects
Commit f3a00b6a authored by André Anjos's avatar André Anjos :speech_balloon:
Browse files

Merge branch 'lightning-acc' into 'main'

Lightning acc

Closes #25

See merge request biosignal/software/mednet!40
parents d354388c 96b8c28e
No related branches found
No related tags found
1 merge request!40Lightning acc
Pipeline #87621 canceled
...@@ -492,18 +492,6 @@ class ConcatDataModule(lightning.LightningDataModule): ...@@ -492,18 +492,6 @@ class ConcatDataModule(lightning.LightningDataModule):
the last batch will be smaller than the first, unless the last batch will be smaller than the first, unless
``drop_incomplete_batch`` is set to ``true``, in which case this batch ``drop_incomplete_batch`` is set to ``true``, in which case this batch
is not used. is not used.
batch_chunk_count
Number of chunks in every batch (this parameter affects memory
requirements for the network). The number of samples loaded for every
iteration will be ``batch_size/batch_chunk_count``. ``batch_size``
needs to be divisible by ``batch_chunk_count``, otherwise an error will
be raised. This parameter is used to reduce the number of samples loaded in
each iteration, in order to reduce the memory usage in exchange for
processing time (more iterations). This is especially interesting when
one is running on GPUs with limited RAM. The default of 1 forces the
whole batch to be processed at once. Otherwise the batch is broken into
batch-chunk-count pieces, and gradients are accumulated to complete
each batch.
drop_incomplete_batch drop_incomplete_batch
If set, then may drop the last batch in an epoch in case it is If set, then may drop the last batch in an epoch in case it is
incomplete. If you set this option, you should also consider incomplete. If you set this option, you should also consider
...@@ -526,14 +514,11 @@ class ConcatDataModule(lightning.LightningDataModule): ...@@ -526,14 +514,11 @@ class ConcatDataModule(lightning.LightningDataModule):
split_name: str = "", split_name: str = "",
cache_samples: bool = False, cache_samples: bool = False,
batch_size: int = 1, batch_size: int = 1,
batch_chunk_count: int = 1,
drop_incomplete_batch: bool = False, drop_incomplete_batch: bool = False,
parallel: int = -1, parallel: int = -1,
): ):
super().__init__() super().__init__()
self.set_chunk_size(batch_size, batch_chunk_count)
self.splits = splits self.splits = splits
self.database_name = database_name self.database_name = database_name
self.split_name = split_name self.split_name = split_name
...@@ -550,6 +535,8 @@ class ConcatDataModule(lightning.LightningDataModule): ...@@ -550,6 +535,8 @@ class ConcatDataModule(lightning.LightningDataModule):
self._model_transforms: list[Transform] | None = None self._model_transforms: list[Transform] | None = None
self.batch_size = batch_size
self.drop_incomplete_batch = drop_incomplete_batch self.drop_incomplete_batch = drop_incomplete_batch
self.parallel = parallel # immutable, otherwise would need to call self.parallel = parallel # immutable, otherwise would need to call
...@@ -661,46 +648,6 @@ class ConcatDataModule(lightning.LightningDataModule): ...@@ -661,46 +648,6 @@ class ConcatDataModule(lightning.LightningDataModule):
) )
self._datasets = {} self._datasets = {}
def set_chunk_size(self, batch_size: int, batch_chunk_count: int) -> None:
"""Coherently set the batch-chunk-size after validation.
Parameters
----------
batch_size
Number of samples in every **training** batch (this parameter affects
memory requirements for the network). If the number of samples in the
batch is larger than the total number of samples available for
training, this value is truncated. If this number is smaller, then
batches of the specified size are created and fed to the network until
there are no more new samples to feed (epoch is finished). If the
total number of training samples is not a multiple of the batch-size,
the last batch will be smaller than the first, unless
``drop_incomplete_batch`` is set to ``true``, in which case this batch
is not used.
batch_chunk_count
Number of chunks in every batch (this parameter affects memory
requirements for the network). The number of samples loaded for every
iteration will be ``batch_size/batch_chunk_count``. ``batch_size``
needs to be divisible by ``batch_chunk_count``, otherwise an error will
be raised. This parameter is used to reduce number of samples loaded in
each iteration, in order to reduce the memory usage in exchange for
processing time (more iterations). This is especially interesting when
one is running on GPUs with limited RAM. The default of 1 forces the
whole batch to be processed at once. Otherwise the batch is broken into
batch-chunk-count pieces, and gradients are accumulated to complete
each batch.
"""
# validation
if batch_size % batch_chunk_count != 0:
raise RuntimeError(
f"batch_size ({batch_size}) must be divisible by "
f"batch_chunk_size ({batch_chunk_count}).",
)
self._batch_size = batch_size
self._batch_chunk_count = batch_chunk_count
self._chunk_size = self._batch_size // self._batch_chunk_count
def _setup_dataset(self, name: str) -> None: def _setup_dataset(self, name: str) -> None:
"""Set up a single dataset from the input data split. """Set up a single dataset from the input data split.
...@@ -845,7 +792,7 @@ class ConcatDataModule(lightning.LightningDataModule): ...@@ -845,7 +792,7 @@ class ConcatDataModule(lightning.LightningDataModule):
return torch.utils.data.DataLoader( return torch.utils.data.DataLoader(
self._datasets["train"], self._datasets["train"],
shuffle=(self._train_sampler is None), shuffle=(self._train_sampler is None),
batch_size=self._chunk_size, batch_size=self.batch_size,
drop_last=self.drop_incomplete_batch, drop_last=self.drop_incomplete_batch,
pin_memory=self.pin_memory, pin_memory=self.pin_memory,
sampler=self._train_sampler, sampler=self._train_sampler,
...@@ -863,7 +810,7 @@ class ConcatDataModule(lightning.LightningDataModule): ...@@ -863,7 +810,7 @@ class ConcatDataModule(lightning.LightningDataModule):
return torch.utils.data.DataLoader( return torch.utils.data.DataLoader(
self._datasets["train"], self._datasets["train"],
shuffle=False, shuffle=False,
batch_size=self._chunk_size, batch_size=self.batch_size,
drop_last=False, drop_last=False,
**self._dataloader_multiproc, **self._dataloader_multiproc,
) )
...@@ -877,7 +824,7 @@ class ConcatDataModule(lightning.LightningDataModule): ...@@ -877,7 +824,7 @@ class ConcatDataModule(lightning.LightningDataModule):
""" """
validation_loader_opts = { validation_loader_opts = {
"batch_size": self._chunk_size, "batch_size": self.batch_size,
"shuffle": False, "shuffle": False,
"drop_last": self.drop_incomplete_batch, "drop_last": self.drop_incomplete_batch,
"pin_memory": self.pin_memory, "pin_memory": self.pin_memory,
...@@ -903,7 +850,7 @@ class ConcatDataModule(lightning.LightningDataModule): ...@@ -903,7 +850,7 @@ class ConcatDataModule(lightning.LightningDataModule):
return dict( return dict(
test=torch.utils.data.DataLoader( test=torch.utils.data.DataLoader(
self._datasets["test"], self._datasets["test"],
batch_size=self._chunk_size, batch_size=self.batch_size,
shuffle=False, shuffle=False,
drop_last=self.drop_incomplete_batch, drop_last=self.drop_incomplete_batch,
pin_memory=self.pin_memory, pin_memory=self.pin_memory,
...@@ -922,7 +869,7 @@ class ConcatDataModule(lightning.LightningDataModule): ...@@ -922,7 +869,7 @@ class ConcatDataModule(lightning.LightningDataModule):
return { return {
k: torch.utils.data.DataLoader( k: torch.utils.data.DataLoader(
self._datasets[k], self._datasets[k],
batch_size=self._chunk_size, batch_size=self.batch_size,
shuffle=False, shuffle=False,
drop_last=self.drop_incomplete_batch, drop_last=self.drop_incomplete_batch,
pin_memory=self.pin_memory, pin_memory=self.pin_memory,
......
...@@ -26,7 +26,7 @@ def run( ...@@ -26,7 +26,7 @@ def run(
max_epochs: int, max_epochs: int,
output_folder: pathlib.Path, output_folder: pathlib.Path,
monitoring_interval: int | float, monitoring_interval: int | float,
batch_chunk_count: int, accumulate_grad_batches: int,
checkpoint: pathlib.Path | None, checkpoint: pathlib.Path | None,
): ):
"""Fit a CNN model using supervised learning and save it to disk. """Fit a CNN model using supervised learning and save it to disk.
...@@ -60,12 +60,13 @@ def run( ...@@ -60,12 +60,13 @@ def run(
monitoring_interval monitoring_interval
Interval, in seconds (or fractions), through which we should monitor Interval, in seconds (or fractions), through which we should monitor
resources during training. resources during training.
batch_chunk_count accumulate_grad_batches
If this number is different than 1, then each batch will be divided in Number of accumulations for backward propagation to accumulate gradients
this number of chunks. Gradients will be accumulated to perform each over k batches before stepping the optimizer. The default of 1 forces
mini-batch. This is particularly interesting when one has limited RAM the whole batch to be processed at once. Otherwise the batch is multiplied
on the GPU, but would like to keep training with larger batches. One by accumulate-grad-batches pieces, and gradients are accumulated to complete
exchanges for longer processing times in this case. each step. This is especially interesting when one is training on GPUs with
a limited amount of onboard RAM.
checkpoint checkpoint
Path to an optional checkpoint file to load. Path to an optional checkpoint file to load.
""" """
...@@ -118,7 +119,7 @@ def run( ...@@ -118,7 +119,7 @@ def run(
accelerator=accelerator, accelerator=accelerator,
devices=devices, devices=devices,
max_epochs=max_epochs, max_epochs=max_epochs,
accumulate_grad_batches=batch_chunk_count, accumulate_grad_batches=accumulate_grad_batches,
logger=tensorboard_logger, logger=tensorboard_logger,
check_val_every_n_epoch=validation_period, check_val_every_n_epoch=validation_period,
log_every_n_steps=len(datamodule.train_dataloader()), log_every_n_steps=len(datamodule.train_dataloader()),
......
...@@ -40,7 +40,7 @@ def experiment( ...@@ -40,7 +40,7 @@ def experiment(
output_folder, output_folder,
epochs, epochs,
batch_size, batch_size,
batch_chunk_count, accumulate_grad_batches,
drop_incomplete_batch, drop_incomplete_batch,
datamodule, datamodule,
validation_period, validation_period,
...@@ -79,7 +79,7 @@ def experiment( ...@@ -79,7 +79,7 @@ def experiment(
output_folder=train_output_folder, output_folder=train_output_folder,
epochs=epochs, epochs=epochs,
batch_size=batch_size, batch_size=batch_size,
batch_chunk_count=batch_chunk_count, accumulate_grad_batches=accumulate_grad_batches,
drop_incomplete_batch=drop_incomplete_batch, drop_incomplete_batch=drop_incomplete_batch,
datamodule=datamodule, datamodule=datamodule,
validation_period=validation_period, validation_period=validation_period,
......
...@@ -142,7 +142,6 @@ def predict( ...@@ -142,7 +142,6 @@ def predict(
save_json_with_backup, save_json_with_backup,
) )
datamodule.set_chunk_size(batch_size, 1)
datamodule.parallel = parallel datamodule.parallel = parallel
datamodule.model_transforms = model.model_transforms datamodule.model_transforms = model.model_transforms
......
...@@ -216,9 +216,6 @@ def completeness( ...@@ -216,9 +216,6 @@ def completeness(
device_manager = DeviceManager(device) device_manager = DeviceManager(device)
# batch_size must be == 1 for now (underlying code is NOT prepared to
# treat multiple samples at once).
datamodule.set_chunk_size(1, 1)
datamodule.cache_samples = cache_samples datamodule.cache_samples = cache_samples
datamodule.parallel = parallel datamodule.parallel = parallel
datamodule.model_transforms = model.model_transforms datamodule.model_transforms = model.model_transforms
......
...@@ -177,9 +177,6 @@ def generate( ...@@ -177,9 +177,6 @@ def generate(
device_manager = DeviceManager(device) device_manager = DeviceManager(device)
# batch_size must be == 1 for now (underlying code is NOT prepared to
# treat multiple samples at once).
datamodule.set_chunk_size(1, 1)
datamodule.cache_samples = cache_samples datamodule.cache_samples = cache_samples
datamodule.parallel = parallel datamodule.parallel = parallel
datamodule.model_transforms = model.model_transforms datamodule.model_transforms = model.model_transforms
......
...@@ -110,7 +110,6 @@ def view( ...@@ -110,7 +110,6 @@ def view(
logger.info(f"Output folder: {output_folder}") logger.info(f"Output folder: {output_folder}")
output_folder.mkdir(parents=True, exist_ok=True) output_folder.mkdir(parents=True, exist_ok=True)
datamodule.set_chunk_size(1, 1)
datamodule.drop_incomplete_batch = False datamodule.drop_incomplete_batch = False
# datamodule.cache_samples = cache_samples # datamodule.cache_samples = cache_samples
# datamodule.parallel = parallel # datamodule.parallel = parallel
......
...@@ -79,19 +79,18 @@ def reusable_options(f): ...@@ -79,19 +79,18 @@ def reusable_options(f):
cls=ResourceOption, cls=ResourceOption,
) )
@click.option( @click.option(
"--batch-chunk-count", "--accumulate-grad-batches",
"-c", "-a",
help="Number of chunks in every batch (this parameter affects " help="Number of accumulations for backward propagation to accumulate "
"memory requirements for the network). The number of samples " "gradients over k batches before stepping the optimizer. This "
"loaded for every iteration will be batch-size/batch-chunk-count. " "parameter, used in conjunction with the batch-size, may be used to "
"batch-size needs to be divisible by batch-chunk-count, otherwise an " "reduce the number of samples loaded in each iteration, to affect memory "
"error will be raised. This parameter is used to reduce the number of " "usage in exchange for processing time (more iterations). This is "
"samples loaded in each iteration, in order to reduce the memory usage " "useful interesting when one is training on GPUs with a limited amount "
"in exchange for processing time (more iterations). This is especially " "of onboard RAM. The default of 1 forces the whole batch to be "
"interesting when one is training on GPUs with limited RAM. The " "processed at once. Otherwise the batch is multiplied by "
"default of 1 forces the whole batch to be processed at once. Otherwise " "accumulate-grad-batches pieces, and gradients are accumulated "
"the batch is broken into batch-chunk-count pieces, and gradients are " "to complete each training step.",
"accumulated to complete each batch.",
required=True, required=True,
show_default=True, show_default=True,
default=1, default=1,
...@@ -236,7 +235,7 @@ def train( ...@@ -236,7 +235,7 @@ def train(
output_folder, output_folder,
epochs, epochs,
batch_size, batch_size,
batch_chunk_count, accumulate_grad_batches,
drop_incomplete_batch, drop_incomplete_batch,
datamodule, datamodule,
validation_period, validation_period,
...@@ -283,7 +282,6 @@ def train( ...@@ -283,7 +282,6 @@ def train(
seed_everything(seed) seed_everything(seed)
# reset datamodule with user configurable options # reset datamodule with user configurable options
datamodule.set_chunk_size(batch_size, batch_chunk_count)
datamodule.drop_incomplete_batch = drop_incomplete_batch datamodule.drop_incomplete_batch = drop_incomplete_batch
datamodule.cache_samples = cache_samples datamodule.cache_samples = cache_samples
datamodule.parallel = parallel datamodule.parallel = parallel
...@@ -342,7 +340,7 @@ def train( ...@@ -342,7 +340,7 @@ def train(
split_name=datamodule.split_name, split_name=datamodule.split_name,
epochs=epochs, epochs=epochs,
batch_size=batch_size, batch_size=batch_size,
batch_chunk_count=batch_chunk_count, accumulate_grad_batches=accumulate_grad_batches,
drop_incomplete_batch=drop_incomplete_batch, drop_incomplete_batch=drop_incomplete_batch,
validation_period=validation_period, validation_period=validation_period,
cache_samples=cache_samples, cache_samples=cache_samples,
...@@ -365,6 +363,6 @@ def train( ...@@ -365,6 +363,6 @@ def train(
max_epochs=epochs, max_epochs=epochs,
output_folder=output_folder, output_folder=output_folder,
monitoring_interval=monitoring_interval, monitoring_interval=monitoring_interval,
batch_chunk_count=batch_chunk_count, accumulate_grad_batches=accumulate_grad_batches,
checkpoint=checkpoint_file, checkpoint=checkpoint_file,
) )
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment