From d75f924dcf6bf11dc109572e002cb51ca6cf09a4 Mon Sep 17 00:00:00 2001 From: dcarron <daniel.carron@idiap.ch> Date: Mon, 1 Jul 2024 10:34:08 +0200 Subject: [PATCH] [segmentation] Apply augmentation transforms --- pyproject.toml | 7 +++++++ src/mednet/libs/segmentation/models/segmentation_model.py | 8 ++++---- src/mednet/libs/segmentation/scripts/experiment.py | 2 ++ 3 files changed, 13 insertions(+), 4 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 34b75b6a..525d3a9b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -429,6 +429,13 @@ lwnet = "mednet.libs.segmentation.config.models.lwnet" m2unet = "mednet.libs.segmentation.config.models.m2unet" unet = "mednet.libs.segmentation.config.models.unet" +affine = "mednet.libs.common.config.augmentations.affine" +elastic = "mednet.libs.common.config.augmentations.elastic" +hflip = "mednet.libs.common.config.augmentations.hflip" +jitter = "mednet.libs.common.config.augmentations.jitter" +hflip-affine = "mednet.libs.common.config.augmentations.hflip_affine" +hflip-jitter-affine = "mednet.libs.common.config.augmentations.hflip_jitter_affine" + # chase-db1 - retinography chasedb1 = "mednet.libs.segmentation.config.data.chasedb1.first_annotator" chasedb1-2nd = "mednet.libs.segmentation.config.data.chasedb1.second_annotator" diff --git a/src/mednet/libs/segmentation/models/segmentation_model.py b/src/mednet/libs/segmentation/models/segmentation_model.py index 834af45f..51599b3e 100644 --- a/src/mednet/libs/segmentation/models/segmentation_model.py +++ b/src/mednet/libs/segmentation/models/segmentation_model.py @@ -90,11 +90,11 @@ class SegmentationModel(Model): self.normalizer = make_z_normalizer(dataloader) def training_step(self, batch, batch_idx): - images = batch[0]["image"] - ground_truths = batch[0]["target"] - masks = batch[0]["mask"] + images = self.augmentation_transforms(batch[0]["image"]) + ground_truths = self.augmentation_transforms(batch[0]["target"]) + masks = self.augmentation_transforms(batch[0]["mask"]) - outputs = self(self._augmentation_transforms(images)) + outputs = self(images) return self._train_loss(outputs, ground_truths, masks) def validation_step(self, batch, batch_idx): diff --git a/src/mednet/libs/segmentation/scripts/experiment.py b/src/mednet/libs/segmentation/scripts/experiment.py index 5960e139..6a53edf7 100644 --- a/src/mednet/libs/segmentation/scripts/experiment.py +++ b/src/mednet/libs/segmentation/scripts/experiment.py @@ -50,6 +50,7 @@ def experiment( seed, parallel, monitoring_interval, + augmentations, **_, ): # numpydoc ignore=PR01 r"""Run a complete experiment, from training, to prediction and evaluation. @@ -85,6 +86,7 @@ def experiment( seed=seed, parallel=parallel, monitoring_interval=monitoring_interval, + augmentations=augmentations, ) train_stop_timestamp = datetime.now() -- GitLab