diff --git a/src/medbase/__init__.py b/src/medbase/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/mednet/data/__init__.py b/src/medbase/data/__init__.py similarity index 100% rename from src/mednet/data/__init__.py rename to src/medbase/data/__init__.py diff --git a/src/mednet/data/augmentations.py b/src/medbase/data/augmentations.py similarity index 100% rename from src/mednet/data/augmentations.py rename to src/medbase/data/augmentations.py diff --git a/src/mednet/data/datamodule.py b/src/medbase/data/datamodule.py similarity index 100% rename from src/mednet/data/datamodule.py rename to src/medbase/data/datamodule.py diff --git a/src/mednet/data/image_utils.py b/src/medbase/data/image_utils.py similarity index 100% rename from src/mednet/data/image_utils.py rename to src/medbase/data/image_utils.py diff --git a/src/mednet/data/split.py b/src/medbase/data/split.py similarity index 98% rename from src/mednet/data/split.py rename to src/medbase/data/split.py index 3bf8b8ca16d870a9987512c6897765f5ba1257fd..0bcdf300850cff67c5e596245d3ba53b59f47d9c 100644 --- a/src/mednet/data/split.py +++ b/src/medbase/data/split.py @@ -12,7 +12,7 @@ import typing import torch -from .typing import DatabaseSplit, RawDataLoader +from medbase.data.typing import DatabaseSplit, RawDataLoader logger = logging.getLogger(__name__) diff --git a/src/mednet/data/typing.py b/src/medbase/data/typing.py similarity index 100% rename from src/mednet/data/typing.py rename to src/medbase/data/typing.py diff --git a/src/medbase/engine/__init__.py b/src/medbase/engine/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/mednet/engine/callbacks.py b/src/medbase/engine/callbacks.py similarity index 100% rename from src/mednet/engine/callbacks.py rename to src/medbase/engine/callbacks.py diff --git a/src/mednet/engine/device.py b/src/medbase/engine/device.py similarity index 100% rename from src/mednet/engine/device.py rename to src/medbase/engine/device.py diff --git a/src/mednet/engine/loggers.py b/src/medbase/engine/loggers.py similarity index 100% rename from src/mednet/engine/loggers.py rename to src/medbase/engine/loggers.py diff --git a/src/mednet/engine/trainer.py b/src/medbase/engine/trainer.py similarity index 100% rename from src/mednet/engine/trainer.py rename to src/medbase/engine/trainer.py diff --git a/src/medbase/utils/__init__.py b/src/medbase/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/mednet/utils/checkpointer.py b/src/medbase/utils/checkpointer.py similarity index 100% rename from src/mednet/utils/checkpointer.py rename to src/medbase/utils/checkpointer.py diff --git a/src/mednet/utils/resources.py b/src/medbase/utils/resources.py similarity index 99% rename from src/mednet/utils/resources.py rename to src/medbase/utils/resources.py index 5b8844a15e02a277a293770ed55e7fb3a43ec8d7..5816771157435564113721bd44f6619ffc780ebb 100644 --- a/src/mednet/utils/resources.py +++ b/src/medbase/utils/resources.py @@ -18,7 +18,7 @@ import warnings import numpy import psutil -from ..engine.device import SupportedPytorchDevice +from medbase.engine.device import SupportedPytorchDevice logger = logging.getLogger(__name__) diff --git a/src/medbase/utils/summary.py b/src/medbase/utils/summary.py new file mode 100644 index 0000000000000000000000000000000000000000..bff705e30b557a0314dfef929535671e3dad7f81 --- /dev/null +++ b/src/medbase/utils/summary.py @@ -0,0 +1,61 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +# Adapted from https://github.com/pytorch/pytorch/issues/2001#issuecomment-405675488 + +from functools import reduce + +import torch + +from torch.nn.modules.module import _addindent + + +# ignore this space! +def _repr(model: torch.nn.Module) -> tuple[str, int]: + # We treat the extra repr like the sub-module, one item per line + extra_lines = [] + extra_repr = model.extra_repr() + # empty string will be split into list [''] + if extra_repr: + extra_lines = extra_repr.split("\n") + child_lines = [] + total_params = 0 + for key, module in model._modules.items(): + mod_str, num_params = _repr(module) + mod_str = _addindent(mod_str, 2) + child_lines.append("(" + key + "): " + mod_str) + total_params += num_params + lines = extra_lines + child_lines + + for _, p in model._parameters.items(): + if hasattr(p, "dtype"): + total_params += reduce(lambda x, y: x * y, p.shape) + + main_str = model._get_name() + "(" + if lines: + # simple one-liner info, which most builtin Modules will use + if len(extra_lines) == 1 and not child_lines: + main_str += extra_lines[0] + else: + main_str += "\n " + "\n ".join(lines) + "\n" + + main_str += ")" + main_str += f", {total_params:,} params" + return main_str, total_params + + +def summary(model: torch.nn.Module) -> tuple[str, int]: + """Count the number of parameters in each model layer. + + Parameters + ---------- + model + Model to summarize. + + Returns + ------- + tuple[int, str] + A tuple containing a multiline string representation of the network and the number of parameters. + """ + return _repr(model) diff --git a/src/mednet/utils/tensorboard.py b/src/medbase/utils/tensorboard.py similarity index 100% rename from src/mednet/utils/tensorboard.py rename to src/medbase/utils/tensorboard.py diff --git a/src/mednet/config/data/hivtb/datamodule.py b/src/mednet/config/data/hivtb/datamodule.py index 6c674085511dec84f7cd36f26d5e1f48ed285c13..21f773b24f4695da7ac4dab8b7a3bb8a76606f1a 100644 --- a/src/mednet/config/data/hivtb/datamodule.py +++ b/src/mednet/config/data/hivtb/datamodule.py @@ -12,11 +12,12 @@ import pathlib import PIL.Image from torchvision.transforms.functional import to_tensor -from ....data.datamodule import CachingDataModule -from ....data.image_utils import remove_black_borders -from ....data.split import make_split -from ....data.typing import RawDataLoader as _BaseRawDataLoader -from ....data.typing import Sample +from medbase.data.datamodule import CachingDataModule +from medbase.data.image_utils import remove_black_borders +from medbase.data.split import make_split +from medbase.data.typing import RawDataLoader as _BaseRawDataLoader +from medbase.data.typing import Sample + from ....utils.rc import load_rc CONFIGURATION_KEY_DATADIR = "datadir." + (__name__.rsplit(".", 2)[-2]) diff --git a/src/mednet/config/data/indian/datamodule.py b/src/mednet/config/data/indian/datamodule.py index fa285170aa1d2d5212621bb5932783158614c725..f6825bba1e666abef37afea034aa5b29eb4081ec 100644 --- a/src/mednet/config/data/indian/datamodule.py +++ b/src/mednet/config/data/indian/datamodule.py @@ -8,9 +8,10 @@ Database reference: [INDIAN-2013]_ import pathlib +from medbase.data.datamodule import CachingDataModule +from medbase.data.split import make_split + from ....config.data.shenzhen.datamodule import RawDataLoader -from ....data.datamodule import CachingDataModule -from ....data.split import make_split CONFIGURATION_KEY_DATADIR = "datadir." + (__name__.rsplit(".", 2)[-2]) """Key to search for in the configuration file for the root directory of this diff --git a/src/mednet/config/data/montgomery/datamodule.py b/src/mednet/config/data/montgomery/datamodule.py index 98d591ae5f00dfcc5b1bcd8c5f657482a7492b0b..64b2f7ff2ec0d1b564e9439e23ede7411ce9aab0 100644 --- a/src/mednet/config/data/montgomery/datamodule.py +++ b/src/mednet/config/data/montgomery/datamodule.py @@ -12,11 +12,12 @@ import pathlib import PIL.Image from torchvision.transforms.functional import to_tensor -from ....data.datamodule import CachingDataModule -from ....data.image_utils import remove_black_borders -from ....data.split import make_split -from ....data.typing import RawDataLoader as _BaseRawDataLoader -from ....data.typing import Sample +from medbase.data.datamodule import CachingDataModule +from medbase.data.image_utils import remove_black_borders +from medbase.data.split import make_split +from medbase.data.typing import Sample +from medbase.data.typing import RawDataLoader as _BaseRawDataLoader + from ....utils.rc import load_rc CONFIGURATION_KEY_DATADIR = "datadir." + (__name__.rsplit(".", 2)[-2]) diff --git a/src/mednet/config/data/montgomery_shenzhen/datamodule.py b/src/mednet/config/data/montgomery_shenzhen/datamodule.py index f726a43159968737de4e5b768fef852ab364742d..16dfec25d7f87606a837cdc8b79a944aa3b26b7d 100644 --- a/src/mednet/config/data/montgomery_shenzhen/datamodule.py +++ b/src/mednet/config/data/montgomery_shenzhen/datamodule.py @@ -5,8 +5,9 @@ import pathlib -from ....data.datamodule import ConcatDataModule -from ....data.split import make_split +from medbase.data.datamodule import ConcatDataModule +from medbase.data.split import make_split + from ..montgomery.datamodule import RawDataLoader as MontgomeryLoader from ..shenzhen.datamodule import RawDataLoader as ShenzhenLoader diff --git a/src/mednet/config/data/montgomery_shenzhen_indian/datamodule.py b/src/mednet/config/data/montgomery_shenzhen_indian/datamodule.py index f75eb1e5b883aefde630061e35d50a8a2e8d9dba..bfeb1a338ba93bffa0a959692ae06a5e47857f18 100644 --- a/src/mednet/config/data/montgomery_shenzhen_indian/datamodule.py +++ b/src/mednet/config/data/montgomery_shenzhen_indian/datamodule.py @@ -7,8 +7,9 @@ databases. import pathlib -from ....data.datamodule import ConcatDataModule -from ....data.split import make_split +from medbase.data.datamodule import ConcatDataModule +from medbase.data.split import make_split + from ..indian.datamodule import CONFIGURATION_KEY_DATADIR as INDIAN_KEY_DATADIR from ..indian.datamodule import DataModule as IndianDataModule from ..indian.datamodule import RawDataLoader as IndianLoader diff --git a/src/mednet/config/data/montgomery_shenzhen_indian_padchest/datamodule.py b/src/mednet/config/data/montgomery_shenzhen_indian_padchest/datamodule.py index 060f623115d68dccf5b4451b94faf8746fc57b00..268ba2a261d2b7a667d8e1c9068c033f8b005504 100644 --- a/src/mednet/config/data/montgomery_shenzhen_indian_padchest/datamodule.py +++ b/src/mednet/config/data/montgomery_shenzhen_indian_padchest/datamodule.py @@ -7,8 +7,9 @@ datasets. import pathlib -from ....data.datamodule import ConcatDataModule -from ....data.split import make_split +from medbase.data.datamodule import ConcatDataModule +from medbase.data.split import make_split + from ..indian.datamodule import CONFIGURATION_KEY_DATADIR as INDIAN_KEY_DATADIR from ..indian.datamodule import DataModule as IndianDataModule from ..indian.datamodule import RawDataLoader as IndianLoader diff --git a/src/mednet/config/data/montgomery_shenzhen_indian_tbx11k/datamodule.py b/src/mednet/config/data/montgomery_shenzhen_indian_tbx11k/datamodule.py index db2f79010073dbbbefbcbbfb1a75dae3b283bea8..932e65710cbb1e61ad407fb1c8d121607af6f354 100644 --- a/src/mednet/config/data/montgomery_shenzhen_indian_tbx11k/datamodule.py +++ b/src/mednet/config/data/montgomery_shenzhen_indian_tbx11k/datamodule.py @@ -7,8 +7,9 @@ datasets. import pathlib -from ....data.datamodule import ConcatDataModule -from ....data.split import make_split +from medbase.data.datamodule import ConcatDataModule +from medbase.data.split import make_split + from ..indian.datamodule import CONFIGURATION_KEY_DATADIR as INDIAN_KEY_DATADIR from ..indian.datamodule import DataModule as IndianDataModule from ..indian.datamodule import RawDataLoader as IndianLoader diff --git a/src/mednet/config/data/nih_cxr14/datamodule.py b/src/mednet/config/data/nih_cxr14/datamodule.py index 6d6b283ea01d0ddc9a1af1e91fbb5cf8479f7d70..8e72d6d414bab383452c66b5ab326cc2599ba3e8 100644 --- a/src/mednet/config/data/nih_cxr14/datamodule.py +++ b/src/mednet/config/data/nih_cxr14/datamodule.py @@ -12,10 +12,11 @@ import pathlib import PIL.Image from torchvision.transforms.functional import to_tensor -from ....data.datamodule import CachingDataModule -from ....data.split import make_split -from ....data.typing import RawDataLoader as _BaseRawDataLoader -from ....data.typing import Sample +from medbase.data.datamodule import CachingDataModule +from medbase.data.split import make_split +from medbase.data.typing import Sample +from medbase.data.typing import RawDataLoader as _BaseRawDataLoader + from ....utils.rc import load_rc CONFIGURATION_KEY_DATADIR = "datadir." + (__name__.rsplit(".", 2)[-2]) diff --git a/src/mednet/config/data/nih_cxr14_padchest/datamodule.py b/src/mednet/config/data/nih_cxr14_padchest/datamodule.py index bfd17daf1991be79a9a7041c97f315cfc2feb04c..3de000c129d19bd949cc5c00c7145364b6935200 100644 --- a/src/mednet/config/data/nih_cxr14_padchest/datamodule.py +++ b/src/mednet/config/data/nih_cxr14_padchest/datamodule.py @@ -5,8 +5,9 @@ import pathlib -from ....data.datamodule import ConcatDataModule -from ....data.split import make_split +from medbase.data.datamodule import ConcatDataModule +from medbase.data.split import make_split + from ..nih_cxr14.datamodule import RawDataLoader as CXR14Loader from ..padchest.datamodule import RawDataLoader as PadchestLoader diff --git a/src/mednet/config/data/padchest/datamodule.py b/src/mednet/config/data/padchest/datamodule.py index 97c80a13b98fa5c7339b4c83c2f4c686a5598bc6..99545ee2d0b0503974f9d210dbd968ca74571d99 100644 --- a/src/mednet/config/data/padchest/datamodule.py +++ b/src/mednet/config/data/padchest/datamodule.py @@ -13,11 +13,12 @@ import numpy import PIL.Image from torchvision.transforms.functional import to_tensor -from ....data.datamodule import CachingDataModule -from ....data.image_utils import remove_black_borders -from ....data.split import make_split -from ....data.typing import RawDataLoader as _BaseRawDataLoader -from ....data.typing import Sample +from medbase.data.datamodule import CachingDataModule +from medbase.data.image_utils import remove_black_borders +from medbase.data.split import make_split +from medbase.data.typing import Sample +from medbase.data.typing import RawDataLoader as _BaseRawDataLoader + from ....utils.rc import load_rc CONFIGURATION_KEY_DATADIR = "datadir." + (__name__.rsplit(".", 2)[-2]) diff --git a/src/mednet/config/data/shenzhen/datamodule.py b/src/mednet/config/data/shenzhen/datamodule.py index c2838ce223eb4a124e043e39cddc96ab58d5162a..ae1b463988f7d103ed963767da649732c31b37c0 100644 --- a/src/mednet/config/data/shenzhen/datamodule.py +++ b/src/mednet/config/data/shenzhen/datamodule.py @@ -12,11 +12,12 @@ import pathlib import PIL.Image from torchvision.transforms.functional import to_tensor -from ....data.datamodule import CachingDataModule -from ....data.image_utils import remove_black_borders -from ....data.split import make_split -from ....data.typing import RawDataLoader as _BaseRawDataLoader -from ....data.typing import Sample +from medbase.data.datamodule import CachingDataModule +from medbase.data.image_utils import remove_black_borders +from medbase.data.split import make_split +from medbase.data.typing import Sample +from medbase.data.typing import RawDataLoader as _BaseRawDataLoader + from ....utils.rc import load_rc CONFIGURATION_KEY_DATADIR = "datadir." + (__name__.rsplit(".", 2)[-2]) diff --git a/src/mednet/config/data/tbpoc/datamodule.py b/src/mednet/config/data/tbpoc/datamodule.py index 3e7dcaa239ad6355bd1281958fc53c4e8292a22c..5b4b6011fee91c1ad09f1b92d9388fa253c1b518 100644 --- a/src/mednet/config/data/tbpoc/datamodule.py +++ b/src/mednet/config/data/tbpoc/datamodule.py @@ -8,11 +8,12 @@ import pathlib import PIL.Image from torchvision.transforms.functional import to_tensor -from ....data.datamodule import CachingDataModule -from ....data.image_utils import remove_black_borders -from ....data.split import make_split -from ....data.typing import RawDataLoader as _BaseRawDataLoader -from ....data.typing import Sample +from medbase.data.datamodule import CachingDataModule +from medbase.data.image_utils import remove_black_borders +from medbase.data.split import make_split +from medbase.data.typing import Sample +from medbase.data.typing import RawDataLoader as _BaseRawDataLoader + from ....utils.rc import load_rc CONFIGURATION_KEY_DATADIR = "datadir." + (__name__.rsplit(".", 2)[-2]) diff --git a/src/mednet/config/data/tbx11k/datamodule.py b/src/mednet/config/data/tbx11k/datamodule.py index 993f9e50d1cbb8bdfe0d52f9d526f8e142e356e4..c6b28ef9237c68a636708312d5147feb8838b8c2 100644 --- a/src/mednet/config/data/tbx11k/datamodule.py +++ b/src/mednet/config/data/tbx11k/datamodule.py @@ -13,10 +13,11 @@ import typing_extensions from torch.utils.data._utils.collate import default_collate_fn_map from torchvision.transforms.functional import to_tensor -from ....data.datamodule import CachingDataModule -from ....data.split import make_split -from ....data.typing import RawDataLoader as _BaseRawDataLoader -from ....data.typing import Sample +from medbase.data.datamodule import CachingDataModule +from medbase.data.split import make_split +from medbase.data.typing import RawDataLoader as _BaseRawDataLoader +from medbase.data.typing import Sample + from ....utils.rc import load_rc CONFIGURATION_KEY_DATADIR = "datadir." + (__name__.rsplit(".", 2)[-2]) diff --git a/src/mednet/engine/predictor.py b/src/mednet/engine/predictor.py index 1ba6f05a12bc16b4c9a48def96b70a0472a4646a..fce844a17151cd23e0271d34e162dd7b652c9a5b 100644 --- a/src/mednet/engine/predictor.py +++ b/src/mednet/engine/predictor.py @@ -7,13 +7,14 @@ import logging import lightning.pytorch import torch.utils.data +from medbase.engine.device import DeviceManager + from ..models.typing import ( BinaryPrediction, BinaryPredictionSplit, MultiClassPrediction, MultiClassPredictionSplit, ) -from .device import DeviceManager logger = logging.getLogger(__name__) diff --git a/src/mednet/engine/saliency/completeness.py b/src/mednet/engine/saliency/completeness.py index eb2832bc3da7bfac94d1c62f14154540b1bc0215..7c3cb3c624e1965508c243051c602080632acde0 100644 --- a/src/mednet/engine/saliency/completeness.py +++ b/src/mednet/engine/saliency/completeness.py @@ -17,9 +17,10 @@ from pytorch_grad_cam.metrics.road import ( ) from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget -from ...data.typing import Sample +from medbase.data.typing import Sample +from medbase.engine.device import DeviceManager + from ...models.typing import SaliencyMapAlgorithm -from ..device import DeviceManager logger = logging.getLogger(__name__) diff --git a/src/mednet/engine/saliency/generator.py b/src/mednet/engine/saliency/generator.py index e1c03af73cc99279aba650eac2ec024b1ec6a475..f440eb086088686ef722b99323bdd62bae16f238 100644 --- a/src/mednet/engine/saliency/generator.py +++ b/src/mednet/engine/saliency/generator.py @@ -12,8 +12,9 @@ import torch import torch.nn import tqdm +from medbase.engine.device import DeviceManager + from ...models.typing import SaliencyMapAlgorithm -from ..device import DeviceManager logger = logging.getLogger(__name__) diff --git a/src/mednet/models/alexnet.py b/src/mednet/models/alexnet.py index 3e58463e62d1ceae2a54868a08dc9205b35837ca..22c415471b50adf4c684130d8b692679189181f7 100644 --- a/src/mednet/models/alexnet.py +++ b/src/mednet/models/alexnet.py @@ -12,7 +12,7 @@ import torch.utils.data import torchvision.models as models import torchvision.transforms -from ..data.typing import TransformSequence +from medbase.data.typing import TransformSequence from .model import Model from .separate import separate from .transforms import RGB, SquareCenterPad diff --git a/src/mednet/models/densenet.py b/src/mednet/models/densenet.py index c54d5d9ac772b31947ef45426a9b6b7fd62632d2..3feb2adebff2611ac6848088fb3df32c81e8d7d2 100644 --- a/src/mednet/models/densenet.py +++ b/src/mednet/models/densenet.py @@ -12,7 +12,7 @@ import torch.utils.data import torchvision.models as models import torchvision.transforms -from ..data.typing import TransformSequence +from medbase.data.typing import TransformSequence from .model import Model from .separate import separate from .transforms import RGB, SquareCenterPad diff --git a/src/mednet/models/loss_weights.py b/src/mednet/models/loss_weights.py index 2cf5e292f81f4ee3ba048a518c21a367278af669..382a7ae5123c44fdd91a4eebe2c9d93d7c17e805 100644 --- a/src/mednet/models/loss_weights.py +++ b/src/mednet/models/loss_weights.py @@ -9,6 +9,8 @@ from collections import Counter import torch import torch.utils.data +from medbase.data.typing import DataLoader + logger = logging.getLogger(__name__) diff --git a/src/mednet/models/model.py b/src/mednet/models/model.py index 560b1115096310fa2238dfb24f07a4f73829b8cc..e1b264bee83eec12c46f6286a2c2add78e15f69e 100644 --- a/src/mednet/models/model.py +++ b/src/mednet/models/model.py @@ -12,7 +12,7 @@ import torch.optim.optimizer import torch.utils.data import torchvision.transforms -from ..data.typing import TransformSequence +from medbase.data.typing import TransformSequence from .loss_weights import get_positive_weights from .typing import Checkpoint diff --git a/src/mednet/models/pasa.py b/src/mednet/models/pasa.py index cb7ebfea0da3d8433de869bda54f22b64d43bb0f..0e1d24f15973e374539a8eff7e370ff6f2aa602a 100644 --- a/src/mednet/models/pasa.py +++ b/src/mednet/models/pasa.py @@ -12,7 +12,7 @@ import torch.optim.optimizer import torch.utils.data import torchvision.transforms -from ..data.typing import TransformSequence +from medbase.data.typing import TransformSequence from .model import Model from .separate import separate from .transforms import Grayscale, SquareCenterPad diff --git a/src/mednet/models/separate.py b/src/mednet/models/separate.py index c79820bf5b93a1085ef23e58b627d323be4be1df..6c7afabb5769b4e5f805c9723b28a81e85facc0a 100644 --- a/src/mednet/models/separate.py +++ b/src/mednet/models/separate.py @@ -7,7 +7,8 @@ import typing import torch -from ..data.typing import Sample +from medbase.data.typing import Sample + from .typing import BinaryPrediction, MultiClassPrediction diff --git a/src/mednet/scripts/predict.py b/src/mednet/scripts/predict.py index 5645685a61de040e52ecdcbf3086f1b65381a87a..ffd42affe8f148dc1b55386c2e49b0a68d4d6704 100644 --- a/src/mednet/scripts/predict.py +++ b/src/mednet/scripts/predict.py @@ -134,9 +134,9 @@ def predict( import shutil import typing - from ..engine.device import DeviceManager - from ..engine.predictor import run - from ..utils.checkpointer import get_checkpoint_to_run_inference + from medbase.engine.device import DeviceManager + from mednet.engine.predictor import run + from medbase.utils.checkpointer import get_checkpoint_to_run_inference from .utils import ( device_properties, execution_metadata, diff --git a/src/mednet/scripts/saliency/completeness.py b/src/mednet/scripts/saliency/completeness.py index e02aa3b6a835c7b2383d961c5f100d3693f331b2..28a3369517a7e6660dff371e7b8f93bfeb31acf3 100644 --- a/src/mednet/scripts/saliency/completeness.py +++ b/src/mednet/scripts/saliency/completeness.py @@ -203,9 +203,10 @@ def completeness( """ import json - from ...engine.device import DeviceManager + from medbase.engine.device import DeviceManager + from medbase.utils.checkpointer import get_checkpoint_to_run_inference + from ...engine.saliency.completeness import run - from ...utils.checkpointer import get_checkpoint_to_run_inference if device in ("cuda", "mps") and (parallel == 0 or parallel > 1): raise RuntimeError( diff --git a/src/mednet/scripts/saliency/generate.py b/src/mednet/scripts/saliency/generate.py index 649ad96bf15a99126266568b596073c611cf04d7..34fb8382a4ecce911edd866d40d36dd52f160ff9 100644 --- a/src/mednet/scripts/saliency/generate.py +++ b/src/mednet/scripts/saliency/generate.py @@ -168,9 +168,11 @@ def generate( The quality of saliency information depends on the saliency map algorithm and trained model. """ - from ...engine.device import DeviceManager + + from medbase.engine.device import DeviceManager + from medbase.utils.checkpointer import get_checkpoint_to_run_inference + from ...engine.saliency.generator import run - from ...utils.checkpointer import get_checkpoint_to_run_inference logger.info(f"Output folder: {output_folder}") output_folder.mkdir(parents=True, exist_ok=True) diff --git a/src/mednet/scripts/train.py b/src/mednet/scripts/train.py index 42ecf1b8df9fe0f3b8dde27a60f0cbc8d45ec29d..df8b1fec1c9e9b150d6932b489fdd39ada63377c 100644 --- a/src/mednet/scripts/train.py +++ b/src/mednet/scripts/train.py @@ -272,9 +272,10 @@ def train( import torch from lightning.pytorch import seed_everything - from ..engine.device import DeviceManager - from ..engine.trainer import run - from ..utils.checkpointer import get_checkpoint_to_resume_training + from medbase.engine.device import DeviceManager + from medbase.engine.trainer import run + from medbase.utils.checkpointer import get_checkpoint_to_resume_training + from .utils import ( device_properties, execution_metadata, diff --git a/src/mednet/scripts/train_analysis.py b/src/mednet/scripts/train_analysis.py index 80e74e44c4e1a29f55028809daaf07f3e6a5198d..3938e22604fe4704bfd1d40b008acb0041ecd6af 100644 --- a/src/mednet/scripts/train_analysis.py +++ b/src/mednet/scripts/train_analysis.py @@ -231,7 +231,7 @@ def train_analysis( import matplotlib.pyplot as plt from matplotlib.backends.backend_pdf import PdfPages - from ..utils.tensorboard import scalars_to_dict + from medbase.utils.tensorboard import scalars_to_dict data = scalars_to_dict(logdir) diff --git a/tests/conftest.py b/tests/conftest.py index a2127bcb174a9796748103ef51624b7302b0ce10..89b055ec4de60cab6cacd3b4986aefd5fda91c1d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -10,8 +10,9 @@ import numpy import numpy.typing import pytest import torch -from mednet.data.split import JSONDatabaseSplit -from mednet.data.typing import DatabaseSplit + +from medbase.data.split import JSONDatabaseSplit +from medbase.data.typing import DatabaseSplit @pytest.fixture diff --git a/tests/test_cli.py b/tests/test_cli.py index e79d500d99cad1a2a6575f627681b901b362a037..ab6abe57cc1417e79974d5f807d32bea9ab05e63 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -204,11 +204,11 @@ def test_upload_help(): @pytest.mark.slow @pytest.mark.skip_if_rc_var_not_set("datadir.montgomery") def test_train_pasa_montgomery(temporary_basedir): - from mednet.scripts.train import train - from mednet.utils.checkpointer import ( + from medbase.utils.checkpointer import ( CHECKPOINT_EXTENSION, _get_checkpoint_from_alias, ) + from mednet.scripts.train import train runner = CliRunner() @@ -260,11 +260,11 @@ def test_train_pasa_montgomery(temporary_basedir): @pytest.mark.slow @pytest.mark.skip_if_rc_var_not_set("datadir.montgomery") def test_train_pasa_montgomery_from_checkpoint(temporary_basedir): - from mednet.scripts.train import train - from mednet.utils.checkpointer import ( + from medbase.utils.checkpointer import ( CHECKPOINT_EXTENSION, _get_checkpoint_from_alias, ) + from mednet.scripts.train import train runner = CliRunner() @@ -337,12 +337,12 @@ def test_train_pasa_montgomery_from_checkpoint(temporary_basedir): @pytest.mark.slow @pytest.mark.skip_if_rc_var_not_set("datadir.montgomery") -def test_predict_pasa_montgomery(temporary_basedir): - from mednet.scripts.predict import predict - from mednet.utils.checkpointer import ( +def test_predict_pasa_montgomery(temporary_basedir, datadir): + from medbase.utils.checkpointer import ( CHECKPOINT_EXTENSION, _get_checkpoint_from_alias, ) + from mednet.scripts.predict import predict runner = CliRunner() diff --git a/tests/test_database_split.py b/tests/test_database_split.py index 8f50ed997593ef4fc3bbddb3fab0140a54728054..7ebe400a305e0e0e009822f65bb0e3034915a91b 100644 --- a/tests/test_database_split.py +++ b/tests/test_database_split.py @@ -3,8 +3,7 @@ # SPDX-License-Identifier: GPL-3.0-or-later """Test code for datasets.""" -from mednet.data.split import JSONDatabaseSplit - +from medbase.data.split import JSONDatabaseSplit def test_json_loading(datadir): # tests if we can build a simple JSON loader for the Iris Flower dataset diff --git a/tests/test_image_utils.py b/tests/test_image_utils.py index ab431eca66f4170278363e780bdb960799b425ce..d81501ad6d71106dc18eab51b7068f479a834e8c 100644 --- a/tests/test_image_utils.py +++ b/tests/test_image_utils.py @@ -5,7 +5,8 @@ import numpy import PIL.Image -from mednet.data.image_utils import remove_black_borders + +from medbase.data.image_utils import remove_black_borders def test_remove_black_borders(datadir): diff --git a/tests/test_summary.py b/tests/test_summary.py new file mode 100644 index 0000000000000000000000000000000000000000..4e420eb29ad332655064ab177d91c985bee6d96a --- /dev/null +++ b/tests/test_summary.py @@ -0,0 +1,19 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +import unittest + +import mednet.config.models.pasa as pasa_config + +from medbase.utils.summary import summary + + +class Tester(unittest.TestCase): + """Unit test for model architectures.""" + + def test_summary_driu(self): + model = pasa_config.model + s, param = summary(model) + self.assertIsInstance(s, str) + self.assertIsInstance(param, int) diff --git a/tests/test_transforms.py b/tests/test_transforms.py index 2d8faf1e659ca3fe3e536337006620cc2b88de9c..85f599582814b1a1119b4a5f6580b3544094bffb 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -5,8 +5,9 @@ import numpy import PIL.Image -import torchvision.transforms.functional as F # noqa: N812 -from mednet.data.augmentations import ElasticDeformation +import torchvision.transforms.functional as F # noqa: N812 + +from medbase.data.augmentations import ElasticDeformation def test_elastic_deformation(datadir):