From 4d1d4867f1fe829929105f789e90ba1ba10138f8 Mon Sep 17 00:00:00 2001 From: Andre Anjos <andre.anjos@idiap.ch> Date: Thu, 14 Nov 2019 15:36:14 +0100 Subject: [PATCH] [engine/predicter] spaces --- bob/ip/binseg/engine/predicter.py | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/bob/ip/binseg/engine/predicter.py b/bob/ip/binseg/engine/predicter.py index b6e8ad06..ebd09ac5 100644 --- a/bob/ip/binseg/engine/predicter.py +++ b/bob/ip/binseg/engine/predicter.py @@ -1,7 +1,7 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- -import os +import os import logging import time import datetime @@ -24,7 +24,7 @@ def do_predict( """ Run inference and calculate metrics - + Parameters --------- model : :py:class:`torch.nn.Module` @@ -37,12 +37,12 @@ def do_predict( logger = logging.getLogger("bob.ip.binseg.engine.inference") logger.info("Start evaluation") logger.info("Output folder: {}, Device: {}".format(output_folder, device)) - results_subfolder = os.path.join(output_folder,'results') + results_subfolder = os.path.join(output_folder,'results') os.makedirs(results_subfolder,exist_ok=True) - + model.eval().to(device) - # Sigmoid for probabilities - sigmoid = torch.nn.Sigmoid() + # Sigmoid for probabilities + sigmoid = torch.nn.Sigmoid() # Setup timers start_total_time = time.time() @@ -55,24 +55,24 @@ def do_predict( start_time = time.perf_counter() outputs = model(images) - - # necessary check for hed architecture that uses several outputs + + # necessary check for hed architecture that uses several outputs # for loss calculation instead of just the last concatfuse block if isinstance(outputs,list): outputs = outputs[-1] - + probabilities = sigmoid(outputs) - + batch_time = time.perf_counter() - start_time times.append(batch_time) logger.info("Batch time: {:.5f} s".format(batch_time)) - + # Create probability images save_probability_images(probabilities, names, output_folder, logger) # Save hdf5 save_hdf(probabilities, names, output_folder, logger) - + # Report times total_inference_time = str(datetime.timedelta(seconds=int(sum(times)))) average_batch_inference_time = np.mean(times) @@ -82,7 +82,7 @@ def do_predict( times_file = "Times.txt" logger.info("saving {}".format(times_file)) - + with open (os.path.join(results_subfolder,times_file), "w+") as outfile: date = datetime.datetime.now() outfile.write("Date: {} \n".format(date.strftime("%Y-%m-%d %H:%M:%S"))) -- GitLab