Skip to content
Snippets Groups Projects
Commit e016b365 authored by André Anjos's avatar André Anjos :speech_balloon:
Browse files

[libs.common.models.model] Fix typing errors

parent b257cb10
No related branches found
No related tags found
1 merge request!46Create common library
......@@ -49,7 +49,7 @@ class Model(pl.LightningModule):
def __init__(
self,
loss_type: torch.nn.Module = torch.nn.BCEWithLogitsLoss,
loss_type: type[torch.nn.Module] = torch.nn.BCEWithLogitsLoss,
loss_arguments: dict[str, typing.Any] = {},
optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam,
optimizer_arguments: dict[str, typing.Any] = {},
......
......@@ -19,13 +19,13 @@ class WeightedBCELogitsLoss(torch.nn.Module):
super().__init__()
def forward(
self, sample: torch.Tensor, target: torch.Tensor, mask: torch.Tensor
self, tensor: torch.Tensor, target: torch.Tensor, mask: torch.Tensor
) -> torch.Tensor:
"""Forward pass.
Parameters
----------
sample
tensor
Value produced by the model to be evaluated, with the shape ``[n, c,
h, w]``.
target
......@@ -46,7 +46,7 @@ class WeightedBCELogitsLoss(torch.nn.Module):
num_neg = valid.sum() - num_pos
pos_weight = num_neg / num_pos
return torch.nn.functional.binary_cross_entropy_with_logits(
sample[valid],
tensor[valid],
target[valid],
reduction="mean",
pos_weight=pos_weight,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment