Skip to content
Snippets Groups Projects
Commit e016b365 authored by André Anjos's avatar André Anjos :speech_balloon:
Browse files

[libs.common.models.model] Fix typing errors

parent b257cb10
No related branches found
No related tags found
1 merge request!46Create common library
...@@ -49,7 +49,7 @@ class Model(pl.LightningModule): ...@@ -49,7 +49,7 @@ class Model(pl.LightningModule):
def __init__( def __init__(
self, self,
loss_type: torch.nn.Module = torch.nn.BCEWithLogitsLoss, loss_type: type[torch.nn.Module] = torch.nn.BCEWithLogitsLoss,
loss_arguments: dict[str, typing.Any] = {}, loss_arguments: dict[str, typing.Any] = {},
optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam, optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam,
optimizer_arguments: dict[str, typing.Any] = {}, optimizer_arguments: dict[str, typing.Any] = {},
......
...@@ -19,13 +19,13 @@ class WeightedBCELogitsLoss(torch.nn.Module): ...@@ -19,13 +19,13 @@ class WeightedBCELogitsLoss(torch.nn.Module):
super().__init__() super().__init__()
def forward( def forward(
self, sample: torch.Tensor, target: torch.Tensor, mask: torch.Tensor self, tensor: torch.Tensor, target: torch.Tensor, mask: torch.Tensor
) -> torch.Tensor: ) -> torch.Tensor:
"""Forward pass. """Forward pass.
Parameters Parameters
---------- ----------
sample tensor
Value produced by the model to be evaluated, with the shape ``[n, c, Value produced by the model to be evaluated, with the shape ``[n, c,
h, w]``. h, w]``.
target target
...@@ -46,7 +46,7 @@ class WeightedBCELogitsLoss(torch.nn.Module): ...@@ -46,7 +46,7 @@ class WeightedBCELogitsLoss(torch.nn.Module):
num_neg = valid.sum() - num_pos num_neg = valid.sum() - num_pos
pos_weight = num_neg / num_pos pos_weight = num_neg / num_pos
return torch.nn.functional.binary_cross_entropy_with_logits( return torch.nn.functional.binary_cross_entropy_with_logits(
sample[valid], tensor[valid],
target[valid], target[valid],
reduction="mean", reduction="mean",
pos_weight=pos_weight, pos_weight=pos_weight,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment