Skip to content
Snippets Groups Projects
Commit 4413f921 authored by André Anjos's avatar André Anjos :speech_balloon:
Browse files

Merge branch 'DrhagisUpdates2' into 'master'

[Rebase] rebase merged request

See merge request bob/bob.ip.binseg!31
parents 59180f52 3d402727
No related branches found
No related tags found
1 merge request!31[Rebase] rebase merged request
Pipeline #53234 passed
......@@ -5,7 +5,7 @@
def _maker(protocol):
from ....data.drhagis import dataset as raw
from ....data.transforms import Resize
from ....data.transforms import Resize, ResizeCrop
from .. import make_dataset as mk
return mk(raw.subsets(protocol), [Resize((1760, 1760))])
return mk(raw.subsets(protocol), [ResizeCrop(), Resize((1760, 1760))])
......@@ -15,6 +15,7 @@ import random
import numpy
import PIL.Image
import PIL.ImageOps
import torchvision.transforms
import torchvision.transforms.functional
......@@ -266,3 +267,75 @@ class ColorJitter(torchvision.transforms.ColorJitter):
def __repr__(self):
retval = super(ColorJitter, self).__repr__()
return retval.replace("(", f"(p={self.p},", 1)
def _expand2square(pil_img, background_color):
"""
Function that pad the minimum between the height and the width to fit a square
Parameters
----------
pil_img : PIL.Image.Image
A PIL image that represents the image for analysis.
background_color: py:class:`tuple`, Optional
A tuple to represent the color of the background of the image in order to pad with the same color.
If the image is an RGB image background_color should be a tuple of size 3 , if it's a grayscale image the variable can be represented with an integer.
Returns
-------
image : PIL.Image.Image
A new image with height equal to width.
"""
width, height = pil_img.size
if width == height:
return pil_img
elif width > height:
result = PIL.Image.new(pil_img.mode, (width, width), background_color)
result.paste(pil_img, (0, (width - height) // 2))
return result
else:
result = PIL.Image.new(pil_img.mode, (height, height), background_color)
result.paste(pil_img, ((height - width) // 2, 0))
return result
class ResizeCrop:
"""
Crop all the images by removing the black pixels in the width and height until it finds a non-black pixel.
"""
def __call__(self, *args):
img = args[0]
label = args[1]
mask = args[2]
mask_data = numpy.asarray(mask)
wid = numpy.sum(mask_data, axis=0)
heig = numpy.sum(mask_data, axis=1)
crop_left, crop_right = (wid != 0).argmax(axis=0), (
wid[::-1] != 0
).argmax(axis=0)
crop_up, crop_down = (heig != 0).argmax(axis=0), (
heig[::-1] != 0
).argmax(axis=0)
border = (crop_left, crop_up, crop_right, crop_down)
new_mask = PIL.ImageOps.crop(mask, border)
new_img = PIL.ImageOps.crop(img, border)
new_label = PIL.ImageOps.crop(label, border)
new_img = _expand2square(new_img, (0, 0, 0))
new_label = _expand2square(new_label, 0)
new_mask = _expand2square(new_mask, 0)
args = (new_img, new_label, new_mask)
return args
......@@ -6,6 +6,8 @@ import random
import numpy
import PIL.Image
import PIL.ImageDraw
import PIL.ImageOps
import pkg_resources
import torch
import torchvision.transforms.functional
......@@ -20,6 +22,7 @@ from ..data.transforms import (
RandomRotation,
RandomVerticalFlip,
Resize,
ResizeCrop,
SingleAutoLevel16to8,
ToTensor,
)
......@@ -372,3 +375,36 @@ def test_16bit_autolevel():
assert timg.getextrema() == (0, 255)
# timg.show()
# import ipdb; ipdb.set_trace()
def test_ResizeCrop():
# parameters
im_size = (3, 128, 140) # (planes, height, width)
mask_gt_size = (1, 128, 140) # (planes, height, width)
crop_size = (30, 30, 91, 91) # (left,up , right ,down)
size_after_crop = (61, 61)
idx = (slice(crop_size[0], crop_size[2]), slice(crop_size[1], crop_size[3]))
# Create random image and a mask with a circle inside
img, gt = _create_img(im_size), _create_img(mask_gt_size)
mask = PIL.Image.new("L", (140, 128), "black")
dr = PIL.ImageDraw.Draw(mask)
dr.ellipse((30, 30, 90, 90), "white")
# Test
transform = ResizeCrop()
img_, gt_, mask_ = transform(img, gt, mask)
assert img_.size == size_after_crop
assert gt_.size == size_after_crop
assert mask_.size == size_after_crop
assert img_.mode == "RGB"
assert gt_.mode == "L"
assert mask_.mode == "L"
assert numpy.all(numpy.array(img_) == numpy.array(img)[idx])
assert numpy.all(numpy.array(gt_) == numpy.array(gt)[idx])
assert numpy.all(numpy.array(mask_) == numpy.array(mask)[idx])
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment