Skip to content
Snippets Groups Projects
Commit 4d1d4867 authored by André Anjos's avatar André Anjos :speech_balloon:
Browse files

[engine/predicter] spaces

parent f9e14859
No related branches found
No related tags found
1 merge request!9Minor fixes
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import os
import os
import logging
import time
import datetime
......@@ -24,7 +24,7 @@ def do_predict(
"""
Run inference and calculate metrics
Parameters
---------
model : :py:class:`torch.nn.Module`
......@@ -37,12 +37,12 @@ def do_predict(
logger = logging.getLogger("bob.ip.binseg.engine.inference")
logger.info("Start evaluation")
logger.info("Output folder: {}, Device: {}".format(output_folder, device))
results_subfolder = os.path.join(output_folder,'results')
results_subfolder = os.path.join(output_folder,'results')
os.makedirs(results_subfolder,exist_ok=True)
model.eval().to(device)
# Sigmoid for probabilities
sigmoid = torch.nn.Sigmoid()
# Sigmoid for probabilities
sigmoid = torch.nn.Sigmoid()
# Setup timers
start_total_time = time.time()
......@@ -55,24 +55,24 @@ def do_predict(
start_time = time.perf_counter()
outputs = model(images)
# necessary check for hed architecture that uses several outputs
# necessary check for hed architecture that uses several outputs
# for loss calculation instead of just the last concatfuse block
if isinstance(outputs,list):
outputs = outputs[-1]
probabilities = sigmoid(outputs)
batch_time = time.perf_counter() - start_time
times.append(batch_time)
logger.info("Batch time: {:.5f} s".format(batch_time))
# Create probability images
save_probability_images(probabilities, names, output_folder, logger)
# Save hdf5
save_hdf(probabilities, names, output_folder, logger)
# Report times
total_inference_time = str(datetime.timedelta(seconds=int(sum(times))))
average_batch_inference_time = np.mean(times)
......@@ -82,7 +82,7 @@ def do_predict(
times_file = "Times.txt"
logger.info("saving {}".format(times_file))
with open (os.path.join(results_subfolder,times_file), "w+") as outfile:
date = datetime.datetime.now()
outfile.write("Date: {} \n".format(date.strftime("%Y-%m-%d %H:%M:%S")))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment