diff --git a/helpers/visceral_preprocess.py b/helpers/visceral_preprocess.py index 6e1bb9dd964b9ea662a5010eb33a4aec037635de..3391f5128db47970f3948cfd558388e95566bbe2 100644 --- a/helpers/visceral_preprocess.py +++ b/helpers/visceral_preprocess.py @@ -1,15 +1,28 @@ # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> # # SPDX-License-Identifier: GPL-3.0-or-later -""" Preprocess visceral dataset to prepare volume cubes for organ classification using 3D cnn. - -Example of use: python visceral_preprocess.py /idiap/temp/ojimenez/previous-projects/VISCERAL /idiap/home/ypannatier/test_helpers 1302 16 +"""Preprocess visceral dataset to prepare volume cubes for organ classification using 3D cnn. +Example of use: 'python visceral_preprocess.py /idiap/temp/ojimenez/previous-projects/VISCERAL /idiap/home/ypannatier/visceral/preprocessed 237 16 '. + +Arguments of the scripts are as follow: + root-of-visceral-dataset + Full path to the root of the visceral dataset. + output_folder_path + Full path to the folder where the prepared cubes will be saved. + Note that the script will create a subfolder corresponding to the desired size of cube to avoid mixing volumes of different sizes. + organ_id + Integer representing the ID of the organ to process in the visceral dataset. For example, 237 corresponds to bladder. + size + Integer representing the size of the volume cube that will be output. Each volume will be of dimension SIZExSIZExSIZE. """ import os +import pathlib import sys -import torchio as tio + import torch +import torchio as tio + def get_mask_cube(mask_path: str, size: int) -> torch.Tensor: """Create a mask of dimension SIZExSIZExSIZE from the input mask at the center of the non-zero area. @@ -19,7 +32,7 @@ def get_mask_cube(mask_path: str, size: int) -> torch.Tensor: mask_path A string representing the full path to the mask file. size - An integer representing the size of the cube + An integer representing the size of the cube. Returns ------- @@ -35,11 +48,15 @@ def get_mask_cube(mask_path: str, size: int) -> torch.Tensor: start_coords = center - half_size end_coords = center + half_size result = torch.zeros_like(mask_data) - slices = [slice(max(0, start), min(end, dim)) for start, end, dim in zip(start_coords, end_coords, mask_data.shape)] + slices = [ + slice(max(0, start), min(end, dim)) + for start, end, dim in zip(start_coords, end_coords, mask_data.shape) + ] result[slices[0], slices[1], slices[2]] = 1 return result -def get_masks(mask_paths: list[str], filters: list[str], volume_ids: list[str]) -> list[list[str]]: + +def get_masks(mask_paths: list[str], filters: list[str], volume_ids: list[str]): """Find the list of usable masks corresponding to the desired organ. Parameters @@ -49,31 +66,35 @@ def get_masks(mask_paths: list[str], filters: list[str], volume_ids: list[str]) filters A list of strings corresponding to a substring of mask's file name. A valid mask must match one of the filters. Each filter should contains the organ id. - Example: To get the white/black CT scans corresponding to an organ: "_1_CT_wb_{organ_id}" + Example: To get the white/black CT scans corresponding to an organ: "_1_CT_wb_{organ_id}". volume_ids - A list containing the list of all patient ids retrieved from the volume dataset. + A list containing the list of all patient ids retrieved from the volume dataset. Valid masks must start by one of entry of volume ids. + Returns ------- - list[list[str]] - The list of valid masks as a list of strings. Each list contains 2 entries. First, the path to the mask folder, second the mask file name. + The list of valid masks as a list of strings. Each list contains 2 entries. First, the path to the mask folder, second the mask file name. """ masks = [] for mask_path in mask_paths: - for filter in filters: + for mask_filter in filters: available_masks = os.listdir(mask_path) - available_masks = [[mask_path, mask] for mask in available_masks - if mask.endswith('.nii.gz') - and filter in mask - and mask.split('_')[0] in volume_ids] - masks.extend(available_masks) + for mask in available_masks: + if ( + mask.endswith(".nii.gz") + and mask_filter in mask + and mask.split("_")[0] in volume_ids + ): + masks.append([mask_path, mask]) return masks def main(): if len(sys.argv) != 5: print(__doc__) - print(f"Usage: python3 {sys.argv[0]} <root-of-visceral-dataset> <output_folder_path> <organ_id> <size>") + print( + f"Usage: python3 {sys.argv[0]} <root-of-visceral-dataset> <output_folder_path> <organ_id> <size>" + ) sys.exit(0) root_path = sys.argv[1] @@ -81,47 +102,79 @@ def main(): organ_id = sys.argv[3] size = sys.argv[4] - FILTERS = [f'_1_CTce_ThAb_{organ_id}', f'_1_CT_wb_{organ_id}',f'_1_{organ_id}'] - + filters = [f"_1_CTce_ThAb_{organ_id}", f"_1_CT_wb_{organ_id}", f"_1_{organ_id}"] + + annot_2_mask_path = ( + pathlib.Path(root_path) + / "annotations" + / "Anatomy2" + / "anat2-trainingset" + / "Anat2_Segmentations" + ) + + annot_3_mask_path = ( + pathlib.Path(root_path) + / "annotations" + / "Anatomy3" + / "Visceral-QC-testset" + / "qc-anat3-testset-segmentations" + ) + + silver_corpus_mask_path = ( + pathlib.Path(root_path) + / "annotations" + / "SilverCorpus" + / "AnatomySilverCorpus" + / "BinarySegmentations" + ) - annot_2_mask_path = os.path.join(root_path, 'annotations/Anatomy2/anat2-trainingset/Anat2_Segmentations') - annot_3_mask_path = os.path.join(root_path, 'annotations/Anatomy3/Visceral-QC-testset/qc-anat3-testset-segmentations') - silver_corpus_mask_path = os.path.join(root_path, 'annotations/SilverCorpus/AnatomySilverCorpus/BinarySegmentations') mask_paths = [annot_2_mask_path, annot_3_mask_path, silver_corpus_mask_path] - volume_path = os.path.join(root_path, 'volumes_for_annotation/GeoS_oriented_Volumes(Annotators)') - output_size_path = os.path.join(output_path, size) - - # Ensure required output folders exist - if not os.path.exists(output_path): - os.makedirs(output_path) - if not os.path.exists(output_path): - os.makedirs(output_size_path) - if not os.path.exists(output_size_path): - os.makedirs(output_size_path) + volume_path = ( + pathlib.Path(root_path) + / "volumes_for_annotation" + / "GeoS_oriented_Volumes(Annotators)" + ) + output_size_path = pathlib.Path(output_path) / size + # Ensure required output folders exist + output_size_path.mkdir(parents=True, exist_ok=True) volumes = os.listdir(volume_path) - volume_ids = [volume.split('_')[0] for volume in volumes] - - masks = get_masks(mask_paths, FILTERS, volume_ids) - - - for mask in masks: - id = f'{mask[1].split("_")[0]}_1' - full_path = os.path.join(mask[0], mask[1]) - mask_cube = get_mask_cube(full_path, size) - volume_name = [volume for volume in volumes if id in volume][0] - volume_full_path = os.path.join(volume_path, volume_name) - output_full_path = os.path.join(output_size_path, mask[1]) - volume_image = tio.ScalarImage(volume_full_path) - volume_data = volume_image.data.squeeze() - volume_image.unload() - - volume_data = volume_data[mask_cube==1].reshape(int(size), int(size), int(size)).unsqueeze(0) - cropped_volume = tio.ScalarImage(tensor=volume_data) - cropped_volume.save(output_full_path) - - print(f'Output file in : {output_full_path}') + volume_ids = [volume.split("_")[0] for volume in volumes] + + masks = get_masks(mask_paths, filters, volume_ids) + + print(f"Found {len(masks)} volumes to process...") + print("Generating volumes...") + for i, mask in enumerate(masks): + if i % 10 == 0: + print(f"Generated volumes: {i}/{len(masks)}") + + patient_id = f'{mask[1].split("_")[0]}_1' + full_path = pathlib.Path(mask[0]) / mask[1] + volume_name = [volume for volume in volumes if patient_id in volume][0] + volume_full_path = pathlib.Path(volume_path) / volume_name + output_full_path = pathlib.Path(output_size_path) / mask[1] + + if pathlib.Path.exists(output_full_path): + continue + + try: + mask_cube = get_mask_cube(full_path, size) + + volume_image = tio.ScalarImage(volume_full_path) + volume_data = volume_image.data.squeeze() + volume_image.unload() + volume_data = ( + volume_data[mask_cube == 1] + .reshape(int(size), int(size), int(size)) + .unsqueeze(0) + ) + cropped_volume = tio.ScalarImage(tensor=volume_data) + cropped_volume.save(output_full_path) + except Exception as e: + print(f"Error: {e} while processing {full_path}") + if __name__ == "__main__": main()