Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
deepdraw
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Model registry
Operate
Environments
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
This is an archived project. Repository and other project resources are read-only.
Show more breadcrumbs
medai
software
deepdraw
Commits
e760ac7d
Commit
e760ac7d
authored
5 years ago
by
André Anjos
Browse files
Options
Downloads
Patches
Plain Diff
[data/imagefolderinference] Accept globs
parent
4d1d4867
No related branches found
No related tags found
1 merge request
!9
Minor fixes
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
bob/ip/binseg/utils/plot.py
+102
-96
102 additions, 96 deletions
bob/ip/binseg/utils/plot.py
with
102 additions
and
96 deletions
bob/ip/binseg/utils/plot.py
+
102
−
96
View file @
e760ac7d
...
...
@@ -3,7 +3,7 @@
import
numpy
as
np
import
os
import
csv
import
csv
import
pandas
as
pd
import
PIL
from
PIL
import
Image
,
ImageFont
,
ImageDraw
...
...
@@ -13,62 +13,62 @@ import torch
def
precision_recall_f1iso
(
precision
,
recall
,
names
,
title
=
None
):
"""
Author: Andre Anjos (andre.anjos@idiap.ch).
Creates a precision-recall plot of the given data.
Creates a precision-recall plot of the given data.
The plot will be annotated with F1-score iso-lines (in which the F1-score
maintains the same value)
maintains the same value)
Parameters
----------
----------
precision : :py:class:`numpy.ndarray` or :py:class:`list`
A list of 1D np arrays containing the Y coordinates of the plot, or
the precision, or a 2D np array in which the rows correspond to each
of the system
'
s precision coordinates.
of the system
'
s precision coordinates.
recall : :py:class:`numpy.ndarray` or :py:class:`list`
A list of 1D np arrays containing the X coordinates of the plot, or
the recall, or a 2D np array in which the rows correspond to each
of the system
'
s recall coordinates.
of the system
'
s recall coordinates.
names : :py:class:`list`
An iterable over the names of each of the systems along the rows of
``precision`` and ``recall``
``precision`` and ``recall``
title : :py:class:`str`, optional
A title for the plot. If not set, omits the title
A title for the plot. If not set, omits the title
Returns
-------
-------
matplotlib.figure.Figure
A matplotlib figure you can save or display
"""
A matplotlib figure you can save or display
"""
import
matplotlib
matplotlib
.
use
(
'
agg
'
)
import
matplotlib.pyplot
as
plt
import
matplotlib.pyplot
as
plt
from
itertools
import
cycle
fig
,
ax1
=
plt
.
subplots
(
1
)
fig
,
ax1
=
plt
.
subplots
(
1
)
lines
=
[
"
-
"
,
"
--
"
,
"
-.
"
,
"
:
"
]
linecycler
=
cycle
(
lines
)
for
p
,
r
,
n
in
zip
(
precision
,
recall
,
names
):
for
p
,
r
,
n
in
zip
(
precision
,
recall
,
names
):
# Plots only from the point where recall reaches its maximum, otherwise, we
# don't see a curve...
i
=
r
.
argmax
()
pi
=
p
[
i
:]
ri
=
r
[
i
:]
ri
=
r
[
i
:]
valid
=
(
pi
+
ri
)
>
0
f1
=
2
*
(
pi
[
valid
]
*
ri
[
valid
])
/
(
pi
[
valid
]
+
ri
[
valid
])
f1
=
2
*
(
pi
[
valid
]
*
ri
[
valid
])
/
(
pi
[
valid
]
+
ri
[
valid
])
# optimal point along the curve
argmax
=
f1
.
argmax
()
opi
=
pi
[
argmax
]
ori
=
ri
[
argmax
]
# Plot Recall/Precision as threshold changes
ax1
.
plot
(
ri
[
pi
>
0
],
pi
[
pi
>
0
],
next
(
linecycler
),
label
=
'
[F={:.4f}] {}
'
.
format
(
f1
.
max
(),
n
),)
ax1
.
plot
(
ri
[
pi
>
0
],
pi
[
pi
>
0
],
next
(
linecycler
),
label
=
'
[F={:.4f}] {}
'
.
format
(
f1
.
max
(),
n
),)
ax1
.
plot
(
ori
,
opi
,
marker
=
'
o
'
,
linestyle
=
None
,
markersize
=
3
,
color
=
'
black
'
)
ax1
.
grid
(
linestyle
=
'
--
'
,
linewidth
=
1
,
color
=
'
gray
'
,
alpha
=
0.2
)
ax1
.
grid
(
linestyle
=
'
--
'
,
linewidth
=
1
,
color
=
'
gray
'
,
alpha
=
0.2
)
if
len
(
names
)
>
1
:
plt
.
legend
(
loc
=
'
lower left
'
,
framealpha
=
0.5
)
plt
.
legend
(
loc
=
'
lower left
'
,
framealpha
=
0.5
)
ax1
.
set_xlabel
(
'
Recall
'
)
ax1
.
set_ylabel
(
'
Precision
'
)
ax1
.
set_xlim
([
0.0
,
1.0
])
ax1
.
set_ylim
([
0.0
,
1.0
])
if
title
is
not
None
:
ax1
.
set_title
(
title
)
ax1
.
set_ylim
([
0.0
,
1.0
])
if
title
is
not
None
:
ax1
.
set_title
(
title
)
# Annotates plot with F1-score iso-lines
ax2
=
ax1
.
twinx
()
f_scores
=
np
.
linspace
(
0.1
,
0.9
,
num
=
9
)
...
...
@@ -79,70 +79,70 @@ def precision_recall_f1iso(precision, recall, names, title=None):
y
=
f_score
*
x
/
(
2
*
x
-
f_score
)
l
,
=
plt
.
plot
(
x
[
y
>=
0
],
y
[
y
>=
0
],
color
=
'
green
'
,
alpha
=
0.1
)
tick_locs
.
append
(
y
[
-
1
])
tick_labels
.
append
(
'
%.1f
'
%
f_score
)
tick_labels
.
append
(
'
%.1f
'
%
f_score
)
ax2
.
tick_params
(
axis
=
'
y
'
,
which
=
'
both
'
,
pad
=
0
,
right
=
False
,
left
=
False
)
ax2
.
set_ylabel
(
'
iso-F
'
,
color
=
'
green
'
,
alpha
=
0.3
)
ax2
.
set_ylim
([
0.0
,
1.0
])
ax2
.
yaxis
.
set_label_coords
(
1.015
,
0.97
)
ax2
.
set_yticks
(
tick_locs
)
#notice these are invisible
ax2
.
yaxis
.
set_label_coords
(
1.015
,
0.97
)
ax2
.
set_yticks
(
tick_locs
)
#notice these are invisible
for
k
in
ax2
.
set_yticklabels
(
tick_labels
):
k
.
set_color
(
'
green
'
)
k
.
set_alpha
(
0.3
)
k
.
set_size
(
8
)
k
.
set_size
(
8
)
# we should see some of axes 1 axes
ax1
.
spines
[
'
right
'
].
set_visible
(
False
)
ax1
.
spines
[
'
top
'
].
set_visible
(
False
)
ax1
.
spines
[
'
left
'
].
set_position
((
'
data
'
,
-
0.015
))
ax1
.
spines
[
'
bottom
'
].
set_position
((
'
data
'
,
-
0.015
))
ax1
.
spines
[
'
bottom
'
].
set_position
((
'
data
'
,
-
0.015
))
# we shouldn't see any of axes 2 axes
ax2
.
spines
[
'
right
'
].
set_visible
(
False
)
ax2
.
spines
[
'
top
'
].
set_visible
(
False
)
ax2
.
spines
[
'
left
'
].
set_visible
(
False
)
ax2
.
spines
[
'
bottom
'
].
set_visible
(
False
)
plt
.
tight_layout
()
return
fig
ax2
.
spines
[
'
bottom
'
].
set_visible
(
False
)
plt
.
tight_layout
()
return
fig
def
precision_recall_f1iso_confintval
(
precision
,
recall
,
pr_upper
,
pr_lower
,
re_upper
,
re_lower
,
names
,
title
=
None
):
"""
Author: Andre Anjos (andre.anjos@idiap.ch).
Creates a precision-recall plot of the given data.
Creates a precision-recall plot of the given data.
The plot will be annotated with F1-score iso-lines (in which the F1-score
maintains the same value)
maintains the same value)
Parameters
----------
----------
precision : :py:class:`numpy.ndarray` or :py:class:`list`
A list of 1D np arrays containing the Y coordinates of the plot, or
the precision, or a 2D np array in which the rows correspond to each
of the system
'
s precision coordinates.
of the system
'
s precision coordinates.
recall : :py:class:`numpy.ndarray` or :py:class:`list`
A list of 1D np arrays containing the X coordinates of the plot, or
the recall, or a 2D np array in which the rows correspond to each
of the system
'
s recall coordinates.
of the system
'
s recall coordinates.
names : :py:class:`list`
An iterable over the names of each of the systems along the rows of
``precision`` and ``recall``
``precision`` and ``recall``
title : :py:class:`str`, optional
A title for the plot. If not set, omits the title
A title for the plot. If not set, omits the title
Returns
-------
-------
matplotlib.figure.Figure
A matplotlib figure you can save or display
"""
A matplotlib figure you can save or display
"""
import
matplotlib
matplotlib
.
use
(
'
agg
'
)
import
matplotlib.pyplot
as
plt
import
matplotlib.pyplot
as
plt
from
itertools
import
cycle
fig
,
ax1
=
plt
.
subplots
(
1
)
fig
,
ax1
=
plt
.
subplots
(
1
)
lines
=
[
"
-
"
,
"
--
"
,
"
-.
"
,
"
:
"
]
colors
=
[
'
#1f77b4
'
,
'
#ff7f0e
'
,
'
#2ca02c
'
,
'
#d62728
'
,
'
#9467bd
'
,
'
#8c564b
'
,
'
#e377c2
'
,
'
#7f7f7f
'
,
'
#bcbd22
'
,
'
#17becf
'
]
colorcycler
=
cycle
(
colors
)
linecycler
=
cycle
(
lines
)
for
p
,
r
,
pu
,
pl
,
ru
,
rl
,
n
in
zip
(
precision
,
recall
,
pr_upper
,
pr_lower
,
re_upper
,
re_lower
,
names
):
for
p
,
r
,
pu
,
pl
,
ru
,
rl
,
n
in
zip
(
precision
,
recall
,
pr_upper
,
pr_lower
,
re_upper
,
re_lower
,
names
):
# Plots only from the point where recall reaches its maximum, otherwise, we
# don't see a curve...
i
=
r
.
argmax
()
...
...
@@ -151,24 +151,24 @@ def precision_recall_f1iso_confintval(precision, recall, pr_upper, pr_lower, re_
pui
=
pu
[
i
:]
pli
=
pl
[
i
:]
rui
=
ru
[
i
:]
rli
=
rl
[
i
:]
rli
=
rl
[
i
:]
valid
=
(
pi
+
ri
)
>
0
f1
=
2
*
(
pi
[
valid
]
*
ri
[
valid
])
/
(
pi
[
valid
]
+
ri
[
valid
])
f1
=
2
*
(
pi
[
valid
]
*
ri
[
valid
])
/
(
pi
[
valid
]
+
ri
[
valid
])
# optimal point along the curve
argmax
=
f1
.
argmax
()
opi
=
pi
[
argmax
]
ori
=
ri
[
argmax
]
# Plot Recall/Precision as threshold changes
ax1
.
plot
(
ri
[
pi
>
0
],
pi
[
pi
>
0
],
next
(
linecycler
),
label
=
'
[F={:.4f}] {}
'
.
format
(
f1
.
max
(),
n
),)
ax1
.
plot
(
ri
[
pi
>
0
],
pi
[
pi
>
0
],
next
(
linecycler
),
label
=
'
[F={:.4f}] {}
'
.
format
(
f1
.
max
(),
n
),)
ax1
.
plot
(
ori
,
opi
,
marker
=
'
o
'
,
linestyle
=
None
,
markersize
=
3
,
color
=
'
black
'
)
# Plot confidence
# Upper bound
#ax1.plot(r95ui[p95ui>0], p95ui[p95ui>0])
#ax1.plot(r95ui[p95ui>0], p95ui[p95ui>0])
# Lower bound
#ax1.plot(r95li[p95li>0], p95li[p95li>0])
# create the limiting polygon
vert_x
=
np
.
concatenate
((
rui
[
pui
>
0
],
rli
[
pli
>
0
][::
-
1
]))
vert_y
=
np
.
concatenate
((
pui
[
pui
>
0
],
pli
[
pli
>
0
][::
-
1
]))
vert_y
=
np
.
concatenate
((
pui
[
pui
>
0
],
pli
[
pli
>
0
][::
-
1
]))
# hacky workaround to plot 2nd human
if
np
.
isclose
(
np
.
mean
(
rui
),
rui
[
1
],
rtol
=
1e-05
):
print
(
'
found human
'
)
...
...
@@ -177,14 +177,14 @@ def precision_recall_f1iso_confintval(precision, recall, pr_upper, pr_lower, re_
p
=
plt
.
Polygon
(
np
.
column_stack
((
vert_x
,
vert_y
)),
facecolor
=
next
(
colorcycler
),
alpha
=
.
2
,
edgecolor
=
'
none
'
,
lw
=
.
2
)
ax1
.
add_artist
(
p
)
ax1
.
grid
(
linestyle
=
'
--
'
,
linewidth
=
1
,
color
=
'
gray
'
,
alpha
=
0.2
)
ax1
.
grid
(
linestyle
=
'
--
'
,
linewidth
=
1
,
color
=
'
gray
'
,
alpha
=
0.2
)
if
len
(
names
)
>
1
:
plt
.
legend
(
loc
=
'
lower left
'
,
framealpha
=
0.5
)
plt
.
legend
(
loc
=
'
lower left
'
,
framealpha
=
0.5
)
ax1
.
set_xlabel
(
'
Recall
'
)
ax1
.
set_ylabel
(
'
Precision
'
)
ax1
.
set_xlim
([
0.0
,
1.0
])
ax1
.
set_ylim
([
0.0
,
1.0
])
if
title
is
not
None
:
ax1
.
set_title
(
title
)
ax1
.
set_ylim
([
0.0
,
1.0
])
if
title
is
not
None
:
ax1
.
set_title
(
title
)
# Annotates plot with F1-score iso-lines
ax2
=
ax1
.
twinx
()
f_scores
=
np
.
linspace
(
0.1
,
0.9
,
num
=
9
)
...
...
@@ -195,45 +195,45 @@ def precision_recall_f1iso_confintval(precision, recall, pr_upper, pr_lower, re_
y
=
f_score
*
x
/
(
2
*
x
-
f_score
)
l
,
=
plt
.
plot
(
x
[
y
>=
0
],
y
[
y
>=
0
],
color
=
'
green
'
,
alpha
=
0.1
)
tick_locs
.
append
(
y
[
-
1
])
tick_labels
.
append
(
'
%.1f
'
%
f_score
)
tick_labels
.
append
(
'
%.1f
'
%
f_score
)
ax2
.
tick_params
(
axis
=
'
y
'
,
which
=
'
both
'
,
pad
=
0
,
right
=
False
,
left
=
False
)
ax2
.
set_ylabel
(
'
iso-F
'
,
color
=
'
green
'
,
alpha
=
0.3
)
ax2
.
set_ylim
([
0.0
,
1.0
])
ax2
.
yaxis
.
set_label_coords
(
1.015
,
0.97
)
ax2
.
set_yticks
(
tick_locs
)
#notice these are invisible
ax2
.
yaxis
.
set_label_coords
(
1.015
,
0.97
)
ax2
.
set_yticks
(
tick_locs
)
#notice these are invisible
for
k
in
ax2
.
set_yticklabels
(
tick_labels
):
k
.
set_color
(
'
green
'
)
k
.
set_alpha
(
0.3
)
k
.
set_size
(
8
)
k
.
set_size
(
8
)
# we should see some of axes 1 axes
ax1
.
spines
[
'
right
'
].
set_visible
(
False
)
ax1
.
spines
[
'
top
'
].
set_visible
(
False
)
ax1
.
spines
[
'
left
'
].
set_position
((
'
data
'
,
-
0.015
))
ax1
.
spines
[
'
bottom
'
].
set_position
((
'
data
'
,
-
0.015
))
ax1
.
spines
[
'
bottom
'
].
set_position
((
'
data
'
,
-
0.015
))
# we shouldn't see any of axes 2 axes
ax2
.
spines
[
'
right
'
].
set_visible
(
False
)
ax2
.
spines
[
'
top
'
].
set_visible
(
False
)
ax2
.
spines
[
'
left
'
].
set_visible
(
False
)
ax2
.
spines
[
'
bottom
'
].
set_visible
(
False
)
plt
.
tight_layout
()
return
fig
ax2
.
spines
[
'
bottom
'
].
set_visible
(
False
)
plt
.
tight_layout
()
return
fig
def
loss_curve
(
df
,
title
):
"""
Creates a loss curve given a Dataframe with column names:
``[
'
avg. loss
'
,
'
median loss
'
,
'
lr
'
,
'
max memory
'
]``
Parameters
----------
df : :py:class:`pandas.DataFrame`
Returns
-------
matplotlib.figure.Figure
"""
"""
import
matplotlib
matplotlib
.
use
(
'
agg
'
)
import
matplotlib.pyplot
as
plt
import
matplotlib.pyplot
as
plt
ax1
=
df
.
plot
(
y
=
"
median loss
"
,
grid
=
True
)
ax1
.
set_title
(
title
)
ax1
.
set_ylabel
(
'
median loss
'
)
...
...
@@ -241,7 +241,7 @@ def loss_curve(df, title):
ax2
=
df
[
'
lr
'
].
plot
(
secondary_y
=
True
,
legend
=
True
,
grid
=
True
,)
ax2
.
set_ylabel
(
'
lr
'
)
ax1
.
set_xlabel
(
'
epoch
'
)
plt
.
tight_layout
()
plt
.
tight_layout
()
fig
=
ax1
.
get_figure
()
return
fig
...
...
@@ -249,12 +249,12 @@ def loss_curve(df, title):
def
read_metricscsv
(
file
):
"""
Read precision and recall from csv file
Parameters
----------
file : str
path to file
Returns
-------
:py:class:`numpy.ndarray`
...
...
@@ -283,7 +283,7 @@ def read_metricscsv(file):
def
plot_overview
(
outputfolders
,
title
):
"""
Plots comparison chart of all trained models
Parameters
----------
outputfolder : list
...
...
@@ -303,7 +303,7 @@ def plot_overview(outputfolders,title):
names
=
[]
params
=
[]
for
folder
in
outputfolders
:
# metrics
# metrics
metrics_path
=
os
.
path
.
join
(
folder
,
'
results/Metrics.csv
'
)
pr
,
re
,
pr_upper
,
pr_lower
,
re_upper
,
re_lower
=
read_metricscsv
(
metrics_path
)
precisions
.
append
(
pr
)
...
...
@@ -335,7 +335,7 @@ def metricsviz(dataset
,
overlayed
=
True
):
"""
Visualizes true positives, false positives and false negatives
Default colors TP: Gray, FP: Cyan, FN: Orange
Parameters
----------
dataset : :py:class:`torch.utils.data.Dataset`
...
...
@@ -354,27 +354,27 @@ def metricsviz(dataset
name
=
sample
[
0
]
img
=
VF
.
to_pil_image
(
sample
[
1
])
# PIL Image
gt
=
sample
[
2
].
byte
()
# byte tensor
# read metrics
# read metrics
metrics
=
pd
.
read_csv
(
os
.
path
.
join
(
output_path
,
'
results
'
,
'
Metrics.csv
'
))
optimal_threshold
=
metrics
[
'
threshold
'
][
metrics
[
'
f1_score
'
].
idxmax
()]
# read probability output
# read probability output
pred
=
Image
.
open
(
os
.
path
.
join
(
output_path
,
'
images
'
,
name
))
pred
=
pred
.
convert
(
mode
=
'
L
'
)
pred
=
VF
.
to_tensor
(
pred
)
binary_pred
=
torch
.
gt
(
pred
,
optimal_threshold
).
byte
()
# calc metrics
# equals and not-equals
equals
=
torch
.
eq
(
binary_pred
,
gt
)
# tensor
notequals
=
torch
.
ne
(
binary_pred
,
gt
)
# tensor
# true positives
notequals
=
torch
.
ne
(
binary_pred
,
gt
)
# tensor
# true positives
tp_tensor
=
(
gt
*
binary_pred
)
# tensor
tp_pil
=
VF
.
to_pil_image
(
tp_tensor
.
float
())
tp_pil_colored
=
PIL
.
ImageOps
.
colorize
(
tp_pil
,
(
0
,
0
,
0
),
tp_color
)
# false positives
fp_tensor
=
torch
.
eq
((
binary_pred
+
tp_tensor
),
1
)
# false positives
fp_tensor
=
torch
.
eq
((
binary_pred
+
tp_tensor
),
1
)
fp_pil
=
VF
.
to_pil_image
(
fp_tensor
.
float
())
fp_pil_colored
=
PIL
.
ImageOps
.
colorize
(
fp_pil
,
(
0
,
0
,
0
),
fp_color
)
# false negatives
...
...
@@ -385,7 +385,7 @@ def metricsviz(dataset
# paste together
tp_pil_colored
.
paste
(
fp_pil_colored
,
mask
=
fp_pil
)
tp_pil_colored
.
paste
(
fn_pil_colored
,
mask
=
fn_pil
)
if
overlayed
:
tp_pil_colored
=
PIL
.
Image
.
blend
(
img
,
tp_pil_colored
,
0.4
)
img_metrics
=
pd
.
read_csv
(
os
.
path
.
join
(
output_path
,
'
results
'
,
name
+
'
.csv
'
))
...
...
@@ -396,15 +396,17 @@ def metricsviz(dataset
fnt
=
ImageFont
.
truetype
(
'
FreeMono.ttf
'
,
fnt_size
)
draw
.
text
((
0
,
0
),
"
F1: {:.4f}
"
.
format
(
f1
),(
255
,
255
,
255
),
font
=
fnt
)
# save to disk
# save to disk
overlayed_path
=
os
.
path
.
join
(
output_path
,
'
tpfnfpviz
'
)
if
not
os
.
path
.
exists
(
overlayed_path
):
os
.
makedirs
(
overlayed_path
)
tp_pil_colored
.
save
(
os
.
path
.
join
(
overlayed_path
,
name
))
fullpath
=
os
.
path
.
join
(
overlayed_path
,
name
)
fulldir
=
os
.
path
.
dirname
(
fullpath
)
if
not
os
.
path
.
exists
(
fulldir
):
os
.
makedirs
(
fulldir
)
tp_pil_colored
.
save
(
fullpath
)
def
overlay
(
dataset
,
output_path
):
"""
Overlays prediction probabilities vessel tree with original test image.
Parameters
----------
dataset : :py:class:`torch.utils.data.Dataset`
...
...
@@ -416,8 +418,8 @@ def overlay(dataset, output_path):
# get sample
name
=
sample
[
0
]
img
=
VF
.
to_pil_image
(
sample
[
1
])
# PIL Image
# read probability output
# read probability output
pred
=
Image
.
open
(
os
.
path
.
join
(
output_path
,
'
images
'
,
name
)).
convert
(
mode
=
'
L
'
)
# color and overlay
pred_green
=
PIL
.
ImageOps
.
colorize
(
pred
,
(
0
,
0
,
0
),
(
0
,
255
,
0
))
...
...
@@ -430,14 +432,16 @@ def overlay(dataset, output_path):
#draw.text((0, 0),"F1: {:.4f}".format(f1),(255,255,255),font=fnt)
# save to disk
overlayed_path
=
os
.
path
.
join
(
output_path
,
'
overlayed
'
)
if
not
os
.
path
.
exists
(
overlayed_path
):
os
.
makedirs
(
overlayed_path
)
overlayed
.
save
(
os
.
path
.
join
(
overlayed_path
,
name
))
fullpath
=
os
.
path
.
join
(
overlayed_path
,
name
)
fulldir
=
os
.
path
.
dirname
(
fullpath
)
if
not
os
.
path
.
exists
(
fulldir
):
os
.
makedirs
(
fulldir
)
overlayed
.
save
(
fullpath
)
def
savetransformedtest
(
dataset
,
output_path
):
"""
Save the test images as they are fed into the neural network.
"""
Save the test images as they are fed into the neural network.
Makes it easier to create overlay animations (e.g. slide)
Parameters
----------
dataset : :py:class:`torch.utils.data.Dataset`
...
...
@@ -449,8 +453,10 @@ def savetransformedtest(dataset, output_path):
# get sample
name
=
sample
[
0
]
img
=
VF
.
to_pil_image
(
sample
[
1
])
# PIL Image
# save to disk
testimg_path
=
os
.
path
.
join
(
output_path
,
'
transformedtestimages
'
)
if
not
os
.
path
.
exists
(
testimg_path
):
os
.
makedirs
(
testimg_path
)
img
.
save
(
os
.
path
.
join
(
testimg_path
,
name
))
fullpath
=
os
.
path
.
join
(
testimg_path
,
name
)
fulldir
=
os
.
path
.
dirname
(
fullpath
)
if
not
os
.
path
.
exists
(
fulldir
):
os
.
makedirs
(
fulldir
)
img
.
save
(
fullpath
)
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment