Skip to content
Snippets Groups Projects
test_database_split.py 1.42 KiB
Newer Older
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later

"""Test code for datasets."""

from ptbench.data.split import CSVDatabaseSplit, JSONDatabaseSplit


def test_csv_loading(datadir):
    # tests if we can build a simple CSV loader for the Iris Flower dataset
    database_split = CSVDatabaseSplit(datadir)

    assert len(database_split["iris-train"]) == 75
    for k in database_split["iris-train"]:
        for f in range(4):
            assert type(k[f]) == str  # csv only loads stringd
        assert type(k[4]) == str

    assert len(database_split["iris-test"]) == 75
    for k in database_split["iris-test"]:
        for f in range(4):
            assert type(k[f]) == str  # csv only loads stringd
        assert type(k[4]) == str
        assert k[4] in ("Iris-setosa", "Iris-versicolor", "Iris-virginica")


def test_json_loading(datadir):
    # tests if we can build a simple JSON loader for the Iris Flower dataset

    database_split = JSONDatabaseSplit(datadir / "iris.json")

    assert len(database_split["train"]) == 75
    for k in database_split["train"]:
        for f in range(4):
            assert type(k[f]) in [int, float]
        assert type(k[4]) == str

    assert len(database_split["test"]) == 75
    for k in database_split["test"]:
        for f in range(4):
            assert type(k[f]) in [int, float]
        assert type(k[4]) == str