Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
mednet
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Package registry
Model registry
Operate
Environments
Terraform modules
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
medai
software
mednet
Commits
2fbbd899
Commit
2fbbd899
authored
1 year ago
by
André Anjos
Browse files
Options
Downloads
Patches
Plain Diff
[ptbench.engine.trainer] Implement type hints
parent
b596fe59
No related branches found
No related tags found
1 merge request
!6
Making use of LightningDataModule and simplification of data loading
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
src/ptbench/engine/device.py
+1
-1
1 addition, 1 deletion
src/ptbench/engine/device.py
src/ptbench/engine/trainer.py
+34
-39
34 additions, 39 deletions
src/ptbench/engine/trainer.py
src/ptbench/scripts/train.py
+5
-11
5 additions, 11 deletions
src/ptbench/scripts/train.py
with
40 additions
and
51 deletions
src/ptbench/engine/device.py
+
1
−
1
View file @
2fbbd899
...
...
@@ -128,7 +128,7 @@ class DeviceManager:
f
"
Unexpected device type
{
self
.
device_type
}
lacks support
"
)
def
lightning_accelerator
(
self
)
->
tuple
[
str
,
int
|
list
[
int
]
|
str
|
None
]:
def
lightning_accelerator
(
self
)
->
tuple
[
str
,
int
|
list
[
int
]
|
str
]:
"""
Returns the lightning accelerator setup.
Returns
...
...
This diff is collapsed.
Click to expand it.
src/ptbench/engine/trainer.py
+
34
−
39
View file @
2fbbd899
...
...
@@ -14,14 +14,16 @@ import torch.nn
from
..utils.resources
import
ResourceMonitor
,
cpu_constants
,
gpu_constants
from
.callbacks
import
LoggingCallback
from
.device
import
DeviceManager
logger
=
logging
.
getLogger
(
__name__
)
def
save_model_summary
(
output_folder
:
str
,
model
:
torch
.
nn
.
Module
output_folder
:
str
,
model
:
torch
.
nn
.
Module
,
)
->
tuple
[
lightning
.
pytorch
.
callbacks
.
ModelSummary
,
int
]:
"""
Save a little summary of the model in a txt file.
"""
Save
s
a little summary of the model in a txt file.
Parameters
----------
...
...
@@ -32,13 +34,14 @@ def save_model_summary(
model
Network (e.g. driu, hed, unet)
Returns
-------
summary
:
The model summary in a text format
.
summary
The model summary in a text format
total_parameters
:
The number of parameters of the model
.
total_parameters
The number of parameters of the model
"""
summary_path
=
os
.
path
.
join
(
output_folder
,
"
model_summary.txt
"
)
logger
.
info
(
f
"
Saving model summary at
{
summary_path
}
...
"
)
...
...
@@ -94,15 +97,15 @@ def static_information_to_csv(
def
run
(
model
,
datamodule
,
checkpoint_period
,
device_manager
,
argume
nt
s
,
output_folder
,
monitoring_interval
,
batch_chunk_count
,
checkpoint
,
model
:
lightning
.
pytorch
.
LightningModule
,
datamodule
:
lightning
.
pytorch
.
LightningDataModule
,
checkpoint_period
:
int
,
device_manager
:
DeviceManager
,
max_epochs
:
i
nt
,
output_folder
:
str
,
monitoring_interval
:
int
|
float
,
batch_chunk_count
:
int
,
checkpoint
:
str
,
):
"""
Fits a CNN model using supervised learning and save it to disk.
...
...
@@ -113,48 +116,40 @@ def run(
Parameters
----------
model
: :py:class:`torch.nn.Module`
model
Neural network model (e.g. pasa).
data_loader : :py:class:`torch.utils.data.DataLoader`
The pytorch Dataloader used to iterate over batches.
valid_loaders : :py:class:`list` of :py:class:`torch.utils.data.DataLoader`
To be used to validate the model and enable automatic checkpointing.
If ``None``, then do not validate it.
extra_valid_loaders : :py:class:`list` of :py:class:`torch.utils.data.DataLoader`
To be used to validate the model, however **does not affect** automatic
checkpointing. If empty, then does not log anything else. Otherwise,
an extra column with the loss of every dataset in this list is kept on
the final training log.
datamodule
The lightning datamodule to use for training **and** validation
checkpoint_period
: int
checkpoint_period
Save a checkpoint every ``n`` epochs. If set to ``0`` (zero), then do
not save intermediary checkpoints.
device_manager : DeviceManager
A device, to be used for training.
device_manager
An internal device representation, to be used for training and
validation. This representation can be converted into a pytorch device
or a torch lightning accelerator setup.
arguments : dict
Start and end epochs:
max_epochs
The maximum number of epochs to train for.
output_folder
: str
output_folder
Directory in which the results will be saved.
monitoring_interval
: int, float
monitoring_interval
Interval, in seconds (or fractions), through which we should monitor
resources during training.
batch_chunk_count
: int
batch_chunk_count
If this number is different than 1, then each batch will be divided in
this number of chunks. Gradients will be accumulated to perform each
mini-batch. This is particularly interesting when one has limited RAM
on the GPU, but would like to keep training with larger batches. One
exchanges for longer processing times in this case.
"""
max_epoch
=
arguments
[
"
max_epoch
"
]
checkpoint
"""
os
.
makedirs
(
output_folder
,
exist_ok
=
True
)
...
...
@@ -198,7 +193,7 @@ def run(
trainer
=
lightning
.
pytorch
.
Trainer
(
accelerator
=
accelerator
,
devices
=
devices
,
max_epochs
=
max_epoch
,
max_epochs
=
max_epoch
s
,
accumulate_grad_batches
=
batch_chunk_count
,
logger
=
[
csv_logger
,
tensorboard_logger
],
check_val_every_n_epoch
=
1
,
...
...
This diff is collapsed.
Click to expand it.
src/ptbench/scripts/train.py
+
5
−
11
View file @
2fbbd899
...
...
@@ -229,8 +229,7 @@ def train(
procedure in case it stops abruptly.
"""
import
torch.cuda
import
torch.nn
import
torch
from
lightning.pytorch
import
seed_everything
...
...
@@ -276,25 +275,20 @@ def train(
"
Skipping sample class/dataset ownership balancing on user request
"
)
arguments
=
{}
arguments
[
"
max_epoch
"
]
=
epochs
arguments
[
"
epoch
"
]
=
0
logger
.
info
(
f
"
Training for at most
{
epochs
}
epochs.
"
)
# We only load the checkpoint to get some information about its state. The
# actual loading of the model is done in trainer.fit()
if
checkpoint_file
is
not
None
:
checkpoint
=
torch
.
load
(
checkpoint_file
)
arguments
[
"
epoch
"
]
=
checkpoint
[
"
epoch
"
]
logger
.
info
(
"
Training for {} epochs
"
.
format
(
arguments
[
"
max_epoch
"
]))
logger
.
info
(
"
Continuing from epoch {}
"
.
format
(
arguments
[
"
epoch
"
]))
start_epoch
=
checkpoint
[
"
epoch
"
]
logger
.
info
(
f
"
Resuming from epoch
{
start_epoch
}
...
"
)
run
(
model
=
model
,
datamodule
=
datamodule
,
checkpoint_period
=
checkpoint_period
,
device_manager
=
DeviceManager
(
device
),
arguments
=
argument
s
,
max_epochs
=
epoch
s
,
output_folder
=
output_folder
,
monitoring_interval
=
monitoring_interval
,
batch_chunk_count
=
batch_chunk_count
,
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment