Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
mednet
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Package registry
Model registry
Operate
Environments
Terraform modules
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
medai
software
mednet
Commits
eea2f306
Commit
eea2f306
authored
1 year ago
by
André Anjos
Browse files
Options
Downloads
Patches
Plain Diff
[data.datamodule] Implements ConcatDataModule (closes
#16
); Streamline types (see
#24
)
parent
6f2383c8
No related branches found
Branches containing commit
No related tags found
Tags containing commit
1 merge request
!6
Making use of LightningDataModule and simplification of data loading
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
src/ptbench/data/datamodule.py
+174
-64
174 additions, 64 deletions
src/ptbench/data/datamodule.py
src/ptbench/data/typing.py
+17
-3
17 additions, 3 deletions
src/ptbench/data/typing.py
with
191 additions
and
67 deletions
src/ptbench/data/datamodule.py
+
174
−
64
View file @
eea2f306
...
@@ -4,6 +4,7 @@
...
@@ -4,6 +4,7 @@
import
collections
import
collections
import
functools
import
functools
import
itertools
import
logging
import
logging
import
multiprocessing
import
multiprocessing
import
sys
import
sys
...
@@ -17,12 +18,14 @@ import torchvision.transforms
...
@@ -17,12 +18,14 @@ import torchvision.transforms
import
tqdm
import
tqdm
from
.typing
import
(
from
.typing
import
(
ConcatDatabaseSplit
,
DatabaseSplit
,
DatabaseSplit
,
DataLoader
,
DataLoader
,
Dataset
,
Dataset
,
RawDataLoader
,
RawDataLoader
,
Sample
,
Sample
,
Transform
,
Transform
,
TransformSequence
,
)
)
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -72,9 +75,9 @@ class _DelayedLoadingDataset(Dataset):
...
@@ -72,9 +75,9 @@ class _DelayedLoadingDataset(Dataset):
Parameters
Parameters
----------
----------
spli
t
raw_datase
t
An iterable containing the raw dataset samples
loaded from the databas
e
An iterable containing the raw dataset samples
representing one of th
e
spli
ts.
database split datase
ts.
loader
loader
An object instance that can load samples and labels from storage.
An object instance that can load samples and labels from storage.
...
@@ -86,11 +89,11 @@ class _DelayedLoadingDataset(Dataset):
...
@@ -86,11 +89,11 @@ class _DelayedLoadingDataset(Dataset):
def
__init__
(
def
__init__
(
self
,
self
,
spli
t
:
typing
.
Sequence
[
typing
.
Any
],
raw_datase
t
:
typing
.
Sequence
[
typing
.
Any
],
loader
:
RawDataLoader
,
loader
:
RawDataLoader
,
transforms
:
typing
.
Sequence
[
Transform
]
=
[],
transforms
:
TransformSequence
=
[],
):
):
self
.
split
=
spli
t
self
.
raw_dataset
=
raw_datase
t
self
.
loader
=
loader
self
.
loader
=
loader
self
.
transform
=
torchvision
.
transforms
.
Compose
(
transforms
)
self
.
transform
=
torchvision
.
transforms
.
Compose
(
transforms
)
...
@@ -105,14 +108,14 @@ class _DelayedLoadingDataset(Dataset):
...
@@ -105,14 +108,14 @@ class _DelayedLoadingDataset(Dataset):
def
labels
(
self
)
->
list
[
int
]:
def
labels
(
self
)
->
list
[
int
]:
"""
Returns the integer labels for all samples in the dataset.
"""
"""
Returns the integer labels for all samples in the dataset.
"""
return
[
self
.
loader
.
label
(
k
)
for
k
in
self
.
spli
t
]
return
[
self
.
loader
.
label
(
k
)
for
k
in
self
.
raw_datase
t
]
def
__getitem__
(
self
,
key
:
int
)
->
Sample
:
def
__getitem__
(
self
,
key
:
int
)
->
Sample
:
tensor
,
metadata
=
self
.
loader
.
sample
(
self
.
spli
t
[
key
])
tensor
,
metadata
=
self
.
loader
.
sample
(
self
.
raw_datase
t
[
key
])
return
self
.
transform
(
tensor
),
metadata
return
self
.
transform
(
tensor
),
metadata
def
__len__
(
self
):
def
__len__
(
self
):
return
len
(
self
.
spli
t
)
return
len
(
self
.
raw_datase
t
)
def
__iter__
(
self
):
def
__iter__
(
self
):
for
x
in
range
(
len
(
self
)):
for
x
in
range
(
len
(
self
)):
...
@@ -131,7 +134,7 @@ def _apply_loader_and_transforms(
...
@@ -131,7 +134,7 @@ def _apply_loader_and_transforms(
----------
----------
info
info
The sample information, as loaded from its
spli
t dictionary
The sample information, as loaded from its
raw datase
t dictionary
load
load
The raw-data loader function to use for loading the sample
The raw-data loader function to use for loading the sample
...
@@ -155,7 +158,7 @@ def _apply_loader_and_transforms(
...
@@ -155,7 +158,7 @@ def _apply_loader_and_transforms(
class
_CachedDataset
(
Dataset
):
class
_CachedDataset
(
Dataset
):
"""
Basically, a list of preloaded samples.
"""
Basically, a list of preloaded samples.
This dataset will load all samples from the
spli
t during construction
This dataset will load all samples from the
raw datase
t during construction
instead of delaying that to the indexing. Beyong raw-data-loading,
instead of delaying that to the indexing. Beyong raw-data-loading,
``transforms`` given upon construction contribute to the cached samples.
``transforms`` given upon construction contribute to the cached samples.
...
@@ -163,9 +166,9 @@ class _CachedDataset(Dataset):
...
@@ -163,9 +166,9 @@ class _CachedDataset(Dataset):
Parameters
Parameters
----------
----------
spli
t
raw_datase
t
An iterable containing the raw dataset samples
loaded from the databas
e
An iterable containing the raw dataset samples
representing one of th
e
spli
ts.
database split datase
ts.
loader
loader
An object instance that can load samples and labels from storage.
An object instance that can load samples and labels from storage.
...
@@ -184,10 +187,10 @@ class _CachedDataset(Dataset):
...
@@ -184,10 +187,10 @@ class _CachedDataset(Dataset):
def
__init__
(
def
__init__
(
self
,
self
,
spli
t
:
typing
.
Sequence
[
typing
.
Any
],
raw_datase
t
:
typing
.
Sequence
[
typing
.
Any
],
loader
:
RawDataLoader
,
loader
:
RawDataLoader
,
parallel
:
int
=
-
1
,
parallel
:
int
=
-
1
,
transforms
:
typing
.
Sequence
[
Transform
]
=
[],
transforms
:
TransformSequence
=
[],
):
):
self
.
loader
=
functools
.
partial
(
self
.
loader
=
functools
.
partial
(
_apply_loader_and_transforms
,
_apply_loader_and_transforms
,
...
@@ -197,14 +200,16 @@ class _CachedDataset(Dataset):
...
@@ -197,14 +200,16 @@ class _CachedDataset(Dataset):
if
parallel
<
0
:
if
parallel
<
0
:
self
.
data
=
[
self
.
data
=
[
self
.
loader
(
k
)
for
k
in
tqdm
.
tqdm
(
spli
t
,
unit
=
"
sample
"
)
self
.
loader
(
k
)
for
k
in
tqdm
.
tqdm
(
raw_datase
t
,
unit
=
"
sample
"
)
]
]
else
:
else
:
instances
=
parallel
or
multiprocessing
.
cpu_count
()
instances
=
parallel
or
multiprocessing
.
cpu_count
()
logger
.
info
(
f
"
Caching dataset using
{
instances
}
processes...
"
)
logger
.
info
(
f
"
Caching dataset using
{
instances
}
processes...
"
)
with
multiprocessing
.
Pool
(
instances
)
as
p
:
with
multiprocessing
.
Pool
(
instances
)
as
p
:
self
.
data
=
list
(
self
.
data
=
list
(
tqdm
.
tqdm
(
p
.
imap
(
self
.
loader
,
split
),
total
=
len
(
split
))
tqdm
.
tqdm
(
p
.
imap
(
self
.
loader
,
raw_dataset
),
total
=
len
(
raw_dataset
)
)
)
)
# Estimates memory occupance
# Estimates memory occupance
...
@@ -229,8 +234,42 @@ class _CachedDataset(Dataset):
...
@@ -229,8 +234,42 @@ class _CachedDataset(Dataset):
return
len
(
self
.
data
)
return
len
(
self
.
data
)
def
__iter__
(
self
):
def
__iter__
(
self
):
for
x
in
range
(
len
(
self
)):
yield
from
self
.
data
yield
self
[
x
]
class
_ConcatDataset
(
Dataset
):
"""
A dataset that represents a concatenation of other cached or delayed
datasets.
Parameters
----------
datasets
An iterable over pre-instantiated datasets.
"""
def
__init__
(
self
,
datasets
:
typing
.
Sequence
[
Dataset
]):
self
.
_datasets
=
datasets
self
.
_indices
=
[
(
i
,
j
)
# dataset relative position, sample relative position
for
i
in
range
(
len
(
datasets
))
for
j
in
range
(
len
(
datasets
[
i
]))
]
def
labels
(
self
)
->
list
[
int
]:
"""
Returns the integer labels for all samples in the dataset.
"""
return
list
(
itertools
.
chain
(
*
[
k
.
labels
()
for
k
in
self
.
_datasets
]))
def
__getitem__
(
self
,
key
:
int
)
->
Sample
:
i
,
j
=
self
.
_indices
[
key
]
return
self
.
_datasets
[
i
][
j
]
def
__len__
(
self
):
return
sum
([
len
(
k
)
for
k
in
self
.
_datasets
])
def
__iter__
(
self
):
for
dataset
in
self
.
_datasets
:
yield
from
dataset
def
_make_balanced_random_sampler
(
def
_make_balanced_random_sampler
(
...
@@ -375,14 +414,15 @@ def _make_balanced_random_sampler(
...
@@ -375,14 +414,15 @@ def _make_balanced_random_sampler(
)
)
class
C
aching
DataModule
(
lightning
.
LightningDataModule
):
class
C
oncat
DataModule
(
lightning
.
LightningDataModule
):
"""
A conveninent data module with
CSV or JSON protocol
loading, mini-
"""
A conveninent data module with
dictionary split
loading, mini-
batching,
batching,
parallelisation and caching, all in one.
parallelisation and caching, all in one.
Instances of this class load data-split (a.k.a. protocol) definitions for a
Instances of this class can load and concatenate an arbitrary number of
database, and can load the data from the disk. An optional caching
data-split (a.k.a. protocol) definitions for (possibly disjoint) databases,
mechanism stores the data at associated CPU memory, which can improve data
and can manage raw data-loading from disk. An optional caching mechanism
serving while training and evaluating models.
stores the data at associated CPU memory, which can improve data serving
while training and evaluating models.
This datamodule defines basic operations to handle data loading and
This datamodule defines basic operations to handle data loading and
mini-batch handling within this package
'
s framework. It can return
mini-batch handling within this package
'
s framework. It can return
...
@@ -390,31 +430,32 @@ class CachingDataModule(lightning.LightningDataModule):
...
@@ -390,31 +430,32 @@ class CachingDataModule(lightning.LightningDataModule):
prediction and testing conditions. Parallelisation is handled by a simple
prediction and testing conditions. Parallelisation is handled by a simple
input flag.
input flag.
Users must implement the basic :py:meth:`setup` function, which is
parameterised by a single string enumeration containing: ``fit``,
``validate``, ``test``, or ``predict``.
Parameters
Parameters
----------
----------
database_split
splits
A dictionary that contains string keys representing subset names, and
A dictionary that contains string keys representing dataset names, and
values that are iterables over sample representations (potentially on
values that are iterables over a 2-tuple containing an iterable over
disk). These objects are passed to the ``sample_loader`` for loading
arbitrary, user-configurable sample representations (potentially on
the sample data (and metadata) in memory. The objects represented may
disk or permanent storage), and :py:class:`RawDataLoader` (or
"
sample
"
)
be of any format (e.g. list, dictionary, etc), for as long as the
loader objects, which concretely implement a mechanism to load such
``sample_loader`` can properly handle it. To check the split and the
samples in memory, from permanent storage.
loader function works correctly, you may use
:py:func:`..dataset.check_database_split_loading`. As is, this class
expects at least one entry called ``train`` to exist in the input
dictionary. Optional entries are ``validation``, and ``test``. Entries
named ``monitor-...`` will be considered extra subsets that do not
influence any early stop criteria during training, and are just
monitored beyond the ``validation`` dataset.
loader
Sample representations on permanent storage may be of any iterable
An object instance that can load samples and labels from storage.
format (e.g. list, dictionary, etc.), for as long as the assigned
:py:class:`RawDataLoader` can properly handle it.
.. tip::
To check the split and the loader function works correctly, you may
use :py:func:`..dataset.check_database_split_loading`.
This class expects at least one entry called ``train`` to exist in the
input dictionary. Optional entries are ``validation``, and ``test``.
Entries named ``monitor-...`` will be considered extra datasets that do
not influence any early stop criteria during training, and are just
monitored beyond the ``validation`` dataset.
cache_samples
cache_samples
If set, then issue raw data loading during ``prepare_data()``, and
If set, then issue raw data loading during ``prepare_data()``, and
...
@@ -486,8 +527,7 @@ class CachingDataModule(lightning.LightningDataModule):
...
@@ -486,8 +527,7 @@ class CachingDataModule(lightning.LightningDataModule):
def
__init__
(
def
__init__
(
self
,
self
,
database_split
:
DatabaseSplit
,
splits
:
ConcatDatabaseSplit
,
raw_data_loader
:
RawDataLoader
,
cache_samples
:
bool
=
False
,
cache_samples
:
bool
=
False
,
balance_sampler_by_class
:
bool
=
False
,
balance_sampler_by_class
:
bool
=
False
,
batch_size
:
int
=
1
,
batch_size
:
int
=
1
,
...
@@ -499,8 +539,7 @@ class CachingDataModule(lightning.LightningDataModule):
...
@@ -499,8 +539,7 @@ class CachingDataModule(lightning.LightningDataModule):
self
.
set_chunk_size
(
batch_size
,
batch_chunk_count
)
self
.
set_chunk_size
(
batch_size
,
batch_chunk_count
)
self
.
database_split
=
database_split
self
.
splits
=
splits
self
.
raw_data_loader
=
raw_data_loader
self
.
cache_samples
=
cache_samples
self
.
cache_samples
=
cache_samples
self
.
_train_sampler
=
None
self
.
_train_sampler
=
None
self
.
balance_sampler_by_class
=
balance_sampler_by_class
self
.
balance_sampler_by_class
=
balance_sampler_by_class
...
@@ -581,6 +620,14 @@ class CachingDataModule(lightning.LightningDataModule):
...
@@ -581,6 +620,14 @@ class CachingDataModule(lightning.LightningDataModule):
If set, then modifies the random sampler used during training
If set, then modifies the random sampler used during training
and validation to balance sample picking probability, making
and validation to balance sample picking probability, making
sample across classes **and** datasets equitable.
sample across classes **and** datasets equitable.
.. warning::
This method does **NOT** balance the sampler per dataset, in case
multiple datasets compose the same training set. It only balances
samples acording to their ground-truth (labels). If you
'
d like to
have samples balanced per dataset, then implement your own data
module inheriting from this one.
"""
"""
return
self
.
_train_sampler
is
not
None
return
self
.
_train_sampler
is
not
None
...
@@ -661,32 +708,45 @@ class CachingDataModule(lightning.LightningDataModule):
...
@@ -661,32 +708,45 @@ class CachingDataModule(lightning.LightningDataModule):
f
"
Not re-instantiating it.
"
f
"
Not re-instantiating it.
"
)
)
return
return
datasets
:
list
[
_CachedDataset
|
_DelayedLoadingDataset
]
=
[]
if
self
.
cache_samples
:
if
self
.
cache_samples
:
logger
.
info
(
logger
.
info
(
f
"
Loading dataset:`
{
name
}
` into memory (caching).
"
f
"
Loading dataset:`
{
name
}
` into memory (caching).
"
f
"
Trade-off: CPU RAM: more | Disk: less
"
f
"
Trade-off: CPU RAM: more | Disk: less
"
)
)
self
.
_datasets
[
name
]
=
_CachedDataset
(
for
split
,
loader
in
self
.
splits
[
name
]:
self
.
database_split
[
name
],
datasets
.
append
(
self
.
raw_data_loader
,
_CachedDataset
(
self
.
parallel
,
split
,
self
.
model_transforms
,
loader
,
)
self
.
parallel
,
self
.
model_transforms
,
)
)
else
:
else
:
logger
.
info
(
logger
.
info
(
f
"
Loading dataset:`
{
name
}
` without caching.
"
f
"
Loading dataset:`
{
name
}
` without caching.
"
f
"
Trade-off: CPU RAM: less | Disk: more
"
f
"
Trade-off: CPU RAM: less | Disk: more
"
)
)
self
.
_datasets
[
name
]
=
_DelayedLoadingDataset
(
for
split
,
loader
in
self
.
splits
[
name
]:
self
.
database_split
[
name
],
datasets
.
append
(
self
.
raw_data_loader
,
_DelayedLoadingDataset
(
self
.
model_transforms
,
split
,
)
loader
,
self
.
model_transforms
,
)
)
if
len
(
datasets
)
==
1
:
self
.
_datasets
[
name
]
=
datasets
[
0
]
else
:
self
.
_datasets
[
name
]
=
_ConcatDataset
(
datasets
)
def
_val_dataset_keys
(
self
)
->
list
[
str
]:
def
_val_dataset_keys
(
self
)
->
list
[
str
]:
"""
Returns list of validation dataset names.
"""
"""
Returns list of validation dataset names.
"""
return
[
"
validation
"
]
+
[
return
[
"
validation
"
]
+
[
k
for
k
in
self
.
database_
split
.
keys
()
if
k
.
startswith
(
"
monitor-
"
)
k
for
k
in
self
.
split
s
.
keys
()
if
k
.
startswith
(
"
monitor-
"
)
]
]
def
setup
(
self
,
stage
:
str
)
->
None
:
def
setup
(
self
,
stage
:
str
)
->
None
:
...
@@ -727,7 +787,7 @@ class CachingDataModule(lightning.LightningDataModule):
...
@@ -727,7 +787,7 @@ class CachingDataModule(lightning.LightningDataModule):
self
.
_setup_dataset
(
"
test
"
)
self
.
_setup_dataset
(
"
test
"
)
elif
stage
==
"
predict
"
:
elif
stage
==
"
predict
"
:
for
k
in
self
.
database_
split
:
for
k
in
self
.
split
s
:
self
.
_setup_dataset
(
k
)
self
.
_setup_dataset
(
k
)
def
teardown
(
self
,
stage
:
str
)
->
None
:
def
teardown
(
self
,
stage
:
str
)
->
None
:
...
@@ -826,3 +886,53 @@ class CachingDataModule(lightning.LightningDataModule):
...
@@ -826,3 +886,53 @@ class CachingDataModule(lightning.LightningDataModule):
)
)
for
k
in
self
.
_datasets
for
k
in
self
.
_datasets
}
}
class
CachingDataModule
(
ConcatDataModule
):
"""
A simplified version of our data module for a single split.
Apart from construction, the behaviour of this data module is very similar
to its simpler counterpart, serving training, validation and test sets.
Parameters
----------
database_split
A dictionary that contains string keys representing dataset names, and
values that are iterables over sample representations (potentially on
disk). These objects are passed to an unique :py:class:`RawDataLoader`
for loading the :py:class:`Sample` data (and metadata) in memory. It
therefore assumes the whole split is homogeneous and can be loaded in
the same way.
.. tip::
To check the split and the loader function works correctly, you may
use :py:func:`..dataset.check_database_split_loading`.
This class expects at least one entry called ``train`` to exist in the
input dictionary. Optional entries are ``validation``, and ``test``.
Entries named ``monitor-...`` will be considered extra datasets that do
not influence any early stop criteria during training, and are just
monitored beyond the ``validation`` dataset.
raw_data_loader
An object instance that can load samples and labels from storage.
**kwargs
List if named parameters matching those of
:py:class:`ConcatDataModule`, other than ``splits``.
"""
def
__init__
(
self
,
database_split
:
DatabaseSplit
,
raw_data_loader
:
RawDataLoader
,
**
kwargs
,
):
splits
=
{
k
:
[(
v
,
raw_data_loader
)]
for
k
,
v
in
database_split
.
items
()}
super
().
__init__
(
splits
=
splits
,
**
kwargs
,
)
This diff is collapsed.
Click to expand it.
src/ptbench/data/typing.py
+
17
−
3
View file @
eea2f306
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
#
# SPDX-License-Identifier: GPL-3.0-or-later
# SPDX-License-Identifier: GPL-3.0-or-later
"""
Defines most common types used in code.
"""
"""
Defines most common types used in code.
"""
import
collections.abc
import
collections.abc
...
@@ -51,8 +50,23 @@ TransformSequence = typing.Sequence[Transform]
...
@@ -51,8 +50,23 @@ TransformSequence = typing.Sequence[Transform]
DatabaseSplit
=
collections
.
abc
.
Mapping
[
str
,
typing
.
Sequence
[
typing
.
Any
]]
DatabaseSplit
=
collections
.
abc
.
Mapping
[
str
,
typing
.
Sequence
[
typing
.
Any
]]
"""
The definition of a database split.
"""
The definition of a database split.
A database split maps subset names to sequences of objects that, through
A database split maps dataset (subset) names to sequences of objects
RawDataLoader
'
s eventually become Samples in the processing pipeline.
that, through :py:class:`RawDataLoader`s, eventually become
:py:class:`Sample`s in the processing pipeline.
"""
ConcatDatabaseSplit
=
collections
.
abc
.
Mapping
[
str
,
typing
.
Sequence
[
tuple
[
typing
.
Sequence
[
typing
.
Any
],
RawDataLoader
]],
]
"""
The definition of a complex database split composed of several other splits.
A database split maps dataset (subset) names to sequences of objects
that, through :py:class:`RawDataLoader`s, eventually become
:py:class:`Sample`s in the processing pipeline. Objects of this subtype
allow the construction of complex splits composed of cannibalized parts
of other splits. Each split may be assigned a different
:py:class:`RawDataLoader`.
"""
"""
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment