Skip to content
Snippets Groups Projects
Commit 37a4e70f authored by André Anjos's avatar André Anjos :speech_balloon:
Browse files

Merge branch 'ground_truth_crop_transform' into 'master'

Ground Truth Box Crop Transform

See merge request biosignal/software/deepdraw!66
parents f4c4a140 5c13431e
Branches
Tags
1 merge request!66Ground Truth Box Crop Transform
Pipeline #71313 passed
......@@ -171,16 +171,20 @@ combined-disc = "deepdraw.binseg.configs.datasets.combined.od"
combined-cup = "deepdraw.binseg.configs.datasets.combined.oc"
# montgomery county - cxr
montgomery = "deepdraw.binseg.configs.datasets.montgomery.default"
montgomery-gt = "deepdraw.binseg.configs.datasets.montgomery.default_gtcrop"
montgomery-xtest = "deepdraw.binseg.configs.datasets.montgomery.xtest"
# shenzhen - cxr
shenzhen = "deepdraw.binseg.configs.datasets.shenzhen.default"
shenzhen-small = "deepdraw.binseg.configs.datasets.shenzhen.default_256"
shenzhen-gt = "deepdraw.binseg.configs.datasets.shenzhen.default_gtcrop"
shenzhen-xtest = "deepdraw.binseg.configs.datasets.shenzhen.xtest"
# jsrt - cxr
jsrt = "deepdraw.binseg.configs.datasets.jsrt.default"
jsrt-gt = "deepdraw.binseg.configs.datasets.jsrt.default_gtcrop"
jsrt-xtest = "deepdraw.binseg.configs.datasets.jsrt.xtest"
# cxr8 - cxr
cxr8 = "deepdraw.binseg.configs.datasets.cxr8.default"
cxr8-gt = "deepdraw.binseg.configs.datasets.cxr8.default_gtcrop"
cxr8-xtest = "deepdraw.binseg.configs.datasets.cxr8.xtest"
[project.entry-points."detect.config"]
......
......@@ -55,3 +55,51 @@ def _maker_augmented(protocol, n):
)
],
)
def _maker_augmented_gt_box(protocol, n):
from .....common.data.transforms import ColorJitter as _jitter
from .....common.data.transforms import Compose as _compose
from .....common.data.transforms import GaussianBlur as _blur
from .....common.data.transforms import GroundTruthCrop as _gtcrop
from .....common.data.transforms import RandomHorizontalFlip as _hflip
from .....common.data.transforms import RandomRotation as _rotation
from .....common.data.transforms import Resize as _resize
from ....data.cxr8 import dataset as raw
from .. import make_subset
def mk_aug_subset(subsets, train_transforms, all_transforms):
retval = {}
for key in subsets.keys():
retval[key] = make_subset(subsets[key], transforms=all_transforms)
if key == "train":
retval["__train__"] = make_subset(
subsets[key],
transforms=train_transforms,
)
else:
if key == "validation":
retval["__valid__"] = retval[key]
if ("__train__" in retval) and ("__valid__" not in retval):
retval["__valid__"] = retval["__train__"]
return retval
return mk_aug_subset(
subsets=raw.subsets(protocol),
all_transforms=[_gtcrop(extra_area=0.2), _resize((n, n))],
train_transforms=[
_compose(
[
_gtcrop(extra_area=0.2),
_resize((n, n)),
_rotation(degrees=15, p=0.5),
_hflip(p=0.5),
_jitter(p=0.5),
_blur(p=0.5),
]
)
],
)
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""CXR8 Dataset (default protocol)
* Split reference: [GAAL-2020]_
* Configuration resolution: 256 x 256
* See :py:mod:`deepdraw.binseg.data.cxr8` for dataset details
"""
from . import _maker_augmented_gt_box
dataset = _maker_augmented_gt_box("default", 256)
......@@ -55,3 +55,51 @@ def _maker_augmented(protocol):
)
],
)
def _maker_augmented_gt_box(protocol):
from .....common.data.transforms import ColorJitter as _jitter
from .....common.data.transforms import Compose as _compose
from .....common.data.transforms import GaussianBlur as _blur
from .....common.data.transforms import GroundTruthCrop as _gtcrop
from .....common.data.transforms import RandomHorizontalFlip as _hflip
from .....common.data.transforms import RandomRotation as _rotation
from .....common.data.transforms import Resize as _resize
from ....data.jsrt import dataset as raw
from .. import make_subset
def mk_aug_subset(subsets, train_transforms, all_transforms):
retval = {}
for key in subsets.keys():
retval[key] = make_subset(subsets[key], transforms=all_transforms)
if key == "train":
retval["__train__"] = make_subset(
subsets[key],
transforms=train_transforms,
)
else:
if key == "validation":
retval["__valid__"] = retval[key]
if ("__train__" in retval) and ("__valid__" not in retval):
retval["__valid__"] = retval["__train__"]
return retval
return mk_aug_subset(
subsets=raw.subsets(protocol),
all_transforms=[_gtcrop(extra_area=0.2), _resize((256, 256))],
train_transforms=[
_compose(
[
_gtcrop(extra_area=0.2),
_resize((256, 256)),
_rotation(degrees=15, p=0.5),
_hflip(p=0.5),
_jitter(p=0.5),
_blur(p=0.5),
]
)
],
)
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""Japanese Society of Radiological Technology dataset for Lung Segmentation
(default protocol)
* Split reference: [GAAL-2020]_
* Configuration resolution: 256 x 256
* See :py:mod:`deepdraw.binseg.data.jsrt` for dataset details
"""
from . import _maker_augmented_gt_box
dataset = _maker_augmented_gt_box("default")
......@@ -57,3 +57,57 @@ def _maker_augmented(protocol):
)
],
)
def _maker_augmented_gt_box(protocol):
from .....common.data.transforms import ColorJitter as _jitter
from .....common.data.transforms import Compose as _compose
from .....common.data.transforms import GaussianBlur as _blur
from .....common.data.transforms import GroundTruthCrop as _gtcrop
from .....common.data.transforms import RandomHorizontalFlip as _hflip
from .....common.data.transforms import RandomRotation as _rotation
from .....common.data.transforms import Resize as _resize
from .....common.data.transforms import ShrinkIntoSquare as _shrinkintosq
from ....data.montgomery import dataset as raw
from .. import make_subset
def mk_aug_subset(subsets, train_transforms, all_transforms):
retval = {}
for key in subsets.keys():
retval[key] = make_subset(subsets[key], transforms=all_transforms)
if key == "train":
retval["__train__"] = make_subset(
subsets[key],
transforms=train_transforms,
)
else:
if key == "validation":
retval["__valid__"] = retval[key]
if ("__train__" in retval) and ("__valid__" not in retval):
retval["__valid__"] = retval["__train__"]
return retval
return mk_aug_subset(
subsets=raw.subsets(protocol),
all_transforms=[
_shrinkintosq(),
_gtcrop(extra_area=0.2),
_resize((256, 256)),
],
train_transforms=[
_compose(
[
_shrinkintosq(),
_gtcrop(extra_area=0.2),
_resize((256, 256)),
_rotation(degrees=15, p=0.5),
_hflip(p=0.5),
_jitter(p=0.5),
_blur(p=0.5),
]
)
],
)
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""Montgomery County dataset for Lung Segmentation (default protocol)
* Split reference: [GAAL-2020]_
* Configuration resolution: 256 x 256
* See :py:mod:`deepdraw.binseg.data.montgomery` for dataset details
"""
from . import _maker_augmented_gt_box
dataset = _maker_augmented_gt_box("default")
......@@ -57,3 +57,57 @@ def _maker_augmented(protocol, n):
)
],
)
def _maker_augmented_gt_box(protocol, n):
from .....common.data.transforms import ColorJitter as _jitter
from .....common.data.transforms import Compose as _compose
from .....common.data.transforms import GaussianBlur as _blur
from .....common.data.transforms import GroundTruthCrop as _gtcrop
from .....common.data.transforms import RandomHorizontalFlip as _hflip
from .....common.data.transforms import RandomRotation as _rotation
from .....common.data.transforms import Resize as _resize
from .....common.data.transforms import ShrinkIntoSquare as _shrinkintosq
from ....data.shenzhen import dataset as raw
from .. import make_subset
def mk_aug_subset(subsets, train_transforms, all_transforms):
retval = {}
for key in subsets.keys():
retval[key] = make_subset(subsets[key], transforms=all_transforms)
if key == "train":
retval["__train__"] = make_subset(
subsets[key],
transforms=train_transforms,
)
else:
if key == "validation":
retval["__valid__"] = retval[key]
if ("__train__" in retval) and ("__valid__" not in retval):
retval["__valid__"] = retval["__train__"]
return retval
return mk_aug_subset(
subsets=raw.subsets(protocol),
all_transforms=[
_shrinkintosq(),
_gtcrop(extra_area=0.2),
_resize((n, n)),
],
train_transforms=[
_compose(
[
_shrinkintosq(),
_gtcrop(extra_area=0.2),
_resize((n, n)),
_rotation(degrees=15, p=0.5),
_hflip(p=0.5),
_jitter(p=0.5),
_blur(p=0.5),
]
)
],
)
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""Shenzhen dataset for Lung Segmentation (default protocol)
* Split reference: [GAAL-2020]_
* Configuration resolution: 256 x 256
* See :py:mod:`deepdraw.binseg.data.shenzhen` for dataset details
"""
from . import _maker_augmented_gt_box
dataset = _maker_augmented_gt_box("default", 256)
......@@ -430,3 +430,110 @@ class GetBoundingBox:
target["labels"] = labels
return [args[self.image], target]
class GroundTruthCrop:
"""Crop image in a square keeping only the area with the ground truth.
This transform can crop all images given a ground-truth mask as reference.
Notice that the crop will result in a square image at the end, which means
that it will keep the bigger dimension and adjust the smaller one to fit
into a square. There's an option to add extra area around the gt bounding
box. If resulting dimensions are larger than the boundaries of the image,
minimal padding will be done to keep the image in a square shape.
Parameters
----------
reference : :py:class:`int`, Optional
Which reference part of the sample to use for getting coordinates.
If not set, use the second object on the sample (typically, the mask).
extra_area : :py:class:`float`, Optional
Multiplier that will add the extra area around the ground-truth
bounding box. Example: 0.1 will result in a crop with dimensions of
the largest side increased by 10%. If not set, the default will be 0
(only the ground-truth box).
"""
def __init__(self, reference=1, extra_area=0.0):
self.reference = reference
self.extra_area = extra_area
def __call__(self, *args):
ref = args[self.reference]
max_w, max_h = ref.size
where = numpy.where(ref)
y0 = numpy.min(where[0])
y1 = numpy.max(where[0])
x0 = numpy.min(where[1])
x1 = numpy.max(where[1])
w = x1 - x0
h = y1 - y0
extra_x = self.extra_area * w / 2
extra_y = self.extra_area * h / 2
new_w = (1 + self.extra_area) * w
new_h = (1 + self.extra_area) * h
diff = abs(new_w - new_h) / 2
if new_w == new_h:
x0_new = x0.copy() - extra_x
x1_new = x1.copy() + extra_x
y0_new = y0.copy() - extra_y
y1_new = y1.copy() + extra_y
elif new_w > new_h:
x0_new = x0.copy() - extra_x
x1_new = x1.copy() + extra_x
y0_new = y0.copy() - extra_y - diff
y1_new = y1.copy() + extra_y + diff
else:
x0_new = x0.copy() - extra_x - diff
x1_new = x1.copy() + extra_x + diff
y0_new = y0.copy() - extra_y
y1_new = y1.copy() + extra_y
border = (x0_new, y0_new, max_w - x1_new, max_h - y1_new)
def _expand_img(
pil_img, background_color, x0_pad=0, x1_pad=0, y0_pad=0, y1_pad=0
):
width = pil_img.size[0] + x0_pad + x1_pad
height = pil_img.size[1] + y0_pad + y1_pad
result = PIL.Image.new(
pil_img.mode, (width, height), background_color
)
result.paste(pil_img, (x0_pad, y0_pad))
return result
def _black_background(i):
return (0, 0, 0) if i.mode == "RGB" else 0
d_x0 = numpy.rint(max([0 - x0_new, 0])).astype(int)
d_y0 = numpy.rint(max([0 - y0_new, 0])).astype(int)
d_x1 = numpy.rint(max([x1_new - max_w, 0])).astype(int)
d_y1 = numpy.rint(max([y1_new - max_h, 0])).astype(int)
new_args = [
_expand_img(
k,
_black_background(k),
x0_pad=d_x0,
x1_pad=d_x1,
y0_pad=d_y0,
y1_pad=d_y1,
)
for k in args
]
new_args = [PIL.ImageOps.crop(k, border) for k in new_args]
return new_args
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment