diff --git a/src/mednet/libs/common/scripts/predict.py b/src/mednet/libs/common/scripts/predict.py index bd093251405092d407384f917971f9f2e26e0dd6..d38c38411257c95a293d491bb178470d4560da64 100644 --- a/src/mednet/libs/common/scripts/predict.py +++ b/src/mednet/libs/common/scripts/predict.py @@ -127,6 +127,7 @@ def setup_datamodule( parallel, ) -> None: # numpydoc ignore=PR01 """Configure and set up the datamodule.""" + datamodule.batch_size = batch_size datamodule.parallel = parallel datamodule.model_transforms = model.model_transforms diff --git a/src/mednet/libs/common/scripts/train.py b/src/mednet/libs/common/scripts/train.py index b806609a84634559f6cec6a28cf6ef69b76b21dd..156d98654172bd4e85d8b487aed2efb460da843c 100644 --- a/src/mednet/libs/common/scripts/train.py +++ b/src/mednet/libs/common/scripts/train.py @@ -275,11 +275,13 @@ def load_checkpoint(checkpoint_file, datamodule, model): def setup_datamodule( datamodule, model, + batch_size, drop_incomplete_batch, cache_samples, parallel, ) -> None: # numpydoc ignore=PR01 """Configure and set up the datamodule.""" + datamodule.batch_size = batch_size datamodule.drop_incomplete_batch = drop_incomplete_batch datamodule.cache_samples = cache_samples datamodule.parallel = parallel diff --git a/src/mednet/libs/segmentation/scripts/train.py b/src/mednet/libs/segmentation/scripts/train.py index 41467b896c63136d0b6905151645d63d5346c755..fc726a14cbf5f9efbdcc66e5b300785b5a676814 100644 --- a/src/mednet/libs/segmentation/scripts/train.py +++ b/src/mednet/libs/segmentation/scripts/train.py @@ -61,6 +61,7 @@ def train( setup_datamodule( datamodule, model, + batch_size, drop_incomplete_batch, cache_samples, parallel,