diff --git a/bob/ip/binseg/engine/predictor.py b/bob/ip/binseg/engine/predictor.py index de78e69f5ba071d017babd5f7b3654244e64bb80..89d24d089d7f4eeece5d3974fe5e11ab90270eb0 100644 --- a/bob/ip/binseg/engine/predictor.py +++ b/bob/ip/binseg/engine/predictor.py @@ -17,6 +17,7 @@ import h5py from ..data.utils import overlayed_image import logging + logger = logging.getLogger(__name__) @@ -41,10 +42,12 @@ def _save_hdf5(stem, prob, output_folder): fullpath = os.path.join(output_folder, f"{stem}.hdf5") tqdm.write(f"Saving {fullpath}...") os.makedirs(os.path.dirname(fullpath), exist_ok=True) - with h5py.File(fullpath, 'w') as f: + with h5py.File(fullpath, "w") as f: data = prob.cpu().squeeze(0).numpy() - f.create_dataset("array", data=data, compression="gzip", - compression_opts=9) + f.create_dataset( + "array", data=data, compression="gzip", compression_opts=9 + ) + def _save_image(stem, extension, data, output_folder): """Saves a PIL image into a file @@ -95,7 +98,7 @@ def _save_overlayed_png(stem, image, prob, output_folder): image = VF.to_pil_image(image) prob = VF.to_pil_image(prob.cpu()) - _save_image(stem, '.png', overlayed_image(image, prob), output_folder) + _save_image(stem, ".png", overlayed_image(image, prob), output_folder) def run(model, data_loader, device, output_folder, overlayed_folder): @@ -136,11 +139,13 @@ def run(model, data_loader, device, output_folder, overlayed_folder): len_samples = [] for samples in tqdm( - data_loader, desc="batches", leave=False, disable=None, - ): + data_loader, desc="batches", leave=False, disable=None, + ): names = samples[0] - images = samples[1].to(device) + images = samples[1].to( + device=device, non_blocking=torch.cuda.is_available() + ) with torch.no_grad(): @@ -170,5 +175,7 @@ def run(model, data_loader, device, output_folder, overlayed_folder): average_batch_time = numpy.mean(times) logger.info(f"Average batch time: {average_batch_time:g}s") - average_image_time = numpy.sum(numpy.array(times) * len_samples) / float(sum(len_samples)) + average_image_time = numpy.sum(numpy.array(times) * len_samples) / float( + sum(len_samples) + ) logger.info(f"Average image time: {average_image_time:g}s")