From 501be8c7f4631f7183a2555adefc6794080eb48c Mon Sep 17 00:00:00 2001
From: dcarron <daniel.carron@idiap.ch>
Date: Thu, 11 Apr 2024 15:19:56 +0200
Subject: [PATCH] [logging] Hardcode logger names

Since the core module and mednet-specific code have been split in
separate modules,getLogger(__name__) was returning different results.
The logger names have been hardcoded for now to fix issues during
testing but a better solution needs to be found.
---
 src/medbase/data/augmentations.py              | 2 +-
 src/medbase/data/datamodule.py                 | 2 +-
 src/medbase/data/split.py                      | 2 +-
 src/medbase/engine/callbacks.py                | 2 +-
 src/medbase/engine/device.py                   | 2 +-
 src/medbase/engine/trainer.py                  | 2 +-
 src/medbase/utils/checkpointer.py              | 2 +-
 src/medbase/utils/resources.py                 | 2 +-
 src/mednet/engine/evaluator.py                 | 2 +-
 src/mednet/engine/predictor.py                 | 3 +--
 src/mednet/engine/saliency/completeness.py     | 7 +++----
 src/mednet/engine/saliency/generator.py        | 3 +--
 src/mednet/engine/saliency/interpretability.py | 2 +-
 src/mednet/engine/saliency/viewer.py           | 2 +-
 src/mednet/models/alexnet.py                   | 3 +--
 src/mednet/models/densenet.py                  | 3 +--
 src/mednet/models/loss_weights.py              | 3 +--
 src/mednet/models/pasa.py                      | 3 +--
 src/mednet/scripts/train.py                    | 4 ++--
 src/mednet/scripts/utils.py                    | 2 +-
 20 files changed, 23 insertions(+), 30 deletions(-)

diff --git a/src/medbase/data/augmentations.py b/src/medbase/data/augmentations.py
index fcf8887d..ff3f438e 100644
--- a/src/medbase/data/augmentations.py
+++ b/src/medbase/data/augmentations.py
@@ -22,7 +22,7 @@ import numpy.typing
 import torch
 from scipy.ndimage import gaussian_filter, map_coordinates
 
-logger = logging.getLogger(__name__)
+logger = logging.getLogger("mednet.medbase")
 
 
 def _elastic_deformation_on_image(
diff --git a/src/medbase/data/datamodule.py b/src/medbase/data/datamodule.py
index ff627d03..1d6c14eb 100644
--- a/src/medbase/data/datamodule.py
+++ b/src/medbase/data/datamodule.py
@@ -27,7 +27,7 @@ from .typing import (
     TransformSequence,
 )
 
-logger = logging.getLogger(__name__)
+logger = logging.getLogger("mednet.medbase")
 
 
 def _sample_size_bytes(s: Sample) -> int:
diff --git a/src/medbase/data/split.py b/src/medbase/data/split.py
index 0bcdf300..3b412d96 100644
--- a/src/medbase/data/split.py
+++ b/src/medbase/data/split.py
@@ -14,7 +14,7 @@ import torch
 
 from medbase.data.typing import DatabaseSplit, RawDataLoader
 
-logger = logging.getLogger(__name__)
+logger = logging.getLogger("mednet.medbase")
 
 
 class JSONDatabaseSplit(DatabaseSplit):
diff --git a/src/medbase/engine/callbacks.py b/src/medbase/engine/callbacks.py
index 3d108a8d..bd4755eb 100644
--- a/src/medbase/engine/callbacks.py
+++ b/src/medbase/engine/callbacks.py
@@ -13,7 +13,7 @@ import torch
 
 from ..utils.resources import ResourceMonitor, aggregate
 
-logger = logging.getLogger(__name__)
+logger = logging.getLogger("mednet.medbase")
 
 
 class LoggingCallback(lightning.pytorch.Callback):
diff --git a/src/medbase/engine/device.py b/src/medbase/engine/device.py
index 4375cbb3..851b0539 100644
--- a/src/medbase/engine/device.py
+++ b/src/medbase/engine/device.py
@@ -9,7 +9,7 @@ import typing
 import torch
 import torch.backends
 
-logger = logging.getLogger(__name__)
+logger = logging.getLogger("mednet.medbase")
 
 
 SupportedPytorchDevice: typing.TypeAlias = typing.Literal[
diff --git a/src/medbase/engine/trainer.py b/src/medbase/engine/trainer.py
index 5d3f2a0d..79b509fe 100644
--- a/src/medbase/engine/trainer.py
+++ b/src/medbase/engine/trainer.py
@@ -15,7 +15,7 @@ from ..utils.resources import ResourceMonitor
 from .callbacks import LoggingCallback
 from .device import DeviceManager
 
-logger = logging.getLogger(__name__)
+logger = logging.getLogger("mednet.medbase")
 
 
 def run(
diff --git a/src/medbase/utils/checkpointer.py b/src/medbase/utils/checkpointer.py
index 19450a13..44e61da8 100644
--- a/src/medbase/utils/checkpointer.py
+++ b/src/medbase/utils/checkpointer.py
@@ -7,7 +7,7 @@ import pathlib
 import re
 import typing
 
-logger = logging.getLogger(__name__)
+logger = logging.getLogger("mednet.medbase")
 
 
 CHECKPOINT_ALIASES = {
diff --git a/src/medbase/utils/resources.py b/src/medbase/utils/resources.py
index 58167711..645bf76a 100644
--- a/src/medbase/utils/resources.py
+++ b/src/medbase/utils/resources.py
@@ -20,7 +20,7 @@ import psutil
 
 from medbase.engine.device import SupportedPytorchDevice
 
-logger = logging.getLogger(__name__)
+logger = logging.getLogger("mednet")
 
 _nvidia_smi = shutil.which("nvidia-smi")
 """Location of the nvidia-smi program, if one exists."""
diff --git a/src/mednet/engine/evaluator.py b/src/mednet/engine/evaluator.py
index acbf79b3..598fb3b8 100644
--- a/src/mednet/engine/evaluator.py
+++ b/src/mednet/engine/evaluator.py
@@ -20,7 +20,7 @@ from matplotlib import pyplot as plt
 
 from ..models.typing import BinaryPrediction
 
-logger = logging.getLogger(__name__)
+logger = logging.getLogger("mednet")
 
 
 def eer_threshold(predictions: Iterable[BinaryPrediction]) -> float:
diff --git a/src/mednet/engine/predictor.py b/src/mednet/engine/predictor.py
index fce844a1..5a5b3545 100644
--- a/src/mednet/engine/predictor.py
+++ b/src/mednet/engine/predictor.py
@@ -6,7 +6,6 @@ import logging
 
 import lightning.pytorch
 import torch.utils.data
-
 from medbase.engine.device import DeviceManager
 
 from ..models.typing import (
@@ -16,7 +15,7 @@ from ..models.typing import (
     MultiClassPredictionSplit,
 )
 
-logger = logging.getLogger(__name__)
+logger = logging.getLogger("mednet")
 
 
 def run(
diff --git a/src/mednet/engine/saliency/completeness.py b/src/mednet/engine/saliency/completeness.py
index 7c3cb3c6..84327ad4 100644
--- a/src/mednet/engine/saliency/completeness.py
+++ b/src/mednet/engine/saliency/completeness.py
@@ -11,18 +11,17 @@ import lightning.pytorch
 import numpy as np
 import torch
 import tqdm
+from medbase.data.typing import Sample
+from medbase.engine.device import DeviceManager
 from pytorch_grad_cam.metrics.road import (
     ROADLeastRelevantFirstAverage,
     ROADMostRelevantFirstAverage,
 )
 from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
 
-from medbase.data.typing import Sample
-from medbase.engine.device import DeviceManager
-
 from ...models.typing import SaliencyMapAlgorithm
 
-logger = logging.getLogger(__name__)
+logger = logging.getLogger("mednet")
 
 
 class SigmoidClassifierOutputTarget(torch.nn.Module):
diff --git a/src/mednet/engine/saliency/generator.py b/src/mednet/engine/saliency/generator.py
index f440eb08..ddad6ce7 100644
--- a/src/mednet/engine/saliency/generator.py
+++ b/src/mednet/engine/saliency/generator.py
@@ -11,12 +11,11 @@ import numpy
 import torch
 import torch.nn
 import tqdm
-
 from medbase.engine.device import DeviceManager
 
 from ...models.typing import SaliencyMapAlgorithm
 
-logger = logging.getLogger(__name__)
+logger = logging.getLogger("mednet")
 
 
 def _create_saliency_map_callable(
diff --git a/src/mednet/engine/saliency/interpretability.py b/src/mednet/engine/saliency/interpretability.py
index 54c72a42..621aa9d6 100644
--- a/src/mednet/engine/saliency/interpretability.py
+++ b/src/mednet/engine/saliency/interpretability.py
@@ -16,7 +16,7 @@ from tqdm import tqdm
 
 from ...config.data.tbx11k.datamodule import BoundingBox, BoundingBoxes
 
-logger = logging.getLogger(__name__)
+logger = logging.getLogger("mednet")
 
 SaliencyMap: typing.TypeAlias = (
     typing.Sequence[typing.Sequence[float]] | numpy.typing.NDArray[numpy.double]
diff --git a/src/mednet/engine/saliency/viewer.py b/src/mednet/engine/saliency/viewer.py
index d969aa91..588ba70c 100644
--- a/src/mednet/engine/saliency/viewer.py
+++ b/src/mednet/engine/saliency/viewer.py
@@ -17,7 +17,7 @@ from tqdm import tqdm
 
 from ...config.data.tbx11k.datamodule import BoundingBox, BoundingBoxes
 
-logger = logging.getLogger(__name__)
+logger = logging.getLogger("mednet")
 
 
 def _overlay_saliency_map(
diff --git a/src/mednet/models/alexnet.py b/src/mednet/models/alexnet.py
index 22c41547..5c686de3 100644
--- a/src/mednet/models/alexnet.py
+++ b/src/mednet/models/alexnet.py
@@ -11,13 +11,12 @@ import torch.optim.optimizer
 import torch.utils.data
 import torchvision.models as models
 import torchvision.transforms
-
 from medbase.data.typing import TransformSequence
 from .model import Model
 from .separate import separate
 from .transforms import RGB, SquareCenterPad
 
-logger = logging.getLogger(__name__)
+logger = logging.getLogger("mednet")
 
 
 class Alexnet(Model):
diff --git a/src/mednet/models/densenet.py b/src/mednet/models/densenet.py
index 3feb2ade..e9433b99 100644
--- a/src/mednet/models/densenet.py
+++ b/src/mednet/models/densenet.py
@@ -11,13 +11,12 @@ import torch.optim.optimizer
 import torch.utils.data
 import torchvision.models as models
 import torchvision.transforms
-
 from medbase.data.typing import TransformSequence
 from .model import Model
 from .separate import separate
 from .transforms import RGB, SquareCenterPad
 
-logger = logging.getLogger(__name__)
+logger = logging.getLogger("mednet")
 
 
 class Densenet(Model):
diff --git a/src/mednet/models/loss_weights.py b/src/mednet/models/loss_weights.py
index 382a7ae5..2264a44d 100644
--- a/src/mednet/models/loss_weights.py
+++ b/src/mednet/models/loss_weights.py
@@ -8,10 +8,9 @@ from collections import Counter
 
 import torch
 import torch.utils.data
-
 from medbase.data.typing import DataLoader
 
-logger = logging.getLogger(__name__)
+logger = logging.getLogger("mednet")
 
 
 def compute_binary_weights(targets):
diff --git a/src/mednet/models/pasa.py b/src/mednet/models/pasa.py
index 0e1d24f1..f6e83d83 100644
--- a/src/mednet/models/pasa.py
+++ b/src/mednet/models/pasa.py
@@ -11,13 +11,12 @@ import torch.nn.functional as F  # noqa: N812
 import torch.optim.optimizer
 import torch.utils.data
 import torchvision.transforms
-
 from medbase.data.typing import TransformSequence
 from .model import Model
 from .separate import separate
 from .transforms import Grayscale, SquareCenterPad
 
-logger = logging.getLogger(__name__)
+logger = logging.getLogger("mednet")
 
 
 class Pasa(Model):
diff --git a/src/mednet/scripts/train.py b/src/mednet/scripts/train.py
index df8b1fec..2fd63221 100644
--- a/src/mednet/scripts/train.py
+++ b/src/mednet/scripts/train.py
@@ -12,7 +12,8 @@ from clapper.logging import setup
 
 from .click import ConfigCommand
 
-logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
+# logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
+logger = setup("mednet", format="%(levelname)s: %(message)s")
 
 
 def reusable_options(f):
@@ -271,7 +272,6 @@ def train(
 
     import torch
     from lightning.pytorch import seed_everything
-
     from medbase.engine.device import DeviceManager
     from medbase.engine.trainer import run
     from medbase.utils.checkpointer import get_checkpoint_to_resume_training
diff --git a/src/mednet/scripts/utils.py b/src/mednet/scripts/utils.py
index 4c894bd7..164f46b4 100644
--- a/src/mednet/scripts/utils.py
+++ b/src/mednet/scripts/utils.py
@@ -14,7 +14,7 @@ import lightning.pytorch.callbacks
 import torch.nn
 from medbase.engine.device import SupportedPytorchDevice
 
-logger = logging.getLogger(__name__)
+logger = logging.getLogger("mednet")
 
 
 def model_summary(
-- 
GitLab