Skip to content
Snippets Groups Projects
make_splits_from_database.py 13.7 KiB
Newer Older
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""Converts TBX11k JSON annotation files into simplified JSON datasets for
ptbench.

Requires ``datadir.tbx11k`` to be set on your configuration file, or that you
are sitting at the root directory of the database.

Because the test set does not have annotations, we generate train, validation
and test datasets as such:

1. The original validation set becomes the test set.
2. The original training set is split into new training and validation
   sets (validation ration = 0.203 by default).  The selection of samples is
   stratified (respects class proportions in Özgür's way - see comments through
   the code.)

Our output format is the following:

.. code:: json

   {
     "train": [
       [
         <filename-from-root>,
         # label is one of:
         # 0: healthy / 1: active-tb / 2: active-and-latent-tb
         # 3: latent-tb / 4: sick (no tb)
         <label>, [
             # bounding-box annotations follow.  Box-labels are:
             # 0: latent-tb sign / 1: active-tb sign
             [<box-label>, <xmin>, <ymin>, <width>, <height>],
             [0, <xmin>, <ymin>, <width>, <height>],
             [1, <xmin>, <ymin>, <width>, <height>],
             ...
         ],
       ],
       ...
     ],
     "validation": [
       # same format as for train
       ...
     ]
     "test": [
       # same format as for train
       ...
     ]
"""

import collections
import json
import os
import pathlib
import sys
import typing

from sklearn.model_selection import StratifiedKFold, train_test_split


def reorder(data: dict) -> list:
    """Reorders data from TBX11K into a sample-based organisation."""

    categories = {k["id"]: k["name"] for k in data["categories"]}
    assert len(set(categories.values())) == len(
        categories
    ), "Category ids are not unique"

    # reset category values, so latent-tb = 0, and active-tb = 1
    cat_translator = {
        "ActiveTuberculosis": 1,
        "ObsoletePulmonaryTuberculosis": 0,
        "PulmonaryTuberculosis": 2,  # this should NOT exist anywhere!
    }
    categories = {k: cat_translator[v] for k, v in categories.items()}

    images = {k["id"]: k["file_name"] for k in data["images"]}
    assert len(set(images.values())) == len(images), "Image ids are not unique"

    retval: dict[str, list[typing.Any]] = {
        k["file_name"]: [-1, []] for k in data["images"]
    }

    # we now "consume" all annotations and assign each to an image
    for annotation in data["annotations"]:
        int_bbox: list[int] = [
            categories[annotation["category_id"]],
            *[round(k) for k in annotation["bbox"]],
        ]
        retval[images[annotation["image_id"]]][1].append(int_bbox)

    # remove empty bounding-box entries to save space on final JSON
    for v in retval.values():
        if not v[1]:
            del v[1]

    return sorted([["imgs/" + k, *v] for k, v in retval.items()])


def normalize_labels(data: list) -> list:
    """Decides on the final labels for each sample.

    Categories are decided on the following principles:

    0: healthy, no other bounding box detected, comes from the imgs/health
       subdir
    1: active-tb, no latent tb, comes from the imgs/tb subdir, has one or more
       bounding boxes with label 1, and no bounding box with label 0
    2: active-tb and latent tb, comes from the imgs/tb subdir, has one or more
       bounding boxes with label 1 and one or more with label 0
    3: latent tb, comes from the imgs/tb subdir, has one or more
       bounding boxes with label 0 and no bounding box with label 1
    4: sick (but no tb), comes from the imgs/sick subdir, does not have any
       annotated bounding box.
    """

    def _set_label(s: list) -> int:
        if s[0].startswith("imgs/health"):
            assert (
                len(s) == 2
            ), f"Image {s[0]} is healthy, but contains tb bbox annotations"
            return 0  # patient is healthy

        elif s[0].startswith("imgs/sick"):
            assert (
                len(s) == 2
            ), f"Image {s[0]} is sick (no tb), but contains tb bbox annotations"
            return 4  # patient is sick

        elif s[0].startswith("imgs/tb"):
            if len(s) == 2:
                print(
                    f"WARNING: Image {s[0]} is from the tb subdir, "
                    f"but contains no tb bbox annotations"
                )
                return -1  # unknown diagnosis
            bbx_labels: list[int] = [k[0] for k in s[2]]
            tb_counts = collections.Counter(bbx_labels)
            assert 2 not in tb_counts, (
                f"Label 2 (PulmonaryTuberculosis) was used in image {s[0]} "
                f"- please check!"
            )
            if 0 in tb_counts:
                if 1 not in tb_counts:
                    return 3  # patient has latent tb
                else:
                    print(
                        f"WARNING: Image {s[0]} has bboxes with both "
                        f"active and latent tb."
                    )
                    return 2  # patient has active and latent tb
            else:  # 1 in tb_counts:
                assert 0 not in tb_counts  # cannot really happen, but check...
                return 1  # patient has only active tb

        else:
            raise RuntimeError("Cannot happen - please check")



def print_statistics(d: dict):
    """Print some statistics about the dataset."""

    label_translations = {
        -1: "Unknown",
        0: "Healthy",
        1: "Active TB only",
        2: "Both active and latent TB",
        3: "Latent TB only",
        4: "Sick (but no TB)",
    }

    def _print_dataset(ds: list):
        """Print stats only for the dataset."""
        class_count = collections.Counter([k[1] for k in ds])
        for k, v in class_count.items():
            print(f"  - {label_translations[k]}: {v}")
        print(f"  - Total: {len(ds)}")

    print("Training set statistics:")
    _print_dataset(d["train"])

    print("\nValidation set statistics:")
    _print_dataset(d["validation"])

    print("\nTest set statistics:")
    _print_dataset(d["test"])

    total_samples = sum(len(ds) for ds in d.values())
    print(f"\nTotal samples in database: {total_samples}")


def create_v1_default_split(d: dict, seed: int, validation_size: float) -> dict:
    """In the v1 split, we consider active-tb cases against healthy.

    Because the test set is not annotated we do the following:

    1. The original validation set becomes the test set.
    2. The original training set is split into new training and validation
       sets.  The selection of samples is stratified (respects class
       proportions in Özgür's way - see comments)


    Parameters
    ----------

    d
        The original dataset that will be split

    seed
        The seed to use at the relevant RNG

    validation_size
        The proportion of data when we split the training set to make a
        train and validation sets.
    """

    # filter cases (only interested in labels 0:healthy or 1:active-tb)
    use_data = {
        "train": [k for k in d["train"] if k[1] in (0, 1)],
        "validation": [k for k in d["validation"] if k[1] in (0, 1)],
    }

    # Required to repeat Özgür's heuristic with labels that reverse somehow the
    # sorting for "no_tb" (instead of 0), and "active_tb" (instead of 1).
    # Reversing the labels used in the stratification process solves this
    # issue.
    targets = {0: 1, 1: 0}

    train, val = train_test_split(
        use_data["train"],
        test_size=validation_size,
        random_state=seed,
        stratify=[targets[k[1]] for k in use_data["train"]],
    )

    return {
        "train": train,
        "validation": val,
        "test": use_data["validation"],
    }


def create_v2_default_split(d: dict, seed: int, validation_size) -> dict:
    """In the v2 split, we consider active-tb cases against healthy, sick and
    latent-tb cases.

    Because the test set is not annotated we do the following:

    1. The original validation set becomes the test set.
    2. The original training set is split into new training and validation
       sets.  The selection of samples is stratified (respects class
       proportions in Özgür's way - see comments)
    """

    # filter cases (only interested in labels 0:healthy or 1:active-tb)
    use_data = {
        "train": [k for k in d["train"] if k[1] in (0, 1, 3, 4)],
        "validation": [k for k in d["validation"] if k[1] in (0, 1, 3, 4)],
    }

    # Required to repeat Özgür's heuristic with labels that reverse somehow the
    # sorting for "no_tb" (instead of 0, 3 or 4), and "active_tb" (instead of
    # 1). Reversing the labels used in the stratification process solves this
    # issue.
    targets = {0: 1, 1: 0, 3: 1, 4: 1}

    train, val = train_test_split(
        use_data["train"],
        test_size=validation_size,
        random_state=seed,
        stratify=[targets[k[1]] for k in use_data["train"]],
    )

    # These are the targets that will show up in the split.  We make everything
    # that is not active-tb to be label=0.
    split_targets = {0: 0, 1: 1, 3: 0, 4: 0}
    return {
        "train": [[k[0], split_targets[k[1]], *k[2:]] for k in train],
        "validation": [[k[0], split_targets[k[1]], *k[2:]] for k in val],
        "test": [
            [k[0], split_targets[k[1]], *k[2:]] for k in use_data["validation"]
        ],
    }


def create_folds(
    d: dict, n: int, seed: int, validation_size: float
) -> list[dict]:
    """Creates folds from existing splits.

    Parameters
    ----------

    d
        The original split to consider

    n
        The number of folds to produce


    Returns
    -------

    folds
        All the ``n`` folds
    """

    X = d["train"] + d["validation"] + d["test"]
    y = [[k[1]] for k in X]

    # Initializes a StratifiedKFold object with 10 folds
    skf = StratifiedKFold(n_splits=n, shuffle=True, random_state=seed)

    # Required to repeat Özgür's heuristic with labels that reverse somehow the
    # sorting for "no_tb" (instead of 0), and "active_tb" (instead of 1).
    # Reversing the labels used in the stratification process solves this
    # issue.
    targets = {0: 1, 1: 0}

    # Loops over the 10 folds and split the data
    retval = []
    for train_idx, test_idx in skf.split(X, y):
        # Get the training and test data for this fold
        train_dataset = [X[k] for k in train_idx]
        test_dataset = [X[k] for k in test_idx]

        # Split the training data into training and validation sets
        train_dataset, val_dataset = train_test_split(
            train_dataset,
            test_size=validation_size,
            random_state=seed,
            stratify=[targets[k[1]] for k in train_dataset],
        )

        retval.append(
            {
                "train": train_dataset,
                "validation": val_dataset,
                "test": test_dataset,
            }
        )

    return retval


def main():
    if len(sys.argv) != 1:
        print(__doc__)
        print(f"Usage: python3 {sys.argv[0]} ")
        sys.exit(0)

    # program constants used by Özgür
    seed = 42  # used to seed the relevant RNG
    validation_size = 0.203  # proportion for test when splitting
    n_folds = 10  # number of folds to create

    from clapper.rc import UserDefaults

    datadir = pathlib.Path(
        UserDefaults("ptbench.toml").get(
            "datadir.tbx11k", os.path.realpath(os.curdir)
        )
    )
    train_filename = datadir / "annotations" / "json" / "TBX11K_train.json"
    val_filename = datadir / "annotations" / "json" / "TBX11K_val.json"
    test_filename = datadir / "annotations" / "json" / "all_test.json"

    with open(train_filename) as f:
        print(f"Loading {str(train_filename)}...")
        data = json.load(f)
        train_data = normalize_labels(reorder(data))

    with open(val_filename) as f:
        print(f"Loading {str(val_filename)}...")
        data = json.load(f)
        val_data = normalize_labels(reorder(data))

    with open(test_filename) as f:
        print(f"Loading {str(test_filename)}...")
        data = json.load(f)
        test_data = reorder(data)

    final_data = {
        "train": train_data,
        "validation": val_data,
        "test": test_data,
    }
    print_statistics(final_data)

    # No need to record the re-processed data.
    # with open(sys.argv[4], "w") as fout:
    #     json.dump(final_data, fout, indent=2)

    print("\nGenerating v1 split...")
    v1_split = create_v1_default_split(
        final_data, seed=seed, validation_size=validation_size
    )
    print_statistics(v1_split)
    with open("v1-healthy-vs-atb.json", "w") as v1def:
        json.dump(v1_split, v1def, indent=2)

    # folds for the v1 split
    print(f"\nGenerating {n_folds} v1 split folds...")
    v1_folds = create_folds(
        v1_split, n=n_folds, seed=seed, validation_size=validation_size
    )
    for i, k in enumerate(v1_folds):
        with open(f"v1-fold-{i}.json", "w") as v1fold:
            json.dump(k, v1fold, indent=2)

    print("\nGenerating v2 split...")
    v2_split = create_v2_default_split(
        final_data, seed=seed, validation_size=validation_size
    )
    print_statistics(v2_split)
    with open("v2-others-vs-atb.json", "w") as v2def:
        json.dump(v2_split, v2def, indent=2)

    # folds for the v2 split
    print(f"\nGenerating {n_folds} v2 split folds...")
    v2_folds = create_folds(
        v2_split, n=n_folds, seed=seed, validation_size=validation_size
    )
    for i, k in enumerate(v2_folds):
        with open(f"v2-fold-{i}.json", "w") as v2fold:
            json.dump(k, v2fold, indent=2)


if __name__ == "__main__":
    main()