From 978f04e896c1f74399bc50eaabb67d7c3396891b Mon Sep 17 00:00:00 2001 From: Andre Anjos <andre.dos.anjos@gmail.com> Date: Thu, 27 Jul 2023 20:59:03 +0200 Subject: [PATCH] [data.indian] Use right split name; separate split creation so it is reusable --- src/ptbench/data/indian/datamodule.py | 15 ++++++++++----- src/ptbench/data/indian/default.py | 2 +- src/ptbench/data/indian/fold_0.py | 2 +- src/ptbench/data/indian/fold_1.py | 2 +- src/ptbench/data/indian/fold_2.py | 2 +- src/ptbench/data/indian/fold_3.py | 2 +- src/ptbench/data/indian/fold_4.py | 2 +- src/ptbench/data/indian/fold_5.py | 2 +- src/ptbench/data/indian/fold_6.py | 2 +- src/ptbench/data/indian/fold_7.py | 2 +- src/ptbench/data/indian/fold_8.py | 2 +- src/ptbench/data/indian/fold_9.py | 2 +- 12 files changed, 21 insertions(+), 16 deletions(-) diff --git a/src/ptbench/data/indian/datamodule.py b/src/ptbench/data/indian/datamodule.py index 7042a0c4..34fb4b5a 100644 --- a/src/ptbench/data/indian/datamodule.py +++ b/src/ptbench/data/indian/datamodule.py @@ -7,6 +7,15 @@ import importlib.resources from ..datamodule import CachingDataModule from ..shenzhen.datamodule import RawDataLoader from ..split import JSONDatabaseSplit +from ..typing import DatabaseSplit + + +def make_split(basename: str) -> DatabaseSplit: + """Returns a database split for the Indian database.""" + + return JSONDatabaseSplit( + importlib.resources.files(__name__.rsplit(".", 1)[0]).joinpath(basename) + ) class DataModule(CachingDataModule): @@ -42,10 +51,6 @@ class DataModule(CachingDataModule): def __init__(self, split_filename: str): super().__init__( - database_split=JSONDatabaseSplit( - importlib.resources.files(__name__.rsplit(".", 1)[0]).joinpath( - split_filename - ) - ), + database_split=make_split(split_filename), raw_data_loader=RawDataLoader(), ) diff --git a/src/ptbench/data/indian/default.py b/src/ptbench/data/indian/default.py index 7fe993a9..2b8a8fb2 100644 --- a/src/ptbench/data/indian/default.py +++ b/src/ptbench/data/indian/default.py @@ -4,4 +4,4 @@ from .datamodule import DataModule -datamodule = DataModule("default.json.bz2") +datamodule = DataModule("default.json") diff --git a/src/ptbench/data/indian/fold_0.py b/src/ptbench/data/indian/fold_0.py index c810e85c..3d114d07 100644 --- a/src/ptbench/data/indian/fold_0.py +++ b/src/ptbench/data/indian/fold_0.py @@ -4,4 +4,4 @@ from .datamodule import DataModule -datamodule = DataModule("fold_0.json.bz2") +datamodule = DataModule("fold_0.json") diff --git a/src/ptbench/data/indian/fold_1.py b/src/ptbench/data/indian/fold_1.py index 736a778d..cd3a8cb6 100644 --- a/src/ptbench/data/indian/fold_1.py +++ b/src/ptbench/data/indian/fold_1.py @@ -4,4 +4,4 @@ from .datamodule import DataModule -datamodule = DataModule("fold_1.json.bz2") +datamodule = DataModule("fold_1.json") diff --git a/src/ptbench/data/indian/fold_2.py b/src/ptbench/data/indian/fold_2.py index 48df1bfe..44eeda80 100644 --- a/src/ptbench/data/indian/fold_2.py +++ b/src/ptbench/data/indian/fold_2.py @@ -4,4 +4,4 @@ from .datamodule import DataModule -datamodule = DataModule("fold_2.json.bz2") +datamodule = DataModule("fold_2.json") diff --git a/src/ptbench/data/indian/fold_3.py b/src/ptbench/data/indian/fold_3.py index 9967e4ea..f24fb314 100644 --- a/src/ptbench/data/indian/fold_3.py +++ b/src/ptbench/data/indian/fold_3.py @@ -4,4 +4,4 @@ from .datamodule import DataModule -datamodule = DataModule("fold_3.json.bz2") +datamodule = DataModule("fold_3.json") diff --git a/src/ptbench/data/indian/fold_4.py b/src/ptbench/data/indian/fold_4.py index 8630ee09..58456d38 100644 --- a/src/ptbench/data/indian/fold_4.py +++ b/src/ptbench/data/indian/fold_4.py @@ -4,4 +4,4 @@ from .datamodule import DataModule -datamodule = DataModule("fold_4.json.bz2") +datamodule = DataModule("fold_4.json") diff --git a/src/ptbench/data/indian/fold_5.py b/src/ptbench/data/indian/fold_5.py index 0c7504c5..92796746 100644 --- a/src/ptbench/data/indian/fold_5.py +++ b/src/ptbench/data/indian/fold_5.py @@ -4,4 +4,4 @@ from .datamodule import DataModule -datamodule = DataModule("fold_5.json.bz2") +datamodule = DataModule("fold_5.json") diff --git a/src/ptbench/data/indian/fold_6.py b/src/ptbench/data/indian/fold_6.py index 2f8e8e32..9566b7cf 100644 --- a/src/ptbench/data/indian/fold_6.py +++ b/src/ptbench/data/indian/fold_6.py @@ -4,4 +4,4 @@ from .datamodule import DataModule -datamodule = DataModule("fold_6.json.bz2") +datamodule = DataModule("fold_6.json") diff --git a/src/ptbench/data/indian/fold_7.py b/src/ptbench/data/indian/fold_7.py index 389e7f4e..25cbfe1b 100644 --- a/src/ptbench/data/indian/fold_7.py +++ b/src/ptbench/data/indian/fold_7.py @@ -4,4 +4,4 @@ from .datamodule import DataModule -datamodule = DataModule("fold_7.json.bz2") +datamodule = DataModule("fold_7.json") diff --git a/src/ptbench/data/indian/fold_8.py b/src/ptbench/data/indian/fold_8.py index a9480359..fb5332ce 100644 --- a/src/ptbench/data/indian/fold_8.py +++ b/src/ptbench/data/indian/fold_8.py @@ -4,4 +4,4 @@ from .datamodule import DataModule -datamodule = DataModule("fold_8.json.bz2") +datamodule = DataModule("fold_8.json") diff --git a/src/ptbench/data/indian/fold_9.py b/src/ptbench/data/indian/fold_9.py index daa85e03..d1626586 100644 --- a/src/ptbench/data/indian/fold_9.py +++ b/src/ptbench/data/indian/fold_9.py @@ -4,4 +4,4 @@ from .datamodule import DataModule -datamodule = DataModule("fold_9.json.bz2") +datamodule = DataModule("fold_9.json") -- GitLab