Skip to content
Snippets Groups Projects
Commit 2631da1d authored by André Anjos's avatar André Anjos :speech_balloon:
Browse files

Merge branch '3d-cnn-visceral' into 'main'

3d cnn visceral

See merge request biosignal/software/mednet!51
parents 232ccc51 cfff0324
No related branches found
No related tags found
1 merge request!513d cnn visceral
Pipeline #89232 passed
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""Generate visceral default JSON dataset for 3d binary classification tasks in mednet.
Arguments of the scripts are as follow:
root-of-preprocessed-visceral-dataset
Full path to the root of the preprocessed visceral dataset.
Filenames in the resulting json are relative to this path.
See output format below.
output-folder
Full path to the folder where to output the default.json file containing the default split of data.
organ_1_id
Integer representing the ID of the first organ to include in the split dataset.
This organ will be labeled as 0 for the binary classification task.
For example, 237 corresponds to bladder.
organ_2_id
Integer representing the ID of the second organ to include in the split dataset.
This organ will be labeled as 1 for the binary classification task.
For example, 237 corresponds to bladder.
Output format is the following:
.. code:: json
{
"train": [
[
"<size>/<filename>",
# label is one of:
# 0: organ_1 / 1: organ_2
<label>,
],
...
],
"validation": [
# same format as for train
...
]
"test": [
# same format as for train
...
]
"""
import json
import os
import pathlib
import sys
from sklearn.model_selection import train_test_split
def split_files(
files: list[str],
train_size: float = 0.7,
test_size: float = 0.2,
validation_size: float = 0.1,
):
train_files, temp_files = train_test_split(files, test_size=(1 - train_size))
test_files, validation_files = train_test_split(
temp_files, test_size=(validation_size / (test_size + validation_size))
)
return train_files, test_files, validation_files
def save_to_json(
train_files: list[str],
test_files: list[str],
validation_files: list[str],
output_file: str,
organ_1_id: str,
):
data = {
"train": [
[filename, 0 if organ_1_id in filename else 1] for filename in train_files
],
"test": [
[filename, 0 if organ_1_id in filename else 1] for filename in test_files
],
"validation": [
[filename, 0 if organ_1_id in filename else 1]
for filename in validation_files
],
}
with pathlib.Path(output_file).open("w") as json_file:
json.dump(data, json_file, indent=2)
def main():
if len(sys.argv) != 6:
print(__doc__)
print(
f"Usage: python3 {sys.argv[0]} <root-of-preprocessed-visceral-dataset> <output-folder> <organ_1_id> <organ_2_id> <size>"
)
sys.exit(0)
root_folder = sys.argv[1]
output_folder = sys.argv[2]
organ_1_id = sys.argv[3]
organ_2_id = sys.argv[4]
size = sys.argv[5]
output_file = pathlib.Path(output_folder) / "default.json"
input_folder = pathlib.Path(root_folder) / size
files = [
f"{size}/{file}"
for file in os.listdir(input_folder)
if organ_1_id in file or organ_2_id in file
]
train_files, test_files, validation_files = split_files(files)
save_to_json(train_files, test_files, validation_files, output_file, organ_1_id)
print(f"Data saved to {output_file}")
if __name__ == "__main__":
main()
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""Preprocess visceral dataset to prepare volume cubes for organ classification using 3D cnn.
Example of use: 'python visceral_preprocess.py /idiap/temp/ojimenez/previous-projects/VISCERAL /idiap/home/ypannatier/visceral/preprocessed 237 16 '.
Arguments of the scripts are as follow:
root-of-visceral-dataset
Full path to the root of the visceral dataset.
output_folder_path
Full path to the folder where the prepared cubes will be saved.
Note that the script will create a subfolder corresponding to the desired size of cube to avoid mixing volumes of different sizes.
organ_id
Integer representing the ID of the organ to process in the visceral dataset. For example, 237 corresponds to bladder.
size
Integer representing the size of the volume cube that will be output. Each volume will be of dimension SIZExSIZExSIZE.
"""
import os
import pathlib
import sys
import torch
import torchio as tio
def get_mask_cube(mask_path: str, size: int) -> torch.Tensor:
"""Create a mask of dimension SIZExSIZExSIZE from the input mask at the center of the non-zero area.
Parameters
----------
mask_path
A string representing the full path to the mask file.
size
An integer representing the size of the cube.
Returns
-------
Tensor
The mask tensor representing a volume of dimension SIZExSIZExSIZE.
"""
mask_image = tio.ScalarImage(mask_path)
mask_data = mask_image.data.bool().squeeze()
mask_image.unload()
ones_coords = torch.nonzero(mask_data)
center = torch.mean(ones_coords.float(), dim=0).long()
half_size = int(size) // 2
start_coords = center - half_size
end_coords = center + half_size
result = torch.zeros_like(mask_data)
slices = [
slice(max(0, start), min(end, dim))
for start, end, dim in zip(start_coords, end_coords, mask_data.shape)
]
result[slices[0], slices[1], slices[2]] = 1
return result
def get_masks(mask_paths: list[str], filters: list[str], volume_ids: list[str]):
"""Find the list of usable masks corresponding to the desired organ.
Parameters
----------
mask_paths
A list of strings representing the folders in which to search for masks.
filters
A list of strings corresponding to a substring of mask's file name. A valid mask must match one of the filters.
Each filter should contains the organ id.
Example: To get the white/black CT scans corresponding to an organ: "_1_CT_wb_{organ_id}".
volume_ids
A list containing the list of all patient ids retrieved from the volume dataset.
Valid masks must start by one of entry of volume ids.
Returns
-------
The list of valid masks as a list of strings. Each list contains 2 entries. First, the path to the mask folder, second the mask file name.
"""
masks = []
for mask_path in mask_paths:
for mask_filter in filters:
available_masks = os.listdir(mask_path)
for mask in available_masks:
if (
mask.endswith(".nii.gz")
and mask_filter in mask
and mask.split("_")[0] in volume_ids
):
masks.append([mask_path, mask])
return masks
def main():
if len(sys.argv) != 5:
print(__doc__)
print(
f"Usage: python3 {sys.argv[0]} <root-of-visceral-dataset> <output_folder_path> <organ_id> <size>"
)
sys.exit(0)
root_path = sys.argv[1]
output_path = sys.argv[2]
organ_id = sys.argv[3]
size = sys.argv[4]
filters = [f"_1_CTce_ThAb_{organ_id}", f"_1_CT_wb_{organ_id}", f"_1_{organ_id}"]
annot_2_mask_path = (
pathlib.Path(root_path)
/ "annotations"
/ "Anatomy2"
/ "anat2-trainingset"
/ "Anat2_Segmentations"
)
annot_3_mask_path = (
pathlib.Path(root_path)
/ "annotations"
/ "Anatomy3"
/ "Visceral-QC-testset"
/ "qc-anat3-testset-segmentations"
)
silver_corpus_mask_path = (
pathlib.Path(root_path)
/ "annotations"
/ "SilverCorpus"
/ "AnatomySilverCorpus"
/ "BinarySegmentations"
)
mask_paths = [annot_2_mask_path, annot_3_mask_path, silver_corpus_mask_path]
volume_path = (
pathlib.Path(root_path)
/ "volumes_for_annotation"
/ "GeoS_oriented_Volumes(Annotators)"
)
output_size_path = pathlib.Path(output_path) / size
# Ensure required output folders exist
output_size_path.mkdir(parents=True, exist_ok=True)
volumes = os.listdir(volume_path)
volume_ids = [volume.split("_")[0] for volume in volumes]
masks = get_masks(mask_paths, filters, volume_ids)
print(f"Found {len(masks)} volumes to process...")
print("Generating volumes...")
for i, mask in enumerate(masks):
if i % 10 == 0:
print(f"Generated volumes: {i}/{len(masks)}")
patient_id = f'{mask[1].split("_")[0]}_1'
full_path = pathlib.Path(mask[0]) / mask[1]
volume_name = [volume for volume in volumes if patient_id in volume][0]
volume_full_path = pathlib.Path(volume_path) / volume_name
output_full_path = pathlib.Path(output_size_path) / mask[1]
if pathlib.Path.exists(output_full_path):
continue
try:
mask_cube = get_mask_cube(full_path, size)
volume_image = tio.ScalarImage(volume_full_path)
volume_data = volume_image.data.squeeze()
volume_image.unload()
volume_data = (
volume_data[mask_cube == 1]
.reshape(int(size), int(size), int(size))
.unsqueeze(0)
)
cropped_volume = tio.ScalarImage(tensor=volume_data)
cropped_volume.save(output_full_path)
except Exception as e:
print(f"Error: {e} while processing {full_path}")
if __name__ == "__main__":
main()
This diff is collapsed.
......@@ -102,6 +102,7 @@ tensorboard = "*"
torchvision = { version = "~=0.17.2", channel = "pytorch" }
tqdm = "*"
versioningit = "*"
torchio = ">=0.19.7,<0.20"
[tool.pixi.feature.self.pypi-dependencies]
mednet = { path = ".", editable = true }
......@@ -245,6 +246,9 @@ alexnet-pretrained = "mednet.config.models.alexnet_pretrained"
densenet = "mednet.config.models.densenet"
densenet-pretrained = "mednet.config.models.densenet_pretrained"
# 3D models
cnn3d = "mednet.config.models.cnn3d"
# lists of data augmentations
affine = "mednet.config.augmentations.affine"
elastic = "mednet.config.augmentations.elastic"
......@@ -409,6 +413,9 @@ nih-cxr14-padchest = "mednet.config.data.nih_cxr14_padchest.idiap"
# montgomery-shenzhen-indian-padchest aggregated dataset
montgomery-shenzhen-indian-padchest = "mednet.config.data.montgomery_shenzhen_indian_padchest.default"
# VISCERAL dataset
visceral = "mednet.config.data.visceral.default"
[tool.ruff]
line-length = 88
target-version = "py310"
......
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""VISCERAL dataset for 3D organ classification (only lungs and bladders).
Loaded samples are not full scans but 16x16x16 volumes of organs.
Database reference:
"""
import os
import pathlib
import torchio as tio
from ....data.datamodule import CachingDataModule
from ....data.split import make_split
from ....data.typing import RawDataLoader as _BaseRawDataLoader
from ....data.typing import Sample
from ....utils.rc import load_rc
CONFIGURATION_KEY_DATADIR = "datadir." + (__name__.rsplit(".", 2)[-2])
"""Key to search for in the configuration file for the root directory of this
database."""
class RawDataLoader(_BaseRawDataLoader):
"""A specialized raw-data-loader for the VISCERAL dataset."""
datadir: pathlib.Path
"""This variable contains the base directory where the database raw data is
stored."""
def __init__(self) -> None:
self.datadir = pathlib.Path(
load_rc().get(
CONFIGURATION_KEY_DATADIR,
os.path.realpath(os.curdir),
),
)
def sample(self, sample: tuple[str, int]) -> Sample:
"""Load a single volume sample from the disk.
Parameters
----------
sample
A tuple containing the path suffix, within the dataset root folder,
where to find the volume to be loaded and an integer, representing
the sample label.
Returns
-------
The sample representation.
"""
clamp = tio.Clamp(out_min=-1000, out_max=2000)
rescale = tio.RescaleIntensity(percentiles=(0.5, 99.5))
preprocess = tio.Compose([clamp, rescale])
image = tio.ScalarImage(self.datadir / sample[0])
image = preprocess(image)
tensor = image.data
return tensor, dict(label=sample[1], name=sample[0])
def label(self, sample: tuple[str, int]) -> int:
"""Load a single image sample label from the disk.
Parameters
----------
sample
A tuple containing the path suffix, within the dataset root folder,
where to find the image to be loaded, and an integer, representing
the sample label.
Returns
-------
int
The integer label associated with the sample.
"""
return sample[1]
class DataModule(CachingDataModule):
"""VISCERAL DataModule for 3D organ binary classification.
Data specifications:
* Raw data input (on disk):
* NIfTI volumes
* resolution: 16x16x16
* Output image:
* Transforms:
* Load raw NIfTI with :py:mod:`torchio`
* Clamp and Rescale intensity
* Convert to torch tensor
* Final specifications
* 32-bit floats, cubes 16x16x16 pixels
* Labels: 0 (bladder), 1 (lung)
Parameters
----------
split_filename
Name of the .json file containing the split to load.
"""
def __init__(self, split_filename: str):
super().__init__(
make_split(__package__, split_filename),
raw_data_loader=RawDataLoader(),
database_name=__package__.split(".")[-1],
split_name=pathlib.Path(split_filename).stem,
)
{
"train": [
["16/10000013_1_1302_117.nii.gz",1],
["16/10000164_1_1302_117.nii.gz",1],
["16/10000115_1_1302_117.nii.gz",1],
["16/10000072_1_237_117.nii.gz",0],
["16/10000148_1_CTce_ThAb_1302_6.nii.gz",1],
["16/10000022_1_CT_wb_1302_4.nii.gz",1],
["16/10000163_1_237_117.nii.gz",0],
["16/10000005_1_CT_wb_1302_7.nii.gz",1],
["16/10000113_1_CTce_ThAb_1302_6.nii.gz",1],
["16/10000140_1_CTce_ThAb_237_5.nii.gz",0],
["16/10000137_1_CTce_ThAb_237_7.nii.gz",0],
["16/10000009_1_237_117.nii.gz",0],
["16/10000090_1_CT_wb_1302_6.nii.gz",1],
["16/10000115_1_237_117.nii.gz",0],
["16/10000101_1_237_117.nii.gz",0],
["16/10000168_1_1302_117.nii.gz",1],
["16/10000054_1_1302_117.nii.gz",1],
["16/10000059_1_237_117.nii.gz",0],
["16/10000178_1_237_117.nii.gz",0],
["16/10000053_1_237_117.nii.gz",0],
["16/10000043_1_1302_117.nii.gz",1],
["16/10000018_1_CT_wb_1302_4.nii.gz",1],
["16/10000165_1_1302_117.nii.gz",1],
["16/10000070_1_1302_117.nii.gz",1],
["16/10000113_1_CTce_ThAb_237_6.nii.gz",0],
["16/10000080_1_CT_wb_237_7.nii.gz",0],
["16/10000145_1_CTce_ThAb_237_9.nii.gz",0],
["16/10000140_1_CTce_ThAb_1302_5.nii.gz",1],
["16/10000075_1_1302_117.nii.gz",1],
["16/10000046_1_237_117.nii.gz",0],
["16/10000108_1_CTce_ThAb_237_4.nii.gz",0],
["16/10000109_1_CTce_ThAb_237_6.nii.gz",0],
["16/10000055_1_1302_117.nii.gz",1],
["16/10000004_1_1302_117.nii.gz",1],
["16/10000150_1_237_117.nii.gz",0],
["16/10000040_1_1302_117.nii.gz",1],
["16/10000158_1_237_117.nii.gz",0],
["16/10000183_1_1302_117.nii.gz",1],
["16/10000025_1_CT_wb_237_8.nii.gz",0],
["16/10000064_1_1302_117.nii.gz",1],
["16/10000193_1_1302_117.nii.gz",1],
["16/10000122_1_1302_117.nii.gz",1],
["16/10000105_1_CTce_ThAb_237_4.nii.gz",0],
["16/10000034_1_237_117.nii.gz",0],
["16/10000112_1_CTce_ThAb_1302_8.nii.gz",1],
["16/10000119_1_237_117.nii.gz",0],
["16/10000138_1_CTce_ThAb_1302_8.nii.gz",1],
["16/10000172_1_237_117.nii.gz",0],
["16/10000072_1_1302_117.nii.gz",1],
["16/10000012_1_1302_117.nii.gz",1],
["16/10000078_1_1302_117.nii.gz",1],
["16/10000096_1_CT_wb_1302_8.nii.gz",1],
["16/10000125_1_1302_117.nii.gz",1],
["16/10000191_1_1302_117.nii.gz",1],
["16/10000055_1_237_117.nii.gz",0],
["16/10000049_1_1302_117.nii.gz",1],
["16/10000141_1_CTce_ThAb_237_5.nii.gz",0],
["16/10000006_1_CT_wb_1302_8.nii.gz",1],
["16/10000087_1_CT_wb_1302_4.nii.gz",1],
["16/10000070_1_237_117.nii.gz",0],
["16/10000080_1_CT_wb_1302_7.nii.gz",1],
["16/10000015_1_CT_wb_1302_6.nii.gz",1],
["16/10000011_1_CT_wb_237_7.nii.gz",0],
["16/10000159_1_1302_117.nii.gz",1],
["16/10000077_1_237_117.nii.gz",0],
["16/10000087_1_CT_wb_237_4.nii.gz",0],
["16/10000088_1_CT_wb_1302_4.nii.gz",1],
["16/10000116_1_1302_117.nii.gz",1],
["16/10000077_1_1302_117.nii.gz",1],
["16/10000056_1_237_117.nii.gz",0],
["16/10000054_1_237_117.nii.gz",0],
["16/10000151_1_237_117.nii.gz",0],
["16/10000032_1_1302_117.nii.gz",1],
["16/10000051_1_237_117.nii.gz",0],
["16/10000153_1_1302_117.nii.gz",1],
["16/10000019_1_CT_wb_237_6.nii.gz",0],
["16/10000044_1_237_117.nii.gz",0],
["16/10000175_1_237_117.nii.gz",0],
["16/10000143_1_CTce_ThAb_1302_7.nii.gz",1],
["16/10000160_1_1302_117.nii.gz",1],
["16/10000169_1_1302_117.nii.gz",1],
["16/10000048_1_1302_117.nii.gz",1],
["16/10000203_1_1302_117.nii.gz",1],
["16/10000099_1_CT_wb_1302_5.nii.gz",1],
["16/10000157_1_1302_117.nii.gz",1],
["16/10000046_1_1302_117.nii.gz",1],
["16/10000147_1_CTce_ThAb_1302_8.nii.gz",1],
["16/10000095_1_CT_wb_237_7.nii.gz",0],
["16/10000190_1_237_117.nii.gz",0],
["16/10000179_1_237_117.nii.gz",0],
["16/10000007_1_1302_117.nii.gz",1],
["16/10000189_1_1302_117.nii.gz",1],
["16/10000156_1_1302_117.nii.gz",1],
["16/10000076_1_237_117.nii.gz",0],
["16/10000196_1_237_117.nii.gz",0],
["16/10000150_1_1302_117.nii.gz",1],
["16/10000106_1_CTce_ThAb_237_7.nii.gz",0],
["16/10000200_1_237_117.nii.gz",0],
["16/10000023_1_CT_wb_1302_6.nii.gz",1],
["16/10000168_1_237_117.nii.gz",0],
["16/10000104_1_CTce_ThAb_237_7.nii.gz",0],
["16/10000138_1_CTce_ThAb_237_9.nii.gz",0],
["16/10000068_1_1302_117.nii.gz",1],
["16/10000041_1_237_117.nii.gz",0],
["16/10000154_1_1302_117.nii.gz",1],
["16/10000136_1_CTce_ThAb_237_7.nii.gz",0],
["16/10000189_1_237_117.nii.gz",0],
["16/10000188_1_237_117.nii.gz",0],
["16/10000091_1_1302_117.nii.gz",1],
["16/10000117_1_1302_117.nii.gz",1],
["16/10000179_1_1302_117.nii.gz",1],
["16/10000129_1_CTce_ThAb_1302_5.nii.gz",1],
["16/10000088_1_CT_wb_237_4.nii.gz",0],
["16/10000081_1_CT_wb_1302_6.nii.gz",1],
["16/10000127_1_CTce_ThAb_1302_4.nii.gz",1],
["16/10000161_1_1302_117.nii.gz",1],
["16/10000192_1_237_117.nii.gz",0],
["16/10000162_1_1302_117.nii.gz",1],
["16/10000123_1_237_117.nii.gz",0],
["16/10000194_1_1302_117.nii.gz",1],
["16/10000104_1_CTce_ThAb_1302_7.nii.gz",1],
["16/10000142_1_CTce_ThAb_237_7.nii.gz",0],
["16/10000124_1_1302_117.nii.gz",1],
["16/10000038_1_237_117.nii.gz",0],
["16/10000090_1_CT_wb_237_6.nii.gz",0],
["16/10000126_1_237_117.nii.gz",0],
["16/10000047_1_1302_117.nii.gz",1],
["16/10000026_1_1302_117.nii.gz",1],
["16/10000132_1_CTce_ThAb_237_8.nii.gz",0],
["16/10000037_1_237_117.nii.gz",0],
["16/10000126_1_1302_117.nii.gz",1],
["16/10000148_1_CTce_ThAb_237_6.nii.gz",0],
["16/10000142_1_CTce_ThAb_1302_7.nii.gz",1],
["16/10000170_1_1302_117.nii.gz",1],
["16/10000057_1_237_117.nii.gz",0],
["16/10000040_1_237_117.nii.gz",0],
["16/10000079_1_CT_wb_1302_7.nii.gz",1],
["16/10000133_1_CTce_ThAb_1302_9.nii.gz",1],
["16/10000129_1_CTce_ThAb_237_5.nii.gz",0],
["16/10000201_1_237_117.nii.gz",0],
["16/10000084_1_237_117.nii.gz",0],
["16/10000173_1_1302_117.nii.gz",1],
["16/10000064_1_237_117.nii.gz",0],
["16/10000130_1_CTce_ThAb_1302_5.nii.gz",1],
["16/10000094_1_CT_wb_237_7.nii.gz",0],
["16/10000065_1_1302_117.nii.gz",1],
["16/10000114_1_1302_117.nii.gz",1],
["16/10000042_1_237_117.nii.gz",0],
["16/10000048_1_237_117.nii.gz",0],
["16/10000151_1_1302_117.nii.gz",1],
["16/10000060_1_1302_117.nii.gz",1],
["16/10000018_1_CT_wb_237_4.nii.gz",0],
["16/10000024_1_1302_117.nii.gz",1],
["16/10000187_1_1302_117.nii.gz",1],
["16/10000111_1_CTce_ThAb_237_6.nii.gz",0],
["16/10000050_1_1302_117.nii.gz",1],
["16/10000128_1_CTce_ThAb_1302_5.nii.gz",1],
["16/10000134_1_CTce_ThAb_1302_6.nii.gz",1],
["16/10000119_1_1302_117.nii.gz",1],
["16/10000092_1_CT_wb_1302_8.nii.gz",1],
["16/10000205_1_1302_117.nii.gz",1],
["16/10000169_1_237_117.nii.gz",0],
["16/10000155_1_1302_117.nii.gz",1],
["16/10000044_1_1302_117.nii.gz",1],
["16/10000134_1_CTce_ThAb_237_6.nii.gz",0],
["16/10000016_1_CT_wb_1302_9.nii.gz",1],
["16/10000031_1_237_117.nii.gz",0],
["16/10000020_1_CT_wb_237_8.nii.gz",0],
["16/10000122_1_237_117.nii.gz",0],
["16/10000123_1_1302_117.nii.gz",1],
["16/10000062_1_237_117.nii.gz",0],
["16/10000042_1_1302_117.nii.gz",1],
["16/10000135_1_CTce_ThAb_1302_7.nii.gz",1],
["16/10000106_1_CTce_ThAb_1302_7.nii.gz",1],
["16/10000051_1_1302_117.nii.gz",1],
["16/10000198_1_237_117.nii.gz",0],
["16/10000136_1_CTce_ThAb_1302_7.nii.gz",1],
["16/10000045_1_237_117.nii.gz",0],
["16/10000147_1_CTce_ThAb_237_8.nii.gz",0],
["16/10000086_1_CT_wb_1302_6.nii.gz",1],
["16/10000131_1_CTce_ThAb_1302_4.nii.gz",1],
["16/10000174_1_237_117.nii.gz",0],
["16/10000170_1_237_117.nii.gz",0],
["16/10000049_1_237_117.nii.gz",0],
["16/10000116_1_237_117.nii.gz",0],
["16/10000137_1_CTce_ThAb_1302_7.nii.gz",1],
["16/10000111_1_CTce_ThAb_1302_6.nii.gz",1],
["16/10000180_1_1302_117.nii.gz",1],
["16/10000075_1_237_117.nii.gz",0],
["16/10000161_1_237_117.nii.gz",0],
["16/10000011_1_CT_wb_1302_7.nii.gz",1],
["16/10000094_1_CT_wb_1302_7.nii.gz",1],
["16/10000015_1_CT_wb_237_6.nii.gz",0],
["16/10000145_1_CTce_ThAb_1302_7.nii.gz",1],
["16/10000085_1_CT_wb_1302_6.nii.gz",1],
["16/10000033_1_237_117.nii.gz",0],
["16/10000097_1_1302_117.nii.gz",1],
["16/10000141_1_CTce_ThAb_1302_5.nii.gz",1],
["16/10000130_1_CTce_ThAb_237_5.nii.gz",0],
["16/10000171_1_237_117.nii.gz",0],
["16/10000163_1_1302_117.nii.gz",1],
["16/10000021_1_CT_wb_1302_9.nii.gz",1],
["16/10000184_1_237_117.nii.gz",0],
["16/10000114_1_237_117.nii.gz",0],
["16/10000160_1_237_117.nii.gz",0],
["16/10000071_1_237_117.nii.gz",0],
["16/10000084_1_1302_117.nii.gz",1],
["16/10000135_1_CTce_ThAb_237_7.nii.gz",0],
["16/10000117_1_237_117.nii.gz",0],
["16/10000149_1_CTce_ThAb_1302_8.nii.gz",1],
["16/10000162_1_237_117.nii.gz",0],
["16/10000025_1_CT_wb_1302_8.nii.gz",1],
["16/10000100_1_CTce_ThAb_237_6.nii.gz",0],
["16/10000200_1_1302_117.nii.gz",1],
["16/10000096_1_CT_wb_237_9.nii.gz",0],
["16/10000024_1_237_117.nii.gz",0],
["16/10000118_1_1302_117.nii.gz",1],
["16/10000112_1_CTce_ThAb_237_8.nii.gz",0],
["16/10000159_1_237_117.nii.gz",0],
["16/10000196_1_1302_117.nii.gz",1],
["16/10000204_1_237_117.nii.gz",0],
["16/10000019_1_CT_wb_1302_6.nii.gz",1],
["16/10000047_1_237_117.nii.gz",0],
["16/10000035_1_237_117.nii.gz",0],
["16/10000012_1_237_117.nii.gz",0],
["16/10000128_1_CTce_ThAb_237_5.nii.gz",0],
["16/10000174_1_1302_117.nii.gz",1],
["16/10000013_1_237_117.nii.gz",0],
["16/10000056_1_1302_117.nii.gz",1],
["16/10000020_1_CT_wb_1302_8.nii.gz",1],
["16/10000076_1_1302_117.nii.gz",1],
["16/10000177_1_237_117.nii.gz",0],
["16/10000198_1_1302_117.nii.gz",1],
["16/10000078_1_237_117.nii.gz",0],
["16/10000125_1_237_117.nii.gz",0],
["16/10000082_1_CT_wb_1302_6.nii.gz",1],
["16/10000091_1_237_117.nii.gz",0],
["16/10000085_1_CT_wb_237_6.nii.gz",0],
["16/10000089_1_CT_wb_1302_8.nii.gz",1],
["16/10000110_1_CTce_ThAb_1302_9.nii.gz",1],
["16/10000053_1_1302_117.nii.gz",1]
],
"validation": [
["16/10000203_1_237_117.nii.gz",0],
["16/10000022_1_CT_wb_237_4.nii.gz",0],
["16/10000017_1_CT_wb_237_9.nii.gz",0],
["16/10000166_1_1302_117.nii.gz",1],
["16/10000016_1_CT_wb_237_9.nii.gz",0],
["16/10000057_1_1302_117.nii.gz",1],
["16/10000165_1_237_117.nii.gz",0],
["16/10000021_1_CT_wb_237_9.nii.gz",0],
["16/10000098_1_237_117.nii.gz",0],
["16/10000068_1_237_117.nii.gz",0],
["16/10000095_1_CT_wb_1302_7.nii.gz",1],
["16/10000067_1_CT_wb_237_9.nii.gz",0],
["16/10000069_1_1302_117.nii.gz",1],
["16/10000201_1_1302_117.nii.gz",1],
["16/10000132_1_CTce_ThAb_1302_8.nii.gz",1],
["16/10000181_1_1302_117.nii.gz",1],
["16/10000101_1_1302_117.nii.gz",1],
["16/10000067_1_CT_wb_1302_9.nii.gz",1],
["16/10000027_1_237_117.nii.gz",0],
["16/10000187_1_237_117.nii.gz",0],
["16/10000065_1_237_117.nii.gz",0],
["16/10000186_1_237_117.nii.gz",0],
["16/10000186_1_1302_117.nii.gz",1],
["16/10000185_1_1302_117.nii.gz",1],
["16/10000097_1_237_117.nii.gz",0],
["16/10000023_1_CT_wb_237_6.nii.gz",0],
["16/10000098_1_1302_117.nii.gz",1],
["16/10000099_1_CT_wb_237_5.nii.gz",0],
["16/10000073_1_1302_117.nii.gz",1],
["16/10000184_1_1302_117.nii.gz",1],
["16/10000093_1_1302_117.nii.gz",1],
["16/10000086_1_CT_wb_237_6.nii.gz",0],
["16/10000007_1_237_117.nii.gz",0],
["16/10000038_1_1302_117.nii.gz",1],
["16/10000050_1_237_117.nii.gz",0],
["16/10000093_1_237_117.nii.gz",0],
["16/10000183_1_237_117.nii.gz",0],
["16/10000060_1_237_117.nii.gz",0],
["16/10000089_1_CT_wb_237_8.nii.gz",0],
["16/10000124_1_237_117.nii.gz",0],
["16/10000152_1_237_117.nii.gz",0],
["16/10000158_1_1302_117.nii.gz",1],
["16/10000193_1_237_117.nii.gz",0],
["16/10000167_1_237_117.nii.gz",0],
["16/10000008_1_237_117.nii.gz",0],
["16/10000071_1_1302_117.nii.gz",1],
["16/10000171_1_1302_117.nii.gz",1],
["16/10000120_1_237_117.nii.gz",0],
["16/10000100_1_CTce_ThAb_1302_6.nii.gz",1],
["16/10000154_1_237_117.nii.gz",0],
["16/10000082_1_CT_wb_237_6.nii.gz",0],
["16/10000120_1_1302_117.nii.gz",1],
["16/10000175_1_1302_117.nii.gz",1],
["16/10000026_1_237_117.nii.gz",0],
["16/10000178_1_1302_117.nii.gz",1],
["16/10000192_1_1302_117.nii.gz",1],
["16/10000155_1_237_117.nii.gz",0],
["16/10000152_1_1302_117.nii.gz",1],
["16/10000005_1_CT_wb_237_7.nii.gz",0],
["16/10000191_1_237_117.nii.gz",0],
["16/10000073_1_237_117.nii.gz",0],
["16/10000181_1_237_117.nii.gz",0],
["16/10000014_1_CT_wb_1302_5.nii.gz",1],
["16/10000017_1_CT_wb_1302_9.nii.gz",1],
["16/10000081_1_CT_wb_237_6.nii.gz",0],
["16/10000079_1_CT_wb_237_7.nii.gz",0],
["16/10000153_1_237_117.nii.gz",0],
["16/10000164_1_237_117.nii.gz",0],
["16/10000167_1_1302_117.nii.gz",1]
],
"test": [
["16/10000109_1_CTce_ThAb_1302_6.nii.gz",1],
["16/10000039_1_237_117.nii.gz",0],
["16/10000118_1_237_117.nii.gz",0],
["16/10000190_1_1302_117.nii.gz",1],
["16/10000110_1_CTce_ThAb_237_9.nii.gz",0],
["16/10000180_1_237_117.nii.gz",0],
["16/10000131_1_CTce_ThAb_237_8.nii.gz",0],
["16/10000121_1_1302_117.nii.gz",1],
["16/10000127_1_CTce_ThAb_237_8.nii.gz",0],
["16/10000173_1_237_117.nii.gz",0],
["16/10000166_1_237_117.nii.gz",0],
["16/10000105_1_CTce_ThAb_1302_4.nii.gz",1],
["16/10000177_1_1302_117.nii.gz",1],
["16/10000031_1_1302_117.nii.gz",1],
["16/10000188_1_1302_117.nii.gz",1],
["16/10000069_1_237_117.nii.gz",0],
["16/10000143_1_CTce_ThAb_237_7.nii.gz",0],
["16/10000014_1_CT_wb_237_5.nii.gz",0],
["16/10000133_1_CTce_ThAb_237_8.nii.gz",0],
["16/10000199_1_1302_117.nii.gz",1],
["16/10000205_1_237_117.nii.gz",0],
["16/10000121_1_237_117.nii.gz",0],
["16/10000149_1_CTce_ThAb_237_8.nii.gz",0],
["16/10000199_1_237_117.nii.gz",0],
["16/10000032_1_237_117.nii.gz",0],
["16/10000156_1_237_117.nii.gz",0],
["16/10000092_1_CT_wb_237_4.nii.gz",0],
["16/10000037_1_1302_117.nii.gz",1],
["16/10000172_1_1302_117.nii.gz",1],
["16/10000043_1_237_117.nii.gz",0],
["16/10000204_1_1302_117.nii.gz",1],
["16/10000185_1_237_117.nii.gz",0],
["16/10000108_1_CTce_ThAb_1302_4.nii.gz",1],
["16/10000194_1_237_117.nii.gz",0],
["16/10000059_1_1302_117.nii.gz",1]
]
}
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""VISCERAL dataset for 3D organ classification.
Database reference:
See :py:class:`mednet.config.data.visceral.datamodule.DataModule` for
technical details.
"""
from mednet.config.data.visceral.datamodule import DataModule
datamodule = DataModule("default.json")
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""Simple CNN for 3D organ classification, to be trained from scratch."""
from torch.nn import BCEWithLogitsLoss
from torch.optim import Adam
from mednet.models.cnn3d import Conv3DNet
model = Conv3DNet(
loss_type=BCEWithLogitsLoss,
optimizer_type=Adam,
optimizer_arguments=dict(lr=8e-5),
)
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
import logging
import typing
import torch
import torch.nn as nn
import torch.nn.functional as F # noqa: N812
import torch.optim.optimizer
import torch.utils.data
from ..data.typing import TransformSequence
from .model import Model
from .separate import separate
logger = logging.getLogger(__name__)
class Conv3DNet(Model):
"""Implementation of 3D CNN.
This network has a linear output. You should use losses with ``WithLogit``
instead of cross-entropy versions when training.
Parameters
----------
loss_type
The loss to be used for training and evaluation.
.. warning::
The loss should be set to always return batch averages (as opposed
to the batch sum), as our logging system expects it so.
loss_arguments
Arguments to the loss.
optimizer_type
The type of optimizer to use for training.
optimizer_arguments
Arguments to the optimizer after ``params``.
augmentation_transforms
An optional sequence of torch modules containing transforms to be
applied on the input **before** it is fed into the network.
num_classes
Number of outputs (classes) for this model.
"""
def __init__(
self,
loss_type: torch.nn.Module = torch.nn.BCEWithLogitsLoss,
loss_arguments: dict[str, typing.Any] = {},
optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam,
optimizer_arguments: dict[str, typing.Any] = {},
augmentation_transforms: TransformSequence = [],
num_classes: int = 1,
):
super().__init__(
loss_type,
loss_arguments,
optimizer_type,
optimizer_arguments,
augmentation_transforms,
num_classes,
)
self.name = "cnn3D"
self.num_classes = num_classes
self.model_transforms = []
# First convolution block
self.conv3d_1_1 = nn.Conv3d(
in_channels=1, out_channels=4, kernel_size=3, stride=1, padding=1
)
self.conv3d_1_2 = nn.Conv3d(
in_channels=4, out_channels=16, kernel_size=3, stride=1, padding=1
)
self.conv3d_1_3 = nn.Conv3d(
in_channels=1, out_channels=16, kernel_size=1, stride=1
)
self.batch_norm_1_1 = nn.BatchNorm3d(4)
self.batch_norm_1_2 = nn.BatchNorm3d(16)
self.batch_norm_1_3 = nn.BatchNorm3d(16)
# Second convolution block
self.conv3d_2_1 = nn.Conv3d(
in_channels=16, out_channels=24, kernel_size=3, stride=1, padding=1
)
self.conv3d_2_2 = nn.Conv3d(
in_channels=24, out_channels=32, kernel_size=3, stride=1, padding=1
)
self.conv3d_2_3 = nn.Conv3d(
in_channels=16, out_channels=32, kernel_size=1, stride=1
)
self.batch_norm_2_1 = nn.BatchNorm3d(24)
self.batch_norm_2_2 = nn.BatchNorm3d(32)
self.batch_norm_2_3 = nn.BatchNorm3d(32)
# Third convolution block
self.conv3d_3_1 = nn.Conv3d(
in_channels=32, out_channels=40, kernel_size=3, stride=1, padding=1
)
self.conv3d_3_2 = nn.Conv3d(
in_channels=40, out_channels=48, kernel_size=3, stride=1, padding=1
)
self.conv3d_3_3 = nn.Conv3d(
in_channels=32, out_channels=48, kernel_size=1, stride=1
)
self.batch_norm_3_1 = nn.BatchNorm3d(40)
self.batch_norm_3_2 = nn.BatchNorm3d(48)
self.batch_norm_3_3 = nn.BatchNorm3d(48)
# Fourth convolution block
self.conv3d_4_1 = nn.Conv3d(
in_channels=48, out_channels=56, kernel_size=3, stride=1, padding=1
)
self.conv3d_4_2 = nn.Conv3d(
in_channels=56, out_channels=64, kernel_size=3, stride=1, padding=1
)
self.conv3d_4_3 = nn.Conv3d(
in_channels=48, out_channels=64, kernel_size=1, stride=1
)
self.batch_norm_4_1 = nn.BatchNorm3d(56)
self.batch_norm_4_2 = nn.BatchNorm3d(64)
self.batch_norm_4_3 = nn.BatchNorm3d(64)
self.pool = nn.MaxPool3d(2)
self.global_pool = nn.AdaptiveAvgPool3d((1, 1, 1))
self.dropout = nn.Dropout(0.3)
self.fc1 = nn.Linear(64, 32)
self.fc2 = nn.Linear(32, num_classes)
def forward(self, x):
x = self.normalizer(x) # type: ignore
# First convolution block
_x = x
x = F.relu(self.batch_norm_1_1(self.conv3d_1_1(x)))
x = F.relu(self.batch_norm_1_2(self.conv3d_1_2(x)))
x = (x + F.relu(self.batch_norm_1_3(self.conv3d_1_3(_x)))) / 2
x = self.pool(x)
# Second convolution block
_x = x
x = F.relu(self.batch_norm_2_1(self.conv3d_2_1(x)))
x = F.relu(self.batch_norm_2_2(self.conv3d_2_2(x)))
x = (x + F.relu(self.batch_norm_2_3(self.conv3d_2_3(_x)))) / 2
x = self.pool(x)
# Third convolution block
_x = x
x = F.relu(self.batch_norm_3_1(self.conv3d_3_1(x)))
x = F.relu(self.batch_norm_3_2(self.conv3d_3_2(x)))
x = (x + F.relu(self.batch_norm_3_3(self.conv3d_3_3(_x)))) / 2
x = self.pool(x)
# Fourth convolution block
_x = x
x = F.relu(self.batch_norm_4_1(self.conv3d_4_1(x)))
x = F.relu(self.batch_norm_4_2(self.conv3d_4_2(x)))
x = (x + F.relu(self.batch_norm_4_3(self.conv3d_4_3(_x)))) / 2
x = self.global_pool(x)
x = x.view(x.size(0), x.size(1))
x = F.relu(self.fc1(x))
x = self.dropout(x)
return self.fc2(x)
# x = F.log_softmax(x, dim=1) # 0 is batch size
def training_step(self, batch, _):
images = batch[0]
labels = batch[1]["label"]
# Increase label dimension if too low
# Allows single and multiclass usage
if labels.ndim == 1:
labels = torch.reshape(labels, (labels.shape[0], 1))
# Forward pass on the network
outputs = self(self.augmentation_transforms(images))
return self._train_loss(outputs, labels.float())
def validation_step(self, batch, batch_idx, dataloader_idx=0):
images = batch[0]
labels = batch[1]["label"]
# Increase label dimension if too low
# Allows single and multiclass usage
if labels.ndim == 1:
labels = torch.reshape(labels, (labels.shape[0], 1))
# data forwarding on the existing network
outputs = self(images)
return self._validation_loss(outputs, labels.float())
def predict_step(self, batch, batch_idx, dataloader_idx=0):
outputs = self(batch[0])
probabilities = torch.sigmoid(outputs)
return separate((probabilities, batch[1]))
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""Tests for VISCERAL dataset."""
import pytest
from click.testing import CliRunner
def id_function(val):
if isinstance(val, dict):
return str(val)
return repr(val)
@pytest.mark.parametrize(
"split,lenghts",
[
("default", dict(train=241, validation=69, test=35)),
],
ids=id_function, # just changes how pytest prints it
)
def test_protocol_consistency(
database_checkers,
split: str,
lenghts: dict[str, int],
):
from mednet.data.split import make_split
database_checkers.check_split(
make_split("mednet.config.data.visceral", f"{split}.json"),
lengths=lenghts,
prefixes=("16/10000"),
possible_labels=(0, 1),
)
@pytest.mark.skip_if_rc_var_not_set("datadir.visceral")
def test_database_check():
from mednet.scripts.database import check
runner = CliRunner()
result = runner.invoke(check, ["visceral"])
assert (
result.exit_code == 0
), f"Exit code {result.exit_code} != 0 -- Output:\n{result.output}"
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment