diff --git a/bob/paper/nir_patch_pooling/config/patch_pooling_lr.py b/bob/paper/nir_patch_pooling/config/patch_pooling_lr.py
index 17f63564bb7146e5955681c4f6e7f0edee53dcd8..192e9d29c69732d2a8c0add2b13504e662010f70 100644
--- a/bob/paper/nir_patch_pooling/config/patch_pooling_lr.py
+++ b/bob/paper/nir_patch_pooling/config/patch_pooling_lr.py
@@ -1,4 +1,3 @@
-#!/usr/bin/env python
 # -*- coding: utf-8 -*-
 
 """
@@ -6,7 +5,7 @@ Configuration file to run PatchPooling + LR classifier for Face PAD
 toward detection of mask attacks in NIR.
 
 """
-#----------------------------------------------------------
+#------------------------------------------------------------------------------
 
 sub_directory = "pooling_lr"
 
@@ -17,6 +16,7 @@ from bob.pad.face.preprocessor import FaceCropAlign
 from bob.bio.video.preprocessor import Wrapper
 from bob.bio.video.utils import FrameSelector
 
+
 # parameters and constants
 FACE_SIZE = 128
 RGB_OUTPUT_FLAG = False
@@ -37,9 +37,10 @@ _image_preprocessor = FaceCropAlign(face_size=FACE_SIZE,
 
 _frame_selector = FrameSelector(selection_style = "all")
 
-preprocessor = Wrapper(preprocessor = _image_preprocessor, frame_selector = _frame_selector)
+preprocessor = Wrapper(preprocessor = _image_preprocessor,
+    frame_selector = _frame_selector)
 
-#----------------------------------------------------------
+#------------------------------------------------------------------------------
 
 # define extractor:
 
@@ -48,16 +49,19 @@ from bob.bio.video.extractor import Wrapper
 from bob.extension import rc
 import os
 
-_model_dir = rc.get("LIGHTCNN9_MODEL_DIRECTORY")
+_model_directory = rc["lightcnn9.model.directory"]
 _model_name = "LightCNN_9Layers_checkpoint.pth.tar"
-_model_file = os.path.join(_model_dir, _model_name)
+_model_file = os.path.join(_model_directory, _model_name)
+
 if not os.path.exists(_model_file):
-    print("Error: Could not find the LightCNN-9 model at [{}].\nPlease follow the download instructions from README".format(_model_dir))
+    print("Error: Could not find the LightCNN-9 model [{}].\nPlease follow \
+    the download instructions from README".format(_model_directory))
     exit(0)
 
-extractor = Wrapper(PatchPoolingCNN(model_file=_model_file), frame_selector = _frame_selector)
+extractor = Wrapper(PatchPoolingCNN(model_file=_model_file),
+    frame_selector = _frame_selector)
 
-#----------------------------------------------------------
+#------------------------------------------------------------------------------
 
 # define algorithm
 
@@ -67,12 +71,6 @@ C = 1.0
 
 algorithm = LogRegr(C=C, frame_level_scores_flag=True)
 
-#----------------------------------------------------------
-
-
-
-
-
-
+#------------------------------------------------------------------------------
 
 
diff --git a/bob/paper/nir_patch_pooling/extractor/patch_pooling_cnn.py b/bob/paper/nir_patch_pooling/extractor/patch_pooling_cnn.py
index 5941c877c73d2fba8b661d8e241f5bfefb925b72..e98cb5a17177cd79f3c61f769a06c146f1d01e3d 100644
--- a/bob/paper/nir_patch_pooling/extractor/patch_pooling_cnn.py
+++ b/bob/paper/nir_patch_pooling/extractor/patch_pooling_cnn.py
@@ -1,10 +1,8 @@
-#!/usr/bin/env python2
 # -*- coding: utf-8 -*-
 
 """
-Implementation of PCNN feature extractor for LightCNN-9. 
+Implementation of Patch Pooling CNN feature extractor with LightCNN-9 backbone
 @author: Ketan Kotwal
-
 """
 
 # Imports
@@ -21,14 +19,13 @@ import logging
 logger = logging.getLogger(__name__)
 logger.setLevel(logging.INFO)
 
-#----------------------------------------------------------
+#------------------------------------------------------------------------------
 
 class PatchPoolingCNN(Extractor):
 
     """
-    The class implements the feature extraction of LightCNN9 embeddings.
-    It has some implementation differences from a similar extractor from
-    bob.learn.pytorch.
+    The class implements extraction of patch pooled features from the final
+    convolutional layer of LightCNN9 (MFM5 layer). 
     """
   
     def __init__(self, model_file=None, num_classes=79077):
@@ -39,30 +36,36 @@ class PatchPoolingCNN(Extractor):
         # load the model into network. 
         cp = torch.load(model_file, map_location="cpu")
       
-        # checked if pre-trained model was saved using nn.DataParallel ...
-        saved_with_nnDataParallel = False
+        # checked if pre-trained model was saved using nn.DataParallel
+        saved_with_data_parallel = False
         for k, v in cp["state_dict"].items():
             if("module" in k):
-                saved_with_nnDataParallel = True
+                saved_with_data_parallel = True
                 break
  
         # if DataParallel format, remove module term
-        if(saved_with_nnDataParallel):
+        if(saved_with_data_parallel):
             if("state_dict" in cp):
+
                 from collections import OrderedDict
                 new_state_dict = OrderedDict()
+                
                 for k, v in cp["state_dict"].items():
                     name = k[7:]
                     new_state_dict[name] = v
+                
                 self.network.load_state_dict(new_state_dict)
         else:
+
             self.network.load_state_dict(cp["state_dict"])
+        
         self.network.eval()
 
         # image pre-processing
-        self.data_transform = transforms.Compose([transforms.Resize(size=128), transforms.ToTensor()])
+        self.data_transform = transforms.Compose([transforms.Resize(size=128),
+            transforms.ToTensor()])
 
-#----------------------------------------------------------
+#------------------------------------------------------------------------------
 
     def __call__(self, image):
 
@@ -76,7 +79,7 @@ class PatchPoolingCNN(Extractor):
         Returns
         -------
         feature : :py:class:`numpy.ndarray` (floats)
-        The extracted features as a 1d array of size 320 
+        The extracted features as a 1d array of size 256 
     
         """
   
@@ -84,45 +87,49 @@ class PatchPoolingCNN(Extractor):
         pil_image = Image.fromarray(image.astype(np.uint8))
         input_image = self.data_transform(pil_image)
         input_image = input_image.unsqueeze(0)
-    
-        # to be compliant with the loaded model, where weight and biases are torch.FloatTensor
         input_image = input_image.float()
 
+        # obtain the features (to be pooled) from forward pass of network
         _ , features = self.network.forward(Variable(input_image))
+
+        # pool features through patch-level processing
         features = self.conv_to_patch(features)
         features = features.data.numpy().flatten()
         return features.astype(np.float64)
 
-#----------------------------------------------------------
+#------------------------------------------------------------------------------
 
     def conv_to_patch(self, features):
 
-        logger.debug("Shape of input features: {} {}".format(features.shape, features.squeeze().shape))
-
         # parameters for the patch conversion
-        stride = 4 # orig:4 #feat.shape[2]/4  
-        idx = 0 # purely for debugging
-        num_patch = features.shape[2]/stride
-        feat_patch = torch.zeros(1, stride*stride*features.shape[1])
+        stride = 4 # features.shape[2]/4  
+
+        # for debugging
+        # idx = 0 
+        # num_patch = features.shape[2]/stride
+
+        pooled_features = torch.zeros(1, stride*stride*features.shape[1])
 
         # obtain patches by tesselation of feature maps
+        # pool linearized version of individual patches
         for i in range(0, features.shape[2], stride):
             for j in range(0, features.shape[3], stride):
-                feat_tmp = features[:, :, i:i+stride, j:j+stride]
-                feat_tmp = feat_tmp.contiguous().view(feat_tmp.size(0), -1)
-                feat_patch += feat_tmp
-                idx += 1
+                feat_temp = features[:, :, i:i+stride, j:j+stride]
+                feat_temp = feat_temp.contiguous().view(feat_temp.size(0), -1)
+                pooled_features += feat_temp
+                # idx += 1
 
-        # normalize the patch vector
-        feat_patch = feat_patch/stride/stride
-        logger.debug("Feat patch shape: {}".format(feat_patch.shape))
+        # normalize the vector of pooled features
+        pooled_features = pooled_features/stride/stride
+        
+        return pooled_features
 
-        return feat_patch
+#------------------------------------------------------------------------------
 
-#----------------------------------------------------------
+#------------------------------------------------------------------------------
 
-# class LightCNN9Patch: it inherits the LightCNN-9 class, and returns the last
-# conv layer features; instead of embeddings.
+# class LightCNN9Patch: it inherits the LightCNN-9 class from bob,
+# and returns the last conv layer features and embeddings.
 
 class LightCNN9Patch(LightCNN9):
 
@@ -131,7 +138,7 @@ class LightCNN9Patch(LightCNN9):
         # do not change the init
         super(LightCNN9Patch, self).__init__()    
 
-#----------------------------------------------------------
+#------------------------------------------------------------------------------
 
     def forward(self, x):
 
@@ -147,14 +154,7 @@ class LightCNN9Patch(LightCNN9):
         out = self.fc2(x)
         return out, conv_out
 
-#----------------------------------------------------------
-
-
-
-
-
-
-
+#------------------------------------------------------------------------------