From 11560cc6feadb677dc9ed1f7d65fb5f06fcd0fb7 Mon Sep 17 00:00:00 2001 From: dcarron <daniel.carron@idiap.ch> Date: Mon, 17 Jun 2024 10:10:18 +0200 Subject: [PATCH] [classification.config] Add model_transforms to model configs --- src/mednet/libs/classification/config/models/alexnet.py | 1 + src/mednet/libs/classification/config/models/densenet.py | 2 ++ .../libs/classification/config/models/densenet_pretrained.py | 1 + src/mednet/libs/classification/config/models/densenet_rs.py | 2 ++ src/mednet/libs/classification/config/models/pasa.py | 1 + 5 files changed, 7 insertions(+) diff --git a/src/mednet/libs/classification/config/models/alexnet.py b/src/mednet/libs/classification/config/models/alexnet.py index 1ba9e082..702f9ed0 100644 --- a/src/mednet/libs/classification/config/models/alexnet.py +++ b/src/mednet/libs/classification/config/models/alexnet.py @@ -24,4 +24,5 @@ model = Alexnet( torchvision.transforms.Resize(512, antialias=True), RGB(), ], + augmentation_transforms=[], ) diff --git a/src/mednet/libs/classification/config/models/densenet.py b/src/mednet/libs/classification/config/models/densenet.py index bac5868c..65f7e90a 100644 --- a/src/mednet/libs/classification/config/models/densenet.py +++ b/src/mednet/libs/classification/config/models/densenet.py @@ -18,4 +18,6 @@ model = Densenet( optimizer_arguments=dict(lr=0.0001), pretrained=False, dropout=0.1, + model_transforms=[], + augmentation_transforms=[], ) diff --git a/src/mednet/libs/classification/config/models/densenet_pretrained.py b/src/mednet/libs/classification/config/models/densenet_pretrained.py index 90fdd8c9..8d7103ae 100644 --- a/src/mednet/libs/classification/config/models/densenet_pretrained.py +++ b/src/mednet/libs/classification/config/models/densenet_pretrained.py @@ -31,4 +31,5 @@ model = Densenet( ), RGB(), ], + augmentation_transforms=[], ) diff --git a/src/mednet/libs/classification/config/models/densenet_rs.py b/src/mednet/libs/classification/config/models/densenet_rs.py index b3620d68..9d300875 100644 --- a/src/mednet/libs/classification/config/models/densenet_rs.py +++ b/src/mednet/libs/classification/config/models/densenet_rs.py @@ -20,4 +20,6 @@ model = Densenet( pretrained=False, dropout=0.1, num_classes=14, # number of classes in NIH CXR-14 + model_transforms=[], + augmentation_transforms=[], ) diff --git a/src/mednet/libs/classification/config/models/pasa.py b/src/mednet/libs/classification/config/models/pasa.py index 6eb67d29..637df9a4 100644 --- a/src/mednet/libs/classification/config/models/pasa.py +++ b/src/mednet/libs/classification/config/models/pasa.py @@ -29,4 +29,5 @@ model = Pasa( interpolation=torchvision.transforms.InterpolationMode.BILINEAR, ), ], + augmentation_transforms=[], ) -- GitLab