diff --git a/src/ptbench/engine/visualizer.py b/src/ptbench/engine/visualizer.py
index 593b19dc6b85e2e28e9285fb076b43b04b7157e8..87ceca5998b5bba39a8e38c7cd2cdf42e2c7763f 100644
--- a/src/ptbench/engine/visualizer.py
+++ b/src/ptbench/engine/visualizer.py
@@ -12,6 +12,8 @@ import pandas as pd
 from PIL import Image
 from tqdm import tqdm
 
+from torchvision.transforms.functional import to_pil_image, to_tensor
+
 from ..utils.cam_utils import (
     draw_boxes_on_image,
     draw_largest_component_bbox_on_image,
@@ -33,7 +35,6 @@ def _get_target_classes_from_directory(input_folder):
 
 def run(
     data_loader,
-    img_dir,
     input_folder,
     output_folder,
     dataset_split_name,
@@ -73,19 +74,6 @@ def run(
     output_folder : str
         Directory containing the images overlaid with heatmaps.
     """
-    assert (
-        input_folder != output_folder
-    ), "Output folder must not be the same as the input folder."
-
-    # Check that the output folder is not a subdirectory of the input_folder
-    assert not str(output_folder).startswith(
-        str(input_folder)
-    ), "Output folder must not be a subdirectory of the input folder."
-
-    output_folder = os.path.abspath(output_folder)
-
-    logger.info(f"Output folder: {output_folder}")
-    os.makedirs(output_folder, exist_ok=True)
 
     target_classes = _get_target_classes_from_directory(input_folder)
 
@@ -127,6 +115,8 @@ def run(
 
         names = samples[1]["name"]
 
+        images = samples[0]
+
         if (
             visualize_groundtruth
             and not includes_bboxes
@@ -143,7 +133,6 @@ def run(
             vis_directory_name = os.path.basename(input_folder)
             output_base_path = f"{output_folder}/{vis_directory_name}/{target_class}/{dataset_split_name}"
 
-            img_path = os.path.join(img_dir, names[0])
             saliency_map_name = names[0].rsplit(".", 1)[0] + ".npy"
             saliency_map_path = os.path.join(input_base_path, saliency_map_name)
 
@@ -153,8 +142,12 @@ def run(
 
             saliency_map = np.load(saliency_map_path)
 
+            # Make sure the image is RGB
+            pil_img = to_pil_image(images[0])
+            pil_img = pil_img.convert("RGB")
+
             # Image is expected to be of type float32
-            original_img = np.array(Image.open(img_path), dtype=np.float32)
+            original_img = np.array(pil_img, dtype=np.float32)
             img_with_boxes = original_img.copy()
 
             # Draw bounding boxes on the image
diff --git a/src/ptbench/scripts/visualize.py b/src/ptbench/scripts/visualize.py
index e26a2d464fed80534e0b9efb1c993ecc4d7ad481..79097c3ac2f4811edee77c8cc85a5f5b1999cd40 100644
--- a/src/ptbench/scripts/visualize.py
+++ b/src/ptbench/scripts/visualize.py
@@ -2,6 +2,7 @@
 #
 # SPDX-License-Identifier: GPL-3.0-or-later
 
+import os
 import pathlib
 
 import click
@@ -120,7 +121,18 @@ def visualize(
     from ..engine.visualizer import run
     from .utils import save_sh_command
 
-    save_sh_command(input_folder / "command.sh")
+    assert (
+        input_folder != output_folder
+    ), "Output folder must not be the same as the input folder."
+
+    assert not str(output_folder).startswith(
+        str(input_folder)
+    ), "Output folder must not be a subdirectory of the input folder."
+
+    logger.info(f"Output folder: {output_folder}")
+    os.makedirs(output_folder, exist_ok=True)
+
+    save_sh_command(output_folder / "command.sh")
 
     datamodule.set_chunk_size(1, 1)
     datamodule.drop_incomplete_batch = False
@@ -134,22 +146,13 @@ def visualize(
     dataloaders = datamodule.predict_dataloader()
 
     for k, v in dataloaders.items():
-        # Hacky way to get img_dir
-        img_dir = datamodule.splits[k][0][1].datadir
-
-        if img_dir:
-            run(
-                data_loader=v,
-                img_dir=img_dir,
-                input_folder=input_folder,
-                output_folder=output_folder,
-                dataset_split_name=k,
-                visualize_groundtruth=visualize_groundtruth,
-                road_path=road_path,
-                visualize_detected_bbox=visualize_detected_bbox,
-                threshold=threshold,
-            )
-        else:
-            logger.warning(
-                'No "datadir" found in dataset. No visualizations can be generated.'
-            )
+        run(
+            data_loader=v,
+            input_folder=input_folder,
+            output_folder=output_folder,
+            dataset_split_name=k,
+            visualize_groundtruth=visualize_groundtruth,
+            road_path=road_path,
+            visualize_detected_bbox=visualize_detected_bbox,
+            threshold=threshold,
+        )