diff --git a/src/mednet/libs/common/scripts/predict.py b/src/mednet/libs/common/scripts/predict.py
index bd093251405092d407384f917971f9f2e26e0dd6..d38c38411257c95a293d491bb178470d4560da64 100644
--- a/src/mednet/libs/common/scripts/predict.py
+++ b/src/mednet/libs/common/scripts/predict.py
@@ -127,6 +127,7 @@ def setup_datamodule(
     parallel,
 ) -> None:  # numpydoc ignore=PR01
     """Configure and set up the datamodule."""
+    datamodule.batch_size = batch_size
     datamodule.parallel = parallel
     datamodule.model_transforms = model.model_transforms
 
diff --git a/src/mednet/libs/common/scripts/train.py b/src/mednet/libs/common/scripts/train.py
index b806609a84634559f6cec6a28cf6ef69b76b21dd..156d98654172bd4e85d8b487aed2efb460da843c 100644
--- a/src/mednet/libs/common/scripts/train.py
+++ b/src/mednet/libs/common/scripts/train.py
@@ -275,11 +275,13 @@ def load_checkpoint(checkpoint_file, datamodule, model):
 def setup_datamodule(
     datamodule,
     model,
+    batch_size,
     drop_incomplete_batch,
     cache_samples,
     parallel,
 ) -> None:  # numpydoc ignore=PR01
     """Configure and set up the datamodule."""
+    datamodule.batch_size = batch_size
     datamodule.drop_incomplete_batch = drop_incomplete_batch
     datamodule.cache_samples = cache_samples
     datamodule.parallel = parallel
diff --git a/src/mednet/libs/segmentation/scripts/train.py b/src/mednet/libs/segmentation/scripts/train.py
index 41467b896c63136d0b6905151645d63d5346c755..fc726a14cbf5f9efbdcc66e5b300785b5a676814 100644
--- a/src/mednet/libs/segmentation/scripts/train.py
+++ b/src/mednet/libs/segmentation/scripts/train.py
@@ -61,6 +61,7 @@ def train(
     setup_datamodule(
         datamodule,
         model,
+        batch_size,
         drop_incomplete_batch,
         cache_samples,
         parallel,