From e088d2050b3cfa859fd1c7cb7d8af7c91c8cbacb Mon Sep 17 00:00:00 2001
From: dcarron <daniel.carron@idiap.ch>
Date: Mon, 17 Jun 2024 16:15:29 +0200
Subject: [PATCH] [mednet.scripts] Move top-level imports inside functions

---
 src/mednet/libs/classification/scripts/train.py  | 14 +++++++-------
 src/mednet/libs/segmentation/scripts/evaluate.py |  7 +++----
 src/mednet/libs/segmentation/scripts/predict.py  | 13 +++++++------
 src/mednet/libs/segmentation/scripts/train.py    | 14 +++++++-------
 4 files changed, 24 insertions(+), 24 deletions(-)

diff --git a/src/mednet/libs/classification/scripts/train.py b/src/mednet/libs/classification/scripts/train.py
index 58c99cad..2fc53078 100644
--- a/src/mednet/libs/classification/scripts/train.py
+++ b/src/mednet/libs/classification/scripts/train.py
@@ -2,13 +2,7 @@ import click
 from clapper.click import ResourceOption, verbosity_option
 from clapper.logging import setup
 from mednet.libs.common.scripts.click import ConfigCommand
-from mednet.libs.common.scripts.train import (
-    get_checkpoint_file,
-    load_checkpoint,
-    reusable_options,
-    save_json_data,
-    setup_datamodule,
-)
+from mednet.libs.common.scripts.train import reusable_options
 
 logger = setup("mednet", format="%(levelname)s: %(message)s")
 
@@ -67,6 +61,12 @@ def train(
     from lightning.pytorch import seed_everything
     from mednet.libs.common.engine.device import DeviceManager
     from mednet.libs.common.engine.trainer import run
+    from mednet.libs.common.scripts.train import (
+        get_checkpoint_file,
+        load_checkpoint,
+        save_json_data,
+        setup_datamodule,
+    )
 
     seed_everything(seed)
 
diff --git a/src/mednet/libs/segmentation/scripts/evaluate.py b/src/mednet/libs/segmentation/scripts/evaluate.py
index a76b0081..9be44595 100644
--- a/src/mednet/libs/segmentation/scripts/evaluate.py
+++ b/src/mednet/libs/segmentation/scripts/evaluate.py
@@ -7,16 +7,12 @@ import pathlib
 import typing
 
 import click
-import pandas
 from clapper.click import ResourceOption, verbosity_option
 from clapper.logging import setup
 from mednet.libs.common.scripts.click import ConfigCommand
-from tqdm import tqdm
 
 logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
 
-from mednet.libs.segmentation.engine.evaluator import run
-
 
 @click.command(
     entry_point_group="mednet.libs.segmentation.config",
@@ -150,10 +146,13 @@ def evaluate(
 ):  # numpydoc ignore=PR01
     """Evaluate predictions (from a model) on a segmentation task."""
 
+    import pandas
     from mednet.libs.common.scripts.utils import (
         execution_metadata,
         save_json_with_backup,
     )
+    from mednet.libs.segmentation.engine.evaluator import run
+    from tqdm import tqdm
 
     def _validate_threshold(threshold: float | str, splits: list[str]):
         """Validate the user threshold selection and returns parsed threshold.
diff --git a/src/mednet/libs/segmentation/scripts/predict.py b/src/mednet/libs/segmentation/scripts/predict.py
index b91400ac..023b0162 100644
--- a/src/mednet/libs/segmentation/scripts/predict.py
+++ b/src/mednet/libs/segmentation/scripts/predict.py
@@ -6,21 +6,19 @@ import json
 import pathlib
 
 import click
-import h5py
-import PIL
 from clapper.click import ResourceOption, verbosity_option
 from clapper.logging import setup
 from mednet.libs.common.scripts.click import ConfigCommand
 from mednet.libs.common.scripts.predict import reusable_options
-from tqdm import tqdm
+from PIL import Image
 
 logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
 
 
 def _save_hdf5(
-    img: PIL.Image.Image,
-    target: PIL.Image.Image,
-    mask: PIL.Image.Image,
+    img: Image,
+    target: Image,
+    mask: Image,
     hdf5_path: pathlib.Path,
 ) -> None:
     """Save prediction image, target and mask in an hdf5 file.
@@ -37,6 +35,9 @@ def _save_hdf5(
         File in which to save the data.
     """
 
+    import h5py
+    from tqdm import tqdm
+
     tqdm.write(f"Saving {hdf5_path}...")
     hdf5_path.parent.mkdir(parents=True, exist_ok=True)
     with h5py.File(hdf5_path, "w") as f:
diff --git a/src/mednet/libs/segmentation/scripts/train.py b/src/mednet/libs/segmentation/scripts/train.py
index 5089ccab..308d2e30 100644
--- a/src/mednet/libs/segmentation/scripts/train.py
+++ b/src/mednet/libs/segmentation/scripts/train.py
@@ -2,13 +2,7 @@ import click
 from clapper.click import ResourceOption, verbosity_option
 from clapper.logging import setup
 from mednet.libs.common.scripts.click import ConfigCommand
-from mednet.libs.common.scripts.train import (
-    get_checkpoint_file,
-    load_checkpoint,
-    reusable_options,
-    save_json_data,
-    setup_datamodule,
-)
+from mednet.libs.common.scripts.train import reusable_options
 
 logger = setup("mednet", format="%(levelname)s: %(message)s")
 
@@ -54,6 +48,12 @@ def train(
     from lightning.pytorch import seed_everything
     from mednet.libs.common.engine.device import DeviceManager
     from mednet.libs.common.engine.trainer import run
+    from mednet.libs.common.scripts.train import (
+        get_checkpoint_file,
+        load_checkpoint,
+        save_json_data,
+        setup_datamodule,
+    )
 
     seed_everything(seed)
 
-- 
GitLab