Commit e698cd3b authored by Guillaume HEUSCH's avatar Guillaume HEUSCH
Browse files

[utils] add a transform to convert to grayscale

parent c076b423
......@@ -8,6 +8,7 @@ from .utils import RollChannels
from .utils import ToTensor
from .utils import Normalize
from .utils import Resize
from .utils import ToGray
from .utils import map_labels
from .utils import ConcatDataset
......
......@@ -67,6 +67,19 @@ class Resize(object):
sample['image'] = sample['image'][..., numpy.newaxis]
return sample
class ToGray(object):
def __init__(self):
self.op = transforms.Grayscale()
def __call__(self, sample):
# convert to PIL image
from PIL.Image import fromarray
img = fromarray(sample['image'].squeeze())
img = self.op(img)
sample['image'] = numpy.array(img)
sample['image'] = sample['image'][..., numpy.newaxis]
return sample
def map_labels(raw_labels, start_index=0):
"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment