From 901c950a752fb8d953b445c2a34cba04202fedb1 Mon Sep 17 00:00:00 2001 From: dcarron <daniel.carron@idiap.ch> Date: Tue, 27 Feb 2024 14:47:51 +0100 Subject: [PATCH] [common] Move reusable files to new common package --- src/medbase/__init__.py | 0 src/{mednet => medbase}/data/__init__.py | 0 src/{mednet => medbase}/data/augmentations.py | 0 src/{mednet => medbase}/data/datamodule.py | 0 src/{mednet => medbase}/data/image_utils.py | 0 src/{mednet => medbase}/data/split.py | 2 +- src/{mednet => medbase}/data/typing.py | 0 src/medbase/engine/__init__.py | 0 src/{mednet => medbase}/engine/callbacks.py | 0 src/{mednet => medbase}/engine/device.py | 0 src/{mednet => medbase}/engine/loggers.py | 0 src/{mednet => medbase}/engine/trainer.py | 0 src/medbase/utils/__init__.py | 0 src/{mednet => medbase}/utils/checkpointer.py | 0 src/{mednet => medbase}/utils/resources.py | 2 +- src/medbase/utils/summary.py | 61 +++++++++++++++++++ src/{mednet => medbase}/utils/tensorboard.py | 0 src/mednet/config/data/hivtb/datamodule.py | 11 ++-- src/mednet/config/data/indian/datamodule.py | 5 +- .../config/data/montgomery/datamodule.py | 11 ++-- .../data/montgomery_shenzhen/datamodule.py | 5 +- .../montgomery_shenzhen_indian/datamodule.py | 5 +- .../datamodule.py | 5 +- .../datamodule.py | 5 +- .../config/data/nih_cxr14/datamodule.py | 9 +-- .../data/nih_cxr14_padchest/datamodule.py | 5 +- src/mednet/config/data/padchest/datamodule.py | 11 ++-- src/mednet/config/data/shenzhen/datamodule.py | 11 ++-- src/mednet/config/data/tbpoc/datamodule.py | 11 ++-- src/mednet/config/data/tbx11k/datamodule.py | 9 +-- src/mednet/engine/predictor.py | 3 +- src/mednet/engine/saliency/completeness.py | 5 +- src/mednet/engine/saliency/generator.py | 3 +- src/mednet/models/alexnet.py | 2 +- src/mednet/models/densenet.py | 2 +- src/mednet/models/loss_weights.py | 2 + src/mednet/models/model.py | 2 +- src/mednet/models/pasa.py | 2 +- src/mednet/models/separate.py | 3 +- src/mednet/scripts/predict.py | 6 +- src/mednet/scripts/saliency/completeness.py | 5 +- src/mednet/scripts/saliency/generate.py | 6 +- src/mednet/scripts/train.py | 7 ++- src/mednet/scripts/train_analysis.py | 2 +- tests/conftest.py | 5 +- tests/test_cli.py | 14 ++--- tests/test_database_split.py | 3 +- tests/test_image_utils.py | 3 +- tests/test_summary.py | 19 ++++++ tests/test_transforms.py | 5 +- 50 files changed, 186 insertions(+), 81 deletions(-) create mode 100644 src/medbase/__init__.py rename src/{mednet => medbase}/data/__init__.py (100%) rename src/{mednet => medbase}/data/augmentations.py (100%) rename src/{mednet => medbase}/data/datamodule.py (100%) rename src/{mednet => medbase}/data/image_utils.py (100%) rename src/{mednet => medbase}/data/split.py (98%) rename src/{mednet => medbase}/data/typing.py (100%) create mode 100644 src/medbase/engine/__init__.py rename src/{mednet => medbase}/engine/callbacks.py (100%) rename src/{mednet => medbase}/engine/device.py (100%) rename src/{mednet => medbase}/engine/loggers.py (100%) rename src/{mednet => medbase}/engine/trainer.py (100%) create mode 100644 src/medbase/utils/__init__.py rename src/{mednet => medbase}/utils/checkpointer.py (100%) rename src/{mednet => medbase}/utils/resources.py (99%) create mode 100644 src/medbase/utils/summary.py rename src/{mednet => medbase}/utils/tensorboard.py (100%) create mode 100644 tests/test_summary.py diff --git a/src/medbase/__init__.py b/src/medbase/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/mednet/data/__init__.py b/src/medbase/data/__init__.py similarity index 100% rename from src/mednet/data/__init__.py rename to src/medbase/data/__init__.py diff --git a/src/mednet/data/augmentations.py b/src/medbase/data/augmentations.py similarity index 100% rename from src/mednet/data/augmentations.py rename to src/medbase/data/augmentations.py diff --git a/src/mednet/data/datamodule.py b/src/medbase/data/datamodule.py similarity index 100% rename from src/mednet/data/datamodule.py rename to src/medbase/data/datamodule.py diff --git a/src/mednet/data/image_utils.py b/src/medbase/data/image_utils.py similarity index 100% rename from src/mednet/data/image_utils.py rename to src/medbase/data/image_utils.py diff --git a/src/mednet/data/split.py b/src/medbase/data/split.py similarity index 98% rename from src/mednet/data/split.py rename to src/medbase/data/split.py index 3bf8b8ca..0bcdf300 100644 --- a/src/mednet/data/split.py +++ b/src/medbase/data/split.py @@ -12,7 +12,7 @@ import typing import torch -from .typing import DatabaseSplit, RawDataLoader +from medbase.data.typing import DatabaseSplit, RawDataLoader logger = logging.getLogger(__name__) diff --git a/src/mednet/data/typing.py b/src/medbase/data/typing.py similarity index 100% rename from src/mednet/data/typing.py rename to src/medbase/data/typing.py diff --git a/src/medbase/engine/__init__.py b/src/medbase/engine/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/mednet/engine/callbacks.py b/src/medbase/engine/callbacks.py similarity index 100% rename from src/mednet/engine/callbacks.py rename to src/medbase/engine/callbacks.py diff --git a/src/mednet/engine/device.py b/src/medbase/engine/device.py similarity index 100% rename from src/mednet/engine/device.py rename to src/medbase/engine/device.py diff --git a/src/mednet/engine/loggers.py b/src/medbase/engine/loggers.py similarity index 100% rename from src/mednet/engine/loggers.py rename to src/medbase/engine/loggers.py diff --git a/src/mednet/engine/trainer.py b/src/medbase/engine/trainer.py similarity index 100% rename from src/mednet/engine/trainer.py rename to src/medbase/engine/trainer.py diff --git a/src/medbase/utils/__init__.py b/src/medbase/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/mednet/utils/checkpointer.py b/src/medbase/utils/checkpointer.py similarity index 100% rename from src/mednet/utils/checkpointer.py rename to src/medbase/utils/checkpointer.py diff --git a/src/mednet/utils/resources.py b/src/medbase/utils/resources.py similarity index 99% rename from src/mednet/utils/resources.py rename to src/medbase/utils/resources.py index 5b8844a1..58167711 100644 --- a/src/mednet/utils/resources.py +++ b/src/medbase/utils/resources.py @@ -18,7 +18,7 @@ import warnings import numpy import psutil -from ..engine.device import SupportedPytorchDevice +from medbase.engine.device import SupportedPytorchDevice logger = logging.getLogger(__name__) diff --git a/src/medbase/utils/summary.py b/src/medbase/utils/summary.py new file mode 100644 index 00000000..bff705e3 --- /dev/null +++ b/src/medbase/utils/summary.py @@ -0,0 +1,61 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +# Adapted from https://github.com/pytorch/pytorch/issues/2001#issuecomment-405675488 + +from functools import reduce + +import torch + +from torch.nn.modules.module import _addindent + + +# ignore this space! +def _repr(model: torch.nn.Module) -> tuple[str, int]: + # We treat the extra repr like the sub-module, one item per line + extra_lines = [] + extra_repr = model.extra_repr() + # empty string will be split into list [''] + if extra_repr: + extra_lines = extra_repr.split("\n") + child_lines = [] + total_params = 0 + for key, module in model._modules.items(): + mod_str, num_params = _repr(module) + mod_str = _addindent(mod_str, 2) + child_lines.append("(" + key + "): " + mod_str) + total_params += num_params + lines = extra_lines + child_lines + + for _, p in model._parameters.items(): + if hasattr(p, "dtype"): + total_params += reduce(lambda x, y: x * y, p.shape) + + main_str = model._get_name() + "(" + if lines: + # simple one-liner info, which most builtin Modules will use + if len(extra_lines) == 1 and not child_lines: + main_str += extra_lines[0] + else: + main_str += "\n " + "\n ".join(lines) + "\n" + + main_str += ")" + main_str += f", {total_params:,} params" + return main_str, total_params + + +def summary(model: torch.nn.Module) -> tuple[str, int]: + """Count the number of parameters in each model layer. + + Parameters + ---------- + model + Model to summarize. + + Returns + ------- + tuple[int, str] + A tuple containing a multiline string representation of the network and the number of parameters. + """ + return _repr(model) diff --git a/src/mednet/utils/tensorboard.py b/src/medbase/utils/tensorboard.py similarity index 100% rename from src/mednet/utils/tensorboard.py rename to src/medbase/utils/tensorboard.py diff --git a/src/mednet/config/data/hivtb/datamodule.py b/src/mednet/config/data/hivtb/datamodule.py index 6c674085..21f773b2 100644 --- a/src/mednet/config/data/hivtb/datamodule.py +++ b/src/mednet/config/data/hivtb/datamodule.py @@ -12,11 +12,12 @@ import pathlib import PIL.Image from torchvision.transforms.functional import to_tensor -from ....data.datamodule import CachingDataModule -from ....data.image_utils import remove_black_borders -from ....data.split import make_split -from ....data.typing import RawDataLoader as _BaseRawDataLoader -from ....data.typing import Sample +from medbase.data.datamodule import CachingDataModule +from medbase.data.image_utils import remove_black_borders +from medbase.data.split import make_split +from medbase.data.typing import RawDataLoader as _BaseRawDataLoader +from medbase.data.typing import Sample + from ....utils.rc import load_rc CONFIGURATION_KEY_DATADIR = "datadir." + (__name__.rsplit(".", 2)[-2]) diff --git a/src/mednet/config/data/indian/datamodule.py b/src/mednet/config/data/indian/datamodule.py index fa285170..f6825bba 100644 --- a/src/mednet/config/data/indian/datamodule.py +++ b/src/mednet/config/data/indian/datamodule.py @@ -8,9 +8,10 @@ Database reference: [INDIAN-2013]_ import pathlib +from medbase.data.datamodule import CachingDataModule +from medbase.data.split import make_split + from ....config.data.shenzhen.datamodule import RawDataLoader -from ....data.datamodule import CachingDataModule -from ....data.split import make_split CONFIGURATION_KEY_DATADIR = "datadir." + (__name__.rsplit(".", 2)[-2]) """Key to search for in the configuration file for the root directory of this diff --git a/src/mednet/config/data/montgomery/datamodule.py b/src/mednet/config/data/montgomery/datamodule.py index 98d591ae..64b2f7ff 100644 --- a/src/mednet/config/data/montgomery/datamodule.py +++ b/src/mednet/config/data/montgomery/datamodule.py @@ -12,11 +12,12 @@ import pathlib import PIL.Image from torchvision.transforms.functional import to_tensor -from ....data.datamodule import CachingDataModule -from ....data.image_utils import remove_black_borders -from ....data.split import make_split -from ....data.typing import RawDataLoader as _BaseRawDataLoader -from ....data.typing import Sample +from medbase.data.datamodule import CachingDataModule +from medbase.data.image_utils import remove_black_borders +from medbase.data.split import make_split +from medbase.data.typing import Sample +from medbase.data.typing import RawDataLoader as _BaseRawDataLoader + from ....utils.rc import load_rc CONFIGURATION_KEY_DATADIR = "datadir." + (__name__.rsplit(".", 2)[-2]) diff --git a/src/mednet/config/data/montgomery_shenzhen/datamodule.py b/src/mednet/config/data/montgomery_shenzhen/datamodule.py index f726a431..16dfec25 100644 --- a/src/mednet/config/data/montgomery_shenzhen/datamodule.py +++ b/src/mednet/config/data/montgomery_shenzhen/datamodule.py @@ -5,8 +5,9 @@ import pathlib -from ....data.datamodule import ConcatDataModule -from ....data.split import make_split +from medbase.data.datamodule import ConcatDataModule +from medbase.data.split import make_split + from ..montgomery.datamodule import RawDataLoader as MontgomeryLoader from ..shenzhen.datamodule import RawDataLoader as ShenzhenLoader diff --git a/src/mednet/config/data/montgomery_shenzhen_indian/datamodule.py b/src/mednet/config/data/montgomery_shenzhen_indian/datamodule.py index f75eb1e5..bfeb1a33 100644 --- a/src/mednet/config/data/montgomery_shenzhen_indian/datamodule.py +++ b/src/mednet/config/data/montgomery_shenzhen_indian/datamodule.py @@ -7,8 +7,9 @@ databases. import pathlib -from ....data.datamodule import ConcatDataModule -from ....data.split import make_split +from medbase.data.datamodule import ConcatDataModule +from medbase.data.split import make_split + from ..indian.datamodule import CONFIGURATION_KEY_DATADIR as INDIAN_KEY_DATADIR from ..indian.datamodule import DataModule as IndianDataModule from ..indian.datamodule import RawDataLoader as IndianLoader diff --git a/src/mednet/config/data/montgomery_shenzhen_indian_padchest/datamodule.py b/src/mednet/config/data/montgomery_shenzhen_indian_padchest/datamodule.py index 060f6231..268ba2a2 100644 --- a/src/mednet/config/data/montgomery_shenzhen_indian_padchest/datamodule.py +++ b/src/mednet/config/data/montgomery_shenzhen_indian_padchest/datamodule.py @@ -7,8 +7,9 @@ datasets. import pathlib -from ....data.datamodule import ConcatDataModule -from ....data.split import make_split +from medbase.data.datamodule import ConcatDataModule +from medbase.data.split import make_split + from ..indian.datamodule import CONFIGURATION_KEY_DATADIR as INDIAN_KEY_DATADIR from ..indian.datamodule import DataModule as IndianDataModule from ..indian.datamodule import RawDataLoader as IndianLoader diff --git a/src/mednet/config/data/montgomery_shenzhen_indian_tbx11k/datamodule.py b/src/mednet/config/data/montgomery_shenzhen_indian_tbx11k/datamodule.py index db2f7901..932e6571 100644 --- a/src/mednet/config/data/montgomery_shenzhen_indian_tbx11k/datamodule.py +++ b/src/mednet/config/data/montgomery_shenzhen_indian_tbx11k/datamodule.py @@ -7,8 +7,9 @@ datasets. import pathlib -from ....data.datamodule import ConcatDataModule -from ....data.split import make_split +from medbase.data.datamodule import ConcatDataModule +from medbase.data.split import make_split + from ..indian.datamodule import CONFIGURATION_KEY_DATADIR as INDIAN_KEY_DATADIR from ..indian.datamodule import DataModule as IndianDataModule from ..indian.datamodule import RawDataLoader as IndianLoader diff --git a/src/mednet/config/data/nih_cxr14/datamodule.py b/src/mednet/config/data/nih_cxr14/datamodule.py index 6d6b283e..8e72d6d4 100644 --- a/src/mednet/config/data/nih_cxr14/datamodule.py +++ b/src/mednet/config/data/nih_cxr14/datamodule.py @@ -12,10 +12,11 @@ import pathlib import PIL.Image from torchvision.transforms.functional import to_tensor -from ....data.datamodule import CachingDataModule -from ....data.split import make_split -from ....data.typing import RawDataLoader as _BaseRawDataLoader -from ....data.typing import Sample +from medbase.data.datamodule import CachingDataModule +from medbase.data.split import make_split +from medbase.data.typing import Sample +from medbase.data.typing import RawDataLoader as _BaseRawDataLoader + from ....utils.rc import load_rc CONFIGURATION_KEY_DATADIR = "datadir." + (__name__.rsplit(".", 2)[-2]) diff --git a/src/mednet/config/data/nih_cxr14_padchest/datamodule.py b/src/mednet/config/data/nih_cxr14_padchest/datamodule.py index bfd17daf..3de000c1 100644 --- a/src/mednet/config/data/nih_cxr14_padchest/datamodule.py +++ b/src/mednet/config/data/nih_cxr14_padchest/datamodule.py @@ -5,8 +5,9 @@ import pathlib -from ....data.datamodule import ConcatDataModule -from ....data.split import make_split +from medbase.data.datamodule import ConcatDataModule +from medbase.data.split import make_split + from ..nih_cxr14.datamodule import RawDataLoader as CXR14Loader from ..padchest.datamodule import RawDataLoader as PadchestLoader diff --git a/src/mednet/config/data/padchest/datamodule.py b/src/mednet/config/data/padchest/datamodule.py index 97c80a13..99545ee2 100644 --- a/src/mednet/config/data/padchest/datamodule.py +++ b/src/mednet/config/data/padchest/datamodule.py @@ -13,11 +13,12 @@ import numpy import PIL.Image from torchvision.transforms.functional import to_tensor -from ....data.datamodule import CachingDataModule -from ....data.image_utils import remove_black_borders -from ....data.split import make_split -from ....data.typing import RawDataLoader as _BaseRawDataLoader -from ....data.typing import Sample +from medbase.data.datamodule import CachingDataModule +from medbase.data.image_utils import remove_black_borders +from medbase.data.split import make_split +from medbase.data.typing import Sample +from medbase.data.typing import RawDataLoader as _BaseRawDataLoader + from ....utils.rc import load_rc CONFIGURATION_KEY_DATADIR = "datadir." + (__name__.rsplit(".", 2)[-2]) diff --git a/src/mednet/config/data/shenzhen/datamodule.py b/src/mednet/config/data/shenzhen/datamodule.py index c2838ce2..ae1b4639 100644 --- a/src/mednet/config/data/shenzhen/datamodule.py +++ b/src/mednet/config/data/shenzhen/datamodule.py @@ -12,11 +12,12 @@ import pathlib import PIL.Image from torchvision.transforms.functional import to_tensor -from ....data.datamodule import CachingDataModule -from ....data.image_utils import remove_black_borders -from ....data.split import make_split -from ....data.typing import RawDataLoader as _BaseRawDataLoader -from ....data.typing import Sample +from medbase.data.datamodule import CachingDataModule +from medbase.data.image_utils import remove_black_borders +from medbase.data.split import make_split +from medbase.data.typing import Sample +from medbase.data.typing import RawDataLoader as _BaseRawDataLoader + from ....utils.rc import load_rc CONFIGURATION_KEY_DATADIR = "datadir." + (__name__.rsplit(".", 2)[-2]) diff --git a/src/mednet/config/data/tbpoc/datamodule.py b/src/mednet/config/data/tbpoc/datamodule.py index 3e7dcaa2..5b4b6011 100644 --- a/src/mednet/config/data/tbpoc/datamodule.py +++ b/src/mednet/config/data/tbpoc/datamodule.py @@ -8,11 +8,12 @@ import pathlib import PIL.Image from torchvision.transforms.functional import to_tensor -from ....data.datamodule import CachingDataModule -from ....data.image_utils import remove_black_borders -from ....data.split import make_split -from ....data.typing import RawDataLoader as _BaseRawDataLoader -from ....data.typing import Sample +from medbase.data.datamodule import CachingDataModule +from medbase.data.image_utils import remove_black_borders +from medbase.data.split import make_split +from medbase.data.typing import Sample +from medbase.data.typing import RawDataLoader as _BaseRawDataLoader + from ....utils.rc import load_rc CONFIGURATION_KEY_DATADIR = "datadir." + (__name__.rsplit(".", 2)[-2]) diff --git a/src/mednet/config/data/tbx11k/datamodule.py b/src/mednet/config/data/tbx11k/datamodule.py index 993f9e50..c6b28ef9 100644 --- a/src/mednet/config/data/tbx11k/datamodule.py +++ b/src/mednet/config/data/tbx11k/datamodule.py @@ -13,10 +13,11 @@ import typing_extensions from torch.utils.data._utils.collate import default_collate_fn_map from torchvision.transforms.functional import to_tensor -from ....data.datamodule import CachingDataModule -from ....data.split import make_split -from ....data.typing import RawDataLoader as _BaseRawDataLoader -from ....data.typing import Sample +from medbase.data.datamodule import CachingDataModule +from medbase.data.split import make_split +from medbase.data.typing import RawDataLoader as _BaseRawDataLoader +from medbase.data.typing import Sample + from ....utils.rc import load_rc CONFIGURATION_KEY_DATADIR = "datadir." + (__name__.rsplit(".", 2)[-2]) diff --git a/src/mednet/engine/predictor.py b/src/mednet/engine/predictor.py index 1ba6f05a..fce844a1 100644 --- a/src/mednet/engine/predictor.py +++ b/src/mednet/engine/predictor.py @@ -7,13 +7,14 @@ import logging import lightning.pytorch import torch.utils.data +from medbase.engine.device import DeviceManager + from ..models.typing import ( BinaryPrediction, BinaryPredictionSplit, MultiClassPrediction, MultiClassPredictionSplit, ) -from .device import DeviceManager logger = logging.getLogger(__name__) diff --git a/src/mednet/engine/saliency/completeness.py b/src/mednet/engine/saliency/completeness.py index eb2832bc..7c3cb3c6 100644 --- a/src/mednet/engine/saliency/completeness.py +++ b/src/mednet/engine/saliency/completeness.py @@ -17,9 +17,10 @@ from pytorch_grad_cam.metrics.road import ( ) from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget -from ...data.typing import Sample +from medbase.data.typing import Sample +from medbase.engine.device import DeviceManager + from ...models.typing import SaliencyMapAlgorithm -from ..device import DeviceManager logger = logging.getLogger(__name__) diff --git a/src/mednet/engine/saliency/generator.py b/src/mednet/engine/saliency/generator.py index e1c03af7..f440eb08 100644 --- a/src/mednet/engine/saliency/generator.py +++ b/src/mednet/engine/saliency/generator.py @@ -12,8 +12,9 @@ import torch import torch.nn import tqdm +from medbase.engine.device import DeviceManager + from ...models.typing import SaliencyMapAlgorithm -from ..device import DeviceManager logger = logging.getLogger(__name__) diff --git a/src/mednet/models/alexnet.py b/src/mednet/models/alexnet.py index 3e58463e..22c41547 100644 --- a/src/mednet/models/alexnet.py +++ b/src/mednet/models/alexnet.py @@ -12,7 +12,7 @@ import torch.utils.data import torchvision.models as models import torchvision.transforms -from ..data.typing import TransformSequence +from medbase.data.typing import TransformSequence from .model import Model from .separate import separate from .transforms import RGB, SquareCenterPad diff --git a/src/mednet/models/densenet.py b/src/mednet/models/densenet.py index c54d5d9a..3feb2ade 100644 --- a/src/mednet/models/densenet.py +++ b/src/mednet/models/densenet.py @@ -12,7 +12,7 @@ import torch.utils.data import torchvision.models as models import torchvision.transforms -from ..data.typing import TransformSequence +from medbase.data.typing import TransformSequence from .model import Model from .separate import separate from .transforms import RGB, SquareCenterPad diff --git a/src/mednet/models/loss_weights.py b/src/mednet/models/loss_weights.py index 2cf5e292..382a7ae5 100644 --- a/src/mednet/models/loss_weights.py +++ b/src/mednet/models/loss_weights.py @@ -9,6 +9,8 @@ from collections import Counter import torch import torch.utils.data +from medbase.data.typing import DataLoader + logger = logging.getLogger(__name__) diff --git a/src/mednet/models/model.py b/src/mednet/models/model.py index 560b1115..e1b264be 100644 --- a/src/mednet/models/model.py +++ b/src/mednet/models/model.py @@ -12,7 +12,7 @@ import torch.optim.optimizer import torch.utils.data import torchvision.transforms -from ..data.typing import TransformSequence +from medbase.data.typing import TransformSequence from .loss_weights import get_positive_weights from .typing import Checkpoint diff --git a/src/mednet/models/pasa.py b/src/mednet/models/pasa.py index cb7ebfea..0e1d24f1 100644 --- a/src/mednet/models/pasa.py +++ b/src/mednet/models/pasa.py @@ -12,7 +12,7 @@ import torch.optim.optimizer import torch.utils.data import torchvision.transforms -from ..data.typing import TransformSequence +from medbase.data.typing import TransformSequence from .model import Model from .separate import separate from .transforms import Grayscale, SquareCenterPad diff --git a/src/mednet/models/separate.py b/src/mednet/models/separate.py index c79820bf..6c7afabb 100644 --- a/src/mednet/models/separate.py +++ b/src/mednet/models/separate.py @@ -7,7 +7,8 @@ import typing import torch -from ..data.typing import Sample +from medbase.data.typing import Sample + from .typing import BinaryPrediction, MultiClassPrediction diff --git a/src/mednet/scripts/predict.py b/src/mednet/scripts/predict.py index 5645685a..ffd42aff 100644 --- a/src/mednet/scripts/predict.py +++ b/src/mednet/scripts/predict.py @@ -134,9 +134,9 @@ def predict( import shutil import typing - from ..engine.device import DeviceManager - from ..engine.predictor import run - from ..utils.checkpointer import get_checkpoint_to_run_inference + from medbase.engine.device import DeviceManager + from mednet.engine.predictor import run + from medbase.utils.checkpointer import get_checkpoint_to_run_inference from .utils import ( device_properties, execution_metadata, diff --git a/src/mednet/scripts/saliency/completeness.py b/src/mednet/scripts/saliency/completeness.py index e02aa3b6..28a33695 100644 --- a/src/mednet/scripts/saliency/completeness.py +++ b/src/mednet/scripts/saliency/completeness.py @@ -203,9 +203,10 @@ def completeness( """ import json - from ...engine.device import DeviceManager + from medbase.engine.device import DeviceManager + from medbase.utils.checkpointer import get_checkpoint_to_run_inference + from ...engine.saliency.completeness import run - from ...utils.checkpointer import get_checkpoint_to_run_inference if device in ("cuda", "mps") and (parallel == 0 or parallel > 1): raise RuntimeError( diff --git a/src/mednet/scripts/saliency/generate.py b/src/mednet/scripts/saliency/generate.py index 649ad96b..34fb8382 100644 --- a/src/mednet/scripts/saliency/generate.py +++ b/src/mednet/scripts/saliency/generate.py @@ -168,9 +168,11 @@ def generate( The quality of saliency information depends on the saliency map algorithm and trained model. """ - from ...engine.device import DeviceManager + + from medbase.engine.device import DeviceManager + from medbase.utils.checkpointer import get_checkpoint_to_run_inference + from ...engine.saliency.generator import run - from ...utils.checkpointer import get_checkpoint_to_run_inference logger.info(f"Output folder: {output_folder}") output_folder.mkdir(parents=True, exist_ok=True) diff --git a/src/mednet/scripts/train.py b/src/mednet/scripts/train.py index 42ecf1b8..df8b1fec 100644 --- a/src/mednet/scripts/train.py +++ b/src/mednet/scripts/train.py @@ -272,9 +272,10 @@ def train( import torch from lightning.pytorch import seed_everything - from ..engine.device import DeviceManager - from ..engine.trainer import run - from ..utils.checkpointer import get_checkpoint_to_resume_training + from medbase.engine.device import DeviceManager + from medbase.engine.trainer import run + from medbase.utils.checkpointer import get_checkpoint_to_resume_training + from .utils import ( device_properties, execution_metadata, diff --git a/src/mednet/scripts/train_analysis.py b/src/mednet/scripts/train_analysis.py index 80e74e44..3938e226 100644 --- a/src/mednet/scripts/train_analysis.py +++ b/src/mednet/scripts/train_analysis.py @@ -231,7 +231,7 @@ def train_analysis( import matplotlib.pyplot as plt from matplotlib.backends.backend_pdf import PdfPages - from ..utils.tensorboard import scalars_to_dict + from medbase.utils.tensorboard import scalars_to_dict data = scalars_to_dict(logdir) diff --git a/tests/conftest.py b/tests/conftest.py index a2127bcb..89b055ec 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -10,8 +10,9 @@ import numpy import numpy.typing import pytest import torch -from mednet.data.split import JSONDatabaseSplit -from mednet.data.typing import DatabaseSplit + +from medbase.data.split import JSONDatabaseSplit +from medbase.data.typing import DatabaseSplit @pytest.fixture diff --git a/tests/test_cli.py b/tests/test_cli.py index e79d500d..ab6abe57 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -204,11 +204,11 @@ def test_upload_help(): @pytest.mark.slow @pytest.mark.skip_if_rc_var_not_set("datadir.montgomery") def test_train_pasa_montgomery(temporary_basedir): - from mednet.scripts.train import train - from mednet.utils.checkpointer import ( + from medbase.utils.checkpointer import ( CHECKPOINT_EXTENSION, _get_checkpoint_from_alias, ) + from mednet.scripts.train import train runner = CliRunner() @@ -260,11 +260,11 @@ def test_train_pasa_montgomery(temporary_basedir): @pytest.mark.slow @pytest.mark.skip_if_rc_var_not_set("datadir.montgomery") def test_train_pasa_montgomery_from_checkpoint(temporary_basedir): - from mednet.scripts.train import train - from mednet.utils.checkpointer import ( + from medbase.utils.checkpointer import ( CHECKPOINT_EXTENSION, _get_checkpoint_from_alias, ) + from mednet.scripts.train import train runner = CliRunner() @@ -337,12 +337,12 @@ def test_train_pasa_montgomery_from_checkpoint(temporary_basedir): @pytest.mark.slow @pytest.mark.skip_if_rc_var_not_set("datadir.montgomery") -def test_predict_pasa_montgomery(temporary_basedir): - from mednet.scripts.predict import predict - from mednet.utils.checkpointer import ( +def test_predict_pasa_montgomery(temporary_basedir, datadir): + from medbase.utils.checkpointer import ( CHECKPOINT_EXTENSION, _get_checkpoint_from_alias, ) + from mednet.scripts.predict import predict runner = CliRunner() diff --git a/tests/test_database_split.py b/tests/test_database_split.py index 8f50ed99..7ebe400a 100644 --- a/tests/test_database_split.py +++ b/tests/test_database_split.py @@ -3,8 +3,7 @@ # SPDX-License-Identifier: GPL-3.0-or-later """Test code for datasets.""" -from mednet.data.split import JSONDatabaseSplit - +from medbase.data.split import JSONDatabaseSplit def test_json_loading(datadir): # tests if we can build a simple JSON loader for the Iris Flower dataset diff --git a/tests/test_image_utils.py b/tests/test_image_utils.py index ab431eca..d81501ad 100644 --- a/tests/test_image_utils.py +++ b/tests/test_image_utils.py @@ -5,7 +5,8 @@ import numpy import PIL.Image -from mednet.data.image_utils import remove_black_borders + +from medbase.data.image_utils import remove_black_borders def test_remove_black_borders(datadir): diff --git a/tests/test_summary.py b/tests/test_summary.py new file mode 100644 index 00000000..4e420eb2 --- /dev/null +++ b/tests/test_summary.py @@ -0,0 +1,19 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +import unittest + +import mednet.config.models.pasa as pasa_config + +from medbase.utils.summary import summary + + +class Tester(unittest.TestCase): + """Unit test for model architectures.""" + + def test_summary_driu(self): + model = pasa_config.model + s, param = summary(model) + self.assertIsInstance(s, str) + self.assertIsInstance(param, int) diff --git a/tests/test_transforms.py b/tests/test_transforms.py index 2d8faf1e..85f59958 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -5,8 +5,9 @@ import numpy import PIL.Image -import torchvision.transforms.functional as F # noqa: N812 -from mednet.data.augmentations import ElasticDeformation +import torchvision.transforms.functional as F # noqa: N812 + +from medbase.data.augmentations import ElasticDeformation def test_elastic_deformation(datadir): -- GitLab