diff --git a/bob/ip/binseg/engine/predicter.py b/bob/ip/binseg/engine/predicter.py index b6e8ad06da54b43a538cf4fc7805cc63e6966cee..ebd09ac5e84d4f9a81a20f72c3919f42071fb73d 100644 --- a/bob/ip/binseg/engine/predicter.py +++ b/bob/ip/binseg/engine/predicter.py @@ -1,7 +1,7 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- -import os +import os import logging import time import datetime @@ -24,7 +24,7 @@ def do_predict( """ Run inference and calculate metrics - + Parameters --------- model : :py:class:`torch.nn.Module` @@ -37,12 +37,12 @@ def do_predict( logger = logging.getLogger("bob.ip.binseg.engine.inference") logger.info("Start evaluation") logger.info("Output folder: {}, Device: {}".format(output_folder, device)) - results_subfolder = os.path.join(output_folder,'results') + results_subfolder = os.path.join(output_folder,'results') os.makedirs(results_subfolder,exist_ok=True) - + model.eval().to(device) - # Sigmoid for probabilities - sigmoid = torch.nn.Sigmoid() + # Sigmoid for probabilities + sigmoid = torch.nn.Sigmoid() # Setup timers start_total_time = time.time() @@ -55,24 +55,24 @@ def do_predict( start_time = time.perf_counter() outputs = model(images) - - # necessary check for hed architecture that uses several outputs + + # necessary check for hed architecture that uses several outputs # for loss calculation instead of just the last concatfuse block if isinstance(outputs,list): outputs = outputs[-1] - + probabilities = sigmoid(outputs) - + batch_time = time.perf_counter() - start_time times.append(batch_time) logger.info("Batch time: {:.5f} s".format(batch_time)) - + # Create probability images save_probability_images(probabilities, names, output_folder, logger) # Save hdf5 save_hdf(probabilities, names, output_folder, logger) - + # Report times total_inference_time = str(datetime.timedelta(seconds=int(sum(times)))) average_batch_inference_time = np.mean(times) @@ -82,7 +82,7 @@ def do_predict( times_file = "Times.txt" logger.info("saving {}".format(times_file)) - + with open (os.path.join(results_subfolder,times_file), "w+") as outfile: date = datetime.datetime.now() outfile.write("Date: {} \n".format(date.strftime("%Y-%m-%d %H:%M:%S")))