Skip to content
Snippets Groups Projects
Commit cd85bb03 authored by ogueler@idiap.ch's avatar ogueler@idiap.ch Committed by Daniel CARRON
Browse files

removed reliance on image directory

parent 95b658e8
No related branches found
No related tags found
1 merge request!12Adds grad-cam support on classifiers
...@@ -12,6 +12,8 @@ import pandas as pd ...@@ -12,6 +12,8 @@ import pandas as pd
from PIL import Image from PIL import Image
from tqdm import tqdm from tqdm import tqdm
from torchvision.transforms.functional import to_pil_image, to_tensor
from ..utils.cam_utils import ( from ..utils.cam_utils import (
draw_boxes_on_image, draw_boxes_on_image,
draw_largest_component_bbox_on_image, draw_largest_component_bbox_on_image,
...@@ -33,7 +35,6 @@ def _get_target_classes_from_directory(input_folder): ...@@ -33,7 +35,6 @@ def _get_target_classes_from_directory(input_folder):
def run( def run(
data_loader, data_loader,
img_dir,
input_folder, input_folder,
output_folder, output_folder,
dataset_split_name, dataset_split_name,
...@@ -73,19 +74,6 @@ def run( ...@@ -73,19 +74,6 @@ def run(
output_folder : str output_folder : str
Directory containing the images overlaid with heatmaps. Directory containing the images overlaid with heatmaps.
""" """
assert (
input_folder != output_folder
), "Output folder must not be the same as the input folder."
# Check that the output folder is not a subdirectory of the input_folder
assert not str(output_folder).startswith(
str(input_folder)
), "Output folder must not be a subdirectory of the input folder."
output_folder = os.path.abspath(output_folder)
logger.info(f"Output folder: {output_folder}")
os.makedirs(output_folder, exist_ok=True)
target_classes = _get_target_classes_from_directory(input_folder) target_classes = _get_target_classes_from_directory(input_folder)
...@@ -127,6 +115,8 @@ def run( ...@@ -127,6 +115,8 @@ def run(
names = samples[1]["name"] names = samples[1]["name"]
images = samples[0]
if ( if (
visualize_groundtruth visualize_groundtruth
and not includes_bboxes and not includes_bboxes
...@@ -143,7 +133,6 @@ def run( ...@@ -143,7 +133,6 @@ def run(
vis_directory_name = os.path.basename(input_folder) vis_directory_name = os.path.basename(input_folder)
output_base_path = f"{output_folder}/{vis_directory_name}/{target_class}/{dataset_split_name}" output_base_path = f"{output_folder}/{vis_directory_name}/{target_class}/{dataset_split_name}"
img_path = os.path.join(img_dir, names[0])
saliency_map_name = names[0].rsplit(".", 1)[0] + ".npy" saliency_map_name = names[0].rsplit(".", 1)[0] + ".npy"
saliency_map_path = os.path.join(input_base_path, saliency_map_name) saliency_map_path = os.path.join(input_base_path, saliency_map_name)
...@@ -153,8 +142,12 @@ def run( ...@@ -153,8 +142,12 @@ def run(
saliency_map = np.load(saliency_map_path) saliency_map = np.load(saliency_map_path)
# Make sure the image is RGB
pil_img = to_pil_image(images[0])
pil_img = pil_img.convert("RGB")
# Image is expected to be of type float32 # Image is expected to be of type float32
original_img = np.array(Image.open(img_path), dtype=np.float32) original_img = np.array(pil_img, dtype=np.float32)
img_with_boxes = original_img.copy() img_with_boxes = original_img.copy()
# Draw bounding boxes on the image # Draw bounding boxes on the image
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
# #
# SPDX-License-Identifier: GPL-3.0-or-later # SPDX-License-Identifier: GPL-3.0-or-later
import os
import pathlib import pathlib
import click import click
...@@ -120,7 +121,18 @@ def visualize( ...@@ -120,7 +121,18 @@ def visualize(
from ..engine.visualizer import run from ..engine.visualizer import run
from .utils import save_sh_command from .utils import save_sh_command
save_sh_command(input_folder / "command.sh") assert (
input_folder != output_folder
), "Output folder must not be the same as the input folder."
assert not str(output_folder).startswith(
str(input_folder)
), "Output folder must not be a subdirectory of the input folder."
logger.info(f"Output folder: {output_folder}")
os.makedirs(output_folder, exist_ok=True)
save_sh_command(output_folder / "command.sh")
datamodule.set_chunk_size(1, 1) datamodule.set_chunk_size(1, 1)
datamodule.drop_incomplete_batch = False datamodule.drop_incomplete_batch = False
...@@ -134,22 +146,13 @@ def visualize( ...@@ -134,22 +146,13 @@ def visualize(
dataloaders = datamodule.predict_dataloader() dataloaders = datamodule.predict_dataloader()
for k, v in dataloaders.items(): for k, v in dataloaders.items():
# Hacky way to get img_dir run(
img_dir = datamodule.splits[k][0][1].datadir data_loader=v,
input_folder=input_folder,
if img_dir: output_folder=output_folder,
run( dataset_split_name=k,
data_loader=v, visualize_groundtruth=visualize_groundtruth,
img_dir=img_dir, road_path=road_path,
input_folder=input_folder, visualize_detected_bbox=visualize_detected_bbox,
output_folder=output_folder, threshold=threshold,
dataset_split_name=k, )
visualize_groundtruth=visualize_groundtruth,
road_path=road_path,
visualize_detected_bbox=visualize_detected_bbox,
threshold=threshold,
)
else:
logger.warning(
'No "datadir" found in dataset. No visualizations can be generated.'
)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment