Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
mednet
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Package Registry
Model registry
Operate
Environments
Terraform modules
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
medai
software
mednet
Commits
4ccba39c
Commit
4ccba39c
authored
1 year ago
by
Daniel CARRON
Browse files
Options
Downloads
Patches
Plain Diff
Functional densenet model
parent
2baa8e0b
No related branches found
No related tags found
1 merge request
!7
Reviewed DataModule design+docs+types
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
src/ptbench/configs/models/densenet.py
+15
-5
15 additions, 5 deletions
src/ptbench/configs/models/densenet.py
src/ptbench/data/shenzhen/rgb.py
+30
-71
30 additions, 71 deletions
src/ptbench/data/shenzhen/rgb.py
src/ptbench/models/densenet.py
+79
-34
79 additions, 34 deletions
src/ptbench/models/densenet.py
with
124 additions
and
110 deletions
src/ptbench/configs/models/densenet.py
+
15
−
5
View file @
4ccba39c
...
...
@@ -6,20 +6,30 @@
from
torch
import
empty
from
torch.nn
import
BCEWithLogitsLoss
from
torch.optim
import
Adam
from
...models.densenet
import
Densenet
# config
optimizer_configs
=
{
"
lr
"
:
0.0001
}
# optimizer
optimizer
=
"
Adam
"
optimizer
=
Adam
optimizer_configs
=
{
"
lr
"
:
0.0001
}
# criterion
criterion
=
BCEWithLogitsLoss
(
pos_weight
=
empty
(
1
))
criterion_valid
=
BCEWithLogitsLoss
(
pos_weight
=
empty
(
1
))
from
...data.transforms
import
ElasticDeformation
augmentation_transforms
=
[
ElasticDeformation
(
p
=
0.8
),
]
# model
model
=
Densenet
(
criterion
,
criterion_valid
,
optimizer
,
optimizer_configs
,
pretrained
=
False
criterion
,
criterion_valid
,
optimizer
,
optimizer_configs
,
pretrained
=
False
,
augmentation_transforms
=
augmentation_transforms
,
)
This diff is collapsed.
Click to expand it.
src/ptbench/data/shenzhen/rgb.py
+
30
−
71
View file @
4ccba39c
...
...
@@ -2,81 +2,40 @@
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""
Shenzhen data
set for TB detection (cr
oss
validation fold 0, RGB
)
"""
Shenzhen data
module for computer-aided diagn
os
i
s
(default protocol
)
* Split reference: first 80% of TB and healthy CXR for
"
train
"
, rest for
"
test
"
* This configuration resolution: 512 x 512 (default)
* See :py:mod:`ptbench.data.shenzhen` for dataset details
"""
from
clapper.logging
import
setup
from
....data
import
return_subsets
from
....data.base_datamodule
import
BaseDataModule
from
....data.dataset
import
JSONProtocol
from
....data.shenzhen
import
_cached_loader
,
_delayed_loader
,
_protocols
See :py:mod:`ptbench.data.shenzhen` for dataset details.
logger
=
setup
(
__name__
.
split
(
"
.
"
)[
0
],
format
=
"
%(levelname)s: %(message)s
"
)
class
DefaultModule
(
BaseDataModule
):
def
__init__
(
self
,
train_batch_size
=
1
,
predict_batch_size
=
1
,
drop_incomplete_batch
=
False
,
cache_samples
=
False
,
multiproc_kwargs
=
None
,
data_transforms
=
[],
model_transforms
=
[],
train_transforms
=
[],
):
super
().
__init__
(
train_batch_size
=
train_batch_size
,
predict_batch_size
=
predict_batch_size
,
drop_incomplete_batch
=
drop_incomplete_batch
,
multiproc_kwargs
=
multiproc_kwargs
,
)
self
.
cache_samples
=
cache_samples
self
.
has_setup_fit
=
False
This configuration:
* raw data (default): :py:obj:`ptbench.data.shenzhen._tranforms`
* augmentations: elastic deformation (probability = 80%)
* output image resolution: 512x512 pixels
"""
self
.
data_transforms
=
data_transforms
self
.
model_transforms
=
model_transforms
self
.
train_transforms
=
train_transforms
import
importlib.resources
"""
[
transforms.ToPILImage(),
transforms.Lambda(lambda x: x.convert(
"
RGB
"
)),
transforms.ToTensor(),
]
"""
from
torchvision
import
transforms
def
setup
(
self
,
stage
:
str
):
if
self
.
cache_samples
:
logger
.
info
(
"
Argument cache_samples set to True. Samples will be loaded in memory.
"
)
samples_loader
=
_cached_loader
else
:
logger
.
info
(
"
Argument cache_samples set to False. Samples will be loaded at runtime.
"
)
samples_loader
=
_delayed_loader
from
..datamodule
import
CachingDataModule
from
..split
import
JSONDatabaseSplit
from
.raw_data_loader
import
raw_data_loader
self
.
json_protocol
=
JSONProtocol
(
protocols
=
_protocols
,
fieldnames
=
(
"
data
"
,
"
label
"
),
loader
=
samples_loader
,
post_transforms
=
self
.
post_transforms
,
datamodule
=
CachingDataModule
(
database_split
=
JSONDatabaseSplit
(
importlib
.
resources
.
files
(
__name__
.
rsplit
(
"
.
"
,
1
)[
0
]).
joinpath
(
"
default.json.bz2
"
)
if
not
self
.
has_setup_fit
and
stage
==
"
fit
"
:
(
self
.
train_dataset
,
self
.
validation_dataset
,
self
.
extra_validation_datasets
,
)
=
return_subsets
(
self
.
json_protocol
,
"
default
"
,
stage
)
self
.
has_setup_fit
=
True
datamodule
=
DefaultModule
),
raw_data_loader
=
raw_data_loader
,
cache_samples
=
False
,
# train_sampler: typing.Optional[torch.utils.data.Sampler] = None,
model_transforms
=
[
transforms
.
ToPILImage
(),
transforms
.
Lambda
(
lambda
x
:
x
.
convert
(
"
RGB
"
)),
transforms
.
ToTensor
(),
],
# batch_size = 1,
# batch_chunk_count = 1,
# drop_incomplete_batch = False,
# parallel = -1,
)
This diff is collapsed.
Click to expand it.
src/ptbench/models/densenet.py
+
79
−
34
View file @
4ccba39c
...
...
@@ -8,6 +8,7 @@ import lightning.pytorch as pl
import
torch
import
torch.nn
as
nn
import
torchvision.models
as
models
import
torchvision.transforms
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -20,25 +21,37 @@ class Densenet(pl.LightningModule):
def
__init__
(
self
,
criterion
,
criterion_valid
,
optimizer
,
optimizer_configs
,
criterion
=
None
,
criterion_valid
=
None
,
optimizer
=
None
,
optimizer_configs
=
None
,
pretrained
=
False
,
nb_channels
=
3
,
augmentation_transforms
=
[]
,
):
super
().
__init__
()
# Saves all hyper parameters declared on __init__ into ``self.hparams`.
# You can access those by their name, like `self.hparams.optimizer`
self
.
save_hyperparameters
(
ignore
=
[
"
criterion
"
,
"
criterion_valid
"
])
self
.
name
=
"
Densenet
"
self
.
augmentation_transforms
=
torchvision
.
transforms
.
Compose
(
augmentation_transforms
)
self
.
criterion
=
criterion
self
.
criterion_valid
=
criterion_valid
self
.
optimizer
=
optimizer
self
.
optimizer_configs
=
optimizer_configs
self
.
normalizer
=
None
self
.
pretrained
=
pretrained
# Load pretrained model
weights
=
None
if
not
pretrained
else
models
.
DenseNet121_Weights
.
DEFAULT
if
not
pretrained
:
weights
=
None
else
:
logger
.
info
(
"
Loading pretrained model weights
"
)
weights
=
models
.
DenseNet121_Weights
.
DEFAULT
self
.
model_ft
=
models
.
densenet121
(
weights
=
weights
)
# Adapt output features
...
...
@@ -52,17 +65,24 @@ class Densenet(pl.LightningModule):
return
x
def
set_normalizer
(
self
,
dataloader
)
:
"""
TODO: Write this function to set
the
N
ormalizer
def
set_normalizer
(
self
,
dataloader
:
torch
.
utils
.
data
.
DataLoader
)
->
None
:
"""
Initializes
the
n
ormalizer
for the current model.
This function is NOOP if ``pretrained = True`` (normalizer set to
imagenet weights, during contruction).
Parameters
----------
dataloader: :py:class:`torch.utils.data.DataLoader`
A torch Dataloader from which to compute the mean and std.
Will not be used if the model is pretrained.
"""
if
self
.
pretrained
:
from
.normalizer
import
make_imagenet_normalizer
logger
.
warning
(
"
ImageNet pre-trained densenet model - NOT
"
"
ImageNet pre-trained densenet model - NOT
"
"
computing z-norm factors from training data.
"
"
Using preset factors from torchvision.
"
)
...
...
@@ -76,9 +96,38 @@ class Densenet(pl.LightningModule):
)
self
.
normalizer
=
make_z_normalizer
(
dataloader
)
def
set_bce_loss_weights
(
self
,
datamodule
):
"""
Reweights loss weights if BCEWithLogitsLoss is used.
Parameters
----------
datamodule:
A datamodule implementing train_dataloader() and val_dataloader()
"""
from
..data.dataset
import
_get_positive_weights
if
isinstance
(
self
.
criterion
,
torch
.
nn
.
BCEWithLogitsLoss
):
logger
.
info
(
"
Reweighting BCEWithLogitsLoss training criterion.
"
)
train_positive_weights
=
_get_positive_weights
(
datamodule
.
train_dataloader
()
)
self
.
criterion
=
torch
.
nn
.
BCEWithLogitsLoss
(
pos_weight
=
train_positive_weights
)
if
isinstance
(
self
.
criterion_valid
,
torch
.
nn
.
BCEWithLogitsLoss
):
logger
.
info
(
"
Reweighting BCEWithLogitsLoss validation criterion.
"
)
validation_positive_weights
=
_get_positive_weights
(
datamodule
.
val_dataloader
()[
"
validation
"
]
)
self
.
criterion_valid
=
torch
.
nn
.
BCEWithLogitsLoss
(
pos_weight
=
validation_positive_weights
)
def
training_step
(
self
,
batch
,
batch_idx
):
images
=
batch
[
1
]
labels
=
batch
[
2
]
images
=
batch
[
0
]
labels
=
batch
[
1
][
"
label
"
]
# Increase label dimension if too low
# Allows single and multiclass usage
...
...
@@ -86,17 +135,20 @@ class Densenet(pl.LightningModule):
labels
=
torch
.
reshape
(
labels
,
(
labels
.
shape
[
0
],
1
))
# Forward pass on the network
outputs
=
self
(
images
)
augmented_images
=
[
self
.
augmentation_transforms
(
img
).
to
(
self
.
device
)
for
img
in
images
]
# Combine list of augmented images back into a tensor
augmented_images
=
torch
.
cat
(
augmented_images
,
0
).
view
(
images
.
shape
)
outputs
=
self
(
augmented_images
)
# Manually move criterion to selected device, since not part of the model.
self
.
hparams
.
criterion
=
self
.
hparams
.
criterion
.
to
(
self
.
device
)
training_loss
=
self
.
hparams
.
criterion
(
outputs
,
labels
.
float
())
training_loss
=
self
.
criterion
(
outputs
,
labels
.
float
())
return
{
"
loss
"
:
training_loss
}
def
validation_step
(
self
,
batch
,
batch_idx
,
dataloader_idx
=
0
):
images
=
batch
[
1
]
labels
=
batch
[
2
]
images
=
batch
[
0
]
labels
=
batch
[
1
][
"
label
"
]
# Increase label dimension if too low
# Allows single and multiclass usage
...
...
@@ -106,11 +158,7 @@ class Densenet(pl.LightningModule):
# data forwarding on the existing network
outputs
=
self
(
images
)
# Manually move criterion to selected device, since not part of the model.
self
.
hparams
.
criterion_valid
=
self
.
hparams
.
criterion_valid
.
to
(
self
.
device
)
validation_loss
=
self
.
hparams
.
criterion_valid
(
outputs
,
labels
.
float
())
validation_loss
=
self
.
criterion_valid
(
outputs
,
labels
.
float
())
if
dataloader_idx
==
0
:
return
{
"
validation_loss
"
:
validation_loss
}
...
...
@@ -118,8 +166,9 @@ class Densenet(pl.LightningModule):
return
{
f
"
extra_validation_loss_
{
dataloader_idx
}
"
:
validation_loss
}
def
predict_step
(
self
,
batch
,
batch_idx
,
dataloader_idx
=
0
,
grad_cams
=
False
):
names
=
batch
[
0
]
images
=
batch
[
1
]
images
=
batch
[
0
]
labels
=
batch
[
1
][
"
label
"
]
names
=
batch
[
1
][
"
name
"
]
outputs
=
self
(
images
)
probabilities
=
torch
.
sigmoid
(
outputs
)
...
...
@@ -129,12 +178,8 @@ class Densenet(pl.LightningModule):
if
isinstance
(
outputs
,
list
):
outputs
=
outputs
[
-
1
]
return
names
[
0
],
torch
.
flatten
(
probabilities
),
torch
.
flatten
(
batch
[
2
]
)
return
names
[
0
],
torch
.
flatten
(
probabilities
),
torch
.
flatten
(
labels
)
def
configure_optimizers
(
self
):
# Dynamically instantiates the optimizer given the configs
optimizer
=
getattr
(
torch
.
optim
,
self
.
hparams
.
optimizer
)(
self
.
parameters
(),
**
self
.
hparams
.
optimizer_configs
)
optimizer
=
self
.
optimizer
(
self
.
parameters
(),
**
self
.
optimizer_configs
)
return
optimizer
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment