From 11560cc6feadb677dc9ed1f7d65fb5f06fcd0fb7 Mon Sep 17 00:00:00 2001
From: dcarron <daniel.carron@idiap.ch>
Date: Mon, 17 Jun 2024 10:10:18 +0200
Subject: [PATCH] [classification.config] Add model_transforms to model configs

---
 src/mednet/libs/classification/config/models/alexnet.py         | 1 +
 src/mednet/libs/classification/config/models/densenet.py        | 2 ++
 .../libs/classification/config/models/densenet_pretrained.py    | 1 +
 src/mednet/libs/classification/config/models/densenet_rs.py     | 2 ++
 src/mednet/libs/classification/config/models/pasa.py            | 1 +
 5 files changed, 7 insertions(+)

diff --git a/src/mednet/libs/classification/config/models/alexnet.py b/src/mednet/libs/classification/config/models/alexnet.py
index 1ba9e082..702f9ed0 100644
--- a/src/mednet/libs/classification/config/models/alexnet.py
+++ b/src/mednet/libs/classification/config/models/alexnet.py
@@ -24,4 +24,5 @@ model = Alexnet(
         torchvision.transforms.Resize(512, antialias=True),
         RGB(),
     ],
+    augmentation_transforms=[],
 )
diff --git a/src/mednet/libs/classification/config/models/densenet.py b/src/mednet/libs/classification/config/models/densenet.py
index bac5868c..65f7e90a 100644
--- a/src/mednet/libs/classification/config/models/densenet.py
+++ b/src/mednet/libs/classification/config/models/densenet.py
@@ -18,4 +18,6 @@ model = Densenet(
     optimizer_arguments=dict(lr=0.0001),
     pretrained=False,
     dropout=0.1,
+    model_transforms=[],
+    augmentation_transforms=[],
 )
diff --git a/src/mednet/libs/classification/config/models/densenet_pretrained.py b/src/mednet/libs/classification/config/models/densenet_pretrained.py
index 90fdd8c9..8d7103ae 100644
--- a/src/mednet/libs/classification/config/models/densenet_pretrained.py
+++ b/src/mednet/libs/classification/config/models/densenet_pretrained.py
@@ -31,4 +31,5 @@ model = Densenet(
         ),
         RGB(),
     ],
+    augmentation_transforms=[],
 )
diff --git a/src/mednet/libs/classification/config/models/densenet_rs.py b/src/mednet/libs/classification/config/models/densenet_rs.py
index b3620d68..9d300875 100644
--- a/src/mednet/libs/classification/config/models/densenet_rs.py
+++ b/src/mednet/libs/classification/config/models/densenet_rs.py
@@ -20,4 +20,6 @@ model = Densenet(
     pretrained=False,
     dropout=0.1,
     num_classes=14,  # number of classes in NIH CXR-14
+    model_transforms=[],
+    augmentation_transforms=[],
 )
diff --git a/src/mednet/libs/classification/config/models/pasa.py b/src/mednet/libs/classification/config/models/pasa.py
index 6eb67d29..637df9a4 100644
--- a/src/mednet/libs/classification/config/models/pasa.py
+++ b/src/mednet/libs/classification/config/models/pasa.py
@@ -29,4 +29,5 @@ model = Pasa(
             interpolation=torchvision.transforms.InterpolationMode.BILINEAR,
         ),
     ],
+    augmentation_transforms=[],
 )
-- 
GitLab