diff --git a/src/mednet/config/models/densenet.py b/src/mednet/config/models/densenet.py index 9fbf388d26fab33f16ae95d19e90763aa39cdcc9..f28dd23cd12c72e5fc6713e706f0e9c05158759c 100644 --- a/src/mednet/config/models/densenet.py +++ b/src/mednet/config/models/densenet.py @@ -19,6 +19,7 @@ model = Densenet( validation_loss=BCEWithLogitsLoss(), optimizer_type=Adam, optimizer_arguments=dict(lr=0.0001), - augmentation_transforms=[ElasticDeformation(p=0.8)], + augmentation_transforms=[ElasticDeformation(p=0.2)], pretrained=False, + dropout=0.1, ) diff --git a/src/mednet/config/models/densenet_pretrained.py b/src/mednet/config/models/densenet_pretrained.py index 7d85d9b4bc3904f887569d266fd489b31b28a171..274a564601094a8ecb51e67c87f19f1f8197a30a 100644 --- a/src/mednet/config/models/densenet_pretrained.py +++ b/src/mednet/config/models/densenet_pretrained.py @@ -21,6 +21,7 @@ model = Densenet( validation_loss=BCEWithLogitsLoss(), optimizer_type=Adam, optimizer_arguments=dict(lr=0.0001), - augmentation_transforms=[ElasticDeformation(p=0.8)], + augmentation_transforms=[ElasticDeformation(p=0.2)], pretrained=True, + dropout=0.1, ) diff --git a/src/mednet/config/models/densenet_rs.py b/src/mednet/config/models/densenet_rs.py index 24caef8681a2b908782ec93cc222bc0d6ca3117b..e7db48850d0e8d2b959b39ee93bae3b78dccfa80 100644 --- a/src/mednet/config/models/densenet_rs.py +++ b/src/mednet/config/models/densenet_rs.py @@ -20,7 +20,8 @@ model = Densenet( validation_loss=BCEWithLogitsLoss(), optimizer_type=Adam, optimizer_arguments=dict(lr=0.0001), - augmentation_transforms=[ElasticDeformation(p=0.8)], + augmentation_transforms=[ElasticDeformation(p=0.2)], pretrained=False, + dropout=0.1, num_classes=14, # number of classes in NIH CXR-14 ) diff --git a/src/mednet/models/densenet.py b/src/mednet/models/densenet.py index e29e128cd0c5a9146ff96a3205ff8428577f6906..f7d1544164cf17c6f33f123d6d4bd02435722eb5 100644 --- a/src/mednet/models/densenet.py +++ b/src/mednet/models/densenet.py @@ -52,6 +52,8 @@ class Densenet(pl.LightningModule): pretrained If set to True, loads pretrained model weights during initialization, else trains a new model. + dropout + Dropout rate after each dense layer. num_classes Number of outputs (classes) for this model. """ @@ -64,6 +66,7 @@ class Densenet(pl.LightningModule): optimizer_arguments: dict[str, typing.Any] = {}, augmentation_transforms: TransformSequence = [], pretrained: bool = False, + dropout: float = 0.1, num_classes: int = 1, ): super().__init__() @@ -97,7 +100,7 @@ class Densenet(pl.LightningModule): logger.info(f"Loading pretrained {self.name} model weights") weights = models.DenseNet121_Weights.DEFAULT - self.model_ft = models.densenet121(weights=weights) + self.model_ft = models.densenet121(weights=weights, drop_rate=dropout) # Adapt output features self.model_ft.classifier = torch.nn.Linear(