import yaml from yaml import Loader from easydict import EasyDict from datasets import load_from_disk from multiprocessing import cpu_count import torch from torchmetrics.functional import pairwise_cosine_similarity import matplotlib.pyplot as plt import numpy as np import os from pathlib import Path from utils import get_time from PIL import Image with open("./transferred_vs_orig.yaml") as file: config = yaml.load(file, Loader=Loader) config = EasyDict(config) output_folder = Path(f"./transferred_vs_orig/{get_time()}") output_folder.mkdir(parents=True, exist_ok=True) original = load_from_disk(config.original.path).filter(lambda sample: sample['status_' + config.original.key], num_proc=cpu_count()-1) original.set_format('torch', columns=[config.original.key], output_all_columns=True) print(original.shape) generated = load_from_disk(config.trans.path).filter(lambda sample: sample['status_' + config.trans.key], num_proc=cpu_count()-1) generated.set_format('torch', columns=[config.trans.key], output_all_columns=True) # Creating new column with last two directories in the image_path for finding the pairs final_path_key = "name" id_key = "id" def file_name(sample): sample[final_path_key] = sample[config.path_key][-13:] sample[id_key] = sample[config.path_key].split('/')[-2] # The folder name of the each image is the id return sample generated = generated.map(file_name, num_proc=cpu_count()-1) pd_gen = generated.select_columns(column_names=[final_path_key]).to_pandas() original = original.map(file_name, num_proc=cpu_count()-1) pd_original = generated.select_columns(column_names=[final_path_key]).to_pandas() generated_file_names = generated[final_path_key] original = original.filter(lambda sample: sample[final_path_key] in generated_file_names, num_proc=cpu_count()-1) original_file_names = original[final_path_key] generated = generated.filter(lambda sample: sample[final_path_key] in original_file_names, num_proc=cpu_count()-1) print("Sorting") generated = generated.sort(column_names=[final_path_key]) original = original.sort(column_names=[final_path_key]) ids_pd = generated.select_columns(column_names=[id_key]).to_pandas() unique_ids = ids_pd[id_key].unique() print(unique_ids) inter_id_gen = torch.zeros(0, dtype=torch.float32) inter_id_orig = torch.zeros(0, dtype=torch.float32) for unique_id in unique_ids: from pandas import Series print(f"Cosines for {unique_id}") indices:Series = ids_pd.loc[ids_pd[id_key] == unique_id].index.to_numpy() print(f"Selecting {indices} for {unique_id}") selected_orig = original[config.original.key][indices] selected_gen = generated[config.trans.key][indices] print(selected_orig.shape) print(selected_gen.shape) inter_id_orig = torch.concat([pairwise_cosine_similarity(selected_orig).flatten(), inter_id_orig]) inter_id_gen = torch.concat([pairwise_cosine_similarity(selected_gen).flatten(), inter_id_gen]) assert generated[final_path_key][100] == original[final_path_key][100] orig_embeds = original[config.original.key] gen_embeds = generated[config.trans.key] print(orig_embeds.shape) print(gen_embeds.shape) print("Calc Cosine Similarities ... ") cosines = pairwise_cosine_similarity(orig_embeds, gen_embeds) cosines = torch.diagonal(cosines, 0 ) distant_indices = torch.where(cosines < config.distant_threshold)[0] close_indices = torch.where(cosines > config.close_threshold)[0] bins = np.linspace(0,1,config.nbins) counts, bins = np.histogram(cosines, bins) # fig, ax = plt.subplot(1) plt.hist(bins[:-1], bins=bins, weights=counts, density=True) plt.title(f"Cosine similarity, pairs of Unrealistic and Realistic images, {config.original.key}") plt.xlabel("Cosine Similarity") plt.grid() plt.tight_layout() plt.savefig( output_folder / "hist.png", dpi=300,bbox_inches='tight' ) plt.cla() plt.clf() bins = np.linspace(-1,1,config.nbins//2) counts_inter_id_gen, bins = np.histogram(inter_id_gen, bins) plt.hist(bins[:-1], bins=bins, weights=counts_inter_id_gen, density=True, label="Generated", alpha=0.3) counts_inter_id_orig, bins = np.histogram(inter_id_orig, bins) plt.hist(bins[:-1], bins=bins, weights=counts_inter_id_orig, density=True, label="Original", alpha=0.3) plt.xlabel("Cosine Similarity") plt.title(f"Positive Pair Distribution") plt.grid() plt.legend() plt.tight_layout() plt.savefig( output_folder / "positive.png", dpi=300, bbox_inches='tight' ) distant_indices = distant_indices[:10] distant_orig_paths = original[distant_indices][config.path_key] distant_gen_paths = generated[distant_indices][config.path_key] close_indices = close_indices[:10] close_orig_paths = original[close_indices][config.path_key] close_gen_paths = generated[close_indices][config.path_key] def image_pair_gen(orig_paths, gen_paths, prefix=""): iter = 0 for orig, gen in zip(orig_paths, gen_paths): w, h = 256, 256 orig = Image.open(orig).resize((w,h)) gen = Image.open(gen).resize((w,h)) final = Image.new("RGB", (2*w,h)) final.paste(orig, (0,0)) final.paste(gen, (w,0)) final.save(output_folder / f"{prefix}_{iter}.jpg") iter += 1 image_pair_gen(distant_orig_paths, distant_gen_paths, f"distant_{config.distant_threshold}") image_pair_gen(close_orig_paths, close_gen_paths, f"close_{config.close_threshold}")