Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
bob
bob.learn.tensorflow
Commits
9d279cac
Commit
9d279cac
authored
Jun 06, 2019
by
Amir MOHAMMADI
Browse files
Add a pixel wise loss
parent
1a2392d8
Changes
1
Hide whitespace changes
Inline
Side-by-side
bob/learn/tensorflow/loss/pixel_wise.py
0 → 100644
View file @
9d279cac
from
..dataset
import
tf_repeat
from
.utils
import
(
balanced_softmax_cross_entropy_loss_weights
,
balanced_sigmoid_cross_entropy_loss_weights
,
)
import
tensorflow
as
tf
class
PixelWise
:
"""A pixel wise loss which is just a cross entropy loss but applied to all pixels"""
def
__init__
(
self
,
balance_weights
=
True
,
n_one_hot_labels
=
None
,
label_smoothing
=
0.5
,
**
kwargs
):
super
(
PixelWise
,
self
).
__init__
(
**
kwargs
)
self
.
balance_weights
=
balance_weights
self
.
n_one_hot_labels
=
n_one_hot_labels
self
.
label_smoothing
=
label_smoothing
def
__call__
(
self
,
labels
,
logits
):
with
tf
.
name_scope
(
"PixelWiseLoss"
):
flatten
=
tf
.
keras
.
layers
.
Flatten
()
logits
=
flatten
(
logits
)
n_pixels
=
logits
.
get_shape
()[
-
1
]
weights
=
1.0
if
self
.
balance_weights
and
self
.
n_one_hot_labels
:
# use labels to figure out the required loss
weights
=
balanced_softmax_cross_entropy_loss_weights
(
labels
,
dtype
=
logits
.
dtype
)
# repeat weights for all pixels
weights
=
tf_repeat
(
weights
[:,
None
],
[
1
,
n_pixels
])
weights
=
tf
.
reshape
(
weights
,
(
-
1
,))
elif
self
.
balance_weights
and
not
self
.
n_one_hot_labels
:
# use labels to figure out the required loss
weights
=
balanced_sigmoid_cross_entropy_loss_weights
(
labels
,
dtype
=
logits
.
dtype
)
# repeat weights for all pixels
weights
=
tf_repeat
(
weights
[:,
None
],
[
1
,
n_pixels
])
if
self
.
n_one_hot_labels
:
labels
=
tf_repeat
(
labels
,
[
n_pixels
,
1
])
labels
=
tf
.
reshape
(
labels
,
(
-
1
,
self
.
n_one_hot_labels
))
# reshape logits too as softmax_cross_entropy is buggy and cannot really
# handle higher dimensions
logits
=
tf
.
reshape
(
logits
,
(
-
1
,
self
.
n_one_hot_labels
))
loss_fn
=
tf
.
losses
.
softmax_cross_entropy
else
:
labels
=
tf
.
reshape
(
labels
,
(
-
1
,
1
))
labels
=
tf_repeat
(
labels
,
[
n_pixels
,
1
])
labels
=
tf
.
reshape
(
labels
,
(
-
1
,
n_pixels
))
loss_fn
=
tf
.
losses
.
sigmoid_cross_entropy
loss_pixel_wise
=
loss_fn
(
labels
,
logits
=
logits
,
weights
=
weights
,
label_smoothing
=
self
.
label_smoothing
,
reduction
=
tf
.
losses
.
Reduction
.
MEAN
,
)
tf
.
summary
.
scalar
(
"loss_pixel_wise"
,
loss_pixel_wise
)
return
loss_pixel_wise
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment