From 7e204e950c4889597dbed1751e7bf9f7425045e2 Mon Sep 17 00:00:00 2001 From: dcarron <daniel.carron@idiap.ch> Date: Wed, 5 Jun 2024 10:44:32 +0200 Subject: [PATCH] [scripts] Set datamodule batch size --- src/mednet/libs/common/scripts/predict.py | 1 + src/mednet/libs/common/scripts/train.py | 2 ++ src/mednet/libs/segmentation/scripts/train.py | 1 + 3 files changed, 4 insertions(+) diff --git a/src/mednet/libs/common/scripts/predict.py b/src/mednet/libs/common/scripts/predict.py index bd093251..d38c3841 100644 --- a/src/mednet/libs/common/scripts/predict.py +++ b/src/mednet/libs/common/scripts/predict.py @@ -127,6 +127,7 @@ def setup_datamodule( parallel, ) -> None: # numpydoc ignore=PR01 """Configure and set up the datamodule.""" + datamodule.batch_size = batch_size datamodule.parallel = parallel datamodule.model_transforms = model.model_transforms diff --git a/src/mednet/libs/common/scripts/train.py b/src/mednet/libs/common/scripts/train.py index b806609a..156d9865 100644 --- a/src/mednet/libs/common/scripts/train.py +++ b/src/mednet/libs/common/scripts/train.py @@ -275,11 +275,13 @@ def load_checkpoint(checkpoint_file, datamodule, model): def setup_datamodule( datamodule, model, + batch_size, drop_incomplete_batch, cache_samples, parallel, ) -> None: # numpydoc ignore=PR01 """Configure and set up the datamodule.""" + datamodule.batch_size = batch_size datamodule.drop_incomplete_batch = drop_incomplete_batch datamodule.cache_samples = cache_samples datamodule.parallel = parallel diff --git a/src/mednet/libs/segmentation/scripts/train.py b/src/mednet/libs/segmentation/scripts/train.py index 41467b89..fc726a14 100644 --- a/src/mednet/libs/segmentation/scripts/train.py +++ b/src/mednet/libs/segmentation/scripts/train.py @@ -61,6 +61,7 @@ def train( setup_datamodule( datamodule, model, + batch_size, drop_incomplete_batch, cache_samples, parallel, -- GitLab