diff --git a/src/mednet/data/datamodule.py b/src/mednet/data/datamodule.py index 6c7d759f4a8890f576af327450ac89e144f6fbfa..56ab3000a712c99f857f771ab36ecf85bebb6902 100644 --- a/src/mednet/data/datamodule.py +++ b/src/mednet/data/datamodule.py @@ -492,18 +492,6 @@ class ConcatDataModule(lightning.LightningDataModule): the last batch will be smaller than the first, unless ``drop_incomplete_batch`` is set to ``true``, in which case this batch is not used. - batch_chunk_count - Number of chunks in every batch (this parameter affects memory - requirements for the network). The number of samples loaded for every - iteration will be ``batch_size/batch_chunk_count``. ``batch_size`` - needs to be divisible by ``batch_chunk_count``, otherwise an error will - be raised. This parameter is used to reduce the number of samples loaded in - each iteration, in order to reduce the memory usage in exchange for - processing time (more iterations). This is especially interesting when - one is running on GPUs with limited RAM. The default of 1 forces the - whole batch to be processed at once. Otherwise the batch is broken into - batch-chunk-count pieces, and gradients are accumulated to complete - each batch. drop_incomplete_batch If set, then may drop the last batch in an epoch in case it is incomplete. If you set this option, you should also consider @@ -526,14 +514,11 @@ class ConcatDataModule(lightning.LightningDataModule): split_name: str = "", cache_samples: bool = False, batch_size: int = 1, - batch_chunk_count: int = 1, drop_incomplete_batch: bool = False, parallel: int = -1, ): super().__init__() - self.set_chunk_size(batch_size, batch_chunk_count) - self.splits = splits self.database_name = database_name self.split_name = split_name @@ -550,6 +535,8 @@ class ConcatDataModule(lightning.LightningDataModule): self._model_transforms: list[Transform] | None = None + self.batch_size = batch_size + self.drop_incomplete_batch = drop_incomplete_batch self.parallel = parallel # immutable, otherwise would need to call @@ -661,46 +648,6 @@ class ConcatDataModule(lightning.LightningDataModule): ) self._datasets = {} - def set_chunk_size(self, batch_size: int, batch_chunk_count: int) -> None: - """Coherently set the batch-chunk-size after validation. - - Parameters - ---------- - batch_size - Number of samples in every **training** batch (this parameter affects - memory requirements for the network). If the number of samples in the - batch is larger than the total number of samples available for - training, this value is truncated. If this number is smaller, then - batches of the specified size are created and fed to the network until - there are no more new samples to feed (epoch is finished). If the - total number of training samples is not a multiple of the batch-size, - the last batch will be smaller than the first, unless - ``drop_incomplete_batch`` is set to ``true``, in which case this batch - is not used. - batch_chunk_count - Number of chunks in every batch (this parameter affects memory - requirements for the network). The number of samples loaded for every - iteration will be ``batch_size/batch_chunk_count``. ``batch_size`` - needs to be divisible by ``batch_chunk_count``, otherwise an error will - be raised. This parameter is used to reduce number of samples loaded in - each iteration, in order to reduce the memory usage in exchange for - processing time (more iterations). This is especially interesting when - one is running on GPUs with limited RAM. The default of 1 forces the - whole batch to be processed at once. Otherwise the batch is broken into - batch-chunk-count pieces, and gradients are accumulated to complete - each batch. - """ - # validation - if batch_size % batch_chunk_count != 0: - raise RuntimeError( - f"batch_size ({batch_size}) must be divisible by " - f"batch_chunk_size ({batch_chunk_count}).", - ) - - self._batch_size = batch_size - self._batch_chunk_count = batch_chunk_count - self._chunk_size = self._batch_size // self._batch_chunk_count - def _setup_dataset(self, name: str) -> None: """Set up a single dataset from the input data split. @@ -845,7 +792,7 @@ class ConcatDataModule(lightning.LightningDataModule): return torch.utils.data.DataLoader( self._datasets["train"], shuffle=(self._train_sampler is None), - batch_size=self._chunk_size, + batch_size=self.batch_size, drop_last=self.drop_incomplete_batch, pin_memory=self.pin_memory, sampler=self._train_sampler, @@ -863,7 +810,7 @@ class ConcatDataModule(lightning.LightningDataModule): return torch.utils.data.DataLoader( self._datasets["train"], shuffle=False, - batch_size=self._chunk_size, + batch_size=self.batch_size, drop_last=False, **self._dataloader_multiproc, ) @@ -877,7 +824,7 @@ class ConcatDataModule(lightning.LightningDataModule): """ validation_loader_opts = { - "batch_size": self._chunk_size, + "batch_size": self.batch_size, "shuffle": False, "drop_last": self.drop_incomplete_batch, "pin_memory": self.pin_memory, @@ -903,7 +850,7 @@ class ConcatDataModule(lightning.LightningDataModule): return dict( test=torch.utils.data.DataLoader( self._datasets["test"], - batch_size=self._chunk_size, + batch_size=self.batch_size, shuffle=False, drop_last=self.drop_incomplete_batch, pin_memory=self.pin_memory, @@ -922,7 +869,7 @@ class ConcatDataModule(lightning.LightningDataModule): return { k: torch.utils.data.DataLoader( self._datasets[k], - batch_size=self._chunk_size, + batch_size=self.batch_size, shuffle=False, drop_last=self.drop_incomplete_batch, pin_memory=self.pin_memory, diff --git a/src/mednet/engine/trainer.py b/src/mednet/engine/trainer.py index 5ea8ccae4cfecb6987fb7affca1274e6bb474a4e..0fe79c818a565899464ad844b9cd916e4218e264 100644 --- a/src/mednet/engine/trainer.py +++ b/src/mednet/engine/trainer.py @@ -26,7 +26,7 @@ def run( max_epochs: int, output_folder: pathlib.Path, monitoring_interval: int | float, - batch_chunk_count: int, + accumulate_grad_batches: int, checkpoint: pathlib.Path | None, ): """Fit a CNN model using supervised learning and save it to disk. @@ -60,12 +60,13 @@ def run( monitoring_interval Interval, in seconds (or fractions), through which we should monitor resources during training. - batch_chunk_count - If this number is different than 1, then each batch will be divided in - this number of chunks. Gradients will be accumulated to perform each - mini-batch. This is particularly interesting when one has limited RAM - on the GPU, but would like to keep training with larger batches. One - exchanges for longer processing times in this case. + accumulate_grad_batches + Number of accumulations for backward propagation to accumulate gradients + over k batches before stepping the optimizer. The default of 1 forces + the whole batch to be processed at once. Otherwise the batch is multiplied + by accumulate-grad-batches pieces, and gradients are accumulated to complete + each step. This is especially interesting when one is training on GPUs with + a limited amount of onboard RAM. checkpoint Path to an optional checkpoint file to load. """ @@ -118,7 +119,7 @@ def run( accelerator=accelerator, devices=devices, max_epochs=max_epochs, - accumulate_grad_batches=batch_chunk_count, + accumulate_grad_batches=accumulate_grad_batches, logger=tensorboard_logger, check_val_every_n_epoch=validation_period, log_every_n_steps=len(datamodule.train_dataloader()), diff --git a/src/mednet/scripts/experiment.py b/src/mednet/scripts/experiment.py index 67e16b64d344df38fec0d305fcbfb2098e6642a5..12a7011569f30dfc1c820b5a718d307a19022538 100644 --- a/src/mednet/scripts/experiment.py +++ b/src/mednet/scripts/experiment.py @@ -40,7 +40,7 @@ def experiment( output_folder, epochs, batch_size, - batch_chunk_count, + accumulate_grad_batches, drop_incomplete_batch, datamodule, validation_period, @@ -79,7 +79,7 @@ def experiment( output_folder=train_output_folder, epochs=epochs, batch_size=batch_size, - batch_chunk_count=batch_chunk_count, + accumulate_grad_batches=accumulate_grad_batches, drop_incomplete_batch=drop_incomplete_batch, datamodule=datamodule, validation_period=validation_period, diff --git a/src/mednet/scripts/predict.py b/src/mednet/scripts/predict.py index b8ecf91c86d5c554337a40ccf7264162692e2d64..a24a98c6d28eb75502fc00565985401c564ec3f1 100644 --- a/src/mednet/scripts/predict.py +++ b/src/mednet/scripts/predict.py @@ -142,7 +142,6 @@ def predict( save_json_with_backup, ) - datamodule.set_chunk_size(batch_size, 1) datamodule.parallel = parallel datamodule.model_transforms = model.model_transforms diff --git a/src/mednet/scripts/saliency/completeness.py b/src/mednet/scripts/saliency/completeness.py index 453618332cd06a4a54570c70b65f7559b0900390..5c93aded47a87728bb72df927dd93e959bdadf0b 100644 --- a/src/mednet/scripts/saliency/completeness.py +++ b/src/mednet/scripts/saliency/completeness.py @@ -216,9 +216,6 @@ def completeness( device_manager = DeviceManager(device) - # batch_size must be == 1 for now (underlying code is NOT prepared to - # treat multiple samples at once). - datamodule.set_chunk_size(1, 1) datamodule.cache_samples = cache_samples datamodule.parallel = parallel datamodule.model_transforms = model.model_transforms diff --git a/src/mednet/scripts/saliency/generate.py b/src/mednet/scripts/saliency/generate.py index 5a9ca8b6955158355e333e8e178a7a5fdd3b814c..71f248fa807fdcfb90bc23f9a8e9efb0fd635cf7 100644 --- a/src/mednet/scripts/saliency/generate.py +++ b/src/mednet/scripts/saliency/generate.py @@ -177,9 +177,6 @@ def generate( device_manager = DeviceManager(device) - # batch_size must be == 1 for now (underlying code is NOT prepared to - # treat multiple samples at once). - datamodule.set_chunk_size(1, 1) datamodule.cache_samples = cache_samples datamodule.parallel = parallel datamodule.model_transforms = model.model_transforms diff --git a/src/mednet/scripts/saliency/view.py b/src/mednet/scripts/saliency/view.py index bb1a5c9d91f27d5d46d4001d0a3b35cc5fb3dc38..270ccbfebf86418bf458add62e71f0cda6baad67 100644 --- a/src/mednet/scripts/saliency/view.py +++ b/src/mednet/scripts/saliency/view.py @@ -110,7 +110,6 @@ def view( logger.info(f"Output folder: {output_folder}") output_folder.mkdir(parents=True, exist_ok=True) - datamodule.set_chunk_size(1, 1) datamodule.drop_incomplete_batch = False # datamodule.cache_samples = cache_samples # datamodule.parallel = parallel diff --git a/src/mednet/scripts/train.py b/src/mednet/scripts/train.py index e83ae359b0be523565f669c57bf1de41520ccac7..9c25de7ae8124a1fe3ce49462c081486f249c17a 100644 --- a/src/mednet/scripts/train.py +++ b/src/mednet/scripts/train.py @@ -79,19 +79,18 @@ def reusable_options(f): cls=ResourceOption, ) @click.option( - "--batch-chunk-count", - "-c", - help="Number of chunks in every batch (this parameter affects " - "memory requirements for the network). The number of samples " - "loaded for every iteration will be batch-size/batch-chunk-count. " - "batch-size needs to be divisible by batch-chunk-count, otherwise an " - "error will be raised. This parameter is used to reduce the number of " - "samples loaded in each iteration, in order to reduce the memory usage " - "in exchange for processing time (more iterations). This is especially " - "interesting when one is training on GPUs with limited RAM. The " - "default of 1 forces the whole batch to be processed at once. Otherwise " - "the batch is broken into batch-chunk-count pieces, and gradients are " - "accumulated to complete each batch.", + "--accumulate-grad-batches", + "-a", + help="Number of accumulations for backward propagation to accumulate " + "gradients over k batches before stepping the optimizer. This " + "parameter, used in conjunction with the batch-size, may be used to " + "reduce the number of samples loaded in each iteration, to affect memory " + "usage in exchange for processing time (more iterations). This is " + "useful interesting when one is training on GPUs with a limited amount " + "of onboard RAM. The default of 1 forces the whole batch to be " + "processed at once. Otherwise the batch is multiplied by " + "accumulate-grad-batches pieces, and gradients are accumulated " + "to complete each training step.", required=True, show_default=True, default=1, @@ -236,7 +235,7 @@ def train( output_folder, epochs, batch_size, - batch_chunk_count, + accumulate_grad_batches, drop_incomplete_batch, datamodule, validation_period, @@ -283,7 +282,6 @@ def train( seed_everything(seed) # reset datamodule with user configurable options - datamodule.set_chunk_size(batch_size, batch_chunk_count) datamodule.drop_incomplete_batch = drop_incomplete_batch datamodule.cache_samples = cache_samples datamodule.parallel = parallel @@ -342,7 +340,7 @@ def train( split_name=datamodule.split_name, epochs=epochs, batch_size=batch_size, - batch_chunk_count=batch_chunk_count, + accumulate_grad_batches=accumulate_grad_batches, drop_incomplete_batch=drop_incomplete_batch, validation_period=validation_period, cache_samples=cache_samples, @@ -365,6 +363,6 @@ def train( max_epochs=epochs, output_folder=output_folder, monitoring_interval=monitoring_interval, - batch_chunk_count=batch_chunk_count, + accumulate_grad_batches=accumulate_grad_batches, checkpoint=checkpoint_file, )