diff --git a/src/mednet/libs/classification/config/models/alexnet.py b/src/mednet/libs/classification/config/models/alexnet.py index 1ba9e0826f13f2be7069a33a76aac02da1e93fc7..702f9ed0de79fbdd18b54ca8ea507fbda6cadf63 100644 --- a/src/mednet/libs/classification/config/models/alexnet.py +++ b/src/mednet/libs/classification/config/models/alexnet.py @@ -24,4 +24,5 @@ model = Alexnet( torchvision.transforms.Resize(512, antialias=True), RGB(), ], + augmentation_transforms=[], ) diff --git a/src/mednet/libs/classification/config/models/densenet.py b/src/mednet/libs/classification/config/models/densenet.py index bac5868cae40e28472af53e7b21617546f012c33..65f7e90a0f06e41cf3efe4b6e0de035fed64fc1a 100644 --- a/src/mednet/libs/classification/config/models/densenet.py +++ b/src/mednet/libs/classification/config/models/densenet.py @@ -18,4 +18,6 @@ model = Densenet( optimizer_arguments=dict(lr=0.0001), pretrained=False, dropout=0.1, + model_transforms=[], + augmentation_transforms=[], ) diff --git a/src/mednet/libs/classification/config/models/densenet_pretrained.py b/src/mednet/libs/classification/config/models/densenet_pretrained.py index 90fdd8c9baf72837837116bc9b3444f6a12fbbc8..8d7103aeeb92275e828a341faf3346ee68e69bf4 100644 --- a/src/mednet/libs/classification/config/models/densenet_pretrained.py +++ b/src/mednet/libs/classification/config/models/densenet_pretrained.py @@ -31,4 +31,5 @@ model = Densenet( ), RGB(), ], + augmentation_transforms=[], ) diff --git a/src/mednet/libs/classification/config/models/densenet_rs.py b/src/mednet/libs/classification/config/models/densenet_rs.py index b3620d68d298ab921ac214389ffb1bc7434d4967..9d3008755da6f9ce02be29cdf91adfdbb28cdfc2 100644 --- a/src/mednet/libs/classification/config/models/densenet_rs.py +++ b/src/mednet/libs/classification/config/models/densenet_rs.py @@ -20,4 +20,6 @@ model = Densenet( pretrained=False, dropout=0.1, num_classes=14, # number of classes in NIH CXR-14 + model_transforms=[], + augmentation_transforms=[], ) diff --git a/src/mednet/libs/classification/config/models/pasa.py b/src/mednet/libs/classification/config/models/pasa.py index 6eb67d29aed82f23455050dc0b2214afb02a4b21..637df9a48a954d2c913d369cd495d4b349311fe2 100644 --- a/src/mednet/libs/classification/config/models/pasa.py +++ b/src/mednet/libs/classification/config/models/pasa.py @@ -29,4 +29,5 @@ model = Pasa( interpolation=torchvision.transforms.InterpolationMode.BILINEAR, ), ], + augmentation_transforms=[], )