Skip to content
Snippets Groups Projects
Commit f347ee8b authored by André Anjos's avatar André Anjos :speech_balloon:
Browse files

Merge branch 'dropout' into 'main'

[densenet] adding drop_rate in model params

See merge request biosignal/software/mednet!35
parents 5d34ee75 854a730e
No related branches found
No related tags found
1 merge request!35[densenet] adding drop_rate in model params
Pipeline #86569 passed
......@@ -19,6 +19,7 @@ model = Densenet(
validation_loss=BCEWithLogitsLoss(),
optimizer_type=Adam,
optimizer_arguments=dict(lr=0.0001),
augmentation_transforms=[ElasticDeformation(p=0.8)],
augmentation_transforms=[ElasticDeformation(p=0.2)],
pretrained=False,
dropout=0.1,
)
......@@ -21,6 +21,7 @@ model = Densenet(
validation_loss=BCEWithLogitsLoss(),
optimizer_type=Adam,
optimizer_arguments=dict(lr=0.0001),
augmentation_transforms=[ElasticDeformation(p=0.8)],
augmentation_transforms=[ElasticDeformation(p=0.2)],
pretrained=True,
dropout=0.1,
)
......@@ -20,7 +20,8 @@ model = Densenet(
validation_loss=BCEWithLogitsLoss(),
optimizer_type=Adam,
optimizer_arguments=dict(lr=0.0001),
augmentation_transforms=[ElasticDeformation(p=0.8)],
augmentation_transforms=[ElasticDeformation(p=0.2)],
pretrained=False,
dropout=0.1,
num_classes=14, # number of classes in NIH CXR-14
)
......@@ -52,6 +52,8 @@ class Densenet(pl.LightningModule):
pretrained
If set to True, loads pretrained model weights during initialization,
else trains a new model.
dropout
Dropout rate after each dense layer.
num_classes
Number of outputs (classes) for this model.
"""
......@@ -64,6 +66,7 @@ class Densenet(pl.LightningModule):
optimizer_arguments: dict[str, typing.Any] = {},
augmentation_transforms: TransformSequence = [],
pretrained: bool = False,
dropout: float = 0.1,
num_classes: int = 1,
):
super().__init__()
......@@ -97,7 +100,7 @@ class Densenet(pl.LightningModule):
logger.info(f"Loading pretrained {self.name} model weights")
weights = models.DenseNet121_Weights.DEFAULT
self.model_ft = models.densenet121(weights=weights)
self.model_ft = models.densenet121(weights=weights, drop_rate=dropout)
# Adapt output features
self.model_ft.classifier = torch.nn.Linear(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment