From 07a96ada0656b7faf4d0dac615b6d7ee738b40ae Mon Sep 17 00:00:00 2001 From: Andre Anjos <andre.dos.anjos@gmail.com> Date: Mon, 16 Mar 2020 12:08:25 +0100 Subject: [PATCH] [engine] Cast vector types to avoid issues with pytorch > 1.0 --- bob/ip/binseg/engine/inferencer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/bob/ip/binseg/engine/inferencer.py b/bob/ip/binseg/engine/inferencer.py index c1ed5faf..e76153c0 100644 --- a/bob/ip/binseg/engine/inferencer.py +++ b/bob/ip/binseg/engine/inferencer.py @@ -60,8 +60,8 @@ def batch_metrics(predictions, ground_truths, names, output_folder, logger): binary_pred = torch.gt(predictions[j], threshold).byte() # equals and not-equals - equals = torch.eq(binary_pred, gts) # tensor - notequals = torch.ne(binary_pred, gts) # tensor + equals = torch.eq(binary_pred, gts).type(torch.uint8) # tensor + notequals = torch.ne(binary_pred, gts).type(torch.uint8) # tensor # true positives tp_tensor = (gts * binary_pred ) # tensor @@ -76,7 +76,7 @@ def batch_metrics(predictions, ground_truths, names, output_folder, logger): tn_count = torch.sum(tn_tensor).item() # false negatives - fn_tensor = notequals - fp_tensor + fn_tensor = notequals - fp_tensor.type(torch.uint8) fn_count = torch.sum(fn_tensor).item() # calc metrics -- GitLab