Skip to content
Snippets Groups Projects
Commit 901c950a authored by Daniel CARRON's avatar Daniel CARRON :b: Committed by André Anjos
Browse files

[common] Move reusable files to new common package

parent 2631da1d
No related branches found
No related tags found
1 merge request!46Create common library
Showing
with 78 additions and 14 deletions
File moved
File moved
File moved
...@@ -12,7 +12,7 @@ import typing ...@@ -12,7 +12,7 @@ import typing
import torch import torch
from .typing import DatabaseSplit, RawDataLoader from medbase.data.typing import DatabaseSplit, RawDataLoader
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
......
File moved
File moved
File moved
File moved
File moved
...@@ -18,7 +18,7 @@ import warnings ...@@ -18,7 +18,7 @@ import warnings
import numpy import numpy
import psutil import psutil
from ..engine.device import SupportedPytorchDevice from medbase.engine.device import SupportedPytorchDevice
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
......
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
# Adapted from https://github.com/pytorch/pytorch/issues/2001#issuecomment-405675488
from functools import reduce
import torch
from torch.nn.modules.module import _addindent
# ignore this space!
def _repr(model: torch.nn.Module) -> tuple[str, int]:
# We treat the extra repr like the sub-module, one item per line
extra_lines = []
extra_repr = model.extra_repr()
# empty string will be split into list ['']
if extra_repr:
extra_lines = extra_repr.split("\n")
child_lines = []
total_params = 0
for key, module in model._modules.items():
mod_str, num_params = _repr(module)
mod_str = _addindent(mod_str, 2)
child_lines.append("(" + key + "): " + mod_str)
total_params += num_params
lines = extra_lines + child_lines
for _, p in model._parameters.items():
if hasattr(p, "dtype"):
total_params += reduce(lambda x, y: x * y, p.shape)
main_str = model._get_name() + "("
if lines:
# simple one-liner info, which most builtin Modules will use
if len(extra_lines) == 1 and not child_lines:
main_str += extra_lines[0]
else:
main_str += "\n " + "\n ".join(lines) + "\n"
main_str += ")"
main_str += f", {total_params:,} params"
return main_str, total_params
def summary(model: torch.nn.Module) -> tuple[str, int]:
"""Count the number of parameters in each model layer.
Parameters
----------
model
Model to summarize.
Returns
-------
tuple[int, str]
A tuple containing a multiline string representation of the network and the number of parameters.
"""
return _repr(model)
File moved
...@@ -12,11 +12,12 @@ import pathlib ...@@ -12,11 +12,12 @@ import pathlib
import PIL.Image import PIL.Image
from torchvision.transforms.functional import to_tensor from torchvision.transforms.functional import to_tensor
from ....data.datamodule import CachingDataModule from medbase.data.datamodule import CachingDataModule
from ....data.image_utils import remove_black_borders from medbase.data.image_utils import remove_black_borders
from ....data.split import make_split from medbase.data.split import make_split
from ....data.typing import RawDataLoader as _BaseRawDataLoader from medbase.data.typing import RawDataLoader as _BaseRawDataLoader
from ....data.typing import Sample from medbase.data.typing import Sample
from ....utils.rc import load_rc from ....utils.rc import load_rc
CONFIGURATION_KEY_DATADIR = "datadir." + (__name__.rsplit(".", 2)[-2]) CONFIGURATION_KEY_DATADIR = "datadir." + (__name__.rsplit(".", 2)[-2])
......
...@@ -8,9 +8,10 @@ Database reference: [INDIAN-2013]_ ...@@ -8,9 +8,10 @@ Database reference: [INDIAN-2013]_
import pathlib import pathlib
from medbase.data.datamodule import CachingDataModule
from medbase.data.split import make_split
from ....config.data.shenzhen.datamodule import RawDataLoader from ....config.data.shenzhen.datamodule import RawDataLoader
from ....data.datamodule import CachingDataModule
from ....data.split import make_split
CONFIGURATION_KEY_DATADIR = "datadir." + (__name__.rsplit(".", 2)[-2]) CONFIGURATION_KEY_DATADIR = "datadir." + (__name__.rsplit(".", 2)[-2])
"""Key to search for in the configuration file for the root directory of this """Key to search for in the configuration file for the root directory of this
......
...@@ -12,11 +12,12 @@ import pathlib ...@@ -12,11 +12,12 @@ import pathlib
import PIL.Image import PIL.Image
from torchvision.transforms.functional import to_tensor from torchvision.transforms.functional import to_tensor
from ....data.datamodule import CachingDataModule from medbase.data.datamodule import CachingDataModule
from ....data.image_utils import remove_black_borders from medbase.data.image_utils import remove_black_borders
from ....data.split import make_split from medbase.data.split import make_split
from ....data.typing import RawDataLoader as _BaseRawDataLoader from medbase.data.typing import Sample
from ....data.typing import Sample from medbase.data.typing import RawDataLoader as _BaseRawDataLoader
from ....utils.rc import load_rc from ....utils.rc import load_rc
CONFIGURATION_KEY_DATADIR = "datadir." + (__name__.rsplit(".", 2)[-2]) CONFIGURATION_KEY_DATADIR = "datadir." + (__name__.rsplit(".", 2)[-2])
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment