diff --git a/pyproject.toml b/pyproject.toml index 34b75b6a14e86ff59cc1f5d9c316a9b3ebac6578..525d3a9b3e99d26ab95c29c5cb9a766b864780a4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -429,6 +429,13 @@ lwnet = "mednet.libs.segmentation.config.models.lwnet" m2unet = "mednet.libs.segmentation.config.models.m2unet" unet = "mednet.libs.segmentation.config.models.unet" +affine = "mednet.libs.common.config.augmentations.affine" +elastic = "mednet.libs.common.config.augmentations.elastic" +hflip = "mednet.libs.common.config.augmentations.hflip" +jitter = "mednet.libs.common.config.augmentations.jitter" +hflip-affine = "mednet.libs.common.config.augmentations.hflip_affine" +hflip-jitter-affine = "mednet.libs.common.config.augmentations.hflip_jitter_affine" + # chase-db1 - retinography chasedb1 = "mednet.libs.segmentation.config.data.chasedb1.first_annotator" chasedb1-2nd = "mednet.libs.segmentation.config.data.chasedb1.second_annotator" diff --git a/src/mednet/libs/segmentation/models/segmentation_model.py b/src/mednet/libs/segmentation/models/segmentation_model.py index 834af45f087e4cc27d9089fc67efe51a6041e5a5..51599b3ef91ada9428cf9f796c5404dff2d58bda 100644 --- a/src/mednet/libs/segmentation/models/segmentation_model.py +++ b/src/mednet/libs/segmentation/models/segmentation_model.py @@ -90,11 +90,11 @@ class SegmentationModel(Model): self.normalizer = make_z_normalizer(dataloader) def training_step(self, batch, batch_idx): - images = batch[0]["image"] - ground_truths = batch[0]["target"] - masks = batch[0]["mask"] + images = self.augmentation_transforms(batch[0]["image"]) + ground_truths = self.augmentation_transforms(batch[0]["target"]) + masks = self.augmentation_transforms(batch[0]["mask"]) - outputs = self(self._augmentation_transforms(images)) + outputs = self(images) return self._train_loss(outputs, ground_truths, masks) def validation_step(self, batch, batch_idx): diff --git a/src/mednet/libs/segmentation/scripts/experiment.py b/src/mednet/libs/segmentation/scripts/experiment.py index 5960e1394e9fc819d9b44e90055d7d1dfd74f0d5..6a53edf76ac94a5246168d2c0009f6649b9c6cc9 100644 --- a/src/mednet/libs/segmentation/scripts/experiment.py +++ b/src/mednet/libs/segmentation/scripts/experiment.py @@ -50,6 +50,7 @@ def experiment( seed, parallel, monitoring_interval, + augmentations, **_, ): # numpydoc ignore=PR01 r"""Run a complete experiment, from training, to prediction and evaluation. @@ -85,6 +86,7 @@ def experiment( seed=seed, parallel=parallel, monitoring_interval=monitoring_interval, + augmentations=augmentations, ) train_stop_timestamp = datetime.now()