Skip to content
Snippets Groups Projects
Commit c96f2af4 authored by ogueler@idiap.ch's avatar ogueler@idiap.ch
Browse files

removed reliance on image directory

parent 01d43650
No related branches found
No related tags found
No related merge requests found
Pipeline #78113 failed
......@@ -12,6 +12,8 @@ import pandas as pd
from PIL import Image
from tqdm import tqdm
from torchvision.transforms.functional import to_pil_image, to_tensor
from ..utils.cam_utils import (
draw_boxes_on_image,
draw_largest_component_bbox_on_image,
......@@ -33,7 +35,6 @@ def _get_target_classes_from_directory(input_folder):
def run(
data_loader,
img_dir,
input_folder,
output_folder,
dataset_split_name,
......@@ -73,19 +74,6 @@ def run(
output_folder : str
Directory containing the images overlaid with heatmaps.
"""
assert (
input_folder != output_folder
), "Output folder must not be the same as the input folder."
# Check that the output folder is not a subdirectory of the input_folder
assert not str(output_folder).startswith(
str(input_folder)
), "Output folder must not be a subdirectory of the input folder."
output_folder = os.path.abspath(output_folder)
logger.info(f"Output folder: {output_folder}")
os.makedirs(output_folder, exist_ok=True)
target_classes = _get_target_classes_from_directory(input_folder)
......@@ -127,6 +115,8 @@ def run(
names = samples[1]["name"]
images = samples[0]
if (
visualize_groundtruth
and not includes_bboxes
......@@ -143,7 +133,6 @@ def run(
vis_directory_name = os.path.basename(input_folder)
output_base_path = f"{output_folder}/{vis_directory_name}/{target_class}/{dataset_split_name}"
img_path = os.path.join(img_dir, names[0])
saliency_map_name = names[0].rsplit(".", 1)[0] + ".npy"
saliency_map_path = os.path.join(input_base_path, saliency_map_name)
......@@ -153,8 +142,12 @@ def run(
saliency_map = np.load(saliency_map_path)
# Make sure the image is RGB
pil_img = to_pil_image(images[0])
pil_img = pil_img.convert("RGB")
# Image is expected to be of type float32
original_img = np.array(Image.open(img_path), dtype=np.float32)
original_img = np.array(pil_img, dtype=np.float32)
img_with_boxes = original_img.copy()
# Draw bounding boxes on the image
......
......@@ -2,6 +2,7 @@
#
# SPDX-License-Identifier: GPL-3.0-or-later
import os
import pathlib
import click
......@@ -120,7 +121,18 @@ def visualize(
from ..engine.visualizer import run
from .utils import save_sh_command
save_sh_command(input_folder / "command.sh")
assert (
input_folder != output_folder
), "Output folder must not be the same as the input folder."
assert not str(output_folder).startswith(
str(input_folder)
), "Output folder must not be a subdirectory of the input folder."
logger.info(f"Output folder: {output_folder}")
os.makedirs(output_folder, exist_ok=True)
save_sh_command(output_folder / "command.sh")
datamodule.set_chunk_size(1, 1)
datamodule.drop_incomplete_batch = False
......@@ -134,22 +146,13 @@ def visualize(
dataloaders = datamodule.predict_dataloader()
for k, v in dataloaders.items():
# Hacky way to get img_dir
img_dir = datamodule.splits[k][0][1].datadir
if img_dir:
run(
data_loader=v,
img_dir=img_dir,
input_folder=input_folder,
output_folder=output_folder,
dataset_split_name=k,
visualize_groundtruth=visualize_groundtruth,
road_path=road_path,
visualize_detected_bbox=visualize_detected_bbox,
threshold=threshold,
)
else:
logger.warning(
'No "datadir" found in dataset. No visualizations can be generated.'
)
run(
data_loader=v,
input_folder=input_folder,
output_folder=output_folder,
dataset_split_name=k,
visualize_groundtruth=visualize_groundtruth,
road_path=road_path,
visualize_detected_bbox=visualize_detected_bbox,
threshold=threshold,
)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment