utils.py 3.63 KB
Newer Older
1
2
3
4
5
#!/usr/bin/env python
# encoding: utf-8

import numpy

6
7
import torchvision.transforms as transforms

8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
class FaceCropper():
  """
    Class to crop a face, based on eyes position
  """
  def __init__(self, cropped_height, cropped_width):
    # the face cropper
    from bob.bio.face.preprocessor import FaceCrop
    cropped_image_size = (cropped_height, cropped_width)
    right_eye_pos = (cropped_height // 5, cropped_width // 4 -1)
    left_eye_pos = (cropped_height // 5, cropped_width // 4 * 3)
    cropped_positions = {'leye': left_eye_pos, 'reye': right_eye_pos}
    self.face_cropper = FaceCrop(cropped_image_size=cropped_image_size,
                                 cropped_positions=cropped_positions,
                                 color_channel='rgb',
                                 dtype='uint8'
                                )

  def __call__(self, sample):
    cropped = self.face_cropper(sample['image'], sample['eyes'])
    sample['image'] = cropped
    return sample
29

30

31
32
33
34
35
class RollChannels(object):
  """
    Class to transform a bob image into skimage.
    i.e. CxHxW to HxWxC
  """
36
37
38
39
  def __call__(self, sample):
    temp = numpy.rollaxis(numpy.rollaxis(sample['image'], 2),2)
    sample['image'] = temp
    return sample
40

41
class ToTensor(object):
42
  def __init__(self):
43
    self.op = transforms.ToTensor()
44

45
  def __call__(self, sample):
46
    sample['image'] = self.op(sample['image'])
47
48
    return sample

49
50
51
52
53
54
55
class Normalize(object):
  def __init__(self, mean, std):
    self.op = transforms.Normalize(mean, std)

  def __call__(self, sample):
    sample['image'] = self.op(sample['image'])
    return sample
56

57
58
59
60
61
62
63
64
65
66
67
68
69
class Resize(object):
  def __init__(self, size):
    self.op = transforms.Resize(size)

  def __call__(self, sample):
    # convert to PIL image
    from PIL.Image import fromarray
    img = fromarray(sample['image'].squeeze())
    img = self.op(img)
    sample['image'] = numpy.array(img)
    sample['image'] = sample['image'][..., numpy.newaxis]
    return sample

70
71
72
73
74
75
76
77
78
79
80
81
82
class ToGray(object):
  def __init__(self):
    self.op = transforms.Grayscale()

  def __call__(self, sample):
    # convert to PIL image
    from PIL.Image import fromarray
    img = fromarray(sample['image'].squeeze())
    img = self.op(img)
    sample['image'] = numpy.array(img)
    sample['image'] = sample['image'][..., numpy.newaxis]
    return sample

83

84
def map_labels(raw_labels, start_index=0):
85
  """
86
  Map the ID label to [0 - # of IDs]
87
88
89
  """
  possible_labels = list(set(raw_labels))
  labels = numpy.array(raw_labels)
90

91
92
  for i in range(len(possible_labels)):
    l = possible_labels[i]
93
    labels[numpy.where(labels==l)[0]] = i + start_index
94

95
  return labels
96
97
98
99
100

from torch.utils.data import Dataset
import bob.io.base
import bob.io.image

101

102
103
104
105
106
107
108
109
110
111
class ConcatDataset(Dataset):
  """
  Class to concatenate two or more datasets for DR-GAN training

  **Parameters**

  datasets: list
    The list of datasets (as torch.utils.data.Dataset)
  """
  def __init__(self, datasets):
112

113
114
115
    self.transform = datasets[0].transform
    self.data_files = sum((d.data_files for d in datasets), [])
    self.pose_labels = sum((d.pose_labels for d in datasets), [])
116
    self.id_labels = sum((d.id_labels for d in datasets), [])
117

118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
  def __len__(self):
      """
        return the length of the dataset (i.e. nb of examples)
      """
      return len(self.data_files)


  def __getitem__(self, idx):
      """
        return a sample from the dataset
      """
      image = bob.io.base.load(self.data_files[idx])
      identity = self.id_labels[idx]
      pose = self.pose_labels[idx]
      sample = {'image': image, 'id': identity, 'pose': pose}

      if self.transform:
        sample = self.transform(sample)

      return sample
138
139