#!/usr/bin/env python # -*- coding: utf-8 -*- import os import logging import time import datetime import numpy as np import torch import pandas as pd import torchvision.transforms.functional as VF from tqdm import tqdm from bob.ip.binseg.utils.metric import SmoothedValue, base_metrics from bob.ip.binseg.utils.plot import precision_recall_f1iso from bob.ip.binseg.utils.summary import summary def batch_metrics(predictions, ground_truths, names, output_folder, logger): """ Calculates metrics on the batch and saves it to disc Parameters ---------- predictions : :py:class:`torch.Tensor` tensor with pixel-wise probabilities ground_truths : :py:class:`torch.Tensor` tensor with binary ground-truth names : list list of file names output_folder : str output path logger : :py:class:`logging.Logger` python logger Returns ------- list list containing batch metrics: ``[name, threshold, precision, recall, specificity, accuracy, jaccard, f1_score]`` """ step_size = 0.01 batch_metrics = [] for j in range(predictions.size()[0]): # ground truth byte gts = ground_truths[j].byte() file_name = "{}.csv".format(names[j]) logger.info("saving {}".format(file_name)) with open (os.path.join(output_folder,file_name), "w+") as outfile: outfile.write("threshold, precision, recall, specificity, accuracy, jaccard, f1_score\n") for threshold in np.arange(0.0,1.0,step_size): # threshold binary_pred = torch.gt(predictions[j], threshold).byte() # equals and not-equals equals = torch.eq(binary_pred, gts) # tensor notequals = torch.ne(binary_pred, gts) # tensor # true positives tp_tensor = (gts * binary_pred ) # tensor tp_count = torch.sum(tp_tensor).item() # scalar # false positives fp_tensor = torch.eq((binary_pred + tp_tensor), 1) fp_count = torch.sum(fp_tensor).item() # true negatives tn_tensor = equals - tp_tensor tn_count = torch.sum(tn_tensor).item() # false negatives fn_tensor = notequals - fp_tensor fn_count = torch.sum(fn_tensor).item() # calc metrics metrics = base_metrics(tp_count, fp_count, tn_count, fn_count) # write to disk outfile.write("{:.2f},{:.5f},{:.5f},{:.5f},{:.5f},{:.5f},{:.5f} \n".format(threshold, *metrics)) batch_metrics.append([names[j],threshold, *metrics ]) return batch_metrics def save_probability_images(predictions, names, output_folder, logger): """ Saves probability maps as image in the same format as the test image Parameters ---------- predictions : :py:class:`torch.Tensor` tensor with pixel-wise probabilities names : list list of file names output_folder : str output path logger : :py:class:`logging.Logger` python logger """ images_subfolder = os.path.join(output_folder,'images') if not os.path.exists(images_subfolder): os.makedirs(images_subfolder) for j in range(predictions.size()[0]): img = VF.to_pil_image(predictions.cpu().data[j]) filename = '{}'.format(names[j]) logger.info("saving {}".format(filename)) img.save(os.path.join(images_subfolder, filename)) def do_inference( model, data_loader, device, output_folder = None ): """ Run inference and calculate metrics Parameters --------- model : :py:class:`torch.nn.Module` neural network model (e.g. DRIU, HED, UNet) data_loader : py:class:`torch.torch.utils.data.DataLoader` device : str device to use ``'cpu'`` or ``'cuda'`` output_folder : str """ logger = logging.getLogger("bob.ip.binseg.engine.inference") logger.info("Start evaluation") logger.info("Split: {}, Output folder: {}, Device: {}".format(data_loader.dataset.split, output_folder, device)) results_subfolder = os.path.join(output_folder,'results') if not os.path.exists(results_subfolder): os.makedirs(results_subfolder) model.eval().to(device) # Sigmoid for probabilities sigmoid = torch.nn.Sigmoid() # Setup timers start_total_time = time.time() times = [] # Collect overall metrics metrics = [] for samples in tqdm(data_loader): names = samples[0] images = samples[1].to(device) ground_truths = samples[2].to(device) with torch.no_grad(): start_time = time.perf_counter() outputs = model(images) # necessary check for hed architecture that uses several outputs # for loss calculation instead of just the last concatfuse block if isinstance(outputs,list): outputs = outputs[-1] probabilities = sigmoid(outputs) batch_time = time.perf_counter() - start_time times.append(batch_time) logger.info("Batch time: {:.5f} s".format(batch_time)) b_metrics = batch_metrics(probabilities, ground_truths, names,results_subfolder, logger) metrics.extend(b_metrics) # Create probability images save_probability_images(probabilities, names, output_folder, logger) # DataFrame df_metrics = pd.DataFrame(metrics,columns= \ ["name", "threshold", "precision", "recall", "specificity", "accuracy", "jaccard", "f1_score"]) # Report and Averages metrics_file = "Metrics.csv".format(model.name) metrics_path = os.path.join(results_subfolder, metrics_file) logger.info("Saving average over all input images: {}".format(metrics_file)) avg_metrics = df_metrics.groupby('threshold').mean() avg_metrics["f1_score"] = (2* avg_metrics["precision"]*avg_metrics["recall"])/ \ (avg_metrics["precision"]+avg_metrics["recall"]) avg_metrics.to_csv(metrics_path) maxf1 = avg_metrics['f1_score'].max() optimal_f1_threshold = avg_metrics['f1_score'].idxmax() logger.info("Highest F1-score of {:.5f}, achieved at threshold {}".format(maxf1, optimal_f1_threshold)) # Plotting np_avg_metrics = avg_metrics.to_numpy().T fig_name = "precision_recall.pdf" logger.info("saving {}".format(fig_name)) fig = precision_recall_f1iso([np_avg_metrics[0]],[np_avg_metrics[1]], [model.name,None], title=output_folder) fig_filename = os.path.join(results_subfolder, fig_name) fig.savefig(fig_filename) # Report times total_inference_time = str(datetime.timedelta(seconds=int(sum(times)))) average_batch_inference_time = np.mean(times) total_evalution_time = str(datetime.timedelta(seconds=int(time.time() - start_total_time ))) logger.info("Average batch inference time: {:.5f}s".format(average_batch_inference_time)) times_file = "Times.txt" logger.info("saving {}".format(times_file)) with open (os.path.join(results_subfolder,times_file), "w+") as outfile: date = datetime.datetime.now() outfile.write("Date: {} \n".format(date.strftime("%Y-%m-%d %H:%M:%S"))) outfile.write("Total evaluation run-time: {} \n".format(total_evalution_time)) outfile.write("Average batch inference time: {} \n".format(average_batch_inference_time)) outfile.write("Total inference time: {} \n".format(total_inference_time)) # Save model summary summary_file = 'ModelSummary.txt' logger.info("saving {}".format(summary_file)) with open (os.path.join(results_subfolder,summary_file), "w+") as outfile: summary(model,outfile)