Commit 9b70f326 authored by Guillaume HEUSCH's avatar Guillaume HEUSCH

[dataset, transforms] wrote wrapper around transforms to not discard label information

parent d8b6a02a
from .multipie import MultiPIEDataset
# transforms
from .utils import RollChannels
from .utils import ToTensor
from .utils import Normalize
from .utils import map_labels
......
......@@ -129,6 +129,6 @@ class MultiPIEDataset(Dataset):
sample = {'image': image, 'id': identity, 'pose': pose}
if self.transform:
sample = self.transform(sample['image'])
sample = self.transform(sample)
return sample
......@@ -3,20 +3,34 @@
import numpy
import torchvision.transforms as transforms
class RollChannels(object):
"""
Class to transform a bob image into skimage.
i.e. CxHxW to HxWxC
"""
def __call__(self, sample):
temp = numpy.rollaxis(numpy.rollaxis(sample['image'], 2),2)
sample['image'] = temp
return sample
class ToTensor(object):
def __init__(self):
pass
self.op = transforms.ToTensor()
def __call__(self, sample):
temp = numpy.rollaxis(numpy.rollaxis(sample, 2),2)
sample = temp
sample['image'] = self.op(sample['image'])
return sample
class Normalize(object):
def __init__(self, mean, std):
self.op = transforms.Normalize(mean, std)
def __call__(self, sample):
sample['image'] = self.op(sample['image'])
return sample
def map_labels(raw_labels):
"""
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment