diff --git a/src/mednet/libs/classification/scripts/train.py b/src/mednet/libs/classification/scripts/train.py
index 85a570d28d0619c2f170d640129434ae26a4d250..8483f667eb71197a191345901e59be02ec5171cd 100644
--- a/src/mednet/libs/classification/scripts/train.py
+++ b/src/mednet/libs/classification/scripts/train.py
@@ -76,7 +76,6 @@ def train(
         datamodule,
         model,
         batch_size,
-        batch_chunk_count,
         drop_incomplete_batch,
         cache_samples,
         parallel,
@@ -120,6 +119,6 @@ def train(
         max_epochs=epochs,
         output_folder=output_folder,
         monitoring_interval=monitoring_interval,
-        batch_chunk_count=batch_chunk_count,
+        accumulate_grad_batches=accumulate_grad_batches,
         checkpoint=checkpoint_file,
     )
diff --git a/src/mednet/libs/common/scripts/predict.py b/src/mednet/libs/common/scripts/predict.py
index 48fbf4d1c3fd76fc281c7df320700e0e6ff421c8..bd093251405092d407384f917971f9f2e26e0dd6 100644
--- a/src/mednet/libs/common/scripts/predict.py
+++ b/src/mednet/libs/common/scripts/predict.py
@@ -127,7 +127,6 @@ def setup_datamodule(
     parallel,
 ) -> None:  # numpydoc ignore=PR01
     """Configure and set up the datamodule."""
-    datamodule.set_chunk_size(batch_size, 1)
     datamodule.parallel = parallel
     datamodule.model_transforms = model.model_transforms
 
diff --git a/src/mednet/libs/common/scripts/train.py b/src/mednet/libs/common/scripts/train.py
index db48174d449fb42879c81ec086b6c0216f3fe1f4..774d485316e068af2192cb2d846a84f5612a6e88 100644
--- a/src/mednet/libs/common/scripts/train.py
+++ b/src/mednet/libs/common/scripts/train.py
@@ -276,13 +276,11 @@ def setup_datamodule(
     datamodule,
     model,
     batch_size,
-    batch_chunk_count,
     drop_incomplete_batch,
     cache_samples,
     parallel,
 ) -> None:  # numpydoc ignore=PR01
     """Configure and set up the datamodule."""
-    datamodule.set_chunk_size(batch_size, batch_chunk_count)
     datamodule.drop_incomplete_batch = drop_incomplete_batch
     datamodule.cache_samples = cache_samples
     datamodule.parallel = parallel
diff --git a/src/mednet/libs/segmentation/scripts/experiment.py b/src/mednet/libs/segmentation/scripts/experiment.py
index ea74b2dd25d7a158b0faa3668e6229840bd47ff1..b909e26f3eb9ed28d01d2fbfc89080dc4e06f2f0 100644
--- a/src/mednet/libs/segmentation/scripts/experiment.py
+++ b/src/mednet/libs/segmentation/scripts/experiment.py
@@ -41,7 +41,7 @@ def experiment(
     output_folder,
     epochs,
     batch_size,
-    batch_chunk_count,
+    accumulate_grad_batches,
     drop_incomplete_batch,
     datamodule,
     validation_period,
@@ -79,7 +79,7 @@ def experiment(
         output_folder=train_output_folder,
         epochs=epochs,
         batch_size=batch_size,
-        batch_chunk_count=batch_chunk_count,
+        accumulate_grad_batches=accumulate_grad_batches,
         drop_incomplete_batch=drop_incomplete_batch,
         datamodule=datamodule,
         validation_period=validation_period,
diff --git a/src/mednet/libs/segmentation/scripts/train.py b/src/mednet/libs/segmentation/scripts/train.py
index 67477db3c3b047fcfcfe9df0c9996193cab5a748..fc726a14cbf5f9efbdcc66e5b300785b5a676814 100644
--- a/src/mednet/libs/segmentation/scripts/train.py
+++ b/src/mednet/libs/segmentation/scripts/train.py
@@ -32,7 +32,7 @@ def train(
     output_folder,
     epochs,
     batch_size,
-    batch_chunk_count,
+    accumulate_grad_batches,
     drop_incomplete_batch,
     datamodule,
     validation_period,
@@ -62,7 +62,6 @@ def train(
         datamodule,
         model,
         batch_size,
-        batch_chunk_count,
         drop_incomplete_batch,
         cache_samples,
         parallel,
@@ -81,7 +80,7 @@ def train(
         device_manager,
         epochs,
         batch_size,
-        batch_chunk_count,
+        accumulate_grad_batches,
         drop_incomplete_batch,
         validation_period,
         cache_samples,
@@ -98,6 +97,6 @@ def train(
         max_epochs=epochs,
         output_folder=output_folder,
         monitoring_interval=monitoring_interval,
-        batch_chunk_count=batch_chunk_count,
+        accumulate_grad_batches=accumulate_grad_batches,
         checkpoint=checkpoint_file,
     )