diff --git a/helpers/visceral_preprocess.py b/helpers/visceral_preprocess.py new file mode 100644 index 0000000000000000000000000000000000000000..6e1bb9dd964b9ea662a5010eb33a4aec037635de --- /dev/null +++ b/helpers/visceral_preprocess.py @@ -0,0 +1,127 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later +""" Preprocess visceral dataset to prepare volume cubes for organ classification using 3D cnn. + +Example of use: python visceral_preprocess.py /idiap/temp/ojimenez/previous-projects/VISCERAL /idiap/home/ypannatier/test_helpers 1302 16 +""" + +import os +import sys +import torchio as tio +import torch + +def get_mask_cube(mask_path: str, size: int) -> torch.Tensor: + """Create a mask of dimension SIZExSIZExSIZE from the input mask at the center of the non-zero area. + + Parameters + ---------- + mask_path + A string representing the full path to the mask file. + size + An integer representing the size of the cube + + Returns + ------- + Tensor + The mask tensor representing a volume of dimension SIZExSIZExSIZE. + """ + mask_image = tio.ScalarImage(mask_path) + mask_data = mask_image.data.bool().squeeze() + mask_image.unload() + ones_coords = torch.nonzero(mask_data) + center = torch.mean(ones_coords.float(), dim=0).long() + half_size = int(size) // 2 + start_coords = center - half_size + end_coords = center + half_size + result = torch.zeros_like(mask_data) + slices = [slice(max(0, start), min(end, dim)) for start, end, dim in zip(start_coords, end_coords, mask_data.shape)] + result[slices[0], slices[1], slices[2]] = 1 + return result + +def get_masks(mask_paths: list[str], filters: list[str], volume_ids: list[str]) -> list[list[str]]: + """Find the list of usable masks corresponding to the desired organ. + + Parameters + ---------- + mask_paths + A list of strings representing the folders in which to search for masks. + filters + A list of strings corresponding to a substring of mask's file name. A valid mask must match one of the filters. + Each filter should contains the organ id. + Example: To get the white/black CT scans corresponding to an organ: "_1_CT_wb_{organ_id}" + volume_ids + A list containing the list of all patient ids retrieved from the volume dataset. + Valid masks must start by one of entry of volume ids. + Returns + ------- + list[list[str]] + The list of valid masks as a list of strings. Each list contains 2 entries. First, the path to the mask folder, second the mask file name. + """ + masks = [] + for mask_path in mask_paths: + for filter in filters: + available_masks = os.listdir(mask_path) + available_masks = [[mask_path, mask] for mask in available_masks + if mask.endswith('.nii.gz') + and filter in mask + and mask.split('_')[0] in volume_ids] + masks.extend(available_masks) + return masks + + +def main(): + if len(sys.argv) != 5: + print(__doc__) + print(f"Usage: python3 {sys.argv[0]} <root-of-visceral-dataset> <output_folder_path> <organ_id> <size>") + sys.exit(0) + + root_path = sys.argv[1] + output_path = sys.argv[2] + organ_id = sys.argv[3] + size = sys.argv[4] + + FILTERS = [f'_1_CTce_ThAb_{organ_id}', f'_1_CT_wb_{organ_id}',f'_1_{organ_id}'] + + + annot_2_mask_path = os.path.join(root_path, 'annotations/Anatomy2/anat2-trainingset/Anat2_Segmentations') + annot_3_mask_path = os.path.join(root_path, 'annotations/Anatomy3/Visceral-QC-testset/qc-anat3-testset-segmentations') + silver_corpus_mask_path = os.path.join(root_path, 'annotations/SilverCorpus/AnatomySilverCorpus/BinarySegmentations') + mask_paths = [annot_2_mask_path, annot_3_mask_path, silver_corpus_mask_path] + volume_path = os.path.join(root_path, 'volumes_for_annotation/GeoS_oriented_Volumes(Annotators)') + output_size_path = os.path.join(output_path, size) + + # Ensure required output folders exist + if not os.path.exists(output_path): + os.makedirs(output_path) + if not os.path.exists(output_path): + os.makedirs(output_size_path) + if not os.path.exists(output_size_path): + os.makedirs(output_size_path) + + + volumes = os.listdir(volume_path) + volume_ids = [volume.split('_')[0] for volume in volumes] + + masks = get_masks(mask_paths, FILTERS, volume_ids) + + + for mask in masks: + id = f'{mask[1].split("_")[0]}_1' + full_path = os.path.join(mask[0], mask[1]) + mask_cube = get_mask_cube(full_path, size) + volume_name = [volume for volume in volumes if id in volume][0] + volume_full_path = os.path.join(volume_path, volume_name) + output_full_path = os.path.join(output_size_path, mask[1]) + volume_image = tio.ScalarImage(volume_full_path) + volume_data = volume_image.data.squeeze() + volume_image.unload() + + volume_data = volume_data[mask_cube==1].reshape(int(size), int(size), int(size)).unsqueeze(0) + cropped_volume = tio.ScalarImage(tensor=volume_data) + cropped_volume.save(output_full_path) + + print(f'Output file in : {output_full_path}') + +if __name__ == "__main__": + main()