Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
mednet
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Package registry
Model registry
Operate
Environments
Terraform modules
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
medai
software
mednet
Commits
1f0e3938
Commit
1f0e3938
authored
10 months ago
by
Daniel CARRON
Browse files
Options
Downloads
Patches
Plain Diff
[model] Fix balancing of multiclass targets
parent
b32ebbe5
No related branches found
Branches containing commit
No related tags found
Tags containing commit
1 merge request
!38
Replace sampler balancing by loss balancing
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
src/mednet/models/loss_weights.py
+143
-49
143 additions, 49 deletions
src/mednet/models/loss_weights.py
src/mednet/models/model.py
+4
-3
4 additions, 3 deletions
src/mednet/models/model.py
with
147 additions
and
52 deletions
src/mednet/models/loss_weights.py
+
143
−
49
View file @
1f0e3938
...
...
@@ -3,86 +3,180 @@
# SPDX-License-Identifier: GPL-3.0-or-later
import
logging
import
typing
from
collections
import
Counter
import
torch
import
torch.utils.data
from
..data.typing
import
DataLoader
logger
=
logging
.
getLogger
(
__name__
)
def
_get_label_weights
(
dataloader
:
torch
.
utils
.
data
.
DataLoader
,
)
->
torch
.
Tensor
:
"""
Compute the weights of each class of a DataLoader.
def
compute_binary_weights
(
targets
):
"""
Compute the positive weights when using binary targets.
This function inputs a pytorch DataLoader and computes the ratio between
number of negative and positive samples (scalar). The weight can be used
to adjust minimisation criteria to in cases there is a huge data imbalance.
Parameters
----------
targets
A tensor of integer values of length n.
It returns a vector with weights (inverse counts) for each label.
Returns
-------
The positive weights per class.
"""
class_sample_count
=
[
float
((
targets
==
t
).
sum
().
item
())
for
t
in
torch
.
unique
(
targets
,
sorted
=
True
)
]
# Divide negatives by positives
return
torch
.
tensor
(
[
class_sample_count
[
0
]
/
class_sample_count
[
1
]],
).
reshape
(
-
1
)
def
compute_multiclass_weights
(
targets
):
"""
Compute the positive weights when using exclusive, multiclass targets.
Parameters
----------
dataloader
A DataLoader from which to compute the positive weights. Entries must
be a dictionary which must contain a ``label`` key.
targets
A [C x n] tensor of integer values, where `C` is the number of target classes and `n` the number of samples.
Returns
-------
torch.Tensor
The positive weight of each class in the dataset given as input.
The positive weights per class.
"""
targets
=
torch
.
tensor
(
[
sample
for
batch
in
dataloader
for
sample
in
batch
[
1
][
"
label
"
]],
class_sample_count
=
torch
.
sum
(
targets
,
dim
=
1
)
negative_class_sample_count
=
(
torch
.
full
((
targets
.
size
()[
0
],),
float
(
targets
.
size
()[
1
]))
-
class_sample_count
)
# Binary labels
if
len
(
list
(
targets
.
shape
))
==
1
:
class_sample_count
=
[
float
((
targets
==
t
).
sum
().
item
())
for
t
in
torch
.
unique
(
targets
,
sorted
=
True
)
]
return
negative_class_sample_count
/
(
class_sample_count
+
negative_class_sample_count
)
# Divide negatives by positives
positive_weights
=
torch
.
tensor
(
[
class_sample_count
[
0
]
/
class_sample_count
[
1
]],
).
reshape
(
-
1
)
# Multiclass labels
else
:
class_sample_count
=
torch
.
sum
(
targets
,
dim
=
0
)
negative_class_sample_count
=
(
torch
.
full
((
targets
.
size
()[
1
],),
float
(
targets
.
size
()[
0
]))
-
class_sample_count
)
def
compute_non_exclusive_multiclass_weights
(
targets
):
"""
Compute the positive weights when using non-exclusive, multiclass targets.
positive_weights
=
negative_class_sample_count
/
(
class_sample_count
+
negative_class_sample_count
)
Parameters
----------
targets
A [C x n] tensor of integer values, where `C` is the number of target classes and `n` the number of samples.
return
positive_weights
Returns
-------
The positive weights per class.
"""
raise
ValueError
(
"
Computing weights of multi-class, non-exclusive labels is not yet supported.
"
)
def
make_balanced_bcewithlogitsloss
(
dataloader
:
DataLoader
,
)
->
torch
.
nn
.
BCEWithLogitsLoss
:
"""
Return a balanced binary-cross-entropy loss.
def
is_multicalss_exclusive
(
targets
:
torch
.
Tensor
)
->
bool
:
"""
Given a [C x n] tensor of integer targets, checks whether samples can only belong to a single class.
The loss is weighted using the ratio between positives and total examples
available.
Parameters
----------
targets
A [C x n] tensor of integer values, where `C` is the number of target classes and `n` the number of samples.
Returns
-------
True if all samples belong to a single class, False otherwise (a sample can belong to multiple classes).
"""
max_counts
=
[]
transposed_targets
=
torch
.
transpose
(
targets
,
0
,
1
)
for
t
in
transposed_targets
:
filtered_list
=
[
i
for
i
in
t
.
tolist
()
if
i
!=
2
]
counts
=
Counter
(
filtered_list
)
max_counts
.
append
(
max
(
counts
.
values
()))
if
set
(
max_counts
)
==
{
1
}:
return
True
return
False
def
tensor_to_list
(
tensor
)
->
list
[
typing
.
Any
]:
"""
Convert a torch.Tensor to a list.
This is necessary, as torch.tolist returns an int when then tensor contains a single value.
Parameters
----------
tensor
The tensor to convert to a list.
Returns
-------
The tensor converted to a list.
"""
tensor
=
tensor
.
tolist
()
if
isinstance
(
tensor
,
int
):
return
[
tensor
]
return
tensor
def
get_positive_weights
(
dataloader
:
torch
.
utils
.
data
.
DataLoader
,
)
->
torch
.
Tensor
:
"""
Compute the weights of each class of a DataLoader.
This function inputs a pytorch DataLoader and computes the ratio between
number of negative and positive samples (scalar). The weight can be used
to adjust minimisation criteria to in cases there is a huge data imbalance.
It returns a vector with weights (inverse counts) for each label.
Parameters
----------
dataloader
The DataLoader to use to compute the BCE weights.
A DataLoader from which to compute the positive weights. Entries must
be a dictionary which must contain a ``label`` key.
Returns
-------
torch.nn.BCEWithLogitsLoss
An instance of the weighted loss.
The positive weight of each class in the dataset given as input.
"""
weights
=
_get_label_weights
(
dataloader
)
return
torch
.
nn
.
BCEWithLogitsLoss
(
pos_weight
=
weights
)
from
collections
import
defaultdict
targets
=
defaultdict
(
list
)
for
batch
in
dataloader
:
for
class_idx
,
class_targets
in
enumerate
(
batch
[
1
][
"
label
"
]):
# Targets are either a single tensor (binary case) or a list of tensors (multilabel)
if
isinstance
(
batch
[
1
][
"
label
"
],
list
):
targets
[
class_idx
].
extend
(
tensor_to_list
(
class_targets
))
else
:
targets
[
0
].
extend
(
tensor_to_list
(
class_targets
))
targets_list
=
[]
for
k
in
sorted
(
list
(
targets
.
keys
())):
targets_list
.
append
(
targets
[
k
])
targets_tensor
=
torch
.
tensor
(
targets_list
)
if
len
(
list
(
targets_tensor
.
shape
))
==
1
:
logger
.
info
(
"
Computing positive weights assuming binary labels.
"
)
positive_weights
=
compute_binary_weights
(
targets_tensor
)
else
:
if
is_multicalss_exclusive
(
targets_tensor
):
logger
.
info
(
"
Computing positive weights assuming multiclass, exclusive labels.
"
)
positive_weights
=
compute_multiclass_weights
(
targets_tensor
)
else
:
logger
.
info
(
"
Computing positive weights assuming multiclass, non-exclusive labels.
"
)
positive_weights
=
compute_non_exclusive_multiclass_weights
(
targets_tensor
)
return
positive_weights
This diff is collapsed.
Click to expand it.
src/mednet/models/model.py
+
4
−
3
View file @
1f0e3938
...
...
@@ -13,7 +13,7 @@ import torch.utils.data
import
torchvision.transforms
from
..data.typing
import
TransformSequence
from
.loss_weights
import
_
get_
label
_weights
from
.loss_weights
import
get_
positive
_weights
from
.typing
import
Checkpoint
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -152,6 +152,7 @@ class Model(pl.LightningModule):
datamodule
Instance of a datamodule.
"""
logger
.
info
(
f
"
Balancing training loss function
{
self
.
_train_loss
}
.
"
)
try
:
getattr
(
self
.
_train_loss
,
"
pos_weight
"
)
...
...
@@ -160,7 +161,7 @@ class Model(pl.LightningModule):
"
Training loss does not posess a
'
pos_weight
'
attribute and will not be balanced.
"
)
else
:
train_weights
=
_
get_
label
_weights
(
datamodule
.
train_dataloader
())
train_weights
=
get_
positive
_weights
(
datamodule
.
train_dataloader
())
setattr
(
self
.
_train_loss
,
"
pos_weight
"
,
train_weights
)
logger
.
info
(
...
...
@@ -185,7 +186,7 @@ class Model(pl.LightningModule):
)
for
val_dataset_key
in
datamodule_validation_keys
:
validation_weights
=
_
get_
label
_weights
(
validation_weights
=
get_
positive
_weights
(
datamodule
.
val_dataloader
()[
val_dataset_key
]
)
new_validation_losses
.
append
(
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment