From 901c950a752fb8d953b445c2a34cba04202fedb1 Mon Sep 17 00:00:00 2001
From: dcarron <daniel.carron@idiap.ch>
Date: Tue, 27 Feb 2024 14:47:51 +0100
Subject: [PATCH] [common] Move reusable files to new common package

---
 src/medbase/__init__.py                       |  0
 src/{mednet => medbase}/data/__init__.py      |  0
 src/{mednet => medbase}/data/augmentations.py |  0
 src/{mednet => medbase}/data/datamodule.py    |  0
 src/{mednet => medbase}/data/image_utils.py   |  0
 src/{mednet => medbase}/data/split.py         |  2 +-
 src/{mednet => medbase}/data/typing.py        |  0
 src/medbase/engine/__init__.py                |  0
 src/{mednet => medbase}/engine/callbacks.py   |  0
 src/{mednet => medbase}/engine/device.py      |  0
 src/{mednet => medbase}/engine/loggers.py     |  0
 src/{mednet => medbase}/engine/trainer.py     |  0
 src/medbase/utils/__init__.py                 |  0
 src/{mednet => medbase}/utils/checkpointer.py |  0
 src/{mednet => medbase}/utils/resources.py    |  2 +-
 src/medbase/utils/summary.py                  | 61 +++++++++++++++++++
 src/{mednet => medbase}/utils/tensorboard.py  |  0
 src/mednet/config/data/hivtb/datamodule.py    | 11 ++--
 src/mednet/config/data/indian/datamodule.py   |  5 +-
 .../config/data/montgomery/datamodule.py      | 11 ++--
 .../data/montgomery_shenzhen/datamodule.py    |  5 +-
 .../montgomery_shenzhen_indian/datamodule.py  |  5 +-
 .../datamodule.py                             |  5 +-
 .../datamodule.py                             |  5 +-
 .../config/data/nih_cxr14/datamodule.py       |  9 +--
 .../data/nih_cxr14_padchest/datamodule.py     |  5 +-
 src/mednet/config/data/padchest/datamodule.py | 11 ++--
 src/mednet/config/data/shenzhen/datamodule.py | 11 ++--
 src/mednet/config/data/tbpoc/datamodule.py    | 11 ++--
 src/mednet/config/data/tbx11k/datamodule.py   |  9 +--
 src/mednet/engine/predictor.py                |  3 +-
 src/mednet/engine/saliency/completeness.py    |  5 +-
 src/mednet/engine/saliency/generator.py       |  3 +-
 src/mednet/models/alexnet.py                  |  2 +-
 src/mednet/models/densenet.py                 |  2 +-
 src/mednet/models/loss_weights.py             |  2 +
 src/mednet/models/model.py                    |  2 +-
 src/mednet/models/pasa.py                     |  2 +-
 src/mednet/models/separate.py                 |  3 +-
 src/mednet/scripts/predict.py                 |  6 +-
 src/mednet/scripts/saliency/completeness.py   |  5 +-
 src/mednet/scripts/saliency/generate.py       |  6 +-
 src/mednet/scripts/train.py                   |  7 ++-
 src/mednet/scripts/train_analysis.py          |  2 +-
 tests/conftest.py                             |  5 +-
 tests/test_cli.py                             | 14 ++---
 tests/test_database_split.py                  |  3 +-
 tests/test_image_utils.py                     |  3 +-
 tests/test_summary.py                         | 19 ++++++
 tests/test_transforms.py                      |  5 +-
 50 files changed, 186 insertions(+), 81 deletions(-)
 create mode 100644 src/medbase/__init__.py
 rename src/{mednet => medbase}/data/__init__.py (100%)
 rename src/{mednet => medbase}/data/augmentations.py (100%)
 rename src/{mednet => medbase}/data/datamodule.py (100%)
 rename src/{mednet => medbase}/data/image_utils.py (100%)
 rename src/{mednet => medbase}/data/split.py (98%)
 rename src/{mednet => medbase}/data/typing.py (100%)
 create mode 100644 src/medbase/engine/__init__.py
 rename src/{mednet => medbase}/engine/callbacks.py (100%)
 rename src/{mednet => medbase}/engine/device.py (100%)
 rename src/{mednet => medbase}/engine/loggers.py (100%)
 rename src/{mednet => medbase}/engine/trainer.py (100%)
 create mode 100644 src/medbase/utils/__init__.py
 rename src/{mednet => medbase}/utils/checkpointer.py (100%)
 rename src/{mednet => medbase}/utils/resources.py (99%)
 create mode 100644 src/medbase/utils/summary.py
 rename src/{mednet => medbase}/utils/tensorboard.py (100%)
 create mode 100644 tests/test_summary.py

diff --git a/src/medbase/__init__.py b/src/medbase/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/src/mednet/data/__init__.py b/src/medbase/data/__init__.py
similarity index 100%
rename from src/mednet/data/__init__.py
rename to src/medbase/data/__init__.py
diff --git a/src/mednet/data/augmentations.py b/src/medbase/data/augmentations.py
similarity index 100%
rename from src/mednet/data/augmentations.py
rename to src/medbase/data/augmentations.py
diff --git a/src/mednet/data/datamodule.py b/src/medbase/data/datamodule.py
similarity index 100%
rename from src/mednet/data/datamodule.py
rename to src/medbase/data/datamodule.py
diff --git a/src/mednet/data/image_utils.py b/src/medbase/data/image_utils.py
similarity index 100%
rename from src/mednet/data/image_utils.py
rename to src/medbase/data/image_utils.py
diff --git a/src/mednet/data/split.py b/src/medbase/data/split.py
similarity index 98%
rename from src/mednet/data/split.py
rename to src/medbase/data/split.py
index 3bf8b8ca..0bcdf300 100644
--- a/src/mednet/data/split.py
+++ b/src/medbase/data/split.py
@@ -12,7 +12,7 @@ import typing
 
 import torch
 
-from .typing import DatabaseSplit, RawDataLoader
+from medbase.data.typing import DatabaseSplit, RawDataLoader
 
 logger = logging.getLogger(__name__)
 
diff --git a/src/mednet/data/typing.py b/src/medbase/data/typing.py
similarity index 100%
rename from src/mednet/data/typing.py
rename to src/medbase/data/typing.py
diff --git a/src/medbase/engine/__init__.py b/src/medbase/engine/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/src/mednet/engine/callbacks.py b/src/medbase/engine/callbacks.py
similarity index 100%
rename from src/mednet/engine/callbacks.py
rename to src/medbase/engine/callbacks.py
diff --git a/src/mednet/engine/device.py b/src/medbase/engine/device.py
similarity index 100%
rename from src/mednet/engine/device.py
rename to src/medbase/engine/device.py
diff --git a/src/mednet/engine/loggers.py b/src/medbase/engine/loggers.py
similarity index 100%
rename from src/mednet/engine/loggers.py
rename to src/medbase/engine/loggers.py
diff --git a/src/mednet/engine/trainer.py b/src/medbase/engine/trainer.py
similarity index 100%
rename from src/mednet/engine/trainer.py
rename to src/medbase/engine/trainer.py
diff --git a/src/medbase/utils/__init__.py b/src/medbase/utils/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/src/mednet/utils/checkpointer.py b/src/medbase/utils/checkpointer.py
similarity index 100%
rename from src/mednet/utils/checkpointer.py
rename to src/medbase/utils/checkpointer.py
diff --git a/src/mednet/utils/resources.py b/src/medbase/utils/resources.py
similarity index 99%
rename from src/mednet/utils/resources.py
rename to src/medbase/utils/resources.py
index 5b8844a1..58167711 100644
--- a/src/mednet/utils/resources.py
+++ b/src/medbase/utils/resources.py
@@ -18,7 +18,7 @@ import warnings
 import numpy
 import psutil
 
-from ..engine.device import SupportedPytorchDevice
+from medbase.engine.device import SupportedPytorchDevice
 
 logger = logging.getLogger(__name__)
 
diff --git a/src/medbase/utils/summary.py b/src/medbase/utils/summary.py
new file mode 100644
index 00000000..bff705e3
--- /dev/null
+++ b/src/medbase/utils/summary.py
@@ -0,0 +1,61 @@
+# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
+#
+# SPDX-License-Identifier: GPL-3.0-or-later
+
+# Adapted from https://github.com/pytorch/pytorch/issues/2001#issuecomment-405675488
+
+from functools import reduce
+
+import torch
+
+from torch.nn.modules.module import _addindent
+
+
+# ignore this space!
+def _repr(model: torch.nn.Module) -> tuple[str, int]:
+    # We treat the extra repr like the sub-module, one item per line
+    extra_lines = []
+    extra_repr = model.extra_repr()
+    # empty string will be split into list ['']
+    if extra_repr:
+        extra_lines = extra_repr.split("\n")
+    child_lines = []
+    total_params = 0
+    for key, module in model._modules.items():
+        mod_str, num_params = _repr(module)
+        mod_str = _addindent(mod_str, 2)
+        child_lines.append("(" + key + "): " + mod_str)
+        total_params += num_params
+    lines = extra_lines + child_lines
+
+    for _, p in model._parameters.items():
+        if hasattr(p, "dtype"):
+            total_params += reduce(lambda x, y: x * y, p.shape)
+
+    main_str = model._get_name() + "("
+    if lines:
+        # simple one-liner info, which most builtin Modules will use
+        if len(extra_lines) == 1 and not child_lines:
+            main_str += extra_lines[0]
+        else:
+            main_str += "\n  " + "\n  ".join(lines) + "\n"
+
+    main_str += ")"
+    main_str += f", {total_params:,} params"
+    return main_str, total_params
+
+
+def summary(model: torch.nn.Module) -> tuple[str, int]:
+    """Count the number of parameters in each model layer.
+
+    Parameters
+    ----------
+    model
+        Model to summarize.
+
+    Returns
+    -------
+    tuple[int, str]
+        A tuple containing a multiline string representation of the network and the number of parameters.
+    """
+    return _repr(model)
diff --git a/src/mednet/utils/tensorboard.py b/src/medbase/utils/tensorboard.py
similarity index 100%
rename from src/mednet/utils/tensorboard.py
rename to src/medbase/utils/tensorboard.py
diff --git a/src/mednet/config/data/hivtb/datamodule.py b/src/mednet/config/data/hivtb/datamodule.py
index 6c674085..21f773b2 100644
--- a/src/mednet/config/data/hivtb/datamodule.py
+++ b/src/mednet/config/data/hivtb/datamodule.py
@@ -12,11 +12,12 @@ import pathlib
 import PIL.Image
 from torchvision.transforms.functional import to_tensor
 
-from ....data.datamodule import CachingDataModule
-from ....data.image_utils import remove_black_borders
-from ....data.split import make_split
-from ....data.typing import RawDataLoader as _BaseRawDataLoader
-from ....data.typing import Sample
+from medbase.data.datamodule import CachingDataModule
+from medbase.data.image_utils import remove_black_borders
+from medbase.data.split import make_split
+from medbase.data.typing import RawDataLoader as _BaseRawDataLoader
+from medbase.data.typing import Sample
+
 from ....utils.rc import load_rc
 
 CONFIGURATION_KEY_DATADIR = "datadir." + (__name__.rsplit(".", 2)[-2])
diff --git a/src/mednet/config/data/indian/datamodule.py b/src/mednet/config/data/indian/datamodule.py
index fa285170..f6825bba 100644
--- a/src/mednet/config/data/indian/datamodule.py
+++ b/src/mednet/config/data/indian/datamodule.py
@@ -8,9 +8,10 @@ Database reference: [INDIAN-2013]_
 
 import pathlib
 
+from medbase.data.datamodule import CachingDataModule
+from medbase.data.split import make_split
+
 from ....config.data.shenzhen.datamodule import RawDataLoader
-from ....data.datamodule import CachingDataModule
-from ....data.split import make_split
 
 CONFIGURATION_KEY_DATADIR = "datadir." + (__name__.rsplit(".", 2)[-2])
 """Key to search for in the configuration file for the root directory of this
diff --git a/src/mednet/config/data/montgomery/datamodule.py b/src/mednet/config/data/montgomery/datamodule.py
index 98d591ae..64b2f7ff 100644
--- a/src/mednet/config/data/montgomery/datamodule.py
+++ b/src/mednet/config/data/montgomery/datamodule.py
@@ -12,11 +12,12 @@ import pathlib
 import PIL.Image
 from torchvision.transforms.functional import to_tensor
 
-from ....data.datamodule import CachingDataModule
-from ....data.image_utils import remove_black_borders
-from ....data.split import make_split
-from ....data.typing import RawDataLoader as _BaseRawDataLoader
-from ....data.typing import Sample
+from medbase.data.datamodule import CachingDataModule
+from medbase.data.image_utils import remove_black_borders
+from medbase.data.split import make_split
+from medbase.data.typing import Sample
+from medbase.data.typing import RawDataLoader as _BaseRawDataLoader
+
 from ....utils.rc import load_rc
 
 CONFIGURATION_KEY_DATADIR = "datadir." + (__name__.rsplit(".", 2)[-2])
diff --git a/src/mednet/config/data/montgomery_shenzhen/datamodule.py b/src/mednet/config/data/montgomery_shenzhen/datamodule.py
index f726a431..16dfec25 100644
--- a/src/mednet/config/data/montgomery_shenzhen/datamodule.py
+++ b/src/mednet/config/data/montgomery_shenzhen/datamodule.py
@@ -5,8 +5,9 @@
 
 import pathlib
 
-from ....data.datamodule import ConcatDataModule
-from ....data.split import make_split
+from medbase.data.datamodule import ConcatDataModule
+from medbase.data.split import make_split
+
 from ..montgomery.datamodule import RawDataLoader as MontgomeryLoader
 from ..shenzhen.datamodule import RawDataLoader as ShenzhenLoader
 
diff --git a/src/mednet/config/data/montgomery_shenzhen_indian/datamodule.py b/src/mednet/config/data/montgomery_shenzhen_indian/datamodule.py
index f75eb1e5..bfeb1a33 100644
--- a/src/mednet/config/data/montgomery_shenzhen_indian/datamodule.py
+++ b/src/mednet/config/data/montgomery_shenzhen_indian/datamodule.py
@@ -7,8 +7,9 @@ databases.
 
 import pathlib
 
-from ....data.datamodule import ConcatDataModule
-from ....data.split import make_split
+from medbase.data.datamodule import ConcatDataModule
+from medbase.data.split import make_split
+
 from ..indian.datamodule import CONFIGURATION_KEY_DATADIR as INDIAN_KEY_DATADIR
 from ..indian.datamodule import DataModule as IndianDataModule
 from ..indian.datamodule import RawDataLoader as IndianLoader
diff --git a/src/mednet/config/data/montgomery_shenzhen_indian_padchest/datamodule.py b/src/mednet/config/data/montgomery_shenzhen_indian_padchest/datamodule.py
index 060f6231..268ba2a2 100644
--- a/src/mednet/config/data/montgomery_shenzhen_indian_padchest/datamodule.py
+++ b/src/mednet/config/data/montgomery_shenzhen_indian_padchest/datamodule.py
@@ -7,8 +7,9 @@ datasets.
 
 import pathlib
 
-from ....data.datamodule import ConcatDataModule
-from ....data.split import make_split
+from medbase.data.datamodule import ConcatDataModule
+from medbase.data.split import make_split
+
 from ..indian.datamodule import CONFIGURATION_KEY_DATADIR as INDIAN_KEY_DATADIR
 from ..indian.datamodule import DataModule as IndianDataModule
 from ..indian.datamodule import RawDataLoader as IndianLoader
diff --git a/src/mednet/config/data/montgomery_shenzhen_indian_tbx11k/datamodule.py b/src/mednet/config/data/montgomery_shenzhen_indian_tbx11k/datamodule.py
index db2f7901..932e6571 100644
--- a/src/mednet/config/data/montgomery_shenzhen_indian_tbx11k/datamodule.py
+++ b/src/mednet/config/data/montgomery_shenzhen_indian_tbx11k/datamodule.py
@@ -7,8 +7,9 @@ datasets.
 
 import pathlib
 
-from ....data.datamodule import ConcatDataModule
-from ....data.split import make_split
+from medbase.data.datamodule import ConcatDataModule
+from medbase.data.split import make_split
+
 from ..indian.datamodule import CONFIGURATION_KEY_DATADIR as INDIAN_KEY_DATADIR
 from ..indian.datamodule import DataModule as IndianDataModule
 from ..indian.datamodule import RawDataLoader as IndianLoader
diff --git a/src/mednet/config/data/nih_cxr14/datamodule.py b/src/mednet/config/data/nih_cxr14/datamodule.py
index 6d6b283e..8e72d6d4 100644
--- a/src/mednet/config/data/nih_cxr14/datamodule.py
+++ b/src/mednet/config/data/nih_cxr14/datamodule.py
@@ -12,10 +12,11 @@ import pathlib
 import PIL.Image
 from torchvision.transforms.functional import to_tensor
 
-from ....data.datamodule import CachingDataModule
-from ....data.split import make_split
-from ....data.typing import RawDataLoader as _BaseRawDataLoader
-from ....data.typing import Sample
+from medbase.data.datamodule import CachingDataModule
+from medbase.data.split import make_split
+from medbase.data.typing import Sample
+from medbase.data.typing import RawDataLoader as _BaseRawDataLoader
+
 from ....utils.rc import load_rc
 
 CONFIGURATION_KEY_DATADIR = "datadir." + (__name__.rsplit(".", 2)[-2])
diff --git a/src/mednet/config/data/nih_cxr14_padchest/datamodule.py b/src/mednet/config/data/nih_cxr14_padchest/datamodule.py
index bfd17daf..3de000c1 100644
--- a/src/mednet/config/data/nih_cxr14_padchest/datamodule.py
+++ b/src/mednet/config/data/nih_cxr14_padchest/datamodule.py
@@ -5,8 +5,9 @@
 
 import pathlib
 
-from ....data.datamodule import ConcatDataModule
-from ....data.split import make_split
+from medbase.data.datamodule import ConcatDataModule
+from medbase.data.split import make_split
+
 from ..nih_cxr14.datamodule import RawDataLoader as CXR14Loader
 from ..padchest.datamodule import RawDataLoader as PadchestLoader
 
diff --git a/src/mednet/config/data/padchest/datamodule.py b/src/mednet/config/data/padchest/datamodule.py
index 97c80a13..99545ee2 100644
--- a/src/mednet/config/data/padchest/datamodule.py
+++ b/src/mednet/config/data/padchest/datamodule.py
@@ -13,11 +13,12 @@ import numpy
 import PIL.Image
 from torchvision.transforms.functional import to_tensor
 
-from ....data.datamodule import CachingDataModule
-from ....data.image_utils import remove_black_borders
-from ....data.split import make_split
-from ....data.typing import RawDataLoader as _BaseRawDataLoader
-from ....data.typing import Sample
+from medbase.data.datamodule import CachingDataModule
+from medbase.data.image_utils import remove_black_borders
+from medbase.data.split import make_split
+from medbase.data.typing import Sample
+from medbase.data.typing import RawDataLoader as _BaseRawDataLoader
+
 from ....utils.rc import load_rc
 
 CONFIGURATION_KEY_DATADIR = "datadir." + (__name__.rsplit(".", 2)[-2])
diff --git a/src/mednet/config/data/shenzhen/datamodule.py b/src/mednet/config/data/shenzhen/datamodule.py
index c2838ce2..ae1b4639 100644
--- a/src/mednet/config/data/shenzhen/datamodule.py
+++ b/src/mednet/config/data/shenzhen/datamodule.py
@@ -12,11 +12,12 @@ import pathlib
 import PIL.Image
 from torchvision.transforms.functional import to_tensor
 
-from ....data.datamodule import CachingDataModule
-from ....data.image_utils import remove_black_borders
-from ....data.split import make_split
-from ....data.typing import RawDataLoader as _BaseRawDataLoader
-from ....data.typing import Sample
+from medbase.data.datamodule import CachingDataModule
+from medbase.data.image_utils import remove_black_borders
+from medbase.data.split import make_split
+from medbase.data.typing import Sample
+from medbase.data.typing import RawDataLoader as _BaseRawDataLoader
+
 from ....utils.rc import load_rc
 
 CONFIGURATION_KEY_DATADIR = "datadir." + (__name__.rsplit(".", 2)[-2])
diff --git a/src/mednet/config/data/tbpoc/datamodule.py b/src/mednet/config/data/tbpoc/datamodule.py
index 3e7dcaa2..5b4b6011 100644
--- a/src/mednet/config/data/tbpoc/datamodule.py
+++ b/src/mednet/config/data/tbpoc/datamodule.py
@@ -8,11 +8,12 @@ import pathlib
 import PIL.Image
 from torchvision.transforms.functional import to_tensor
 
-from ....data.datamodule import CachingDataModule
-from ....data.image_utils import remove_black_borders
-from ....data.split import make_split
-from ....data.typing import RawDataLoader as _BaseRawDataLoader
-from ....data.typing import Sample
+from medbase.data.datamodule import CachingDataModule
+from medbase.data.image_utils import remove_black_borders
+from medbase.data.split import make_split
+from medbase.data.typing import Sample
+from medbase.data.typing import RawDataLoader as _BaseRawDataLoader
+
 from ....utils.rc import load_rc
 
 CONFIGURATION_KEY_DATADIR = "datadir." + (__name__.rsplit(".", 2)[-2])
diff --git a/src/mednet/config/data/tbx11k/datamodule.py b/src/mednet/config/data/tbx11k/datamodule.py
index 993f9e50..c6b28ef9 100644
--- a/src/mednet/config/data/tbx11k/datamodule.py
+++ b/src/mednet/config/data/tbx11k/datamodule.py
@@ -13,10 +13,11 @@ import typing_extensions
 from torch.utils.data._utils.collate import default_collate_fn_map
 from torchvision.transforms.functional import to_tensor
 
-from ....data.datamodule import CachingDataModule
-from ....data.split import make_split
-from ....data.typing import RawDataLoader as _BaseRawDataLoader
-from ....data.typing import Sample
+from medbase.data.datamodule import CachingDataModule
+from medbase.data.split import make_split
+from medbase.data.typing import RawDataLoader as _BaseRawDataLoader
+from medbase.data.typing import Sample
+
 from ....utils.rc import load_rc
 
 CONFIGURATION_KEY_DATADIR = "datadir." + (__name__.rsplit(".", 2)[-2])
diff --git a/src/mednet/engine/predictor.py b/src/mednet/engine/predictor.py
index 1ba6f05a..fce844a1 100644
--- a/src/mednet/engine/predictor.py
+++ b/src/mednet/engine/predictor.py
@@ -7,13 +7,14 @@ import logging
 import lightning.pytorch
 import torch.utils.data
 
+from medbase.engine.device import DeviceManager
+
 from ..models.typing import (
     BinaryPrediction,
     BinaryPredictionSplit,
     MultiClassPrediction,
     MultiClassPredictionSplit,
 )
-from .device import DeviceManager
 
 logger = logging.getLogger(__name__)
 
diff --git a/src/mednet/engine/saliency/completeness.py b/src/mednet/engine/saliency/completeness.py
index eb2832bc..7c3cb3c6 100644
--- a/src/mednet/engine/saliency/completeness.py
+++ b/src/mednet/engine/saliency/completeness.py
@@ -17,9 +17,10 @@ from pytorch_grad_cam.metrics.road import (
 )
 from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
 
-from ...data.typing import Sample
+from medbase.data.typing import Sample
+from medbase.engine.device import DeviceManager
+
 from ...models.typing import SaliencyMapAlgorithm
-from ..device import DeviceManager
 
 logger = logging.getLogger(__name__)
 
diff --git a/src/mednet/engine/saliency/generator.py b/src/mednet/engine/saliency/generator.py
index e1c03af7..f440eb08 100644
--- a/src/mednet/engine/saliency/generator.py
+++ b/src/mednet/engine/saliency/generator.py
@@ -12,8 +12,9 @@ import torch
 import torch.nn
 import tqdm
 
+from medbase.engine.device import DeviceManager
+
 from ...models.typing import SaliencyMapAlgorithm
-from ..device import DeviceManager
 
 logger = logging.getLogger(__name__)
 
diff --git a/src/mednet/models/alexnet.py b/src/mednet/models/alexnet.py
index 3e58463e..22c41547 100644
--- a/src/mednet/models/alexnet.py
+++ b/src/mednet/models/alexnet.py
@@ -12,7 +12,7 @@ import torch.utils.data
 import torchvision.models as models
 import torchvision.transforms
 
-from ..data.typing import TransformSequence
+from medbase.data.typing import TransformSequence
 from .model import Model
 from .separate import separate
 from .transforms import RGB, SquareCenterPad
diff --git a/src/mednet/models/densenet.py b/src/mednet/models/densenet.py
index c54d5d9a..3feb2ade 100644
--- a/src/mednet/models/densenet.py
+++ b/src/mednet/models/densenet.py
@@ -12,7 +12,7 @@ import torch.utils.data
 import torchvision.models as models
 import torchvision.transforms
 
-from ..data.typing import TransformSequence
+from medbase.data.typing import TransformSequence
 from .model import Model
 from .separate import separate
 from .transforms import RGB, SquareCenterPad
diff --git a/src/mednet/models/loss_weights.py b/src/mednet/models/loss_weights.py
index 2cf5e292..382a7ae5 100644
--- a/src/mednet/models/loss_weights.py
+++ b/src/mednet/models/loss_weights.py
@@ -9,6 +9,8 @@ from collections import Counter
 import torch
 import torch.utils.data
 
+from medbase.data.typing import DataLoader
+
 logger = logging.getLogger(__name__)
 
 
diff --git a/src/mednet/models/model.py b/src/mednet/models/model.py
index 560b1115..e1b264be 100644
--- a/src/mednet/models/model.py
+++ b/src/mednet/models/model.py
@@ -12,7 +12,7 @@ import torch.optim.optimizer
 import torch.utils.data
 import torchvision.transforms
 
-from ..data.typing import TransformSequence
+from medbase.data.typing import TransformSequence
 from .loss_weights import get_positive_weights
 from .typing import Checkpoint
 
diff --git a/src/mednet/models/pasa.py b/src/mednet/models/pasa.py
index cb7ebfea..0e1d24f1 100644
--- a/src/mednet/models/pasa.py
+++ b/src/mednet/models/pasa.py
@@ -12,7 +12,7 @@ import torch.optim.optimizer
 import torch.utils.data
 import torchvision.transforms
 
-from ..data.typing import TransformSequence
+from medbase.data.typing import TransformSequence
 from .model import Model
 from .separate import separate
 from .transforms import Grayscale, SquareCenterPad
diff --git a/src/mednet/models/separate.py b/src/mednet/models/separate.py
index c79820bf..6c7afabb 100644
--- a/src/mednet/models/separate.py
+++ b/src/mednet/models/separate.py
@@ -7,7 +7,8 @@ import typing
 
 import torch
 
-from ..data.typing import Sample
+from medbase.data.typing import Sample
+
 from .typing import BinaryPrediction, MultiClassPrediction
 
 
diff --git a/src/mednet/scripts/predict.py b/src/mednet/scripts/predict.py
index 5645685a..ffd42aff 100644
--- a/src/mednet/scripts/predict.py
+++ b/src/mednet/scripts/predict.py
@@ -134,9 +134,9 @@ def predict(
     import shutil
     import typing
 
-    from ..engine.device import DeviceManager
-    from ..engine.predictor import run
-    from ..utils.checkpointer import get_checkpoint_to_run_inference
+    from medbase.engine.device import DeviceManager
+    from mednet.engine.predictor import run
+    from medbase.utils.checkpointer import get_checkpoint_to_run_inference
     from .utils import (
         device_properties,
         execution_metadata,
diff --git a/src/mednet/scripts/saliency/completeness.py b/src/mednet/scripts/saliency/completeness.py
index e02aa3b6..28a33695 100644
--- a/src/mednet/scripts/saliency/completeness.py
+++ b/src/mednet/scripts/saliency/completeness.py
@@ -203,9 +203,10 @@ def completeness(
     """
     import json
 
-    from ...engine.device import DeviceManager
+    from medbase.engine.device import DeviceManager
+    from medbase.utils.checkpointer import get_checkpoint_to_run_inference
+
     from ...engine.saliency.completeness import run
-    from ...utils.checkpointer import get_checkpoint_to_run_inference
 
     if device in ("cuda", "mps") and (parallel == 0 or parallel > 1):
         raise RuntimeError(
diff --git a/src/mednet/scripts/saliency/generate.py b/src/mednet/scripts/saliency/generate.py
index 649ad96b..34fb8382 100644
--- a/src/mednet/scripts/saliency/generate.py
+++ b/src/mednet/scripts/saliency/generate.py
@@ -168,9 +168,11 @@ def generate(
     The quality of saliency information depends on the saliency map
     algorithm and trained model.
     """
-    from ...engine.device import DeviceManager
+
+    from medbase.engine.device import DeviceManager
+    from medbase.utils.checkpointer import get_checkpoint_to_run_inference
+
     from ...engine.saliency.generator import run
-    from ...utils.checkpointer import get_checkpoint_to_run_inference
 
     logger.info(f"Output folder: {output_folder}")
     output_folder.mkdir(parents=True, exist_ok=True)
diff --git a/src/mednet/scripts/train.py b/src/mednet/scripts/train.py
index 42ecf1b8..df8b1fec 100644
--- a/src/mednet/scripts/train.py
+++ b/src/mednet/scripts/train.py
@@ -272,9 +272,10 @@ def train(
     import torch
     from lightning.pytorch import seed_everything
 
-    from ..engine.device import DeviceManager
-    from ..engine.trainer import run
-    from ..utils.checkpointer import get_checkpoint_to_resume_training
+    from medbase.engine.device import DeviceManager
+    from medbase.engine.trainer import run
+    from medbase.utils.checkpointer import get_checkpoint_to_resume_training
+
     from .utils import (
         device_properties,
         execution_metadata,
diff --git a/src/mednet/scripts/train_analysis.py b/src/mednet/scripts/train_analysis.py
index 80e74e44..3938e226 100644
--- a/src/mednet/scripts/train_analysis.py
+++ b/src/mednet/scripts/train_analysis.py
@@ -231,7 +231,7 @@ def train_analysis(
     import matplotlib.pyplot as plt
     from matplotlib.backends.backend_pdf import PdfPages
 
-    from ..utils.tensorboard import scalars_to_dict
+    from medbase.utils.tensorboard import scalars_to_dict
 
     data = scalars_to_dict(logdir)
 
diff --git a/tests/conftest.py b/tests/conftest.py
index a2127bcb..89b055ec 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -10,8 +10,9 @@ import numpy
 import numpy.typing
 import pytest
 import torch
-from mednet.data.split import JSONDatabaseSplit
-from mednet.data.typing import DatabaseSplit
+
+from medbase.data.split import JSONDatabaseSplit
+from medbase.data.typing import DatabaseSplit
 
 
 @pytest.fixture
diff --git a/tests/test_cli.py b/tests/test_cli.py
index e79d500d..ab6abe57 100644
--- a/tests/test_cli.py
+++ b/tests/test_cli.py
@@ -204,11 +204,11 @@ def test_upload_help():
 @pytest.mark.slow
 @pytest.mark.skip_if_rc_var_not_set("datadir.montgomery")
 def test_train_pasa_montgomery(temporary_basedir):
-    from mednet.scripts.train import train
-    from mednet.utils.checkpointer import (
+    from medbase.utils.checkpointer import (
         CHECKPOINT_EXTENSION,
         _get_checkpoint_from_alias,
     )
+    from mednet.scripts.train import train
 
     runner = CliRunner()
 
@@ -260,11 +260,11 @@ def test_train_pasa_montgomery(temporary_basedir):
 @pytest.mark.slow
 @pytest.mark.skip_if_rc_var_not_set("datadir.montgomery")
 def test_train_pasa_montgomery_from_checkpoint(temporary_basedir):
-    from mednet.scripts.train import train
-    from mednet.utils.checkpointer import (
+    from medbase.utils.checkpointer import (
         CHECKPOINT_EXTENSION,
         _get_checkpoint_from_alias,
     )
+    from mednet.scripts.train import train
 
     runner = CliRunner()
 
@@ -337,12 +337,12 @@ def test_train_pasa_montgomery_from_checkpoint(temporary_basedir):
 
 @pytest.mark.slow
 @pytest.mark.skip_if_rc_var_not_set("datadir.montgomery")
-def test_predict_pasa_montgomery(temporary_basedir):
-    from mednet.scripts.predict import predict
-    from mednet.utils.checkpointer import (
+def test_predict_pasa_montgomery(temporary_basedir, datadir):
+    from medbase.utils.checkpointer import (
         CHECKPOINT_EXTENSION,
         _get_checkpoint_from_alias,
     )
+    from mednet.scripts.predict import predict
 
     runner = CliRunner()
 
diff --git a/tests/test_database_split.py b/tests/test_database_split.py
index 8f50ed99..7ebe400a 100644
--- a/tests/test_database_split.py
+++ b/tests/test_database_split.py
@@ -3,8 +3,7 @@
 # SPDX-License-Identifier: GPL-3.0-or-later
 """Test code for datasets."""
 
-from mednet.data.split import JSONDatabaseSplit
-
+from medbase.data.split import JSONDatabaseSplit
 
 def test_json_loading(datadir):
     # tests if we can build a simple JSON loader for the Iris Flower dataset
diff --git a/tests/test_image_utils.py b/tests/test_image_utils.py
index ab431eca..d81501ad 100644
--- a/tests/test_image_utils.py
+++ b/tests/test_image_utils.py
@@ -5,7 +5,8 @@
 
 import numpy
 import PIL.Image
-from mednet.data.image_utils import remove_black_borders
+
+from medbase.data.image_utils import remove_black_borders
 
 
 def test_remove_black_borders(datadir):
diff --git a/tests/test_summary.py b/tests/test_summary.py
new file mode 100644
index 00000000..4e420eb2
--- /dev/null
+++ b/tests/test_summary.py
@@ -0,0 +1,19 @@
+# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
+#
+# SPDX-License-Identifier: GPL-3.0-or-later
+
+import unittest
+
+import mednet.config.models.pasa as pasa_config
+
+from medbase.utils.summary import summary
+
+
+class Tester(unittest.TestCase):
+    """Unit test for model architectures."""
+
+    def test_summary_driu(self):
+        model = pasa_config.model
+        s, param = summary(model)
+        self.assertIsInstance(s, str)
+        self.assertIsInstance(param, int)
diff --git a/tests/test_transforms.py b/tests/test_transforms.py
index 2d8faf1e..85f59958 100644
--- a/tests/test_transforms.py
+++ b/tests/test_transforms.py
@@ -5,8 +5,9 @@
 
 import numpy
 import PIL.Image
-import torchvision.transforms.functional as F  # noqa: N812
-from mednet.data.augmentations import ElasticDeformation
+import torchvision.transforms.functional as F # noqa: N812
+
+from medbase.data.augmentations import ElasticDeformation
 
 
 def test_elastic_deformation(datadir):
-- 
GitLab