Skip to content
Snippets Groups Projects

Cast vector types to avoid issues with pytorch > 1.0

Merged André Anjos requested to merge inferencer-fix into master
1 file
+ 3
3
Compare changes
  • Side-by-side
  • Inline
@@ -60,8 +60,8 @@ def batch_metrics(predictions, ground_truths, names, output_folder, logger):
binary_pred = torch.gt(predictions[j], threshold).byte()
# equals and not-equals
equals = torch.eq(binary_pred, gts) # tensor
notequals = torch.ne(binary_pred, gts) # tensor
equals = torch.eq(binary_pred, gts).type(torch.uint8) # tensor
notequals = torch.ne(binary_pred, gts).type(torch.uint8) # tensor
# true positives
tp_tensor = (gts * binary_pred ) # tensor
@@ -76,7 +76,7 @@ def batch_metrics(predictions, ground_truths, names, output_folder, logger):
tn_count = torch.sum(tn_tensor).item()
# false negatives
fn_tensor = notequals - fp_tensor
fn_tensor = notequals - fp_tensor.type(torch.uint8)
fn_count = torch.sum(fn_tensor).item()
# calc metrics
Loading