LightCNN9.py 2.83 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48
import numpy

import torch
from torch.autograd import Variable

import torchvision.transforms as transforms

from bob.learn.pytorch.architectures import LightCNN9
from bob.bio.base.extractor import Extractor

class LightCNN9Extractor(Extractor):
  """ The class implementing the feature extraction of LightCNN9 embeddings.

  Attributes
  ----------
  network: :py:class:`torch.nn.Module`
      The network architecture
  to_tensor: :py:mod:`torchvision.transforms`
      The transform from numpy.array to torch.Tensor
  norm: :py:mod:`torchvision.transforms`
      The transform to normalize the input 

  """
  
  def __init__(self, model_file=None, num_classes=79077):
    """ Init method

    Parameters
    ----------
    model_file: str
        The path of the trained network to load
    num_classes: int
        The number of classes.
    
    """
        
    Extractor.__init__(self, skip_extractor_training=True)
    
    # model
    self.network = LightCNN9(num_classes)
    if model_file is None:
      # do nothing (used mainly for unit testing) 
      pass
    else:

      cp = torch.load(model_file, map_location='cpu')
      
      # checked if pre-trained model was saved using nn.DataParallel ...
49
      saved_with_nnDataParallel = False
50 51 52 53 54
      for k, v in cp['state_dict'].items():
        if 'module' in k:
          saved_with_nnDataParallel = True
          break

55
      # if it was, you have to rename the keys of state_dict ... (i.e. remove 'module.')
56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98
      if saved_with_nnDataParallel:
        if 'state_dict' in cp:
          from collections import OrderedDict
          new_state_dict = OrderedDict()
          for k, v in cp['state_dict'].items():
            name = k[7:]
            new_state_dict[name] = v
          self.network.load_state_dict(new_state_dict)
      else:
          self.network.load_state_dict(cp['state_dict'])
    self.network.eval()

    # image pre-processing
    self.to_tensor = transforms.ToTensor()
    self.norm = transforms.Normalize((0.5,), (0.5,))

  def __call__(self, image):
    """ Extract features from an image

    Parameters
    ----------
    image : 2D :py:class:`numpy.ndarray` (floats)
      The grayscale image to extract the features from. Its size must be 128x128

    Returns
    -------
    feature : :py:class:`numpy.ndarray` (floats)
      The extracted features as a 1d array of size 320 
    
    """
  
    # torchvision.transforms expect a numpy array of size HxWxC
    input_image = numpy.expand_dims(image, axis=2)
    input_image = self.to_tensor(input_image)
    input_image = self.norm(input_image)
    input_image = input_image.unsqueeze(0)
    
    # to be compliant with the loaded model, where weight and biases are torch.FloatTensor
    input_image = input_image.float()

    _ , features = self.network.forward(Variable(input_image))
    features = features.data.numpy().flatten()
    return features