diff --git a/bob/ip/binseg/engine/inferencer.py b/bob/ip/binseg/engine/inferencer.py
index c1ed5fafa2aa20b46a749fcdca3f29a46a003455..e76153c045fc69cabffa1e5ea7d70efacbd510a6 100644
--- a/bob/ip/binseg/engine/inferencer.py
+++ b/bob/ip/binseg/engine/inferencer.py
@@ -60,8 +60,8 @@ def batch_metrics(predictions, ground_truths, names, output_folder, logger):
                 binary_pred = torch.gt(predictions[j], threshold).byte()
 
                 # equals and not-equals
-                equals = torch.eq(binary_pred, gts) # tensor
-                notequals = torch.ne(binary_pred, gts) # tensor
+                equals = torch.eq(binary_pred, gts).type(torch.uint8) # tensor
+                notequals = torch.ne(binary_pred, gts).type(torch.uint8) # tensor
 
                 # true positives
                 tp_tensor = (gts * binary_pred ) # tensor
@@ -76,7 +76,7 @@ def batch_metrics(predictions, ground_truths, names, output_folder, logger):
                 tn_count = torch.sum(tn_tensor).item()
 
                 # false negatives
-                fn_tensor = notequals - fp_tensor
+                fn_tensor = notequals - fp_tensor.type(torch.uint8)
                 fn_count = torch.sum(fn_tensor).item()
 
                 # calc metrics