Skip to content
Snippets Groups Projects
Commit 07a96ada authored by André Anjos's avatar André Anjos :speech_balloon:
Browse files

[engine] Cast vector types to avoid issues with pytorch > 1.0

parent 6ba55d93
No related branches found
No related tags found
1 merge request!10Cast vector types to avoid issues with pytorch > 1.0
Pipeline #38113 passed
......@@ -60,8 +60,8 @@ def batch_metrics(predictions, ground_truths, names, output_folder, logger):
binary_pred = torch.gt(predictions[j], threshold).byte()
# equals and not-equals
equals = torch.eq(binary_pred, gts) # tensor
notequals = torch.ne(binary_pred, gts) # tensor
equals = torch.eq(binary_pred, gts).type(torch.uint8) # tensor
notequals = torch.ne(binary_pred, gts).type(torch.uint8) # tensor
# true positives
tp_tensor = (gts * binary_pred ) # tensor
......@@ -76,7 +76,7 @@ def batch_metrics(predictions, ground_truths, names, output_folder, logger):
tn_count = torch.sum(tn_tensor).item()
# false negatives
fn_tensor = notequals - fp_tensor
fn_tensor = notequals - fp_tensor.type(torch.uint8)
fn_count = torch.sum(fn_tensor).item()
# calc metrics
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment