Skip to content
Snippets Groups Projects
Commit 362f15e7 authored by Guillaume HEUSCH's avatar Guillaume HEUSCH
Browse files

[datasets] added the code to build the Multi-PIE PyTorch dataset

parent 7d67b1b7
No related branches found
No related tags found
No related merge requests found
# gets sphinx autodoc done right - don't remove it
__all__ = [_ for _ in dir() if not _.startswith('_')]
#!/usr/bin/env python
# encoding: utf-8
import os
import torch
import numpy
from torch.utils.data import Dataset, DataLoader
import bob.db.multipie
import bob.io.base
import bob.io.image
class MultiPIEDataset(Dataset):
"""MultiPIE dataset."""
def __init__(self, root_dir, world=False, transform=None):
"""
"""
self.root_dir = root_dir
self.transform = transform
self.world = world
camera_to_pose = {'11_0': 'l90',
'12_0': 'l75',
'09_0': 'l60',
'08_0': 'l45',
'13_0': 'l30',
'14_0': 'l15',
'05_1': '0',
'05_0': 'r15',
'04_1': 'r30',
'19_0': 'r45',
'20_0': 'r60',
'01_0': 'r75',
'24_0': 'r90'}
camera_to_label = {'11_0': '0',
'12_0': '1',
'09_0': '2',
'08_0': '3',
'13_0': '4',
'14_0': '5',
'05_1': '6',
'05_0': '7',
'04_1': '8',
'19_0': '9',
'20_0': '10',
'01_0': '11',
'24_0': '12'}
# get all the needed file, the pose labels, and the id labels
self.data_files = []
self.pose_labels = []
id_labels = []
db = bob.db.multipie.Database()
if world:
c_set = db.clients(groups='world')
else:
c_set = db.clients()
ids = [client.id for client in c_set]
ids = sorted(ids, key=int)
# filename and pose label are dependent on the camera
for camera in sorted(db.camera_names()):
if world:
objs = db.objects(cameras=camera, groups='world')
else:
objs = db.objects(cameras=camera)
# skip "high" cameras
if (camera == '19_1') or (camera == '08_1'):
continue
for obj in objs:
temp = os.path.split(obj.path)
identity = int(temp[0].split('/')[2])
id_labels.append(identity)
cropped_filename = os.path.join(root_dir, camera_to_pose[camera], temp[1])
cropped_filename += '.png'
self.data_files.append(cropped_filename)
self.pose_labels.append(camera_to_label[camera])
self.id_labels = self.map_labels(id_labels)
print max(self.id_labels)
def __len__(self):
return len(self.data_files)
def __getitem__(self, idx):
image = bob.io.base.load(self.data_files[idx])
identity = self.id_labels[idx]
pose = self.pose_labels[idx]
sample = {'image': image, 'id': identity, 'pose': pose}
if self.transform:
sample = self.transform(sample)
return sample
def map_labels(self, raw_labels):
"""
Map the clients to 0 to 1
"""
possible_labels = list(set(raw_labels))
labels = numpy.array(raw_labels)
for i in range(len(possible_labels)):
l = possible_labels[i]
labels[numpy.where(labels==l)[0]] = i
return labels
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment