From 501be8c7f4631f7183a2555adefc6794080eb48c Mon Sep 17 00:00:00 2001 From: dcarron <daniel.carron@idiap.ch> Date: Thu, 11 Apr 2024 15:19:56 +0200 Subject: [PATCH] [logging] Hardcode logger names Since the core module and mednet-specific code have been split in separate modules,getLogger(__name__) was returning different results. The logger names have been hardcoded for now to fix issues during testing but a better solution needs to be found. --- src/medbase/data/augmentations.py | 2 +- src/medbase/data/datamodule.py | 2 +- src/medbase/data/split.py | 2 +- src/medbase/engine/callbacks.py | 2 +- src/medbase/engine/device.py | 2 +- src/medbase/engine/trainer.py | 2 +- src/medbase/utils/checkpointer.py | 2 +- src/medbase/utils/resources.py | 2 +- src/mednet/engine/evaluator.py | 2 +- src/mednet/engine/predictor.py | 3 +-- src/mednet/engine/saliency/completeness.py | 7 +++---- src/mednet/engine/saliency/generator.py | 3 +-- src/mednet/engine/saliency/interpretability.py | 2 +- src/mednet/engine/saliency/viewer.py | 2 +- src/mednet/models/alexnet.py | 3 +-- src/mednet/models/densenet.py | 3 +-- src/mednet/models/loss_weights.py | 3 +-- src/mednet/models/pasa.py | 3 +-- src/mednet/scripts/train.py | 4 ++-- src/mednet/scripts/utils.py | 2 +- 20 files changed, 23 insertions(+), 30 deletions(-) diff --git a/src/medbase/data/augmentations.py b/src/medbase/data/augmentations.py index fcf8887d..ff3f438e 100644 --- a/src/medbase/data/augmentations.py +++ b/src/medbase/data/augmentations.py @@ -22,7 +22,7 @@ import numpy.typing import torch from scipy.ndimage import gaussian_filter, map_coordinates -logger = logging.getLogger(__name__) +logger = logging.getLogger("mednet.medbase") def _elastic_deformation_on_image( diff --git a/src/medbase/data/datamodule.py b/src/medbase/data/datamodule.py index ff627d03..1d6c14eb 100644 --- a/src/medbase/data/datamodule.py +++ b/src/medbase/data/datamodule.py @@ -27,7 +27,7 @@ from .typing import ( TransformSequence, ) -logger = logging.getLogger(__name__) +logger = logging.getLogger("mednet.medbase") def _sample_size_bytes(s: Sample) -> int: diff --git a/src/medbase/data/split.py b/src/medbase/data/split.py index 0bcdf300..3b412d96 100644 --- a/src/medbase/data/split.py +++ b/src/medbase/data/split.py @@ -14,7 +14,7 @@ import torch from medbase.data.typing import DatabaseSplit, RawDataLoader -logger = logging.getLogger(__name__) +logger = logging.getLogger("mednet.medbase") class JSONDatabaseSplit(DatabaseSplit): diff --git a/src/medbase/engine/callbacks.py b/src/medbase/engine/callbacks.py index 3d108a8d..bd4755eb 100644 --- a/src/medbase/engine/callbacks.py +++ b/src/medbase/engine/callbacks.py @@ -13,7 +13,7 @@ import torch from ..utils.resources import ResourceMonitor, aggregate -logger = logging.getLogger(__name__) +logger = logging.getLogger("mednet.medbase") class LoggingCallback(lightning.pytorch.Callback): diff --git a/src/medbase/engine/device.py b/src/medbase/engine/device.py index 4375cbb3..851b0539 100644 --- a/src/medbase/engine/device.py +++ b/src/medbase/engine/device.py @@ -9,7 +9,7 @@ import typing import torch import torch.backends -logger = logging.getLogger(__name__) +logger = logging.getLogger("mednet.medbase") SupportedPytorchDevice: typing.TypeAlias = typing.Literal[ diff --git a/src/medbase/engine/trainer.py b/src/medbase/engine/trainer.py index 5d3f2a0d..79b509fe 100644 --- a/src/medbase/engine/trainer.py +++ b/src/medbase/engine/trainer.py @@ -15,7 +15,7 @@ from ..utils.resources import ResourceMonitor from .callbacks import LoggingCallback from .device import DeviceManager -logger = logging.getLogger(__name__) +logger = logging.getLogger("mednet.medbase") def run( diff --git a/src/medbase/utils/checkpointer.py b/src/medbase/utils/checkpointer.py index 19450a13..44e61da8 100644 --- a/src/medbase/utils/checkpointer.py +++ b/src/medbase/utils/checkpointer.py @@ -7,7 +7,7 @@ import pathlib import re import typing -logger = logging.getLogger(__name__) +logger = logging.getLogger("mednet.medbase") CHECKPOINT_ALIASES = { diff --git a/src/medbase/utils/resources.py b/src/medbase/utils/resources.py index 58167711..645bf76a 100644 --- a/src/medbase/utils/resources.py +++ b/src/medbase/utils/resources.py @@ -20,7 +20,7 @@ import psutil from medbase.engine.device import SupportedPytorchDevice -logger = logging.getLogger(__name__) +logger = logging.getLogger("mednet") _nvidia_smi = shutil.which("nvidia-smi") """Location of the nvidia-smi program, if one exists.""" diff --git a/src/mednet/engine/evaluator.py b/src/mednet/engine/evaluator.py index acbf79b3..598fb3b8 100644 --- a/src/mednet/engine/evaluator.py +++ b/src/mednet/engine/evaluator.py @@ -20,7 +20,7 @@ from matplotlib import pyplot as plt from ..models.typing import BinaryPrediction -logger = logging.getLogger(__name__) +logger = logging.getLogger("mednet") def eer_threshold(predictions: Iterable[BinaryPrediction]) -> float: diff --git a/src/mednet/engine/predictor.py b/src/mednet/engine/predictor.py index fce844a1..5a5b3545 100644 --- a/src/mednet/engine/predictor.py +++ b/src/mednet/engine/predictor.py @@ -6,7 +6,6 @@ import logging import lightning.pytorch import torch.utils.data - from medbase.engine.device import DeviceManager from ..models.typing import ( @@ -16,7 +15,7 @@ from ..models.typing import ( MultiClassPredictionSplit, ) -logger = logging.getLogger(__name__) +logger = logging.getLogger("mednet") def run( diff --git a/src/mednet/engine/saliency/completeness.py b/src/mednet/engine/saliency/completeness.py index 7c3cb3c6..84327ad4 100644 --- a/src/mednet/engine/saliency/completeness.py +++ b/src/mednet/engine/saliency/completeness.py @@ -11,18 +11,17 @@ import lightning.pytorch import numpy as np import torch import tqdm +from medbase.data.typing import Sample +from medbase.engine.device import DeviceManager from pytorch_grad_cam.metrics.road import ( ROADLeastRelevantFirstAverage, ROADMostRelevantFirstAverage, ) from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget -from medbase.data.typing import Sample -from medbase.engine.device import DeviceManager - from ...models.typing import SaliencyMapAlgorithm -logger = logging.getLogger(__name__) +logger = logging.getLogger("mednet") class SigmoidClassifierOutputTarget(torch.nn.Module): diff --git a/src/mednet/engine/saliency/generator.py b/src/mednet/engine/saliency/generator.py index f440eb08..ddad6ce7 100644 --- a/src/mednet/engine/saliency/generator.py +++ b/src/mednet/engine/saliency/generator.py @@ -11,12 +11,11 @@ import numpy import torch import torch.nn import tqdm - from medbase.engine.device import DeviceManager from ...models.typing import SaliencyMapAlgorithm -logger = logging.getLogger(__name__) +logger = logging.getLogger("mednet") def _create_saliency_map_callable( diff --git a/src/mednet/engine/saliency/interpretability.py b/src/mednet/engine/saliency/interpretability.py index 54c72a42..621aa9d6 100644 --- a/src/mednet/engine/saliency/interpretability.py +++ b/src/mednet/engine/saliency/interpretability.py @@ -16,7 +16,7 @@ from tqdm import tqdm from ...config.data.tbx11k.datamodule import BoundingBox, BoundingBoxes -logger = logging.getLogger(__name__) +logger = logging.getLogger("mednet") SaliencyMap: typing.TypeAlias = ( typing.Sequence[typing.Sequence[float]] | numpy.typing.NDArray[numpy.double] diff --git a/src/mednet/engine/saliency/viewer.py b/src/mednet/engine/saliency/viewer.py index d969aa91..588ba70c 100644 --- a/src/mednet/engine/saliency/viewer.py +++ b/src/mednet/engine/saliency/viewer.py @@ -17,7 +17,7 @@ from tqdm import tqdm from ...config.data.tbx11k.datamodule import BoundingBox, BoundingBoxes -logger = logging.getLogger(__name__) +logger = logging.getLogger("mednet") def _overlay_saliency_map( diff --git a/src/mednet/models/alexnet.py b/src/mednet/models/alexnet.py index 22c41547..5c686de3 100644 --- a/src/mednet/models/alexnet.py +++ b/src/mednet/models/alexnet.py @@ -11,13 +11,12 @@ import torch.optim.optimizer import torch.utils.data import torchvision.models as models import torchvision.transforms - from medbase.data.typing import TransformSequence from .model import Model from .separate import separate from .transforms import RGB, SquareCenterPad -logger = logging.getLogger(__name__) +logger = logging.getLogger("mednet") class Alexnet(Model): diff --git a/src/mednet/models/densenet.py b/src/mednet/models/densenet.py index 3feb2ade..e9433b99 100644 --- a/src/mednet/models/densenet.py +++ b/src/mednet/models/densenet.py @@ -11,13 +11,12 @@ import torch.optim.optimizer import torch.utils.data import torchvision.models as models import torchvision.transforms - from medbase.data.typing import TransformSequence from .model import Model from .separate import separate from .transforms import RGB, SquareCenterPad -logger = logging.getLogger(__name__) +logger = logging.getLogger("mednet") class Densenet(Model): diff --git a/src/mednet/models/loss_weights.py b/src/mednet/models/loss_weights.py index 382a7ae5..2264a44d 100644 --- a/src/mednet/models/loss_weights.py +++ b/src/mednet/models/loss_weights.py @@ -8,10 +8,9 @@ from collections import Counter import torch import torch.utils.data - from medbase.data.typing import DataLoader -logger = logging.getLogger(__name__) +logger = logging.getLogger("mednet") def compute_binary_weights(targets): diff --git a/src/mednet/models/pasa.py b/src/mednet/models/pasa.py index 0e1d24f1..f6e83d83 100644 --- a/src/mednet/models/pasa.py +++ b/src/mednet/models/pasa.py @@ -11,13 +11,12 @@ import torch.nn.functional as F # noqa: N812 import torch.optim.optimizer import torch.utils.data import torchvision.transforms - from medbase.data.typing import TransformSequence from .model import Model from .separate import separate from .transforms import Grayscale, SquareCenterPad -logger = logging.getLogger(__name__) +logger = logging.getLogger("mednet") class Pasa(Model): diff --git a/src/mednet/scripts/train.py b/src/mednet/scripts/train.py index df8b1fec..2fd63221 100644 --- a/src/mednet/scripts/train.py +++ b/src/mednet/scripts/train.py @@ -12,7 +12,8 @@ from clapper.logging import setup from .click import ConfigCommand -logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") +# logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") +logger = setup("mednet", format="%(levelname)s: %(message)s") def reusable_options(f): @@ -271,7 +272,6 @@ def train( import torch from lightning.pytorch import seed_everything - from medbase.engine.device import DeviceManager from medbase.engine.trainer import run from medbase.utils.checkpointer import get_checkpoint_to_resume_training diff --git a/src/mednet/scripts/utils.py b/src/mednet/scripts/utils.py index 4c894bd7..164f46b4 100644 --- a/src/mednet/scripts/utils.py +++ b/src/mednet/scripts/utils.py @@ -14,7 +14,7 @@ import lightning.pytorch.callbacks import torch.nn from medbase.engine.device import SupportedPytorchDevice -logger = logging.getLogger(__name__) +logger = logging.getLogger("mednet") def model_summary( -- GitLab