Skip to content
Snippets Groups Projects
Commit 1606385f authored by Guillaume HEUSCH's avatar Guillaume HEUSCH
Browse files

[datasets] added frontal only option (to test with DC-GAN) and changed the RollChannels transform

parent 21a3784f
No related branches found
No related tags found
No related merge requests found
...@@ -16,7 +16,7 @@ import bob.io.image ...@@ -16,7 +16,7 @@ import bob.io.image
class MultiPIEDataset(Dataset): class MultiPIEDataset(Dataset):
"""MultiPIE dataset.""" """MultiPIE dataset."""
def __init__(self, root_dir, world=False, transform=None): def __init__(self, root_dir, world=False, frontal_only=False, transform=None):
""" """
""" """
self.root_dir = root_dir self.root_dir = root_dir
...@@ -78,6 +78,10 @@ class MultiPIEDataset(Dataset): ...@@ -78,6 +78,10 @@ class MultiPIEDataset(Dataset):
if (camera == '19_1') or (camera == '08_1'): if (camera == '19_1') or (camera == '08_1'):
continue continue
# skip cameras that are not frontal if we want frontal only
if (camera != '05_1') and frontal_only:
continue
for obj in objs: for obj in objs:
temp = os.path.split(obj.path) temp = os.path.split(obj.path)
identity = int(temp[0].split('/')[2]) identity = int(temp[0].split('/')[2])
...@@ -88,7 +92,6 @@ class MultiPIEDataset(Dataset): ...@@ -88,7 +92,6 @@ class MultiPIEDataset(Dataset):
self.pose_labels.append(camera_to_label[camera]) self.pose_labels.append(camera_to_label[camera])
self.id_labels = self.map_labels(id_labels) self.id_labels = self.map_labels(id_labels)
print max(self.id_labels)
def __len__(self): def __len__(self):
...@@ -102,7 +105,7 @@ class MultiPIEDataset(Dataset): ...@@ -102,7 +105,7 @@ class MultiPIEDataset(Dataset):
sample = {'image': image, 'id': identity, 'pose': pose} sample = {'image': image, 'id': identity, 'pose': pose}
if self.transform: if self.transform:
sample = self.transform(sample) sample = self.transform(sample['image'])
return sample return sample
...@@ -121,3 +124,12 @@ class MultiPIEDataset(Dataset): ...@@ -121,3 +124,12 @@ class MultiPIEDataset(Dataset):
return labels return labels
class RollChannels(object):
def __init__(self):
pass
def __call__(self, sample):
temp = numpy.rollaxis(numpy.rollaxis(sample, 2),2)
sample = temp
return sample
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment