import yaml 
from yaml import Loader 
from easydict import EasyDict
from datasets import load_from_disk
from multiprocessing import cpu_count
import torch
from torchmetrics.functional import pairwise_cosine_similarity
import matplotlib.pyplot as plt
import numpy as np
import os 
from pathlib import Path
from utils import get_time
from PIL import Image

with open("./transferred_vs_orig.yaml") as file:
    config = yaml.load(file, Loader=Loader)
config = EasyDict(config)
    
output_folder = Path(f"./transferred_vs_orig/{get_time()}")
output_folder.mkdir(parents=True, exist_ok=True)

original = load_from_disk(config.original.path).filter(lambda sample: sample['status_' + config.original.key], num_proc=cpu_count()-1)
original.set_format('torch', columns=[config.original.key], output_all_columns=True)

print(original.shape)

generated = load_from_disk(config.trans.path).filter(lambda sample: sample['status_' + config.trans.key], num_proc=cpu_count()-1)
generated.set_format('torch', columns=[config.trans.key], output_all_columns=True)

# Creating new column with last two directories in the image_path for finding the pairs
final_path_key = "name"
id_key = "id"
def file_name(sample): 
    sample[final_path_key] = sample[config.path_key][-13:]

    sample[id_key] = sample[config.path_key].split('/')[-2] # The folder name of the each image is the id

    return sample

generated = generated.map(file_name, num_proc=cpu_count()-1)
pd_gen = generated.select_columns(column_names=[final_path_key]).to_pandas()
original = original.map(file_name, num_proc=cpu_count()-1)
pd_original = generated.select_columns(column_names=[final_path_key]).to_pandas()

generated_file_names = generated[final_path_key]
original = original.filter(lambda sample: sample[final_path_key] in generated_file_names, num_proc=cpu_count()-1)
original_file_names = original[final_path_key]
generated = generated.filter(lambda sample: sample[final_path_key] in original_file_names, num_proc=cpu_count()-1)

print("Sorting")
generated = generated.sort(column_names=[final_path_key])
original  = original.sort(column_names=[final_path_key])

ids_pd = generated.select_columns(column_names=[id_key]).to_pandas()
unique_ids = ids_pd[id_key].unique()
print(unique_ids)

inter_id_gen = torch.zeros(0, dtype=torch.float32)
inter_id_orig = torch.zeros(0, dtype=torch.float32)
for unique_id in unique_ids: 
    from pandas import Series
    print(f"Cosines for {unique_id}")
    indices:Series = ids_pd.loc[ids_pd[id_key] == unique_id].index.to_numpy()

    print(f"Selecting {indices} for {unique_id}")
    selected_orig = original[config.original.key][indices]
    selected_gen = generated[config.trans.key][indices]

    print(selected_orig.shape)
    print(selected_gen.shape)

    inter_id_orig = torch.concat([pairwise_cosine_similarity(selected_orig).flatten(), inter_id_orig])
    inter_id_gen = torch.concat([pairwise_cosine_similarity(selected_gen).flatten(), inter_id_gen]) 

assert generated[final_path_key][100] == original[final_path_key][100]

orig_embeds = original[config.original.key]
gen_embeds = generated[config.trans.key]
print(orig_embeds.shape)
print(gen_embeds.shape)

print("Calc Cosine Similarities ... ")
cosines = pairwise_cosine_similarity(orig_embeds, gen_embeds) 

cosines = torch.diagonal(cosines, 0 )
distant_indices = torch.where(cosines < config.distant_threshold)[0]
close_indices = torch.where(cosines > config.close_threshold)[0]

bins = np.linspace(0,1,config.nbins)
counts, bins = np.histogram(cosines, bins)

# fig, ax = plt.subplot(1)

plt.hist(bins[:-1], bins=bins, weights=counts, density=True)
plt.title(f"Cosine similarity, pairs of Unrealistic and Realistic images, {config.original.key}")
plt.xlabel("Cosine Similarity")
plt.grid()
plt.tight_layout()
plt.savefig(
    output_folder / "hist.png", dpi=300,bbox_inches='tight'
)

plt.cla()
plt.clf()

bins = np.linspace(-1,1,config.nbins//2)
counts_inter_id_gen, bins = np.histogram(inter_id_gen, bins)
plt.hist(bins[:-1], bins=bins, weights=counts_inter_id_gen, density=True, label="Generated", alpha=0.3)
counts_inter_id_orig, bins = np.histogram(inter_id_orig, bins)
plt.hist(bins[:-1], bins=bins, weights=counts_inter_id_orig, density=True, label="Original", alpha=0.3)
plt.xlabel("Cosine Similarity")
plt.title(f"Positive Pair Distribution")
plt.grid()
plt.legend()
plt.tight_layout()
plt.savefig(
    output_folder / "positive.png", dpi=300, bbox_inches='tight'
)

distant_indices = distant_indices[:10]
distant_orig_paths = original[distant_indices][config.path_key]
distant_gen_paths = generated[distant_indices][config.path_key]

close_indices = close_indices[:10]
close_orig_paths = original[close_indices][config.path_key]
close_gen_paths = generated[close_indices][config.path_key]

def image_pair_gen(orig_paths, gen_paths, prefix=""):
    iter = 0    
    for orig, gen in zip(orig_paths, gen_paths):
        w, h = 256, 256
        orig = Image.open(orig).resize((w,h))
        gen = Image.open(gen).resize((w,h))

        final = Image.new("RGB", (2*w,h))
        final.paste(orig, (0,0))
        final.paste(gen, (w,0))
        final.save(output_folder / f"{prefix}_{iter}.jpg")
        iter += 1 

image_pair_gen(distant_orig_paths, distant_gen_paths, f"distant_{config.distant_threshold}")
image_pair_gen(close_orig_paths, close_gen_paths, f"close_{config.close_threshold}")