Skip to content
Snippets Groups Projects
Commit 901c950a authored by Daniel CARRON's avatar Daniel CARRON :b: Committed by André Anjos
Browse files

[common] Move reusable files to new common package

parent 2631da1d
No related branches found
No related tags found
1 merge request!46Create common library
......@@ -203,9 +203,10 @@ def completeness(
"""
import json
from ...engine.device import DeviceManager
from medbase.engine.device import DeviceManager
from medbase.utils.checkpointer import get_checkpoint_to_run_inference
from ...engine.saliency.completeness import run
from ...utils.checkpointer import get_checkpoint_to_run_inference
if device in ("cuda", "mps") and (parallel == 0 or parallel > 1):
raise RuntimeError(
......
......@@ -168,9 +168,11 @@ def generate(
The quality of saliency information depends on the saliency map
algorithm and trained model.
"""
from ...engine.device import DeviceManager
from medbase.engine.device import DeviceManager
from medbase.utils.checkpointer import get_checkpoint_to_run_inference
from ...engine.saliency.generator import run
from ...utils.checkpointer import get_checkpoint_to_run_inference
logger.info(f"Output folder: {output_folder}")
output_folder.mkdir(parents=True, exist_ok=True)
......
......@@ -272,9 +272,10 @@ def train(
import torch
from lightning.pytorch import seed_everything
from ..engine.device import DeviceManager
from ..engine.trainer import run
from ..utils.checkpointer import get_checkpoint_to_resume_training
from medbase.engine.device import DeviceManager
from medbase.engine.trainer import run
from medbase.utils.checkpointer import get_checkpoint_to_resume_training
from .utils import (
device_properties,
execution_metadata,
......
......@@ -231,7 +231,7 @@ def train_analysis(
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages
from ..utils.tensorboard import scalars_to_dict
from medbase.utils.tensorboard import scalars_to_dict
data = scalars_to_dict(logdir)
......
......@@ -10,8 +10,9 @@ import numpy
import numpy.typing
import pytest
import torch
from mednet.data.split import JSONDatabaseSplit
from mednet.data.typing import DatabaseSplit
from medbase.data.split import JSONDatabaseSplit
from medbase.data.typing import DatabaseSplit
@pytest.fixture
......
......@@ -204,11 +204,11 @@ def test_upload_help():
@pytest.mark.slow
@pytest.mark.skip_if_rc_var_not_set("datadir.montgomery")
def test_train_pasa_montgomery(temporary_basedir):
from mednet.scripts.train import train
from mednet.utils.checkpointer import (
from medbase.utils.checkpointer import (
CHECKPOINT_EXTENSION,
_get_checkpoint_from_alias,
)
from mednet.scripts.train import train
runner = CliRunner()
......@@ -260,11 +260,11 @@ def test_train_pasa_montgomery(temporary_basedir):
@pytest.mark.slow
@pytest.mark.skip_if_rc_var_not_set("datadir.montgomery")
def test_train_pasa_montgomery_from_checkpoint(temporary_basedir):
from mednet.scripts.train import train
from mednet.utils.checkpointer import (
from medbase.utils.checkpointer import (
CHECKPOINT_EXTENSION,
_get_checkpoint_from_alias,
)
from mednet.scripts.train import train
runner = CliRunner()
......@@ -337,12 +337,12 @@ def test_train_pasa_montgomery_from_checkpoint(temporary_basedir):
@pytest.mark.slow
@pytest.mark.skip_if_rc_var_not_set("datadir.montgomery")
def test_predict_pasa_montgomery(temporary_basedir):
from mednet.scripts.predict import predict
from mednet.utils.checkpointer import (
def test_predict_pasa_montgomery(temporary_basedir, datadir):
from medbase.utils.checkpointer import (
CHECKPOINT_EXTENSION,
_get_checkpoint_from_alias,
)
from mednet.scripts.predict import predict
runner = CliRunner()
......
......@@ -3,8 +3,7 @@
# SPDX-License-Identifier: GPL-3.0-or-later
"""Test code for datasets."""
from mednet.data.split import JSONDatabaseSplit
from medbase.data.split import JSONDatabaseSplit
def test_json_loading(datadir):
# tests if we can build a simple JSON loader for the Iris Flower dataset
......
......@@ -5,7 +5,8 @@
import numpy
import PIL.Image
from mednet.data.image_utils import remove_black_borders
from medbase.data.image_utils import remove_black_borders
def test_remove_black_borders(datadir):
......
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
import unittest
import mednet.config.models.pasa as pasa_config
from medbase.utils.summary import summary
class Tester(unittest.TestCase):
"""Unit test for model architectures."""
def test_summary_driu(self):
model = pasa_config.model
s, param = summary(model)
self.assertIsInstance(s, str)
self.assertIsInstance(param, int)
......@@ -5,8 +5,9 @@
import numpy
import PIL.Image
import torchvision.transforms.functional as F # noqa: N812
from mednet.data.augmentations import ElasticDeformation
import torchvision.transforms.functional as F # noqa: N812
from medbase.data.augmentations import ElasticDeformation
def test_elastic_deformation(datadir):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment