diff --git a/src/mednet/libs/common/models/model.py b/src/mednet/libs/common/models/model.py index ae4e658f37dc3d1638aebd9632974534bbf74352..fb4c6e18a8ebdfa5d6a016ff44839844f8e4484c 100644 --- a/src/mednet/libs/common/models/model.py +++ b/src/mednet/libs/common/models/model.py @@ -49,7 +49,7 @@ class Model(pl.LightningModule): def __init__( self, - loss_type: torch.nn.Module = torch.nn.BCEWithLogitsLoss, + loss_type: type[torch.nn.Module] = torch.nn.BCEWithLogitsLoss, loss_arguments: dict[str, typing.Any] = {}, optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam, optimizer_arguments: dict[str, typing.Any] = {}, diff --git a/src/mednet/libs/segmentation/models/losses.py b/src/mednet/libs/segmentation/models/losses.py index 89945687ae73f11ed5df30a2e8d2600c6fdb937e..2a344dd461fb4294bab6cb83f0edfa60789486b3 100644 --- a/src/mednet/libs/segmentation/models/losses.py +++ b/src/mednet/libs/segmentation/models/losses.py @@ -19,13 +19,13 @@ class WeightedBCELogitsLoss(torch.nn.Module): super().__init__() def forward( - self, sample: torch.Tensor, target: torch.Tensor, mask: torch.Tensor + self, tensor: torch.Tensor, target: torch.Tensor, mask: torch.Tensor ) -> torch.Tensor: """Forward pass. Parameters ---------- - sample + tensor Value produced by the model to be evaluated, with the shape ``[n, c, h, w]``. target @@ -46,7 +46,7 @@ class WeightedBCELogitsLoss(torch.nn.Module): num_neg = valid.sum() - num_pos pos_weight = num_neg / num_pos return torch.nn.functional.binary_cross_entropy_with_logits( - sample[valid], + tensor[valid], target[valid], reduction="mean", pos_weight=pos_weight,