Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
mednet
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Package Registry
Model registry
Operate
Environments
Terraform modules
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
medai
software
mednet
Commits
1f0e3938
Commit
1f0e3938
authored
10 months ago
by
Daniel CARRON
Browse files
Options
Downloads
Patches
Plain Diff
[model] Fix balancing of multiclass targets
parent
b32ebbe5
No related branches found
No related tags found
1 merge request
!38
Replace sampler balancing by loss balancing
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
src/mednet/models/loss_weights.py
+143
-49
143 additions, 49 deletions
src/mednet/models/loss_weights.py
src/mednet/models/model.py
+4
-3
4 additions, 3 deletions
src/mednet/models/model.py
with
147 additions
and
52 deletions
src/mednet/models/loss_weights.py
+
143
−
49
View file @
1f0e3938
...
@@ -3,86 +3,180 @@
...
@@ -3,86 +3,180 @@
# SPDX-License-Identifier: GPL-3.0-or-later
# SPDX-License-Identifier: GPL-3.0-or-later
import
logging
import
logging
import
typing
from
collections
import
Counter
import
torch
import
torch
import
torch.utils.data
import
torch.utils.data
from
..data.typing
import
DataLoader
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
def
_get_label_weights
(
def
compute_binary_weights
(
targets
):
dataloader
:
torch
.
utils
.
data
.
DataLoader
,
"""
Compute the positive weights when using binary targets.
)
->
torch
.
Tensor
:
"""
Compute the weights of each class of a DataLoader.
This function inputs a pytorch DataLoader and computes the ratio between
Parameters
number of negative and positive samples (scalar). The weight can be used
----------
to adjust minimisation criteria to in cases there is a huge data imbalance.
targets
A tensor of integer values of length n.
It returns a vector with weights (inverse counts) for each label.
Returns
-------
The positive weights per class.
"""
class_sample_count
=
[
float
((
targets
==
t
).
sum
().
item
())
for
t
in
torch
.
unique
(
targets
,
sorted
=
True
)
]
# Divide negatives by positives
return
torch
.
tensor
(
[
class_sample_count
[
0
]
/
class_sample_count
[
1
]],
).
reshape
(
-
1
)
def
compute_multiclass_weights
(
targets
):
"""
Compute the positive weights when using exclusive, multiclass targets.
Parameters
Parameters
----------
----------
dataloader
targets
A DataLoader from which to compute the positive weights. Entries must
A [C x n] tensor of integer values, where `C` is the number of target classes and `n` the number of samples.
be a dictionary which must contain a ``label`` key.
Returns
Returns
-------
-------
torch.Tensor
The positive weights per class.
The positive weight of each class in the dataset given as input.
"""
"""
targets
=
torch
.
tensor
(
[
sample
for
batch
in
dataloader
for
sample
in
batch
[
1
][
"
label
"
]],
class_sample_count
=
torch
.
sum
(
targets
,
dim
=
1
)
negative_class_sample_count
=
(
torch
.
full
((
targets
.
size
()[
0
],),
float
(
targets
.
size
()[
1
]))
-
class_sample_count
)
)
# Binary labels
return
negative_class_sample_count
/
(
if
len
(
list
(
targets
.
shape
))
==
1
:
class_sample_count
+
negative_class_sample_count
class_sample_count
=
[
)
float
((
targets
==
t
).
sum
().
item
())
for
t
in
torch
.
unique
(
targets
,
sorted
=
True
)
]
# Divide negatives by positives
positive_weights
=
torch
.
tensor
(
[
class_sample_count
[
0
]
/
class_sample_count
[
1
]],
).
reshape
(
-
1
)
# Multiclass labels
def
compute_non_exclusive_multiclass_weights
(
targets
):
else
:
"""
Compute the positive weights when using non-exclusive, multiclass targets.
class_sample_count
=
torch
.
sum
(
targets
,
dim
=
0
)
negative_class_sample_count
=
(
torch
.
full
((
targets
.
size
()[
1
],),
float
(
targets
.
size
()[
0
]))
-
class_sample_count
)
positive_weights
=
negative_class_sample_count
/
(
Parameters
class_sample_count
+
negative_class_sample_count
----------
)
targets
A [C x n] tensor of integer values, where `C` is the number of target classes and `n` the number of samples.
return
positive_weights
Returns
-------
The positive weights per class.
"""
raise
ValueError
(
"
Computing weights of multi-class, non-exclusive labels is not yet supported.
"
)
def
make_balanced_bcewithlogitsloss
(
def
is_multicalss_exclusive
(
targets
:
torch
.
Tensor
)
->
bool
:
dataloader
:
DataLoader
,
"""
Given a [C x n] tensor of integer targets, checks whether samples can only belong to a single class.
)
->
torch
.
nn
.
BCEWithLogitsLoss
:
"""
Return a balanced binary-cross-entropy loss.
The loss is weighted using the ratio between positives and total examples
Parameters
available.
----------
targets
A [C x n] tensor of integer values, where `C` is the number of target classes and `n` the number of samples.
Returns
-------
True if all samples belong to a single class, False otherwise (a sample can belong to multiple classes).
"""
max_counts
=
[]
transposed_targets
=
torch
.
transpose
(
targets
,
0
,
1
)
for
t
in
transposed_targets
:
filtered_list
=
[
i
for
i
in
t
.
tolist
()
if
i
!=
2
]
counts
=
Counter
(
filtered_list
)
max_counts
.
append
(
max
(
counts
.
values
()))
if
set
(
max_counts
)
==
{
1
}:
return
True
return
False
def
tensor_to_list
(
tensor
)
->
list
[
typing
.
Any
]:
"""
Convert a torch.Tensor to a list.
This is necessary, as torch.tolist returns an int when then tensor contains a single value.
Parameters
----------
tensor
The tensor to convert to a list.
Returns
-------
The tensor converted to a list.
"""
tensor
=
tensor
.
tolist
()
if
isinstance
(
tensor
,
int
):
return
[
tensor
]
return
tensor
def
get_positive_weights
(
dataloader
:
torch
.
utils
.
data
.
DataLoader
,
)
->
torch
.
Tensor
:
"""
Compute the weights of each class of a DataLoader.
This function inputs a pytorch DataLoader and computes the ratio between
number of negative and positive samples (scalar). The weight can be used
to adjust minimisation criteria to in cases there is a huge data imbalance.
It returns a vector with weights (inverse counts) for each label.
Parameters
Parameters
----------
----------
dataloader
dataloader
The DataLoader to use to compute the BCE weights.
A DataLoader from which to compute the positive weights. Entries must
be a dictionary which must contain a ``label`` key.
Returns
Returns
-------
-------
torch.nn.BCEWithLogitsLoss
The positive weight of each class in the dataset given as input.
An instance of the weighted loss.
"""
"""
weights
=
_get_label_weights
(
dataloader
)
from
collections
import
defaultdict
return
torch
.
nn
.
BCEWithLogitsLoss
(
pos_weight
=
weights
)
targets
=
defaultdict
(
list
)
for
batch
in
dataloader
:
for
class_idx
,
class_targets
in
enumerate
(
batch
[
1
][
"
label
"
]):
# Targets are either a single tensor (binary case) or a list of tensors (multilabel)
if
isinstance
(
batch
[
1
][
"
label
"
],
list
):
targets
[
class_idx
].
extend
(
tensor_to_list
(
class_targets
))
else
:
targets
[
0
].
extend
(
tensor_to_list
(
class_targets
))
targets_list
=
[]
for
k
in
sorted
(
list
(
targets
.
keys
())):
targets_list
.
append
(
targets
[
k
])
targets_tensor
=
torch
.
tensor
(
targets_list
)
if
len
(
list
(
targets_tensor
.
shape
))
==
1
:
logger
.
info
(
"
Computing positive weights assuming binary labels.
"
)
positive_weights
=
compute_binary_weights
(
targets_tensor
)
else
:
if
is_multicalss_exclusive
(
targets_tensor
):
logger
.
info
(
"
Computing positive weights assuming multiclass, exclusive labels.
"
)
positive_weights
=
compute_multiclass_weights
(
targets_tensor
)
else
:
logger
.
info
(
"
Computing positive weights assuming multiclass, non-exclusive labels.
"
)
positive_weights
=
compute_non_exclusive_multiclass_weights
(
targets_tensor
)
return
positive_weights
This diff is collapsed.
Click to expand it.
src/mednet/models/model.py
+
4
−
3
View file @
1f0e3938
...
@@ -13,7 +13,7 @@ import torch.utils.data
...
@@ -13,7 +13,7 @@ import torch.utils.data
import
torchvision.transforms
import
torchvision.transforms
from
..data.typing
import
TransformSequence
from
..data.typing
import
TransformSequence
from
.loss_weights
import
_
get_
label
_weights
from
.loss_weights
import
get_
positive
_weights
from
.typing
import
Checkpoint
from
.typing
import
Checkpoint
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -152,6 +152,7 @@ class Model(pl.LightningModule):
...
@@ -152,6 +152,7 @@ class Model(pl.LightningModule):
datamodule
datamodule
Instance of a datamodule.
Instance of a datamodule.
"""
"""
logger
.
info
(
f
"
Balancing training loss function
{
self
.
_train_loss
}
.
"
)
logger
.
info
(
f
"
Balancing training loss function
{
self
.
_train_loss
}
.
"
)
try
:
try
:
getattr
(
self
.
_train_loss
,
"
pos_weight
"
)
getattr
(
self
.
_train_loss
,
"
pos_weight
"
)
...
@@ -160,7 +161,7 @@ class Model(pl.LightningModule):
...
@@ -160,7 +161,7 @@ class Model(pl.LightningModule):
"
Training loss does not posess a
'
pos_weight
'
attribute and will not be balanced.
"
"
Training loss does not posess a
'
pos_weight
'
attribute and will not be balanced.
"
)
)
else
:
else
:
train_weights
=
_
get_
label
_weights
(
datamodule
.
train_dataloader
())
train_weights
=
get_
positive
_weights
(
datamodule
.
train_dataloader
())
setattr
(
self
.
_train_loss
,
"
pos_weight
"
,
train_weights
)
setattr
(
self
.
_train_loss
,
"
pos_weight
"
,
train_weights
)
logger
.
info
(
logger
.
info
(
...
@@ -185,7 +186,7 @@ class Model(pl.LightningModule):
...
@@ -185,7 +186,7 @@ class Model(pl.LightningModule):
)
)
for
val_dataset_key
in
datamodule_validation_keys
:
for
val_dataset_key
in
datamodule_validation_keys
:
validation_weights
=
_
get_
label
_weights
(
validation_weights
=
get_
positive
_weights
(
datamodule
.
val_dataloader
()[
val_dataset_key
]
datamodule
.
val_dataloader
()[
val_dataset_key
]
)
)
new_validation_losses
.
append
(
new_validation_losses
.
append
(
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment