Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
mednet
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Package registry
Model registry
Operate
Environments
Terraform modules
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
medai
software
mednet
Commits
0b9ae46c
Commit
0b9ae46c
authored
2 years ago
by
André Anjos
Browse files
Options
Downloads
Plain Diff
Merge branch 'checkpointing-cleanup' into 'main'
Checkpointing cleanup See merge request biosignal/software/ptbench!2
parents
0209ebe1
e360a861
Branches
Branches containing commit
Tags
Tags containing commit
1 merge request
!2
Checkpointing cleanup
Pipeline
#71477
passed
2 years ago
Stage: qa
Stage: test
Stage: doc
Stage: dist
Stage: deploy
Changes
3
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
src/ptbench/scripts/predict.py
+1
-1
1 addition, 1 deletion
src/ptbench/scripts/predict.py
src/ptbench/utils/checkpointer.py
+6
-10
6 additions, 10 deletions
src/ptbench/utils/checkpointer.py
tests/data/lfs
+1
-1
1 addition, 1 deletion
tests/data/lfs
with
8 additions
and
12 deletions
src/ptbench/scripts/predict.py
+
1
−
1
View file @
0b9ae46c
...
@@ -131,7 +131,7 @@ def predict(
...
@@ -131,7 +131,7 @@ def predict(
weight_fullpath
=
os
.
path
.
abspath
(
weight
)
weight_fullpath
=
os
.
path
.
abspath
(
weight
)
checkpointer
=
Checkpointer
(
model
)
checkpointer
=
Checkpointer
(
model
)
checkpointer
.
load
(
weight_fullpath
,
strict
=
False
)
checkpointer
.
load
(
weight_fullpath
)
# Logistic regressor weights
# Logistic regressor weights
if
model
.
name
==
"
logistic_regression
"
:
if
model
.
name
==
"
logistic_regression
"
:
...
...
This diff is collapsed.
Click to expand it.
src/ptbench/utils/checkpointer.py
+
6
−
10
View file @
0b9ae46c
...
@@ -51,7 +51,7 @@ class Checkpointer:
...
@@ -51,7 +51,7 @@ class Checkpointer:
with
open
(
self
.
_last_checkpoint_filename
,
"
w
"
)
as
f
:
with
open
(
self
.
_last_checkpoint_filename
,
"
w
"
)
as
f
:
f
.
write
(
name
)
f
.
write
(
name
)
def
load
(
self
,
f
=
None
,
strict
=
True
):
def
load
(
self
,
f
=
None
):
"""
Loads model, optimizer and scheduler from file.
"""
Loads model, optimizer and scheduler from file.
Parameters
Parameters
...
@@ -62,9 +62,6 @@ class Checkpointer:
...
@@ -62,9 +62,6 @@ class Checkpointer:
contains the checkpoint data to load into the model, and optionally
contains the checkpoint data to load into the model, and optionally
into the optimizer and the scheduler. If not specified, loads data
into the optimizer and the scheduler. If not specified, loads data
from current path.
from current path.
partial : :py:class:`bool`, Optional
If True, loading is not strict and only the model is loaded
"""
"""
if
f
is
None
:
if
f
is
None
:
f
=
self
.
last_checkpoint
()
f
=
self
.
last_checkpoint
()
...
@@ -79,13 +76,12 @@ class Checkpointer:
...
@@ -79,13 +76,12 @@ class Checkpointer:
checkpoint
=
torch
.
load
(
f
,
map_location
=
torch
.
device
(
"
cpu
"
))
checkpoint
=
torch
.
load
(
f
,
map_location
=
torch
.
device
(
"
cpu
"
))
# converts model entry to model parameters
# converts model entry to model parameters
self
.
model
.
load_state_dict
(
checkpoint
.
pop
(
"
model
"
)
,
strict
=
strict
)
self
.
model
.
load_state_dict
(
checkpoint
.
pop
(
"
model
"
))
if
strict
:
if
self
.
optimizer
is
not
None
:
if
self
.
optimizer
is
not
None
:
self
.
optimizer
.
load_state_dict
(
checkpoint
.
pop
(
"
optimizer
"
))
self
.
optimizer
.
load_state_dict
(
checkpoint
.
pop
(
"
optimizer
"
))
if
self
.
scheduler
is
not
None
:
if
self
.
scheduler
is
not
None
:
self
.
scheduler
.
load_state_dict
(
checkpoint
.
pop
(
"
scheduler
"
))
self
.
scheduler
.
load_state_dict
(
checkpoint
.
pop
(
"
scheduler
"
))
return
checkpoint
return
checkpoint
...
...
This diff is collapsed.
Click to expand it.
lfs
@
69185f0d
Compare
64c25ecf
...
69185f0d
Subproject commit 6
4c25ecf20b6f6ac2f250772fcb5338c1196a950
Subproject commit 6
9185f0d9ea67893722c5a840e2caa59946b3b83
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment