From 7e204e950c4889597dbed1751e7bf9f7425045e2 Mon Sep 17 00:00:00 2001
From: dcarron <daniel.carron@idiap.ch>
Date: Wed, 5 Jun 2024 10:44:32 +0200
Subject: [PATCH] [scripts] Set datamodule batch size

---
 src/mednet/libs/common/scripts/predict.py     | 1 +
 src/mednet/libs/common/scripts/train.py       | 2 ++
 src/mednet/libs/segmentation/scripts/train.py | 1 +
 3 files changed, 4 insertions(+)

diff --git a/src/mednet/libs/common/scripts/predict.py b/src/mednet/libs/common/scripts/predict.py
index bd093251..d38c3841 100644
--- a/src/mednet/libs/common/scripts/predict.py
+++ b/src/mednet/libs/common/scripts/predict.py
@@ -127,6 +127,7 @@ def setup_datamodule(
     parallel,
 ) -> None:  # numpydoc ignore=PR01
     """Configure and set up the datamodule."""
+    datamodule.batch_size = batch_size
     datamodule.parallel = parallel
     datamodule.model_transforms = model.model_transforms
 
diff --git a/src/mednet/libs/common/scripts/train.py b/src/mednet/libs/common/scripts/train.py
index b806609a..156d9865 100644
--- a/src/mednet/libs/common/scripts/train.py
+++ b/src/mednet/libs/common/scripts/train.py
@@ -275,11 +275,13 @@ def load_checkpoint(checkpoint_file, datamodule, model):
 def setup_datamodule(
     datamodule,
     model,
+    batch_size,
     drop_incomplete_batch,
     cache_samples,
     parallel,
 ) -> None:  # numpydoc ignore=PR01
     """Configure and set up the datamodule."""
+    datamodule.batch_size = batch_size
     datamodule.drop_incomplete_batch = drop_incomplete_batch
     datamodule.cache_samples = cache_samples
     datamodule.parallel = parallel
diff --git a/src/mednet/libs/segmentation/scripts/train.py b/src/mednet/libs/segmentation/scripts/train.py
index 41467b89..fc726a14 100644
--- a/src/mednet/libs/segmentation/scripts/train.py
+++ b/src/mednet/libs/segmentation/scripts/train.py
@@ -61,6 +61,7 @@ def train(
     setup_datamodule(
         datamodule,
         model,
+        batch_size,
         drop_incomplete_batch,
         cache_samples,
         parallel,
-- 
GitLab