diff --git a/src/ptbench/scripts/cli.py b/src/ptbench/scripts/cli.py index c9cacaa95c3108dae172e85247f698bb49bbec5f..d57a974d5528fe7aec81926c4d6d9e8e88822630 100644 --- a/src/ptbench/scripts/cli.py +++ b/src/ptbench/scripts/cli.py @@ -8,7 +8,7 @@ from clapper.click import AliasedGroup from . import ( calculate_road, - comparevis, + compare_vis, config, database, evaluate, @@ -33,7 +33,7 @@ def cli(): cli.add_command(calculate_road.calculate_road) -cli.add_command(comparevis.comparevis) +cli.add_command(compare_vis.compare_vis) cli.add_command(config.config) cli.add_command(database.database) cli.add_command(evaluate.evaluate) diff --git a/src/ptbench/scripts/comparevis.py b/src/ptbench/scripts/compare_vis.py similarity index 89% rename from src/ptbench/scripts/comparevis.py rename to src/ptbench/scripts/compare_vis.py index cdec4ad7f05b7a12f5650151d202ad69a364eb65..6cee4b9366debc7ee46348b615bbcd484c76315e 100644 --- a/src/ptbench/scripts/comparevis.py +++ b/src/ptbench/scripts/compare_vis.py @@ -26,6 +26,13 @@ def _sorting_rule(folder_name): else: # Everything else will be sorted alphabetically after fullgrad and before randomcam return (3, folder_name) +def _get_images_from_directory(dir_path): + image_files = [] + for root, _, files in os.walk(dir_path): + for file in files: + if file.lower().endswith((".png", ".jpg", ".jpeg")): + image_files.append(os.path.join(root, file)) + return image_files @click.command( epilog="""Examples: @@ -37,7 +44,7 @@ def _sorting_rule(folder_name): .. code:: sh - ptbench comparevis -i path/to/input_folder -o path/to/output_folder + ptbench compare-vis -i path/to/input_folder -o path/to/output_folder """, ) @click.option( @@ -56,7 +63,7 @@ def _sorting_rule(folder_name): type=click.Path(), ) @verbosity_option(logger=logger, expose_value=False) -def comparevis(input_folder, output_folder) -> None: +def compare_vis(input_folder, output_folder) -> None: """Compares multiple visualization techniques by showing their results in one image.""" @@ -126,13 +133,13 @@ def comparevis(input_folder, output_folder) -> None: os.makedirs(output_directory, exist_ok=True) # Use a set (unordered collection of unique elements) for efficient membership tests - image_names = set(os.listdir(comparison_folders[0])) + image_names = set([os.path.basename(img) for img in _get_images_from_directory(comparison_folders[0])]) # Only keep image names that exist in all folders for folder in comparison_folders[1:]: # This is basically an intersection-check of contents of different folders # Images that don't exist in all folders are removed from the set - image_names &= set(os.listdir(folder)) + image_names &= set([os.path.basename(img) for img in _get_images_from_directory(folder)]) if not image_names: raise ValueError("No common images found in the folders.") @@ -158,7 +165,7 @@ def comparevis(input_folder, output_folder) -> None: ) axs = axs.ravel() for i, folder in enumerate(comparison_folders): - image_path = os.path.join(folder, image_name) + image_path = [img for img in _get_images_from_directory(folder) if os.path.basename(img) == image_name][0] try: img = cv2.imread(image_path) img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)