Skip to content
Snippets Groups Projects
pasa.py 1.11 KiB
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later

"""CNN for Tuberculosis Detection.

Implementation of the model architecture proposed by F. Pasa in the article
"Efficient Deep Network Architectures for Fast Chest X-Ray Tuberculosis
Screening and Visualization".

Reference: [PASA-2019]_
"""

from torch import empty
from torch.nn import BCEWithLogitsLoss
from torch.optim import Adam

from ...models.pasa import PASA

# optimizer
optimizer = Adam
optimizer_configs = {"lr": 8e-5}

# criterion
criterion = BCEWithLogitsLoss(pos_weight=empty(1))
criterion_valid = BCEWithLogitsLoss(pos_weight=empty(1))

from ...data.transforms import ElasticDeformation

augmentation_transforms = [ElasticDeformation(p=0.8)]

# from torchvision.transforms.v2 import ElasticTransform, InterpolationMode
# augmentation_transforms = [ElasticTransform(alpha=1000.0, sigma=30.0, interpolation=InterpolationMode.NEAREST)]

# model
model = PASA(
    criterion,
    criterion_valid,
    optimizer,
    optimizer_configs,
    augmentation_transforms=augmentation_transforms,
)