Commit dad4c652 authored by Guillaume HEUSCH's avatar Guillaume HEUSCH

[dataset] reorganized the code

parent f02fa509
from .multipie import MultiPIEDataset
from .utils import RollChannels
from .utils import map_labels
# gets sphinx autodoc done right - don't remove it
__all__ = [_ for _ in dir() if not _.startswith('_')]
......@@ -12,13 +12,31 @@ import bob.db.multipie
import bob.io.base
import bob.io.image
from .utils import map_labels
class MultiPIEDataset(Dataset):
"""MultiPIE dataset."""
"""MultiPIE dataset.
Class represeting the Multi-PIE dataset
**Parameters**
root-dir: path
The path to the data
world: boolean
If you want to only use data corresponding to the world model
frontal_only: boolean
If you want to only use frontal faces
transform: torchvision.transforms
The transform(s) to apply to the face images
"""
# TODO: Start from original data and annotations - Guillaume HEUSCH, 06-11-2017
def __init__(self, root_dir, world=False, frontal_only=False, transform=None):
"""
"""
self.root_dir = root_dir
self.transform = transform
self.world = world
......@@ -91,45 +109,26 @@ class MultiPIEDataset(Dataset):
self.data_files.append(cropped_filename)
self.pose_labels.append(camera_to_label[camera])
self.id_labels = self.map_labels(id_labels)
self.id_labels = map_labels(id_labels)
def __len__(self):
"""
return the length of the dataset (i.e. nb of examples)
"""
return len(self.data_files)
def __getitem__(self, idx):
image = bob.io.base.load(self.data_files[idx])
identity = self.id_labels[idx]
pose = self.pose_labels[idx]
sample = {'image': image, 'id': identity, 'pose': pose}
if self.transform:
sample = self.transform(sample['image'])
return sample
def map_labels(self, raw_labels):
"""
Map the clients to 0 to 1
return a sample from the dataset
"""
possible_labels = list(set(raw_labels))
labels = numpy.array(raw_labels)
for i in range(len(possible_labels)):
l = possible_labels[i]
labels[numpy.where(labels==l)[0]] = i
return labels
class RollChannels(object):
image = bob.io.base.load(self.data_files[idx])
identity = self.id_labels[idx]
pose = self.pose_labels[idx]
sample = {'image': image, 'id': identity, 'pose': pose}
def __init__(self):
pass
if self.transform:
sample = self.transform(sample['image'])
def __call__(self, sample):
temp = numpy.rollaxis(numpy.rollaxis(sample, 2),2)
sample = temp
return sample
#!/usr/bin/env python
# encoding: utf-8
import numpy
class RollChannels(object):
"""
Class to transform a bob image into skimage.
i.e. CxHxW to HxWxC
"""
def __init__(self):
pass
def __call__(self, sample):
temp = numpy.rollaxis(numpy.rollaxis(sample, 2),2)
sample = temp
return sample
def map_labels(raw_labels):
"""
Map the ID label to [0 - # of IDs]
"""
possible_labels = list(set(raw_labels))
labels = numpy.array(raw_labels)
for i in range(len(possible_labels)):
l = possible_labels[i]
labels[numpy.where(labels==l)[0]] = i
return labels
......@@ -54,8 +54,8 @@ import torchvision.utils as vutils
from torch.autograd import Variable
# data and architecture from the package
from bob.learn.pytorch.datasets.multipie import MultiPIEDataset
from bob.learn.pytorch.datasets.multipie import RollChannels
from bob.learn.pytorch.datasets import MultiPIEDataset
from bob.learn.pytorch.datasets import RollChannels
from bob.learn.pytorch.architectures.DCGAN import _netG
from bob.learn.pytorch.architectures.DCGAN import _netD
......
......@@ -74,7 +74,7 @@ setup(
# scripts should be declared using this entry:
'console_scripts': [
'train.py = bob.learn.pytorch.scripts.train:main',
'train_dcgan_multipie.py = bob.learn.pytorch.scripts.train_dcgan_multipie:main',
],
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment