From 8ce33d1db03ec0e49c516e54e43200e640c680c7 Mon Sep 17 00:00:00 2001
From: dcarron <daniel.carron@idiap.ch>
Date: Wed, 29 May 2024 16:07:20 +0200
Subject: [PATCH] [mednet] Fixes after rebase

---
 src/mednet/libs/classification/scripts/train.py    | 3 +--
 src/mednet/libs/common/scripts/predict.py          | 1 -
 src/mednet/libs/common/scripts/train.py            | 2 --
 src/mednet/libs/segmentation/scripts/experiment.py | 4 ++--
 src/mednet/libs/segmentation/scripts/train.py      | 7 +++----
 5 files changed, 6 insertions(+), 11 deletions(-)

diff --git a/src/mednet/libs/classification/scripts/train.py b/src/mednet/libs/classification/scripts/train.py
index 85a570d2..8483f667 100644
--- a/src/mednet/libs/classification/scripts/train.py
+++ b/src/mednet/libs/classification/scripts/train.py
@@ -76,7 +76,6 @@ def train(
         datamodule,
         model,
         batch_size,
-        batch_chunk_count,
         drop_incomplete_batch,
         cache_samples,
         parallel,
@@ -120,6 +119,6 @@ def train(
         max_epochs=epochs,
         output_folder=output_folder,
         monitoring_interval=monitoring_interval,
-        batch_chunk_count=batch_chunk_count,
+        accumulate_grad_batches=accumulate_grad_batches,
         checkpoint=checkpoint_file,
     )
diff --git a/src/mednet/libs/common/scripts/predict.py b/src/mednet/libs/common/scripts/predict.py
index 48fbf4d1..bd093251 100644
--- a/src/mednet/libs/common/scripts/predict.py
+++ b/src/mednet/libs/common/scripts/predict.py
@@ -127,7 +127,6 @@ def setup_datamodule(
     parallel,
 ) -> None:  # numpydoc ignore=PR01
     """Configure and set up the datamodule."""
-    datamodule.set_chunk_size(batch_size, 1)
     datamodule.parallel = parallel
     datamodule.model_transforms = model.model_transforms
 
diff --git a/src/mednet/libs/common/scripts/train.py b/src/mednet/libs/common/scripts/train.py
index db48174d..774d4853 100644
--- a/src/mednet/libs/common/scripts/train.py
+++ b/src/mednet/libs/common/scripts/train.py
@@ -276,13 +276,11 @@ def setup_datamodule(
     datamodule,
     model,
     batch_size,
-    batch_chunk_count,
     drop_incomplete_batch,
     cache_samples,
     parallel,
 ) -> None:  # numpydoc ignore=PR01
     """Configure and set up the datamodule."""
-    datamodule.set_chunk_size(batch_size, batch_chunk_count)
     datamodule.drop_incomplete_batch = drop_incomplete_batch
     datamodule.cache_samples = cache_samples
     datamodule.parallel = parallel
diff --git a/src/mednet/libs/segmentation/scripts/experiment.py b/src/mednet/libs/segmentation/scripts/experiment.py
index ea74b2dd..b909e26f 100644
--- a/src/mednet/libs/segmentation/scripts/experiment.py
+++ b/src/mednet/libs/segmentation/scripts/experiment.py
@@ -41,7 +41,7 @@ def experiment(
     output_folder,
     epochs,
     batch_size,
-    batch_chunk_count,
+    accumulate_grad_batches,
     drop_incomplete_batch,
     datamodule,
     validation_period,
@@ -79,7 +79,7 @@ def experiment(
         output_folder=train_output_folder,
         epochs=epochs,
         batch_size=batch_size,
-        batch_chunk_count=batch_chunk_count,
+        accumulate_grad_batches=accumulate_grad_batches,
         drop_incomplete_batch=drop_incomplete_batch,
         datamodule=datamodule,
         validation_period=validation_period,
diff --git a/src/mednet/libs/segmentation/scripts/train.py b/src/mednet/libs/segmentation/scripts/train.py
index 67477db3..fc726a14 100644
--- a/src/mednet/libs/segmentation/scripts/train.py
+++ b/src/mednet/libs/segmentation/scripts/train.py
@@ -32,7 +32,7 @@ def train(
     output_folder,
     epochs,
     batch_size,
-    batch_chunk_count,
+    accumulate_grad_batches,
     drop_incomplete_batch,
     datamodule,
     validation_period,
@@ -62,7 +62,6 @@ def train(
         datamodule,
         model,
         batch_size,
-        batch_chunk_count,
         drop_incomplete_batch,
         cache_samples,
         parallel,
@@ -81,7 +80,7 @@ def train(
         device_manager,
         epochs,
         batch_size,
-        batch_chunk_count,
+        accumulate_grad_batches,
         drop_incomplete_batch,
         validation_period,
         cache_samples,
@@ -98,6 +97,6 @@ def train(
         max_epochs=epochs,
         output_folder=output_folder,
         monitoring_interval=monitoring_interval,
-        batch_chunk_count=batch_chunk_count,
+        accumulate_grad_batches=accumulate_grad_batches,
         checkpoint=checkpoint_file,
     )
-- 
GitLab