Skip to content
Snippets Groups Projects
densenet_pretrained.py 761 B
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later

"""DenseNet."""

from torch import empty
from torch.nn import BCEWithLogitsLoss
from torch.optim import Adam

from ...models.densenet import Densenet

# optimizer
optimizer = Adam
optimizer_configs = {"lr": 0.0001}

# criterion
criterion = BCEWithLogitsLoss(pos_weight=empty(1))
criterion_valid = BCEWithLogitsLoss(pos_weight=empty(1))

from ...data.transforms import ElasticDeformation

augmentation_transforms = [
    ElasticDeformation(p=0.8),
]

# model
model = Densenet(
    criterion,
    criterion_valid,
    optimizer,
    optimizer_configs,
    pretrained=True,
    augmentation_transforms=augmentation_transforms,
)