diff --git a/src/mednet/libs/classification/scripts/train.py b/src/mednet/libs/classification/scripts/train.py index 58c99cad6c1477e336544ebd9bc25fc197c3aa3a..2fc53078e32af07e20d03301c6351ad43aac5731 100644 --- a/src/mednet/libs/classification/scripts/train.py +++ b/src/mednet/libs/classification/scripts/train.py @@ -2,13 +2,7 @@ import click from clapper.click import ResourceOption, verbosity_option from clapper.logging import setup from mednet.libs.common.scripts.click import ConfigCommand -from mednet.libs.common.scripts.train import ( - get_checkpoint_file, - load_checkpoint, - reusable_options, - save_json_data, - setup_datamodule, -) +from mednet.libs.common.scripts.train import reusable_options logger = setup("mednet", format="%(levelname)s: %(message)s") @@ -67,6 +61,12 @@ def train( from lightning.pytorch import seed_everything from mednet.libs.common.engine.device import DeviceManager from mednet.libs.common.engine.trainer import run + from mednet.libs.common.scripts.train import ( + get_checkpoint_file, + load_checkpoint, + save_json_data, + setup_datamodule, + ) seed_everything(seed) diff --git a/src/mednet/libs/segmentation/scripts/evaluate.py b/src/mednet/libs/segmentation/scripts/evaluate.py index a76b00816980646bd5a0b47bc05a04eeaf0d1b11..9be445953d558e76d16687a512459b537de60dc6 100644 --- a/src/mednet/libs/segmentation/scripts/evaluate.py +++ b/src/mednet/libs/segmentation/scripts/evaluate.py @@ -7,16 +7,12 @@ import pathlib import typing import click -import pandas from clapper.click import ResourceOption, verbosity_option from clapper.logging import setup from mednet.libs.common.scripts.click import ConfigCommand -from tqdm import tqdm logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") -from mednet.libs.segmentation.engine.evaluator import run - @click.command( entry_point_group="mednet.libs.segmentation.config", @@ -150,10 +146,13 @@ def evaluate( ): # numpydoc ignore=PR01 """Evaluate predictions (from a model) on a segmentation task.""" + import pandas from mednet.libs.common.scripts.utils import ( execution_metadata, save_json_with_backup, ) + from mednet.libs.segmentation.engine.evaluator import run + from tqdm import tqdm def _validate_threshold(threshold: float | str, splits: list[str]): """Validate the user threshold selection and returns parsed threshold. diff --git a/src/mednet/libs/segmentation/scripts/predict.py b/src/mednet/libs/segmentation/scripts/predict.py index b91400ac93dc98e43ce897c4efcc7e2f2c47bc93..023b016258561cb30b698df0eb1fef89754383db 100644 --- a/src/mednet/libs/segmentation/scripts/predict.py +++ b/src/mednet/libs/segmentation/scripts/predict.py @@ -6,21 +6,19 @@ import json import pathlib import click -import h5py -import PIL from clapper.click import ResourceOption, verbosity_option from clapper.logging import setup from mednet.libs.common.scripts.click import ConfigCommand from mednet.libs.common.scripts.predict import reusable_options -from tqdm import tqdm +from PIL import Image logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") def _save_hdf5( - img: PIL.Image.Image, - target: PIL.Image.Image, - mask: PIL.Image.Image, + img: Image, + target: Image, + mask: Image, hdf5_path: pathlib.Path, ) -> None: """Save prediction image, target and mask in an hdf5 file. @@ -37,6 +35,9 @@ def _save_hdf5( File in which to save the data. """ + import h5py + from tqdm import tqdm + tqdm.write(f"Saving {hdf5_path}...") hdf5_path.parent.mkdir(parents=True, exist_ok=True) with h5py.File(hdf5_path, "w") as f: diff --git a/src/mednet/libs/segmentation/scripts/train.py b/src/mednet/libs/segmentation/scripts/train.py index 5089ccab1d865ae8465e26686c7eec290a80c26f..308d2e300c608b0d11c9b60ca36244a4bad02e2c 100644 --- a/src/mednet/libs/segmentation/scripts/train.py +++ b/src/mednet/libs/segmentation/scripts/train.py @@ -2,13 +2,7 @@ import click from clapper.click import ResourceOption, verbosity_option from clapper.logging import setup from mednet.libs.common.scripts.click import ConfigCommand -from mednet.libs.common.scripts.train import ( - get_checkpoint_file, - load_checkpoint, - reusable_options, - save_json_data, - setup_datamodule, -) +from mednet.libs.common.scripts.train import reusable_options logger = setup("mednet", format="%(levelname)s: %(message)s") @@ -54,6 +48,12 @@ def train( from lightning.pytorch import seed_everything from mednet.libs.common.engine.device import DeviceManager from mednet.libs.common.engine.trainer import run + from mednet.libs.common.scripts.train import ( + get_checkpoint_file, + load_checkpoint, + save_json_data, + setup_datamodule, + ) seed_everything(seed)