stacks.py 4.68 KB
Newer Older
1
2
from ..utils.processors import SequentialProcessor, ParallelProcessor
from .Extractor import Extractor
3
from bob.io.base import HDF5File
4
5
6
7
8
9
10


class MultipleExtractor(Extractor):
  """Base class for SequentialExtractor and ParallelExtractor. This class is
  not meant to be used directly."""

  def get_attributes(self, processors):
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
11
    requires_training = any(p.requires_training for p in processors)
12
13
14
15
16
    split_training_data_by_client = any(p.split_training_data_by_client for p
                                        in processors)
    min_extractor_file_size = min(p.min_extractor_file_size for p in
                                  processors)
    min_feature_file_size = min(p.min_feature_file_size for p in processors)
17
18
19
    return (requires_training, split_training_data_by_client,
            min_extractor_file_size, min_feature_file_size)

20
21
22
  def get_extractor_groups(self):
    groups = ['E_{}'.format(i + 1) for i in range(len(self.processors))]
    return groups
23
24
25
26

  def train_one(self, e, training_data, extractor_file, apply=False):
    if not e.requires_training:
      return
27
28
    # if any of the extractors require splitting the data, the
    # split_training_data_by_client is True.
29
30
31
32
33
    if e.split_training_data_by_client:
      e.train(training_data, extractor_file)
      if not apply:
        return
      training_data = [[e(d) for d in datalist] for datalist in training_data]
34
    # when no extractor needs splitting
35
36
37
38
39
    elif not self.split_training_data_by_client:
      e.train(training_data, extractor_file)
      if not apply:
        return
      training_data = [e(d) for d in training_data]
40
    # when e here wants it flat but the data is split
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
    else:
      # make training_data flat
      training_data_len = [len(datalist) for datalist in training_data]
      training_data = [d for datalist in training_data for d in datalist]
      e.train(training_data, extractor_file)
      if not apply:
        return
      # split training data
      new_training_data, i = [], 0
      for length in training_data_len:
        class_data = []
        for _ in range(length):
          class_data.append(e(training_data[i]))
          i += 1
        new_training_data.append(class_data)
      training_data = new_training_data
    return training_data

  def load(self, extractor_file):
60
61
62
63
64
65
    with HDF5File(extractor_file) as f:
      groups = self.get_extractor_groups()
      for e, group in zip(self.processors, groups):
        f.cd(group)
        e.load(f)
        f.cd('..')
66
67
68
69
70
71
72
73


class SequentialExtractor(SequentialProcessor, MultipleExtractor):
  __doc__ = SequentialProcessor.__doc__

  def __init__(self, processors):

    (requires_training, split_training_data_by_client,
74
75
     min_extractor_file_size, min_feature_file_size) = \
        self.get_attributes(processors)
76

77
78
    super(SequentialExtractor, self).__init__(
        processors=processors,
79
80
81
82
83
84
        requires_training=requires_training,
        split_training_data_by_client=split_training_data_by_client,
        min_extractor_file_size=min_extractor_file_size,
        min_feature_file_size=min_feature_file_size)

  def train(self, training_data, extractor_file):
85
86
87
88
89
90
91
    with HDF5File(extractor_file, 'w') as f:
      groups = self.get_extractor_groups()
      for e, group in zip(self.processors, groups):
        f.create_group(group)
        f.cd(group)
        training_data = self.train_one(e, training_data, f, apply=True)
        f.cd('..')
92
93
94
95
96
97
98
99
100
101
102


class ParallelExtractor(ParallelProcessor, MultipleExtractor):
  __doc__ = ParallelProcessor.__doc__

  def __init__(self, processors):

    (requires_training, split_training_data_by_client,
     min_extractor_file_size, min_feature_file_size) = self.get_attributes(
        processors)

103
104
    super(ParallelExtractor, self).__init__(
        processors=processors,
105
106
107
108
109
110
        requires_training=requires_training,
        split_training_data_by_client=split_training_data_by_client,
        min_extractor_file_size=min_extractor_file_size,
        min_feature_file_size=min_feature_file_size)

  def train(self, training_data, extractor_file):
111
112
113
114
115
116
117
    with HDF5File(extractor_file, 'w') as f:
      groups = self.get_extractor_groups()
      for e, group in zip(self.processors, groups):
        f.create_group(group)
        f.cd(group)
        self.train_one(e, training_data, f, apply=False)
        f.cd('..')
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136


class CallableExtractor(Extractor):
  """A simple extractor that takes a callable and applies that callable to the
  input.

  Attributes
  ----------
  callable : object
      Anything that is callable. It will be used as an extractor in
      bob.bio.base.
  """

  def __init__(self, callable, **kwargs):
    super(CallableExtractor, self).__init__(**kwargs)
    self.callable = callable

  def __call__(self, data):
    return self.callable(data)