Skip to content
Snippets Groups Projects
Commit e088d205 authored by Daniel CARRON's avatar Daniel CARRON :b: Committed by André Anjos
Browse files

[mednet.scripts] Move top-level imports inside functions

parent bc097bf5
No related branches found
No related tags found
1 merge request!46Create common library
...@@ -2,13 +2,7 @@ import click ...@@ -2,13 +2,7 @@ import click
from clapper.click import ResourceOption, verbosity_option from clapper.click import ResourceOption, verbosity_option
from clapper.logging import setup from clapper.logging import setup
from mednet.libs.common.scripts.click import ConfigCommand from mednet.libs.common.scripts.click import ConfigCommand
from mednet.libs.common.scripts.train import ( from mednet.libs.common.scripts.train import reusable_options
get_checkpoint_file,
load_checkpoint,
reusable_options,
save_json_data,
setup_datamodule,
)
logger = setup("mednet", format="%(levelname)s: %(message)s") logger = setup("mednet", format="%(levelname)s: %(message)s")
...@@ -67,6 +61,12 @@ def train( ...@@ -67,6 +61,12 @@ def train(
from lightning.pytorch import seed_everything from lightning.pytorch import seed_everything
from mednet.libs.common.engine.device import DeviceManager from mednet.libs.common.engine.device import DeviceManager
from mednet.libs.common.engine.trainer import run from mednet.libs.common.engine.trainer import run
from mednet.libs.common.scripts.train import (
get_checkpoint_file,
load_checkpoint,
save_json_data,
setup_datamodule,
)
seed_everything(seed) seed_everything(seed)
......
...@@ -7,16 +7,12 @@ import pathlib ...@@ -7,16 +7,12 @@ import pathlib
import typing import typing
import click import click
import pandas
from clapper.click import ResourceOption, verbosity_option from clapper.click import ResourceOption, verbosity_option
from clapper.logging import setup from clapper.logging import setup
from mednet.libs.common.scripts.click import ConfigCommand from mednet.libs.common.scripts.click import ConfigCommand
from tqdm import tqdm
logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
from mednet.libs.segmentation.engine.evaluator import run
@click.command( @click.command(
entry_point_group="mednet.libs.segmentation.config", entry_point_group="mednet.libs.segmentation.config",
...@@ -150,10 +146,13 @@ def evaluate( ...@@ -150,10 +146,13 @@ def evaluate(
): # numpydoc ignore=PR01 ): # numpydoc ignore=PR01
"""Evaluate predictions (from a model) on a segmentation task.""" """Evaluate predictions (from a model) on a segmentation task."""
import pandas
from mednet.libs.common.scripts.utils import ( from mednet.libs.common.scripts.utils import (
execution_metadata, execution_metadata,
save_json_with_backup, save_json_with_backup,
) )
from mednet.libs.segmentation.engine.evaluator import run
from tqdm import tqdm
def _validate_threshold(threshold: float | str, splits: list[str]): def _validate_threshold(threshold: float | str, splits: list[str]):
"""Validate the user threshold selection and returns parsed threshold. """Validate the user threshold selection and returns parsed threshold.
......
...@@ -6,21 +6,19 @@ import json ...@@ -6,21 +6,19 @@ import json
import pathlib import pathlib
import click import click
import h5py
import PIL
from clapper.click import ResourceOption, verbosity_option from clapper.click import ResourceOption, verbosity_option
from clapper.logging import setup from clapper.logging import setup
from mednet.libs.common.scripts.click import ConfigCommand from mednet.libs.common.scripts.click import ConfigCommand
from mednet.libs.common.scripts.predict import reusable_options from mednet.libs.common.scripts.predict import reusable_options
from tqdm import tqdm from PIL import Image
logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
def _save_hdf5( def _save_hdf5(
img: PIL.Image.Image, img: Image,
target: PIL.Image.Image, target: Image,
mask: PIL.Image.Image, mask: Image,
hdf5_path: pathlib.Path, hdf5_path: pathlib.Path,
) -> None: ) -> None:
"""Save prediction image, target and mask in an hdf5 file. """Save prediction image, target and mask in an hdf5 file.
...@@ -37,6 +35,9 @@ def _save_hdf5( ...@@ -37,6 +35,9 @@ def _save_hdf5(
File in which to save the data. File in which to save the data.
""" """
import h5py
from tqdm import tqdm
tqdm.write(f"Saving {hdf5_path}...") tqdm.write(f"Saving {hdf5_path}...")
hdf5_path.parent.mkdir(parents=True, exist_ok=True) hdf5_path.parent.mkdir(parents=True, exist_ok=True)
with h5py.File(hdf5_path, "w") as f: with h5py.File(hdf5_path, "w") as f:
......
...@@ -2,13 +2,7 @@ import click ...@@ -2,13 +2,7 @@ import click
from clapper.click import ResourceOption, verbosity_option from clapper.click import ResourceOption, verbosity_option
from clapper.logging import setup from clapper.logging import setup
from mednet.libs.common.scripts.click import ConfigCommand from mednet.libs.common.scripts.click import ConfigCommand
from mednet.libs.common.scripts.train import ( from mednet.libs.common.scripts.train import reusable_options
get_checkpoint_file,
load_checkpoint,
reusable_options,
save_json_data,
setup_datamodule,
)
logger = setup("mednet", format="%(levelname)s: %(message)s") logger = setup("mednet", format="%(levelname)s: %(message)s")
...@@ -54,6 +48,12 @@ def train( ...@@ -54,6 +48,12 @@ def train(
from lightning.pytorch import seed_everything from lightning.pytorch import seed_everything
from mednet.libs.common.engine.device import DeviceManager from mednet.libs.common.engine.device import DeviceManager
from mednet.libs.common.engine.trainer import run from mednet.libs.common.engine.trainer import run
from mednet.libs.common.scripts.train import (
get_checkpoint_file,
load_checkpoint,
save_json_data,
setup_datamodule,
)
seed_everything(seed) seed_everything(seed)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment