train.py 7.97 KB
Newer Older
1 2 3 4 5 6 7
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
# Fri Feb 13 13:18:10 2015 +0200
#
# Copyright (C) 2011-2015 Idiap Research Institute, Martigny, Switzerland
import numpy
8
from ._library import *
9
import logging
10

11
logger = logging.getLogger(__name__)
12

13

14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47
def _set_average(trainer, trainers, data):
  """_set_average(trainer, data) -> None

  This function computes the average of the given data and sets it to the given machine.

  This function works for different types of trainers, and can be used to parallelize the training.
  For some trainers, the data is returned instead of set in the trainer.

  **Parameters:**

  trainer : one of :py:class:`KMeansTrainer`, :py:class:`MAP_GMMTrainer`, :py:class:`ML_GMMTrainer`, :py:class:`ISVTrainer`, :py:class:`IVectorTrainer`, :py:class:`PLDATrainer`, :py:class:`EMPCATrainer`
    The trainer to set the data to.

  trainers : [ trainer ]
    The list of trainer objects that were used in the parallel training process.
    All trainers must be of the same class as the ``trainer``.

  data : [ object ]
    The list of data objects that should be set to the trainer.
    Usually this list is generated by parallelizing the e-step of the ``trainer``.
  """

  if isinstance(trainer, KMeansTrainer):
    # K-Means statistics
    trainer.zeroeth_order_statistics = numpy.zeros(0., trainer.zeroeth_order_statistics.shape)
    trainer.first_order_statistics = numpy.zeros(0., trainer.first_order_statistics.shape)
    trainer.average_min_distance = 0.

    for t in trainer:
      trainer.zeroeth_order_statistics += t.zeroeth_order_statistics
      trainer.first_order_statistics += t.first_order_statistics
      trainer.average_min_distance += trainer.average_min_distance
    trainer.average_min_distance /= sum(d.shape[0] for d in data)

48
  elif isinstance(trainer, (ML_GMMTrainer, MAP_GMMTrainer)):
49
    # GMM statistics
50
    trainer.gmm_statistics = trainers[0].gmm_statistics
51
    for t in trainers[1:]:
52
      trainer.gmm_statistics += t.gmm_statistics
53 54 55 56 57 58 59 60 61 62 63 64 65

  else:
    raise NotImplementedError("Implement Me!")


def _parallel_e_step(args):
  """This function applies the e_step of the given trainer (first argument) on the given data (second argument).
  It is called by each parallel process.
  """
  trainer, machine, data = args
  trainer.e_step(machine, data)


66
def train(trainer, machine, data, max_iterations=50, convergence_threshold=None, initialize=True, rng=None, check_inputs=True, pool=None):
67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85

  """
  Trains a machine given a trainer and the proper data

  **Parameters**:
    trainer : one of :py:class:`KMeansTrainer`, :py:class:`MAP_GMMTrainer`, :py:class:`ML_GMMTrainer`, :py:class:`ISVTrainer`, :py:class:`IVectorTrainer`, :py:class:`PLDATrainer`, :py:class:`EMPCATrainer`
      A trainer mechanism
    machine : one of :py:class:`KMeansMachine`, :py:class:`GMMMachine`, :py:class:`ISVBase`, :py:class:`IVectorMachine`, :py:class:`PLDAMachine`, :py:class:`bob.learn.linear.Machine`
      A container machine
    data : array_like <float, 2D>
      The data to be trained
    max_iterations : int
      The maximum number of iterations to train a machine
    convergence_threshold : float
      The convergence threshold to train a machine. If None, the training procedure will stop with the iterations criteria
    initialize : bool
      If True, runs the initialization procedure
    rng :  :py:class:`bob.core.random.mt19937`
      The Mersenne Twister mt19937 random generator used for the initialization of subspaces/arrays before the EM loop
86 87
    check_inputs:
      Shallow checks in the inputs. Check for inf and NaN
88 89 90
    pool : :py:class:`multiprocessing.Pool` or ``None``
      If given, the provided process pool will be used to parallelize the M-step of the EM algorithm
  """
91
  if check_inputs and isinstance(data, numpy.ndarray):
92 93 94 95 96 97 98
    sum_data = numpy.sum(data)

    if numpy.isinf(sum_data):
        raise ValueError("Please, check your inputs; numpy.inf detected in `data` ")

    if numpy.isnan(sum_data):
        raise ValueError("Please, check your inputs; numpy.nan detected in `data` ")
99 100

  def _e_step(trainer, machine, data):
101

102 103 104 105 106
    # performs the e-step, possibly in parallel
    if pool is None:
      # use only one core
      trainer.e_step(machine, data)
    else:
107

108 109 110
      # use the given process pool
      n_processes = pool._processes

111 112 113
      # Mapping references of the data
      split_data = []
      offset = 0
114 115 116 117
      step = int(data.shape[0] // n_processes)
      for p in range(n_processes):
        if p == n_processes - 1:
          # take all the data in the last chunk
118 119
          split_data.append(data[offset:])
        else:
120 121
          split_data.append(data[offset: offset + step])

122
        offset += step
123

124
      # create trainers for each process
125 126
      trainers = [trainer.__class__(trainer) for p in range(n_processes)]
      # no need to copy the machines
127
      machines = [machine.__class__(machine) for p in range(n_processes)]
128
      # call the parallel processes
129
      pool.map(_parallel_e_step, zip(trainers, machines, split_data))
130 131 132
      # update the trainer with the data of the other trainers
      _set_average(trainer, trainers, data)

133
  # Initialization
134 135 136 137 138 139
  if initialize:
    if rng is not None:
      trainer.initialize(machine, data, rng)
    else:
      trainer.initialize(machine, data)

140
  _e_step(trainer, machine, data)
141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168
  average_output          = 0
  average_output_previous = 0

  if hasattr(trainer,"compute_likelihood"):
    average_output          = trainer.compute_likelihood(machine)

  for i in range(max_iterations):
    logger.info("Iteration = %d/%d", i, max_iterations)
    average_output_previous = average_output
    trainer.m_step(machine, data)
    _e_step(trainer, machine,data)

    if hasattr(trainer,"compute_likelihood"):
      average_output = trainer.compute_likelihood(machine)

      if isinstance(machine, KMeansMachine):
        logger.info("average euclidean distance = %f", average_output)
      else:
        logger.info("log likelihood = %f", average_output)

      convergence_value = abs((average_output_previous - average_output)/average_output_previous)
      logger.info("convergence value = %f",convergence_value)

      #Terminates if converged (and likelihood computation is set)
      if convergence_threshold is not None and convergence_value <= convergence_threshold:
        break
  if hasattr(trainer,"finalize"):
    trainer.finalize(machine, data)
169 170


171
def train_jfa(trainer, jfa_base, data, max_iterations=10, initialize=True, rng=None):
172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218
  """
  Trains a :py:class:`bob.learn.em.JFABase` given a :py:class:`bob.learn.em.JFATrainer` and the proper data

  **Parameters**:
    trainer : :py:class:`bob.learn.em.JFATrainer`
      A JFA trainer mechanism
    jfa_base : :py:class:`bob.learn.em.JFABase`
      A container machine
    data : [[:py:class:`bob.learn.em.GMMStats`]]
      The data to be trained
    max_iterations : int
      The maximum number of iterations to train a machine
    initialize : bool
      If True, runs the initialization procedure
    rng :  :py:class:`bob.core.random.mt19937`
      The Mersenne Twister mt19937 random generator used for the initialization of subspaces/arrays before the EM loops
  """

  if initialize:
    if rng is not None:
      trainer.initialize(jfa_base, data, rng)
    else:
      trainer.initialize(jfa_base, data)

  #V Subspace
  logger.info("V subspace estimation...")
  for i in range(max_iterations):
    logger.info("Iteration = %d/%d", i, max_iterations)
    trainer.e_step_v(jfa_base, data)
    trainer.m_step_v(jfa_base, data)
  trainer.finalize_v(jfa_base, data)

  #U subspace
  logger.info("U subspace estimation...")
  for i in range(max_iterations):
    logger.info("Iteration = %d/%d", i, max_iterations)
    trainer.e_step_u(jfa_base, data)
    trainer.m_step_u(jfa_base, data)
  trainer.finalize_u(jfa_base, data)

  # D subspace
  logger.info("D subspace estimation...")
  for i in range(max_iterations):
    logger.info("Iteration = %d/%d", i, max_iterations)
    trainer.e_step_d(jfa_base, data)
    trainer.m_step_d(jfa_base, data)
  trainer.finalize_d(jfa_base, data)
219