diff --git a/src/ptbench/engine/visualizer.py b/src/ptbench/engine/visualizer.py index 593b19dc6b85e2e28e9285fb076b43b04b7157e8..87ceca5998b5bba39a8e38c7cd2cdf42e2c7763f 100644 --- a/src/ptbench/engine/visualizer.py +++ b/src/ptbench/engine/visualizer.py @@ -12,6 +12,8 @@ import pandas as pd from PIL import Image from tqdm import tqdm +from torchvision.transforms.functional import to_pil_image, to_tensor + from ..utils.cam_utils import ( draw_boxes_on_image, draw_largest_component_bbox_on_image, @@ -33,7 +35,6 @@ def _get_target_classes_from_directory(input_folder): def run( data_loader, - img_dir, input_folder, output_folder, dataset_split_name, @@ -73,19 +74,6 @@ def run( output_folder : str Directory containing the images overlaid with heatmaps. """ - assert ( - input_folder != output_folder - ), "Output folder must not be the same as the input folder." - - # Check that the output folder is not a subdirectory of the input_folder - assert not str(output_folder).startswith( - str(input_folder) - ), "Output folder must not be a subdirectory of the input folder." - - output_folder = os.path.abspath(output_folder) - - logger.info(f"Output folder: {output_folder}") - os.makedirs(output_folder, exist_ok=True) target_classes = _get_target_classes_from_directory(input_folder) @@ -127,6 +115,8 @@ def run( names = samples[1]["name"] + images = samples[0] + if ( visualize_groundtruth and not includes_bboxes @@ -143,7 +133,6 @@ def run( vis_directory_name = os.path.basename(input_folder) output_base_path = f"{output_folder}/{vis_directory_name}/{target_class}/{dataset_split_name}" - img_path = os.path.join(img_dir, names[0]) saliency_map_name = names[0].rsplit(".", 1)[0] + ".npy" saliency_map_path = os.path.join(input_base_path, saliency_map_name) @@ -153,8 +142,12 @@ def run( saliency_map = np.load(saliency_map_path) + # Make sure the image is RGB + pil_img = to_pil_image(images[0]) + pil_img = pil_img.convert("RGB") + # Image is expected to be of type float32 - original_img = np.array(Image.open(img_path), dtype=np.float32) + original_img = np.array(pil_img, dtype=np.float32) img_with_boxes = original_img.copy() # Draw bounding boxes on the image diff --git a/src/ptbench/scripts/visualize.py b/src/ptbench/scripts/visualize.py index e26a2d464fed80534e0b9efb1c993ecc4d7ad481..79097c3ac2f4811edee77c8cc85a5f5b1999cd40 100644 --- a/src/ptbench/scripts/visualize.py +++ b/src/ptbench/scripts/visualize.py @@ -2,6 +2,7 @@ # # SPDX-License-Identifier: GPL-3.0-or-later +import os import pathlib import click @@ -120,7 +121,18 @@ def visualize( from ..engine.visualizer import run from .utils import save_sh_command - save_sh_command(input_folder / "command.sh") + assert ( + input_folder != output_folder + ), "Output folder must not be the same as the input folder." + + assert not str(output_folder).startswith( + str(input_folder) + ), "Output folder must not be a subdirectory of the input folder." + + logger.info(f"Output folder: {output_folder}") + os.makedirs(output_folder, exist_ok=True) + + save_sh_command(output_folder / "command.sh") datamodule.set_chunk_size(1, 1) datamodule.drop_incomplete_batch = False @@ -134,22 +146,13 @@ def visualize( dataloaders = datamodule.predict_dataloader() for k, v in dataloaders.items(): - # Hacky way to get img_dir - img_dir = datamodule.splits[k][0][1].datadir - - if img_dir: - run( - data_loader=v, - img_dir=img_dir, - input_folder=input_folder, - output_folder=output_folder, - dataset_split_name=k, - visualize_groundtruth=visualize_groundtruth, - road_path=road_path, - visualize_detected_bbox=visualize_detected_bbox, - threshold=threshold, - ) - else: - logger.warning( - 'No "datadir" found in dataset. No visualizations can be generated.' - ) + run( + data_loader=v, + input_folder=input_folder, + output_folder=output_folder, + dataset_split_name=k, + visualize_groundtruth=visualize_groundtruth, + road_path=road_path, + visualize_detected_bbox=visualize_detected_bbox, + threshold=threshold, + )