diff --git a/src/mednet/libs/classification/scripts/train.py b/src/mednet/libs/classification/scripts/train.py index 85a570d28d0619c2f170d640129434ae26a4d250..8483f667eb71197a191345901e59be02ec5171cd 100644 --- a/src/mednet/libs/classification/scripts/train.py +++ b/src/mednet/libs/classification/scripts/train.py @@ -76,7 +76,6 @@ def train( datamodule, model, batch_size, - batch_chunk_count, drop_incomplete_batch, cache_samples, parallel, @@ -120,6 +119,6 @@ def train( max_epochs=epochs, output_folder=output_folder, monitoring_interval=monitoring_interval, - batch_chunk_count=batch_chunk_count, + accumulate_grad_batches=accumulate_grad_batches, checkpoint=checkpoint_file, ) diff --git a/src/mednet/libs/common/scripts/predict.py b/src/mednet/libs/common/scripts/predict.py index 48fbf4d1c3fd76fc281c7df320700e0e6ff421c8..bd093251405092d407384f917971f9f2e26e0dd6 100644 --- a/src/mednet/libs/common/scripts/predict.py +++ b/src/mednet/libs/common/scripts/predict.py @@ -127,7 +127,6 @@ def setup_datamodule( parallel, ) -> None: # numpydoc ignore=PR01 """Configure and set up the datamodule.""" - datamodule.set_chunk_size(batch_size, 1) datamodule.parallel = parallel datamodule.model_transforms = model.model_transforms diff --git a/src/mednet/libs/common/scripts/train.py b/src/mednet/libs/common/scripts/train.py index db48174d449fb42879c81ec086b6c0216f3fe1f4..774d485316e068af2192cb2d846a84f5612a6e88 100644 --- a/src/mednet/libs/common/scripts/train.py +++ b/src/mednet/libs/common/scripts/train.py @@ -276,13 +276,11 @@ def setup_datamodule( datamodule, model, batch_size, - batch_chunk_count, drop_incomplete_batch, cache_samples, parallel, ) -> None: # numpydoc ignore=PR01 """Configure and set up the datamodule.""" - datamodule.set_chunk_size(batch_size, batch_chunk_count) datamodule.drop_incomplete_batch = drop_incomplete_batch datamodule.cache_samples = cache_samples datamodule.parallel = parallel diff --git a/src/mednet/libs/segmentation/scripts/experiment.py b/src/mednet/libs/segmentation/scripts/experiment.py index ea74b2dd25d7a158b0faa3668e6229840bd47ff1..b909e26f3eb9ed28d01d2fbfc89080dc4e06f2f0 100644 --- a/src/mednet/libs/segmentation/scripts/experiment.py +++ b/src/mednet/libs/segmentation/scripts/experiment.py @@ -41,7 +41,7 @@ def experiment( output_folder, epochs, batch_size, - batch_chunk_count, + accumulate_grad_batches, drop_incomplete_batch, datamodule, validation_period, @@ -79,7 +79,7 @@ def experiment( output_folder=train_output_folder, epochs=epochs, batch_size=batch_size, - batch_chunk_count=batch_chunk_count, + accumulate_grad_batches=accumulate_grad_batches, drop_incomplete_batch=drop_incomplete_batch, datamodule=datamodule, validation_period=validation_period, diff --git a/src/mednet/libs/segmentation/scripts/train.py b/src/mednet/libs/segmentation/scripts/train.py index 67477db3c3b047fcfcfe9df0c9996193cab5a748..fc726a14cbf5f9efbdcc66e5b300785b5a676814 100644 --- a/src/mednet/libs/segmentation/scripts/train.py +++ b/src/mednet/libs/segmentation/scripts/train.py @@ -32,7 +32,7 @@ def train( output_folder, epochs, batch_size, - batch_chunk_count, + accumulate_grad_batches, drop_incomplete_batch, datamodule, validation_period, @@ -62,7 +62,6 @@ def train( datamodule, model, batch_size, - batch_chunk_count, drop_incomplete_batch, cache_samples, parallel, @@ -81,7 +80,7 @@ def train( device_manager, epochs, batch_size, - batch_chunk_count, + accumulate_grad_batches, drop_incomplete_batch, validation_period, cache_samples, @@ -98,6 +97,6 @@ def train( max_epochs=epochs, output_folder=output_folder, monitoring_interval=monitoring_interval, - batch_chunk_count=batch_chunk_count, + accumulate_grad_batches=accumulate_grad_batches, checkpoint=checkpoint_file, )