Skip to content
Snippets Groups Projects
Commit 1f64e682 authored by André Anjos's avatar André Anjos :speech_balloon:
Browse files

[models.losses] Centralize all custom losses in a single place

parent fbc38f60
No related branches found
No related tags found
No related merge requests found
Showing
with 26 additions and 26 deletions
......@@ -15,12 +15,12 @@ import torch.optim
import torchvision.transforms
import torchvision.transforms.v2
import mednet.models.losses
import mednet.models.segment.driu
import mednet.models.segment.losses
import mednet.models.transforms
model = mednet.models.segment.driu.DRIU(
loss_type=mednet.models.segment.losses.SoftJaccardAndBCEWithLogitsLoss,
loss_type=mednet.models.losses.SoftJaccardAndBCEWithLogitsLoss,
loss_arguments=dict(alpha=0.7), # 0.7 BCE + 0.3 Jaccard
optimizer_type=torch.optim.Adam,
optimizer_arguments=dict(lr=0.01),
......
......@@ -15,12 +15,12 @@ import torch.optim
import torchvision.transforms
import torchvision.transforms.v2
import mednet.models.losses
import mednet.models.segment.driu_bn
import mednet.models.segment.losses
import mednet.models.transforms
model = mednet.models.segment.driu_bn.DRIUBN(
loss_type=mednet.models.segment.losses.SoftJaccardAndBCEWithLogitsLoss,
loss_type=mednet.models.losses.SoftJaccardAndBCEWithLogitsLoss,
loss_arguments=dict(alpha=0.7), # 0.7 BCE + 0.3 Jaccard
optimizer_type=torch.optim.Adam,
optimizer_arguments=dict(lr=0.01),
......
......@@ -15,12 +15,12 @@ import torch.optim
import torchvision.transforms
import torchvision.transforms.v2
import mednet.models.losses
import mednet.models.segment.driu_od
import mednet.models.segment.losses
import mednet.models.transforms
model = mednet.models.segment.driu_od.DRIUOD(
loss_type=mednet.models.segment.losses.SoftJaccardAndBCEWithLogitsLoss,
loss_type=mednet.models.losses.SoftJaccardAndBCEWithLogitsLoss,
loss_arguments=dict(alpha=0.7), # 0.7 BCE + 0.3 Jaccard
optimizer_type=torch.optim.Adam,
optimizer_arguments=dict(lr=0.01),
......
......@@ -15,12 +15,12 @@ import torch.optim
import torchvision.transforms
import torchvision.transforms.v2
import mednet.models.losses
import mednet.models.segment.driu_pix
import mednet.models.segment.losses
import mednet.models.transforms
model = mednet.models.segment.driu_pix.DRIUPix(
loss_type=mednet.models.segment.losses.SoftJaccardAndBCEWithLogitsLoss,
loss_type=mednet.models.losses.SoftJaccardAndBCEWithLogitsLoss,
loss_arguments=dict(alpha=0.7), # 0.7 BCE + 0.3 Jaccard
optimizer_type=torch.optim.Adam,
optimizer_arguments=dict(lr=0.01),
......
......@@ -11,12 +11,12 @@ import torch.optim
import torchvision.transforms
import torchvision.transforms.v2
import mednet.models.losses
import mednet.models.segment.hed
import mednet.models.segment.losses
import mednet.models.transforms
model = mednet.models.segment.hed.HED(
loss_type=mednet.models.segment.losses.MultiLayerSoftJaccardAndBCELogitsLoss,
loss_type=mednet.models.losses.MultiLayerSoftJaccardAndBCELogitsLoss,
loss_arguments=dict(alpha=0.7), # 0.7 BCE + 0.3 Jaccard
optimizer_type=torch.optim.Adam,
optimizer_arguments=dict(lr=0.01),
......
......@@ -14,7 +14,7 @@ import torch.optim.lr_scheduler
import torchvision.transforms
import torchvision.transforms.v2
import mednet.models.segment.losses
import mednet.models.losses
import mednet.models.segment.lwnet
import mednet.models.transforms
......@@ -25,7 +25,7 @@ min_lr = 1e-08 # valley
cycle = 100 # 1/2 epochs for a complete scheduling cycle
model = mednet.models.segment.lwnet.LittleWNet(
loss_type=mednet.models.segment.losses.MultiLayerBCELogitsLossWeightedPerBatch,
loss_type=mednet.models.losses.MultiLayerBCELogitsLossWeightedPerBatch,
loss_arguments=dict(),
optimizer_type=torch.optim.Adam,
optimizer_arguments=dict(lr=max_lr),
......
......@@ -19,12 +19,12 @@ import torch.optim
import torchvision.transforms
import torchvision.transforms.v2
import mednet.models.segment.losses
import mednet.models.losses
import mednet.models.segment.m2unet
import mednet.models.transforms
model = mednet.models.segment.m2unet.M2Unet(
loss_type=mednet.models.segment.losses.SoftJaccardAndBCEWithLogitsLoss,
loss_type=mednet.models.losses.SoftJaccardAndBCEWithLogitsLoss,
loss_arguments=dict(alpha=0.7), # 0.7 BCE + 0.3 Jaccard
optimizer_type=torch.optim.Adam,
optimizer_arguments=dict(lr=0.01),
......
......@@ -17,12 +17,12 @@ import torch.optim
import torchvision.transforms
import torchvision.transforms.v2
import mednet.models.segment.losses
import mednet.models.losses
import mednet.models.segment.unet
import mednet.models.transforms
model = mednet.models.segment.unet.Unet(
loss_type=mednet.models.segment.losses.SoftJaccardAndBCEWithLogitsLoss,
loss_type=mednet.models.losses.SoftJaccardAndBCEWithLogitsLoss,
loss_arguments=dict(alpha=0.7), # 0.7 BCE + 0.3 Jaccard
optimizer_type=torch.optim.Adam,
optimizer_arguments=dict(lr=0.01),
......
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""Specialized losses for semanatic segmentation."""
"""Specialized losses for semantic segmentation."""
import torch
......
......@@ -11,8 +11,8 @@ import torch.nn
import torch.utils.data
from ...data.typing import TransformSequence
from ..losses import SoftJaccardAndBCEWithLogitsLoss
from .backbones.vgg import vgg16_for_segmentation
from .losses import SoftJaccardAndBCEWithLogitsLoss
from .make_layers import UpsampleCropBlock, conv_with_kaiming_uniform
from .model import Model
......
......@@ -11,8 +11,8 @@ import torch.nn
import torch.utils.data
from ...data.typing import TransformSequence
from ..losses import SoftJaccardAndBCEWithLogitsLoss
from .backbones.vgg import vgg16_for_segmentation
from .losses import SoftJaccardAndBCEWithLogitsLoss
from .make_layers import UpsampleCropBlock, conv_with_kaiming_uniform
from .model import Model
......
......@@ -11,9 +11,9 @@ import torch.nn
import torch.utils.data
from ...data.typing import TransformSequence
from ..losses import SoftJaccardAndBCEWithLogitsLoss
from .backbones.vgg import vgg16_for_segmentation
from .driu import ConcatFuseBlock
from .losses import SoftJaccardAndBCEWithLogitsLoss
from .make_layers import UpsampleCropBlock
from .model import Model
......
......@@ -11,9 +11,9 @@ import torch.nn
import torch.utils.data
from ...data.typing import TransformSequence
from ..losses import SoftJaccardAndBCEWithLogitsLoss
from .backbones.vgg import vgg16_for_segmentation
from .driu import ConcatFuseBlock
from .losses import SoftJaccardAndBCEWithLogitsLoss
from .make_layers import UpsampleCropBlock
from .model import Model
......
......@@ -11,8 +11,8 @@ import torch.nn
import torch.utils.data
from ...data.typing import TransformSequence
from ..losses import MultiLayerSoftJaccardAndBCELogitsLoss
from .backbones.vgg import vgg16_for_segmentation
from .losses import MultiLayerSoftJaccardAndBCELogitsLoss
from .make_layers import UpsampleCropBlock, conv_with_kaiming_uniform
from .model import Model
......
......@@ -15,7 +15,7 @@ import torch
import torch.nn
from ...data.typing import TransformSequence
from .losses import MultiLayerBCELogitsLossWeightedPerBatch
from ..losses import MultiLayerBCELogitsLossWeightedPerBatch
from .model import Model
logger = logging.getLogger(__name__)
......
......@@ -12,8 +12,8 @@ import torch.utils.data
from torchvision.models.mobilenetv2 import InvertedResidual
from ...data.typing import TransformSequence
from ..losses import SoftJaccardAndBCEWithLogitsLoss
from .backbones.mobilenetv2 import mobilenet_v2_for_segmentation
from .losses import SoftJaccardAndBCEWithLogitsLoss
from .model import Model
logger = logging.getLogger(__name__)
......
......@@ -12,8 +12,8 @@ import torch.optim.optimizer
import torch.utils.data
from ...data.typing import TransformSequence
from ..losses import MultiLayerBCELogitsLossWeightedPerBatch
from ..model import Model as BaseModel
from .losses import MultiLayerBCELogitsLossWeightedPerBatch
logger = logging.getLogger(__name__)
......
......@@ -10,8 +10,8 @@ import torch.nn
import torch.utils.data
from ...data.typing import TransformSequence
from ..losses import SoftJaccardAndBCEWithLogitsLoss
from .backbones.vgg import vgg16_for_segmentation
from .losses import SoftJaccardAndBCEWithLogitsLoss
from .make_layers import UnetBlock, conv_with_kaiming_uniform
from .model import Model
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment