From e088d2050b3cfa859fd1c7cb7d8af7c91c8cbacb Mon Sep 17 00:00:00 2001 From: dcarron <daniel.carron@idiap.ch> Date: Mon, 17 Jun 2024 16:15:29 +0200 Subject: [PATCH] [mednet.scripts] Move top-level imports inside functions --- src/mednet/libs/classification/scripts/train.py | 14 +++++++------- src/mednet/libs/segmentation/scripts/evaluate.py | 7 +++---- src/mednet/libs/segmentation/scripts/predict.py | 13 +++++++------ src/mednet/libs/segmentation/scripts/train.py | 14 +++++++------- 4 files changed, 24 insertions(+), 24 deletions(-) diff --git a/src/mednet/libs/classification/scripts/train.py b/src/mednet/libs/classification/scripts/train.py index 58c99cad..2fc53078 100644 --- a/src/mednet/libs/classification/scripts/train.py +++ b/src/mednet/libs/classification/scripts/train.py @@ -2,13 +2,7 @@ import click from clapper.click import ResourceOption, verbosity_option from clapper.logging import setup from mednet.libs.common.scripts.click import ConfigCommand -from mednet.libs.common.scripts.train import ( - get_checkpoint_file, - load_checkpoint, - reusable_options, - save_json_data, - setup_datamodule, -) +from mednet.libs.common.scripts.train import reusable_options logger = setup("mednet", format="%(levelname)s: %(message)s") @@ -67,6 +61,12 @@ def train( from lightning.pytorch import seed_everything from mednet.libs.common.engine.device import DeviceManager from mednet.libs.common.engine.trainer import run + from mednet.libs.common.scripts.train import ( + get_checkpoint_file, + load_checkpoint, + save_json_data, + setup_datamodule, + ) seed_everything(seed) diff --git a/src/mednet/libs/segmentation/scripts/evaluate.py b/src/mednet/libs/segmentation/scripts/evaluate.py index a76b0081..9be44595 100644 --- a/src/mednet/libs/segmentation/scripts/evaluate.py +++ b/src/mednet/libs/segmentation/scripts/evaluate.py @@ -7,16 +7,12 @@ import pathlib import typing import click -import pandas from clapper.click import ResourceOption, verbosity_option from clapper.logging import setup from mednet.libs.common.scripts.click import ConfigCommand -from tqdm import tqdm logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") -from mednet.libs.segmentation.engine.evaluator import run - @click.command( entry_point_group="mednet.libs.segmentation.config", @@ -150,10 +146,13 @@ def evaluate( ): # numpydoc ignore=PR01 """Evaluate predictions (from a model) on a segmentation task.""" + import pandas from mednet.libs.common.scripts.utils import ( execution_metadata, save_json_with_backup, ) + from mednet.libs.segmentation.engine.evaluator import run + from tqdm import tqdm def _validate_threshold(threshold: float | str, splits: list[str]): """Validate the user threshold selection and returns parsed threshold. diff --git a/src/mednet/libs/segmentation/scripts/predict.py b/src/mednet/libs/segmentation/scripts/predict.py index b91400ac..023b0162 100644 --- a/src/mednet/libs/segmentation/scripts/predict.py +++ b/src/mednet/libs/segmentation/scripts/predict.py @@ -6,21 +6,19 @@ import json import pathlib import click -import h5py -import PIL from clapper.click import ResourceOption, verbosity_option from clapper.logging import setup from mednet.libs.common.scripts.click import ConfigCommand from mednet.libs.common.scripts.predict import reusable_options -from tqdm import tqdm +from PIL import Image logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") def _save_hdf5( - img: PIL.Image.Image, - target: PIL.Image.Image, - mask: PIL.Image.Image, + img: Image, + target: Image, + mask: Image, hdf5_path: pathlib.Path, ) -> None: """Save prediction image, target and mask in an hdf5 file. @@ -37,6 +35,9 @@ def _save_hdf5( File in which to save the data. """ + import h5py + from tqdm import tqdm + tqdm.write(f"Saving {hdf5_path}...") hdf5_path.parent.mkdir(parents=True, exist_ok=True) with h5py.File(hdf5_path, "w") as f: diff --git a/src/mednet/libs/segmentation/scripts/train.py b/src/mednet/libs/segmentation/scripts/train.py index 5089ccab..308d2e30 100644 --- a/src/mednet/libs/segmentation/scripts/train.py +++ b/src/mednet/libs/segmentation/scripts/train.py @@ -2,13 +2,7 @@ import click from clapper.click import ResourceOption, verbosity_option from clapper.logging import setup from mednet.libs.common.scripts.click import ConfigCommand -from mednet.libs.common.scripts.train import ( - get_checkpoint_file, - load_checkpoint, - reusable_options, - save_json_data, - setup_datamodule, -) +from mednet.libs.common.scripts.train import reusable_options logger = setup("mednet", format="%(levelname)s: %(message)s") @@ -54,6 +48,12 @@ def train( from lightning.pytorch import seed_everything from mednet.libs.common.engine.device import DeviceManager from mednet.libs.common.engine.trainer import run + from mednet.libs.common.scripts.train import ( + get_checkpoint_file, + load_checkpoint, + save_json_data, + setup_datamodule, + ) seed_everything(seed) -- GitLab