Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
mednet
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Package registry
Model registry
Operate
Environments
Terraform modules
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
medai
software
mednet
Commits
993600d9
Commit
993600d9
authored
1 year ago
by
André Anjos
Browse files
Options
Downloads
Patches
Plain Diff
[tests] Rewrite tests for hiv-tb and tb-poc
parent
d0743428
No related branches found
Branches containing commit
No related tags found
Tags containing commit
1 merge request
!6
Making use of LightningDataModule and simplification of data loading
Pipeline
#76736
failed
1 year ago
Stage: qa
Stage: test
Stage: doc
Stage: dist
Changes
3
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
tests/conftest.py
+3
-1
3 additions, 1 deletion
tests/conftest.py
tests/test_hivtb.py
+74
-112
74 additions, 112 deletions
tests/test_hivtb.py
tests/test_tbpoc.py
+84
-116
84 additions, 116 deletions
tests/test_tbpoc.py
with
161 additions
and
229 deletions
tests/conftest.py
+
3
−
1
View file @
993600d9
...
@@ -162,7 +162,9 @@ class DatabaseCheckers:
...
@@ -162,7 +162,9 @@ class DatabaseCheckers:
assert
len
(
split
[
k
])
==
lengths
[
k
]
assert
len
(
split
[
k
])
==
lengths
[
k
]
for
s
in
split
[
k
]:
for
s
in
split
[
k
]:
assert
any
([
s
[
0
].
startswith
(
k
)
for
k
in
prefixes
])
assert
any
(
[
s
[
0
].
startswith
(
k
)
for
k
in
prefixes
]
),
f
"
Sample with name
{
s
[
0
]
}
does not start with any of the prefixes in
{
prefixes
}
"
assert
s
[
1
]
in
possible_labels
assert
s
[
1
]
in
possible_labels
@staticmethod
@staticmethod
...
...
This diff is collapsed.
Click to expand it.
tests/test_hivtb.py
+
74
−
112
View file @
993600d9
...
@@ -3,127 +3,89 @@
...
@@ -3,127 +3,89 @@
# SPDX-License-Identifier: GPL-3.0-or-later
# SPDX-License-Identifier: GPL-3.0-or-later
"""
Tests for HIV-TB dataset.
"""
"""
Tests for HIV-TB dataset.
"""
import
pytest
import
importlib
import
torch
from
ptbench.data.hivtb.datamodule
import
make_split
def
_check_split
(
split_filename
:
str
,
lengths
:
dict
[
str
,
int
],
prefix
:
str
=
"
HIV-TB_Algorithm_study_X-rays/
"
,
extension
:
str
=
"
.BMP
"
,
possible_labels
:
list
[
int
]
=
[
0
,
1
],
):
"""
Runs a simple consistence check on the data split.
Parameters
----------
split_filename
This is the split we will check
lenghts
A dictionary that contains keys matching those of the split (this will
be checked). The values of the dictionary should correspond to the
sizes of each of the datasets in the split.
prefix
import
pytest
Each file named in a split should start with this prefix.
extension
Each file named in a split should end with this extension.
possible_labels
These are the list of possible labels contained in any split.
"""
split
=
make_split
(
split_filename
)
assert
len
(
split
)
==
len
(
lengths
)
for
k
in
lengths
.
keys
():
# dataset must have been declared
assert
k
in
split
assert
len
(
split
[
k
])
==
lengths
[
k
]
for
s
in
split
[
k
]:
assert
s
[
0
].
startswith
(
prefix
)
assert
s
[
0
].
endswith
(
extension
)
assert
s
[
1
]
in
possible_labels
def
_check_loaded_batch
(
def
id_function
(
val
):
batch
,
if
isinstance
(
val
,
dict
):
size
:
int
=
1
,
return
str
(
val
)
prefix
:
str
=
"
HIV-TB_Algorithm_study_X-rays/
"
,
return
repr
(
val
)
extension
:
str
=
"
.BMP
"
,
possible_labels
:
list
[
int
]
=
[
0
,
1
],
@pytest.mark.parametrize
(
"
split,lenghts
"
,
[
(
"
fold-0
"
,
dict
(
train
=
174
,
validation
=
44
,
test
=
25
)),
(
"
fold-1
"
,
dict
(
train
=
174
,
validation
=
44
,
test
=
25
)),
(
"
fold-2
"
,
dict
(
train
=
174
,
validation
=
44
,
test
=
25
)),
(
"
fold-3
"
,
dict
(
train
=
175
,
validation
=
44
,
test
=
24
)),
(
"
fold-4
"
,
dict
(
train
=
175
,
validation
=
44
,
test
=
24
)),
(
"
fold-5
"
,
dict
(
train
=
175
,
validation
=
44
,
test
=
24
)),
(
"
fold-6
"
,
dict
(
train
=
175
,
validation
=
44
,
test
=
24
)),
(
"
fold-7
"
,
dict
(
train
=
175
,
validation
=
44
,
test
=
24
)),
(
"
fold-8
"
,
dict
(
train
=
175
,
validation
=
44
,
test
=
24
)),
(
"
fold-9
"
,
dict
(
train
=
175
,
validation
=
44
,
test
=
24
)),
],
ids
=
id_function
,
# just changes how pytest prints it
)
def
test_protocol_consistency
(
database_checkers
,
split
:
str
,
lenghts
:
dict
[
str
,
int
]
):
):
"""
Checks the consistence of an individual (loaded) batch.
from
ptbench.data.hivtb.datamodule
import
make_split
Parameters
database_checkers
.
check_split
(
----------
make_split
(
f
"
{
split
}
.json
"
),
lengths
=
lenghts
,
batch
prefixes
=
(
"
HIV-TB_Algorithm_study_X-rays
"
,),
The loaded batch to be checked.
possible_labels
=
(
0
,
1
),
)
prefix
Each file named in a split should start with this prefix.
extension
Each file named in a split should end with this extension.
possible_labels
These are the list of possible labels contained in any split.
"""
assert
len
(
batch
)
==
2
# data, metadata
assert
isinstance
(
batch
[
0
],
torch
.
Tensor
)
assert
batch
[
0
].
shape
[
0
]
==
size
# mini-batch size
assert
batch
[
0
].
shape
[
1
]
==
1
# grayscale images
assert
batch
[
0
].
shape
[
2
]
==
batch
[
0
].
shape
[
3
]
# image is square
assert
isinstance
(
batch
[
1
],
dict
)
# metadata
assert
len
(
batch
[
1
])
==
2
# label and name
assert
"
label
"
in
batch
[
1
]
assert
all
([
k
in
possible_labels
for
k
in
batch
[
1
][
"
label
"
]])
assert
"
name
"
in
batch
[
1
]
assert
all
([
k
.
startswith
(
prefix
)
for
k
in
batch
[
1
][
"
name
"
]])
assert
all
([
k
.
endswith
(
extension
)
for
k
in
batch
[
1
][
"
name
"
]])
def
test_protocol_consistency
():
# Cross-validation fold 0-2
for
k
in
range
(
3
):
_check_split
(
f
"
fold-
{
k
}
.json
"
,
lengths
=
dict
(
train
=
174
,
validation
=
44
,
test
=
25
),
)
# Cross-validation fold 3-9
for
k
in
range
(
3
,
10
):
_check_split
(
f
"
fold-
{
k
}
.json
"
,
lengths
=
dict
(
train
=
175
,
validation
=
44
,
test
=
24
),
)
@pytest.mark.skip_if_rc_var_not_set
(
"
datadir.hivtb
"
)
@pytest.mark.skip_if_rc_var_not_set
(
"
datadir.hivtb
"
)
def
test_loading
():
@pytest.mark.parametrize
(
from
ptbench.data.hivtb.fold_0
import
datamodule
"
dataset
"
,
[
"
train
"
,
"
validation
"
,
"
test
"
,
],
)
@pytest.mark.parametrize
(
"
name
"
,
[
"
fold_0
"
,
"
fold_1
"
,
"
fold_2
"
,
"
fold_3
"
,
"
fold_4
"
,
"
fold_5
"
,
"
fold_6
"
,
"
fold_7
"
,
"
fold_8
"
,
"
fold_9
"
,
],
)
def
test_loading
(
database_checkers
,
name
:
str
,
dataset
:
str
):
datamodule
=
importlib
.
import_module
(
f
"
.
{
name
}
"
,
"
ptbench.data.hivtb
"
).
datamodule
datamodule
.
model_transforms
=
[]
# should be done before setup()
datamodule
.
model_transforms
=
[]
# should be done before setup()
datamodule
.
setup
(
"
predict
"
)
# sets up all datasets
datamodule
.
setup
(
"
predict
"
)
# sets up all datasets
for
loader
in
datamodule
.
predict_dataloader
().
values
():
loader
=
datamodule
.
predict_dataloader
()[
dataset
]
limit
=
5
# limit load checking
for
batch
in
loader
:
limit
=
3
# limit load checking
if
limit
==
0
:
for
batch
in
loader
:
break
if
limit
==
0
:
_check_loaded_batch
(
batch
)
break
limit
-=
1
database_checkers
.
check_loaded_batch
(
batch
,
batch_size
=
1
,
color_planes
=
1
,
prefixes
=
(
"
HIV-TB_Algorithm_study_X-rays
"
,),
possible_labels
=
(
0
,
1
),
)
limit
-=
1
This diff is collapsed.
Click to expand it.
tests/test_tbpoc.py
+
84
−
116
View file @
993600d9
...
@@ -3,127 +3,95 @@
...
@@ -3,127 +3,95 @@
# SPDX-License-Identifier: GPL-3.0-or-later
# SPDX-License-Identifier: GPL-3.0-or-later
"""
Tests for TB-POC dataset.
"""
"""
Tests for TB-POC dataset.
"""
import
pytest
import
importlib
import
torch
from
ptbench.data.tbpoc.datamodule
import
make_split
def
_check_split
(
split_filename
:
str
,
lengths
:
dict
[
str
,
int
],
prefix
:
str
=
"
TBPOC_CXR/
"
,
extension
:
str
=
"
.jpeg
"
,
possible_labels
:
list
[
int
]
=
[
0
,
1
],
):
"""
Runs a simple consistence check on the data split.
Parameters
----------
split_filename
This is the split we will check
lenghts
A dictionary that contains keys matching those of the split (this will
be checked). The values of the dictionary should correspond to the
sizes of each of the datasets in the split.
prefix
Each file named in a split should start with this prefix.
extension
import
pytest
Each file named in a split should end with this extension.
possible_labels
These are the list of possible labels contained in any split.
"""
split
=
make_split
(
split_filename
)
assert
len
(
split
)
==
len
(
lengths
)
for
k
in
lengths
.
keys
():
# dataset must have been declared
assert
k
in
split
assert
len
(
split
[
k
])
==
lengths
[
k
]
for
s
in
split
[
k
]:
# assert s[0].startswith(prefix)
assert
s
[
0
].
endswith
(
extension
)
assert
s
[
1
]
in
possible_labels
def
_check_loaded_batch
(
def
id_function
(
val
):
batch
,
if
isinstance
(
val
,
dict
):
size
:
int
=
1
,
return
str
(
val
)
prefix
:
str
=
"
TBPOC_CXR/
"
,
return
repr
(
val
)
extension
:
str
=
"
.jpeg
"
,
possible_labels
:
list
[
int
]
=
[
0
,
1
],
@pytest.mark.parametrize
(
"
split,lenghts
"
,
[
(
"
fold-0
"
,
dict
(
train
=
292
,
validation
=
74
,
test
=
41
)),
(
"
fold-1
"
,
dict
(
train
=
292
,
validation
=
74
,
test
=
41
)),
(
"
fold-2
"
,
dict
(
train
=
292
,
validation
=
74
,
test
=
41
)),
(
"
fold-3
"
,
dict
(
train
=
292
,
validation
=
74
,
test
=
41
)),
(
"
fold-4
"
,
dict
(
train
=
292
,
validation
=
74
,
test
=
41
)),
(
"
fold-5
"
,
dict
(
train
=
292
,
validation
=
74
,
test
=
41
)),
(
"
fold-6
"
,
dict
(
train
=
292
,
validation
=
74
,
test
=
41
)),
(
"
fold-7
"
,
dict
(
train
=
293
,
validation
=
74
,
test
=
40
)),
(
"
fold-8
"
,
dict
(
train
=
293
,
validation
=
74
,
test
=
40
)),
(
"
fold-9
"
,
dict
(
train
=
293
,
validation
=
74
,
test
=
40
)),
],
ids
=
id_function
,
# just changes how pytest prints it
)
def
test_protocol_consistency
(
database_checkers
,
split
:
str
,
lenghts
:
dict
[
str
,
int
]
):
):
"""
Checks the consistence of an individual (loaded) batch.
from
ptbench.data.tbpoc.datamodule
import
make_split
Parameters
database_checkers
.
check_split
(
----------
make_split
(
f
"
{
split
}
.json
"
),
lengths
=
lenghts
,
batch
prefixes
=
(
The loaded batch to be checked.
"
TBPOC_CXR/TBPOC-
"
,
"
TBPOC_CXR/tbpoc-
"
,
prefix
),
Each file named in a split should start with this prefix.
possible_labels
=
(
0
,
1
),
)
extension
Each file named in a split should end with this extension.
@pytest.mark.skip_if_rc_var_not_set
(
"
datadir.tbpoc
"
)
possible_labels
@pytest.mark.parametrize
(
These are the list of possible labels contained in any split.
"
dataset
"
,
"""
[
"
train
"
,
assert
len
(
batch
)
==
2
# data, metadata
"
validation
"
,
"
test
"
,
assert
isinstance
(
batch
[
0
],
torch
.
Tensor
)
],
assert
batch
[
0
].
shape
[
0
]
==
size
# mini-batch size
)
assert
batch
[
0
].
shape
[
1
]
==
1
# grayscale images
@pytest.mark.parametrize
(
assert
batch
[
0
].
shape
[
2
]
==
batch
[
0
].
shape
[
3
]
# image is square
"
name
"
,
[
assert
isinstance
(
batch
[
1
],
dict
)
# metadata
"
fold_0
"
,
assert
len
(
batch
[
1
])
==
2
# label and name
"
fold_1
"
,
"
fold_2
"
,
assert
"
label
"
in
batch
[
1
]
"
fold_3
"
,
assert
all
([
k
in
possible_labels
for
k
in
batch
[
1
][
"
label
"
]])
"
fold_4
"
,
"
fold_5
"
,
assert
"
name
"
in
batch
[
1
]
"
fold_6
"
,
# assert all([k.startswith(prefix) for k in batch[1]["name"]])
"
fold_7
"
,
assert
all
([
k
.
endswith
(
extension
)
for
k
in
batch
[
1
][
"
name
"
]])
"
fold_8
"
,
"
fold_9
"
,
],
def
test_protocol_consistency
():
)
# Cross-validation fold 0-6
def
test_loading
(
database_checkers
,
name
:
str
,
dataset
:
str
):
for
k
in
range
(
7
):
datamodule
=
importlib
.
import_module
(
_check_split
(
f
"
.
{
name
}
"
,
"
ptbench.data.tbpoc
"
f
"
fold-
{
k
}
.json
"
,
).
datamodule
lengths
=
dict
(
train
=
292
,
validation
=
74
,
test
=
41
),
)
# Cross-validation fold 7-9
for
k
in
range
(
7
,
10
):
_check_split
(
f
"
fold-
{
k
}
.json
"
,
lengths
=
dict
(
train
=
293
,
validation
=
74
,
test
=
40
),
)
@pytest.mark.skip_if_rc_var_not_set
(
"
datadir.hivtb
"
)
def
test_loading
():
from
ptbench.data.tbpoc.fold_0
import
datamodule
datamodule
.
model_transforms
=
[]
# should be done before setup()
datamodule
.
model_transforms
=
[]
# should be done before setup()
datamodule
.
setup
(
"
predict
"
)
# sets up all datasets
datamodule
.
setup
(
"
predict
"
)
# sets up all datasets
for
loader
in
datamodule
.
predict_dataloader
().
values
():
loader
=
datamodule
.
predict_dataloader
()[
dataset
]
limit
=
5
# limit load checking
for
batch
in
loader
:
limit
=
3
# limit load checking
if
limit
==
0
:
for
batch
in
loader
:
break
if
limit
==
0
:
_check_loaded_batch
(
batch
)
break
limit
-=
1
database_checkers
.
check_loaded_batch
(
batch
,
batch_size
=
1
,
color_planes
=
1
,
prefixes
=
(
"
TBPOC_CXR/TBPOC-
"
,
"
TBPOC_CXR/tbpoc-
"
,
),
possible_labels
=
(
0
,
1
),
)
limit
-=
1
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment