Skip to content
Snippets Groups Projects
Commit 1c5061b9 authored by André Anjos's avatar André Anjos :speech_balloon:
Browse files

Simplify toolchain to use ruff instead of black/isort/docformatter; Update all...

Simplify toolchain to use ruff instead of black/isort/docformatter; Update all files to match setup rules
parent ceb8edc2
Branches
Tags
No related merge requests found
Pipeline #85018 failed
...@@ -7,22 +7,17 @@ ...@@ -7,22 +7,17 @@
# See https://pre-commit.com/hooks.html for more hooks # See https://pre-commit.com/hooks.html for more hooks
repos: repos:
- repo: https://github.com/astral-sh/ruff-pre-commit - repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.3.2 rev: v0.3.3
hooks: hooks:
- id: ruff - id: ruff
args: [ --fix ] args: [ --fix ]
- id: ruff-format - id: ruff-format
- repo: https://github.com/pycqa/docformatter
rev: v1.7.5
hooks:
- id: docformatter
args: [ --wrap-summaries=0 ]
- repo: https://github.com/numpy/numpydoc - repo: https://github.com/numpy/numpydoc
rev: v1.6.0 rev: v1.6.0
hooks: hooks:
- id: numpydoc-validation - id: numpydoc-validation
- repo: https://github.com/pre-commit/mirrors-mypy - repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.8.0 rev: v1.9.0
hooks: hooks:
- id: mypy - id: mypy
args: [ --install-types, --non-interactive, --no-strict-optional, --ignore-missing-imports ] args: [ --install-types, --non-interactive, --no-strict-optional, --ignore-missing-imports ]
......
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
# Adapted from https://github.com/pytorch/pytorch/issues/2001#issuecomment-405675488
from functools import reduce
import torch
from torch.nn.modules.module import _addindent
# ignore this space!
def _repr(model: torch.nn.Module) -> tuple[str, int]:
# We treat the extra repr like the sub-module, one item per line
extra_lines = []
extra_repr = model.extra_repr()
# empty string will be split into list ['']
if extra_repr:
extra_lines = extra_repr.split("\n")
child_lines = []
total_params = 0
for key, module in model._modules.items():
mod_str, num_params = _repr(module)
mod_str = _addindent(mod_str, 2)
child_lines.append("(" + key + "): " + mod_str)
total_params += num_params
lines = extra_lines + child_lines
for _, p in model._parameters.items():
if hasattr(p, "dtype"):
total_params += reduce(lambda x, y: x * y, p.shape)
main_str = model._get_name() + "("
if lines:
# simple one-liner info, which most builtin Modules will use
if len(extra_lines) == 1 and not child_lines:
main_str += extra_lines[0]
else:
main_str += "\n " + "\n ".join(lines) + "\n"
main_str += ")"
main_str += f", {total_params:,} params"
return main_str, total_params
def summary(model: torch.nn.Module) -> tuple[str, int]:
"""Count the number of parameters in each model layer.
Parameters
----------
model
Model to summarize.
Returns
-------
tuple[int, str]
A tuple containing a multiline string representation of the network and the number of parameters.
"""
return _repr(model)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment