Skip to content
Snippets Groups Projects
Commit 32ab5c65 authored by Driss KHALIL's avatar Driss KHALIL
Browse files

[Rebase] rebase merged request

parent 59180f52
No related branches found
No related tags found
1 merge request!31[Rebase] rebase merged request
Pipeline #53059 passed
......@@ -5,7 +5,7 @@
def _maker(protocol):
from ....data.drhagis import dataset as raw
from ....data.transforms import Resize
from ....data.transforms import Resize, ResizeCrop
from .. import make_dataset as mk
return mk(raw.subsets(protocol), [Resize((1760, 1760))])
return mk(raw.subsets(protocol), [ResizeCrop(), Resize((1760, 1760))])
......@@ -18,6 +18,8 @@ import PIL.Image
import torchvision.transforms
import torchvision.transforms.functional
from PIL import ImageOps
class TupleMixin:
"""Adds support to work with tuples of objects to torchvision transforms"""
......@@ -266,3 +268,57 @@ class ColorJitter(torchvision.transforms.ColorJitter):
def __repr__(self):
retval = super(ColorJitter, self).__repr__()
return retval.replace("(", f"(p={self.p},", 1)
def _expand2square(pil_img, background_color):
"""
Function that pad the minimum between the height and the width to fit a square
"""
width, height = pil_img.size
if width == height:
return pil_img
elif width > height:
result = PIL.Image.new(pil_img.mode, (width, width), background_color)
result.paste(pil_img, (0, (width - height) // 2))
return result
else:
result = PIL.Image.new(pil_img.mode, (height, height), background_color)
result.paste(pil_img, ((height - width) // 2, 0))
return result
class ResizeCrop:
"""
Crop all the images by removing the black pixels in the width and height until it finds a non-black pixel.
"""
def __call__(self, *args):
img = args[0]
label = args[1]
mask = args[2]
mask_data = numpy.asarray(mask)
wid = numpy.sum(mask_data, axis=0)
heig = numpy.sum(mask_data, axis=1)
crop_left, crop_right = (wid != 0).argmax(axis=0), (
wid[::-1] != 0
).argmax(axis=0)
crop_up, crop_down = (heig != 0).argmax(axis=0), (
heig[::-1] != 0
).argmax(axis=0)
border = (crop_left, crop_up, crop_right, crop_down)
new_mask = ImageOps.crop(mask, border)
new_img = ImageOps.crop(img, border)
new_label = ImageOps.crop(label, border)
new_img = _expand2square(new_img, (0, 0, 0))
new_label = _expand2square(new_label, 0)
new_mask = _expand2square(new_mask, 0)
args = (new_img, new_label, new_mask)
return args
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment