Skip to content
Snippets Groups Projects
Commit f2f202e7 authored by André Anjos's avatar André Anjos :speech_balloon:
Browse files

Merge branch 'inferencer-fix' into 'master'

Cast vector types to avoid issues with pytorch > 1.0

See merge request bob/bob.ip.binseg!10
parents 6ba55d93 07a96ada
No related branches found
No related tags found
1 merge request!10Cast vector types to avoid issues with pytorch > 1.0
Pipeline #38114 passed
...@@ -60,8 +60,8 @@ def batch_metrics(predictions, ground_truths, names, output_folder, logger): ...@@ -60,8 +60,8 @@ def batch_metrics(predictions, ground_truths, names, output_folder, logger):
binary_pred = torch.gt(predictions[j], threshold).byte() binary_pred = torch.gt(predictions[j], threshold).byte()
# equals and not-equals # equals and not-equals
equals = torch.eq(binary_pred, gts) # tensor equals = torch.eq(binary_pred, gts).type(torch.uint8) # tensor
notequals = torch.ne(binary_pred, gts) # tensor notequals = torch.ne(binary_pred, gts).type(torch.uint8) # tensor
# true positives # true positives
tp_tensor = (gts * binary_pred ) # tensor tp_tensor = (gts * binary_pred ) # tensor
...@@ -76,7 +76,7 @@ def batch_metrics(predictions, ground_truths, names, output_folder, logger): ...@@ -76,7 +76,7 @@ def batch_metrics(predictions, ground_truths, names, output_folder, logger):
tn_count = torch.sum(tn_tensor).item() tn_count = torch.sum(tn_tensor).item()
# false negatives # false negatives
fn_tensor = notequals - fp_tensor fn_tensor = notequals - fp_tensor.type(torch.uint8)
fn_count = torch.sum(fn_tensor).item() fn_count = torch.sum(fn_tensor).item()
# calc metrics # calc metrics
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment