Skip to content
Snippets Groups Projects
Commit 4909ede1 authored by André Anjos's avatar André Anjos :speech_balloon:
Browse files

[engine.predictor] Use non-blocking operation to predictor speed-up

parent 0f413ac4
No related branches found
No related tags found
No related merge requests found
......@@ -17,6 +17,7 @@ import h5py
from ..data.utils import overlayed_image
import logging
logger = logging.getLogger(__name__)
......@@ -41,10 +42,12 @@ def _save_hdf5(stem, prob, output_folder):
fullpath = os.path.join(output_folder, f"{stem}.hdf5")
tqdm.write(f"Saving {fullpath}...")
os.makedirs(os.path.dirname(fullpath), exist_ok=True)
with h5py.File(fullpath, 'w') as f:
with h5py.File(fullpath, "w") as f:
data = prob.cpu().squeeze(0).numpy()
f.create_dataset("array", data=data, compression="gzip",
compression_opts=9)
f.create_dataset(
"array", data=data, compression="gzip", compression_opts=9
)
def _save_image(stem, extension, data, output_folder):
"""Saves a PIL image into a file
......@@ -95,7 +98,7 @@ def _save_overlayed_png(stem, image, prob, output_folder):
image = VF.to_pil_image(image)
prob = VF.to_pil_image(prob.cpu())
_save_image(stem, '.png', overlayed_image(image, prob), output_folder)
_save_image(stem, ".png", overlayed_image(image, prob), output_folder)
def run(model, data_loader, device, output_folder, overlayed_folder):
......@@ -136,11 +139,13 @@ def run(model, data_loader, device, output_folder, overlayed_folder):
len_samples = []
for samples in tqdm(
data_loader, desc="batches", leave=False, disable=None,
):
data_loader, desc="batches", leave=False, disable=None,
):
names = samples[0]
images = samples[1].to(device)
images = samples[1].to(
device=device, non_blocking=torch.cuda.is_available()
)
with torch.no_grad():
......@@ -170,5 +175,7 @@ def run(model, data_loader, device, output_folder, overlayed_folder):
average_batch_time = numpy.mean(times)
logger.info(f"Average batch time: {average_batch_time:g}s")
average_image_time = numpy.sum(numpy.array(times) * len_samples) / float(sum(len_samples))
average_image_time = numpy.sum(numpy.array(times) * len_samples) / float(
sum(len_samples)
)
logger.info(f"Average image time: {average_image_time:g}s")
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment