From cfd3773e3d3bc85f3cd7661f144f6a551eb28d4a Mon Sep 17 00:00:00 2001
From: dcarron <daniel.carron@idiap.ch>
Date: Wed, 12 Jul 2023 15:22:15 +0200
Subject: [PATCH] Update model configs

---
 src/ptbench/configs/models/alexnet.py         | 27 +++++--------------
 .../configs/models/alexnet_pretrained.py      | 27 +++++--------------
 src/ptbench/configs/models/densenet.py        | 27 +++++--------------
 .../configs/models/densenet_pretrained.py     | 27 +++++--------------
 4 files changed, 24 insertions(+), 84 deletions(-)

diff --git a/src/ptbench/configs/models/alexnet.py b/src/ptbench/configs/models/alexnet.py
index 2361b886..815226b5 100644
--- a/src/ptbench/configs/models/alexnet.py
+++ b/src/ptbench/configs/models/alexnet.py
@@ -4,32 +4,17 @@
 
 """AlexNet."""
 
-from torch import empty
 from torch.nn import BCEWithLogitsLoss
 from torch.optim import SGD
 
-from ...models.alexnet import Alexnet
-
-# optimizer
-optimizer = SGD
-optimizer_configs = {"lr": 0.01, "momentum": 0.1}
-
-# criterion
-criterion = BCEWithLogitsLoss(pos_weight=empty(1))
-criterion_valid = BCEWithLogitsLoss(pos_weight=empty(1))
-
 from ...data.transforms import ElasticDeformation
+from ...models.alexnet import Alexnet
 
-augmentation_transforms = [
-    ElasticDeformation(p=0.8),
-]
-
-# model
 model = Alexnet(
-    criterion,
-    criterion_valid,
-    optimizer,
-    optimizer_configs,
+    train_loss=BCEWithLogitsLoss(),
+    validation_loss=BCEWithLogitsLoss(),
+    optimizer_type=SGD,
+    optimizer_arguments=dict(lr=0.01, momentum=0.1),
+    augmentation_transforms=[ElasticDeformation(p=0.8)],
     pretrained=False,
-    augmentation_transforms=augmentation_transforms,
 )
diff --git a/src/ptbench/configs/models/alexnet_pretrained.py b/src/ptbench/configs/models/alexnet_pretrained.py
index 0dc7e5d6..f968df50 100644
--- a/src/ptbench/configs/models/alexnet_pretrained.py
+++ b/src/ptbench/configs/models/alexnet_pretrained.py
@@ -4,32 +4,17 @@
 
 """AlexNet."""
 
-from torch import empty
 from torch.nn import BCEWithLogitsLoss
 from torch.optim import SGD
 
-from ...models.alexnet import Alexnet
-
-# optimizer
-optimizer = SGD
-optimizer_configs = {"lr": 0.01, "momentum": 0.1}
-
-# criterion
-criterion = BCEWithLogitsLoss(pos_weight=empty(1))
-criterion_valid = BCEWithLogitsLoss(pos_weight=empty(1))
-
 from ...data.transforms import ElasticDeformation
+from ...models.alexnet import Alexnet
 
-augmentation_transforms = [
-    ElasticDeformation(p=0.8),
-]
-
-# model
 model = Alexnet(
-    criterion,
-    criterion_valid,
-    optimizer,
-    optimizer_configs,
+    train_loss=BCEWithLogitsLoss(),
+    validation_loss=BCEWithLogitsLoss(),
+    optimizer_type=SGD,
+    optimizer_arguments=dict(lr=0.01, momentum=0.1),
+    augmentation_transforms=[ElasticDeformation(p=0.8)],
     pretrained=True,
-    augmentation_transforms=augmentation_transforms,
 )
diff --git a/src/ptbench/configs/models/densenet.py b/src/ptbench/configs/models/densenet.py
index 5d612b2a..79f8f7da 100644
--- a/src/ptbench/configs/models/densenet.py
+++ b/src/ptbench/configs/models/densenet.py
@@ -4,32 +4,17 @@
 
 """DenseNet."""
 
-from torch import empty
 from torch.nn import BCEWithLogitsLoss
 from torch.optim import Adam
 
-from ...models.densenet import Densenet
-
-# optimizer
-optimizer = Adam
-optimizer_configs = {"lr": 0.0001}
-
-# criterion
-criterion = BCEWithLogitsLoss(pos_weight=empty(1))
-criterion_valid = BCEWithLogitsLoss(pos_weight=empty(1))
-
 from ...data.transforms import ElasticDeformation
+from ...models.densenet import Densenet
 
-augmentation_transforms = [
-    ElasticDeformation(p=0.8),
-]
-
-# model
 model = Densenet(
-    criterion,
-    criterion_valid,
-    optimizer,
-    optimizer_configs,
+    train_loss=BCEWithLogitsLoss(),
+    validation_loss=BCEWithLogitsLoss(),
+    optimizer_type=Adam,
+    optimizer_arguments=dict(lr=0.0001),
+    augmentation_transforms=[ElasticDeformation(p=0.8)],
     pretrained=False,
-    augmentation_transforms=augmentation_transforms,
 )
diff --git a/src/ptbench/configs/models/densenet_pretrained.py b/src/ptbench/configs/models/densenet_pretrained.py
index f8908fdb..4bc4616c 100644
--- a/src/ptbench/configs/models/densenet_pretrained.py
+++ b/src/ptbench/configs/models/densenet_pretrained.py
@@ -4,32 +4,17 @@
 
 """DenseNet."""
 
-from torch import empty
 from torch.nn import BCEWithLogitsLoss
 from torch.optim import Adam
 
-from ...models.densenet import Densenet
-
-# optimizer
-optimizer = Adam
-optimizer_configs = {"lr": 0.0001}
-
-# criterion
-criterion = BCEWithLogitsLoss(pos_weight=empty(1))
-criterion_valid = BCEWithLogitsLoss(pos_weight=empty(1))
-
 from ...data.transforms import ElasticDeformation
+from ...models.densenet import Densenet
 
-augmentation_transforms = [
-    ElasticDeformation(p=0.8),
-]
-
-# model
 model = Densenet(
-    criterion,
-    criterion_valid,
-    optimizer,
-    optimizer_configs,
+    train_loss=BCEWithLogitsLoss(),
+    validation_loss=BCEWithLogitsLoss(),
+    optimizer_type=Adam,
+    optimizer_arguments=dict(lr=0.0001),
+    augmentation_transforms=[ElasticDeformation(p=0.8)],
     pretrained=True,
-    augmentation_transforms=augmentation_transforms,
 )
-- 
GitLab