From ff712757433dc512d1ad1f7fc610c6f0edc549e8 Mon Sep 17 00:00:00 2001 From: dcarron <daniel.carron@idiap.ch> Date: Tue, 11 Apr 2023 13:16:42 +0200 Subject: [PATCH] Moved alexnet_pretrained to lightning --- .../configs/models/alexnet_pretrained.py | 25 +++++++++---------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/src/ptbench/configs/models/alexnet_pretrained.py b/src/ptbench/configs/models/alexnet_pretrained.py index f792151d..1d196be6 100644 --- a/src/ptbench/configs/models/alexnet_pretrained.py +++ b/src/ptbench/configs/models/alexnet_pretrained.py @@ -2,24 +2,23 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -"""AlexNet. - -Pretrained AlexNet -""" +"""AlexNet.""" +from torch import empty from torch.nn import BCEWithLogitsLoss -from torch.optim import SGD -from ...models.alexnet import build_alexnet +from ...models.alexnet import Alexnet # config -lr = 0.001 - -# model -model = build_alexnet(pretrained=True) +optimizer_configs = {"lr": 0.001, "momentum": 0.1} # optimizer -optimizer = SGD(model.parameters(), lr=lr, momentum=0.1) - +optimizer = "SGD" # criterion -criterion = BCEWithLogitsLoss() +criterion = BCEWithLogitsLoss(pos_weight=empty(1)) +criterion_valid = BCEWithLogitsLoss(pos_weight=empty(1)) + +# model +model = Alexnet( + criterion, criterion_valid, optimizer, optimizer_configs, pretrained=True +) -- GitLab