# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later

import glob
import os

from collections import defaultdict
from typing import Any

import pandas

from tensorboard.backend.event_processing.event_accumulator import (
    EventAccumulator,
)


def get_scalars(logdir: str) -> pandas.DataFrame:
    """Returns scalars stored in tensorboard event files.

    Parameters
    ----------

    logdir:
        Directory containing the event files.

    Returns
    -------

    data:
        Pandas dataframe containing the results. Rows correspond to an epoch, columns to the metrics.
    """
    tensorboard_logs = sorted(
        glob.glob(os.path.join(logdir, "events.out.tfevents.*"))
    )

    data: dict[str, dict[str, Any]] = defaultdict(dict)
    headers = {"step"}

    for logfile in tensorboard_logs:
        event_accumulator = EventAccumulator(logfile)
        event_accumulator.Reload()

        tags = event_accumulator.Tags()
        # Can cause issues if different logfiles don't have the same tags

        for scalar_tag in tags["scalars"]:
            headers.add(scalar_tag)
            tag_list = event_accumulator.Scalars(scalar_tag)
            for tag_data in tag_list:
                _ = tag_data.wall_time
                step = tag_data.step
                value = tag_data.value

                data[step]["step"] = step
                data[step][scalar_tag] = value

    data = pandas.DataFrame.from_dict(data, orient="index")
    return data