diff --git a/src/ptbench/configs/models/densenet_pretrained.py b/src/ptbench/configs/models/densenet_pretrained.py index b018a52203061b847cdae9f09b5edfa713930302..f8908fdb1e87a62df41dca0ecb75ff1fc79b1012 100644 --- a/src/ptbench/configs/models/densenet_pretrained.py +++ b/src/ptbench/configs/models/densenet_pretrained.py @@ -6,20 +6,30 @@ from torch import empty from torch.nn import BCEWithLogitsLoss +from torch.optim import Adam from ...models.densenet import Densenet -# config -optimizer_configs = {"lr": 0.01} - # optimizer -optimizer = "Adam" +optimizer = Adam +optimizer_configs = {"lr": 0.0001} # criterion criterion = BCEWithLogitsLoss(pos_weight=empty(1)) criterion_valid = BCEWithLogitsLoss(pos_weight=empty(1)) +from ...data.transforms import ElasticDeformation + +augmentation_transforms = [ + ElasticDeformation(p=0.8), +] + # model model = Densenet( - criterion, criterion_valid, optimizer, optimizer_configs, pretrained=True + criterion, + criterion_valid, + optimizer, + optimizer_configs, + pretrained=True, + augmentation_transforms=augmentation_transforms, )