Skip to content
Snippets Groups Projects
visceral_make_splits.py 3.38 KiB
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""Generate visceral default JSON dataset for 3d binary classification tasks in mednet.

Arguments of the scripts are as follow:
    root-of-preprocessed-visceral-dataset
        Full path to the root of the  preprocessed visceral dataset.
        Filenames in the resulting json are relative to this path.
        See output format below.
    output-folder
        Full path to the folder where to output the default.json file containing the default split of data.
    organ_1_id
        Integer representing the ID of the first organ to include in the split dataset.
        This organ will be labeled as 0 for the binary classification task.
        For example, 237 corresponds to bladder.
    organ_2_id
        Integer representing the ID of the second organ to include in the split dataset.
        This organ will be labeled as 1 for the binary classification task.
        For example, 237 corresponds to bladder.

Output format is the following:

.. code:: json

   {
     "train": [
       [
         "<size>/<filename>",
         # label is one of:
         # 0: organ_1 / 1: organ_2
         <label>,
       ],
       ...
     ],
     "validation": [
       # same format as for train
       ...
     ]
     "test": [
       # same format as for train
       ...
     ]
"""

import json
import os
import pathlib
import sys

from sklearn.model_selection import train_test_split


def split_files(
    files: list[str],
    train_size: float = 0.7,
    test_size: float = 0.2,
    validation_size: float = 0.1,
):
    train_files, temp_files = train_test_split(files, test_size=(1 - train_size))
    test_files, validation_files = train_test_split(
        temp_files, test_size=(validation_size / (test_size + validation_size))
    )
    return train_files, test_files, validation_files


def save_to_json(
    train_files: list[str],
    test_files: list[str],
    validation_files: list[str],
    output_file: str,
    organ_1_id: str,
):
    data = {
        "train": [
            [filename, 0 if organ_1_id in filename else 1] for filename in train_files
        ],
        "test": [
            [filename, 0 if organ_1_id in filename else 1] for filename in test_files
        ],
        "validation": [
            [filename, 0 if organ_1_id in filename else 1]
            for filename in validation_files
        ],
    }

    with pathlib.Path(output_file).open("w") as json_file:
        json.dump(data, json_file, indent=2)


def main():
    if len(sys.argv) != 6:
        print(__doc__)
        print(
            f"Usage: python3 {sys.argv[0]} <root-of-preprocessed-visceral-dataset> <output-folder> <organ_1_id> <organ_2_id> <size>"
        )
        sys.exit(0)

    root_folder = sys.argv[1]
    output_folder = sys.argv[2]
    organ_1_id = sys.argv[3]
    organ_2_id = sys.argv[4]
    size = sys.argv[5]
    output_file = pathlib.Path(output_folder) / "default.json"
    input_folder = pathlib.Path(root_folder) / size
    files = [
        f"{size}/{file}"
        for file in os.listdir(input_folder)
        if organ_1_id in file or organ_2_id in file
    ]
    train_files, test_files, validation_files = split_files(files)

    save_to_json(train_files, test_files, validation_files, output_file, organ_1_id)
    print(f"Data saved to {output_file}")


if __name__ == "__main__":
    main()