From 07a96ada0656b7faf4d0dac615b6d7ee738b40ae Mon Sep 17 00:00:00 2001
From: Andre Anjos <andre.dos.anjos@gmail.com>
Date: Mon, 16 Mar 2020 12:08:25 +0100
Subject: [PATCH] [engine] Cast vector types to avoid issues with pytorch > 1.0

---
 bob/ip/binseg/engine/inferencer.py | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/bob/ip/binseg/engine/inferencer.py b/bob/ip/binseg/engine/inferencer.py
index c1ed5faf..e76153c0 100644
--- a/bob/ip/binseg/engine/inferencer.py
+++ b/bob/ip/binseg/engine/inferencer.py
@@ -60,8 +60,8 @@ def batch_metrics(predictions, ground_truths, names, output_folder, logger):
                 binary_pred = torch.gt(predictions[j], threshold).byte()
 
                 # equals and not-equals
-                equals = torch.eq(binary_pred, gts) # tensor
-                notequals = torch.ne(binary_pred, gts) # tensor
+                equals = torch.eq(binary_pred, gts).type(torch.uint8) # tensor
+                notequals = torch.ne(binary_pred, gts).type(torch.uint8) # tensor
 
                 # true positives
                 tp_tensor = (gts * binary_pred ) # tensor
@@ -76,7 +76,7 @@ def batch_metrics(predictions, ground_truths, names, output_folder, logger):
                 tn_count = torch.sum(tn_tensor).item()
 
                 # false negatives
-                fn_tensor = notequals - fp_tensor
+                fn_tensor = notequals - fp_tensor.type(torch.uint8)
                 fn_count = torch.sum(fn_tensor).item()
 
                 # calc metrics
-- 
GitLab