Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
mednet
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Package registry
Model registry
Operate
Environments
Terraform modules
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
medai
software
mednet
Commits
1b1382bb
Commit
1b1382bb
authored
1 year ago
by
Daniel CARRON
Committed by
André Anjos
1 year ago
Browse files
Options
Downloads
Patches
Plain Diff
Evaluation script saves more plots, combines results
parent
fc3551d9
No related branches found
Branches containing commit
No related tags found
Tags containing commit
1 merge request
!6
Making use of LightningDataModule and simplification of data loading
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
src/ptbench/engine/evaluator.py
+32
-46
32 additions, 46 deletions
src/ptbench/engine/evaluator.py
src/ptbench/scripts/evaluate.py
+70
-7
70 additions, 7 deletions
src/ptbench/scripts/evaluate.py
with
102 additions
and
53 deletions
src/ptbench/engine/evaluator.py
+
32
−
46
View file @
1b1382bb
...
@@ -186,10 +186,8 @@ def sample_measures_for_threshold(
...
@@ -186,10 +186,8 @@ def sample_measures_for_threshold(
def
run
(
def
run
(
dataset
,
name
:
str
,
name
:
str
,
predictions_folder
:
str
,
predictions_folder
:
str
,
output_folder
:
Optional
[
str
|
None
]
=
None
,
f1_thresh
:
Optional
[
float
]
=
None
,
f1_thresh
:
Optional
[
float
]
=
None
,
eer_thresh
:
Optional
[
float
]
=
None
,
eer_thresh
:
Optional
[
float
]
=
None
,
steps
:
Optional
[
int
]
=
1000
,
steps
:
Optional
[
int
]
=
1000
,
...
@@ -199,9 +197,6 @@ def run(
...
@@ -199,9 +197,6 @@ def run(
Parameters
Parameters
---------
---------
dataset : py:class:`torch.utils.data.Dataset`
a dataset to iterate on
name:
name:
the local name of this dataset (e.g. ``train``, or ``test``), to be
the local name of this dataset (e.g. ``train``, or ``test``), to be
used when saving measures files.
used when saving measures files.
...
@@ -210,9 +205,6 @@ def run(
...
@@ -210,9 +205,6 @@ def run(
folder where predictions for the dataset images has been previously
folder where predictions for the dataset images has been previously
stored
stored
output_folder:
folder where to store results.
f1_thresh:
f1_thresh:
This number should come from
This number should come from
the training set or a separate validation set. Using a test set value
the training set or a separate validation set. Using a test set value
...
@@ -238,9 +230,7 @@ def run(
...
@@ -238,9 +230,7 @@ def run(
post_eer_threshold : float
post_eer_threshold : float
Threshold achieving Equal Error Rate for this dataset
Threshold achieving Equal Error Rate for this dataset
"""
"""
predictions_path
=
os
.
path
.
join
(
predictions_path
=
os
.
path
.
join
(
predictions_folder
,
f
"
{
name
}
.csv
"
)
predictions_folder
,
f
"
predictions_
{
name
}
"
,
"
predictions.csv
"
)
if
not
os
.
path
.
exists
(
predictions_path
):
if
not
os
.
path
.
exists
(
predictions_path
):
predictions_path
=
predictions_folder
predictions_path
=
predictions_folder
...
@@ -298,12 +288,12 @@ def run(
...
@@ -298,12 +288,12 @@ def run(
)
)
data_df
=
data_df
.
set_index
(
"
index
"
)
data_df
=
data_df
.
set_index
(
"
index
"
)
# Save evaluation csv
"""
# Save evaluation csv
if output_folder is not None:
if output_folder is not None:
fullpath = os.path.join(output_folder, f
"
{name}.csv
"
)
fullpath = os.path.join(output_folder, f
"
{name}.csv
"
)
logger.info(f
"
Saving {fullpath}...
"
)
logger.info(f
"
Saving {fullpath}...
"
)
os.makedirs(os.path.dirname(fullpath), exist_ok=True)
os.makedirs(os.path.dirname(fullpath), exist_ok=True)
data_df
.
to_csv
(
fullpath
)
data_df.to_csv(fullpath)
"""
# Find max F1 score
# Find max F1 score
f1_scores
=
numpy
.
asarray
(
data_df
[
"
f1_score
"
])
f1_scores
=
numpy
.
asarray
(
data_df
[
"
f1_score
"
])
...
@@ -328,42 +318,38 @@ def run(
...
@@ -328,42 +318,38 @@ def run(
f
"
threshold
{
post_eer_threshold
:
.
3
f
}
(chosen *a posteriori*)
"
f
"
threshold
{
post_eer_threshold
:
.
3
f
}
(chosen *a posteriori*)
"
)
)
# Save score table
# Generate scores fig
if
output_folder
is
not
None
:
fig_score
,
axes
=
plt
.
subplots
(
1
)
fig
,
axes
=
plt
.
subplots
(
1
)
fig_score
.
tight_layout
(
pad
=
3.0
)
fig
.
tight_layout
(
pad
=
3.0
)
# Names and bounds
# Names and bounds
axes
.
set_xlabel
(
"
Score
"
)
axes
.
set_xlabel
(
"
Score
"
)
axes
.
set_ylabel
(
"
Normalized counts
"
)
axes
.
set_ylabel
(
"
Normalized counts
"
)
axes
.
set_xlim
(
0.0
,
1.0
)
axes
.
set_xlim
(
0.0
,
1.0
)
neg_weights
=
numpy
.
ones_like
(
neg_gt
[
"
likelihood
"
])
/
len
(
neg_weights
=
numpy
.
ones_like
(
neg_gt
[
"
likelihood
"
])
/
len
(
pred_data
[
"
likelihood
"
]
pred_data
[
"
likelihood
"
]
)
)
pos_weights
=
numpy
.
ones_like
(
pos_gt
[
"
likelihood
"
])
/
len
(
pos_weights
=
numpy
.
ones_like
(
pos_gt
[
"
likelihood
"
])
/
len
(
pred_data
[
"
likelihood
"
]
pred_data
[
"
likelihood
"
]
)
)
axes
.
hist
(
[
neg_gt
[
"
likelihood
"
],
pos_gt
[
"
likelihood
"
]],
weights
=
[
neg_weights
,
pos_weights
],
bins
=
100
,
color
=
[
"
tab:blue
"
,
"
tab:orange
"
],
label
=
[
"
Negatives
"
,
"
Positives
"
],
)
axes
.
legend
(
prop
=
{
"
size
"
:
10
},
loc
=
"
upper center
"
)
axes
.
set_title
(
f
"
Score table for
{
name
}
subset
"
)
# we should see some of axes 1 axes
axes
.
hist
(
axes
.
spines
[
"
right
"
].
set_visible
(
False
)
[
neg_gt
[
"
likelihood
"
],
pos_gt
[
"
likelihood
"
]],
axes
.
spines
[
"
top
"
].
set_visible
(
False
)
weights
=
[
neg_weights
,
pos_weights
],
axes
.
spines
[
"
left
"
].
set_position
((
"
data
"
,
-
0.015
))
bins
=
100
,
color
=
[
"
tab:blue
"
,
"
tab:orange
"
],
label
=
[
"
Negatives
"
,
"
Positives
"
],
)
axes
.
legend
(
prop
=
{
"
size
"
:
10
},
loc
=
"
upper center
"
)
axes
.
set_title
(
f
"
Score table for
{
name
}
subset
"
)
fullpath
=
os
.
path
.
join
(
output_folder
,
f
"
{
name
}
_score_table.pdf
"
)
# we should see some of axes 1 axes
fig
.
savefig
(
fullpath
)
axes
.
spines
[
"
right
"
].
set_visible
(
False
)
axes
.
spines
[
"
top
"
].
set_visible
(
False
)
axes
.
spines
[
"
left
"
].
set_position
((
"
data
"
,
-
0.015
))
if
f1_thresh
is
not
None
and
eer_thresh
is
not
None
:
"""
if f1_thresh is not None and eer_thresh is not None:
# get the closest possible threshold we have
# get the closest possible threshold we have
index = int(round(steps * f1_thresh))
index = int(round(steps * f1_thresh))
f1_a_priori = data_df[
"
f1_score
"
][index]
f1_a_priori = data_df[
"
f1_score
"
][index]
...
@@ -375,6 +361,6 @@ def run(
...
@@ -375,6 +361,6 @@ def run(
)
)
# Print the a priori EER threshold
# Print the a priori EER threshold
logger
.
info
(
f
"
Equal error rate (chosen *a priori*)
{
eer_thresh
:
.
3
f
}
"
)
logger.info(f
"
Equal error rate (chosen *a priori*) {eer_thresh:.3f}
"
)
"""
return
maxf1_threshold
,
post_eer_threshold
return
pred_data
,
fig_score
,
maxf1_threshold
,
post_eer_threshold
This diff is collapsed.
Click to expand it.
src/ptbench/scripts/evaluate.py
+
70
−
7
View file @
1b1382bb
...
@@ -2,15 +2,21 @@
...
@@ -2,15 +2,21 @@
#
#
# SPDX-License-Identifier: GPL-3.0-or-later
# SPDX-License-Identifier: GPL-3.0-or-later
import
os
from
collections
import
defaultdict
from
typing
import
Union
from
typing
import
Union
import
click
import
click
from
clapper.click
import
ConfigCommand
,
ResourceOption
,
verbosity_option
from
clapper.click
import
ConfigCommand
,
ResourceOption
,
verbosity_option
from
clapper.logging
import
setup
from
clapper.logging
import
setup
from
matplotlib.backends.backend_pdf
import
PdfPages
from
..data.datamodule
import
CachingDataModule
from
..data.datamodule
import
CachingDataModule
from
..data.typing
import
DataLoader
from
..data.typing
import
DataLoader
from
..utils.plot
import
precision_recall_f1iso
,
roc_curve
from
..utils.table
import
performance_table
logger
=
setup
(
__name__
.
split
(
"
.
"
)[
0
],
format
=
"
%(levelname)s: %(message)s
"
)
logger
=
setup
(
__name__
.
split
(
"
.
"
)[
0
],
format
=
"
%(levelname)s: %(message)s
"
)
...
@@ -117,7 +123,7 @@ def _validate_threshold(
...
@@ -117,7 +123,7 @@ def _validate_threshold(
"
the test set F1-score a priori performance
"
,
"
the test set F1-score a priori performance
"
,
default
=
None
,
default
=
None
,
show_default
=
False
,
show_default
=
False
,
required
=
Fals
e
,
required
=
Tru
e
,
cls
=
ResourceOption
,
cls
=
ResourceOption
,
)
)
@click.option
(
@click.option
(
...
@@ -159,8 +165,10 @@ def evaluate(
...
@@ -159,8 +165,10 @@ def evaluate(
if
isinstance
(
threshold
,
str
):
if
isinstance
(
threshold
,
str
):
# first run evaluation for reference dataset
# first run evaluation for reference dataset
logger
.
info
(
f
"
Evaluating threshold on
'
{
threshold
}
'
set
"
)
logger
.
info
(
f
"
Evaluating threshold on
'
{
threshold
}
'
set
"
)
f1_threshold
,
eer_threshold
=
run
(
_
,
_
,
f1_threshold
,
eer_threshold
=
run
(
_
,
threshold
,
predictions_folder
,
steps
=
steps
name
=
threshold
,
predictions_folder
=
predictions_folder
,
steps
=
steps
,
)
)
if
(
f1_threshold
is
not
None
)
and
(
eer_threshold
is
not
None
):
if
(
f1_threshold
is
not
None
)
and
(
eer_threshold
is
not
None
):
...
@@ -173,17 +181,72 @@ def evaluate(
...
@@ -173,17 +181,72 @@ def evaluate(
else
:
else
:
raise
ValueError
(
"
Threshold value is neither an int nor a float
"
)
raise
ValueError
(
"
Threshold value is neither an int nor a float
"
)
for
k
,
v
in
dataloader
.
items
():
results_dict
=
{
# type: ignore
"
pred_data
"
:
defaultdict
(
dict
),
"
fig_score
"
:
defaultdict
(
dict
),
"
maxf1_threshold
"
:
defaultdict
(
dict
),
"
post_eer_threshold
"
:
defaultdict
(
dict
),
}
for
k
in
dataloader
.
keys
():
if
k
.
startswith
(
"
_
"
):
if
k
.
startswith
(
"
_
"
):
logger
.
info
(
f
"
Skipping dataset
'
{
k
}
'
(not to be evaluated)
"
)
logger
.
info
(
f
"
Skipping dataset
'
{
k
}
'
(not to be evaluated)
"
)
continue
continue
logger
.
info
(
f
"
Analyzing
'
{
k
}
'
set...
"
)
logger
.
info
(
f
"
Analyzing
'
{
k
}
'
set...
"
)
run
(
pred_data
,
fig_score
,
maxf1_threshold
,
post_eer_threshold
=
run
(
v
,
k
,
k
,
predictions_folder
,
predictions_folder
,
output_folder
,
f1_thresh
=
f1_threshold
,
f1_thresh
=
f1_threshold
,
eer_thresh
=
eer_threshold
,
eer_thresh
=
eer_threshold
,
steps
=
steps
,
steps
=
steps
,
)
)
results_dict
[
"
pred_data
"
][
k
]
=
pred_data
results_dict
[
"
fig_score
"
][
k
]
=
fig_score
results_dict
[
"
maxf1_threshold
"
][
k
]
=
maxf1_threshold
results_dict
[
"
post_eer_threshold
"
][
k
]
=
post_eer_threshold
if
output_folder
is
not
None
:
output_scores
=
os
.
path
.
join
(
output_folder
,
"
scores.pdf
"
)
if
output_scores
is
not
None
:
output_scores
=
os
.
path
.
realpath
(
output_scores
)
logger
.
info
(
f
"
Creating and saving scores at
{
output_scores
}
...
"
)
os
.
makedirs
(
os
.
path
.
dirname
(
output_scores
),
exist_ok
=
True
)
score_pdf
=
PdfPages
(
output_scores
)
for
fig
in
results_dict
[
"
fig_score
"
].
values
():
score_pdf
.
savefig
(
fig
)
score_pdf
.
close
()
data
=
{}
for
subset_name
in
dataloader
.
keys
():
data
[
subset_name
]
=
{
"
df
"
:
results_dict
[
"
pred_data
"
][
subset_name
],
"
threshold
"
:
results_dict
[
"
post_eer_threshold
"
][
# type: ignore
threshold
].
item
(),
}
output_figure
=
os
.
path
.
join
(
output_folder
,
"
plots.pdf
"
)
if
output_figure
is
not
None
:
output_figure
=
os
.
path
.
realpath
(
output_figure
)
logger
.
info
(
f
"
Creating and saving plots at
{
output_figure
}
...
"
)
os
.
makedirs
(
os
.
path
.
dirname
(
output_figure
),
exist_ok
=
True
)
pdf
=
PdfPages
(
output_figure
)
pdf
.
savefig
(
precision_recall_f1iso
(
data
))
pdf
.
savefig
(
roc_curve
(
data
))
pdf
.
close
()
output_table
=
os
.
path
.
join
(
output_folder
,
"
table.txt
"
)
logger
.
info
(
"
Tabulating performance summary...
"
)
table
=
performance_table
(
data
,
"
rst
"
)
click
.
echo
(
table
)
if
output_table
is
not
None
:
output_table
=
os
.
path
.
realpath
(
output_table
)
logger
.
info
(
f
"
Saving table at
{
output_table
}
...
"
)
os
.
makedirs
(
os
.
path
.
dirname
(
output_table
),
exist_ok
=
True
)
with
open
(
output_table
,
"
w
"
)
as
f
:
f
.
write
(
table
)
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment