Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
mednet
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Package registry
Model registry
Operate
Environments
Terraform modules
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
medai
software
mednet
Commits
dc220169
Commit
dc220169
authored
1 year ago
by
Daniel CARRON
Committed by
André Anjos
1 year ago
Browse files
Options
Downloads
Patches
Plain Diff
Updated shenzhen tests
parent
87a65d97
Branches
Branches containing commit
Tags
Tags containing commit
1 merge request
!6
Making use of LightningDataModule and simplification of data loading
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
tests/test_ch.py
+102
-41
102 additions, 41 deletions
tests/test_ch.py
with
102 additions
and
41 deletions
tests/test_ch.py
+
102
−
41
View file @
dc220169
...
@@ -4,133 +4,194 @@
...
@@ -4,133 +4,194 @@
"""
Tests for Shenzhen dataset.
"""
"""
Tests for Shenzhen dataset.
"""
import
pytest
import
importlib
from
ptbench.data.shenzhen
import
datase
t
import
pytes
t
def
test_protocol_consistency
():
def
test_protocol_consistency
():
# Default protocol
# Default protocol
subset
=
dataset
.
subsets
(
"
default
"
)
datamodule
=
importlib
.
import_module
(
"
ptbench.data.shenzhen.default
"
).
datamodule
subset
=
datamodule
.
database_split
.
subsets
assert
len
(
subset
)
==
3
assert
len
(
subset
)
==
3
assert
"
train
"
in
subset
assert
"
train
"
in
subset
assert
len
(
subset
[
"
train
"
])
==
422
assert
len
(
subset
[
"
train
"
])
==
422
for
s
in
subset
[
"
train
"
]:
for
s
in
subset
[
"
train
"
]:
assert
s
.
key
.
startswith
(
"
CXR_png/CHNCXR_0
"
)
assert
s
[
0
]
.
startswith
(
"
CXR_png/CHNCXR_0
"
)
assert
"
validation
"
in
subset
assert
"
validation
"
in
subset
assert
len
(
subset
[
"
validation
"
])
==
107
assert
len
(
subset
[
"
validation
"
])
==
107
for
s
in
subset
[
"
validation
"
]:
for
s
in
subset
[
"
validation
"
]:
assert
s
.
key
.
startswith
(
"
CXR_png/CHNCXR_0
"
)
assert
s
[
0
]
.
startswith
(
"
CXR_png/CHNCXR_0
"
)
assert
"
test
"
in
subset
assert
"
test
"
in
subset
assert
len
(
subset
[
"
test
"
])
==
133
assert
len
(
subset
[
"
test
"
])
==
133
for
s
in
subset
[
"
test
"
]:
for
s
in
subset
[
"
test
"
]:
assert
s
.
key
.
startswith
(
"
CXR_png/CHNCXR_0
"
)
assert
s
[
0
]
.
startswith
(
"
CXR_png/CHNCXR_0
"
)
# Check labels
# Check labels
for
s
in
subset
[
"
train
"
]:
for
s
in
subset
[
"
train
"
]:
assert
s
.
label
in
[
0.0
,
1.0
]
assert
s
[
1
]
in
[
0.0
,
1.0
]
for
s
in
subset
[
"
validation
"
]:
for
s
in
subset
[
"
validation
"
]:
assert
s
.
label
in
[
0.0
,
1.0
]
assert
s
[
1
]
in
[
0.0
,
1.0
]
for
s
in
subset
[
"
test
"
]:
for
s
in
subset
[
"
test
"
]:
assert
s
.
label
in
[
0.0
,
1.0
]
assert
s
[
1
]
in
[
0.0
,
1.0
]
# Cross-validation folds 0-1
# Cross-validation folds 0-1
for
f
in
range
(
2
):
for
f
in
range
(
2
):
subset
=
dataset
.
subsets
(
"
fold_
"
+
str
(
f
))
datamodule
=
importlib
.
import_module
(
f
"
ptbench.data.shenzhen.fold_
{
str
(
f
)
}
"
).
datamodule
subset
=
datamodule
.
database_split
.
subsets
assert
len
(
subset
)
==
3
assert
len
(
subset
)
==
3
assert
"
train
"
in
subset
assert
"
train
"
in
subset
assert
len
(
subset
[
"
train
"
])
==
476
assert
len
(
subset
[
"
train
"
])
==
476
for
s
in
subset
[
"
train
"
]:
for
s
in
subset
[
"
train
"
]:
assert
s
.
key
.
startswith
(
"
CXR_png/CHNCXR_0
"
)
assert
s
[
0
]
.
startswith
(
"
CXR_png/CHNCXR_0
"
)
assert
"
validation
"
in
subset
assert
"
validation
"
in
subset
assert
len
(
subset
[
"
validation
"
])
==
119
assert
len
(
subset
[
"
validation
"
])
==
119
for
s
in
subset
[
"
validation
"
]:
for
s
in
subset
[
"
validation
"
]:
assert
s
.
key
.
startswith
(
"
CXR_png/CHNCXR_0
"
)
assert
s
[
0
]
.
startswith
(
"
CXR_png/CHNCXR_0
"
)
assert
"
test
"
in
subset
assert
"
test
"
in
subset
assert
len
(
subset
[
"
test
"
])
==
67
assert
len
(
subset
[
"
test
"
])
==
67
for
s
in
subset
[
"
test
"
]:
for
s
in
subset
[
"
test
"
]:
assert
s
.
key
.
startswith
(
"
CXR_png/CHNCXR_0
"
)
assert
s
[
0
]
.
startswith
(
"
CXR_png/CHNCXR_0
"
)
# Check labels
# Check labels
for
s
in
subset
[
"
train
"
]:
for
s
in
subset
[
"
train
"
]:
assert
s
.
label
in
[
0.0
,
1.0
]
assert
s
[
1
]
in
[
0.0
,
1.0
]
for
s
in
subset
[
"
validation
"
]:
for
s
in
subset
[
"
validation
"
]:
assert
s
.
label
in
[
0.0
,
1.0
]
assert
s
[
1
]
in
[
0.0
,
1.0
]
for
s
in
subset
[
"
test
"
]:
for
s
in
subset
[
"
test
"
]:
assert
s
.
label
in
[
0.0
,
1.0
]
assert
s
[
1
]
in
[
0.0
,
1.0
]
# Cross-validation folds 2-9
# Cross-validation folds 2-9
for
f
in
range
(
2
,
10
):
for
f
in
range
(
2
,
10
):
subset
=
dataset
.
subsets
(
"
fold_
"
+
str
(
f
))
datamodule
=
importlib
.
import_module
(
f
"
ptbench.data.shenzhen.fold_
{
str
(
f
)
}
"
).
datamodule
subset
=
datamodule
.
database_split
.
subsets
assert
len
(
subset
)
==
3
assert
len
(
subset
)
==
3
assert
"
train
"
in
subset
assert
"
train
"
in
subset
assert
len
(
subset
[
"
train
"
])
==
476
assert
len
(
subset
[
"
train
"
])
==
476
for
s
in
subset
[
"
train
"
]:
for
s
in
subset
[
"
train
"
]:
assert
s
.
key
.
startswith
(
"
CXR_png/CHNCXR_0
"
)
assert
s
[
0
]
.
startswith
(
"
CXR_png/CHNCXR_0
"
)
assert
"
validation
"
in
subset
assert
"
validation
"
in
subset
assert
len
(
subset
[
"
validation
"
])
==
120
assert
len
(
subset
[
"
validation
"
])
==
120
for
s
in
subset
[
"
validation
"
]:
for
s
in
subset
[
"
validation
"
]:
assert
s
.
key
.
startswith
(
"
CXR_png/CHNCXR_0
"
)
assert
s
[
0
]
.
startswith
(
"
CXR_png/CHNCXR_0
"
)
assert
"
test
"
in
subset
assert
"
test
"
in
subset
assert
len
(
subset
[
"
test
"
])
==
66
assert
len
(
subset
[
"
test
"
])
==
66
for
s
in
subset
[
"
test
"
]:
for
s
in
subset
[
"
test
"
]:
assert
s
.
key
.
startswith
(
"
CXR_png/CHNCXR_0
"
)
assert
s
[
0
]
.
startswith
(
"
CXR_png/CHNCXR_0
"
)
# Check labels
# Check labels
for
s
in
subset
[
"
train
"
]:
for
s
in
subset
[
"
train
"
]:
assert
s
.
label
in
[
0.0
,
1.0
]
assert
s
[
1
]
in
[
0.0
,
1.0
]
for
s
in
subset
[
"
validation
"
]:
for
s
in
subset
[
"
validation
"
]:
assert
s
.
label
in
[
0.0
,
1.0
]
assert
s
[
1
]
in
[
0.0
,
1.0
]
for
s
in
subset
[
"
test
"
]:
for
s
in
subset
[
"
test
"
]:
assert
s
.
label
in
[
0.0
,
1.0
]
assert
s
[
1
]
in
[
0.0
,
1.0
]
@pytest.mark.skip_if_rc_var_not_set
(
"
datadir.shenzhen
"
)
@pytest.mark.skip_if_rc_var_not_set
(
"
datadir.shenzhen
"
)
def
test_loading
():
def
test_loading
():
def
_check_size
(
size
):
import
torch
if
(
import
torchvision.transforms
size
[
0
]
>=
1130
and
size
[
0
]
<=
3001
from
ptbench.data.datamodule
import
_DelayedLoadingDataset
and
size
[
1
]
>=
948
and
size
[
1
]
<=
3001
def
_check_size
(
shape
):
)
:
if
shape
[
0
]
==
1
and
shape
[
1
]
==
512
and
shape
[
2
]
==
512
:
return
True
return
True
return
False
return
False
def
_check_sample
(
s
):
def
_check_sample
(
s
):
data
=
s
.
data
assert
len
(
s
)
==
2
assert
isinstance
(
data
,
dict
)
assert
len
(
data
)
==
2
assert
"
data
"
in
data
data
=
s
[
0
]
assert
_check_size
(
data
[
"
data
"
].
size
)
# Check size
metadata
=
s
[
1
]
assert
data
[
"
data
"
].
mode
==
"
L
"
# Check colors
assert
"
label
"
in
data
assert
isinstance
(
data
,
torch
.
Tensor
)
assert
data
[
"
label
"
]
in
[
0
,
1
]
# Check labels
print
(
data
.
shape
)
assert
_check_size
(
data
.
shape
)
# Check size
assert
(
torchvision
.
transforms
.
ToPILImage
()(
data
).
mode
==
"
L
"
)
# Check colors
assert
"
label
"
in
metadata
assert
metadata
[
"
label
"
]
in
[
0
,
1
]
# Check labels
limit
=
30
# use this to limit testing to first images only, else None
limit
=
30
# use this to limit testing to first images only, else None
subset
=
dataset
.
subsets
(
"
default
"
)
datamodule
=
importlib
.
import_module
(
for
s
in
subset
[
"
train
"
][:
limit
]:
"
ptbench.data.shenzhen.default
"
).
datamodule
subset
=
datamodule
.
database_split
.
subsets
raw_data_loader
=
datamodule
.
raw_data_loader
# Need to use private function so we can limit the number of samples to use
dataset
=
_DelayedLoadingDataset
(
subset
[
"
train
"
][:
limit
],
raw_data_loader
,
)
for
s
in
dataset
:
_check_sample
(
s
)
_check_sample
(
s
)
@pytest.mark.skip_if_rc_var_not_set
(
"
datadir.shenzhen
"
)
@pytest.mark.skip_if_rc_var_not_set
(
"
datadir.shenzhen
"
)
def
test_check
():
def
test_check
():
assert
dataset
.
check
()
==
0
from
ptbench.data.split
import
check_database_split_loading
limit
=
30
# use this to limit testing to first images only, else 0
# Default protocol
datamodule
=
importlib
.
import_module
(
"
ptbench.data.shenzhen.default
"
).
datamodule
database_split
=
datamodule
.
database_split
raw_data_loader
=
datamodule
.
raw_data_loader
assert
(
check_database_split_loading
(
database_split
,
raw_data_loader
,
limit
=
limit
)
==
0
)
# Folds
for
f
in
range
(
10
):
datamodule
=
importlib
.
import_module
(
f
"
ptbench.data.shenzhen.fold_
{
f
}
"
).
datamodule
database_split
=
datamodule
.
database_split
raw_data_loader
=
datamodule
.
raw_data_loader
assert
(
check_database_split_loading
(
database_split
,
raw_data_loader
,
limit
=
limit
)
==
0
)
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment