Skip to content
Snippets Groups Projects
Commit 640950de authored by André Anjos's avatar André Anjos :speech_balloon:
Browse files

[data.shenzhen] Use right split name; separate split creation so it is reusable

parent 2e26ec26
No related branches found
No related tags found
1 merge request!6Making use of LightningDataModule and simplification of data loading
...@@ -13,6 +13,7 @@ from ...utils.rc import load_rc ...@@ -13,6 +13,7 @@ from ...utils.rc import load_rc
from ..datamodule import CachingDataModule from ..datamodule import CachingDataModule
from ..image_utils import remove_black_borders from ..image_utils import remove_black_borders
from ..split import JSONDatabaseSplit from ..split import JSONDatabaseSplit
from ..typing import DatabaseSplit
from ..typing import RawDataLoader as _BaseRawDataLoader from ..typing import RawDataLoader as _BaseRawDataLoader
from ..typing import Sample from ..typing import Sample
...@@ -93,6 +94,14 @@ class RawDataLoader(_BaseRawDataLoader): ...@@ -93,6 +94,14 @@ class RawDataLoader(_BaseRawDataLoader):
return sample[1] return sample[1]
def make_split(basename: str) -> DatabaseSplit:
"""Returns a database split for the Shenzhen database."""
return JSONDatabaseSplit(
importlib.resources.files(__name__.rsplit(".", 1)[0]).joinpath(basename)
)
class DataModule(CachingDataModule): class DataModule(CachingDataModule):
"""Shenzhen datamodule for computer-aided diagnosis. """Shenzhen datamodule for computer-aided diagnosis.
...@@ -128,10 +137,6 @@ class DataModule(CachingDataModule): ...@@ -128,10 +137,6 @@ class DataModule(CachingDataModule):
def __init__(self, split_filename: str): def __init__(self, split_filename: str):
super().__init__( super().__init__(
database_split=JSONDatabaseSplit( database_split=make_split(split_filename),
importlib.resources.files(__name__.rsplit(".", 1)[0]).joinpath(
split_filename
)
),
raw_data_loader=RawDataLoader(), raw_data_loader=RawDataLoader(),
) )
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
from .datamodule import DataModule from .datamodule import DataModule
datamodule = DataModule("default.json.bz2") datamodule = DataModule("default.json")
"""Default Shenzen TB database split. """Default Shenzen TB database split.
* Training samples: 64% of TB and healthy CXR (including labels) * Training samples: 64% of TB and healthy CXR (including labels)
......
...@@ -4,4 +4,4 @@ ...@@ -4,4 +4,4 @@
from .datamodule import DataModule from .datamodule import DataModule
datamodule = DataModule("fold_0.json.bz2") datamodule = DataModule("fold_0.json")
...@@ -4,4 +4,4 @@ ...@@ -4,4 +4,4 @@
from .datamodule import DataModule from .datamodule import DataModule
datamodule = DataModule("fold_1.json.bz2") datamodule = DataModule("fold_1.json")
...@@ -4,4 +4,4 @@ ...@@ -4,4 +4,4 @@
from .datamodule import DataModule from .datamodule import DataModule
datamodule = DataModule("fold_2.json.bz2") datamodule = DataModule("fold_2.json")
...@@ -4,4 +4,4 @@ ...@@ -4,4 +4,4 @@
from .datamodule import DataModule from .datamodule import DataModule
datamodule = DataModule("fold_3.json.bz2") datamodule = DataModule("fold_3.json")
...@@ -4,4 +4,4 @@ ...@@ -4,4 +4,4 @@
from .datamodule import DataModule from .datamodule import DataModule
datamodule = DataModule("fold_4.json.bz2") datamodule = DataModule("fold_4.json")
...@@ -4,4 +4,4 @@ ...@@ -4,4 +4,4 @@
from .datamodule import DataModule from .datamodule import DataModule
datamodule = DataModule("fold_5.json.bz2") datamodule = DataModule("fold_5.json")
...@@ -4,4 +4,4 @@ ...@@ -4,4 +4,4 @@
from .datamodule import DataModule from .datamodule import DataModule
datamodule = DataModule("fold_6.json.bz2") datamodule = DataModule("fold_6.json")
...@@ -4,4 +4,4 @@ ...@@ -4,4 +4,4 @@
from .datamodule import DataModule from .datamodule import DataModule
datamodule = DataModule("fold_7.json.bz2") datamodule = DataModule("fold_7.json")
...@@ -4,4 +4,4 @@ ...@@ -4,4 +4,4 @@
from .datamodule import DataModule from .datamodule import DataModule
datamodule = DataModule("fold_8.json.bz2") datamodule = DataModule("fold_8.json")
...@@ -4,4 +4,4 @@ ...@@ -4,4 +4,4 @@
from .datamodule import DataModule from .datamodule import DataModule
datamodule = DataModule("fold_9.json.bz2") datamodule = DataModule("fold_9.json")
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment