From a15d2d27e9a21cd5f9885a6c509099fdab33e559 Mon Sep 17 00:00:00 2001
From: Andre Anjos <andre.anjos@idiap.ch>
Date: Mon, 29 Jun 2020 14:53:44 +0200
Subject: [PATCH] [engine.evaluator] Fix function to calculate patch
 performance

---
 bob/ip/binseg/engine/evaluator.py | 35 ++++++++++++++++++++++++-------
 1 file changed, 28 insertions(+), 7 deletions(-)

diff --git a/bob/ip/binseg/engine/evaluator.py b/bob/ip/binseg/engine/evaluator.py
index 1b5be9f7..37fb4e2b 100644
--- a/bob/ip/binseg/engine/evaluator.py
+++ b/bob/ip/binseg/engine/evaluator.py
@@ -4,6 +4,7 @@
 """Defines functionality for the evaluation of predictions"""
 
 import os
+import itertools
 
 import PIL
 import numpy
@@ -11,6 +12,7 @@ import pandas
 from tqdm import tqdm
 
 import torch
+import torch.nn.functional
 import torchvision.transforms.functional as VF
 
 import h5py
@@ -184,13 +186,32 @@ def _patch_measures(pred, gt, steps, size):
 
     """
 
-    height, width, stride = window_size
-    pred_patches = pred.unfold(0, height, stride).unfold(1, width, stride)
-    gt_patches = unfold(0, height, stride).unfold(1, width, stride)
-
-    # add patch number for each set of measures
-    dfs = [_sample_measures(p, g, step) for p,g in zip(pred_patches, gt_patches)]
-    for i, k in enumerate(dfs): k['patch'] = i
+    height, width, stride = size
+
+    # we calculate the required padding so that the last windows on the left
+    # and bottom size of predictions/ground-truth data are zero padded, and
+    # torch unfolding works exactly.
+    padding = (0, 0)
+    rem = (pred.shape[1] - width) % stride
+    if rem != 0:
+        padding = (0, (stride-rem))
+    rem = (pred.shape[0] - height) % stride
+    if rem != 0:
+        padding += (0, (stride-rem))
+
+    pred_padded = torch.nn.functional.pad(pred, padding)
+    gt_padded = torch.nn.functional.pad(gt.squeeze(0), padding)
+
+    # this will create as many views as required
+    pred_patches = pred_padded.unfold(0, height, stride).unfold(1, width, stride)
+    gt_patches = gt_padded.unfold(0, height, stride).unfold(1, width, stride)
+    assert pred_patches.shape == gt_patches.shape
+    ylen, xlen, _, _ = pred_patches.shape
+
+    dfs = []
+    for j, i in itertools.product(range(ylen), range(xlen)):
+        dfs.append(_sample_measures(pred_patches[j,i,:,:], gt_patches[j,i,:,:], steps))
+        dfs[-1]['patch'] = i+(j*xlen)
 
     return pandas.concat(dfs, ignore_index=True)
 
-- 
GitLab