From e016b365c9204c6e02d681d48e9a5bf5858c1652 Mon Sep 17 00:00:00 2001 From: Andre Anjos <andre.dos.anjos@gmail.com> Date: Wed, 26 Jun 2024 10:39:29 +0200 Subject: [PATCH] [libs.common.models.model] Fix typing errors --- src/mednet/libs/common/models/model.py | 2 +- src/mednet/libs/segmentation/models/losses.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/mednet/libs/common/models/model.py b/src/mednet/libs/common/models/model.py index ae4e658f..fb4c6e18 100644 --- a/src/mednet/libs/common/models/model.py +++ b/src/mednet/libs/common/models/model.py @@ -49,7 +49,7 @@ class Model(pl.LightningModule): def __init__( self, - loss_type: torch.nn.Module = torch.nn.BCEWithLogitsLoss, + loss_type: type[torch.nn.Module] = torch.nn.BCEWithLogitsLoss, loss_arguments: dict[str, typing.Any] = {}, optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam, optimizer_arguments: dict[str, typing.Any] = {}, diff --git a/src/mednet/libs/segmentation/models/losses.py b/src/mednet/libs/segmentation/models/losses.py index 89945687..2a344dd4 100644 --- a/src/mednet/libs/segmentation/models/losses.py +++ b/src/mednet/libs/segmentation/models/losses.py @@ -19,13 +19,13 @@ class WeightedBCELogitsLoss(torch.nn.Module): super().__init__() def forward( - self, sample: torch.Tensor, target: torch.Tensor, mask: torch.Tensor + self, tensor: torch.Tensor, target: torch.Tensor, mask: torch.Tensor ) -> torch.Tensor: """Forward pass. Parameters ---------- - sample + tensor Value produced by the model to be evaluated, with the shape ``[n, c, h, w]``. target @@ -46,7 +46,7 @@ class WeightedBCELogitsLoss(torch.nn.Module): num_neg = valid.sum() - num_pos pos_weight = num_neg / num_pos return torch.nn.functional.binary_cross_entropy_with_logits( - sample[valid], + tensor[valid], target[valid], reduction="mean", pos_weight=pos_weight, -- GitLab