diff --git a/pyproject.toml b/pyproject.toml
index 34b75b6a14e86ff59cc1f5d9c316a9b3ebac6578..525d3a9b3e99d26ab95c29c5cb9a766b864780a4 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -429,6 +429,13 @@ lwnet = "mednet.libs.segmentation.config.models.lwnet"
 m2unet = "mednet.libs.segmentation.config.models.m2unet"
 unet = "mednet.libs.segmentation.config.models.unet"
 
+affine = "mednet.libs.common.config.augmentations.affine"
+elastic = "mednet.libs.common.config.augmentations.elastic"
+hflip = "mednet.libs.common.config.augmentations.hflip"
+jitter = "mednet.libs.common.config.augmentations.jitter"
+hflip-affine = "mednet.libs.common.config.augmentations.hflip_affine"
+hflip-jitter-affine = "mednet.libs.common.config.augmentations.hflip_jitter_affine"
+
 # chase-db1 - retinography
 chasedb1 = "mednet.libs.segmentation.config.data.chasedb1.first_annotator"
 chasedb1-2nd = "mednet.libs.segmentation.config.data.chasedb1.second_annotator"
diff --git a/src/mednet/libs/segmentation/models/segmentation_model.py b/src/mednet/libs/segmentation/models/segmentation_model.py
index 834af45f087e4cc27d9089fc67efe51a6041e5a5..51599b3ef91ada9428cf9f796c5404dff2d58bda 100644
--- a/src/mednet/libs/segmentation/models/segmentation_model.py
+++ b/src/mednet/libs/segmentation/models/segmentation_model.py
@@ -90,11 +90,11 @@ class SegmentationModel(Model):
         self.normalizer = make_z_normalizer(dataloader)
 
     def training_step(self, batch, batch_idx):
-        images = batch[0]["image"]
-        ground_truths = batch[0]["target"]
-        masks = batch[0]["mask"]
+        images = self.augmentation_transforms(batch[0]["image"])
+        ground_truths = self.augmentation_transforms(batch[0]["target"])
+        masks = self.augmentation_transforms(batch[0]["mask"])
 
-        outputs = self(self._augmentation_transforms(images))
+        outputs = self(images)
         return self._train_loss(outputs, ground_truths, masks)
 
     def validation_step(self, batch, batch_idx):
diff --git a/src/mednet/libs/segmentation/scripts/experiment.py b/src/mednet/libs/segmentation/scripts/experiment.py
index 5960e1394e9fc819d9b44e90055d7d1dfd74f0d5..6a53edf76ac94a5246168d2c0009f6649b9c6cc9 100644
--- a/src/mednet/libs/segmentation/scripts/experiment.py
+++ b/src/mednet/libs/segmentation/scripts/experiment.py
@@ -50,6 +50,7 @@ def experiment(
     seed,
     parallel,
     monitoring_interval,
+    augmentations,
     **_,
 ):  # numpydoc ignore=PR01
     r"""Run a complete experiment, from training, to prediction and evaluation.
@@ -85,6 +86,7 @@ def experiment(
         seed=seed,
         parallel=parallel,
         monitoring_interval=monitoring_interval,
+        augmentations=augmentations,
     )
     train_stop_timestamp = datetime.now()