Skip to content
Snippets Groups Projects

Refactor meta to follow drop of bob-devel

Merged Samuel GAIST requested to merge drop_bob_devel into master
Files
2
@@ -2,21 +2,22 @@
# coding=utf-8
import os
import click
from bob.extension.scripts.click_helper import (
verbosity_option,
AliasedGroup,
)
import re
import shutil
import click
import numpy
import torch
import re
import pandas
import logging
logger = logging.getLogger(__name__)
from bob.extension.scripts.click_helper import (
verbosity_option,
AliasedGroup,
)
def _load(data):
"""Load prediction.csv files
@@ -36,11 +37,17 @@ def _load(data):
A dict in which keys are the names of the systems and the values are
dictionaries that contain two keys:
* ``df``: A :py:class:`pandas.DataFrame` with the predictions data
* ``df``: A :py:class:`pandas.DataFrame` with the predictions data
loaded to
"""
def _to_double_tensor(col):
"""Converts a column in a dataframe to a tensor array"""
pattern = re.compile(" +")
return col.apply(lambda cell: numpy.array(eval(pattern.sub(",", cell))))
# loads all data
retval = {}
for name, predictions_path in data.items():
@@ -48,16 +55,8 @@ def _load(data):
# Load predictions
logger.info(f"Loading predictions from {predictions_path}...")
pred_data = pandas.read_csv(predictions_path)
pred = torch.Tensor([eval(re.sub(' +', ' ', x.replace('\n', '')).replace(' ', ',')) for x in pred_data['likelihood'].values]).double()
gt = torch.Tensor([eval(re.sub(' +', ' ', x.replace('\n', '')).replace(' ', ',')) for x in pred_data['ground_truth'].values]).double()
if pred.shape[1] == 1 and gt.shape[1] == 1:
pred = torch.flatten(pred)
gt = torch.flatten(gt)
pred_data['likelihood'] = pred
pred_data['ground_truth'] = gt
pred_data['likelihood'] = _to_double_tensor(pred_data['likelihood'])
pred_data['ground_truth'] = _to_double_tensor(pred_data['ground_truth'])
retval[name] = dict(df=pred_data)
return retval
@@ -109,12 +108,12 @@ def predtojson(label_path, output_folder, **kwargs):
logger.info("Saving JSON file...")
with open(output_file, "a+", newline="") as f:
f.write('{')
for i, (name, value) in enumerate(data.items()):
if i > 0:
f.write(',')
df = value["df"]
f.write('"'+name+'": [')
for index, row in df.iterrows():
@@ -127,4 +126,4 @@ def predtojson(label_path, output_folder, **kwargs):
f.write(']')
f.write(']')
f.write('}')
Loading