Skip to content
Snippets Groups Projects
test_database_split.py 1.45 KiB
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""Test code for datasets."""

from mednet.data.split import CSVDatabaseSplit, JSONDatabaseSplit


def test_csv_loading(datadir):
    # tests if we can build a simple CSV loader for the Iris Flower dataset
    database_split = CSVDatabaseSplit(datadir)

    assert len(database_split["iris-train"]) == 75
    for k in database_split["iris-train"]:
        for f in range(4):
            assert isinstance(k[f], str)  # csv only loads stringd
        assert isinstance(k[4], str)

    assert len(database_split["iris-test"]) == 75
    for k in database_split["iris-test"]:
        for f in range(4):
            assert isinstance(k[f], str)  # csv only loads stringd
        assert isinstance(k[4], str)
        assert k[4] in ("Iris-setosa", "Iris-versicolor", "Iris-virginica")


def test_json_loading(datadir):
    # tests if we can build a simple JSON loader for the Iris Flower dataset

    database_split = JSONDatabaseSplit(datadir / "iris.json")

    assert len(database_split["train"]) == 75
    for k in database_split["train"]:
        for f in range(4):
            assert isinstance(k[f], int | float)
        assert isinstance(k[4], str)

    assert len(database_split["test"]) == 75
    for k in database_split["test"]:
        for f in range(4):
            assert isinstance(k[f], int | float)
        assert isinstance(k[4], str)