diff --git a/bob/ip/binseg/engine/inferencer.py b/bob/ip/binseg/engine/inferencer.py index c1ed5fafa2aa20b46a749fcdca3f29a46a003455..e76153c045fc69cabffa1e5ea7d70efacbd510a6 100644 --- a/bob/ip/binseg/engine/inferencer.py +++ b/bob/ip/binseg/engine/inferencer.py @@ -60,8 +60,8 @@ def batch_metrics(predictions, ground_truths, names, output_folder, logger): binary_pred = torch.gt(predictions[j], threshold).byte() # equals and not-equals - equals = torch.eq(binary_pred, gts) # tensor - notequals = torch.ne(binary_pred, gts) # tensor + equals = torch.eq(binary_pred, gts).type(torch.uint8) # tensor + notequals = torch.ne(binary_pred, gts).type(torch.uint8) # tensor # true positives tp_tensor = (gts * binary_pred ) # tensor @@ -76,7 +76,7 @@ def batch_metrics(predictions, ground_truths, names, output_folder, logger): tn_count = torch.sum(tn_tensor).item() # false negatives - fn_tensor = notequals - fp_tensor + fn_tensor = notequals - fp_tensor.type(torch.uint8) fn_count = torch.sum(fn_tensor).item() # calc metrics