From ff712757433dc512d1ad1f7fc610c6f0edc549e8 Mon Sep 17 00:00:00 2001
From: dcarron <daniel.carron@idiap.ch>
Date: Tue, 11 Apr 2023 13:16:42 +0200
Subject: [PATCH] Moved alexnet_pretrained to lightning

---
 .../configs/models/alexnet_pretrained.py      | 25 +++++++++----------
 1 file changed, 12 insertions(+), 13 deletions(-)

diff --git a/src/ptbench/configs/models/alexnet_pretrained.py b/src/ptbench/configs/models/alexnet_pretrained.py
index f792151d..1d196be6 100644
--- a/src/ptbench/configs/models/alexnet_pretrained.py
+++ b/src/ptbench/configs/models/alexnet_pretrained.py
@@ -2,24 +2,23 @@
 #
 # SPDX-License-Identifier: GPL-3.0-or-later
 
-"""AlexNet.
-
-Pretrained AlexNet
-"""
+"""AlexNet."""
 
+from torch import empty
 from torch.nn import BCEWithLogitsLoss
-from torch.optim import SGD
 
-from ...models.alexnet import build_alexnet
+from ...models.alexnet import Alexnet
 
 # config
-lr = 0.001
-
-# model
-model = build_alexnet(pretrained=True)
+optimizer_configs = {"lr": 0.001, "momentum": 0.1}
 
 # optimizer
-optimizer = SGD(model.parameters(), lr=lr, momentum=0.1)
-
+optimizer = "SGD"
 # criterion
-criterion = BCEWithLogitsLoss()
+criterion = BCEWithLogitsLoss(pos_weight=empty(1))
+criterion_valid = BCEWithLogitsLoss(pos_weight=empty(1))
+
+# model
+model = Alexnet(
+    criterion, criterion_valid, optimizer, optimizer_configs, pretrained=True
+)
-- 
GitLab