From 8ce33d1db03ec0e49c516e54e43200e640c680c7 Mon Sep 17 00:00:00 2001 From: dcarron <daniel.carron@idiap.ch> Date: Wed, 29 May 2024 16:07:20 +0200 Subject: [PATCH] [mednet] Fixes after rebase --- src/mednet/libs/classification/scripts/train.py | 3 +-- src/mednet/libs/common/scripts/predict.py | 1 - src/mednet/libs/common/scripts/train.py | 2 -- src/mednet/libs/segmentation/scripts/experiment.py | 4 ++-- src/mednet/libs/segmentation/scripts/train.py | 7 +++---- 5 files changed, 6 insertions(+), 11 deletions(-) diff --git a/src/mednet/libs/classification/scripts/train.py b/src/mednet/libs/classification/scripts/train.py index 85a570d2..8483f667 100644 --- a/src/mednet/libs/classification/scripts/train.py +++ b/src/mednet/libs/classification/scripts/train.py @@ -76,7 +76,6 @@ def train( datamodule, model, batch_size, - batch_chunk_count, drop_incomplete_batch, cache_samples, parallel, @@ -120,6 +119,6 @@ def train( max_epochs=epochs, output_folder=output_folder, monitoring_interval=monitoring_interval, - batch_chunk_count=batch_chunk_count, + accumulate_grad_batches=accumulate_grad_batches, checkpoint=checkpoint_file, ) diff --git a/src/mednet/libs/common/scripts/predict.py b/src/mednet/libs/common/scripts/predict.py index 48fbf4d1..bd093251 100644 --- a/src/mednet/libs/common/scripts/predict.py +++ b/src/mednet/libs/common/scripts/predict.py @@ -127,7 +127,6 @@ def setup_datamodule( parallel, ) -> None: # numpydoc ignore=PR01 """Configure and set up the datamodule.""" - datamodule.set_chunk_size(batch_size, 1) datamodule.parallel = parallel datamodule.model_transforms = model.model_transforms diff --git a/src/mednet/libs/common/scripts/train.py b/src/mednet/libs/common/scripts/train.py index db48174d..774d4853 100644 --- a/src/mednet/libs/common/scripts/train.py +++ b/src/mednet/libs/common/scripts/train.py @@ -276,13 +276,11 @@ def setup_datamodule( datamodule, model, batch_size, - batch_chunk_count, drop_incomplete_batch, cache_samples, parallel, ) -> None: # numpydoc ignore=PR01 """Configure and set up the datamodule.""" - datamodule.set_chunk_size(batch_size, batch_chunk_count) datamodule.drop_incomplete_batch = drop_incomplete_batch datamodule.cache_samples = cache_samples datamodule.parallel = parallel diff --git a/src/mednet/libs/segmentation/scripts/experiment.py b/src/mednet/libs/segmentation/scripts/experiment.py index ea74b2dd..b909e26f 100644 --- a/src/mednet/libs/segmentation/scripts/experiment.py +++ b/src/mednet/libs/segmentation/scripts/experiment.py @@ -41,7 +41,7 @@ def experiment( output_folder, epochs, batch_size, - batch_chunk_count, + accumulate_grad_batches, drop_incomplete_batch, datamodule, validation_period, @@ -79,7 +79,7 @@ def experiment( output_folder=train_output_folder, epochs=epochs, batch_size=batch_size, - batch_chunk_count=batch_chunk_count, + accumulate_grad_batches=accumulate_grad_batches, drop_incomplete_batch=drop_incomplete_batch, datamodule=datamodule, validation_period=validation_period, diff --git a/src/mednet/libs/segmentation/scripts/train.py b/src/mednet/libs/segmentation/scripts/train.py index 67477db3..fc726a14 100644 --- a/src/mednet/libs/segmentation/scripts/train.py +++ b/src/mednet/libs/segmentation/scripts/train.py @@ -32,7 +32,7 @@ def train( output_folder, epochs, batch_size, - batch_chunk_count, + accumulate_grad_batches, drop_incomplete_batch, datamodule, validation_period, @@ -62,7 +62,6 @@ def train( datamodule, model, batch_size, - batch_chunk_count, drop_incomplete_batch, cache_samples, parallel, @@ -81,7 +80,7 @@ def train( device_manager, epochs, batch_size, - batch_chunk_count, + accumulate_grad_batches, drop_incomplete_batch, validation_period, cache_samples, @@ -98,6 +97,6 @@ def train( max_epochs=epochs, output_folder=output_folder, monitoring_interval=monitoring_interval, - batch_chunk_count=batch_chunk_count, + accumulate_grad_batches=accumulate_grad_batches, checkpoint=checkpoint_file, ) -- GitLab