Skip to content
Snippets Groups Projects
Commit d75f924d authored by Daniel CARRON's avatar Daniel CARRON :b: Committed by André Anjos
Browse files

[segmentation] Apply augmentation transforms

parent e4a0b7e2
No related branches found
No related tags found
1 merge request!46Create common library
...@@ -429,6 +429,13 @@ lwnet = "mednet.libs.segmentation.config.models.lwnet" ...@@ -429,6 +429,13 @@ lwnet = "mednet.libs.segmentation.config.models.lwnet"
m2unet = "mednet.libs.segmentation.config.models.m2unet" m2unet = "mednet.libs.segmentation.config.models.m2unet"
unet = "mednet.libs.segmentation.config.models.unet" unet = "mednet.libs.segmentation.config.models.unet"
affine = "mednet.libs.common.config.augmentations.affine"
elastic = "mednet.libs.common.config.augmentations.elastic"
hflip = "mednet.libs.common.config.augmentations.hflip"
jitter = "mednet.libs.common.config.augmentations.jitter"
hflip-affine = "mednet.libs.common.config.augmentations.hflip_affine"
hflip-jitter-affine = "mednet.libs.common.config.augmentations.hflip_jitter_affine"
# chase-db1 - retinography # chase-db1 - retinography
chasedb1 = "mednet.libs.segmentation.config.data.chasedb1.first_annotator" chasedb1 = "mednet.libs.segmentation.config.data.chasedb1.first_annotator"
chasedb1-2nd = "mednet.libs.segmentation.config.data.chasedb1.second_annotator" chasedb1-2nd = "mednet.libs.segmentation.config.data.chasedb1.second_annotator"
......
...@@ -90,11 +90,11 @@ class SegmentationModel(Model): ...@@ -90,11 +90,11 @@ class SegmentationModel(Model):
self.normalizer = make_z_normalizer(dataloader) self.normalizer = make_z_normalizer(dataloader)
def training_step(self, batch, batch_idx): def training_step(self, batch, batch_idx):
images = batch[0]["image"] images = self.augmentation_transforms(batch[0]["image"])
ground_truths = batch[0]["target"] ground_truths = self.augmentation_transforms(batch[0]["target"])
masks = batch[0]["mask"] masks = self.augmentation_transforms(batch[0]["mask"])
outputs = self(self._augmentation_transforms(images)) outputs = self(images)
return self._train_loss(outputs, ground_truths, masks) return self._train_loss(outputs, ground_truths, masks)
def validation_step(self, batch, batch_idx): def validation_step(self, batch, batch_idx):
......
...@@ -50,6 +50,7 @@ def experiment( ...@@ -50,6 +50,7 @@ def experiment(
seed, seed,
parallel, parallel,
monitoring_interval, monitoring_interval,
augmentations,
**_, **_,
): # numpydoc ignore=PR01 ): # numpydoc ignore=PR01
r"""Run a complete experiment, from training, to prediction and evaluation. r"""Run a complete experiment, from training, to prediction and evaluation.
...@@ -85,6 +86,7 @@ def experiment( ...@@ -85,6 +86,7 @@ def experiment(
seed=seed, seed=seed,
parallel=parallel, parallel=parallel,
monitoring_interval=monitoring_interval, monitoring_interval=monitoring_interval,
augmentations=augmentations,
) )
train_stop_timestamp = datetime.now() train_stop_timestamp = datetime.now()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment