diff --git a/src/mednet/data/datamodule.py b/src/mednet/data/datamodule.py index 6c7d759f4a8890f576af327450ac89e144f6fbfa..d7b6e93b6e79f51d4c545668ecb7d789e0f5b352 100644 --- a/src/mednet/data/datamodule.py +++ b/src/mednet/data/datamodule.py @@ -492,18 +492,6 @@ class ConcatDataModule(lightning.LightningDataModule): the last batch will be smaller than the first, unless ``drop_incomplete_batch`` is set to ``true``, in which case this batch is not used. - batch_chunk_count - Number of chunks in every batch (this parameter affects memory - requirements for the network). The number of samples loaded for every - iteration will be ``batch_size/batch_chunk_count``. ``batch_size`` - needs to be divisible by ``batch_chunk_count``, otherwise an error will - be raised. This parameter is used to reduce the number of samples loaded in - each iteration, in order to reduce the memory usage in exchange for - processing time (more iterations). This is especially interesting when - one is running on GPUs with limited RAM. The default of 1 forces the - whole batch to be processed at once. Otherwise the batch is broken into - batch-chunk-count pieces, and gradients are accumulated to complete - each batch. drop_incomplete_batch If set, then may drop the last batch in an epoch in case it is incomplete. If you set this option, you should also consider @@ -526,14 +514,11 @@ class ConcatDataModule(lightning.LightningDataModule): split_name: str = "", cache_samples: bool = False, batch_size: int = 1, - batch_chunk_count: int = 1, drop_incomplete_batch: bool = False, parallel: int = -1, ): super().__init__() - self.set_chunk_size(batch_size, batch_chunk_count) - self.splits = splits self.database_name = database_name self.split_name = split_name @@ -550,6 +535,8 @@ class ConcatDataModule(lightning.LightningDataModule): self._model_transforms: list[Transform] | None = None + self.batch_size = batch_size + self.drop_incomplete_batch = drop_incomplete_batch self.parallel = parallel # immutable, otherwise would need to call @@ -661,46 +648,6 @@ class ConcatDataModule(lightning.LightningDataModule): ) self._datasets = {} - def set_chunk_size(self, batch_size: int, batch_chunk_count: int) -> None: - """Coherently set the batch-chunk-size after validation. - - Parameters - ---------- - batch_size - Number of samples in every **training** batch (this parameter affects - memory requirements for the network). If the number of samples in the - batch is larger than the total number of samples available for - training, this value is truncated. If this number is smaller, then - batches of the specified size are created and fed to the network until - there are no more new samples to feed (epoch is finished). If the - total number of training samples is not a multiple of the batch-size, - the last batch will be smaller than the first, unless - ``drop_incomplete_batch`` is set to ``true``, in which case this batch - is not used. - batch_chunk_count - Number of chunks in every batch (this parameter affects memory - requirements for the network). The number of samples loaded for every - iteration will be ``batch_size/batch_chunk_count``. ``batch_size`` - needs to be divisible by ``batch_chunk_count``, otherwise an error will - be raised. This parameter is used to reduce number of samples loaded in - each iteration, in order to reduce the memory usage in exchange for - processing time (more iterations). This is especially interesting when - one is running on GPUs with limited RAM. The default of 1 forces the - whole batch to be processed at once. Otherwise the batch is broken into - batch-chunk-count pieces, and gradients are accumulated to complete - each batch. - """ - # validation - if batch_size % batch_chunk_count != 0: - raise RuntimeError( - f"batch_size ({batch_size}) must be divisible by " - f"batch_chunk_size ({batch_chunk_count}).", - ) - - self._batch_size = batch_size - self._batch_chunk_count = batch_chunk_count - self._chunk_size = self._batch_size // self._batch_chunk_count - def _setup_dataset(self, name: str) -> None: """Set up a single dataset from the input data split. @@ -845,7 +792,7 @@ class ConcatDataModule(lightning.LightningDataModule): return torch.utils.data.DataLoader( self._datasets["train"], shuffle=(self._train_sampler is None), - batch_size=self._chunk_size, + batch_size=self.batch, drop_last=self.drop_incomplete_batch, pin_memory=self.pin_memory, sampler=self._train_sampler, @@ -863,7 +810,7 @@ class ConcatDataModule(lightning.LightningDataModule): return torch.utils.data.DataLoader( self._datasets["train"], shuffle=False, - batch_size=self._chunk_size, + batch_size=self.batch_size, drop_last=False, **self._dataloader_multiproc, ) @@ -877,7 +824,7 @@ class ConcatDataModule(lightning.LightningDataModule): """ validation_loader_opts = { - "batch_size": self._chunk_size, + "batch_size": self.batch_size, "shuffle": False, "drop_last": self.drop_incomplete_batch, "pin_memory": self.pin_memory, @@ -903,7 +850,7 @@ class ConcatDataModule(lightning.LightningDataModule): return dict( test=torch.utils.data.DataLoader( self._datasets["test"], - batch_size=self._chunk_size, + batch_size=self.batch_size, shuffle=False, drop_last=self.drop_incomplete_batch, pin_memory=self.pin_memory, @@ -922,7 +869,7 @@ class ConcatDataModule(lightning.LightningDataModule): return { k: torch.utils.data.DataLoader( self._datasets[k], - batch_size=self._chunk_size, + batch_size=self.batch_size, shuffle=False, drop_last=self.drop_incomplete_batch, pin_memory=self.pin_memory, diff --git a/src/mednet/scripts/predict.py b/src/mednet/scripts/predict.py index b8ecf91c86d5c554337a40ccf7264162692e2d64..a24a98c6d28eb75502fc00565985401c564ec3f1 100644 --- a/src/mednet/scripts/predict.py +++ b/src/mednet/scripts/predict.py @@ -142,7 +142,6 @@ def predict( save_json_with_backup, ) - datamodule.set_chunk_size(batch_size, 1) datamodule.parallel = parallel datamodule.model_transforms = model.model_transforms diff --git a/src/mednet/scripts/saliency/completeness.py b/src/mednet/scripts/saliency/completeness.py index 453618332cd06a4a54570c70b65f7559b0900390..5c93aded47a87728bb72df927dd93e959bdadf0b 100644 --- a/src/mednet/scripts/saliency/completeness.py +++ b/src/mednet/scripts/saliency/completeness.py @@ -216,9 +216,6 @@ def completeness( device_manager = DeviceManager(device) - # batch_size must be == 1 for now (underlying code is NOT prepared to - # treat multiple samples at once). - datamodule.set_chunk_size(1, 1) datamodule.cache_samples = cache_samples datamodule.parallel = parallel datamodule.model_transforms = model.model_transforms diff --git a/src/mednet/scripts/saliency/generate.py b/src/mednet/scripts/saliency/generate.py index 5a9ca8b6955158355e333e8e178a7a5fdd3b814c..71f248fa807fdcfb90bc23f9a8e9efb0fd635cf7 100644 --- a/src/mednet/scripts/saliency/generate.py +++ b/src/mednet/scripts/saliency/generate.py @@ -177,9 +177,6 @@ def generate( device_manager = DeviceManager(device) - # batch_size must be == 1 for now (underlying code is NOT prepared to - # treat multiple samples at once). - datamodule.set_chunk_size(1, 1) datamodule.cache_samples = cache_samples datamodule.parallel = parallel datamodule.model_transforms = model.model_transforms diff --git a/src/mednet/scripts/saliency/view.py b/src/mednet/scripts/saliency/view.py index bb1a5c9d91f27d5d46d4001d0a3b35cc5fb3dc38..270ccbfebf86418bf458add62e71f0cda6baad67 100644 --- a/src/mednet/scripts/saliency/view.py +++ b/src/mednet/scripts/saliency/view.py @@ -110,7 +110,6 @@ def view( logger.info(f"Output folder: {output_folder}") output_folder.mkdir(parents=True, exist_ok=True) - datamodule.set_chunk_size(1, 1) datamodule.drop_incomplete_batch = False # datamodule.cache_samples = cache_samples # datamodule.parallel = parallel diff --git a/src/mednet/scripts/train.py b/src/mednet/scripts/train.py index e83ae359b0be523565f669c57bf1de41520ccac7..78720362864c18616a2df342e130180d82e37a58 100644 --- a/src/mednet/scripts/train.py +++ b/src/mednet/scripts/train.py @@ -283,7 +283,6 @@ def train( seed_everything(seed) # reset datamodule with user configurable options - datamodule.set_chunk_size(batch_size, batch_chunk_count) datamodule.drop_incomplete_batch = drop_incomplete_batch datamodule.cache_samples = cache_samples datamodule.parallel = parallel