diff --git a/src/mednet/libs/classification/scripts/train.py b/src/mednet/libs/classification/scripts/train.py
index d1a1376964edc5fb0800b926955996d8429c7b9d..8483f667eb71197a191345901e59be02ec5171cd 100644
--- a/src/mednet/libs/classification/scripts/train.py
+++ b/src/mednet/libs/classification/scripts/train.py
@@ -75,6 +75,7 @@ def train(
     setup_datamodule(
         datamodule,
         model,
+        batch_size,
         drop_incomplete_batch,
         cache_samples,
         parallel,