Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
mednet
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Package Registry
Model registry
Operate
Environments
Terraform modules
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
medai
software
mednet
Commits
5fae86b8
Commit
5fae86b8
authored
10 months ago
by
Daniel CARRON
Browse files
Options
Downloads
Patches
Plain Diff
[model] Use base model
parent
21937fb2
No related branches found
No related tags found
1 merge request
!38
Replace sampler balancing by loss balancing
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
src/mednet/models/alexnet.py
+10
-51
10 additions, 51 deletions
src/mednet/models/alexnet.py
src/mednet/models/densenet.py
+10
-51
10 additions, 51 deletions
src/mednet/models/densenet.py
src/mednet/models/pasa.py
+10
-68
10 additions, 68 deletions
src/mednet/models/pasa.py
with
30 additions
and
170 deletions
src/mednet/models/alexnet.py
+
10
−
51
View file @
5fae86b8
...
...
@@ -5,7 +5,6 @@
import
logging
import
typing
import
lightning.pytorch
as
pl
import
torch
import
torch.nn
import
torch.optim.optimizer
...
...
@@ -14,14 +13,14 @@ import torchvision.models as models
import
torchvision.transforms
from
..data.typing
import
TransformSequence
from
.model
import
Model
from
.separate
import
separate
from
.transforms
import
RGB
,
SquareCenterPad
from
.typing
import
Checkpoint
logger
=
logging
.
getLogger
(
__name__
)
class
Alexnet
(
pl
.
LightningModule
):
class
Alexnet
(
Model
):
"""
Alexnet module.
Note: only usable with a normalized dataset
...
...
@@ -68,7 +67,14 @@ class Alexnet(pl.LightningModule):
pretrained
:
bool
=
False
,
num_classes
:
int
=
1
,
):
super
().
__init__
()
super
().
__init__
(
train_loss
,
validation_loss
,
optimizer_type
,
optimizer_arguments
,
augmentation_transforms
,
num_classes
,
)
self
.
name
=
"
alexnet
"
self
.
num_classes
=
num_classes
...
...
@@ -79,17 +85,6 @@ class Alexnet(pl.LightningModule):
RGB
(),
]
self
.
_train_loss
=
train_loss
self
.
_validation_loss
=
(
validation_loss
if
validation_loss
is
not
None
else
train_loss
)
self
.
_optimizer_type
=
optimizer_type
self
.
_optimizer_arguments
=
optimizer_arguments
self
.
_augmentation_transforms
=
torchvision
.
transforms
.
Compose
(
augmentation_transforms
,
)
self
.
pretrained
=
pretrained
# Load pretrained model
...
...
@@ -109,36 +104,6 @@ class Alexnet(pl.LightningModule):
x
=
self
.
normalizer
(
x
)
# type: ignore
return
self
.
model_ft
(
x
)
def
on_save_checkpoint
(
self
,
checkpoint
:
Checkpoint
)
->
None
:
"""
Perform actions during checkpoint saving (called by lightning).
Called by Lightning when saving a checkpoint to give you a chance to
store anything else you might want to save. Use on_load_checkpoint() to
restore what additional data is saved here.
Parameters
----------
checkpoint
The checkpoint to save.
"""
checkpoint
[
"
normalizer
"
]
=
self
.
normalizer
def
on_load_checkpoint
(
self
,
checkpoint
:
Checkpoint
)
->
None
:
"""
Perform actions during model loading (called by lightning).
If you saved something with on_save_checkpoint() this is your chance to
restore this.
Parameters
----------
checkpoint
The loaded checkpoint.
"""
logger
.
info
(
"
Restoring normalizer from checkpoint.
"
)
self
.
normalizer
=
checkpoint
[
"
normalizer
"
]
def
set_normalizer
(
self
,
dataloader
:
torch
.
utils
.
data
.
DataLoader
)
->
None
:
"""
Initialize the normalizer for the current model.
...
...
@@ -208,9 +173,3 @@ class Alexnet(pl.LightningModule):
outputs
=
self
(
batch
[
0
])
probabilities
=
torch
.
sigmoid
(
outputs
)
return
separate
((
probabilities
,
batch
[
1
]))
def
configure_optimizers
(
self
):
return
self
.
_optimizer_type
(
self
.
parameters
(),
**
self
.
_optimizer_arguments
,
)
This diff is collapsed.
Click to expand it.
src/mednet/models/densenet.py
+
10
−
51
View file @
5fae86b8
...
...
@@ -5,7 +5,6 @@
import
logging
import
typing
import
lightning.pytorch
as
pl
import
torch
import
torch.nn
import
torch.optim.optimizer
...
...
@@ -14,14 +13,14 @@ import torchvision.models as models
import
torchvision.transforms
from
..data.typing
import
TransformSequence
from
.model
import
Model
from
.separate
import
separate
from
.transforms
import
RGB
,
SquareCenterPad
from
.typing
import
Checkpoint
logger
=
logging
.
getLogger
(
__name__
)
class
Densenet
(
pl
.
LightningModule
):
class
Densenet
(
Model
):
"""
Densenet-121 module.
Parameters
...
...
@@ -69,7 +68,14 @@ class Densenet(pl.LightningModule):
dropout
:
float
=
0.1
,
num_classes
:
int
=
1
,
):
super
().
__init__
()
super
().
__init__
(
train_loss
,
validation_loss
,
optimizer_type
,
optimizer_arguments
,
augmentation_transforms
,
num_classes
,
)
self
.
name
=
"
densenet-121
"
self
.
num_classes
=
num_classes
...
...
@@ -80,17 +86,6 @@ class Densenet(pl.LightningModule):
RGB
(),
]
self
.
_train_loss
=
train_loss
self
.
_validation_loss
=
(
validation_loss
if
validation_loss
is
not
None
else
train_loss
)
self
.
_optimizer_type
=
optimizer_type
self
.
_optimizer_arguments
=
optimizer_arguments
self
.
_augmentation_transforms
=
torchvision
.
transforms
.
Compose
(
augmentation_transforms
,
)
self
.
pretrained
=
pretrained
# Load pretrained model
...
...
@@ -112,36 +107,6 @@ class Densenet(pl.LightningModule):
x
=
self
.
normalizer
(
x
)
# type: ignore
return
self
.
model_ft
(
x
)
def
on_save_checkpoint
(
self
,
checkpoint
:
Checkpoint
)
->
None
:
"""
Perform actions during checkpoint saving (called by lightning).
Called by Lightning when saving a checkpoint to give you a chance to
store anything else you might want to save. Use on_load_checkpoint() to
restore what additional data is saved here.
Parameters
----------
checkpoint
The checkpoint to save.
"""
checkpoint
[
"
normalizer
"
]
=
self
.
normalizer
def
on_load_checkpoint
(
self
,
checkpoint
:
Checkpoint
)
->
None
:
"""
Perform actions during model loading (called by lightning).
If you saved something with on_save_checkpoint() this is your chance to
restore this.
Parameters
----------
checkpoint
The loaded checkpoint.
"""
logger
.
info
(
"
Restoring normalizer from checkpoint.
"
)
self
.
normalizer
=
checkpoint
[
"
normalizer
"
]
def
set_normalizer
(
self
,
dataloader
:
torch
.
utils
.
data
.
DataLoader
)
->
None
:
"""
Initialize the normalizer for the current model.
...
...
@@ -205,9 +170,3 @@ class Densenet(pl.LightningModule):
outputs
=
self
(
batch
[
0
])
probabilities
=
torch
.
sigmoid
(
outputs
)
return
separate
((
probabilities
,
batch
[
1
]))
def
configure_optimizers
(
self
):
return
self
.
_optimizer_type
(
self
.
parameters
(),
**
self
.
_optimizer_arguments
,
)
This diff is collapsed.
Click to expand it.
src/mednet/models/pasa.py
+
10
−
68
View file @
5fae86b8
...
...
@@ -5,7 +5,6 @@
import
logging
import
typing
import
lightning.pytorch
as
pl
import
torch
import
torch.nn
import
torch.nn.functional
as
F
# noqa: N812
...
...
@@ -14,14 +13,14 @@ import torch.utils.data
import
torchvision.transforms
from
..data.typing
import
TransformSequence
from
.model
import
Model
from
.separate
import
separate
from
.transforms
import
Grayscale
,
SquareCenterPad
from
.typing
import
Checkpoint
logger
=
logging
.
getLogger
(
__name__
)
class
Pasa
(
pl
.
LightningModule
):
class
Pasa
(
Model
):
"""
Implementation of CNN by Pasa and others.
Simple CNN for classification based on paper by [PASA-2019]_.
...
...
@@ -67,7 +66,14 @@ class Pasa(pl.LightningModule):
augmentation_transforms
:
TransformSequence
=
[],
num_classes
:
int
=
1
,
):
super
().
__init__
()
super
().
__init__
(
train_loss
,
validation_loss
,
optimizer_type
,
optimizer_arguments
,
augmentation_transforms
,
num_classes
,
)
self
.
name
=
"
pasa
"
self
.
num_classes
=
num_classes
...
...
@@ -82,17 +88,6 @@ class Pasa(pl.LightningModule):
),
]
self
.
_train_loss
=
train_loss
self
.
_validation_loss
=
(
validation_loss
if
validation_loss
is
not
None
else
train_loss
)
self
.
_optimizer_type
=
optimizer_type
self
.
_optimizer_arguments
=
optimizer_arguments
self
.
_augmentation_transforms
=
torchvision
.
transforms
.
Compose
(
augmentation_transforms
,
)
# First convolution block
self
.
fc1
=
torch
.
nn
.
Conv2d
(
1
,
4
,
(
3
,
3
),
(
2
,
2
),
(
1
,
1
))
self
.
fc2
=
torch
.
nn
.
Conv2d
(
4
,
16
,
(
3
,
3
),
(
2
,
2
),
(
1
,
1
))
...
...
@@ -213,53 +208,6 @@ class Pasa(pl.LightningModule):
# x = F.log_softmax(x, dim=1) # 0 is batch size
def
on_save_checkpoint
(
self
,
checkpoint
:
Checkpoint
)
->
None
:
"""
Perform actions during checkpoint saving (called by lightning).
Called by Lightning when saving a checkpoint to give you a chance to
store anything else you might want to save. Use on_load_checkpoint() to
restore what additional data is saved here.
Parameters
----------
checkpoint
The checkpoint to save.
"""
checkpoint
[
"
normalizer
"
]
=
self
.
normalizer
def
on_load_checkpoint
(
self
,
checkpoint
:
Checkpoint
)
->
None
:
"""
Perform actions during model loading (called by lightning).
If you saved something with on_save_checkpoint() this is your chance to
restore this.
Parameters
----------
checkpoint
The loaded checkpoint.
"""
logger
.
info
(
"
Restoring normalizer from checkpoint.
"
)
self
.
normalizer
=
checkpoint
[
"
normalizer
"
]
def
set_normalizer
(
self
,
dataloader
:
torch
.
utils
.
data
.
DataLoader
)
->
None
:
"""
Initialize the input normalizer for the current model.
Parameters
----------
dataloader
A torch Dataloader from which to compute the mean and std.
"""
from
.normalizer
import
make_z_normalizer
logger
.
info
(
f
"
Uninitialised
{
self
.
name
}
model -
"
f
"
computing z-norm factors from train dataloader.
"
,
)
self
.
normalizer
=
make_z_normalizer
(
dataloader
)
def
training_step
(
self
,
batch
,
_
):
images
=
batch
[
0
]
labels
=
batch
[
1
][
"
label
"
]
...
...
@@ -292,9 +240,3 @@ class Pasa(pl.LightningModule):
outputs
=
self
(
batch
[
0
])
probabilities
=
torch
.
sigmoid
(
outputs
)
return
separate
((
probabilities
,
batch
[
1
]))
def
configure_optimizers
(
self
):
return
self
.
_optimizer_type
(
self
.
parameters
(),
**
self
.
_optimizer_arguments
,
)
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment