From d75f924dcf6bf11dc109572e002cb51ca6cf09a4 Mon Sep 17 00:00:00 2001
From: dcarron <daniel.carron@idiap.ch>
Date: Mon, 1 Jul 2024 10:34:08 +0200
Subject: [PATCH] [segmentation] Apply augmentation transforms

---
 pyproject.toml                                            | 7 +++++++
 src/mednet/libs/segmentation/models/segmentation_model.py | 8 ++++----
 src/mednet/libs/segmentation/scripts/experiment.py        | 2 ++
 3 files changed, 13 insertions(+), 4 deletions(-)

diff --git a/pyproject.toml b/pyproject.toml
index 34b75b6a..525d3a9b 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -429,6 +429,13 @@ lwnet = "mednet.libs.segmentation.config.models.lwnet"
 m2unet = "mednet.libs.segmentation.config.models.m2unet"
 unet = "mednet.libs.segmentation.config.models.unet"
 
+affine = "mednet.libs.common.config.augmentations.affine"
+elastic = "mednet.libs.common.config.augmentations.elastic"
+hflip = "mednet.libs.common.config.augmentations.hflip"
+jitter = "mednet.libs.common.config.augmentations.jitter"
+hflip-affine = "mednet.libs.common.config.augmentations.hflip_affine"
+hflip-jitter-affine = "mednet.libs.common.config.augmentations.hflip_jitter_affine"
+
 # chase-db1 - retinography
 chasedb1 = "mednet.libs.segmentation.config.data.chasedb1.first_annotator"
 chasedb1-2nd = "mednet.libs.segmentation.config.data.chasedb1.second_annotator"
diff --git a/src/mednet/libs/segmentation/models/segmentation_model.py b/src/mednet/libs/segmentation/models/segmentation_model.py
index 834af45f..51599b3e 100644
--- a/src/mednet/libs/segmentation/models/segmentation_model.py
+++ b/src/mednet/libs/segmentation/models/segmentation_model.py
@@ -90,11 +90,11 @@ class SegmentationModel(Model):
         self.normalizer = make_z_normalizer(dataloader)
 
     def training_step(self, batch, batch_idx):
-        images = batch[0]["image"]
-        ground_truths = batch[0]["target"]
-        masks = batch[0]["mask"]
+        images = self.augmentation_transforms(batch[0]["image"])
+        ground_truths = self.augmentation_transforms(batch[0]["target"])
+        masks = self.augmentation_transforms(batch[0]["mask"])
 
-        outputs = self(self._augmentation_transforms(images))
+        outputs = self(images)
         return self._train_loss(outputs, ground_truths, masks)
 
     def validation_step(self, batch, batch_idx):
diff --git a/src/mednet/libs/segmentation/scripts/experiment.py b/src/mednet/libs/segmentation/scripts/experiment.py
index 5960e139..6a53edf7 100644
--- a/src/mednet/libs/segmentation/scripts/experiment.py
+++ b/src/mednet/libs/segmentation/scripts/experiment.py
@@ -50,6 +50,7 @@ def experiment(
     seed,
     parallel,
     monitoring_interval,
+    augmentations,
     **_,
 ):  # numpydoc ignore=PR01
     r"""Run a complete experiment, from training, to prediction and evaluation.
@@ -85,6 +86,7 @@ def experiment(
         seed=seed,
         parallel=parallel,
         monitoring_interval=monitoring_interval,
+        augmentations=augmentations,
     )
     train_stop_timestamp = datetime.now()
 
-- 
GitLab