Skip to content
Snippets Groups Projects
Commit a4bc1166 authored by Daniel CARRON's avatar Daniel CARRON :b: Committed by André Anjos
Browse files

[segmentation] Ensure crop to mask is applied correclty on databases

parent 1e726a8c
No related branches found
No related tags found
1 merge request!46Create common library
Showing
with 17 additions and 16 deletions
......@@ -429,7 +429,7 @@ driu-pix = "mednet.libs.segmentation.config.models.driu_pix"
hed = "mednet.libs.segmentation.config.models.hed"
lwnet = "mednet.libs.segmentation.config.models.lwnet"
m2unet = "mednet.libs.segmentation.config.models.m2unet"
resunet = "mednet.libs.segmentation.config.models.resunet"
#resunet = "mednet.libs.segmentation.config.models.resunet"
unet = "mednet.libs.segmentation.config.models.unet"
# chase-db1 - retinography
......
......@@ -11,6 +11,7 @@ import pkg_resources
from mednet.libs.common.data.datamodule import CachingDataModule
from mednet.libs.common.data.split import make_split
from mednet.libs.common.data.typing import Sample
from mednet.libs.common.models.transforms import crop_image_to_mask
from mednet.libs.segmentation.data.typing import (
SegmentationRawDataLoader as _SegmentationRawDataLoader,
)
......@@ -54,21 +55,20 @@ class SegmentationRawDataLoader(_SegmentationRawDataLoader):
The sample representation.
"""
image = PIL.Image.open(self.datadir / sample[0]).convert(mode="RGB")
tensor = tv_tensors.Image(to_tensor(image))
target = tv_tensors.Image(
to_tensor(
PIL.Image.open(self.datadir / sample[1]).convert(mode="1", dither=None)
)
image = to_tensor(PIL.Image.open(self.datadir / sample[0]).convert(mode="RGB"))
target = to_tensor(
PIL.Image.open(self.datadir / sample[1]).convert(mode="1", dither=None)
)
mask = tv_tensors.Mask(
to_tensor(
PIL.Image.open(self._pkg_path / sample[2]).convert(
mode="1", dither=None
)
)
mask = to_tensor(
PIL.Image.open(self._pkg_path / sample[2]).convert(mode="1", dither=None)
)
tensor = tv_tensors.Image(crop_image_to_mask(image, mask))
target = tv_tensors.Image(crop_image_to_mask(target, mask))
mask = tv_tensors.Mask(crop_image_to_mask(mask, mask))
return tensor, dict(target=target, mask=mask, name=sample[0]) # type: ignore[arg-type]
......
......@@ -11,6 +11,7 @@ import pkg_resources
from mednet.libs.common.data.datamodule import CachingDataModule
from mednet.libs.common.data.split import make_split
from mednet.libs.common.data.typing import Sample
from mednet.libs.common.models.transforms import crop_image_to_mask
from mednet.libs.segmentation.data.typing import (
SegmentationRawDataLoader as _SegmentationRawDataLoader,
)
......@@ -87,9 +88,9 @@ class SegmentationRawDataLoader(_SegmentationRawDataLoader):
PIL.Image.open(self._pkg_path / sample[2]).convert(mode="1", dither=None)
)
tensor = tv_tensors.Image(image)
target = tv_tensors.Image(target)
mask = tv_tensors.Mask(mask)
tensor = tv_tensors.Image(crop_image_to_mask(image, mask))
target = tv_tensors.Image(crop_image_to_mask(target, mask))
mask = tv_tensors.Mask(crop_image_to_mask(mask, mask))
return tensor, dict(target=target, mask=mask, name=sample[0]) # type: ignore[arg-type]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment