Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
mednet
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Package Registry
Model registry
Operate
Environments
Terraform modules
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
medai
software
mednet
Commits
2e79aae0
Commit
2e79aae0
authored
9 months ago
by
Gokhan OZBULAK
Browse files
Options
Downloads
Patches
Plain Diff
Change flag for batch accumulation.
#25
parent
cf742317
No related branches found
No related tags found
1 merge request
!40
Lightning acc
Pipeline
#87608
passed
9 months ago
Stage: qa
Stage: doc
Stage: dist
Stage: test
Changes
3
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
src/mednet/engine/trainer.py
+9
-8
9 additions, 8 deletions
src/mednet/engine/trainer.py
src/mednet/scripts/experiment.py
+2
-2
2 additions, 2 deletions
src/mednet/scripts/experiment.py
src/mednet/scripts/train.py
+14
-13
14 additions, 13 deletions
src/mednet/scripts/train.py
with
25 additions
and
23 deletions
src/mednet/engine/trainer.py
+
9
−
8
View file @
2e79aae0
...
...
@@ -26,7 +26,7 @@ def run(
max_epochs
:
int
,
output_folder
:
pathlib
.
Path
,
monitoring_interval
:
int
|
float
,
batch_chunk_count
:
int
,
accumulate_grad_batches
:
int
,
checkpoint
:
pathlib
.
Path
|
None
,
):
"""
Fit a CNN model using supervised learning and save it to disk.
...
...
@@ -60,12 +60,13 @@ def run(
monitoring_interval
Interval, in seconds (or fractions), through which we should monitor
resources during training.
batch_chunk_count
If this number is different than 1, then each batch will be divided in
this number of chunks. Gradients will be accumulated to perform each
mini-batch. This is particularly interesting when one has limited RAM
on the GPU, but would like to keep training with larger batches. One
exchanges for longer processing times in this case.
accumulate_grad_batches
Number of accumulations for backward propagation to accumulate gradients
over k batches before stepping the optimizer. The default of 1 forces
the whole batch to be processed at once. Otherwise the batch is multiplied
by accumulate-grad-batches pieces, and gradients are accumulated to complete
each step. This is especially interesting when one is training on GPUs with
a limited amount of onboard RAM.
checkpoint
Path to an optional checkpoint file to load.
"""
...
...
@@ -118,7 +119,7 @@ def run(
accelerator
=
accelerator
,
devices
=
devices
,
max_epochs
=
max_epochs
,
accumulate_grad_batches
=
batch_chunk_count
,
accumulate_grad_batches
=
accumulate_grad_batches
,
logger
=
tensorboard_logger
,
check_val_every_n_epoch
=
validation_period
,
log_every_n_steps
=
len
(
datamodule
.
train_dataloader
()),
...
...
This diff is collapsed.
Click to expand it.
src/mednet/scripts/experiment.py
+
2
−
2
View file @
2e79aae0
...
...
@@ -40,7 +40,7 @@ def experiment(
output_folder
,
epochs
,
batch_size
,
batch_chunk_count
,
accumulate_grad_batches
,
drop_incomplete_batch
,
datamodule
,
validation_period
,
...
...
@@ -79,7 +79,7 @@ def experiment(
output_folder
=
train_output_folder
,
epochs
=
epochs
,
batch_size
=
batch_size
,
batch_chunk_count
=
batch_chunk_count
,
accumulate_grad_batches
=
accumulate_grad_batches
,
drop_incomplete_batch
=
drop_incomplete_batch
,
datamodule
=
datamodule
,
validation_period
=
validation_period
,
...
...
This diff is collapsed.
Click to expand it.
src/mednet/scripts/train.py
+
14
−
13
View file @
2e79aae0
...
...
@@ -79,18 +79,19 @@ def reusable_options(f):
cls
=
ResourceOption
,
)
@click.option
(
"
--batch-chunk-count
"
,
"
-c
"
,
help
=
"
Number of chunks in every batch (this parameter affects
"
"
memory requirements for the network). The number of samples
"
"
loaded for every iteration will be batch-size*batch-chunk-count.
"
"
This parameter is used to reduce the number of samples loaded in each
"
"
iteration, in order to reduce the memory usage in exchange for
"
"
processing time (more iterations). This is especially interesting
"
"
--accumulate-grad-batches
"
,
"
-a
"
,
help
=
"
Number of accumulations for backward propagation to accumulate
"
"
gradients over k batches before stepping the optimizer. This
"
"
parameter, used in conjunction with the batch-size, may be used to
"
"
reduce the number of samples loaded in each iteration, to affect memory
"
"
usage in exchange for processing time (more iterations). This is
"
"
especially interesting when one is training on GPUs with a limited amount
"
"
of onboard RAM. processing time (more iterations). This is especially interesting
"
"
when one is training on GPUs with limited RAM. The default of 1 forces
"
"
the whole batch to be processed at once. Otherwise the batch is
"
"
multiplied by
batch-chunk-count
pieces, and gradients are accumulated
"
"
to complete each
batch
.
"
,
"
multiplied by
accumulate-grad-batches
pieces, and gradients are accumulated
"
"
to complete each
step
.
"
,
required
=
True
,
show_default
=
True
,
default
=
1
,
...
...
@@ -235,7 +236,7 @@ def train(
output_folder
,
epochs
,
batch_size
,
batch_chunk_count
,
accumulate_grad_batches
,
drop_incomplete_batch
,
datamodule
,
validation_period
,
...
...
@@ -340,7 +341,7 @@ def train(
split_name
=
datamodule
.
split_name
,
epochs
=
epochs
,
batch_size
=
batch_size
,
batch_chunk_count
=
batch_chunk_count
,
accumulate_grad_batches
=
accumulate_grad_batches
,
drop_incomplete_batch
=
drop_incomplete_batch
,
validation_period
=
validation_period
,
cache_samples
=
cache_samples
,
...
...
@@ -363,6 +364,6 @@ def train(
max_epochs
=
epochs
,
output_folder
=
output_folder
,
monitoring_interval
=
monitoring_interval
,
batch_chunk_count
=
batch_chunk_count
,
accumulate_grad_batches
=
accumulate_grad_batches
,
checkpoint
=
checkpoint_file
,
)
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment