From 8a158686576d437c5ae3abca63082c4c1333aeae Mon Sep 17 00:00:00 2001
From: ageorge <anjith.george@idiap.ch>
Date: Mon, 21 Jan 2019 18:13:36 +0100
Subject: [PATCH] WIP, MCCNN extractor

---
 bob/ip/pytorch_extractor/MCCNN.py    | 102 +++++++++++++++++++++++++++
 bob/ip/pytorch_extractor/__init__.py |   2 +
 bob/ip/pytorch_extractor/test.py     |  14 ++++
 3 files changed, 118 insertions(+)
 create mode 100644 bob/ip/pytorch_extractor/MCCNN.py

diff --git a/bob/ip/pytorch_extractor/MCCNN.py b/bob/ip/pytorch_extractor/MCCNN.py
new file mode 100644
index 0000000..aa9bc5e
--- /dev/null
+++ b/bob/ip/pytorch_extractor/MCCNN.py
@@ -0,0 +1,102 @@
+import numpy as np
+
+import torch
+from torch.autograd import Variable
+
+import torchvision.transforms as transforms
+
+from bob.learn.pytorch.architectures import MCCNN
+from bob.bio.base.extractor import Extractor
+
+
+#TODO: Clean up
+
+
+class MCCNNExtractor(Extractor):
+  """ The class implementing the MC-CNN score computation.
+
+  Attributes
+  ----------
+  network: :py:class:`torch.nn.Module`
+      The network architecture
+  transforms: :py:mod:`torchvision.transforms`
+      The transform from numpy.array to torch.Tensor
+
+  """
+  
+  def __init__(self, considered_modalities=['C','D','I','T'], model_file=None):
+    """ Init method
+
+    Parameters
+    ----------
+    considered_modalities: list
+      The list of modalities used C,D,I,T represents color, depth , infrared and thermal respectively
+    pretrained_lightCNN_modelpath: str
+      Path to the Pretrained LightCNN model
+    NB: There are two model files here; one is the pretrained Light CNN model which is used as the base network; then there is another model file
+    specifically trained for PAD (model_file) which contains the new adapted layers and the fully connected layers.
+    model_file: str
+        The path of the trained PAD network to load
+    
+    """
+
+    Extractor.__init__(self, skip_extractor_training=True)
+    
+    # model
+    self.network = MCCNN(considered_modalities=considered_modalities)#.net
+    
+    #self.network=self.network.to(device)
+
+    if model_file is None:
+      # do nothing (used mainly for unit testing) 
+
+      print("No pretrained file")
+      pass
+    else:
+
+      # Old approach 
+
+      # state_dict = torch.load(model_file,map_location=lambda storage,loc:storage)
+
+      # self.network.load_state_dict(state_dict)
+
+      # With the new training
+      cp = torch.load(model_file)
+      if 'state_dict' in cp:
+        self.network.load_state_dict(cp['state_dict'])
+
+      print('Loaded the pretrained PAD model')    
+ 
+    self.network.net.eval()
+
+    # image pre-processing
+    self.transforms= transforms.Compose([transforms.ToTensor()])
+
+  def __call__(self, image):
+    """ Extract features from an image
+
+    Parameters
+    ----------
+    image : 3D :py:class:`numpy.ndarray` (floats)
+      The multi-channel image to extract the score from. Its size must be 4x128x128;
+      The channels should be ordered in C D I T order (color, depth, infrared and thermal respectively)
+
+    Returns
+    -------
+    output : float
+      The extracted feature is a scalar values ~0 for bonafide and ~1 for PAs
+    
+    """
+   
+    input_image = np.rollaxis(np.rollaxis(image, 2),2)
+    input_image = self.transforms(input_image)
+    input_image = input_image.unsqueeze(0)
+
+    #print("input_image",input_image.shape)
+    
+    output = self.network.forward(Variable(input_image))
+    output = output.data.numpy().flatten()
+
+    # output is a scalar score
+
+    return output
diff --git a/bob/ip/pytorch_extractor/__init__.py b/bob/ip/pytorch_extractor/__init__.py
index d3b6b9c..3bbb43c 100755
--- a/bob/ip/pytorch_extractor/__init__.py
+++ b/bob/ip/pytorch_extractor/__init__.py
@@ -1,5 +1,6 @@
 from .CNN8 import CNN8Extractor
 from .CasiaNet import CasiaNetExtractor
+from .MCCNN import MCCNNExtractor
 
 # gets sphinx autodoc done right - don't remove it
 def __appropriate__(*args):
@@ -19,6 +20,7 @@ def __appropriate__(*args):
 __appropriate__(
     CNN8Extractor,
     CasiaNetExtractor,
+    MCCNNExtractor,
 )
 
 # gets sphinx autodoc done right - don't remove it
diff --git a/bob/ip/pytorch_extractor/test.py b/bob/ip/pytorch_extractor/test.py
index 6411890..95a6823 100644
--- a/bob/ip/pytorch_extractor/test.py
+++ b/bob/ip/pytorch_extractor/test.py
@@ -32,3 +32,17 @@ def test_casianet():
     data = numpy.random.rand(3, 128, 128).astype("float32")
     output = extractor(data)
     assert output.shape[0] == 320
+
+
+def test_mccnn():
+    """ test for the MCCNN architecture
+
+        this architecture takes 4x128x128 images as input
+        output a single score
+    """
+    from . import MCCNNExtractor
+    extractor = MCCNNExtractor()
+    # this architecture expects 4x128x128 Multi channel images
+    data = numpy.random.rand(4, 128, 128).astype("float32")
+    output = extractor(data)
+    assert output.shape[0] == 1
-- 
GitLab