diff --git a/.flake8 b/.flake8 deleted file mode 100644 index 0de55fba7ec878cb182f3a6efdb732db1e254cc0..0000000000000000000000000000000000000000 --- a/.flake8 +++ /dev/null @@ -1,8 +0,0 @@ -# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> -# SPDX-FileContributor: Andre Anjos <andre.anjos@idiap.ch> -# -# SPDX-License-Identifier: GPL-3.0-or-later - -[flake8] -max-line-length = 80 -ignore = E501,W503,E302,E402,E203 diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 2c305176c0f4da15bffc19a7f6dd34cc207c3437..079c8eae0595ff9c516a89d56b8461993a17d68b 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -6,44 +6,21 @@ # See https://pre-commit.com for more information # See https://pre-commit.com/hooks.html for more hooks repos: + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.3.3 + hooks: + - id: ruff + args: [ --fix ] + - id: ruff-format - repo: https://github.com/numpy/numpydoc rev: v1.6.0 hooks: - id: numpydoc-validation - - repo: https://github.com/psf/black - rev: 23.12.1 - hooks: - - id: black - - repo: https://github.com/pycqa/docformatter - rev: v1.7.5 - hooks: - - id: docformatter - args: [ - --wrap-summaries=0, - ] - - repo: https://github.com/pycqa/isort - rev: 5.13.2 - hooks: - - id: isort - - repo: https://github.com/pycqa/flake8 - rev: 7.0.0 - hooks: - - id: flake8 - repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.8.0 + rev: v1.9.0 hooks: - id: mypy - args: [ - --install-types, - --non-interactive, - --no-strict-optional, - --ignore-missing-imports, - ] - - repo: https://github.com/asottile/pyupgrade - rev: v3.15.0 - hooks: - - id: pyupgrade - args: [--py39-plus] + args: [ --install-types, --non-interactive, --no-strict-optional, --ignore-missing-imports ] # - repo: https://github.com/pre-commit/mirrors-prettier # rev: v2.7.1 # hooks: diff --git a/doc/api.rst b/doc/api.rst index 39b4d40f6218f6100b7f7ab50fb826aaf32e541a..9d0b02c0c9aece7a7647ef95c8ea80827948ad87 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -99,7 +99,6 @@ Reusable auxiliary functions. mednet.utils.checkpointer mednet.utils.rc mednet.utils.resources - mednet.utils.summary mednet.utils.tensorboard diff --git a/doc/conf.py b/doc/conf.py index 7cf72ca51fb2b9e6d4f22c2cd8143d63f32893da..f547b0b0938ab095df1a074426d75cb88665d65e 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -2,9 +2,8 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -import os +import pathlib import time - from importlib.metadata import distribution # -- General configuration ----------------------------------------------------- @@ -36,8 +35,9 @@ nitpicky = True nitpick_ignore = [] # Allows the user to override warnings from a separate file -if os.path.exists("nitpick-exceptions.txt"): - for line in open("nitpick-exceptions.txt"): +nitpick_path = pathlib.Path("nitpick-exceptions.txt") +if nitpick_path.exists(): + for line in nitpick_path.open(): if line.strip() == "" or line.startswith("#"): continue dtype, target = line.split(None, 1) @@ -54,9 +54,9 @@ autosummary_generate = True numfig = True # If we are on OSX, the 'dvipng' path maybe different -dvipng_osx = "/Library/TeX/texbin/dvipng" -if os.path.exists(dvipng_osx): - pngmath_dvipng = dvipng_osx +dvipng_osx = pathlib.Path("/Library/TeX/texbin/dvipng") +if dvipng_osx.exists(): + pngmath_dvipng = str(dvipng_osx) # Add any paths that contain templates here, relative to this directory. templates_path = ["_templates"] @@ -71,7 +71,7 @@ main_doc = "index" project = "mednet" package = distribution(project) -copyright = "%s, Idiap Research Institute" % time.strftime("%Y") +copyright = "%s, Idiap Research Institute" % time.strftime("%Y") # noqa: A001 # The short X.Y version. version = package.version diff --git a/src/mednet/config/data/tbx11k/make_splits_from_database.py b/helpers/tbx11k_make_splits.py similarity index 88% rename from src/mednet/config/data/tbx11k/make_splits_from_database.py rename to helpers/tbx11k_make_splits.py index 5c097d1749ceb4e0bd69376941a2034cf02c4725..233f0194d38b7f9e488b566d85951ade544007b1 100644 --- a/src/mednet/config/data/tbx11k/make_splits_from_database.py +++ b/helpers/tbx11k_make_splits.py @@ -74,7 +74,7 @@ def reorder(data: dict) -> list: categories = {k["id"]: k["name"] for k in data["categories"]} assert len(set(categories.values())) == len( - categories + categories, ), "Category ids are not unique" # reset category values, so latent-tb = 0, and active-tb = 1 @@ -142,17 +142,17 @@ def normalize_labels(data: list) -> list: ), f"Image {s[0]} is healthy, but contains tb bbox annotations" return 0 # patient is healthy - elif s[0].startswith("imgs/sick"): + if s[0].startswith("imgs/sick"): assert ( len(s) == 2 ), f"Image {s[0]} is sick (no tb), but contains tb bbox annotations" return 4 # patient is sick - elif s[0].startswith("imgs/tb"): + if s[0].startswith("imgs/tb"): if len(s) == 2: print( f"WARNING: Image {s[0]} is from the tb subdir, " - f"but contains no tb bbox annotations" + f"but contains no tb bbox annotations", ) return -1 # unknown diagnosis bbx_labels: list[int] = [k[0] for k in s[2]] @@ -164,18 +164,17 @@ def normalize_labels(data: list) -> list: if 0 in tb_counts: if 1 not in tb_counts: return 3 # patient has latent tb - else: - print( - f"WARNING: Image {s[0]} has bboxes with both " - f"active and latent tb." - ) - return 2 # patient has active and latent tb - else: # 1 in tb_counts: - assert 0 not in tb_counts # cannot really happen, but check... - return 1 # patient has only active tb - - else: - raise RuntimeError("Cannot happen - please check") + + print( + f"WARNING: Image {s[0]} has bboxes with both " + f"active and latent tb.", + ) + return 2 # patient has active and latent tb + + assert 0 not in tb_counts # cannot really happen, but check... + return 1 # patient has only active tb + + raise RuntimeError("Cannot happen - please check") for k in data: k[1] = _set_label(k) @@ -209,6 +208,7 @@ def print_statistics(d: dict): ds The dataset to print stats for. """ + class_count = collections.Counter([k[1] for k in ds]) for k, v in class_count.items(): print(f" - {label_translations[k]}: {v}") @@ -338,7 +338,10 @@ def create_v2_default_split(d: dict, seed: int, validation_size) -> dict: def create_folds( - d: dict, n: int, seed: int, validation_size: float + d: dict, + n: int, + seed: int, + validation_size: float, ) -> list[dict]: """Create folds from existing splits. @@ -360,8 +363,8 @@ def create_folds( All the ``n`` folds. """ - X = d["train"] + d["validation"] + d["test"] - y = [[k[1]] for k in X] + x = d["train"] + d["validation"] + d["test"] + y = [[k[1]] for k in x] # Initializes a StratifiedKFold object with 10 folds skf = StratifiedKFold(n_splits=n, shuffle=True, random_state=seed) @@ -374,10 +377,10 @@ def create_folds( # Loops over the 10 folds and split the data retval = [] - for train_idx, test_idx in skf.split(X, y): + for train_idx, test_idx in skf.split(x, y): # Get the training and test data for this fold - train_dataset = [X[k] for k in train_idx] - test_dataset = [X[k] for k in test_idx] + train_dataset = [x[k] for k in train_idx] + test_dataset = [x[k] for k in test_idx] # Split the training data into training and validation sets train_dataset, val_dataset = train_test_split( @@ -392,7 +395,7 @@ def create_folds( "train": train_dataset, "validation": val_dataset, "test": test_dataset, - } + }, ) return retval @@ -413,24 +416,25 @@ def main(): datadir = pathlib.Path( UserDefaults("mednet.toml").get( - "datadir.tbx11k", os.path.realpath(os.curdir) - ) + "datadir.tbx11k", + os.path.realpath(os.curdir), + ), ) train_filename = datadir / "annotations" / "json" / "TBX11K_train.json" val_filename = datadir / "annotations" / "json" / "TBX11K_val.json" test_filename = datadir / "annotations" / "json" / "all_test.json" - with open(train_filename) as f: + with train_filename.open() as f: print(f"Loading {str(train_filename)}...") data = json.load(f) train_data = normalize_labels(reorder(data)) - with open(val_filename) as f: + with val_filename.open() as f: print(f"Loading {str(val_filename)}...") data = json.load(f) val_data = normalize_labels(reorder(data)) - with open(test_filename) as f: + with test_filename.open() as f: print(f"Loading {str(test_filename)}...") data = json.load(f) test_data = reorder(data) @@ -448,36 +452,46 @@ def main(): print("\nGenerating v1 split...") v1_split = create_v1_default_split( - final_data, seed=seed, validation_size=validation_size + final_data, + seed=seed, + validation_size=validation_size, ) print_statistics(v1_split) - with open("v1-healthy-vs-atb.json", "w") as v1def: + with pathlib.Path("v1-healthy-vs-atb.json").open("w") as v1def: json.dump(v1_split, v1def, indent=2) # folds for the v1 split print(f"\nGenerating {n_folds} v1 split folds...") v1_folds = create_folds( - v1_split, n=n_folds, seed=seed, validation_size=validation_size + v1_split, + n=n_folds, + seed=seed, + validation_size=validation_size, ) for i, k in enumerate(v1_folds): - with open(f"v1-fold-{i}.json", "w") as v1fold: + with pathlib.Path(f"v1-fold-{i}.json").open("w") as v1fold: json.dump(k, v1fold, indent=2) print("\nGenerating v2 split...") v2_split = create_v2_default_split( - final_data, seed=seed, validation_size=validation_size + final_data, + seed=seed, + validation_size=validation_size, ) print_statistics(v2_split) - with open("v2-others-vs-atb.json", "w") as v2def: + with pathlib.Path("v2-others-vs-atb.json").open("w") as v2def: json.dump(v2_split, v2def, indent=2) # folds for the v2 split print(f"\nGenerating {n_folds} v2 split folds...") v2_folds = create_folds( - v2_split, n=n_folds, seed=seed, validation_size=validation_size + v2_split, + n=n_folds, + seed=seed, + validation_size=validation_size, ) for i, k in enumerate(v2_folds): - with open(f"v2-fold-{i}.json", "w") as v2fold: + with pathlib.Path(f"v2-fold-{i}.json").open("w") as v2fold: json.dump(k, v2fold, indent=2) diff --git a/pyproject.toml b/pyproject.toml index f56c58a3aeb2574c1f049fc9a963141af9beedca..a04dee2c3637386bf88f13344abbc3e6e6c6c3d2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -258,14 +258,62 @@ nih-cxr14-padchest = "mednet.config.data.nih_cxr14_padchest.idiap" # montgomery-shenzhen-indian-padchest aggregated dataset montgomery-shenzhen-indian-padchest = "mednet.config.data.montgomery_shenzhen_indian_padchest.default" -[tool.isort] -profile = "black" -line_length = 80 -order_by_type = true -lines_between_types = 1 - -[tool.black] +[tool.ruff] line-length = 80 +target-version = "py310" + +[tool.ruff.format] +docstring-code-format = true + +[tool.ruff.lint] +select = [ + "A", # https://docs.astral.sh/ruff/rules/#flake8-builtins-a + "COM", # https://docs.astral.sh/ruff/rules/#flake8-commas-com + "D", # https://docs.astral.sh/ruff/rules/#pydocstyle-d + "E", # https://docs.astral.sh/ruff/rules/#error-e + "F", # https://docs.astral.sh/ruff/rules/#pyflakes-f + "I", # https://docs.astral.sh/ruff/rules/#isort-i + "ISC", # https://docs.astral.sh/ruff/rules/#flake8-implicit-str-concat-isc + "LOG", # https://docs.astral.sh/ruff/rules/#flake8-logging-log + "N", # https://docs.astral.sh/ruff/rules/#pep8-naming-n + "PTH", # https://docs.astral.sh/ruff/rules/#flake8-use-pathlib-pth + "Q", # https://docs.astral.sh/ruff/rules/#flake8-quotes-q + "RET", # https://docs.astral.sh/ruff/rules/#flake8-return-ret + "SLF", # https://docs.astral.sh/ruff/rules/#flake8-self-slf + "T10", # https://docs.astral.sh/ruff/rules/#flake8-debugger-t10 + "T20", # https://docs.astral.sh/ruff/rules/#flake8-print-t20 + "UP", # https://docs.astral.sh/ruff/rules/#pyupgrade-up + "W", # https://docs.astral.sh/ruff/rules/#warning-w + #"G", # https://docs.astral.sh/ruff/rules/#flake8-logging-format-g + #"ICN", # https://docs.astral.sh/ruff/rules/#flake8-import-conventions-icn + #"NPY", # https://docs.astral.sh/ruff/rules/#numpy-specific-rules-npy +] +ignore = [ + "COM812", # https://docs.astral.sh/ruff/rules/missing-trailing-comma/ + "D100", # https://docs.astral.sh/ruff/rules/undocumented-public-module/ + "D102", # https://docs.astral.sh/ruff/rules/undocumented-public-method/ + "D104", # https://docs.astral.sh/ruff/rules/undocumented-public-package/ + "D105", # https://docs.astral.sh/ruff/rules/undocumented-magic-method/ + "D107", # https://docs.astral.sh/ruff/rules/undocumented-public-init/ + "D203", # https://docs.astral.sh/ruff/rules/one-blank-line-before-class/ + "D202", # https://docs.astral.sh/ruff/rules/no-blank-line-after-function/ + "D205", # https://docs.astral.sh/ruff/rules/blank-line-after-summary/ + "D212", # https://docs.astral.sh/ruff/rules/multi-line-summary-first-line/ + "D213", # https://docs.astral.sh/ruff/rules/multi-line-summary-second-line/ + "E302", # https://docs.astral.sh/ruff/rules/blank-lines-top-level/ + "E402", # https://docs.astral.sh/ruff/rules/module-import-not-at-top-of-file/ + "E501", # https://docs.astral.sh/ruff/rules/line-too-long/ + "ISC001", # https://docs.astral.sh/ruff/rules/single-line-implicit-string-concatenation/ +] + +[tool.ruff.lint.pydocstyle] +convention = "numpy" + +[tool.ruff.lint.per-file-ignores] +"helpers/*.py" = ["T201", "D103"] +"tests/*.py" = ["D", "E501"] +"doc/conf.py" = ["D"] +"**/scripts/*.py" = ["E501"] [tool.pytest.ini_options] addopts = ["--cov=mednet", "--cov-report=term-missing", "--import-mode=append"] @@ -274,7 +322,7 @@ junit_log_passing_tests = false [tool.numpydoc_validation] checks = [ - "all", # report on all checks, except the below + "all", # report on all checks, except the ones below "ES01", # Not all functions require extended summaries "EX01", # Not all functions require examples "GL01", # Expects text to be on the line after the opening quotes but that is in direct opposition of the sphinx recommendations and conflicts with other pre-commit hooks. diff --git a/src/mednet/config/data/hivtb/datamodule.py b/src/mednet/config/data/hivtb/datamodule.py index cae64f2c53b2fda0992aea1f662307ddab27d616..b923e6c9b84781cc41e3b42777b3be86d734876b 100644 --- a/src/mednet/config/data/hivtb/datamodule.py +++ b/src/mednet/config/data/hivtb/datamodule.py @@ -6,20 +6,18 @@ Database reference: [HIV-TB-2019]_ """ - import importlib.resources import os +import pathlib import PIL.Image - from torchvision.transforms.functional import to_tensor from ....data.datamodule import CachingDataModule from ....data.image_utils import remove_black_borders from ....data.split import JSONDatabaseSplit -from ....data.typing import DatabaseSplit +from ....data.typing import DatabaseSplit, Sample from ....data.typing import RawDataLoader as _BaseRawDataLoader -from ....data.typing import Sample from ....utils.rc import load_rc CONFIGURATION_KEY_DATADIR = "datadir." + (__name__.rsplit(".", 2)[-2]) @@ -30,13 +28,16 @@ database.""" class RawDataLoader(_BaseRawDataLoader): """A specialized raw-data-loader for the HIV-TB dataset.""" - datadir: str + datadir: pathlib.Path """This variable contains the base directory where the database raw data is stored.""" def __init__(self): - self.datadir = load_rc().get( - CONFIGURATION_KEY_DATADIR, os.path.realpath(os.curdir) + self.datadir = pathlib.Path( + load_rc().get( + CONFIGURATION_KEY_DATADIR, + os.path.realpath(os.curdir), + ), ) def sample(self, sample: tuple[str, int]) -> Sample: @@ -46,16 +47,15 @@ class RawDataLoader(_BaseRawDataLoader): ---------- sample A tuple containing the path suffix, within the dataset root folder, - where to find the image to be loaded, and an integer, representing the - sample label. + where to find the image to be loaded, and an integer, representing + the sample label. Returns ------- The sample representation. """ - image = PIL.Image.open(os.path.join(self.datadir, sample[0])).convert( - "L" - ) + + image = PIL.Image.open(self.datadir / sample[0]).convert("L") image = remove_black_borders(image) tensor = to_tensor(image) @@ -73,14 +73,15 @@ class RawDataLoader(_BaseRawDataLoader): ---------- sample A tuple containing the path suffix, within the dataset root folder, - where to find the image to be loaded, and an integer, representing the - sample label. + where to find the image to be loaded, and an integer, representing + the sample label. Returns ------- int The integer label associated with the sample. """ + return sample[1] @@ -98,7 +99,9 @@ def make_split(basename: str) -> DatabaseSplit: """ return JSONDatabaseSplit( - importlib.resources.files(__name__.rsplit(".", 1)[0]).joinpath(basename) + importlib.resources.files(__name__.rsplit(".", 1)[0]).joinpath( + basename, + ), ) @@ -142,5 +145,5 @@ class DataModule(CachingDataModule): database_split=make_split(split_filename), raw_data_loader=RawDataLoader(), database_name=__package__.split(".")[-1], - split_name=os.path.splitext(split_filename)[0], + split_name=pathlib.Path(split_filename).stem, ) diff --git a/src/mednet/config/data/indian/datamodule.py b/src/mednet/config/data/indian/datamodule.py index 2fc0567bbbefd65ec4c96fc19592c1ec1db7215c..f9d52a2d332ee58b29c4cb2cadbe6d4be4abc56a 100644 --- a/src/mednet/config/data/indian/datamodule.py +++ b/src/mednet/config/data/indian/datamodule.py @@ -7,7 +7,7 @@ Database reference: [INDIAN-2013]_ """ import importlib.resources -import os +import pathlib from ....config.data.shenzhen.datamodule import RawDataLoader from ....data.datamodule import CachingDataModule @@ -33,7 +33,9 @@ def make_split(basename: str) -> DatabaseSplit: """ return JSONDatabaseSplit( - importlib.resources.files(__name__.rsplit(".", 1)[0]).joinpath(basename) + importlib.resources.files(__name__.rsplit(".", 1)[0]).joinpath( + basename, + ), ) @@ -47,7 +49,8 @@ class DataModule(CachingDataModule): * Database reference: [INDIAN-2013]_ * Original images PNG, 8-bit grayscale, 1024 x 1024 pixels - * Split reference: [INDIAN-2013]_ with 20% of train set for the validation set + * Split reference: [INDIAN-2013]_ with 20% of train set for the validation + set Data specifications: @@ -81,8 +84,8 @@ class DataModule(CachingDataModule): super().__init__( database_split=make_split(split_filename), raw_data_loader=RawDataLoader( - config_variable=CONFIGURATION_KEY_DATADIR + config_variable=CONFIGURATION_KEY_DATADIR, ), database_name=__package__.split(".")[-1], - split_name=os.path.splitext(split_filename)[0], + split_name=pathlib.Path(split_filename).stem, ) diff --git a/src/mednet/config/data/montgomery/datamodule.py b/src/mednet/config/data/montgomery/datamodule.py index 5ed7fa50e325810f3b8523a883aa5f9f8c1c301b..5465077351001eb9e61b8f4b18a98764719f930e 100644 --- a/src/mednet/config/data/montgomery/datamodule.py +++ b/src/mednet/config/data/montgomery/datamodule.py @@ -8,17 +8,16 @@ Database reference: [MONTGOMERY-SHENZHEN-2014]_ import importlib.resources import os +import pathlib import PIL.Image - from torchvision.transforms.functional import to_tensor from ....data.datamodule import CachingDataModule from ....data.image_utils import remove_black_borders from ....data.split import JSONDatabaseSplit -from ....data.typing import DatabaseSplit +from ....data.typing import DatabaseSplit, Sample from ....data.typing import RawDataLoader as _BaseRawDataLoader -from ....data.typing import Sample from ....utils.rc import load_rc CONFIGURATION_KEY_DATADIR = "datadir." + (__name__.rsplit(".", 2)[-2]) @@ -29,13 +28,16 @@ database.""" class RawDataLoader(_BaseRawDataLoader): """A specialized raw-data-loader for the Montgomery dataset.""" - datadir: str + datadir: pathlib.Path """This variable contains the base directory where the database raw data is stored.""" def __init__(self): - self.datadir = load_rc().get( - CONFIGURATION_KEY_DATADIR, os.path.realpath(os.curdir) + self.datadir = pathlib.Path( + load_rc().get( + CONFIGURATION_KEY_DATADIR, + os.path.realpath(os.curdir), + ), ) def sample(self, sample: tuple[str, int]) -> Sample: @@ -45,16 +47,17 @@ class RawDataLoader(_BaseRawDataLoader): ---------- sample A tuple containing the path suffix, within the dataset root folder, - where to find the image to be loaded, and an integer, representing the - sample label. + where to find the image to be loaded, and an integer, representing + the sample label. Returns ------- The sample representation. """ + # N.B.: Montgomery images are encoded as grayscale PNGs, so no need to # convert them again with Image.convert("L"). - image = PIL.Image.open(os.path.join(self.datadir, sample[0])) + image = PIL.Image.open(self.datadir / sample[0]) image = remove_black_borders(image) tensor = to_tensor(image) @@ -72,14 +75,15 @@ class RawDataLoader(_BaseRawDataLoader): ---------- sample A tuple containing the path suffix, within the dataset root folder, - where to find the image to be loaded, and an integer, representing the - sample label. + where to find the image to be loaded, and an integer, representing + the sample label. Returns ------- int The integer label associated with the sample. """ + return sample[1] @@ -97,19 +101,23 @@ def make_split(basename: str) -> DatabaseSplit: """ return JSONDatabaseSplit( - importlib.resources.files(__name__.rsplit(".", 1)[0]).joinpath(basename) + importlib.resources.files(__name__.rsplit(".", 1)[0]).joinpath( + basename, + ), ) class DataModule(CachingDataModule): """Montgomery DataModule for TB detection. - The standard digital image database for Tuberculosis was created by the National - Library of Medicine, Maryland, USA in collaboration with Shenzhen No.3 People’s - Hospital, Guangdong Medical College, Shenzhen, China. The Chest X-rays are from + The standard digital image database for Tuberculosis was created by the + National Library of Medicine, Maryland, USA in collaboration with Shenzhen + No.3 People’s Hospital, Guangdong Medical College, Shenzhen, China. The + Chest X-rays are from * Database reference: [MONTGOMERY-SHENZHEN-2014]_ - * Original resolution (height x width or width x height): 4020x4892 px or 4892x4020 px + * Original resolution (height x width or width x height): 4020x4892 px or + 4892x4020 px Data specifications: @@ -144,5 +152,5 @@ class DataModule(CachingDataModule): database_split=make_split(split_filename), raw_data_loader=RawDataLoader(), database_name=__package__.split(".")[-1], - split_name=os.path.splitext(split_filename)[0], + split_name=pathlib.Path(split_filename).stem, ) diff --git a/src/mednet/config/data/montgomery_shenzhen/datamodule.py b/src/mednet/config/data/montgomery_shenzhen/datamodule.py index 6df353ad70d701d8585be9b46dc4ea540adccdb4..2ca4582bfba64e09c247b29717b6c6cc9b962b2c 100644 --- a/src/mednet/config/data/montgomery_shenzhen/datamodule.py +++ b/src/mednet/config/data/montgomery_shenzhen/datamodule.py @@ -3,7 +3,7 @@ # SPDX-License-Identifier: GPL-3.0-or-later """Aggregated DataModule composed of Montgomery and Shenzhen databases.""" -import os +import pathlib from ....data.datamodule import ConcatDataModule from ..montgomery.datamodule import RawDataLoader as MontgomeryLoader @@ -43,5 +43,5 @@ class DataModule(ConcatDataModule): ], }, database_name=__package__.split(".")[-1], - split_name=os.path.splitext(split_filename)[0], + split_name=pathlib.Path(split_filename).stem, ) diff --git a/src/mednet/config/data/montgomery_shenzhen_indian/datamodule.py b/src/mednet/config/data/montgomery_shenzhen_indian/datamodule.py index fc9af897b9501adcf40e46a0c841d06aefec214b..9029f5aa8e623b32bc3596e3bfe2f75bec01a6e9 100644 --- a/src/mednet/config/data/montgomery_shenzhen_indian/datamodule.py +++ b/src/mednet/config/data/montgomery_shenzhen_indian/datamodule.py @@ -1,9 +1,11 @@ # Copyright © 2022 Idiap Research Institute <contact@idiap.ch> # # SPDX-License-Identifier: GPL-3.0-or-later -"""Aggregated DataModule composed of Montgomery, Shenzhen and Indian databases.""" +"""Aggregated DataModule composed of Montgomery, Shenzhen and Indian +databases. +""" -import os +import pathlib from ....data.datamodule import ConcatDataModule from ..indian.datamodule import CONFIGURATION_KEY_DATADIR as INDIAN_KEY_DATADIR @@ -52,5 +54,5 @@ class DataModule(ConcatDataModule): ], }, database_name=__package__.split(".")[-1], - split_name=os.path.splitext(split_filename)[0], + split_name=pathlib.Path(split_filename).stem, ) diff --git a/src/mednet/config/data/montgomery_shenzhen_indian_padchest/datamodule.py b/src/mednet/config/data/montgomery_shenzhen_indian_padchest/datamodule.py index bbfda89bd3d54d695f82004d79dc59c8566af21a..7d81d193e94dd6f13fd713be75caa6d496708255 100644 --- a/src/mednet/config/data/montgomery_shenzhen_indian_padchest/datamodule.py +++ b/src/mednet/config/data/montgomery_shenzhen_indian_padchest/datamodule.py @@ -1,9 +1,11 @@ # Copyright © 2022 Idiap Research Institute <contact@idiap.ch> # # SPDX-License-Identifier: GPL-3.0-or-later -"""Aggregated DataModule composed of Montgomery, Shenzhen, Indian, and PadChest datasets.""" +"""Aggregated DataModule composed of Montgomery, Shenzhen, Indian, and PadChest +datasets. +""" -import os +import pathlib from ....data.datamodule import ConcatDataModule from ..indian.datamodule import CONFIGURATION_KEY_DATADIR as INDIAN_KEY_DATADIR @@ -63,8 +65,8 @@ class DataModule(ConcatDataModule): }, database_name=__package__.split(".")[-1], split_name=( - os.path.splitext(split_filename)[0] + pathlib.Path(split_filename).stem + "+" - + os.path.splitext(padchest_split_filename)[0] + + pathlib.Path(padchest_split_filename).stem ), ) diff --git a/src/mednet/config/data/montgomery_shenzhen_indian_tbx11k/datamodule.py b/src/mednet/config/data/montgomery_shenzhen_indian_tbx11k/datamodule.py index f8f82c967289219ad0adac7646874d01dcffc1c3..e916462d85b24b58c651344c25e094eabef03132 100644 --- a/src/mednet/config/data/montgomery_shenzhen_indian_tbx11k/datamodule.py +++ b/src/mednet/config/data/montgomery_shenzhen_indian_tbx11k/datamodule.py @@ -1,9 +1,11 @@ # Copyright © 2022 Idiap Research Institute <contact@idiap.ch> # # SPDX-License-Identifier: GPL-3.0-or-later -"""Aggregated DataModule composed of Montgomery, Shenzhen, Indian, and TBX11k datasets.""" +"""Aggregated DataModule composed of Montgomery, Shenzhen, Indian, and TBX11k +datasets. +""" -import os +import pathlib from ....data.datamodule import ConcatDataModule from ..indian.datamodule import CONFIGURATION_KEY_DATADIR as INDIAN_KEY_DATADIR @@ -62,8 +64,8 @@ class DataModule(ConcatDataModule): }, database_name=__package__.split(".")[-1], split_name=( - os.path.splitext(split_filename)[0] + pathlib.Path(split_filename).stem + "+" - + os.path.splitext(tbx11k_split_filename)[0] + + pathlib.Path(tbx11k_split_filename).stem ), ) diff --git a/src/mednet/config/data/nih_cxr14/datamodule.py b/src/mednet/config/data/nih_cxr14/datamodule.py index 26596b74c427e712d7ef3b2b141d4508d1eae395..4dc2a1f132abb3318a2bb5137b5717951c5781f3 100644 --- a/src/mednet/config/data/nih_cxr14/datamodule.py +++ b/src/mednet/config/data/nih_cxr14/datamodule.py @@ -8,25 +8,25 @@ Database reference: [NIH-CXR14-2017]_ import importlib.resources import os +import pathlib import PIL.Image - from torchvision.transforms.functional import to_tensor from ....data.datamodule import CachingDataModule from ....data.split import JSONDatabaseSplit -from ....data.typing import DatabaseSplit +from ....data.typing import DatabaseSplit, Sample from ....data.typing import RawDataLoader as _BaseRawDataLoader -from ....data.typing import Sample from ....utils.rc import load_rc CONFIGURATION_KEY_DATADIR = "datadir." + (__name__.rsplit(".", 2)[-2]) """Key to search for in the configuration file for the root directory of this -database.""" +database. +""" CONFIGURATION_KEY_IDIAP_FILESTRUCTURE = ( - __name__.rsplit(".", 2)[-2] -) + ".idiap_folder_structure" + (__name__.rsplit(".", 2)[-2]) + ".idiap_folder_structure" +) """Key to search for in the configuration file indicating if the loader should use standard or idiap-based file organisation structure. @@ -39,7 +39,7 @@ different folder structure, that was adapted to Idiap's requirements class RawDataLoader(_BaseRawDataLoader): """A specialized raw-data-loader for the NIH CXR-14 dataset.""" - datadir: str + datadir: pathlib.Path """This variable contains the base directory where the database raw data is stored.""" @@ -55,11 +55,12 @@ class RawDataLoader(_BaseRawDataLoader): def __init__(self): rc = load_rc() - self.datadir = rc.get( - CONFIGURATION_KEY_DATADIR, os.path.realpath(os.curdir) + self.datadir = pathlib.Path( + rc.get(CONFIGURATION_KEY_DATADIR, os.path.realpath(os.curdir)), ) self.idiap_file_organisation = rc.get( - CONFIGURATION_KEY_IDIAP_FILESTRUCTURE, False + CONFIGURATION_KEY_IDIAP_FILESTRUCTURE, + False, ) def sample(self, sample: tuple[str, list[int]]) -> Sample: @@ -69,28 +70,24 @@ class RawDataLoader(_BaseRawDataLoader): ---------- sample A tuple containing the path suffix, within the dataset root folder, - where to find the image to be loaded, and an integer, representing the - sample label. + where to find the image to be loaded, and an integer, representing + the sample label. Returns ------- The sample representation. """ - file_path = sample[0] # default + + file_path = pathlib.Path(sample[0]) # default if self.idiap_file_organisation: # for folder lookup efficiency, data is split into subfolders # each original file is on the subfolder `f[:5]/f`, where f # is the original file basename - basename = os.path.basename(sample[0]) - file_path = os.path.join( - os.path.dirname(sample[0]), - basename[:5], - basename, - ) + file_path = file_path.parent / file_path.name[:5] / file_path.name # N.B.: some NIH CXR-14 images are encoded as color PNGs with an alpha # channel. Most, are grayscale PNGs - image = PIL.Image.open(os.path.join(self.datadir, file_path)) + image = PIL.Image.open(self.datadir / file_path) image = image.convert("L") # required for some images tensor = to_tensor(image) @@ -116,6 +113,7 @@ class RawDataLoader(_BaseRawDataLoader): list[int] The integer labels associated with the sample. """ + return sample[1] @@ -133,7 +131,9 @@ def make_split(basename: str) -> DatabaseSplit: """ return JSONDatabaseSplit( - importlib.resources.files(__name__.rsplit(".", 1)[0]).joinpath(basename) + importlib.resources.files(__name__.rsplit(".", 1)[0]).joinpath( + basename, + ), ) @@ -165,7 +165,8 @@ class DataModule(CachingDataModule): * Final specifications: - * RGB, encoded as a 3-plane tensor, 32-bit floats, square (1024x1024 px) + * RGB, encoded as a 3-plane tensor, 32-bit floats, square + (1024x1024 px) * Labels in order: * cardiomegaly * emphysema @@ -193,5 +194,5 @@ class DataModule(CachingDataModule): database_split=make_split(split_filename), raw_data_loader=RawDataLoader(), database_name=__package__.split(".")[-1], - split_name=os.path.splitext(split_filename)[0], + split_name=pathlib.Path(split_filename).stem, ) diff --git a/src/mednet/config/data/nih_cxr14_padchest/datamodule.py b/src/mednet/config/data/nih_cxr14_padchest/datamodule.py index 6cc383405093db4adca141ad6001bab9572fefa9..f4d13d5b3bfcd3e331e46a8d1cdaf86d0ec266b4 100644 --- a/src/mednet/config/data/nih_cxr14_padchest/datamodule.py +++ b/src/mednet/config/data/nih_cxr14_padchest/datamodule.py @@ -3,7 +3,7 @@ # SPDX-License-Identifier: GPL-3.0-or-later """Aggregated DataModule composed of NIH-CXR-14 and PadChest databases.""" -import os +import pathlib from ....data.datamodule import ConcatDataModule from ..nih_cxr14.datamodule import RawDataLoader as CXR14Loader @@ -48,8 +48,8 @@ class DataModule(ConcatDataModule): }, database_name=__package__.split(".")[-1], split_name=( - os.path.splitext(cxr14_split_filename)[0] + pathlib.Path(cxr14_split_filename).stem + "+" - + os.path.splitext(padchest_split_filename)[0] + + pathlib.Path(padchest_split_filename).stem ), ) diff --git a/src/mednet/config/data/nih_cxr14_padchest/idiap.py b/src/mednet/config/data/nih_cxr14_padchest/idiap.py index 6ac62f99a2097a4add1349e47652c65e9eb8a913..7ba5b9bca91784f8b26a986df92220d9f998022d 100644 --- a/src/mednet/config/data/nih_cxr14_padchest/idiap.py +++ b/src/mednet/config/data/nih_cxr14_padchest/idiap.py @@ -1,7 +1,9 @@ # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> # # SPDX-License-Identifier: GPL-3.0-or-later -"""Aggregated dataset composed of NIH CXR14 relabeld and PadChest (normalized) datasets (no-tb-idiap split).""" +"""Aggregated dataset composed of NIH CXR14 relabeld and PadChest (normalized) datasets +(no-tb-idiap split). +""" from mednet.config.data.nih_cxr14_padchest.datamodule import DataModule diff --git a/src/mednet/config/data/padchest/datamodule.py b/src/mednet/config/data/padchest/datamodule.py index d146fc0357659134151084d765945b96a8f8a305..6790dcbfb1511416659578c6a76e49e7c1a750cc 100644 --- a/src/mednet/config/data/padchest/datamodule.py +++ b/src/mednet/config/data/padchest/datamodule.py @@ -8,18 +8,17 @@ Database reference: [PADCHEST-2019]_ import importlib.resources import os +import pathlib import numpy import PIL.Image - from torchvision.transforms.functional import to_tensor from ....data.datamodule import CachingDataModule from ....data.image_utils import remove_black_borders from ....data.split import JSONDatabaseSplit -from ....data.typing import DatabaseSplit +from ....data.typing import DatabaseSplit, Sample from ....data.typing import RawDataLoader as _BaseRawDataLoader -from ....data.typing import Sample from ....utils.rc import load_rc CONFIGURATION_KEY_DATADIR = "datadir." + (__name__.rsplit(".", 2)[-2]) @@ -30,14 +29,16 @@ database.""" class RawDataLoader(_BaseRawDataLoader): """A specialized raw-data-loader for the PadChest dataset.""" - datadir: str + datadir: pathlib.Path """This variable contains the base directory where the database raw data is stored.""" def __init__(self): - rc = load_rc() - self.datadir = rc.get( - CONFIGURATION_KEY_DATADIR, os.path.realpath(os.curdir) + self.datadir = pathlib.Path( + load_rc().get( + CONFIGURATION_KEY_DATADIR, + os.path.realpath(os.curdir), + ), ) def sample(self, sample: tuple[str, int | list[int]]) -> Sample: @@ -47,15 +48,16 @@ class RawDataLoader(_BaseRawDataLoader): ---------- sample A tuple containing the path suffix, within the dataset root folder, - where to find the image to be loaded, and an integer, representing the - sample label. + where to find the image to be loaded, and an integer, representing + the sample label. Returns ------- The sample representation. """ + # N.B.: PadChest images are encoded as 16-bit grayscale images - image = PIL.Image.open(os.path.join(self.datadir, sample[0])) + image = PIL.Image.open(self.datadir / sample[0]) image = remove_black_borders(image) array = numpy.array(image).astype(numpy.float32) / 65535 tensor = to_tensor(array) @@ -74,14 +76,15 @@ class RawDataLoader(_BaseRawDataLoader): ---------- sample A tuple containing the path suffix, within the dataset root folder, - where to find the image to be loaded, and an integer, representing the - sample label. + where to find the image to be loaded, and an integer, representing + the sample label. Returns ------- list[int] The integer labels associated with the sample. """ + return sample[1] @@ -99,7 +102,9 @@ def make_split(basename: str) -> DatabaseSplit: """ return JSONDatabaseSplit( - importlib.resources.files(__name__.rsplit(".", 1)[0]).joinpath(basename) + importlib.resources.files(__name__.rsplit(".", 1)[0]).joinpath( + basename, + ), ) @@ -342,5 +347,5 @@ class DataModule(CachingDataModule): database_split=make_split(split_filename), raw_data_loader=RawDataLoader(), database_name=__package__.split(".")[-1], - split_name=os.path.splitext(split_filename)[0], + split_name=pathlib.Path(split_filename).stem, ) diff --git a/src/mednet/config/data/shenzhen/datamodule.py b/src/mednet/config/data/shenzhen/datamodule.py index 81e48f9b70dc887a97a312a0d8062afc809682e2..f2ced7295f59e13da787b4fe0ccb0a3f09b7a338 100644 --- a/src/mednet/config/data/shenzhen/datamodule.py +++ b/src/mednet/config/data/shenzhen/datamodule.py @@ -8,17 +8,16 @@ Database reference: [MONTGOMERY-SHENZHEN-2014]_ import importlib.resources import os +import pathlib import PIL.Image - from torchvision.transforms.functional import to_tensor from ....data.datamodule import CachingDataModule from ....data.image_utils import remove_black_borders from ....data.split import JSONDatabaseSplit -from ....data.typing import DatabaseSplit +from ....data.typing import DatabaseSplit, Sample from ....data.typing import RawDataLoader as _BaseRawDataLoader -from ....data.typing import Sample from ....utils.rc import load_rc CONFIGURATION_KEY_DATADIR = "datadir." + (__name__.rsplit(".", 2)[-2]) @@ -32,19 +31,19 @@ class RawDataLoader(_BaseRawDataLoader): Parameters ---------- config_variable - Key to search for in the configuration file for the root directory of this - database. + Key to search for in the configuration file for the root directory of + this database. """ - datadir: str + datadir: pathlib.Path """This variable contains the base directory where the database raw data is stored.""" # config_variable: required so this loader can be used for the Indian # database as well. def __init__(self, config_variable: str = CONFIGURATION_KEY_DATADIR): - self.datadir = load_rc().get( - config_variable, os.path.realpath(os.curdir) + self.datadir = pathlib.Path( + load_rc().get(config_variable, os.path.realpath(os.curdir)), ) def sample(self, sample: tuple[str, int]) -> Sample: @@ -54,18 +53,17 @@ class RawDataLoader(_BaseRawDataLoader): ---------- sample A tuple containing the path suffix, within the dataset root folder, - where to find the image to be loaded, and an integer, representing the - sample label. + where to find the image to be loaded, and an integer, representing + the sample label. Returns ------- The sample representation. """ + # N.B.: Image.convert("L") is required to normalize grayscale back to # normal (instead of inverted). - image = PIL.Image.open(os.path.join(self.datadir, sample[0])).convert( - "L" - ) + image = PIL.Image.open(self.datadir / sample[0]).convert("L") image = remove_black_borders(image) tensor = to_tensor(image) @@ -83,14 +81,15 @@ class RawDataLoader(_BaseRawDataLoader): ---------- sample A tuple containing the path suffix, within the dataset root folder, - where to find the image to be loaded, and an integer, representing the - sample label. + where to find the image to be loaded, and an integer, representing + the sample label. Returns ------- int The integer label associated with the sample. """ + return sample[1] @@ -108,18 +107,20 @@ def make_split(basename: str) -> DatabaseSplit: """ return JSONDatabaseSplit( - importlib.resources.files(__name__.rsplit(".", 1)[0]).joinpath(basename) + importlib.resources.files(__name__.rsplit(".", 1)[0]).joinpath( + basename, + ), ) class DataModule(CachingDataModule): """Shenzhen DataModule for computer-aided diagnosis. - The standard digital image database for Tuberculosis was created by the National - Library of Medicine, Maryland, USA in collaboration with Shenzhen No.3 People’s - Hospital, Guangdong Medical College, Shenzhen, China. The Chest X-rays are from - out-patient clinics, and were captured as part of the daily routine using - Philips DR Digital Diagnose systems. + The standard digital image database for Tuberculosis was created by the + National Library of Medicine, Maryland, USA in collaboration with Shenzhen + No.3 People’s Hospital, Guangdong Medical College, Shenzhen, China. The + Chest X-rays are from out-patient clinics, and were captured as part of the + daily routine using Philips DR Digital Diagnose systems. * Database reference: [MONTGOMERY-SHENZHEN-2014]_ @@ -156,5 +157,5 @@ class DataModule(CachingDataModule): database_split=make_split(split_filename), raw_data_loader=RawDataLoader(), database_name=__package__.split(".")[-1], - split_name=os.path.splitext(split_filename)[0], + split_name=pathlib.Path(split_filename).stem, ) diff --git a/src/mednet/config/data/tbpoc/datamodule.py b/src/mednet/config/data/tbpoc/datamodule.py index 14b09e7f23e7c72779f43f7f9bb8503ca9954414..6322248c423747b3a9257f9ceea3c327e77ef6b1 100644 --- a/src/mednet/config/data/tbpoc/datamodule.py +++ b/src/mednet/config/data/tbpoc/datamodule.py @@ -4,17 +4,16 @@ import importlib.resources import os +import pathlib import PIL.Image - from torchvision.transforms.functional import to_tensor from ....data.datamodule import CachingDataModule from ....data.image_utils import remove_black_borders from ....data.split import JSONDatabaseSplit -from ....data.typing import DatabaseSplit +from ....data.typing import DatabaseSplit, Sample from ....data.typing import RawDataLoader as _BaseRawDataLoader -from ....data.typing import Sample from ....utils.rc import load_rc CONFIGURATION_KEY_DATADIR = "datadir." + (__name__.rsplit(".", 2)[-2]) @@ -25,13 +24,16 @@ database.""" class RawDataLoader(_BaseRawDataLoader): """A specialized raw-data-loader for the Shenzen dataset.""" - datadir: str + datadir: pathlib.Path """This variable contains the base directory where the database raw data is stored.""" def __init__(self): - self.datadir = load_rc().get( - CONFIGURATION_KEY_DATADIR, os.path.realpath(os.curdir) + self.datadir = pathlib.Path( + load_rc().get( + CONFIGURATION_KEY_DATADIR, + os.path.realpath(os.curdir), + ), ) def sample(self, sample: tuple[str, int]) -> Sample: @@ -41,16 +43,17 @@ class RawDataLoader(_BaseRawDataLoader): ---------- sample A tuple containing the path suffix, within the dataset root folder, - where to find the image to be loaded, and an integer, representing the - sample label. + where to find the image to be loaded, and an integer, representing + the sample label. Returns ------- The sample representation. """ + # images from TBPOC are encoded as grayscale JPEGs, no need to # call convert("L") here. - image = PIL.Image.open(os.path.join(self.datadir, sample[0])) + image = PIL.Image.open(self.datadir / sample[0]) image = remove_black_borders(image) tensor = to_tensor(image) @@ -68,13 +71,14 @@ class RawDataLoader(_BaseRawDataLoader): ---------- sample A tuple containing the path suffix, within the dataset root folder, - where to find the image to be loaded, and an integer, representing the - sample label. + where to find the image to be loaded, and an integer, representing + the sample label. Returns ------- The integer label associated with the sample """ + return sample[1] @@ -92,7 +96,9 @@ def make_split(basename: str) -> DatabaseSplit: """ return JSONDatabaseSplit( - importlib.resources.files(__name__.rsplit(".", 1)[0]).joinpath(basename) + importlib.resources.files(__name__.rsplit(".", 1)[0]).joinpath( + basename, + ), ) @@ -137,5 +143,5 @@ class DataModule(CachingDataModule): database_split=make_split(split_filename), raw_data_loader=RawDataLoader(), database_name=__package__.split(".")[-1], - split_name=os.path.splitext(split_filename)[0], + split_name=pathlib.Path(split_filename).stem, ) diff --git a/src/mednet/config/data/tbx11k/datamodule.py b/src/mednet/config/data/tbx11k/datamodule.py index 1735607ee72cb7f1298bca1e1adca57bb5d41d6a..73e390b75fd4346370eed59a858a41c6cc178e41 100644 --- a/src/mednet/config/data/tbx11k/datamodule.py +++ b/src/mednet/config/data/tbx11k/datamodule.py @@ -6,19 +6,18 @@ import collections.abc import dataclasses import importlib.resources import os +import pathlib import typing import PIL.Image import typing_extensions - from torch.utils.data._utils.collate import default_collate_fn_map from torchvision.transforms.functional import to_tensor from ....data.datamodule import CachingDataModule from ....data.split import JSONDatabaseSplit -from ....data.typing import DatabaseSplit +from ....data.typing import DatabaseSplit, Sample from ....data.typing import RawDataLoader as _BaseRawDataLoader -from ....data.typing import Sample from ....utils.rc import load_rc CONFIGURATION_KEY_DATADIR = "datadir." + (__name__.rsplit(".", 2)[-2]) @@ -55,14 +54,29 @@ class BoundingBox: int The area in square-pixels. """ + return self.width * self.height @property def xmax(self) -> int: + """Return ``x`` coordinate of rightmost side of bounding box. + + Returns + ------- + int + The rightmost side of the bounding box. + """ return self.xmin + self.width - 1 @property def ymax(self) -> int: + """Return ``y`` coordinate of the lower side of the box. + + Returns + ------- + int + The ``y`` coordinate of the lower side of the box. + """ return self.ymin + self.height - 1 def intersection(self, other: typing_extensions.Self) -> int: @@ -87,6 +101,7 @@ class BoundingBox: The area intersection between this and the other bounding-box in square pixels. """ + dx = min(self.xmax, other.xmax) - max(self.xmin, other.xmin) + 1 dy = min(self.ymax, other.ymax) - max(self.ymin, other.ymin) + 1 @@ -119,9 +134,14 @@ class BoundingBoxes(collections.abc.Sequence[BoundingBox]): # explained at: # https://pytorch.org/docs/stable/data.html#torch.utils.data.default_collate def _collate_boundingboxes_fn( - batch, *, collate_fn_map=None + batch, + *, + collate_fn_map=None, ): # numpydoc ignore=PR01 - """Custom collate_fn() for pytorch dataloaders that ignores BoundingBoxes objects. + """Collate samples that includes bounding boxes. + + Custom collate_fn() for pytorch dataloaders that ignores BoundingBoxes + objects. Returns ------- @@ -151,13 +171,16 @@ finding locations, as described above. class RawDataLoader(_BaseRawDataLoader): """A specialized raw-data-loader for the TBX11k dataset.""" - datadir: str + datadir: pathlib.Path """This variable contains the base directory where the database raw data is stored.""" def __init__(self): - self.datadir = load_rc().get( - CONFIGURATION_KEY_DATADIR, os.path.realpath(os.curdir) + self.datadir = pathlib.Path( + load_rc().get( + CONFIGURATION_KEY_DATADIR, + os.path.realpath(os.curdir), + ), ) def sample(self, sample: DatabaseSample) -> Sample: @@ -175,7 +198,8 @@ class RawDataLoader(_BaseRawDataLoader): ------- The sample representation. """ - image = PIL.Image.open(os.path.join(self.datadir, sample[0])) + + image = PIL.Image.open(self.datadir / sample[0]) tensor = to_tensor(image) # use the code below to view generated images @@ -205,6 +229,7 @@ class RawDataLoader(_BaseRawDataLoader): int The integer label associated with the sample. """ + return sample[1] def bounding_boxes(self, sample: DatabaseSample) -> BoundingBoxes: @@ -223,6 +248,7 @@ class RawDataLoader(_BaseRawDataLoader): BoundingBoxes Bounding box annotations, if any available with the sample. """ + if len(sample) > 2: return BoundingBoxes([BoundingBox(*k) for k in sample[2]]) # type: ignore @@ -243,7 +269,9 @@ def make_split(basename: str) -> DatabaseSplit: """ return JSONDatabaseSplit( - importlib.resources.files(__name__.rsplit(".", 1)[0]).joinpath(basename) + importlib.resources.files(__name__.rsplit(".", 1)[0]).joinpath( + basename, + ), ) @@ -356,5 +384,5 @@ class DataModule(CachingDataModule): database_split=make_split(split_filename), raw_data_loader=RawDataLoader(), database_name=__package__.split(".")[-1], - split_name=os.path.splitext(split_filename)[0], + split_name=pathlib.Path(split_filename).stem, ) diff --git a/src/mednet/data/augmentations.py b/src/mednet/data/augmentations.py index 8dbb4520c91c3338c7d3c9bb5e60c84b91f1a709..2ed9fa1fce0ee0bdb90a426bbf894c62e219ead8 100644 --- a/src/mednet/data/augmentations.py +++ b/src/mednet/data/augmentations.py @@ -21,7 +21,6 @@ import typing import numpy.random import numpy.typing import torch - from scipy.ndimage import gaussian_filter, map_coordinates logger = logging.getLogger(__name__) @@ -116,7 +115,7 @@ def _elastic_deformation_on_image( indices, order=spline_order, mode=mode, - ).reshape(img_shape) + ).reshape(img_shape), ) # wraps numpy array as tensor, move to destination device if need-be @@ -166,6 +165,7 @@ def _elastic_deformation_on_batch( tensor A batch of images with elastic deformation applied, as a tensor on the CPU. """ + # transforms our custom functions into simpler callables partial = functools.partial( _elastic_deformation_on_image, @@ -257,6 +257,7 @@ class ElasticDeformation: int The multiprocessing type. """ + return self._parallel @parallel.setter @@ -266,7 +267,7 @@ class ElasticDeformation: if value >= 0: instances = value or multiprocessing.cpu_count() logger.info( - f"Applying data-augmentation using {instances} processes..." + f"Applying data-augmentation using {instances} processes...", ) self._mp_pool = multiprocessing.get_context("spawn").Pool(instances) else: @@ -278,7 +279,7 @@ class ElasticDeformation: # auto-tunning on first batch instances = min(img.shape[0], multiprocessing.cpu_count()) self._mp_pool = multiprocessing.get_context("spawn").Pool( - instances + instances, ) return _elastic_deformation_on_batch( @@ -291,7 +292,7 @@ class ElasticDeformation: self._mp_pool, ).to(img.device) - elif len(img.shape) == 3: + if len(img.shape) == 3: return _elastic_deformation_on_image( img.cpu(), self.alpha, @@ -304,5 +305,5 @@ class ElasticDeformation: raise RuntimeError( f"This transform accepts only images with 3 dimensions," f"or batches of images with 4 dimensions. However, I got " - f"an image with {img.ndim} dimensions." + f"an image with {img.ndim} dimensions.", ) diff --git a/src/mednet/data/datamodule.py b/src/mednet/data/datamodule.py index 71fead5a2d275b2aa33f28dcfc9c894ac687a2b2..c73bb27932c52e2061134a7f007256c7f6292161 100644 --- a/src/mednet/data/datamodule.py +++ b/src/mednet/data/datamodule.py @@ -58,6 +58,7 @@ def _sample_size_bytes(s: Sample) -> int: int The size of the Tensor in bytes. """ + return int(t.element_size() * torch.prod(torch.tensor(t.shape))) size = sys.getsizeof(s[0]) # tensor metadata @@ -105,7 +106,7 @@ class _DelayedLoadingDataset(Dataset): first_sample = self[0] logger.info( f"Delayed loading dataset (first tensor): " - f"{list(first_sample[0].shape)}@{first_sample[0].dtype}" + f"{list(first_sample[0].shape)}@{first_sample[0].dtype}", ) sample_size_mb = _sample_size_bytes(first_sample) / (1024.0 * 1024.0) logger.info(f"Estimated sample size: {sample_size_mb:.1f} Mb") @@ -118,6 +119,7 @@ class _DelayedLoadingDataset(Dataset): list[int | list[int]] The integer labels for all samples in the dataset. """ + return [self.loader.label(k) for k in self.raw_dataset] def __getitem__(self, key: int) -> Sample: @@ -156,6 +158,7 @@ def _apply_loader_and_transforms( Sample The loaded and transformed sample. """ + sample = load(info) return model_transform(sample[0]), sample[1] @@ -208,19 +211,20 @@ class _CachedDataset(Dataset): with multiprocessing.Pool(instances) as p: self.data = list( tqdm.tqdm( - p.imap(self.loader, raw_dataset), total=len(raw_dataset) - ) + p.imap(self.loader, raw_dataset), + total=len(raw_dataset), + ), ) # Estimates memory occupance logger.info( f"Cached dataset (first tensor): " - f"{list(self.data[0][0].shape)}@{self.data[0][0].dtype}" + f"{list(self.data[0][0].shape)}@{self.data[0][0].dtype}", ) sample_size_mb = _sample_size_bytes(self.data[0]) / (1024.0 * 1024.0) logger.info( f"Estimated RAM occupance (sample / dataset): " - f"{sample_size_mb:.1f} / {(len(self.data)*sample_size_mb):.1f} Mb" + f"{sample_size_mb:.1f} / {(len(self.data)*sample_size_mb):.1f} Mb", ) def labels(self) -> list[int | list[int]]: @@ -231,6 +235,7 @@ class _CachedDataset(Dataset): list[int | list[int]] The integer labels for all samples in the dataset. """ + return [k[1]["label"] for k in self.data] def __getitem__(self, key: int) -> Sample: @@ -269,6 +274,7 @@ class _ConcatDataset(Dataset): list[int | list[int]] The integer labels for all samples in the dataset. """ + return list(itertools.chain(*[k.labels() for k in self._datasets])) def __getitem__(self, key: int) -> Sample: @@ -370,12 +376,13 @@ def _make_balanced_random_sampler( # There are two possible cases: targets/no-targets metadata_example = dataset.datasets[0][0][1] if target in metadata_example and isinstance( - metadata_example[target], int + metadata_example[target], + int, ): # there are integer targets, let's balance with those logger.info( f"Balancing sample selection probabilities **and** " - f"concatenated-datasets using metadata targets `{target}`" + f"concatenated-datasets using metadata targets `{target}`", ) targets = [ k @@ -386,7 +393,7 @@ def _make_balanced_random_sampler( else: logger.warning( f"Balancing samples **and** concatenated-datasets " - f"by using dataset totals as `{target}: int` is not true" + f"by using dataset totals as `{target}: int` is not true", ) weights = [ k @@ -400,21 +407,24 @@ def _make_balanced_random_sampler( else: metadata_example = dataset[0][1] if target in metadata_example and isinstance( - metadata_example[target], int + metadata_example[target], + int, ): logger.info( f"Balancing samples from dataset using metadata " - f"targets `{target}`" + f"targets `{target}`", ) weights = _calculate_weights(dataset.labels()) # type: ignore else: raise RuntimeError( f"Cannot balance samples with multiple class labels " - f"({target}: list[int]) or without metadata targets `{target}`" + f"({target}: list[int]) or without metadata targets `{target}`", ) return torch.utils.data.WeightedRandomSampler( - weights, len(weights), replacement=True + weights, + len(weights), + replacement=True, ) @@ -537,7 +547,7 @@ class ConcatDataModule(lightning.LightningDataModule): count = sum([len(k) for k, _ in split_loaders]) logger.info( f"Dataset `{dataset_name}` (`{database_name}`/`{split_name}`) " - f"contains {count} samples" + f"contains {count} samples", ) self.cache_samples = cache_samples @@ -573,7 +583,7 @@ class ConcatDataModule(lightning.LightningDataModule): The mapping between the command-line interface ``parallel`` setting works like this: - .. list-table:: Relationship between ``parallel`` and DataLoader parameterisation + .. list-table:: Relationship between ``parallel`` and DataLoader parameters :widths: 15 15 70 :header-rows: 1 @@ -614,9 +624,9 @@ class ConcatDataModule(lightning.LightningDataModule): self._dataloader_multiproc["num_workers"] = num_workers if num_workers > 0 and sys.platform == "darwin": - self._dataloader_multiproc[ - "multiprocessing_context" - ] = multiprocessing.get_context("spawn") + self._dataloader_multiproc["multiprocessing_context"] = ( + multiprocessing.get_context("spawn") + ) # keep workers hanging around if we have multiple if value >= 0: @@ -653,7 +663,7 @@ class ConcatDataModule(lightning.LightningDataModule): logger.warning( f"Resetting {len(self._datasets)} loaded datasets due " "to changes in model-transform properties. If you were caching " - "data loading, this will (eventually) trigger a reload." + "data loading, this will (eventually) trigger a reload.", ) self._datasets = {} @@ -686,7 +696,7 @@ class ConcatDataModule(lightning.LightningDataModule): if "train" not in self._datasets: self._setup_dataset("train") self._train_sampler = _make_balanced_random_sampler( - self._datasets["train"] + self._datasets["train"], ) else: self._train_sampler = None @@ -720,12 +730,11 @@ class ConcatDataModule(lightning.LightningDataModule): batch-chunk-count pieces, and gradients are accumulated to complete each batch. """ - # validation if batch_size % batch_chunk_count != 0: raise RuntimeError( f"batch_size ({batch_size}) must be divisible by " - f"batch_chunk_size ({batch_chunk_count})." + f"batch_chunk_size ({batch_chunk_count}).", ) self._batch_size = batch_size @@ -733,7 +742,7 @@ class ConcatDataModule(lightning.LightningDataModule): self._chunk_size = self._batch_size // self._batch_chunk_count def _setup_dataset(self, name: str) -> None: - """Set-up a single dataset from the input data split. + """Set up a single dataset from the input data split. Parameters ---------- @@ -745,13 +754,13 @@ class ConcatDataModule(lightning.LightningDataModule): raise RuntimeError( "Parameter `model_transforms` has not yet been " "set. If you do not have model transforms, then " - "set it to an empty list." + "set it to an empty list.", ) if name in self._datasets: logger.info( f"Dataset `{name}` is already setup. " - f"Not re-instantiating it." + f"Not re-instantiating it.", ) return @@ -759,7 +768,7 @@ class ConcatDataModule(lightning.LightningDataModule): if self.cache_samples: logger.info( f"Loading dataset:`{name}` into memory (caching)." - f" Trade-off: CPU RAM usage: more | Disk I/O: less" + f" Trade-off: CPU RAM usage: more | Disk I/O: less", ) for split, loader in self.splits[name]: datasets.append( @@ -768,12 +777,12 @@ class ConcatDataModule(lightning.LightningDataModule): loader, self.parallel, self.model_transforms, - ) + ), ) else: logger.info( f"Loading dataset:`{name}` without caching." - f" Trade-off: CPU RAM usage: less | Disk I/O: more" + f" Trade-off: CPU RAM usage: less | Disk I/O: more", ) for split, loader in self.splits[name]: datasets.append( @@ -781,7 +790,7 @@ class ConcatDataModule(lightning.LightningDataModule): split, loader, self.model_transforms, - ) + ), ) if len(datasets) == 1: @@ -797,6 +806,7 @@ class ConcatDataModule(lightning.LightningDataModule): list[str] The list of validation dataset names. """ + return ["validation"] + [ k for k in self.splits.keys() if k.startswith("monitor-") ] @@ -861,6 +871,7 @@ class ConcatDataModule(lightning.LightningDataModule): * ``test``: uses only the test dataset * ``predict``: uses only the test dataset """ + super().teardown(stage) def train_dataloader(self) -> DataLoader: @@ -915,7 +926,8 @@ class ConcatDataModule(lightning.LightningDataModule): return { k: torch.utils.data.DataLoader( - self._datasets[k], **validation_loader_opts + self._datasets[k], + **validation_loader_opts, ) for k in self._val_dataset_keys() } @@ -936,7 +948,7 @@ class ConcatDataModule(lightning.LightningDataModule): drop_last=self.drop_incomplete_batch, pin_memory=self.pin_memory, **self._dataloader_multiproc, - ) + ), ) def predict_dataloader(self) -> dict[str, DataLoader]: diff --git a/src/mednet/data/image_utils.py b/src/mednet/data/image_utils.py index c0e0ff6a6e9c64b994c54f9e0842cfb1af05df8a..502244a8eea7084c1e1991e29828fc027b8ca622 100644 --- a/src/mednet/data/image_utils.py +++ b/src/mednet/data/image_utils.py @@ -8,17 +8,18 @@ import PIL.Image def remove_black_borders( - img: PIL.Image.Image, threshold: int = 0 + img: PIL.Image.Image, + threshold: int = 0, ) -> PIL.Image.Image: """Remove black borders of CXR. Parameters ---------- - img - A PIL image. - threshold - Threshold value from which borders are considered black. - Defaults to 0. + img + A PIL image. + threshold + Threshold value from which borders are considered black. + Defaults to 0. Returns ------- @@ -30,18 +31,17 @@ def remove_black_borders( if len(img_array.shape) == 2: # single channel mask = numpy.asarray(img_array) > threshold return PIL.Image.fromarray( - img_array[numpy.ix_(mask.any(1), mask.any(0))] + img_array[numpy.ix_(mask.any(1), mask.any(0))], ) - elif len(img_array.shape) == 3 and img_array.shape[2] == 3: + if len(img_array.shape) == 3 and img_array.shape[2] == 3: r_mask = img_array[:, :, 0] > threshold g_mask = img_array[:, :, 1] > threshold b_mask = img_array[:, :, 2] > threshold mask = r_mask | g_mask | b_mask return PIL.Image.fromarray( - img_array[numpy.ix_(mask.any(1), mask.any(0))] + img_array[numpy.ix_(mask.any(1), mask.any(0))], ) - else: - raise NotImplementedError + raise NotImplementedError diff --git a/src/mednet/data/split.py b/src/mednet/data/split.py index 5d70aafd551b18fca856f846d91818442bb63c03..e50384c1a91fae7bc1ea6d5782f6866be35e4383 100644 --- a/src/mednet/data/split.py +++ b/src/mednet/data/split.py @@ -101,7 +101,6 @@ class JSONDatabaseSplit(DatabaseSplit): return iter(self._datasets) def __len__(self) -> int: - """The number of datasets we currently have.""" return len(self._datasets) @@ -138,7 +137,8 @@ class CSVDatabaseSplit(DatabaseSplit): """ def __init__( - self, directory: pathlib.Path | str | importlib.abc.Traversable + self, + directory: pathlib.Path | str | importlib.abc.Traversable, ): if isinstance(directory, str): directory = pathlib.Path(directory) @@ -174,12 +174,12 @@ class CSVDatabaseSplit(DatabaseSplit): retval[dataset.name[: -len(".csv")]] = [k for k in reader] else: logger.debug( - f"Ignoring file {dataset} in CSVDatabaseSplit readout" + f"Ignoring file {dataset} in CSVDatabaseSplit readout", ) return retval def __getitem__(self, key: str) -> typing.Sequence[typing.Any]: - """Accesse dataset ``key`` from this split.""" + """Access dataset ``key`` from this split.""" return self._datasets[key] def __iter__(self): @@ -187,7 +187,7 @@ class CSVDatabaseSplit(DatabaseSplit): return iter(self._datasets) def __len__(self) -> int: - """The number of datasets we currently have.""" + """Return number of datasets we currently have.""" return len(self._datasets) @@ -220,8 +220,9 @@ def check_database_split_loading( int Number of errors found. """ + logger.info( - "Checking if all samples in all datasets of this split can be loaded..." + "Checking if all samples in all datasets of this split can be loaded...", ) errors = 0 for dataset, samples in database_split.items(): @@ -232,7 +233,7 @@ def check_database_split_loading( assert isinstance(data, torch.Tensor) except Exception as e: logger.info( - f"Found error loading entry {pos} in dataset `{dataset}`: {e}" + f"Found error loading entry {pos} in dataset `{dataset}`: {e}", ) errors += 1 return errors diff --git a/src/mednet/data/typing.py b/src/mednet/data/typing.py index e5b68d982af5ae421ce94b442dbbfa5617843d4f..93bbfc970b50549ae03a125ef89b785e62a161d0 100644 --- a/src/mednet/data/typing.py +++ b/src/mednet/data/typing.py @@ -32,6 +32,7 @@ class RawDataLoader: _ Information about the sample to load. Implementation dependent. """ + raise NotImplementedError("You must implement the `sample()` method") def label(self, k: typing.Any) -> int | list[int]: @@ -51,6 +52,7 @@ class RawDataLoader: int | list[int] The label corresponding to the specified sample. """ + return self.sample(k)[1]["label"] @@ -64,7 +66,8 @@ TransformSequence: typing.TypeAlias = typing.Sequence[Transform] """A sequence of transforms.""" DatabaseSplit: typing.TypeAlias = collections.abc.Mapping[ - str, typing.Sequence[typing.Any] + str, + typing.Sequence[typing.Any], ] """The definition of a database split. diff --git a/src/mednet/engine/callbacks.py b/src/mednet/engine/callbacks.py index fd0da8843b5f32baf407d098eb9751dcaf1b4080..d9e757bf36664c9c2b10295ecd728e219cbee611 100644 --- a/src/mednet/engine/callbacks.py +++ b/src/mednet/engine/callbacks.py @@ -61,7 +61,7 @@ class LoggingCallback(lightning.pytorch.Callback): trainer: lightning.pytorch.Trainer, pl_module: lightning.pytorch.LightningModule, ): - """Callback to be executed **before** the whole training starts. + """Execute actions when training starts (lightning callback). This method is executed whenever you *start* training a module. @@ -79,9 +79,9 @@ class LoggingCallback(lightning.pytorch.Callback): trainer: lightning.pytorch.Trainer, pl_module: lightning.pytorch.LightningModule, ) -> None: - """Callback to be executed **before** every training batch starts. + """Execute actions when a training epoch starts (lightning callback). - This method is executed whenever a training batch starts. Presumably, + This method is executed whenever a training epoch starts. Presumably, batches happen as often as possible. You want to make this code very fast. Do not log things to the terminal or the such, or do complicated (lengthy) calculations. @@ -98,6 +98,7 @@ class LoggingCallback(lightning.pytorch.Callback): pl_module The lightning module that is being trained. """ + # summarizes resource usage since the last checkpoint # clears internal buffers and starts accumulating again. self._train_resource_monitor.checkpoint(remove_last_n=-1) @@ -109,7 +110,7 @@ class LoggingCallback(lightning.pytorch.Callback): trainer: lightning.pytorch.Trainer, pl_module: lightning.pytorch.LightningModule, ): - """Callback to be executed **after** every training epoch ends. + """Execute actions after a training epoch ends (lightning callback). This method is executed whenever a training epoch ends. Presumably, epochs happen as often as possible. You want to make this code @@ -137,7 +138,7 @@ class LoggingCallback(lightning.pytorch.Callback): logger.warning( "Unable to fetch monitoring information from " "resource monitor. CPU/GPU utilisation will be " - "missing." + "missing.", ) overall_cycle_time = time.time() - self._start_training_epoch_time @@ -166,7 +167,7 @@ class LoggingCallback(lightning.pytorch.Callback): batch: tuple[torch.Tensor, typing.Mapping[str, torch.Tensor]], batch_idx: int, ) -> None: - """Callback to be executed **after** every training batch ends. + """Execute actions after a training batch ends (lightning callback). This method is executed whenever a training batch ends. Presumably, batches happen as often as possible. You want to make this code very @@ -191,6 +192,7 @@ class LoggingCallback(lightning.pytorch.Callback): batch_idx The relative number of the batch. """ + pl_module.log( "loss/train", outputs["loss"].item(), @@ -205,7 +207,7 @@ class LoggingCallback(lightning.pytorch.Callback): trainer: lightning.pytorch.Trainer, pl_module: lightning.pytorch.LightningModule, ) -> None: - """Callback to be executed **before** every validation batch starts. + """Execute actions before a validation batch starts (lightning callback). This method is executed whenever a validation batch starts. Presumably, batches happen as often as possible. You want to make this code very @@ -224,6 +226,7 @@ class LoggingCallback(lightning.pytorch.Callback): pl_module The lightning module that is being trained. """ + # required because the validation epoch is started **within** the # training epoch START/END. # @@ -239,7 +242,7 @@ class LoggingCallback(lightning.pytorch.Callback): trainer: lightning.pytorch.Trainer, pl_module: lightning.pytorch.LightningModule, ) -> None: - """Callback to be executed **after** every validation epoch ends. + """Execute actions after a validation batch ends (lightning callback). This method is executed whenever a validation epoch ends. Presumably, epochs happen as often as possible. You want to make this code @@ -268,7 +271,7 @@ class LoggingCallback(lightning.pytorch.Callback): logger.warning( "Unable to fetch monitoring information from " "resource monitor. CPU/GPU utilisation will be " - "missing." + "missing.", ) self._to_log["step"] = float(trainer.current_epoch) @@ -287,7 +290,7 @@ class LoggingCallback(lightning.pytorch.Callback): batch_idx: int, dataloader_idx: int = 0, ) -> None: - """Callback to be executed **after** every validation batch ends. + """Execute actions after a validation after ends (lightning callback). This method is executed whenever a validation batch ends. Presumably, batches happen as often as possible. You want to make this code very diff --git a/src/mednet/engine/device.py b/src/mednet/engine/device.py index b008c4cbe952c33bf4ace294426e1362ff15b205..c2467c416c8fe21a733cf4066abe56c91937a6f3 100644 --- a/src/mednet/engine/device.py +++ b/src/mednet/engine/device.py @@ -21,7 +21,8 @@ SupportedPytorchDevice: typing.TypeAlias = typing.Literal[ def _split_int_list(s: str) -> list[int]: - """Split a list of integers encoded in a string (e.g. "1,2,3") into a Python list of integers (e.g. ``[1, 2, 3]``). + """Split a list of integers encoded in a string (e.g. "1,2,3") into a + Python list of integers (e.g. ``[1, 2, 3]``). Parameters ---------- @@ -33,15 +34,15 @@ def _split_int_list(s: str) -> list[int]: list[int] A Python list of integers. """ + return [int(k.strip()) for k in s.split(",")] class DeviceManager: - """This class is used to manage the Lightning Accelerator and Pytorch - Devices. + r"""Manage Lightning Accelerator and Pytorch Devices. It takes the user input, in the form of a string defined by - ``[\\S+][:\\d[,\\d]?]?`` (e.g.: ``cpu``, ``mps``, or ``cuda:3``), and can + ``[\S+][:\d[,\d]?]?`` (e.g.: ``cpu``, ``mps``, or ``cuda:3``), and can translate to the right incarnation of Pytorch devices or Lightning Accelerators to interface with the various frameworks. @@ -52,7 +53,7 @@ class DeviceManager: ---------- name The name of the device to use, in the form of a string defined by - ``[\\S+][:\\d[,\\d]?]?`` (e.g.: ``cpu``, ``mps``, or ``cuda:3``). In + ``[\S+][:\d[,\d]?]?`` (e.g.: ``cpu``, ``mps``, or ``cuda:3``). In the specific case of ``cuda``, one can also specify a device to use either by adding ``:N``, where N is the zero-indexed board number on the computer, or by setting the environment variable @@ -67,7 +68,8 @@ class DeviceManager: if parts[0] not in typing.get_args(SupportedPytorchDevice): raise ValueError(f"Unsupported device-type `{parts[0]}`") self.device_type: SupportedPytorchDevice = typing.cast( - SupportedPytorchDevice, parts[0] + SupportedPytorchDevice, + parts[0], ) self.device_ids: list[int] = [] @@ -81,7 +83,7 @@ class DeviceManager: if self.device_ids and visible != self.device_ids: logger.warning( f"${{CUDA_VISIBLE_DEVICES}}={visible} and name={name} " - f"- overriding environment with value set on `name`" + f"- overriding environment with value set on `name`", ) else: self.device_ids = visible @@ -89,20 +91,20 @@ class DeviceManager: # make sure that it is consistent with the environment if self.device_ids: os.environ["CUDA_VISIBLE_DEVICES"] = ",".join( - [str(k) for k in self.device_ids] + [str(k) for k in self.device_ids], ) if self.device_type not in typing.get_args(SupportedPytorchDevice): raise RuntimeError( f"Unsupported device type `{self.device_type}`. " f"Supported devices types are " - f"`{', '.join(typing.get_args(SupportedPytorchDevice))}`" + f"`{', '.join(typing.get_args(SupportedPytorchDevice))}`", ) if self.device_ids and self.device_type in ("cpu", "mps"): logger.warning( f"Cannot pin device ids if using cpu or mps backend. " - f"Setting `name` to {name} is non-sensical. Ignoring..." + f"Setting `name` to {name} is non-sensical. Ignoring...", ) # check if the device_type that was set has support compiled in @@ -138,15 +140,16 @@ class DeviceManager: if self.device_type in ("cpu", "mps"): return torch.device(self.device_type) - elif self.device_type == "cuda": + + if self.device_type == "cuda": if not self.device_ids: return torch.device(self.device_type) - else: - return torch.device(self.device_type, self.device_ids[0]) + + return torch.device(self.device_type, self.device_ids[0]) # if you get to this point, this is an unexpected RuntimeError raise RuntimeError( - f"Unexpected device type {self.device_type} lacks support" + f"Unexpected device type {self.device_type} lacks support", ) def lightning_accelerator(self) -> tuple[str, int | list[int] | str]: diff --git a/src/mednet/engine/evaluator.py b/src/mednet/engine/evaluator.py index bfd32cf94b5fb10a9fec000bc78711509211c155..42b1308a6714336b9b4f01d392d4500317342043 100644 --- a/src/mednet/engine/evaluator.py +++ b/src/mednet/engine/evaluator.py @@ -8,7 +8,6 @@ import itertools import json import logging import typing - from collections.abc import Iterable, Iterator import credible.curves @@ -17,7 +16,6 @@ import numpy import numpy.typing import sklearn.metrics import tabulate - from matplotlib import pyplot as plt from ..models.typing import BinaryPrediction @@ -39,6 +37,7 @@ def eer_threshold(predictions: Iterable[BinaryPrediction]) -> float: float The EER threshold value. """ + from scipy.interpolate import interp1d from scipy.optimize import brentq @@ -52,7 +51,8 @@ def eer_threshold(predictions: Iterable[BinaryPrediction]) -> float: def _get_centered_maxf1( - f1_scores: numpy.typing.NDArray, thresholds: numpy.typing.NDArray + f1_scores: numpy.typing.NDArray, + thresholds: numpy.typing.NDArray, ) -> tuple[float, float]: """Return the centered max F1 score threshold when multiple thresholds give the same max F1 score. @@ -69,6 +69,7 @@ def _get_centered_maxf1( tuple(float, float) A tuple with the maximum F1-score and the "centered" threshold. """ + maxf1 = f1_scores.max() maxf1_indices = numpy.where(f1_scores == maxf1)[0] @@ -97,17 +98,22 @@ def maxf1_threshold(predictions: Iterable[BinaryPrediction]) -> float: The threshold value leading to the maximum F1-score on the provided set of predictions. """ + y_scores = [k[2] for k in predictions] y_labels = [k[1] for k in predictions] precision, recall, thresholds = sklearn.metrics.precision_recall_curve( - y_labels, y_scores + y_labels, + y_scores, ) numerator = 2 * recall * precision denom = recall + precision f1_scores = numpy.divide( - numerator, denom, out=numpy.zeros_like(denom), where=(denom != 0) + numerator, + denom, + out=numpy.zeros_like(denom), + where=(denom != 0), ) _, maxf1_threshold = _get_centered_maxf1(f1_scores, thresholds) @@ -140,6 +146,7 @@ def score_plot( A single (matplotlib) plot containing the score distribution, ready to be saved to disk or displayed. """ + from matplotlib.ticker import MaxNLocator fig, ax = plt.subplots(1, 1) @@ -244,7 +251,7 @@ def run_binary( logger.warning( f"User did not pass an *a priori* threshold for the evaluation " f"of split `{name}`. Using threshold a posteriori (biased) with value " - f"`{use_threshold:.4f}`" + f"`{use_threshold:.4f}`", ) y_predictions = numpy.where(y_scores >= use_threshold, pos_label, neg_label) @@ -255,19 +262,29 @@ def run_binary( threshold=use_threshold, threshold_a_posteriori=(threshold_a_priori is None), precision=sklearn.metrics.precision_score( - y_labels, y_predictions, pos_label=pos_label + y_labels, + y_predictions, + pos_label=pos_label, ), recall=sklearn.metrics.recall_score( - y_labels, y_predictions, pos_label=pos_label + y_labels, + y_predictions, + pos_label=pos_label, ), f1_score=sklearn.metrics.f1_score( - y_labels, y_predictions, pos_label=pos_label + y_labels, + y_predictions, + pos_label=pos_label, ), average_precision_score=sklearn.metrics.average_precision_score( - y_labels, y_scores, pos_label=pos_label + y_labels, + y_scores, + pos_label=pos_label, ), specificity=sklearn.metrics.recall_score( - y_labels, y_predictions, pos_label=neg_label + y_labels, + y_predictions, + pos_label=neg_label, ), auc_score=sklearn.metrics.roc_auc_score( y_labels, @@ -282,15 +299,19 @@ def run_binary( zip( ("fpr", "tpr", "thresholds"), sklearn.metrics.roc_curve( - y_labels, y_scores, pos_label=pos_label + y_labels, + y_scores, + pos_label=pos_label, ), - ) + ), ), precision_recall=dict( zip( ("precision", "recall", "thresholds"), sklearn.metrics.precision_recall_curve( - y_labels, y_scores, pos_label=pos_label + y_labels, + y_scores, + pos_label=pos_label, ), ), ), @@ -304,9 +325,11 @@ def run_binary( zip( ("hist", "bin_edges"), numpy.histogram( - y_scores[y_labels == pos_label], bins=binning, range=(0, 1) + y_scores[y_labels == pos_label], + bins=binning, + range=(0, 1), ), - ) + ), ), negatives=dict( zip( @@ -316,7 +339,7 @@ def run_binary( bins=binning, range=(0, 1), ), - ) + ), ), ) @@ -324,7 +347,8 @@ def run_binary( def tabulate_results( - data: typing.Mapping[str, typing.Mapping[str, typing.Any]], fmt: str + data: typing.Mapping[str, typing.Mapping[str, typing.Any]], + fmt: str, ) -> str: """Tabulate summaries from multiple splits. @@ -378,6 +402,7 @@ def aggregate_roc( matplotlib.figure.Figure A figure, containing the aggregated ROC plot. """ + fig, ax = plt.subplots(1, 1) assert isinstance(fig, matplotlib.figure.Figure) @@ -424,7 +449,10 @@ def aggregate_roc( style = next(linecycler) (line,) = ax.plot( - elements["fpr"], elements["tpr"], color=color, linestyle=style + elements["fpr"], + elements["tpr"], + color=color, + linestyle=style, ) legend.append((line, label)) @@ -557,7 +585,7 @@ def aggregate_pr( for name, elements in data.items(): _ap = credible.curves.average_metric( - (elements["precision"], elements["recall"]) + (elements["precision"], elements["recall"]), ) label = f"{name} (AP={_ap:.2f})" color = next(colorcycler) diff --git a/src/mednet/engine/loggers.py b/src/mednet/engine/loggers.py index 8e14f774a44e1feaeadc79ecb8b9c0b62ba91927..db046cc04c0478e5a9a47af13c74308c24422976 100644 --- a/src/mednet/engine/loggers.py +++ b/src/mednet/engine/loggers.py @@ -75,4 +75,4 @@ class CustomTensorboardLogger(TensorBoardLogger): @property def log_dir(self) -> str: - return os.path.join(self.save_dir, self.name) + return os.path.join(self.save_dir, self.name) # noqa: PTH118 diff --git a/src/mednet/engine/predictor.py b/src/mednet/engine/predictor.py index ae9ef271a8e0d72275a3e70fbc57f24a3e4978f3..1ba6f05a12bc16b4c9a48def96b70a0472a4646a 100644 --- a/src/mednet/engine/predictor.py +++ b/src/mednet/engine/predictor.py @@ -88,22 +88,26 @@ def run( return [sample for batch in p for sample in batch] dataloaders = datamodule.predict_dataloader() + if isinstance(dataloaders, torch.utils.data.DataLoader): logger.info("Running prediction on a single dataloader...") return _flatten(trainer.predict(model, dataloaders)) # type: ignore - elif isinstance(dataloaders, list): + + if isinstance(dataloaders, list): retval_list = [] for k, dataloader in enumerate(dataloaders): logger.info(f"Running prediction on split `{k}`...") retval_list.append(_flatten(trainer.predict(model, dataloader))) # type: ignore return retval_list - elif isinstance(dataloaders, dict): + + if isinstance(dataloaders, dict): retval_dict = {} for name, dataloader in dataloaders.items(): logger.info(f"Running prediction on `{name}` split...") retval_dict[name] = _flatten(trainer.predict(model, dataloader)) # type: ignore return retval_dict - elif dataloaders is None: + + if dataloaders is None: logger.warning("Datamodule did not return any prediction dataloaders!") return None @@ -112,5 +116,5 @@ def run( raise TypeError( f"Datamodule returned strangely typed prediction " f"dataloaders: `{type(dataloaders)}` - Please write code " - f"to support this use-case." + f"to support this use-case.", ) diff --git a/src/mednet/engine/saliency/completeness.py b/src/mednet/engine/saliency/completeness.py index 197822ddf28ddc60c14aa039c457e3d4453c5652..314370bf53623ed6f566eefc37e0235c6ae890a3 100644 --- a/src/mednet/engine/saliency/completeness.py +++ b/src/mednet/engine/saliency/completeness.py @@ -11,7 +11,6 @@ import lightning.pytorch import numpy as np import torch import tqdm - from pytorch_grad_cam.metrics.road import ( ROADLeastRelevantFirstAverage, ROADMostRelevantFirstAverage, @@ -26,6 +25,14 @@ logger = logging.getLogger(__name__) class SigmoidClassifierOutputTarget(torch.nn.Module): + """Consider output to be a sigmoid. + + Parameters + ---------- + category + The category. + """ + def __init__(self, category): self.category = category @@ -73,15 +80,17 @@ def _calculate_road_scores( most-relevant-first average score (``morf``), least-relevant-first average score (``lerf``) and the combined value (``(lerf-morf)/2``). """ + saliency_map = saliency_map_callable( - input_tensor=images, targets=[ClassifierOutputTarget(output_num)] + input_tensor=images, + targets=[ClassifierOutputTarget(output_num)], ) - cam_metric_ROADMoRF_avg = ROADMostRelevantFirstAverage( - percentiles=percentiles + cam_metric_roadmorf_avg = ROADMostRelevantFirstAverage( + percentiles=percentiles, ) - cam_metric_ROADLeRF_avg = ROADLeastRelevantFirstAverage( - percentiles=percentiles + cam_metric_roadlerf_avg = ROADLeastRelevantFirstAverage( + percentiles=percentiles, ) # Calculate ROAD scores for all percentiles and average - this is NOT the @@ -91,14 +100,14 @@ def _calculate_road_scores( # ``metrics.road``). metric_target = [SigmoidClassifierOutputTarget(output_num)] - MoRF_scores = cam_metric_ROADMoRF_avg( + morf_scores = cam_metric_roadmorf_avg( input_tensor=images, cams=saliency_map, model=model, targets=metric_target, ) - LeRF_scores = cam_metric_ROADLeRF_avg( + lerf_scores = cam_metric_roadlerf_avg( input_tensor=images, cams=saliency_map, model=model, @@ -106,9 +115,9 @@ def _calculate_road_scores( ) return ( - float(MoRF_scores.item()), - float(LeRF_scores.item()), - float(LeRF_scores.item() - MoRF_scores.item()) / 2.0, + float(morf_scores.item()), + float(lerf_scores.item()), + float(lerf_scores.item() - morf_scores.item()) / 2.0, ) @@ -121,8 +130,9 @@ def _process_sample( positive_only: bool, percentiles: typing.Sequence[int], ) -> list: - """Helper function to :py:func:`run` to be used in multiprocessing - contexts. + """Process a single sample. + + Helper function to :py:func:`run` to be used in multiprocessing contexts. Parameters ---------- @@ -290,7 +300,7 @@ def run( if saliency_map_algorithm == "fullgrad": raise ValueError( "Fullgrad saliency map algorithm is not supported for the " - "Pasa model." + "Pasa model.", ) target_layers = [model.fc14] # Last non-1x1 Conv2d layer elif isinstance(model, Densenet): @@ -309,7 +319,7 @@ def run( f"). The current implementation can only handle a single GPU. " f"Either disable GPU usage, set the number of " f"multiprocessing instances to one, or disable multiprocessing " - "entirely (ie. set it to -1)." + "entirely (ie. set it to -1).", ) # prepares model for evaluation, cast to target device @@ -342,10 +352,13 @@ def run( if parallel < 0: logger.info( f"Computing ROAD scores for dataset `{k}` in the current " - f"process context..." + f"process context...", ) for sample in tqdm.tqdm( - v, desc="samples", leave=False, disable=None + v, + desc="samples", + leave=False, + disable=None, ): retval[k].append(_process(sample)) @@ -353,7 +366,7 @@ def run( instances = parallel or multiprocessing.cpu_count() logger.info( f"Computing ROAD scores for dataset `{k}` using {instances} " - f"processes..." + f"processes...", ) with multiprocessing.Pool(instances) as p: retval[k] = list(tqdm.tqdm(p.imap(_process, v), total=len(v))) diff --git a/src/mednet/engine/saliency/evaluator.py b/src/mednet/engine/saliency/evaluator.py index 0941dc60731d9e7950f95de1aaae0e51ae2cd5a6..fdfe1d7783885ec0e8747a0a223924daf7be4dc1 100644 --- a/src/mednet/engine/saliency/evaluator.py +++ b/src/mednet/engine/saliency/evaluator.py @@ -16,7 +16,7 @@ def _reconcile_metrics( completeness: list, interpretability: list, ) -> list[tuple[str, int, float, float, float]]: - """Summarize samples into a new table containing the most important scores. + r"""Summarize samples into a new table containing the most important scores. It returns a list containing a table with completeness and ROAD scores per sample, for the selected dataset. Only samples for which a completness and @@ -42,9 +42,10 @@ def _reconcile_metrics( .. math:: - \\text{ROAD-WeightedPropEng} = \\max(0, \\text{AvgROAD}) \\cdot - \\text{ProportionalEnergy} + \text{ROAD-WeightedPropEng} = \max(0, \text{AvgROAD}) \cdot + \text{ProportionalEnergy} """ + retval: list[tuple[str, int, float, float, float]] = [] retval = [] @@ -68,7 +69,7 @@ def _reconcile_metrics( aopc_combined, prop_energy, road_weighted_prop_energy, - ) + ), ) return retval @@ -128,18 +129,27 @@ def _make_histogram( # draw median and quartiles quartile = numpy.percentile(values, [25, 50, 75]) ax.axvline( - quartile[0], color="green", linestyle="--", label="Q1", alpha=0.5 + quartile[0], + color="green", + linestyle="--", + label="Q1", + alpha=0.5, ) ax.axvline(quartile[1], color="red", label="median", alpha=0.5) ax.axvline( - quartile[2], color="green", linestyle="--", label="Q3", alpha=0.5 + quartile[2], + color="green", + linestyle="--", + label="Q3", + alpha=0.5, ) return fig # type: ignore def summary_table( - summary: dict[SaliencyMapAlgorithm, dict[str, typing.Any]], fmt: str + summary: dict[SaliencyMapAlgorithm, dict[str, typing.Any]], + fmt: str, ) -> str: """Tabulate various summaries into one table. @@ -297,7 +307,7 @@ def run( ) d["road-normalised-proportional-energy-average"] = sum( - retval["road-weighted-proportional-energy"]["val"] + retval["road-weighted-proportional-energy"]["val"], ) / sum([max(0, k) for k in retval["aopc-combined"]["val"]]) retval[dataset] = d diff --git a/src/mednet/engine/saliency/generator.py b/src/mednet/engine/saliency/generator.py index df0f9ab123f2ead0b17d361bb5dce73196438184..545235b4b748db95764427b631dc9ab55fff8838 100644 --- a/src/mednet/engine/saliency/generator.py +++ b/src/mednet/engine/saliency/generator.py @@ -44,66 +44,82 @@ def _create_saliency_map_callable( match algo_type: case "gradcam": return pytorch_grad_cam.GradCAM( - model=model, target_layers=target_layers + model=model, + target_layers=target_layers, ) case "scorecam": return pytorch_grad_cam.ScoreCAM( - model=model, target_layers=target_layers + model=model, + target_layers=target_layers, ) case "fullgrad": return pytorch_grad_cam.FullGrad( - model=model, target_layers=target_layers + model=model, + target_layers=target_layers, ) case "randomcam": return pytorch_grad_cam.RandomCAM( - model=model, target_layers=target_layers + model=model, + target_layers=target_layers, ) case "hirescam": return pytorch_grad_cam.HiResCAM( - model=model, target_layers=target_layers + model=model, + target_layers=target_layers, ) case "gradcamelementwise": return pytorch_grad_cam.GradCAMElementWise( - model=model, target_layers=target_layers + model=model, + target_layers=target_layers, ) case "gradcam++" | "gradcamplusplus": return pytorch_grad_cam.GradCAMPlusPlus( - model=model, target_layers=target_layers + model=model, + target_layers=target_layers, ) case "xgradcam": return pytorch_grad_cam.XGradCAM( - model=model, target_layers=target_layers + model=model, + target_layers=target_layers, ) case "ablationcam": assert ( target_layers is not None ), "AblationCAM cannot have target_layers=None" return pytorch_grad_cam.AblationCAM( - model=model, target_layers=target_layers + model=model, + target_layers=target_layers, ) case "eigencam": return pytorch_grad_cam.EigenCAM( - model=model, target_layers=target_layers + model=model, + target_layers=target_layers, ) case "eigengradcam": return pytorch_grad_cam.EigenGradCAM( - model=model, target_layers=target_layers + model=model, + target_layers=target_layers, ) case "layercam": return pytorch_grad_cam.LayerCAM( - model=model, target_layers=target_layers + model=model, + target_layers=target_layers, ) case _: raise ValueError( f"Saliency map algorithm `{algo_type}` is not currently " - f"supported." + f"supported.", ) def _save_saliency_map( - output_folder: pathlib.Path, name: str, saliency_map: torch.Tensor + output_folder: pathlib.Path, + name: str, + saliency_map: torch.Tensor, ) -> None: - """Helper function to save a saliency map to disk. + """Save a saliency map to permanent storage (disk). + + Helper function to save a saliency map to disk. Parameters ---------- @@ -158,6 +174,7 @@ def run( Where to save all the saliency maps (this path should exist before this function is called). """ + from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget from ...models.densenet import Densenet @@ -167,7 +184,7 @@ def run( if saliency_map_algorithm == "fullgrad": raise ValueError( "Fullgrad saliency map algorithm is not supported for the " - "Pasa model." + "Pasa model.", ) target_layers = [model.fc14] # Last non-1x1 Conv2d layer elif isinstance(model, Densenet): @@ -190,14 +207,16 @@ def run( for k, v in datamodule.predict_dataloader().items(): logger.info( - f"Generating saliency maps for dataset `{k}` via `{saliency_map_algorithm}`..." + f"Generating saliency maps for dataset `{k}` via " + f"`{saliency_map_algorithm}`...", ) for sample in tqdm.tqdm(v, desc="samples", leave=False, disable=None): name = sample[1]["name"][0] label = sample[1]["label"].item() image = sample[0].to( - device=device, non_blocking=torch.cuda.is_available() + device=device, + non_blocking=torch.cuda.is_available(), ) # in binary classification systems, negative labels may be skipped diff --git a/src/mednet/engine/saliency/interpretability.py b/src/mednet/engine/saliency/interpretability.py index aeca5b6f1049a0a0d363375595452ce56d51b389..4fc48861fce18aee5e0a3ef4333914802ee8f7e4 100644 --- a/src/mednet/engine/saliency/interpretability.py +++ b/src/mednet/engine/saliency/interpretability.py @@ -12,7 +12,6 @@ import numpy.typing import skimage.measure import torch import torchvision.ops - from tqdm import tqdm from ...config.data.tbx11k.datamodule import BoundingBox, BoundingBoxes @@ -93,6 +92,7 @@ def _extract_bounding_box( BoundingBox A bounding box. """ + x, y, x2, y2 = torchvision.ops.masks_to_boxes(torch.tensor(mask)[None, :])[ 0 ] @@ -122,6 +122,7 @@ def _compute_max_iou_and_ioda( tuple[float, float] The max iou and ioda values. """ + detected_area = detected_box.area() if detected_area == 0: return 0.0, 0.0 @@ -420,12 +421,15 @@ def run( # substantially speed this up. for dataset_name, dataset_loader in datamodule.predict_dataloader().items(): logger.info( - f"Estimating interpretability metrics for dataset `{dataset_name}`..." + f"Estimating interpretability metrics for dataset `{dataset_name}`...", ) retval[dataset_name] = [] for sample in tqdm( - dataset_loader, desc="batches", leave=False, disable=None + dataset_loader, + desc="batches", + leave=False, + disable=None, ): name = str(sample[1]["name"][0]) label = int(sample[1]["label"].item()) @@ -440,14 +444,15 @@ def run( # regions of interest. We need to abstract from this to support more # datasets and other ways to annotate. bboxes: BoundingBoxes = sample[1].get( - "bounding_boxes", BoundingBoxes() + "bounding_boxes", + BoundingBoxes(), ) if not bboxes: logger.warning( f"Sample `{name}` does not contain bounding-box information. " f"No localization metrics can be calculated in this case. " - f"Skipping..." + f"Skipping...", ) # we add the entry for dataset completeness retval[dataset_name].append([name, label]) @@ -462,10 +467,10 @@ def run( bboxes[0], numpy.load( input_folder - / pathlib.Path(name).with_suffix(".npy") + / pathlib.Path(name).with_suffix(".npy"), ), ), - ] + ], ) return retval diff --git a/src/mednet/engine/saliency/viewer.py b/src/mednet/engine/saliency/viewer.py index 0047c4e1ad7ce6ee7f9a3cdd6482a234afb11756..e1a67d723a9bb20324b29724cf98ea733240cf8a 100644 --- a/src/mednet/engine/saliency/viewer.py +++ b/src/mednet/engine/saliency/viewer.py @@ -3,7 +3,6 @@ # SPDX-License-Identifier: GPL-3.0-or-later import logging -import os import pathlib import typing @@ -15,7 +14,6 @@ import PIL.Image import PIL.ImageColor import PIL.ImageDraw import torchvision.transforms.functional - from tqdm import tqdm from ...config.data.tbx11k.datamodule import BoundingBox, BoundingBoxes @@ -174,11 +172,14 @@ def _process_sample( # the top, draw rectangles and other annotations in coulour. So, we force # it right up front. retval = torchvision.transforms.functional.to_pil_image(raw_data).convert( - "RGB" + "RGB", ) retval = _overlay_saliency_map( - retval, saliencies, colormap="plasma", image_weight=0.5 + retval, + saliencies, + colormap="plasma", + image_weight=0.5, ) for k in ground_truth: @@ -220,13 +221,17 @@ def run( for dataset_name, dataset_loader in datamodule.predict_dataloader().items(): logger.info( - f"Generating visualisations for samples at dataset `{dataset_name}`..." + f"Generating visualisations for samples at dataset `{dataset_name}`...", ) for sample in tqdm( - dataset_loader, desc="batches", leave=False, disable=None + dataset_loader, + desc="batches", + leave=False, + disable=None, ): - # WARNING: following code assumes a batch size of 1. Will break if not the case. + # WARNING: following code assumes a batch size of 1. Will break if + # not the case. name = str(sample[1]["name"][0]) label = int(sample[1]["label"].item()) data = sample[0][0] @@ -236,7 +241,7 @@ def run( continue saliencies = numpy.load( - input_folder / pathlib.Path(name).with_suffix(".npy") + input_folder / pathlib.Path(name).with_suffix(".npy"), ) saliencies[saliencies < (threshold * saliencies.max())] = 0 @@ -259,7 +264,7 @@ def run( # Save image output_file_path = output_folder / pathlib.Path(name).with_suffix( - ".png" + ".png", ) - os.makedirs(output_file_path.parent, exist_ok=True) + output_file_path.parent.mkdir(parents=True, exist_ok=True) image.save(output_file_path) diff --git a/src/mednet/engine/trainer.py b/src/mednet/engine/trainer.py index 7fd8f2a7a4d30fcf1beae5ed55870d3b2b056f33..57cdfe8e197e7773d4f6a8b2717552c3cb3aa74b 100644 --- a/src/mednet/engine/trainer.py +++ b/src/mednet/engine/trainer.py @@ -70,7 +70,7 @@ def run( Path to an optional checkpoint file to load. """ - os.makedirs(output_folder, exist_ok=True) + output_folder.mkdir(parents=True, exist_ok=True) from .loggers import CustomTensorboardLogger @@ -82,7 +82,7 @@ def run( logger.info( f"Monitor training with `tensorboard serve " f"--logdir={output_folder}/{log_dir}/`. " - f"Then, open a browser on the printed address." + f"Then, open a browser on the printed address.", ) train_resource_monitor = ResourceMonitor( @@ -109,7 +109,7 @@ def run( monitor="loss/validation", mode="min", save_on_train_epoch_end=True, - every_n_epochs=validation_period, # frequency at which it would check the "monitor" + every_n_epochs=validation_period, # frequency at which it checks the "monitor" enable_version_counter=False, # no versioning of aliased checkpoints ) checkpoint_minvalloss_callback.CHECKPOINT_NAME_LAST = CHECKPOINT_ALIASES[ # type: ignore @@ -128,7 +128,8 @@ def run( log_every_n_steps=len(datamodule.train_dataloader()), callbacks=[ LoggingCallback( - train_resource_monitor, validation_resource_monitor + train_resource_monitor, + validation_resource_monitor, ), checkpoint_minvalloss_callback, ], diff --git a/src/mednet/models/alexnet.py b/src/mednet/models/alexnet.py index d2635970f35497f6274b5ea87488f61e8c6a3a30..b4b9e723ad2a327f56e3051522654c3b6166ac89 100644 --- a/src/mednet/models/alexnet.py +++ b/src/mednet/models/alexnet.py @@ -87,7 +87,7 @@ class Alexnet(pl.LightningModule): self._optimizer_arguments = optimizer_arguments self._augmentation_transforms = torchvision.transforms.Compose( - augmentation_transforms + augmentation_transforms, ) self.pretrained = pretrained @@ -107,13 +107,12 @@ class Alexnet(pl.LightningModule): def forward(self, x): x = self.normalizer(x) # type: ignore - - x = self.model_ft(x) - - return x + return self.model_ft(x) def on_save_checkpoint(self, checkpoint: Checkpoint) -> None: - """Called by Lightning when saving a checkpoint to give you a chance to + """Perform actions during checkpoint saving (called by lightning). + + Called by Lightning when saving a checkpoint to give you a chance to store anything else you might want to save. Use on_load_checkpoint() to restore what additional data is saved here. @@ -122,10 +121,11 @@ class Alexnet(pl.LightningModule): checkpoint The checkpoint to save. """ + checkpoint["normalizer"] = self.normalizer def on_load_checkpoint(self, checkpoint: Checkpoint) -> None: - """Called by Lightning to restore your model. + """Perform actions during model loading (called by lightning). If you saved something with on_save_checkpoint() this is your chance to restore this. @@ -135,6 +135,7 @@ class Alexnet(pl.LightningModule): checkpoint The loaded checkpoint. """ + logger.info("Restoring normalizer from checkpoint.") self.normalizer = checkpoint["normalizer"] @@ -150,13 +151,14 @@ class Alexnet(pl.LightningModule): A torch Dataloader from which to compute the mean and std. Will not be used if the model is pretrained. """ + if self.pretrained: from .normalizer import make_imagenet_normalizer logger.warning( f"ImageNet pre-trained {self.name} model - NOT " f"computing z-norm factors from train dataloader. " - f"Using preset factors from torchvision." + f"Using preset factors from torchvision.", ) self.normalizer = make_imagenet_normalizer() else: @@ -164,7 +166,7 @@ class Alexnet(pl.LightningModule): logger.info( f"Uninitialised {self.name} model - " - f"computing z-norm factors from train dataloader." + f"computing z-norm factors from train dataloader.", ) self.normalizer = make_z_normalizer(dataloader) @@ -209,5 +211,6 @@ class Alexnet(pl.LightningModule): def configure_optimizers(self): return self._optimizer_type( - self.parameters(), **self._optimizer_arguments + self.parameters(), + **self._optimizer_arguments, ) diff --git a/src/mednet/models/densenet.py b/src/mednet/models/densenet.py index 91f7c336457d4bda68342be49cb77bb2b22d47c3..e29e128cd0c5a9146ff96a3205ff8428577f6906 100644 --- a/src/mednet/models/densenet.py +++ b/src/mednet/models/densenet.py @@ -85,7 +85,7 @@ class Densenet(pl.LightningModule): self._optimizer_arguments = optimizer_arguments self._augmentation_transforms = torchvision.transforms.Compose( - augmentation_transforms + augmentation_transforms, ) self.pretrained = pretrained @@ -101,18 +101,18 @@ class Densenet(pl.LightningModule): # Adapt output features self.model_ft.classifier = torch.nn.Linear( - self.model_ft.classifier.in_features, self.num_classes + self.model_ft.classifier.in_features, + self.num_classes, ) def forward(self, x): x = self.normalizer(x) # type: ignore - - x = self.model_ft(x) - - return x + return self.model_ft(x) def on_save_checkpoint(self, checkpoint: Checkpoint) -> None: - """Called by Lightning when saving a checkpoint to give you a chance to + """Perform actions during checkpoint saving (called by lightning). + + Called by Lightning when saving a checkpoint to give you a chance to store anything else you might want to save. Use on_load_checkpoint() to restore what additional data is saved here. @@ -121,10 +121,11 @@ class Densenet(pl.LightningModule): checkpoint The checkpoint to save. """ + checkpoint["normalizer"] = self.normalizer def on_load_checkpoint(self, checkpoint: Checkpoint) -> None: - """Called by Lightning to restore your model. + """Perform actions during model loading (called by lightning). If you saved something with on_save_checkpoint() this is your chance to restore this. @@ -134,6 +135,7 @@ class Densenet(pl.LightningModule): checkpoint The loaded checkpoint. """ + logger.info("Restoring normalizer from checkpoint.") self.normalizer = checkpoint["normalizer"] @@ -149,13 +151,14 @@ class Densenet(pl.LightningModule): A torch Dataloader from which to compute the mean and std. Will not be used if the model is pretrained. """ + if self.pretrained: from .normalizer import make_imagenet_normalizer logger.warning( f"ImageNet pre-trained {self.name} model - NOT " f"computing z-norm factors from train dataloader. " - f"Using preset factors from torchvision." + f"Using preset factors from torchvision.", ) self.normalizer = make_imagenet_normalizer() else: @@ -163,7 +166,7 @@ class Densenet(pl.LightningModule): logger.info( f"Uninitialised {self.name} model - " - f"computing z-norm factors from train dataloader." + f"computing z-norm factors from train dataloader.", ) self.normalizer = make_z_normalizer(dataloader) @@ -202,5 +205,6 @@ class Densenet(pl.LightningModule): def configure_optimizers(self): return self._optimizer_type( - self.parameters(), **self._optimizer_arguments + self.parameters(), + **self._optimizer_arguments, ) diff --git a/src/mednet/models/logistic_regression.py b/src/mednet/models/logistic_regression.py index fd3281ec720ce4b926e5ecdc7b8208365769422e..f203e35221671f6abb8a9db87f12d3d82f87201f 100644 --- a/src/mednet/models/logistic_regression.py +++ b/src/mednet/models/logistic_regression.py @@ -65,7 +65,7 @@ class LogisticRegression(pl.LightningModule): return self.linear(x) def training_step(self, batch, batch_idx): - input = batch[1] + _input = batch[1] labels = batch[2] # Increase label dimension if too low @@ -74,7 +74,7 @@ class LogisticRegression(pl.LightningModule): labels = torch.reshape(labels, (labels.shape[0], 1)) # Forward pass on the network - output = self(input) + output = self(_input) # Manually move criterion to selected device, since not part of the model. self._train_loss = self._train_loss.to(self.device) @@ -83,7 +83,7 @@ class LogisticRegression(pl.LightningModule): return {"loss": training_loss} def validation_step(self, batch, batch_idx, dataloader_idx=0): - input = batch[1] + _input = batch[1] labels = batch[2] # Increase label dimension if too low @@ -92,7 +92,7 @@ class LogisticRegression(pl.LightningModule): labels = torch.reshape(labels, (labels.shape[0], 1)) # data forwarding on the existing network - output = self(input) + output = self(_input) # Manually move criterion to selected device, since not part of the model. self._validation_loss = self._validation_loss.to(self.device) @@ -100,8 +100,8 @@ class LogisticRegression(pl.LightningModule): if dataloader_idx == 0: return {"validation_loss": validation_loss} - else: - return {f"extra_validation_loss_{dataloader_idx}": validation_loss} + + return {f"extra_validation_loss_{dataloader_idx}": validation_loss} def predict_step(self, batch, batch_idx, dataloader_idx=0): outputs = self(batch[0]) @@ -110,5 +110,6 @@ class LogisticRegression(pl.LightningModule): def configure_optimizers(self): return self._optimizer_type( - self.parameters(), **self._optimizer_arguments + self.parameters(), + **self._optimizer_arguments, ) diff --git a/src/mednet/models/loss_weights.py b/src/mednet/models/loss_weights.py index 8051f63b27ad7d43b04ca0316738a48c336326a3..bf965790cade10d68e19c0bf372c9fa7bf4d5409 100644 --- a/src/mednet/models/loss_weights.py +++ b/src/mednet/models/loss_weights.py @@ -36,7 +36,7 @@ def _get_label_weights( """ targets = torch.tensor( - [sample for batch in dataloader for sample in batch[1]["label"]] + [sample for batch in dataloader for sample in batch[1]["label"]], ) # Binary labels @@ -48,7 +48,7 @@ def _get_label_weights( # Divide negatives by positives positive_weights = torch.tensor( - [class_sample_count[0] / class_sample_count[1]] + [class_sample_count[0] / class_sample_count[1]], ).reshape(-1) # Multiclass labels diff --git a/src/mednet/models/mlp.py b/src/mednet/models/mlp.py index 831b7385d71dc43d696ae07d62c923a76bb23653..e8e4b2904264d8fc9485ad28adbf4b90213f51eb 100644 --- a/src/mednet/models/mlp.py +++ b/src/mednet/models/mlp.py @@ -69,7 +69,7 @@ class MultiLayerPerceptron(pl.LightningModule): return self.fc2(self.relu(self.fc1(x))) def training_step(self, batch, batch_idx): - input = batch[1] + _input = batch[1] labels = batch[2] # Increase label dimension if too low @@ -78,7 +78,7 @@ class MultiLayerPerceptron(pl.LightningModule): labels = torch.reshape(labels, (labels.shape[0], 1)) # Forward pass on the network - output = self(input) + output = self(_input) # Manually move criterion to selected device, since not part of the model. self._train_loss = self._train_loss.to(self.device) @@ -87,7 +87,7 @@ class MultiLayerPerceptron(pl.LightningModule): return {"loss": training_loss} def validation_step(self, batch, batch_idx, dataloader_idx=0): - input = batch[1] + _input = batch[1] labels = batch[2] # Increase label dimension if too low @@ -96,7 +96,7 @@ class MultiLayerPerceptron(pl.LightningModule): labels = torch.reshape(labels, (labels.shape[0], 1)) # data forwarding on the existing network - output = self(input) + output = self(_input) # Manually move criterion to selected device, since not part of the model. self._validation_loss = self._validation_loss.to(self.device) @@ -104,8 +104,8 @@ class MultiLayerPerceptron(pl.LightningModule): if dataloader_idx == 0: return {"validation_loss": validation_loss} - else: - return {f"extra_validation_loss_{dataloader_idx}": validation_loss} + + return {f"extra_validation_loss_{dataloader_idx}": validation_loss} def predict_step(self, batch, batch_idx, dataloader_idx=0): outputs = self(batch[0]) @@ -114,5 +114,6 @@ class MultiLayerPerceptron(pl.LightningModule): def configure_optimizers(self): return self._optimizer_type( - self.parameters(), **self._optimizer_arguments + self.parameters(), + **self._optimizer_arguments, ) diff --git a/src/mednet/models/normalizer.py b/src/mednet/models/normalizer.py index 6e09f51e3972a98e4f4c3e4cc359e44b4516aada..fc2992c57bb336ec9625152e4054211293cfee25 100644 --- a/src/mednet/models/normalizer.py +++ b/src/mednet/models/normalizer.py @@ -66,5 +66,6 @@ def make_imagenet_normalizer() -> torchvision.transforms.Normalize: """ return torchvision.transforms.Normalize( - (0.485, 0.456, 0.406), (0.229, 0.224, 0.225) + (0.485, 0.456, 0.406), + (0.229, 0.224, 0.225), ) diff --git a/src/mednet/models/pasa.py b/src/mednet/models/pasa.py index e285bf7baa36a9ab548de01451622ee61ae1c111..16a71f73c93bc9ab4dd8a65d2c9da7d96d27a36e 100644 --- a/src/mednet/models/pasa.py +++ b/src/mednet/models/pasa.py @@ -8,7 +8,7 @@ import typing import lightning.pytorch as pl import torch import torch.nn -import torch.nn.functional as F +import torch.nn.functional as F # noqa: N812 import torch.optim.optimizer import torch.utils.data import torchvision.transforms @@ -22,7 +22,7 @@ logger = logging.getLogger(__name__) class Pasa(pl.LightningModule): - """Implementation of CNN by Pasa. + """Implementation of CNN by Pasa and others. Simple CNN for classification based on paper by [PASA-2019]_. @@ -90,7 +90,7 @@ class Pasa(pl.LightningModule): self._optimizer_arguments = optimizer_arguments self._augmentation_transforms = torchvision.transforms.Compose( - augmentation_transforms + augmentation_transforms, ) # First convolution block @@ -106,7 +106,10 @@ class Pasa(pl.LightningModule): self.fc4 = torch.nn.Conv2d(16, 24, (3, 3), (1, 1), (1, 1)) self.fc5 = torch.nn.Conv2d(24, 32, (3, 3), (1, 1), (1, 1)) self.fc6 = torch.nn.Conv2d( - 16, 32, (1, 1), (1, 1) + 16, + 32, + (1, 1), + (1, 1), ) # Original stride (2, 2) self.batchNorm2d_24 = torch.nn.BatchNorm2d(24) @@ -117,7 +120,10 @@ class Pasa(pl.LightningModule): self.fc7 = torch.nn.Conv2d(32, 40, (3, 3), (1, 1), (1, 1)) self.fc8 = torch.nn.Conv2d(40, 48, (3, 3), (1, 1), (1, 1)) self.fc9 = torch.nn.Conv2d( - 32, 48, (1, 1), (1, 1) + 32, + 48, + (1, 1), + (1, 1), ) # Original stride (2, 2) self.batchNorm2d_40 = torch.nn.BatchNorm2d(40) @@ -128,7 +134,10 @@ class Pasa(pl.LightningModule): self.fc10 = torch.nn.Conv2d(48, 56, (3, 3), (1, 1), (1, 1)) self.fc11 = torch.nn.Conv2d(56, 64, (3, 3), (1, 1), (1, 1)) self.fc12 = torch.nn.Conv2d( - 48, 64, (1, 1), (1, 1) + 48, + 64, + (1, 1), + (1, 1), ) # Original stride (2, 2) self.batchNorm2d_56 = torch.nn.BatchNorm2d(56) @@ -139,7 +148,10 @@ class Pasa(pl.LightningModule): self.fc13 = torch.nn.Conv2d(64, 72, (3, 3), (1, 1), (1, 1)) self.fc14 = torch.nn.Conv2d(72, 80, (3, 3), (1, 1), (1, 1)) self.fc15 = torch.nn.Conv2d( - 64, 80, (1, 1), (1, 1) + 64, + 80, + (1, 1), + (1, 1), ) # Original stride (2, 2) self.batchNorm2d_72 = torch.nn.BatchNorm2d(72) @@ -147,10 +159,12 @@ class Pasa(pl.LightningModule): self.batchNorm2d_80_2 = torch.nn.BatchNorm2d(80) self.pool2d = torch.nn.MaxPool2d( - (3, 3), (2, 2) + (3, 3), + (2, 2), ) # Pool after conv. block self.dense = torch.nn.Linear( - 80, self.num_classes + 80, + self.num_classes, ) # Fully connected layer def forward(self, x): @@ -195,14 +209,14 @@ class Pasa(pl.LightningModule): x = torch.mean(x.view(x.size(0), x.size(1), -1), dim=2) # Dense layer - x = self.dense(x) + return self.dense(x) # x = F.log_softmax(x, dim=1) # 0 is batch size - return x - def on_save_checkpoint(self, checkpoint: Checkpoint) -> None: - """Called by Lightning when saving a checkpoint to give you a chance to + """Perform actions during checkpoint saving (called by lightning). + + Called by Lightning when saving a checkpoint to give you a chance to store anything else you might want to save. Use on_load_checkpoint() to restore what additional data is saved here. @@ -211,10 +225,11 @@ class Pasa(pl.LightningModule): checkpoint The checkpoint to save. """ + checkpoint["normalizer"] = self.normalizer def on_load_checkpoint(self, checkpoint: Checkpoint) -> None: - """Called by Lightning to restore your model. + """Perform actions during model loading (called by lightning). If you saved something with on_save_checkpoint() this is your chance to restore this. @@ -224,6 +239,7 @@ class Pasa(pl.LightningModule): checkpoint The loaded checkpoint. """ + logger.info("Restoring normalizer from checkpoint.") self.normalizer = checkpoint["normalizer"] @@ -235,11 +251,12 @@ class Pasa(pl.LightningModule): dataloader A torch Dataloader from which to compute the mean and std. """ + from .normalizer import make_z_normalizer logger.info( f"Uninitialised {self.name} model - " - f"computing z-norm factors from train dataloader." + f"computing z-norm factors from train dataloader.", ) self.normalizer = make_z_normalizer(dataloader) @@ -278,5 +295,6 @@ class Pasa(pl.LightningModule): def configure_optimizers(self): return self._optimizer_type( - self.parameters(), **self._optimizer_arguments + self.parameters(), + **self._optimizer_arguments, ) diff --git a/src/mednet/models/separate.py b/src/mednet/models/separate.py index 9e5c0ee7b1c604521b2ce868e96837373fa967e3..6f3dfaedddbffc71102271aa8d9045ae8a6524ab 100644 --- a/src/mednet/models/separate.py +++ b/src/mednet/models/separate.py @@ -27,6 +27,7 @@ def _as_predictions( list[BinaryPrediction | MultiClassPrediction] A list of typed predictions that can be saved to disk. """ + return [(v[1]["name"], v[1]["label"].item(), v[0].item()) for v in samples] diff --git a/src/mednet/models/transforms.py b/src/mednet/models/transforms.py index 2a4f933ae4f541ea71015176e0a048e99bdfd4d8..fb85933ca270ea0ee80320210cabbdf77bba7fcf 100644 --- a/src/mednet/models/transforms.py +++ b/src/mednet/models/transforms.py @@ -15,14 +15,12 @@ def square_center_pad(img: torch.Tensor) -> torch.Tensor: Parameters ---------- - img The tensor to be transformed. Expected to be in the form: ``[..., [1,3], H, W]`` (i.e. arbitrary number of leading dimensions). Returns ------- - Transformed tensor, guaranteed to be square (ie. equal height and width). """ @@ -37,7 +35,10 @@ def square_center_pad(img: torch.Tensor) -> torch.Tensor: bottom = maxdim - height - top return torchvision.transforms.functional.pad( - img, [left, top, right, bottom], 0, "constant" + img, + [left, top, right, bottom], + 0, + "constant", ) @@ -60,17 +61,19 @@ def grayscale_to_rgb(img: torch.Tensor) -> torch.Tensor: torch.Tensor Transformed tensor with 3 identical color channels. """ + if img.ndim < 3: raise TypeError( f"Input image tensor should have at least 3 dimensions, " f"but found {img.ndim}. If a grayscale image was provided, " - f"ensure to include a channel dimension of size 1 ( i.e: [1, height, width])." + f"ensure to include a channel dimension of size 1 ( i.e: " + f"[1, height, width]).", ) if img.shape[-3] not in (1, 3): raise TypeError( f"Input image tensor should have 1 or 3 color channels," - f"but found {img.shape[-3]}." + f"but found {img.shape[-3]}.", ) if img.shape[-3] == 3: @@ -107,17 +110,19 @@ def rgb_to_grayscale(img: torch.Tensor) -> torch.Tensor: torch.Tensor Transformed tensor with a single (grayscale) color channel. """ + if img.ndim < 3: raise TypeError( f"Input image tensor should have at least 3 dimensions, " f"but found {img.ndim}. If a grayscale image was provided, " - f"ensure to include a channel dimension of size 1 ( i.e: [1, height, width])." + f"ensure to include a channel dimension of size 1 ( i.e: " + f"[1, height, width]).", ) if img.shape[-3] not in (1, 3): raise TypeError( f"Input image tensor should have 1 or 3 planes," - f"but found {img.shape[-3]}" + f"but found {img.shape[-3]}", ) if img.shape[-3] == 1: @@ -128,7 +133,9 @@ def rgb_to_grayscale(img: torch.Tensor) -> torch.Tensor: class SquareCenterPad(torch.nn.Module): - """Transform to a squared version of the image, centered on a canvas padded with zeros.""" + """Transform to a squared version of the image, centered on a canvas padded + with zeros. + """ def __init__(self): super().__init__() @@ -138,7 +145,9 @@ class SquareCenterPad(torch.nn.Module): class RGB(torch.nn.Module): - """Wrapper class around :py:func:`.grayscale_to_rgb` to be used as a model transform.""" + """Wrapper class around :py:func:`.grayscale_to_rgb` to be used as a model + transform. + """ def __init__(self): super().__init__() @@ -148,7 +157,9 @@ class RGB(torch.nn.Module): class Grayscale(torch.nn.Module): - """Wrapper class around :py:func:`rgb_to_grayscale` to be used as a model transform.""" + """Wrapper class around :py:func:`rgb_to_grayscale` to be used as a model + transform. + """ def __init__(self): super().__init__() diff --git a/src/mednet/models/typing.py b/src/mednet/models/typing.py index 883811acd439ebfb1fc40a581c18f3b225267d00..9d1e7f32c1b08684cf324233694518855841eaa0 100644 --- a/src/mednet/models/typing.py +++ b/src/mednet/models/typing.py @@ -12,17 +12,21 @@ BinaryPrediction: typing.TypeAlias = tuple[str, int, float] """The sample name, the target, and the predicted value.""" MultiClassPrediction: typing.TypeAlias = tuple[ - str, typing.Sequence[int], typing.Sequence[float] + str, + typing.Sequence[int], + typing.Sequence[float], ] """The sample name, the target, and the predicted value.""" BinaryPredictionSplit: typing.TypeAlias = typing.Mapping[ - str, typing.Sequence[BinaryPrediction] + str, + typing.Sequence[BinaryPrediction], ] """A series of predictions for different database splits.""" MultiClassPredictionSplit: typing.TypeAlias = typing.Mapping[ - str, typing.Sequence[MultiClassPrediction] + str, + typing.Sequence[MultiClassPrediction], ] """A series of predictions for different database splits.""" diff --git a/src/mednet/scripts/cli.py b/src/mednet/scripts/cli.py index 3c44ea0a847f9c7ee17e67ad0dc19ee628a6bd51..fb3f605739811c9a5a5455cf6d8c8a89919c59f3 100644 --- a/src/mednet/scripts/cli.py +++ b/src/mednet/scripts/cli.py @@ -5,7 +5,6 @@ import importlib import click - from clapper.click import AliasedGroup @@ -20,18 +19,21 @@ def cli(): cli.add_command(importlib.import_module("..config", package=__name__).config) cli.add_command( - importlib.import_module("..database", package=__name__).database + importlib.import_module("..database", package=__name__).database, ) cli.add_command( - importlib.import_module("..evaluate", package=__name__).evaluate + importlib.import_module("..evaluate", package=__name__).evaluate, ) cli.add_command( - importlib.import_module("..experiment", package=__name__).experiment + importlib.import_module("..experiment", package=__name__).experiment, ) cli.add_command(importlib.import_module("..predict", package=__name__).predict) cli.add_command(importlib.import_module("..train", package=__name__).train) cli.add_command( - importlib.import_module("..train_analysis", package=__name__).train_analysis + importlib.import_module( + "..train_analysis", + package=__name__, + ).train_analysis, ) @@ -40,28 +42,30 @@ cli.add_command( context_settings=dict(help_option_names=["-?", "-h", "--help"]), ) def saliency(): - """The sub-commands to generate, evaluate and view saliency maps.""" + """Generate, evaluate and view saliency maps.""" pass cli.add_command(saliency) saliency.add_command( - importlib.import_module("..saliency.generate", package=__name__).generate + importlib.import_module("..saliency.generate", package=__name__).generate, ) saliency.add_command( importlib.import_module( - "..saliency.completeness", package=__name__ - ).completeness + "..saliency.completeness", + package=__name__, + ).completeness, ) saliency.add_command( importlib.import_module( - "..saliency.interpretability", package=__name__ - ).interpretability + "..saliency.interpretability", + package=__name__, + ).interpretability, ) saliency.add_command( - importlib.import_module("..saliency.evaluate", package=__name__).evaluate + importlib.import_module("..saliency.evaluate", package=__name__).evaluate, ) saliency.add_command( - importlib.import_module("..saliency.view", package=__name__).view + importlib.import_module("..saliency.view", package=__name__).view, ) diff --git a/src/mednet/scripts/click.py b/src/mednet/scripts/click.py index 8cb7c392e50c074c4b6b2523bee17f0953d75acb..9e3c41e432ab9bb6454ccdac21536294401dbbac 100644 --- a/src/mednet/scripts/click.py +++ b/src/mednet/scripts/click.py @@ -3,7 +3,6 @@ # SPDX-License-Identifier: GPL-3.0-or-later import click - from clapper.click import ConfigCommand as _BaseConfigCommand @@ -11,16 +10,18 @@ class ConfigCommand(_BaseConfigCommand): """A click command-class that has the properties of :py:class:`clapper.click.ConfigCommand` and adds verbatim epilog formatting.""" def format_epilog( - self, _: click.core.Context, formatter: click.formatting.HelpFormatter + self, + _: click.core.Context, + formatter: click.formatting.HelpFormatter, ) -> None: """Format the command epilog during --help. Parameters ---------- - _ - The current parsing context. - formatter - The formatter to use for printing text. + _ + The current parsing context. + formatter + The formatter to use for printing text. """ if self.epilog: diff --git a/src/mednet/scripts/config.py b/src/mednet/scripts/config.py index 5091cd28d9faf62f9e7dcc988a23868e85768dec..a6acc885c917634d75c95a59a19545c219f78969 100644 --- a/src/mednet/scripts/config.py +++ b/src/mednet/scripts/config.py @@ -4,10 +4,10 @@ import importlib.metadata import inspect +import pathlib import typing import click - from clapper.click import AliasedGroup, verbosity_option from clapper.logging import setup @@ -21,6 +21,7 @@ def config(): @config.command( + name="list", epilog="""Examples: \b @@ -39,13 +40,13 @@ def config(): mednet config list -v -""" +""", ) @verbosity_option(logger=logger) -def list(verbose) -> None: # numpydoc ignore=PR01 +def list_(verbose) -> None: # numpydoc ignore=PR01 """List configuration files installed.""" entry_points = importlib.metadata.entry_points().select( - group="mednet.config" + group="mednet.config", ) entry_point_dict = {k.name: k for k in entry_points} @@ -71,7 +72,7 @@ def list(verbose) -> None: # numpydoc ignore=PR01 # 79 - 4 spaces = 75 (see string above) description_leftover = 75 - longest_name_length - print(f"module: {config_type}") + click.echo(f"module: {config_type}") for name in sorted(entry_points_by_module[config_type]): ep = entry_point_dict[name] @@ -91,7 +92,7 @@ def list(verbose) -> None: # numpydoc ignore=PR01 else summary ) - print(print_string % (name, summary)) + click.echo(print_string % (name, summary)) @config.command( @@ -113,7 +114,7 @@ def list(verbose) -> None: # numpydoc ignore=PR01 mednet config describe montgomery -v -""" +""", ) @click.argument( "name", @@ -124,7 +125,7 @@ def list(verbose) -> None: # numpydoc ignore=PR01 def describe(name, verbose) -> None: # numpydoc ignore=PR01 """Describe a specific configuration file.""" entry_points = importlib.metadata.entry_points().select( - group="mednet.config" + group="mednet.config", ) entry_point_dict = {k.name: k for k in entry_points} @@ -133,19 +134,19 @@ def describe(name, verbose) -> None: # numpydoc ignore=PR01 logger.error("Cannot find configuration resource '%s'", k) continue ep = entry_point_dict[k] - print(f"Configuration: {ep.name}") - print(f"Python Module: {ep.module}") - print("") + click.echo(f"Configuration: {ep.name}") + click.echo(f"Python Module: {ep.module}") + click.echo("") mod = ep.load() if verbose >= 1: fname = inspect.getfile(mod) - print("Contents:") - with open(fname) as f: - print(f.read()) + click.echo("Contents:") + with pathlib.Path(fname).open() as f: + click.echo(f.read()) else: # only output documentation - print("Documentation:") - print(inspect.getdoc(mod)) + click.echo("Documentation:") + click.echo(inspect.getdoc(mod)) @config.command( @@ -159,7 +160,7 @@ def describe(name, verbose) -> None: # numpydoc ignore=PR01 $ mednet config copy montgomery -vvv newdataset.py -""" +""", ) @click.argument( "source", @@ -177,7 +178,7 @@ def copy(source, destination) -> None: # numpydoc ignore=PR01 import shutil entry_points = importlib.metadata.entry_points().select( - group="mednet.config" + group="mednet.config", ) entry_point_dict = {k.name: k for k in entry_points} diff --git a/src/mednet/scripts/database.py b/src/mednet/scripts/database.py index e968ef26e8f8b0c6422c28814fb8828c0507b45c..606dbdd487d9c8095f6dae917a0033dd58cd3500 100644 --- a/src/mednet/scripts/database.py +++ b/src/mednet/scripts/database.py @@ -3,7 +3,6 @@ # SPDX-License-Identifier: GPL-3.0-or-later import click - from clapper.click import AliasedGroup, verbosity_option from clapper.logging import setup @@ -24,6 +23,7 @@ def _get_raw_databases() -> dict[str, dict[str, str]]: * ``datadir``: points to the user-configured data directory for the current dataset, if set, or ``None`` otherwise. """ + import importlib import pkgutil @@ -35,18 +35,19 @@ def _get_raw_databases() -> dict[str, dict[str, str]]: retval = {} for k in pkgutil.iter_modules(data.__path__): for j in pkgutil.iter_modules( - [next(iter(data.__path__)) + f"/{k.name}"] + [next(iter(data.__path__)) + f"/{k.name}"], ): if j.name == "datamodule": # this is a submodule that can read raw data files module = importlib.import_module( - f".{j.name}", data.__package__ + f".{k.name}" + f".{j.name}", + data.__package__ + f".{k.name}", ) if hasattr(module, "CONFIGURATION_KEY_DATADIR"): retval[k.name] = dict( module=module.__name__.rsplit(".", 1)[0], datadir=user_configuration.get( - module.CONFIGURATION_KEY_DATADIR + module.CONFIGURATION_KEY_DATADIR, ), ) else: @@ -62,6 +63,7 @@ def database() -> None: @database.command( + name="list", epilog="""Examples: \b @@ -89,7 +91,7 @@ def database() -> None: """, ) @verbosity_option(logger=logger, expose_value=False) -def list(): +def list_(): """List all supported and configured databases.""" config = _get_raw_databases() @@ -175,7 +177,7 @@ def check(split, limit): # numpydoc ignore=PR01 break logger.info( f"{batch[1]['name'][0]}: " - f"{[s for s in batch[0][0].shape]}@{batch[0][0].dtype}" + f"{[s for s in batch[0][0].shape]}@{batch[0][0].dtype}", ) loader_limit -= 1 except Exception: @@ -189,5 +191,6 @@ def check(split, limit): # numpydoc ignore=PR01 ) else: click.secho( - f"Found {errors} errors loading DataModule `{split}`.", fg="red" + f"Found {errors} errors loading DataModule `{split}`.", + fg="red", ) diff --git a/src/mednet/scripts/evaluate.py b/src/mednet/scripts/evaluate.py index f37a6d26a7bd8f74d742b85d63d74fbb9474bf1a..47264b58f2b9dd8a8e4d28a3b0d7ff3fb3e72e40 100644 --- a/src/mednet/scripts/evaluate.py +++ b/src/mednet/scripts/evaluate.py @@ -5,7 +5,6 @@ import pathlib import click - from clapper.click import ResourceOption, verbosity_option from clapper.logging import setup @@ -41,7 +40,10 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") help="Directory in which predictions are currently stored", required=True, type=click.Path( - file_okay=True, dir_okay=False, writable=True, path_type=pathlib.Path + file_okay=True, + dir_okay=False, + writable=True, + path_type=pathlib.Path, ), cls=ResourceOption, ) @@ -54,7 +56,10 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") default="evaluation.json", cls=ResourceOption, type=click.Path( - file_okay=True, dir_okay=False, writable=True, path_type=pathlib.Path + file_okay=True, + dir_okay=False, + writable=True, + path_type=pathlib.Path, ), ) @click.option( @@ -108,7 +113,6 @@ def evaluate( **_, # ignored ) -> None: # numpydoc ignore=PR01 """Evaluate predictions (from a model) on a classification task.""" - import json import typing @@ -148,7 +152,7 @@ def evaluate( raise click.BadParameter( f"""The value of --threshold=`{threshold}` does not match one of the database split names ({', '.join(predict_data.keys())}) - or can not be converted to a float. Check your input.""" + or can not be converted to a float. Check your input.""", ) results: dict[str, dict[str, typing.Any]] = dict() @@ -179,7 +183,7 @@ def evaluate( table_path = output.with_suffix(".rst") logger.info( - f"Saving evaluation results in table format at `{table_path}`..." + f"Saving evaluation results in table format at `{table_path}`...", ) with table_path.open("w") as f: f.write(table) diff --git a/src/mednet/scripts/experiment.py b/src/mednet/scripts/experiment.py index e19f59efd4ec7120fbde8957b8b9d112c39ebe02..74f2abad84a8920bbb9f76dc3a2673c07d7d98f6 100644 --- a/src/mednet/scripts/experiment.py +++ b/src/mednet/scripts/experiment.py @@ -3,7 +3,6 @@ # SPDX-License-Identifier: GPL-3.0-or-later import click - from clapper.click import ConfigCommand, ResourceOption, verbosity_option from clapper.logging import setup @@ -51,7 +50,7 @@ def experiment( balance_classes, **_, ): # numpydoc ignore=PR01 - """Run a complete experiment, from training, to prediction and evaluation. + r"""Run a complete experiment, from training, to prediction and evaluation. This script is just a wrapper around the individual scripts for training, running prediction, and evaluating. It organises the output in a preset way:: diff --git a/src/mednet/scripts/predict.py b/src/mednet/scripts/predict.py index e1f38d48fccbb3f07b57439f034cd04d7ad6423a..b8ecf91c86d5c554337a40ccf7264162692e2d64 100644 --- a/src/mednet/scripts/predict.py +++ b/src/mednet/scripts/predict.py @@ -5,7 +5,6 @@ import pathlib import click - from clapper.click import ResourceOption, verbosity_option from clapper.logging import setup @@ -43,7 +42,10 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") default="predictions.json", cls=ResourceOption, type=click.Path( - file_okay=True, dir_okay=False, writable=True, path_type=pathlib.Path + file_okay=True, + dir_okay=False, + writable=True, + path_type=pathlib.Path, ), ) @click.option( @@ -126,7 +128,6 @@ def predict( **_, ) -> None: # numpydoc ignore=PR01 """Run inference (generates scores) on all input images, using a pre-trained model.""" - import json import shutil import typing @@ -164,7 +165,7 @@ def predict( database_name=datamodule.database_name, database_split=datamodule.split_name, model_name=model.name, - ) + ), ) json_data.update(model_summary(model)) json_data = {k.replace("_", "-"): v for k, v in json_data.items()} @@ -177,7 +178,7 @@ def predict( backup = output.parent / (output.name + "~") logger.warning( f"Output predictions file `{str(output)}` exists - " - f"backing it up to `{str(backup)}`..." + f"backing it up to `{str(backup)}`...", ) shutil.copy(output, backup) diff --git a/src/mednet/scripts/saliency/completeness.py b/src/mednet/scripts/saliency/completeness.py index a6e30bbea94e3420927c439bd5710ebfeb5853a6..453618332cd06a4a54570c70b65f7559b0900390 100644 --- a/src/mednet/scripts/saliency/completeness.py +++ b/src/mednet/scripts/saliency/completeness.py @@ -6,7 +6,6 @@ import pathlib import typing import click - from clapper.click import ResourceOption, verbosity_option from clapper.logging import setup @@ -117,7 +116,8 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") "-s", help="""Saliency map algorithm to be used.""", type=click.Choice( - typing.get_args(SaliencyMapAlgorithm), case_sensitive=False + typing.get_args(SaliencyMapAlgorithm), + case_sensitive=False, ), default="gradcam", show_default=True, @@ -211,7 +211,7 @@ def completeness( f"you asked to use a GPU (device = `{device}`). The currently " f"implementation can only handle a single GPU. Either disable GPU " f"utilisation or set the number of multiprocessing instances to " - f"one, or disable multiprocessing entirely (ie. set it to -1)." + f"one, or disable multiprocessing entirely (ie. set it to -1).", ) device_manager = DeviceManager(device) @@ -235,7 +235,7 @@ def completeness( logger.info( f"Evaluating RemOve And Debias (ROAD) average scores for " f"algorithm `{saliency_map_algorithm}` with percentiles " - f"`{', '.join([str(k) for k in percentile])}`..." + f"`{', '.join([str(k) for k in percentile])}`...", ) results = run( model=model, diff --git a/src/mednet/scripts/saliency/evaluate.py b/src/mednet/scripts/saliency/evaluate.py index 26a9b8d1f6ce75b3a4202240541266e0fafa38c8..cf8868afc6f9d667498f2a4fe33e86d2e6680614 100644 --- a/src/mednet/scripts/saliency/evaluate.py +++ b/src/mednet/scripts/saliency/evaluate.py @@ -6,7 +6,6 @@ import pathlib import typing import click - from clapper.click import ResourceOption, verbosity_option from clapper.logging import setup @@ -45,7 +44,8 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") f"{'|'.join(typing.get_args(SaliencyMapAlgorithm))}", type=( click.Choice( - typing.get_args(SaliencyMapAlgorithm), case_sensitive=False + typing.get_args(SaliencyMapAlgorithm), + case_sensitive=False, ), click.Path( exists=True, @@ -110,5 +110,5 @@ def evaluate( pdf.savefig( summary[dataset]["road-weighted-proportional-energy"][ "plot" - ] + ], ) diff --git a/src/mednet/scripts/saliency/generate.py b/src/mednet/scripts/saliency/generate.py index 583c250c7061181298d6a5ed0bdb27f12826211d..5a9ca8b6955158355e333e8e178a7a5fdd3b814c 100644 --- a/src/mednet/scripts/saliency/generate.py +++ b/src/mednet/scripts/saliency/generate.py @@ -6,7 +6,6 @@ import pathlib import typing import click - from clapper.click import ResourceOption, verbosity_option from clapper.logging import setup @@ -118,7 +117,8 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") "-s", help="""Saliency map algorithm to be used.""", type=click.Choice( - typing.get_args(SaliencyMapAlgorithm), case_sensitive=False + typing.get_args(SaliencyMapAlgorithm), + case_sensitive=False, ), default="gradcam", show_default=True, @@ -168,7 +168,6 @@ def generate( The quality of saliency information depends on the saliency map algorithm and trained model. """ - from ...engine.device import DeviceManager from ...engine.saliency.generator import run from ...utils.checkpointer import get_checkpoint_to_run_inference diff --git a/src/mednet/scripts/saliency/interpretability.py b/src/mednet/scripts/saliency/interpretability.py index 14a99040b87798863d2d97d69b97b46d14bceaae..83d8cc82836d0bf3e3a595a16d146efd849df22d 100644 --- a/src/mednet/scripts/saliency/interpretability.py +++ b/src/mednet/scripts/saliency/interpretability.py @@ -5,7 +5,6 @@ import pathlib import click - from clapper.click import ResourceOption, verbosity_option from clapper.logging import setup @@ -119,7 +118,6 @@ def interpretability( proportional energy measure in the sense that it does not need explicit thresholding. """ - import json from ...engine.saliency.interpretability import run diff --git a/src/mednet/scripts/saliency/view.py b/src/mednet/scripts/saliency/view.py index 542d177bb757a7984c0e3b89308bcb998b304aa3..bb1a5c9d91f27d5d46d4001d0a3b35cc5fb3dc38 100644 --- a/src/mednet/scripts/saliency/view.py +++ b/src/mednet/scripts/saliency/view.py @@ -2,11 +2,9 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -import os import pathlib import click - from clapper.click import ConfigCommand, ResourceOption, verbosity_option from clapper.logging import setup @@ -99,7 +97,6 @@ def view( **_, ) -> None: # numpydoc ignore=PR01 """Generate heatmaps for input CXRs based on existing saliency maps.""" - from ...engine.saliency.viewer import run assert ( @@ -107,11 +104,11 @@ def view( ), "Output folder must not be the same as the input folder." assert not str(output_folder).startswith( - str(input_folder) + str(input_folder), ), "Output folder must not be a subdirectory of the input folder." logger.info(f"Output folder: {output_folder}") - os.makedirs(output_folder, exist_ok=True) + output_folder.mkdir(parents=True, exist_ok=True) datamodule.set_chunk_size(1, 1) datamodule.drop_incomplete_batch = False diff --git a/src/mednet/scripts/train.py b/src/mednet/scripts/train.py index 9280f1dd02c454feb8212e06f29b47117c1b7660..68a4e7e721f412fbed430774491147f9c1130577 100644 --- a/src/mednet/scripts/train.py +++ b/src/mednet/scripts/train.py @@ -7,7 +7,6 @@ import pathlib import typing import click - from clapper.click import ResourceOption, verbosity_option from clapper.logging import setup @@ -17,8 +16,7 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") def reusable_options(f): - """The options that can be re-used by top-level scripts (i.e. - ``experiment``). + """Wrap reusable training script options (for ``experiment``). This decorator equips the target function ``f`` with all (reusable) ``train`` script options. @@ -258,10 +256,7 @@ def train( can be resumed. """ - import os - import torch - from lightning.pytorch import seed_everything from ..engine.device import DeviceManager @@ -275,14 +270,14 @@ def train( ) checkpoint_file = None - if os.path.isdir(output_folder): + if output_folder.is_dir(): try: checkpoint_file = get_checkpoint_to_resume_training(output_folder) except FileNotFoundError: logger.info( f"Folder {output_folder} already exists, but I did not" f" find any usable checkpoint file to resume training" - f" from. Starting from scratch..." + f" from. Starting from scratch...", ) seed_everything(seed) @@ -307,7 +302,7 @@ def train( # model.balance_losses_by_class(datamodule) else: logger.info( - "Skipping sample class/dataset ownership balancing on user request" + "Skipping sample class/dataset ownership balancing on user request", ) logger.info(f"Training for at most {epochs} epochs.") @@ -327,7 +322,7 @@ def train( else: logger.warning( f"Model {model.name} has no `set_normalizer` method. " - "Skipping normalization setup (unsupported external model)." + "Skipping normalization setup (unsupported external model).", ) else: # Normalizer will be loaded during model.on_load_checkpoint @@ -335,7 +330,7 @@ def train( start_epoch = checkpoint["epoch"] logger.info( f"Resuming from epoch {start_epoch} " - f"(checkpoint file: `{str(checkpoint_file)}`)..." + f"(checkpoint file: `{str(checkpoint_file)}`)...", ) device_manager = DeviceManager(device) @@ -358,7 +353,7 @@ def train( monitoring_interval=monitoring_interval, balance_classes=balance_classes, model_name=model.name, - ) + ), ) json_data.update(model_summary(model)) json_data = {k.replace("_", "-"): v for k, v in json_data.items()} diff --git a/src/mednet/scripts/train_analysis.py b/src/mednet/scripts/train_analysis.py index 0dc23f324c1a71b5705137f001c30a0fb8c7403d..46898ed24079375ff5a9fd53274cff5f8a00cb28 100644 --- a/src/mednet/scripts/train_analysis.py +++ b/src/mednet/scripts/train_analysis.py @@ -5,7 +5,6 @@ import pathlib import click - from clapper.click import verbosity_option from clapper.logging import setup @@ -20,7 +19,7 @@ def create_figures( groups: list[str] = [ "loss/*", "learning-rate", - "memory-used-GB/cpu/*" "rss-GB/cpu/*", + "memory-used-GB/cpu/*rss-GB/cpu/*", "vms-GB/cpu/*", "num-open-files/cpu/*", "num-processes/cpu/*", @@ -55,11 +54,11 @@ def create_figures( list List of matplotlib figures, one per metric. """ + import fnmatch import typing import matplotlib.pyplot as plt - from matplotlib.axes import Axes from matplotlib.figure import Figure from matplotlib.ticker import MaxNLocator @@ -134,9 +133,7 @@ def train_analysis( output: pathlib.Path, ) -> None: # numpydoc ignore=PR01 """Create a plot for each metric in the training logs and saves them in a .pdf file.""" - import matplotlib.pyplot as plt - from matplotlib.backends.backend_pdf import PdfPages from ..utils.tensorboard import scalars_to_dict diff --git a/src/mednet/scripts/utils.py b/src/mednet/scripts/utils.py index d05973d3313c6501fedcaf5b3716fe113d64df72..e958e2a48b2f46aad637e9891a02be4b389fbbc4 100644 --- a/src/mednet/scripts/utils.py +++ b/src/mednet/scripts/utils.py @@ -35,7 +35,7 @@ def model_summary( """ s = lightning.pytorch.utilities.model_summary.ModelSummary( # type: ignore - model + model, ) return dict( @@ -81,7 +81,7 @@ def device_properties( return retval -def execution_metadata() -> dict[str, int | float | str]: +def execution_metadata() -> dict[str, int | float | str | dict[str, str]]: """Produce metadata concerning the running script, in the form of a dictionary. @@ -141,14 +141,14 @@ def execution_metadata() -> dict[str, int | float | str]: f"({current_version}) and actual version returned by " f"versioningit ({actual_version}). This typically happens " f"when you commit changes locally and do not re-install the " - f"package. Run `pip install -e .` or equivalent to fix this." + f"package. Run `pip install -e .` or equivalent to fix this.", ) except Exception as e: # not in a git repo? logger.debug(f"Error {e}") pass - data = { + return { "datetime": datetime, "package-name": __package__.split(".")[0], "package-version": current_version, @@ -161,8 +161,6 @@ def execution_metadata() -> dict[str, int | float | str]: "platform": sys.platform, } - return data - def save_json_with_backup(path: pathlib.Path, data: dict | list) -> None: """Save a dictionary into a JSON file with path checking and backup. diff --git a/src/mednet/utils/checkpointer.py b/src/mednet/utils/checkpointer.py index 51d9192fdd6e7379a75879c98af0d1f78244da6a..cf9e607c93a86b42eb23cec5dfd511de72c89e62 100644 --- a/src/mednet/utils/checkpointer.py +++ b/src/mednet/utils/checkpointer.py @@ -62,11 +62,11 @@ def _get_checkpoint_from_alias( # otherwise, we see if we are looking for a template instead, in which case # we must pick the latest. assert "{epoch}" in str( - template + template, ), f"Template `{str(template)}` does not contain the keyword `{{epoch}}`" pattern = re.compile( - template.name.replace("{epoch}", r"epoch=(?P<epoch>\d+)") + template.name.replace("{epoch}", r"epoch=(?P<epoch>\d+)"), ) highest = -1 for f in template.parent.iterdir(): @@ -78,11 +78,11 @@ def _get_checkpoint_from_alias( if highest != -1: return template.with_name( - template.name.replace("{epoch}", f"epoch={highest}") + template.name.replace("{epoch}", f"epoch={highest}"), ) raise FileNotFoundError( - f"A file matching `{str(template)}` specifications was not found" + f"A file matching `{str(template)}` specifications was not found", ) @@ -138,7 +138,7 @@ def get_checkpoint_to_run_inference( except FileNotFoundError: logger.error( "Did not find lowest-validation-loss model to run inference " - "from. Trying to search for the last periodically saved model..." + "from. Trying to search for the last periodically saved model...", ) return _get_checkpoint_from_alias(path, "periodic") diff --git a/src/mednet/utils/rc.py b/src/mednet/utils/rc.py index fcc1659d1f5304c44e0561a7e1e70c6bd905db7a..f1bdbc5d554d35d1e921bb7fef98fd5ef681576c 100644 --- a/src/mednet/utils/rc.py +++ b/src/mednet/utils/rc.py @@ -12,4 +12,5 @@ def load_rc() -> UserDefaults: ------- The user defaults read from the user .toml configuration file. """ + return UserDefaults("mednet.toml") diff --git a/src/mednet/utils/resources.py b/src/mednet/utils/resources.py index 5f71f57473c7d5768c00c9d08770c568cdcf4090..aa676bc5ff4866eb54efaafb76a9c1f140c9c395 100644 --- a/src/mednet/utils/resources.py +++ b/src/mednet/utils/resources.py @@ -50,6 +50,7 @@ def run_nvidia_smi( information is left alone, memory information is transformed to gigabytes (floating-point). """ + if _nvidia_smi is None: return None @@ -74,7 +75,8 @@ def run_nvidia_smi( def run_powermetrics( - time_window_ms: int = 500, key: str | None = None + time_window_ms: int = 500, + key: str | None = None, ) -> dict[str, typing.Any] | None: """Return GPU information from the system. @@ -127,8 +129,8 @@ def run_powermetrics( suitable: `yourusername ALL=(ALL) NOPASSWD:SETENV: /usr/bin/powermetrics`. Replace `yourusername` by your actual username on the machine. Test the setup running the command - `{' '.join(cmd)}` by hand.""" - ) + `{' '.join(cmd)}` by hand.""", + ), ) return None @@ -151,6 +153,7 @@ def cuda_constants() -> dict[str, str | int | float] | None: * ``memory.total``, as ``gpu_memory_total`` (transformed to gigabytes, :py:class:`float`) """ + retval = run_nvidia_smi(("gpu_name", "driver_version", "memory.total")) if retval is None: return retval @@ -178,7 +181,7 @@ def mps_constants() -> dict[str, str | int | float] | None: """ raw_bytes = subprocess.check_output( - ["/usr/sbin/system_profiler", "-xml", "SPDisplaysDataType"] + ["/usr/sbin/system_profiler", "-xml", "SPDisplaysDataType"], ) data = plistlib.loads(raw_bytes) name = data[0]["_items"][0]["_name"] @@ -214,7 +217,7 @@ def cuda_log() -> dict[str, float] | None: """ result = run_nvidia_smi( - ("memory.total", "memory.used", "memory.free", "utilization.gpu") + ("memory.total", "memory.used", "memory.free", "utilization.gpu"), ) if result is None: @@ -268,6 +271,7 @@ def cpu_constants() -> dict[str, int | float]: in gigabytes 1. ``cpu_count`` (:py:class:`int`): number of logical CPUs available. """ + return { "memory-total-GB/cpu": psutil.virtual_memory().total / GB, "number-of-cores/cpu": psutil.cpu_count(logical=True), @@ -312,6 +316,7 @@ class CPULogger: 5. ``cpu_open_files`` (:py:class:`int`): total number of open files by self and children """ + # check all cluster components and update process list # done so we can keep the cpu_percent() initialization stored_children = set(self.cluster[1:]) @@ -385,7 +390,7 @@ class _InformationGatherer: case "cpu": logger.info( f"Pytorch device-type `{device_type}`: " - f"no GPU logging will be performed " + f"no GPU logging will be performed ", ) case "cuda": example = cuda_log() @@ -398,7 +403,7 @@ class _InformationGatherer: case _: logger.warning( f"Unsupported device-type `{device_type}`: " - f"no GPU logging will be performed " + f"no GPU logging will be performed ", ) self.data: dict[str, list[int | float]] = {k: [] for k in keys} @@ -437,6 +442,7 @@ class _InformationGatherer: dict[str, list[int | float]] A dictionary with a list of resources and their corresponding values. """ + if len(next(iter(self.data.values()))) == 0: self.logger.error("CPU/GPU logger was not able to collect any data") return self.data @@ -451,10 +457,10 @@ def _monitor_worker( queue: queue.Queue, logging_level: int, ): - """A monitoring worker that measures resources and returns lists. + """Monitor worker that measures resources and returns lists. Parameters - ========== + ---------- interval Number of seconds to wait between each measurement (maybe a floating point number as accepted by :py:func:`time.sleep`). @@ -472,6 +478,7 @@ def _monitor_worker( logging_level The logging level to use for logging from launched processes. """ + logger = multiprocessing.log_to_stderr(level=logging_level) ra = _InformationGatherer(device_type, main_pid, logger) @@ -488,7 +495,7 @@ def _monitor_worker( except Exception: logger.exception( "Iterative CPU/GPU logging did not work properly." - " Exception follows. Retrying..." + " Exception follows. Retrying...", ) time.sleep(0.5) # wait half a second, and try again! @@ -522,9 +529,9 @@ class ResourceMonitor: self.main_pid = main_pid self.stop_event = multiprocessing.Event() self.summary_event = multiprocessing.Event() - self.q: multiprocessing.Queue[ - dict[str, list[int | float]] - ] = multiprocessing.Queue() + self.q: multiprocessing.Queue[dict[str, list[int | float]]] = ( + multiprocessing.Queue() + ) self.logging_level = logging_level self.monitor = multiprocessing.Process( @@ -558,17 +565,18 @@ class ResourceMonitor: least one entry is kept. Useful to remove spurious observations by the end of a period. """ + self.summary_event.set() try: data: dict[str, list[int | float]] = self.q.get( - timeout=2 * self.interval + timeout=2 * self.interval, ) except queue.Empty: logger.warning( f"CPU/GPU resource monitor did not provide anything when " f"joined (even after a {2*self.interval}-second timeout - " f"this is normally due to exceptions on the monitoring process. " - f"Check above for other exceptions." + f"Check above for other exceptions.", ) self.data = None else: @@ -593,5 +601,5 @@ class ResourceMonitor: if self.monitor.exitcode != 0: logger.error( f"CPU/GPU resource monitor process exited with code " - f"{self.monitor.exitcode}. Check logs for errors!" + f"{self.monitor.exitcode}. Check logs for errors!", ) diff --git a/src/mednet/utils/summary.py b/src/mednet/utils/summary.py deleted file mode 100644 index bff705e30b557a0314dfef929535671e3dad7f81..0000000000000000000000000000000000000000 --- a/src/mednet/utils/summary.py +++ /dev/null @@ -1,61 +0,0 @@ -# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> -# -# SPDX-License-Identifier: GPL-3.0-or-later - -# Adapted from https://github.com/pytorch/pytorch/issues/2001#issuecomment-405675488 - -from functools import reduce - -import torch - -from torch.nn.modules.module import _addindent - - -# ignore this space! -def _repr(model: torch.nn.Module) -> tuple[str, int]: - # We treat the extra repr like the sub-module, one item per line - extra_lines = [] - extra_repr = model.extra_repr() - # empty string will be split into list [''] - if extra_repr: - extra_lines = extra_repr.split("\n") - child_lines = [] - total_params = 0 - for key, module in model._modules.items(): - mod_str, num_params = _repr(module) - mod_str = _addindent(mod_str, 2) - child_lines.append("(" + key + "): " + mod_str) - total_params += num_params - lines = extra_lines + child_lines - - for _, p in model._parameters.items(): - if hasattr(p, "dtype"): - total_params += reduce(lambda x, y: x * y, p.shape) - - main_str = model._get_name() + "(" - if lines: - # simple one-liner info, which most builtin Modules will use - if len(extra_lines) == 1 and not child_lines: - main_str += extra_lines[0] - else: - main_str += "\n " + "\n ".join(lines) + "\n" - - main_str += ")" - main_str += f", {total_params:,} params" - return main_str, total_params - - -def summary(model: torch.nn.Module) -> tuple[str, int]: - """Count the number of parameters in each model layer. - - Parameters - ---------- - model - Model to summarize. - - Returns - ------- - tuple[int, str] - A tuple containing a multiline string representation of the network and the number of parameters. - """ - return _repr(model) diff --git a/src/mednet/utils/tensorboard.py b/src/mednet/utils/tensorboard.py index 56a81c4497bd8552b15ce4468a4fb4fa8a0e1ce9..7c87bfba0519151941642f28fc0cce2cc6d2bb5e 100644 --- a/src/mednet/utils/tensorboard.py +++ b/src/mednet/utils/tensorboard.py @@ -31,6 +31,7 @@ def scalars_to_dict( values were taken), when the monitored values themselves. The lists are pre-sorted by epoch number. """ + retval: dict[str, tuple[list[int], list[float]]] = {} for logfile in sorted(logdir.glob("events.out.tfevents.*")): diff --git a/tests/conftest.py b/tests/conftest.py index b51eafa681ca9c95be145a2c5eedea0b1fe5a5b8..3b3e1ee95200db30fc73f83c4385c6ffc38eb204 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -8,14 +8,14 @@ import typing import numpy import pytest import torch - from mednet.data.split import JSONDatabaseSplit from mednet.data.typing import DatabaseSplit @pytest.fixture def datadir(request) -> pathlib.Path: - """Return the directory in which the test is sitting. Check the pytest documentation for more information. + """Return the directory in which the test is sitting. Check the pytest + documentation for more information. Parameters ---------- @@ -27,6 +27,7 @@ def datadir(request) -> pathlib.Path: pathlib.Path The directory in which the test is sitting. """ + return pathlib.Path(request.module.__file__).parents[0] / "data" @@ -36,8 +37,10 @@ def pytest_configure(config): Parameters ---------- config - Configuration values. Check the pytest documentation for more information. + Configuration values. Check the pytest documentation for more + information. """ + config.addinivalue_line( "markers", "skip_if_rc_var_not_set(name): this mark skips the test if a certain " @@ -56,8 +59,10 @@ def pytest_runtest_setup(item): Parameters ---------- item - A test invocation item. Check the pytest documentation for more information. + A test invocation item. Check the pytest documentation for more + information. """ + from mednet.utils.rc import load_rc rc = load_rc() @@ -76,7 +81,7 @@ def pytest_runtest_setup(item): if any(missing): pytest.skip( f"Test skipped because {', '.join(missing)} is **not** " - f"set in ~/.config/mednet.toml" + f"set in ~/.config/mednet.toml", ) @@ -112,10 +117,9 @@ class DatabaseCheckers: split An instance of DatabaseSplit. lengths - - A dictionary that contains keys matching those of the split (this will - be checked). The values of the dictionary should correspond to the - sizes of each of the datasets in the split. + A dictionary that contains keys matching those of the split (this + will be checked). The values of the dictionary should correspond + to the sizes of each of the datasets in the split. prefixes Each file named in a split should start with at least one of these prefixes. @@ -148,7 +152,7 @@ class DatabaseCheckers: prefixes: typing.Sequence[str], possible_labels: typing.Sequence[int], expected_num_labels: int, - expected_image_shape: typing.Optional[tuple[int, ...]] = None, + expected_image_shape: tuple[int, ...] | None = None, ): """Check the consistency of an individual (loaded) batch. @@ -179,7 +183,7 @@ class DatabaseCheckers: if expected_image_shape: assert all( - [data.shape == expected_image_shape for data in batch[0]] + [data.shape == expected_image_shape for data in batch[0]], ) assert isinstance(batch[1], dict) # metadata @@ -193,7 +197,10 @@ class DatabaseCheckers: assert "name" in batch[1] assert all( - [any([k.startswith(j) for j in prefixes]) for k in batch[1]["name"]] + [ + any([k.startswith(j) for j in prefixes]) + for k in batch[1]["name"] + ], ) # use the code below to view generated images @@ -213,8 +220,9 @@ class DatabaseCheckers: for split_name in ref_histogram_splits: raw_samples = datamodule.splits[split_name][0][0] - # It is not possible to get a sample from a Dataset by name/path, only by index. - # This creates a dict of sample name to dataset index. + # It is not possible to get a sample from a Dataset by name/path, + # only by index. This creates a dict of sample name to dataset + # index. raw_samples_indices = {} for idx, rs in enumerate(raw_samples): raw_samples_indices[rs[0]] = idx @@ -222,27 +230,32 @@ class DatabaseCheckers: for ref_hist_path, ref_hist_data in ref_histogram_splits[ split_name ]: - # Get index in the dataset that will return the data corresponding to the specified sample name + # Get index in the dataset that will return the data + # corresponding to the specified sample name dataset_sample_index = raw_samples_indices[ref_hist_path] - image_tensor = datamodule._datasets[split_name][ + image_tensor = datamodule._datasets[split_name][ # noqa: SLF001 dataset_sample_index ][0] histogram = [] for color_channel in image_tensor: color_channel = numpy.multiply( - color_channel.numpy(), 255 + color_channel.numpy(), + 255, ).astype(int) histogram.extend( numpy.histogram( - color_channel, bins=256, range=(0, 256) - )[0].tolist() + color_channel, + bins=256, + range=(0, 256), + )[0].tolist(), ) if compare_type == "statistical": - # Compute pearson coefficients between histogram and reference - # and check the similarity within a certain threshold + # Compute pearson coefficients between histogram and + # reference and check the similarity within a certain + # threshold pearson_coeffs = numpy.corrcoef(histogram, ref_hist_data) assert ( 1 - pearson_coeff_threshold <= pearson_coeffs[0][1] <= 1 diff --git a/tests/test_cli.py b/tests/test_cli.py index 381ef539bc7a4dec82e468456124fa42086be7e5..aa04ec07067766e5d7174db4cdd892783c7b5de6 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -7,7 +7,6 @@ import contextlib import re import pytest - from click.testing import CliRunner @@ -48,25 +47,25 @@ def test_config_help(): def test_config_list_help(): - from mednet.scripts.config import list + from mednet.scripts.config import list_ - _check_help(list) + _check_help(list_) def test_config_list(): - from mednet.scripts.config import list + from mednet.scripts.config import list_ runner = CliRunner() - result = runner.invoke(list) + result = runner.invoke(list_) _assert_exit_0(result) assert "module: mednet.config.data" in result.output assert "module: mednet.config.models" in result.output def test_config_list_v(): - from mednet.scripts.config import list + from mednet.scripts.config import list_ - result = CliRunner().invoke(list, ["--verbose"]) + result = CliRunner().invoke(list_, ["--verbose"]) _assert_exit_0(result) assert "module: mednet.config.data" in result.output assert "module: mednet.config.models" in result.output @@ -95,16 +94,16 @@ def test_database_help(): def test_datamodule_list_help(): - from mednet.scripts.database import list + from mednet.scripts.database import list_ - _check_help(list) + _check_help(list_) def test_datamodule_list(): - from mednet.scripts.database import list + from mednet.scripts.database import list_ runner = CliRunner() - result = runner.invoke(list) + result = runner.invoke(list_) _assert_exit_0(result) assert result.output.startswith("Available databases:") @@ -331,7 +330,8 @@ def test_predict_pasa_montgomery(temporary_basedir): with stdout_logging() as buf: output = temporary_basedir / "predictions.json" last = _get_checkpoint_from_alias( - temporary_basedir / "results", "periodic" + temporary_basedir / "results", + "periodic", ) assert last.name.endswith("epoch=0" + CHECKPOINT_EXTENSION) result = runner.invoke( @@ -447,9 +447,9 @@ def test_experiment(temporary_basedir): len( list( (output_folder / "model").glob( - "model-at-lowest-validation-loss-epoch=*.ckpt" - ) - ) + "model-at-lowest-validation-loss-epoch=*.ckpt", + ), + ), ) == 1 ) @@ -457,8 +457,10 @@ def test_experiment(temporary_basedir): assert ( len( list( - (output_folder / "model" / "logs").glob("events.out.tfevents.*") - ) + (output_folder / "model" / "logs").glob( + "events.out.tfevents.*", + ), + ), ) == 1 ) @@ -471,9 +473,9 @@ def test_experiment(temporary_basedir): len( list( (output_folder / "gradcam" / "saliencies" / "CXR_png").glob( - "MCUCXR_*.npy" - ) - ) + "MCUCXR_*.npy", + ), + ), ) == 138 ) @@ -482,9 +484,9 @@ def test_experiment(temporary_basedir): len( list( (output_folder / "gradcam" / "visualizations" / "CXR_png").glob( - "MCUCXR_*.png" - ) - ) + "MCUCXR_*.png", + ), + ), ) == 58 ) diff --git a/tests/test_database_split.py b/tests/test_database_split.py index f7d8d3403f20f9e2fbd609e9e438b442bfbdb9e4..8981fa00d817e94e5203c24dd04f8ea6185d964a 100644 --- a/tests/test_database_split.py +++ b/tests/test_database_split.py @@ -32,11 +32,11 @@ def test_json_loading(datadir): assert len(database_split["train"]) == 75 for k in database_split["train"]: for f in range(4): - assert isinstance(k[f], (int, float)) + assert isinstance(k[f], int | float) assert isinstance(k[4], str) assert len(database_split["test"]) == 75 for k in database_split["test"]: for f in range(4): - assert isinstance(k[f], (int, float)) + assert isinstance(k[f], int | float) assert isinstance(k[4], str) diff --git a/tests/test_evaluator.py b/tests/test_evaluator.py index 8a557278b2abc552cd095eec505fcab146814897..589d8507e309cf5debfbf37b1068b1339ab4cb14 100644 --- a/tests/test_evaluator.py +++ b/tests/test_evaluator.py @@ -41,7 +41,10 @@ def test_run_binary_1(): ] results = run_binary( - "test", predictions, binning=10, threshold_a_priori=0.5 + "test", + predictions, + binning=10, + threshold_a_priori=0.5, ) assert results["num_samples"] == 4 @@ -50,10 +53,12 @@ def test_run_binary_1(): assert numpy.isclose(results["precision"], 1 / 2) # tp / (tp + fp) assert numpy.isclose(results["recall"], 1 / 2) # tp / (tp + fn) assert numpy.isclose( - results["f1_score"], 2 * (1 / 2 * 1 / 2) / (1 / 2 + 1 / 2) + results["f1_score"], + 2 * (1 / 2 * 1 / 2) / (1 / 2 + 1 / 2), ) # 2 * (prec. * recall) / (prec. + recall) assert numpy.isclose( - results["accuracy"], (1 + 1) / (1 + 1 + 1 + 1) + results["accuracy"], + (1 + 1) / (1 + 1 + 1 + 1), ) # (tp + tn) / (tp + fn + tn + fp) assert numpy.isclose(results["specificity"], 1 / 2) # tn / (tn + fp) @@ -94,7 +99,10 @@ def test_run_binary_2(): # a change in the threshold should not affect auc and average precision scores results = run_binary( - "test", predictions, binning=10, threshold_a_priori=0.3 + "test", + predictions, + binning=10, + threshold_a_priori=0.3, ) assert results["num_samples"] == 4 @@ -104,10 +112,12 @@ def test_run_binary_2(): assert numpy.isclose(results["precision"], 2 / 3) # tp / (tp + fp) assert numpy.isclose(results["recall"], 2 / 2) # tp / (tp + fn) assert numpy.isclose( - results["f1_score"], 2 * (2 / 3 * 2 / 2) / (2 / 3 + 2 / 2) + results["f1_score"], + 2 * (2 / 3 * 2 / 2) / (2 / 3 + 2 / 2), ) # 2 * (prec. * recall) / (prec. + recall) assert numpy.isclose( - results["accuracy"], (2 + 1) / (2 + 0 + 1 + 1) + results["accuracy"], + (2 + 1) / (2 + 0 + 1 + 1), ) # (tp + tn) / (tp + fn + tn + fp) assert numpy.isclose(results["specificity"], 1 / (1 + 1)) # tn / (tn + fp) diff --git a/tests/test_hivtb.py b/tests/test_hivtb.py index efc11f40650eda91ec0d47a26f94214a33fc3419..14bdfc37e23ce15d7237f5d5f9517da41e4d0fbf 100644 --- a/tests/test_hivtb.py +++ b/tests/test_hivtb.py @@ -6,7 +6,6 @@ import importlib import pytest - from click.testing import CliRunner @@ -33,7 +32,9 @@ def id_function(val): ids=id_function, # just changes how pytest prints it ) def test_protocol_consistency( - database_checkers, split: str, lenghts: dict[str, int] + database_checkers, + split: str, + lenghts: dict[str, int], ): from mednet.config.data.hivtb.datamodule import make_split @@ -82,7 +83,8 @@ def test_database_check(): ) def test_loading(database_checkers, name: str, dataset: str): datamodule = importlib.import_module( - f".{name}", "mednet.config.data.hivtb" + f".{name}", + "mednet.config.data.hivtb", ).datamodule datamodule.model_transforms = [] # should be done before setup() @@ -108,11 +110,12 @@ def test_loading(database_checkers, name: str, dataset: str): @pytest.mark.skip_if_rc_var_not_set("datadir.hivtb") def test_loaded_image_quality(database_checkers, datadir): reference_histogram_file = str( - datadir / "histograms/raw_data/histograms_hivtb_fold_0.json" + datadir / "histograms/raw_data/histograms_hivtb_fold_0.json", ) datamodule = importlib.import_module( - ".fold_0", "mednet.config.data.hivtb" + ".fold_0", + "mednet.config.data.hivtb", ).datamodule datamodule.model_transforms = [] diff --git a/tests/test_image_utils.py b/tests/test_image_utils.py index 57317c22a108c0b953f348b5adf7a450ff5e88d6..5fee144ca554a749e90f4ca82285deb939d7c8d4 100644 --- a/tests/test_image_utils.py +++ b/tests/test_image_utils.py @@ -5,7 +5,6 @@ import numpy import PIL.Image - from mednet.data.image_utils import remove_black_borders diff --git a/tests/test_indian.py b/tests/test_indian.py index 1a179bd81dd275eb31bbe5f206f71d50b9c4a57b..18bb73ca76edf9309d51d20a9ff8d94229b6806b 100644 --- a/tests/test_indian.py +++ b/tests/test_indian.py @@ -9,7 +9,6 @@ dataset A/dataset B) dataset. import importlib import pytest - from click.testing import CliRunner @@ -37,7 +36,9 @@ def id_function(val): ids=id_function, # just changes how pytest prints it ) def test_protocol_consistency( - database_checkers, split: str, lenghts: dict[str, int] + database_checkers, + split: str, + lenghts: dict[str, int], ): from mednet.config.data.indian.datamodule import make_split @@ -87,7 +88,8 @@ def test_database_check(): ) def test_loading(database_checkers, name: str, dataset: str): datamodule = importlib.import_module( - f".{name}", "mednet.config.data.indian" + f".{name}", + "mednet.config.data.indian", ).datamodule datamodule.model_transforms = [] # should be done before setup() @@ -113,11 +115,12 @@ def test_loading(database_checkers, name: str, dataset: str): @pytest.mark.skip_if_rc_var_not_set("datadir.indian") def test_loaded_image_quality(database_checkers, datadir): reference_histogram_file = str( - datadir / "histograms/raw_data/histograms_indian_default.json" + datadir / "histograms/raw_data/histograms_indian_default.json", ) datamodule = importlib.import_module( - ".default", "mednet.config.data.indian" + ".default", + "mednet.config.data.indian", ).datamodule datamodule.model_transforms = [] diff --git a/tests/test_montgomery.py b/tests/test_montgomery.py index 42b2b9b72f986ca8d8e38e1d19e97bd8a7b89be8..66cc634670909b6629f527f63482bb108991c6e5 100644 --- a/tests/test_montgomery.py +++ b/tests/test_montgomery.py @@ -6,7 +6,6 @@ import importlib import pytest - from click.testing import CliRunner @@ -34,7 +33,9 @@ def id_function(val): ids=id_function, # just changes how pytest prints it ) def test_protocol_consistency( - database_checkers, split: str, lenghts: dict[str, int] + database_checkers, + split: str, + lenghts: dict[str, int], ): from mednet.config.data.montgomery.datamodule import make_split @@ -84,7 +85,8 @@ def test_database_check(): ) def test_loading(database_checkers, name: str, dataset: str): datamodule = importlib.import_module( - f".{name}", "mednet.config.data.montgomery" + f".{name}", + "mednet.config.data.montgomery", ).datamodule datamodule.model_transforms = [] # should be done before setup() @@ -110,11 +112,12 @@ def test_loading(database_checkers, name: str, dataset: str): @pytest.mark.skip_if_rc_var_not_set("datadir.montgomery") def test_raw_transforms_image_quality(database_checkers, datadir): reference_histogram_file = str( - datadir / "histograms/raw_data/histograms_montgomery_default.json" + datadir / "histograms/raw_data/histograms_montgomery_default.json", ) datamodule = importlib.import_module( - ".default", "mednet.config.data.montgomery" + ".default", + "mednet.config.data.montgomery", ).datamodule datamodule.model_transforms = [] @@ -137,20 +140,22 @@ def test_model_transforms_image_quality(database_checkers, datadir, model_name): if model_name == "densenet": reference_histogram_file = str( datadir - / "histograms/models/histograms_densenet-121_montgomery_default.json" + / "histograms/models/histograms_densenet-121_montgomery_default.json", ) else: reference_histogram_file = str( datadir - / f"histograms/models/histograms_{model_name}_montgomery_default.json" + / f"histograms/models/histograms_{model_name}_montgomery_default.json", ) datamodule = importlib.import_module( - ".default", "mednet.config.data.montgomery" + ".default", + "mednet.config.data.montgomery", ).datamodule model = importlib.import_module( - f".{model_name}", "mednet.config.models" + f".{model_name}", + "mednet.config.models", ).model datamodule.model_transforms = model.model_transforms diff --git a/tests/test_montgomery_shenzhen.py b/tests/test_montgomery_shenzhen.py index 8bd7093229d2e104f6e2ccde0f589d04cb5abba1..58bc9806be67bffb70c1a6cd78f18e7f39a11f7d 100644 --- a/tests/test_montgomery_shenzhen.py +++ b/tests/test_montgomery_shenzhen.py @@ -6,7 +6,6 @@ import importlib import pytest - from click.testing import CliRunner @@ -28,33 +27,38 @@ from click.testing import CliRunner ) def test_split_consistency(name: str): montgomery = importlib.import_module( - f".{name}", "mednet.config.data.montgomery" + f".{name}", + "mednet.config.data.montgomery", ).datamodule shenzhen = importlib.import_module( - f".{name}", "mednet.config.data.shenzhen" + f".{name}", + "mednet.config.data.shenzhen", ).datamodule combined = importlib.import_module( - f".{name}", "mednet.config.data.montgomery_shenzhen" + f".{name}", + "mednet.config.data.montgomery_shenzhen", ).datamodule - MontgomeryLoader = importlib.import_module( - ".datamodule", "mednet.config.data.montgomery" + montgomery_loader = importlib.import_module( + ".datamodule", + "mednet.config.data.montgomery", ).RawDataLoader - ShenzhenLoader = importlib.import_module( - ".datamodule", "mednet.config.data.shenzhen" + shenzhen_loader = importlib.import_module( + ".datamodule", + "mednet.config.data.shenzhen", ).RawDataLoader for split in ("train", "validation", "test"): assert montgomery.splits[split][0][0] == combined.splits[split][0][0] - assert isinstance(montgomery.splits[split][0][1], MontgomeryLoader) - assert isinstance(combined.splits[split][0][1], MontgomeryLoader) + assert isinstance(montgomery.splits[split][0][1], montgomery_loader) + assert isinstance(combined.splits[split][0][1], montgomery_loader) assert shenzhen.splits[split][0][0] == combined.splits[split][1][0] - assert isinstance(shenzhen.splits[split][0][1], ShenzhenLoader) - assert isinstance(combined.splits[split][1][1], ShenzhenLoader) + assert isinstance(shenzhen.splits[split][0][1], shenzhen_loader) + assert isinstance(combined.splits[split][1][1], shenzhen_loader) @pytest.mark.skip_if_rc_var_not_set("datadir.montgomery") diff --git a/tests/test_montgomery_shenzhen_indian.py b/tests/test_montgomery_shenzhen_indian.py index 7534f2995226e913222a414a7216d40bc83659b2..470ee2cfaded3164a29ee828130f5d79829f6b91 100644 --- a/tests/test_montgomery_shenzhen_indian.py +++ b/tests/test_montgomery_shenzhen_indian.py @@ -6,7 +6,6 @@ import importlib import pytest - from click.testing import CliRunner @@ -28,45 +27,52 @@ from click.testing import CliRunner ) def test_split_consistency(name: str): montgomery = importlib.import_module( - f".{name}", "mednet.config.data.montgomery" + f".{name}", + "mednet.config.data.montgomery", ).datamodule shenzhen = importlib.import_module( - f".{name}", "mednet.config.data.shenzhen" + f".{name}", + "mednet.config.data.shenzhen", ).datamodule indian = importlib.import_module( - f".{name}", "mednet.config.data.indian" + f".{name}", + "mednet.config.data.indian", ).datamodule combined = importlib.import_module( - f".{name}", "mednet.config.data.montgomery_shenzhen_indian" + f".{name}", + "mednet.config.data.montgomery_shenzhen_indian", ).datamodule - MontgomeryLoader = importlib.import_module( - ".datamodule", "mednet.config.data.montgomery" + montgomery_loader = importlib.import_module( + ".datamodule", + "mednet.config.data.montgomery", ).RawDataLoader - ShenzhenLoader = importlib.import_module( - ".datamodule", "mednet.config.data.shenzhen" + shenzhen_loader = importlib.import_module( + ".datamodule", + "mednet.config.data.shenzhen", ).RawDataLoader - IndianLoader = importlib.import_module( - ".datamodule", "mednet.config.data.indian" + indian_loader = importlib.import_module( + ".datamodule", + "mednet.config.data.indian", ).RawDataLoader for split in ("train", "validation", "test"): assert montgomery.splits[split][0][0] == combined.splits[split][0][0] - assert isinstance(montgomery.splits[split][0][1], MontgomeryLoader) - assert isinstance(combined.splits[split][0][1], MontgomeryLoader) + assert isinstance(montgomery.splits[split][0][1], montgomery_loader) + assert isinstance(combined.splits[split][0][1], montgomery_loader) assert shenzhen.splits[split][0][0] == combined.splits[split][1][0] - assert isinstance(shenzhen.splits[split][0][1], ShenzhenLoader) - assert isinstance(combined.splits[split][1][1], ShenzhenLoader) + assert isinstance(shenzhen.splits[split][0][1], shenzhen_loader) + assert isinstance(combined.splits[split][1][1], shenzhen_loader) assert indian.splits[split][0][0] == combined.splits[split][2][0] - assert isinstance(indian.splits[split][0][1], IndianLoader) - assert isinstance(combined.splits[split][2][1], IndianLoader) + assert isinstance(indian.splits[split][0][1], indian_loader) + assert isinstance(combined.splits[split][2][1], indian_loader) @pytest.mark.skip_if_rc_var_not_set("datadir.montgomery") diff --git a/tests/test_montgomery_shenzhen_indian_padchest.py b/tests/test_montgomery_shenzhen_indian_padchest.py index b70d9c2532ab31954dd0d7f9d9b5f701a5c582ce..d56798a794b9d82fcdac5d14563a43f616e288cd 100644 --- a/tests/test_montgomery_shenzhen_indian_padchest.py +++ b/tests/test_montgomery_shenzhen_indian_padchest.py @@ -6,7 +6,6 @@ import importlib import pytest - from click.testing import CliRunner @@ -18,59 +17,68 @@ from click.testing import CliRunner ) def test_split_consistency(name: str, padchest_name: str): montgomery = importlib.import_module( - f".{name}", "mednet.config.data.montgomery" + f".{name}", + "mednet.config.data.montgomery", ).datamodule shenzhen = importlib.import_module( - f".{name}", "mednet.config.data.shenzhen" + f".{name}", + "mednet.config.data.shenzhen", ).datamodule indian = importlib.import_module( - f".{name}", "mednet.config.data.indian" + f".{name}", + "mednet.config.data.indian", ).datamodule padchest = importlib.import_module( - f".{padchest_name}", "mednet.config.data.padchest" + f".{padchest_name}", + "mednet.config.data.padchest", ).datamodule combined = importlib.import_module( - f".{name}", "mednet.config.data.montgomery_shenzhen_indian_padchest" + f".{name}", + "mednet.config.data.montgomery_shenzhen_indian_padchest", ).datamodule - MontgomeryLoader = importlib.import_module( - ".datamodule", "mednet.config.data.montgomery" + montgomery_loader = importlib.import_module( + ".datamodule", + "mednet.config.data.montgomery", ).RawDataLoader - ShenzhenLoader = importlib.import_module( - ".datamodule", "mednet.config.data.shenzhen" + shenzhen_loader = importlib.import_module( + ".datamodule", + "mednet.config.data.shenzhen", ).RawDataLoader - IndianLoader = importlib.import_module( - ".datamodule", "mednet.config.data.indian" + indian_loader = importlib.import_module( + ".datamodule", + "mednet.config.data.indian", ).RawDataLoader - PadChestLoader = importlib.import_module( - ".datamodule", "mednet.config.data.padchest" + padchest_loader = importlib.import_module( + ".datamodule", + "mednet.config.data.padchest", ).RawDataLoader for split in ("train", "validation", "test"): assert montgomery.splits[split][0][0] == combined.splits[split][0][0] - assert isinstance(montgomery.splits[split][0][1], MontgomeryLoader) - assert isinstance(combined.splits[split][0][1], MontgomeryLoader) + assert isinstance(montgomery.splits[split][0][1], montgomery_loader) + assert isinstance(combined.splits[split][0][1], montgomery_loader) assert shenzhen.splits[split][0][0] == combined.splits[split][1][0] - assert isinstance(shenzhen.splits[split][0][1], ShenzhenLoader) - assert isinstance(combined.splits[split][1][1], ShenzhenLoader) + assert isinstance(shenzhen.splits[split][0][1], shenzhen_loader) + assert isinstance(combined.splits[split][1][1], shenzhen_loader) assert indian.splits[split][0][0] == combined.splits[split][2][0] - assert isinstance(indian.splits[split][0][1], IndianLoader) - assert isinstance(combined.splits[split][2][1], IndianLoader) + assert isinstance(indian.splits[split][0][1], indian_loader) + assert isinstance(combined.splits[split][2][1], indian_loader) if split != "validation": # padchest has no validation assert padchest.splits[split][0][0] == combined.splits[split][3][0] - assert isinstance(padchest.splits[split][0][1], PadChestLoader) - assert isinstance(combined.splits[split][3][1], PadChestLoader) + assert isinstance(padchest.splits[split][0][1], padchest_loader) + assert isinstance(combined.splits[split][3][1], padchest_loader) @pytest.mark.skip_if_rc_var_not_set("datadir.montgomery") diff --git a/tests/test_montgomery_shenzhen_indian_tbx11k.py b/tests/test_montgomery_shenzhen_indian_tbx11k.py index e8f167f13c21057b75129c2f0506d6072e617cf1..0a4a3bffc15ca8c7fee5dd8a93b1c53c0f3dcfd0 100644 --- a/tests/test_montgomery_shenzhen_indian_tbx11k.py +++ b/tests/test_montgomery_shenzhen_indian_tbx11k.py @@ -6,7 +6,6 @@ import importlib import pytest - from click.testing import CliRunner @@ -39,19 +38,23 @@ from click.testing import CliRunner ) def test_split_consistency(name: str, tbx11k_name: str): montgomery = importlib.import_module( - f".{name}", "mednet.config.data.montgomery" + f".{name}", + "mednet.config.data.montgomery", ).datamodule shenzhen = importlib.import_module( - f".{name}", "mednet.config.data.shenzhen" + f".{name}", + "mednet.config.data.shenzhen", ).datamodule indian = importlib.import_module( - f".{name}", "mednet.config.data.indian" + f".{name}", + "mednet.config.data.indian", ).datamodule tbx11k = importlib.import_module( - f".{tbx11k_name}", "mednet.config.data.tbx11k" + f".{tbx11k_name}", + "mednet.config.data.tbx11k", ).datamodule combined = importlib.import_module( @@ -59,38 +62,42 @@ def test_split_consistency(name: str, tbx11k_name: str): "mednet.config.data.montgomery_shenzhen_indian_tbx11k", ).datamodule - MontgomeryLoader = importlib.import_module( - ".datamodule", "mednet.config.data.montgomery" + montgomery_loader = importlib.import_module( + ".datamodule", + "mednet.config.data.montgomery", ).RawDataLoader - ShenzhenLoader = importlib.import_module( - ".datamodule", "mednet.config.data.shenzhen" + shenzhen_loader = importlib.import_module( + ".datamodule", + "mednet.config.data.shenzhen", ).RawDataLoader - IndianLoader = importlib.import_module( - ".datamodule", "mednet.config.data.indian" + indian_loader = importlib.import_module( + ".datamodule", + "mednet.config.data.indian", ).RawDataLoader - TBX11kLoader = importlib.import_module( - ".datamodule", "mednet.config.data.tbx11k" + tbx11k_loader = importlib.import_module( + ".datamodule", + "mednet.config.data.tbx11k", ).RawDataLoader for split in ("train", "validation", "test"): assert montgomery.splits[split][0][0] == combined.splits[split][0][0] - assert isinstance(montgomery.splits[split][0][1], MontgomeryLoader) - assert isinstance(combined.splits[split][0][1], MontgomeryLoader) + assert isinstance(montgomery.splits[split][0][1], montgomery_loader) + assert isinstance(combined.splits[split][0][1], montgomery_loader) assert shenzhen.splits[split][0][0] == combined.splits[split][1][0] - assert isinstance(shenzhen.splits[split][0][1], ShenzhenLoader) - assert isinstance(combined.splits[split][1][1], ShenzhenLoader) + assert isinstance(shenzhen.splits[split][0][1], shenzhen_loader) + assert isinstance(combined.splits[split][1][1], shenzhen_loader) assert indian.splits[split][0][0] == combined.splits[split][2][0] - assert isinstance(indian.splits[split][0][1], IndianLoader) - assert isinstance(combined.splits[split][2][1], IndianLoader) + assert isinstance(indian.splits[split][0][1], indian_loader) + assert isinstance(combined.splits[split][2][1], indian_loader) assert tbx11k.splits[split][0][0] == combined.splits[split][3][0] - assert isinstance(tbx11k.splits[split][0][1], TBX11kLoader) - assert isinstance(combined.splits[split][3][1], TBX11kLoader) + assert isinstance(tbx11k.splits[split][0][1], tbx11k_loader) + assert isinstance(combined.splits[split][3][1], tbx11k_loader) @pytest.mark.skip_if_rc_var_not_set("datadir.montgomery") diff --git a/tests/test_nih_cxr14.py b/tests/test_nih_cxr14.py index 8a16fec9ff19141b35e56e4d3b0e10c6b32e806e..61c21ea7122a7dfeb2918ba82b1e5ef3b3056b0b 100644 --- a/tests/test_nih_cxr14.py +++ b/tests/test_nih_cxr14.py @@ -6,7 +6,6 @@ import importlib import pytest - from click.testing import CliRunner @@ -25,7 +24,9 @@ def id_function(val): ids=id_function, # just changes how pytest prints it ) def test_protocol_consistency( - database_checkers, split: str, lenghts: dict[str, int] + database_checkers, + split: str, + lenghts: dict[str, int], ): from mednet.config.data.nih_cxr14.datamodule import make_split @@ -61,7 +62,8 @@ def test_database_check(): @pytest.mark.parametrize("name,dataset,num_labels", testdata) def test_loading(database_checkers, name: str, dataset: str, num_labels: int): datamodule = importlib.import_module( - f".{name}", "mednet.config.data.nih_cxr14" + f".{name}", + "mednet.config.data.nih_cxr14", ).datamodule datamodule.model_transforms = [] # should be done before setup() @@ -88,11 +90,12 @@ def test_loading(database_checkers, name: str, dataset: str, num_labels: int): @pytest.mark.skip_if_rc_var_not_set("datadir.nih_cxr14") def test_loaded_image_quality(database_checkers, datadir): reference_histogram_file = str( - datadir / "histograms/raw_data/histograms_nih_cxr14_default.json" + datadir / "histograms/raw_data/histograms_nih_cxr14_default.json", ) datamodule = importlib.import_module( - ".default", "mednet.config.data.nih_cxr14" + ".default", + "mednet.config.data.nih_cxr14", ).datamodule datamodule.model_transforms = [] diff --git a/tests/test_nih_cxr14_padchest.py b/tests/test_nih_cxr14_padchest.py index d98e096d4ed74fd3906137e2f1374e84ac1e9248..a36655bbf451c649a25591d3665d4403d63675b2 100644 --- a/tests/test_nih_cxr14_padchest.py +++ b/tests/test_nih_cxr14_padchest.py @@ -6,7 +6,6 @@ import importlib import pytest - from click.testing import CliRunner @@ -18,35 +17,40 @@ from click.testing import CliRunner ) def test_split_consistency(name: str, padchest_name: str, combined_name: str): nih_cxr14 = importlib.import_module( - f".{name}", "mednet.config.data.nih_cxr14" + f".{name}", + "mednet.config.data.nih_cxr14", ).datamodule padchest = importlib.import_module( - f".{padchest_name}", "mednet.config.data.padchest" + f".{padchest_name}", + "mednet.config.data.padchest", ).datamodule combined = importlib.import_module( - f".{combined_name}", "mednet.config.data.nih_cxr14_padchest" + f".{combined_name}", + "mednet.config.data.nih_cxr14_padchest", ).datamodule - CXR14Loader = importlib.import_module( - ".datamodule", "mednet.config.data.nih_cxr14" + cxr14_loader = importlib.import_module( + ".datamodule", + "mednet.config.data.nih_cxr14", ).RawDataLoader - PadChestLoader = importlib.import_module( - ".datamodule", "mednet.config.data.padchest" + padchest_loader = importlib.import_module( + ".datamodule", + "mednet.config.data.padchest", ).RawDataLoader for split in ("train", "validation", "test"): assert nih_cxr14.splits[split][0][0] == combined.splits[split][0][0] - assert isinstance(nih_cxr14.splits[split][0][1], CXR14Loader) - assert isinstance(combined.splits[split][0][1], CXR14Loader) + assert isinstance(nih_cxr14.splits[split][0][1], cxr14_loader) + assert isinstance(combined.splits[split][0][1], cxr14_loader) if split != "test": # padchest has no test assert padchest.splits[split][0][0] == combined.splits[split][1][0] - assert isinstance(padchest.splits[split][0][1], PadChestLoader) - assert isinstance(combined.splits[split][1][1], PadChestLoader) + assert isinstance(padchest.splits[split][0][1], padchest_loader) + assert isinstance(combined.splits[split][1][1], padchest_loader) @pytest.mark.skip_if_rc_var_not_set("datadir.nih_cxr14") diff --git a/tests/test_padchest.py b/tests/test_padchest.py index 262b32c0291daf926dc6940750b8e2d905bea464..1687872a350a115ecb0a956c60cb192d323ccbce 100644 --- a/tests/test_padchest.py +++ b/tests/test_padchest.py @@ -6,7 +6,6 @@ import importlib import pytest - from click.testing import CliRunner @@ -30,7 +29,9 @@ def id_function(val): ids=id_function, # just changes how pytest prints it ) def test_protocol_consistency( - database_checkers, split: str, lenghts: dict[str, int] + database_checkers, + split: str, + lenghts: dict[str, int], ): from mednet.config.data.padchest.datamodule import make_split @@ -66,7 +67,8 @@ testdata = [ @pytest.mark.parametrize("name,dataset,num_labels", testdata) def test_loading(database_checkers, name: str, dataset: str, num_labels: int): datamodule = importlib.import_module( - f".{name}", "mednet.config.data.padchest" + f".{name}", + "mednet.config.data.padchest", ).datamodule datamodule.model_transforms = [] # should be done before setup() @@ -93,11 +95,12 @@ def test_loading(database_checkers, name: str, dataset: str, num_labels: int): @pytest.mark.skip_if_rc_var_not_set("datadir.padchest") def test_loaded_image_quality(database_checkers, datadir): reference_histogram_file = str( - datadir / "histograms/raw_data/histograms_padchest_idiap.json" + datadir / "histograms/raw_data/histograms_padchest_idiap.json", ) datamodule = importlib.import_module( - ".idiap", "mednet.config.data.padchest" + ".idiap", + "mednet.config.data.padchest", ).datamodule datamodule.model_transforms = [] diff --git a/tests/test_saliencymap_interpretability.py b/tests/test_saliencymap_interpretability.py index b4abd8dc6ebb0de45cf37228631db291e1f12d91..f2cc19d71898ed044c7c77ea8b609892bbebb738 100644 --- a/tests/test_saliencymap_interpretability.py +++ b/tests/test_saliencymap_interpretability.py @@ -2,7 +2,6 @@ # # SPDX-License-Identifier: GPL-3.0-or-later import numpy as np - from mednet.config.data.tbx11k.datamodule import BoundingBox, BoundingBoxes from mednet.engine.saliency.interpretability import ( _compute_avg_saliency_focus, @@ -90,13 +89,16 @@ def test_compute_avg_saliency_focus(): binary_mask3 = _compute_binary_mask(gt_boxes, grayscale_cams3) avg_saliency_focus = _compute_avg_saliency_focus( - grayscale_cams, binary_mask + grayscale_cams, + binary_mask, ) avg_saliency_focus2 = _compute_avg_saliency_focus( - grayscale_cams2, binary_mask2 + grayscale_cams2, + binary_mask2, ) avg_saliency_focus3 = _compute_avg_saliency_focus( - grayscale_cams3, binary_mask3 + grayscale_cams3, + binary_mask3, ) assert avg_saliency_focus == 1 @@ -111,7 +113,8 @@ def test_compute_avg_saliency_focus_no_activations(): binary_mask = _compute_binary_mask(gt_boxes, grayscale_cams) avg_saliency_focus = _compute_avg_saliency_focus( - grayscale_cams, binary_mask + grayscale_cams, + binary_mask, ) assert avg_saliency_focus == 0 @@ -124,7 +127,8 @@ def test_compute_avg_saliency_focus_zero_gt_area(): binary_mask = _compute_binary_mask(gt_boxes, grayscale_cams) avg_saliency_focus = _compute_avg_saliency_focus( - grayscale_cams, binary_mask + grayscale_cams, + binary_mask, ) assert avg_saliency_focus == 0 @@ -143,13 +147,16 @@ def test_compute_proportional_energy(): binary_mask3 = _compute_binary_mask(gt_boxes, grayscale_cams3) proportional_energy = _compute_proportional_energy( - grayscale_cams, binary_mask + grayscale_cams, + binary_mask, ) proportional_energy2 = _compute_proportional_energy( - grayscale_cams2, binary_mask2 + grayscale_cams2, + binary_mask2, ) proportional_energy3 = _compute_proportional_energy( - grayscale_cams3, binary_mask3 + grayscale_cams3, + binary_mask3, ) assert proportional_energy == 0.25 @@ -164,7 +171,8 @@ def test_compute_proportional_energy_no_activations(): binary_mask = _compute_binary_mask(gt_boxes, grayscale_cams) proportional_energy = _compute_proportional_energy( - grayscale_cams, binary_mask + grayscale_cams, + binary_mask, ) assert proportional_energy == 0 @@ -177,7 +185,8 @@ def test_compute_proportional_energy_no_gt_box(): binary_mask = _compute_binary_mask(gt_boxes, grayscale_cams) proportional_energy = _compute_proportional_energy( - grayscale_cams, binary_mask + grayscale_cams, + binary_mask, ) assert proportional_energy == 0 @@ -189,7 +198,8 @@ def test_process_sample(): gt_boxes = BoundingBoxes([gt_box_dict]) proportional_energy, avg_saliency_focus = _process_sample( - gt_boxes, grayscale_cams + gt_boxes, + grayscale_cams, ) assert proportional_energy == 0 diff --git a/tests/test_shenzhen.py b/tests/test_shenzhen.py index 3c5fc66122483c2e0c3ac2b9f258ed70d17120a2..4b006c05110d7d060738aa881159ed5b6e73814e 100644 --- a/tests/test_shenzhen.py +++ b/tests/test_shenzhen.py @@ -6,7 +6,6 @@ import importlib import pytest - from click.testing import CliRunner @@ -34,7 +33,9 @@ def id_function(val): ids=id_function, # just changes how pytest prints it ) def test_protocol_consistency( - database_checkers, split: str, lenghts: dict[str, int] + database_checkers, + split: str, + lenghts: dict[str, int], ): from mednet.config.data.shenzhen.datamodule import make_split @@ -84,7 +85,8 @@ def test_database_check(): ) def test_loading(database_checkers, name: str, dataset: str): datamodule = importlib.import_module( - f".{name}", "mednet.config.data.shenzhen" + f".{name}", + "mednet.config.data.shenzhen", ).datamodule datamodule.model_transforms = [] # should be done before setup() @@ -110,11 +112,12 @@ def test_loading(database_checkers, name: str, dataset: str): @pytest.mark.skip_if_rc_var_not_set("datadir.shenzhen") def test_loaded_image_quality(database_checkers, datadir): reference_histogram_file = str( - datadir / "histograms/raw_data/histograms_shenzhen_default.json" + datadir / "histograms/raw_data/histograms_shenzhen_default.json", ) datamodule = importlib.import_module( - ".default", "mednet.config.data.shenzhen" + ".default", + "mednet.config.data.shenzhen", ).datamodule datamodule.model_transforms = [] diff --git a/tests/test_summary.py b/tests/test_summary.py deleted file mode 100644 index 09b5398abda79804ef2fcbf6b1b191d9310dcb29..0000000000000000000000000000000000000000 --- a/tests/test_summary.py +++ /dev/null @@ -1,19 +0,0 @@ -# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> -# -# SPDX-License-Identifier: GPL-3.0-or-later - -import unittest - -import mednet.config.models.pasa as pasa_config - -from mednet.utils.summary import summary - - -class Tester(unittest.TestCase): - """Unit test for model architectures.""" - - def test_summary_driu(self): - model = pasa_config.model - s, param = summary(model) - self.assertIsInstance(s, str) - self.assertIsInstance(param, int) diff --git a/tests/test_tbpoc.py b/tests/test_tbpoc.py index c5d37bbb6899bdf6b0d0dd8bc03190436c2ce184..c9a0184a6e139d226e800cd0fe30586d59059cd9 100644 --- a/tests/test_tbpoc.py +++ b/tests/test_tbpoc.py @@ -6,7 +6,6 @@ import importlib import pytest - from click.testing import CliRunner @@ -33,7 +32,9 @@ def id_function(val): ids=id_function, # just changes how pytest prints it ) def test_protocol_consistency( - database_checkers, split: str, lenghts: dict[str, int] + database_checkers, + split: str, + lenghts: dict[str, int], ): from mednet.config.data.tbpoc.datamodule import make_split @@ -85,7 +86,8 @@ def test_database_check(): ) def test_loading(database_checkers, name: str, dataset: str): datamodule = importlib.import_module( - f".{name}", "mednet.config.data.tbpoc" + f".{name}", + "mednet.config.data.tbpoc", ).datamodule datamodule.model_transforms = [] # should be done before setup() @@ -114,11 +116,12 @@ def test_loading(database_checkers, name: str, dataset: str): @pytest.mark.skip_if_rc_var_not_set("datadir.tbpoc") def test_loaded_image_quality(database_checkers, datadir): reference_histogram_file = str( - datadir / "histograms/raw_data/histograms_tbpoc_fold_0.json" + datadir / "histograms/raw_data/histograms_tbpoc_fold_0.json", ) datamodule = importlib.import_module( - ".fold_0", "mednet.config.data.tbpoc" + ".fold_0", + "mednet.config.data.tbpoc", ).datamodule datamodule.model_transforms = [] diff --git a/tests/test_tbx11k.py b/tests/test_tbx11k.py index 96bc79e4994d7f26c98252255fb415f9709581b1..c7304fc67060eb8086df8c56d896d44de92f5322 100644 --- a/tests/test_tbx11k.py +++ b/tests/test_tbx11k.py @@ -8,12 +8,11 @@ import typing import pytest import torch - from click.testing import CliRunner def id_function(val): - if isinstance(val, (dict, tuple)): + if isinstance(val, dict | tuple): return repr(val) return repr(val) @@ -157,7 +156,7 @@ def check_loaded_batch( prefixes: typing.Sequence[str], possible_labels: typing.Sequence[int], expected_num_labels: int, - expected_image_shape: typing.Optional[tuple[int, ...]] = None, + expected_image_shape: tuple[int, ...] | None = None, ): """Check the consistency of an individual (loaded) batch. @@ -203,13 +202,15 @@ def check_loaded_batch( assert "name" in batch[1] assert all( - [any([k.startswith(j) for j in prefixes]) for k in batch[1]["name"]] + [any([k.startswith(j) for j in prefixes]) for k in batch[1]["name"]], ) assert "bounding_boxes" in batch[1] for sample, label, bboxes in zip( - batch[0], batch[1]["label"], batch[1]["bounding_boxes"] + batch[0], + batch[1]["label"], + batch[1]["bounding_boxes"], ): # there must be a sign indicated on the image, if active TB is detected if label == 1: @@ -287,7 +288,8 @@ def test_database_check(): ) def test_loading(name: str, dataset: str, prefixes: typing.Sequence[str]): datamodule = importlib.import_module( - f".{name}", "mednet.config.data.tbx11k" + f".{name}", + "mednet.config.data.tbx11k", ).datamodule datamodule.model_transforms = [] # should be done before setup() @@ -321,11 +323,12 @@ def test_loading(name: str, dataset: str, prefixes: typing.Sequence[str]): @pytest.mark.skip_if_rc_var_not_set("datadir.tbx11k") def test_loaded_image_quality(database_checkers, datadir, split): reference_histogram_file = str( - datadir / f"histograms/raw_data/histograms_tbx11k_{split}.json" + datadir / f"histograms/raw_data/histograms_tbx11k_{split}.json", ) datamodule = importlib.import_module( - f".{split}", "mednet.config.data.tbx11k" + f".{split}", + "mednet.config.data.tbx11k", ).datamodule datamodule.model_transforms = [] diff --git a/tests/test_transforms.py b/tests/test_transforms.py index 651c4123cbafefa4837411531c2584e74f0ea077..2d8faf1e659ca3fe3e536337006620cc2b88de9c 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -5,8 +5,7 @@ import numpy import PIL.Image -import torchvision.transforms.functional as F - +import torchvision.transforms.functional as F # noqa: N812 from mednet.data.augmentations import ElasticDeformation @@ -26,7 +25,9 @@ def test_elastic_deformation(datadir): # Compare both raw_deformed = (255 * numpy.asarray(raw_deformed)).astype(numpy.uint8)[ - 0, :, : + 0, + :, + :, ] raw_2 = numpy.asarray(raw_2)