diff --git a/src/mednet/data/datamodule.py b/src/mednet/data/datamodule.py
index 6c7d759f4a8890f576af327450ac89e144f6fbfa..d7b6e93b6e79f51d4c545668ecb7d789e0f5b352 100644
--- a/src/mednet/data/datamodule.py
+++ b/src/mednet/data/datamodule.py
@@ -492,18 +492,6 @@ class ConcatDataModule(lightning.LightningDataModule):
         the last batch will be smaller than the first, unless
         ``drop_incomplete_batch`` is set to ``true``, in which case this batch
         is not used.
-    batch_chunk_count
-        Number of chunks in every batch (this parameter affects memory
-        requirements for the network). The number of samples loaded for every
-        iteration will be ``batch_size/batch_chunk_count``. ``batch_size``
-        needs to be divisible by ``batch_chunk_count``, otherwise an error will
-        be raised. This parameter is used to reduce the number of samples loaded in
-        each iteration, in order to reduce the memory usage in exchange for
-        processing time (more iterations). This is especially interesting when
-        one is running on GPUs with limited RAM. The default of 1 forces the
-        whole batch to be processed at once. Otherwise the batch is broken into
-        batch-chunk-count pieces, and gradients are accumulated to complete
-        each batch.
     drop_incomplete_batch
         If set, then may drop the last batch in an epoch in case it is
         incomplete.  If you set this option, you should also consider
@@ -526,14 +514,11 @@ class ConcatDataModule(lightning.LightningDataModule):
         split_name: str = "",
         cache_samples: bool = False,
         batch_size: int = 1,
-        batch_chunk_count: int = 1,
         drop_incomplete_batch: bool = False,
         parallel: int = -1,
     ):
         super().__init__()
 
-        self.set_chunk_size(batch_size, batch_chunk_count)
-
         self.splits = splits
         self.database_name = database_name
         self.split_name = split_name
@@ -550,6 +535,8 @@ class ConcatDataModule(lightning.LightningDataModule):
 
         self._model_transforms: list[Transform] | None = None
 
+        self.batch_size = batch_size
+
         self.drop_incomplete_batch = drop_incomplete_batch
         self.parallel = parallel  # immutable, otherwise would need to call
 
@@ -661,46 +648,6 @@ class ConcatDataModule(lightning.LightningDataModule):
             )
             self._datasets = {}
 
-    def set_chunk_size(self, batch_size: int, batch_chunk_count: int) -> None:
-        """Coherently set the batch-chunk-size after validation.
-
-        Parameters
-        ----------
-        batch_size
-            Number of samples in every **training** batch (this parameter affects
-            memory requirements for the network).  If the number of samples in the
-            batch is larger than the total number of samples available for
-            training, this value is truncated.  If this number is smaller, then
-            batches of the specified size are created and fed to the network  until
-            there are no more new samples to feed (epoch is finished).  If the
-            total number of training samples is not a multiple of the batch-size,
-            the last batch will be smaller than the first, unless
-            ``drop_incomplete_batch`` is set to ``true``, in which case this batch
-            is not used.
-        batch_chunk_count
-            Number of chunks in every batch (this parameter affects memory
-            requirements for the network). The number of samples loaded for every
-            iteration will be ``batch_size/batch_chunk_count``. ``batch_size``
-            needs to be divisible by ``batch_chunk_count``, otherwise an error will
-            be raised. This parameter is used to reduce number of samples loaded in
-            each iteration, in order to reduce the memory usage in exchange for
-            processing time (more iterations). This is especially interesting when
-            one is running on GPUs with limited RAM. The default of 1 forces the
-            whole batch to be processed at once. Otherwise the batch is broken into
-            batch-chunk-count pieces, and gradients are accumulated to complete
-            each batch.
-        """
-        # validation
-        if batch_size % batch_chunk_count != 0:
-            raise RuntimeError(
-                f"batch_size ({batch_size}) must be divisible by "
-                f"batch_chunk_size ({batch_chunk_count}).",
-            )
-
-        self._batch_size = batch_size
-        self._batch_chunk_count = batch_chunk_count
-        self._chunk_size = self._batch_size // self._batch_chunk_count
-
     def _setup_dataset(self, name: str) -> None:
         """Set up a single dataset from the input data split.
 
@@ -845,7 +792,7 @@ class ConcatDataModule(lightning.LightningDataModule):
         return torch.utils.data.DataLoader(
             self._datasets["train"],
             shuffle=(self._train_sampler is None),
-            batch_size=self._chunk_size,
+            batch_size=self.batch,
             drop_last=self.drop_incomplete_batch,
             pin_memory=self.pin_memory,
             sampler=self._train_sampler,
@@ -863,7 +810,7 @@ class ConcatDataModule(lightning.LightningDataModule):
         return torch.utils.data.DataLoader(
             self._datasets["train"],
             shuffle=False,
-            batch_size=self._chunk_size,
+            batch_size=self.batch_size,
             drop_last=False,
             **self._dataloader_multiproc,
         )
@@ -877,7 +824,7 @@ class ConcatDataModule(lightning.LightningDataModule):
         """
 
         validation_loader_opts = {
-            "batch_size": self._chunk_size,
+            "batch_size": self.batch_size,
             "shuffle": False,
             "drop_last": self.drop_incomplete_batch,
             "pin_memory": self.pin_memory,
@@ -903,7 +850,7 @@ class ConcatDataModule(lightning.LightningDataModule):
         return dict(
             test=torch.utils.data.DataLoader(
                 self._datasets["test"],
-                batch_size=self._chunk_size,
+                batch_size=self.batch_size,
                 shuffle=False,
                 drop_last=self.drop_incomplete_batch,
                 pin_memory=self.pin_memory,
@@ -922,7 +869,7 @@ class ConcatDataModule(lightning.LightningDataModule):
         return {
             k: torch.utils.data.DataLoader(
                 self._datasets[k],
-                batch_size=self._chunk_size,
+                batch_size=self.batch_size,
                 shuffle=False,
                 drop_last=self.drop_incomplete_batch,
                 pin_memory=self.pin_memory,
diff --git a/src/mednet/scripts/predict.py b/src/mednet/scripts/predict.py
index b8ecf91c86d5c554337a40ccf7264162692e2d64..a24a98c6d28eb75502fc00565985401c564ec3f1 100644
--- a/src/mednet/scripts/predict.py
+++ b/src/mednet/scripts/predict.py
@@ -142,7 +142,6 @@ def predict(
         save_json_with_backup,
     )
 
-    datamodule.set_chunk_size(batch_size, 1)
     datamodule.parallel = parallel
     datamodule.model_transforms = model.model_transforms
 
diff --git a/src/mednet/scripts/saliency/completeness.py b/src/mednet/scripts/saliency/completeness.py
index 453618332cd06a4a54570c70b65f7559b0900390..5c93aded47a87728bb72df927dd93e959bdadf0b 100644
--- a/src/mednet/scripts/saliency/completeness.py
+++ b/src/mednet/scripts/saliency/completeness.py
@@ -216,9 +216,6 @@ def completeness(
 
     device_manager = DeviceManager(device)
 
-    # batch_size must be == 1 for now (underlying code is NOT prepared to
-    # treat multiple samples at once).
-    datamodule.set_chunk_size(1, 1)
     datamodule.cache_samples = cache_samples
     datamodule.parallel = parallel
     datamodule.model_transforms = model.model_transforms
diff --git a/src/mednet/scripts/saliency/generate.py b/src/mednet/scripts/saliency/generate.py
index 5a9ca8b6955158355e333e8e178a7a5fdd3b814c..71f248fa807fdcfb90bc23f9a8e9efb0fd635cf7 100644
--- a/src/mednet/scripts/saliency/generate.py
+++ b/src/mednet/scripts/saliency/generate.py
@@ -177,9 +177,6 @@ def generate(
 
     device_manager = DeviceManager(device)
 
-    # batch_size must be == 1 for now (underlying code is NOT prepared to
-    # treat multiple samples at once).
-    datamodule.set_chunk_size(1, 1)
     datamodule.cache_samples = cache_samples
     datamodule.parallel = parallel
     datamodule.model_transforms = model.model_transforms
diff --git a/src/mednet/scripts/saliency/view.py b/src/mednet/scripts/saliency/view.py
index bb1a5c9d91f27d5d46d4001d0a3b35cc5fb3dc38..270ccbfebf86418bf458add62e71f0cda6baad67 100644
--- a/src/mednet/scripts/saliency/view.py
+++ b/src/mednet/scripts/saliency/view.py
@@ -110,7 +110,6 @@ def view(
     logger.info(f"Output folder: {output_folder}")
     output_folder.mkdir(parents=True, exist_ok=True)
 
-    datamodule.set_chunk_size(1, 1)
     datamodule.drop_incomplete_batch = False
     # datamodule.cache_samples = cache_samples
     # datamodule.parallel = parallel
diff --git a/src/mednet/scripts/train.py b/src/mednet/scripts/train.py
index e83ae359b0be523565f669c57bf1de41520ccac7..78720362864c18616a2df342e130180d82e37a58 100644
--- a/src/mednet/scripts/train.py
+++ b/src/mednet/scripts/train.py
@@ -283,7 +283,6 @@ def train(
     seed_everything(seed)
 
     # reset datamodule with user configurable options
-    datamodule.set_chunk_size(batch_size, batch_chunk_count)
     datamodule.drop_incomplete_batch = drop_incomplete_batch
     datamodule.cache_samples = cache_samples
     datamodule.parallel = parallel