Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
mednet
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Package Registry
Model registry
Operate
Environments
Terraform modules
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
medai
software
mednet
Commits
21937fb2
Commit
21937fb2
authored
10 months ago
by
Daniel CARRON
Browse files
Options
Downloads
Patches
Plain Diff
[model] Create base Model class
parent
9d570ed0
No related branches found
Branches containing commit
No related tags found
Tags containing commit
1 merge request
!38
Replace sampler balancing by loss balancing
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
src/mednet/models/model.py
+143
-0
143 additions, 0 deletions
src/mednet/models/model.py
with
143 additions
and
0 deletions
src/mednet/models/model.py
0 → 100644
+
143
−
0
View file @
21937fb2
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
import
logging
import
typing
import
lightning.pytorch
as
pl
import
torch
import
torch.nn
import
torch.optim.optimizer
import
torch.utils.data
import
torchvision.transforms
from
..data.typing
import
TransformSequence
from
.typing
import
Checkpoint
logger
=
logging
.
getLogger
(
__name__
)
class
Model
(
pl
.
LightningModule
):
"""
Base class for models.
Parameters
----------
train_loss
The loss to be used during the training.
.. warning::
The loss should be set to always return batch averages (as opposed
to the batch sum), as our logging system expects it so.
validation_loss
The loss to be used for validation (may be different from the training
loss). If extra-validation sets are provided, the same loss will be
used throughout.
.. warning::
The loss should be set to always return batch averages (as opposed
to the batch sum), as our logging system expects it so.
optimizer_type
The type of optimizer to use for training.
optimizer_arguments
Arguments to the optimizer after ``params``.
augmentation_transforms
An optional sequence of torch modules containing transforms to be
applied on the input **before** it is fed into the network.
num_classes
Number of outputs (classes) for this model.
"""
def
__init__
(
self
,
train_loss
:
torch
.
nn
.
Module
=
torch
.
nn
.
BCEWithLogitsLoss
(),
validation_loss
:
torch
.
nn
.
Module
|
None
=
None
,
optimizer_type
:
type
[
torch
.
optim
.
Optimizer
]
=
torch
.
optim
.
Adam
,
optimizer_arguments
:
dict
[
str
,
typing
.
Any
]
=
{},
augmentation_transforms
:
TransformSequence
=
[],
num_classes
:
int
=
1
,
):
super
().
__init__
()
self
.
name
=
"
model
"
self
.
num_classes
=
num_classes
self
.
model_transforms
:
TransformSequence
=
[]
self
.
_train_loss
=
train_loss
self
.
_validation_loss
=
(
validation_loss
if
validation_loss
is
not
None
else
train_loss
)
self
.
_optimizer_type
=
optimizer_type
self
.
_optimizer_arguments
=
optimizer_arguments
self
.
_augmentation_transforms
=
torchvision
.
transforms
.
Compose
(
augmentation_transforms
,
)
def
forward
(
self
,
x
):
raise
NotImplementedError
def
on_save_checkpoint
(
self
,
checkpoint
:
Checkpoint
)
->
None
:
"""
Perform actions during checkpoint saving (called by lightning).
Called by Lightning when saving a checkpoint to give you a chance to
store anything else you might want to save. Use on_load_checkpoint() to
restore what additional data is saved here.
Parameters
----------
checkpoint
The checkpoint to save.
"""
checkpoint
[
"
normalizer
"
]
=
self
.
normalizer
def
on_load_checkpoint
(
self
,
checkpoint
:
Checkpoint
)
->
None
:
"""
Perform actions during model loading (called by lightning).
If you saved something with on_save_checkpoint() this is your chance to
restore this.
Parameters
----------
checkpoint
The loaded checkpoint.
"""
logger
.
info
(
"
Restoring normalizer from checkpoint.
"
)
self
.
normalizer
=
checkpoint
[
"
normalizer
"
]
def
set_normalizer
(
self
,
dataloader
:
torch
.
utils
.
data
.
DataLoader
)
->
None
:
"""
Initialize the input normalizer for the current model.
Parameters
----------
dataloader
A torch Dataloader from which to compute the mean and std.
"""
from
.normalizer
import
make_z_normalizer
logger
.
info
(
f
"
Uninitialised
{
self
.
name
}
model -
"
f
"
computing z-norm factors from train dataloader.
"
,
)
self
.
normalizer
=
make_z_normalizer
(
dataloader
)
def
training_step
(
self
,
batch
,
_
):
raise
NotImplementedError
def
validation_step
(
self
,
batch
,
batch_idx
,
dataloader_idx
=
0
):
raise
NotImplementedError
def
predict_step
(
self
,
batch
,
batch_idx
,
dataloader_idx
=
0
):
raise
NotImplementedError
def
configure_optimizers
(
self
):
return
self
.
_optimizer_type
(
self
.
parameters
(),
**
self
.
_optimizer_arguments
,
)
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment