From 4909ede13ebac1819947b7de7d94eec15f294e1f Mon Sep 17 00:00:00 2001 From: Andre Anjos <andre.dos.anjos@gmail.com> Date: Wed, 13 May 2020 15:56:31 +0200 Subject: [PATCH] [engine.predictor] Use non-blocking operation to predictor speed-up --- bob/ip/binseg/engine/predictor.py | 23 +++++++++++++++-------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/bob/ip/binseg/engine/predictor.py b/bob/ip/binseg/engine/predictor.py index de78e69f..89d24d08 100644 --- a/bob/ip/binseg/engine/predictor.py +++ b/bob/ip/binseg/engine/predictor.py @@ -17,6 +17,7 @@ import h5py from ..data.utils import overlayed_image import logging + logger = logging.getLogger(__name__) @@ -41,10 +42,12 @@ def _save_hdf5(stem, prob, output_folder): fullpath = os.path.join(output_folder, f"{stem}.hdf5") tqdm.write(f"Saving {fullpath}...") os.makedirs(os.path.dirname(fullpath), exist_ok=True) - with h5py.File(fullpath, 'w') as f: + with h5py.File(fullpath, "w") as f: data = prob.cpu().squeeze(0).numpy() - f.create_dataset("array", data=data, compression="gzip", - compression_opts=9) + f.create_dataset( + "array", data=data, compression="gzip", compression_opts=9 + ) + def _save_image(stem, extension, data, output_folder): """Saves a PIL image into a file @@ -95,7 +98,7 @@ def _save_overlayed_png(stem, image, prob, output_folder): image = VF.to_pil_image(image) prob = VF.to_pil_image(prob.cpu()) - _save_image(stem, '.png', overlayed_image(image, prob), output_folder) + _save_image(stem, ".png", overlayed_image(image, prob), output_folder) def run(model, data_loader, device, output_folder, overlayed_folder): @@ -136,11 +139,13 @@ def run(model, data_loader, device, output_folder, overlayed_folder): len_samples = [] for samples in tqdm( - data_loader, desc="batches", leave=False, disable=None, - ): + data_loader, desc="batches", leave=False, disable=None, + ): names = samples[0] - images = samples[1].to(device) + images = samples[1].to( + device=device, non_blocking=torch.cuda.is_available() + ) with torch.no_grad(): @@ -170,5 +175,7 @@ def run(model, data_loader, device, output_folder, overlayed_folder): average_batch_time = numpy.mean(times) logger.info(f"Average batch time: {average_batch_time:g}s") - average_image_time = numpy.sum(numpy.array(times) * len_samples) / float(sum(len_samples)) + average_image_time = numpy.sum(numpy.array(times) * len_samples) / float( + sum(len_samples) + ) logger.info(f"Average image time: {average_image_time:g}s") -- GitLab