diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 4e949b9a46695591b2d327ec8ff7aa4414565611..079c8eae0595ff9c516a89d56b8461993a17d68b 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -7,22 +7,17 @@ # See https://pre-commit.com/hooks.html for more hooks repos: - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.3.2 + rev: v0.3.3 hooks: - id: ruff args: [ --fix ] - id: ruff-format - - repo: https://github.com/pycqa/docformatter - rev: v1.7.5 - hooks: - - id: docformatter - args: [ --wrap-summaries=0 ] - repo: https://github.com/numpy/numpydoc rev: v1.6.0 hooks: - id: numpydoc-validation - repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.8.0 + rev: v1.9.0 hooks: - id: mypy args: [ --install-types, --non-interactive, --no-strict-optional, --ignore-missing-imports ] diff --git a/src/mednet/config/data/tbx11k/make_splits_from_database.py b/helpers/tbx11k_make_splits.py similarity index 100% rename from src/mednet/config/data/tbx11k/make_splits_from_database.py rename to helpers/tbx11k_make_splits.py diff --git a/src/mednet/utils/summary.py b/src/mednet/utils/summary.py deleted file mode 100644 index bff705e30b557a0314dfef929535671e3dad7f81..0000000000000000000000000000000000000000 --- a/src/mednet/utils/summary.py +++ /dev/null @@ -1,61 +0,0 @@ -# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> -# -# SPDX-License-Identifier: GPL-3.0-or-later - -# Adapted from https://github.com/pytorch/pytorch/issues/2001#issuecomment-405675488 - -from functools import reduce - -import torch - -from torch.nn.modules.module import _addindent - - -# ignore this space! -def _repr(model: torch.nn.Module) -> tuple[str, int]: - # We treat the extra repr like the sub-module, one item per line - extra_lines = [] - extra_repr = model.extra_repr() - # empty string will be split into list [''] - if extra_repr: - extra_lines = extra_repr.split("\n") - child_lines = [] - total_params = 0 - for key, module in model._modules.items(): - mod_str, num_params = _repr(module) - mod_str = _addindent(mod_str, 2) - child_lines.append("(" + key + "): " + mod_str) - total_params += num_params - lines = extra_lines + child_lines - - for _, p in model._parameters.items(): - if hasattr(p, "dtype"): - total_params += reduce(lambda x, y: x * y, p.shape) - - main_str = model._get_name() + "(" - if lines: - # simple one-liner info, which most builtin Modules will use - if len(extra_lines) == 1 and not child_lines: - main_str += extra_lines[0] - else: - main_str += "\n " + "\n ".join(lines) + "\n" - - main_str += ")" - main_str += f", {total_params:,} params" - return main_str, total_params - - -def summary(model: torch.nn.Module) -> tuple[str, int]: - """Count the number of parameters in each model layer. - - Parameters - ---------- - model - Model to summarize. - - Returns - ------- - tuple[int, str] - A tuple containing a multiline string representation of the network and the number of parameters. - """ - return _repr(model)