diff --git a/src/ptbench/configs/models/alexnet_pretrained.py b/src/ptbench/configs/models/alexnet_pretrained.py index f792151dc4b489c0ef27526db2bcfcccd5852be3..1d196be6f79ea5c70987c1d1a66eaf32e8e7ca4c 100644 --- a/src/ptbench/configs/models/alexnet_pretrained.py +++ b/src/ptbench/configs/models/alexnet_pretrained.py @@ -2,24 +2,23 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -"""AlexNet. - -Pretrained AlexNet -""" +"""AlexNet.""" +from torch import empty from torch.nn import BCEWithLogitsLoss -from torch.optim import SGD -from ...models.alexnet import build_alexnet +from ...models.alexnet import Alexnet # config -lr = 0.001 - -# model -model = build_alexnet(pretrained=True) +optimizer_configs = {"lr": 0.001, "momentum": 0.1} # optimizer -optimizer = SGD(model.parameters(), lr=lr, momentum=0.1) - +optimizer = "SGD" # criterion -criterion = BCEWithLogitsLoss() +criterion = BCEWithLogitsLoss(pos_weight=empty(1)) +criterion_valid = BCEWithLogitsLoss(pos_weight=empty(1)) + +# model +model = Alexnet( + criterion, criterion_valid, optimizer, optimizer_configs, pretrained=True +)