Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
mednet
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Package Registry
Model registry
Operate
Environments
Terraform modules
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
medai
software
mednet
Commits
73f436fd
Commit
73f436fd
authored
1 year ago
by
André Anjos
Committed by
Daniel CARRON
1 year ago
Browse files
Options
Downloads
Patches
Plain Diff
[engine.callbacks] Refactor callbacks to delegate most work to lightning
parent
50bc1099
No related branches found
No related tags found
1 merge request
!12
Adds grad-cam support on classifiers
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
src/ptbench/engine/callbacks.py
+50
-85
50 additions, 85 deletions
src/ptbench/engine/callbacks.py
src/ptbench/engine/trainer.py
+1
-1
1 addition, 1 deletion
src/ptbench/engine/trainer.py
with
51 additions
and
86 deletions
src/ptbench/engine/callbacks.py
+
50
−
85
View file @
73f436fd
...
@@ -18,8 +18,12 @@ logger = logging.getLogger(__name__)
...
@@ -18,8 +18,12 @@ logger = logging.getLogger(__name__)
class
LoggingCallback
(
lightning
.
pytorch
.
Callback
):
class
LoggingCallback
(
lightning
.
pytorch
.
Callback
):
"""
Callback to log various training metrics and device information.
"""
Callback to log various training metrics and device information.
It ensures CSVLogger logs training and evaluation metrics on the same line
Rationale:
Note that a CSVLogger only accepts numerical values, and not strings.
1. Losses are logged at the end of every batch, accumulated and handled by
the lightning framework
2. Everything else is done at the end of a training or validation epoch and
mostly concerns runtime metrics such as memory and cpu/gpu utilisation.
Parameters
Parameters
...
@@ -33,13 +37,6 @@ class LoggingCallback(lightning.pytorch.Callback):
...
@@ -33,13 +37,6 @@ class LoggingCallback(lightning.pytorch.Callback):
def
__init__
(
self
,
resource_monitor
:
ResourceMonitor
):
def
__init__
(
self
,
resource_monitor
:
ResourceMonitor
):
super
().
__init__
()
super
().
__init__
()
# lists of number of samples/batch and average losses
# - we use this later to compute overall epoch losses
self
.
_training_epoch_loss
:
tuple
[
list
[
int
],
list
[
float
]]
=
([],
[])
self
.
_validation_epoch_loss
:
dict
[
int
,
tuple
[
list
[
int
],
list
[
float
]]
]
=
{}
# timers
# timers
self
.
_start_training_time
=
0.0
self
.
_start_training_time
=
0.0
self
.
_start_training_epoch_time
=
0.0
self
.
_start_training_epoch_time
=
0.0
...
@@ -101,7 +98,6 @@ class LoggingCallback(lightning.pytorch.Callback):
...
@@ -101,7 +98,6 @@ class LoggingCallback(lightning.pytorch.Callback):
The lightning module that is being trained
The lightning module that is being trained
"""
"""
self
.
_start_training_epoch_time
=
time
.
time
()
self
.
_start_training_epoch_time
=
time
.
time
()
self
.
_training_epoch_loss
=
([],
[])
def
on_train_epoch_end
(
def
on_train_epoch_end
(
self
,
self
,
...
@@ -132,17 +128,8 @@ class LoggingCallback(lightning.pytorch.Callback):
...
@@ -132,17 +128,8 @@ class LoggingCallback(lightning.pytorch.Callback):
# evaluates this training epoch total time, and log it
# evaluates this training epoch total time, and log it
epoch_time
=
time
.
time
()
-
self
.
_start_training_epoch_time
epoch_time
=
time
.
time
()
-
self
.
_start_training_epoch_time
# Compute overall training loss considering batches and sizes
# We disconsider accumulate_grad_batches and assume they were all of
# the same size. This way, the average of averages is the overall
# average.
self
.
_to_log
[
"
loss/train
"
]
=
torch
.
mean
(
torch
.
tensor
(
self
.
_training_epoch_loss
[
0
])
*
torch
.
tensor
(
self
.
_training_epoch_loss
[
1
])
).
item
()
self
.
_to_log
[
"
epoch-duration-seconds/train
"
]
=
epoch_time
self
.
_to_log
[
"
epoch-duration-seconds/train
"
]
=
epoch_time
self
.
_to_log
[
"
learning-rate
"
]
=
pl_module
.
optimizers
().
defaults
[
"
lr
"
]
self
.
_to_log
[
"
learning-rate
"
]
=
pl_module
.
optimizers
().
defaults
[
"
lr
"
]
# type: ignore
metrics
=
self
.
_resource_monitor
.
data
metrics
=
self
.
_resource_monitor
.
data
if
metrics
is
not
None
:
if
metrics
is
not
None
:
...
@@ -155,9 +142,23 @@ class LoggingCallback(lightning.pytorch.Callback):
...
@@ -155,9 +142,23 @@ class LoggingCallback(lightning.pytorch.Callback):
"
missing.
"
"
missing.
"
)
)
# if no validation dataloaders, complete cycle by the end of the
overall_cycle_time
=
time
.
time
()
-
self
.
_start_training_epoch_time
# training epoch, by logging all values to the logger
self
.
_to_log
[
"
cycle-time-seconds/train
"
]
=
overall_cycle_time
self
.
on_cycle_end
(
trainer
,
pl_module
)
self
.
_to_log
[
"
total-execution-time-seconds
"
]
=
(
time
.
time
()
-
self
.
_start_training_time
)
self
.
_to_log
[
"
eta-seconds
"
]
=
overall_cycle_time
*
(
trainer
.
max_epochs
-
trainer
.
current_epoch
# type: ignore
)
# the "step" is the tensorboard jargon for "epoch" or "batch",
# depending on how we are logging - in a more general way, it simply
# means the relative time step.
self
.
_to_log
[
"
step
"
]
=
float
(
trainer
.
current_epoch
)
# Do not log during sanity check as results are not relevant
if
not
trainer
.
sanity_checking
:
pl_module
.
log_dict
(
self
.
_to_log
)
self
.
_to_log
=
{}
def
on_train_batch_end
(
def
on_train_batch_end
(
self
,
self
,
...
@@ -198,8 +199,14 @@ class LoggingCallback(lightning.pytorch.Callback):
...
@@ -198,8 +199,14 @@ class LoggingCallback(lightning.pytorch.Callback):
batch_idx
batch_idx
The relative number of the batch
The relative number of the batch
"""
"""
self
.
_training_epoch_loss
[
0
].
append
(
batch
[
0
].
shape
[
0
])
pl_module
.
log
(
self
.
_training_epoch_loss
[
1
].
append
(
outputs
[
"
loss
"
].
item
())
"
loss/train
"
,
outputs
[
"
loss
"
].
item
(),
prog_bar
=
True
,
on_step
=
False
,
on_epoch
=
True
,
batch_size
=
batch
[
0
].
shape
[
0
],
)
def
on_validation_epoch_start
(
def
on_validation_epoch_start
(
self
,
self
,
...
@@ -229,7 +236,6 @@ class LoggingCallback(lightning.pytorch.Callback):
...
@@ -229,7 +236,6 @@ class LoggingCallback(lightning.pytorch.Callback):
The lightning module that is being trained
The lightning module that is being trained
"""
"""
self
.
_start_validation_epoch_time
=
time
.
time
()
self
.
_start_validation_epoch_time
=
time
.
time
()
self
.
_validation_epoch_loss
=
{}
def
on_validation_epoch_end
(
def
on_validation_epoch_end
(
self
,
self
,
...
@@ -271,20 +277,12 @@ class LoggingCallback(lightning.pytorch.Callback):
...
@@ -271,20 +277,12 @@ class LoggingCallback(lightning.pytorch.Callback):
"
missing.
"
"
missing.
"
)
)
# Compute overall validation losses considering batches and sizes
self
.
_to_log
[
"
step
"
]
=
float
(
trainer
.
current_epoch
)
# We disconsider accumulate_grad_batches and assume they were all
# of the same size. This way, the average of averages is the
# Do not log during sanity check as results are not relevant
# overall average.
if
not
trainer
.
sanity_checking
:
for
key
in
sorted
(
self
.
_validation_epoch_loss
.
keys
()):
pl_module
.
log_dict
(
self
.
_to_log
)
if
key
==
0
:
self
.
_to_log
=
{}
name
=
"
loss/validation
"
else
:
name
=
f
"
loss/validation-
{
key
}
"
self
.
_to_log
[
name
]
=
torch
.
mean
(
torch
.
tensor
(
self
.
_validation_epoch_loss
[
key
][
0
])
*
torch
.
tensor
(
self
.
_validation_epoch_loss
[
key
][
1
])
).
item
()
def
on_validation_batch_end
(
def
on_validation_batch_end
(
self
,
self
,
...
@@ -330,50 +328,17 @@ class LoggingCallback(lightning.pytorch.Callback):
...
@@ -330,50 +328,17 @@ class LoggingCallback(lightning.pytorch.Callback):
Index of the dataloader used during validation. Use this to figure
Index of the dataloader used during validation. Use this to figure
out which dataset was used for this validation epoch.
out which dataset was used for this validation epoch.
"""
"""
size
,
value
=
self
.
_validation_epoch_loss
.
setdefault
(
dataloader_idx
,
([],
[])
)
size
.
append
(
batch
[
0
].
shape
[
0
])
value
.
append
(
outputs
.
item
())
def
on_cycle_end
(
self
,
trainer
:
lightning
.
pytorch
.
Trainer
,
pl_module
:
lightning
.
pytorch
.
LightningModule
,
)
->
None
:
"""
Called when the training/validation cycle has ended.
This function will log all relevant values to the various loggers. It
is supposed to be called by the end of the training cycle (consisting
of a training and validation step).
Parameters
----------
trainer
The Lightning trainer object
pl_module
The lightning module that is being trained
"""
# collect some final time for the whole training cycle
if
dataloader_idx
==
0
:
# Note: logging should happen at on_validation_end(), but
key
=
"
loss/validation
"
# apparently you can't log from there
else
:
overall_cycle_time
=
time
.
time
()
-
self
.
_start_training_epoch_time
key
=
f
"
loss/validation-
{
dataloader_idx
}
"
self
.
_to_log
[
"
cycle-time-seconds/train
"
]
=
overall_cycle_time
self
.
_to_log
[
"
total-execution-time-seconds
"
]
=
(
pl_module
.
log
(
time
.
time
()
-
self
.
_start_training_time
key
,
)
outputs
.
item
(),
self
.
_to_log
[
"
eta-seconds
"
]
=
overall_cycle_time
*
(
prog_bar
=
False
,
trainer
.
max_epochs
-
trainer
.
current_epoch
# type: ignore
on_step
=
False
,
on_epoch
=
True
,
batch_size
=
batch
[
0
].
shape
[
0
],
)
)
# Do not log during sanity check as results are not relevant
if
not
trainer
.
sanity_checking
:
for
k
in
sorted
(
self
.
_to_log
.
keys
()):
pl_module
.
log_dict
(
{
k
:
self
.
_to_log
[
k
],
"
step
"
:
float
(
trainer
.
current_epoch
)}
)
self
.
_to_log
=
{}
This diff is collapsed.
Click to expand it.
src/ptbench/engine/trainer.py
+
1
−
1
View file @
73f436fd
...
@@ -193,7 +193,7 @@ def run(
...
@@ -193,7 +193,7 @@ def run(
save_last
=
True
,
# will (re)create the last trained model, at every iteration
save_last
=
True
,
# will (re)create the last trained model, at every iteration
monitor
=
"
loss/validation
"
,
monitor
=
"
loss/validation
"
,
mode
=
"
min
"
,
mode
=
"
min
"
,
save_on_train_epoch_end
=
True
,
# run checks at the end of validation
save_on_train_epoch_end
=
True
,
every_n_epochs
=
validation_period
,
# frequency at which it would check the "monitor"
every_n_epochs
=
validation_period
,
# frequency at which it would check the "monitor"
enable_version_counter
=
False
,
# no versioning of aliased checkpoints
enable_version_counter
=
False
,
# no versioning of aliased checkpoints
)
)
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment