'''
Author: Carl Yang
Function: Evaluate the embeddings
Command: library
'''
from __future__ import absolute_import, division, print_function, unicode_literals
import torch
from datetime import datetime
import numpy as np


class Evaluation(object):
    # Initialize the Evaluation object with current parameters.
    def __init__(self, params):
        self.params = params

    # Main access of the evaluations.
    # Can evaluate both with and without a trained embedding model.
    def evaluate(self, dataset, embed=None, dim=None):
        print('{}: Evaluating the model...'.format(datetime.now()))
        if self.params.raw or embed is None:
            if dim is None:
                dim = self.params.dimension
            norm = torch.nn.BatchNorm1d(dim[1] - dim[0])(dataset.feat_test[1])
            id = torch.LongTensor([int(x) for x in dataset.feat_test[0]])
            self.eval_vec(dataset, [id, norm])
        else:
            pred = embed.compute(dataset.feat_test)
            self.eval_vec(dataset, pred)

    # Compute the evaluation metrics based on the input embedding or feature vectors.
    def eval_vec(self, dataset, vec):
        if self.params.eval in ('knn', 'all'):
            if self.params.verbose:
                print('{}: evaluating test embeddings with knn...'
                    .format(datetime.now()))
            self.eval_knn(vec, dataset.truth_knn)
        if self.params.eval in ('pairwise', 'all'):
            if self.params.verbose:
                print('{}: evaluating test embeddings with pairwise...'
                    .format(datetime.now()))
            self.eval_pair(vec, dataset.truth_pair)
        if self.params.verbose:
            if self.params.eval in ('knn', 'all'):
                print(self.pre_list.tolist())
                print(self.rec_list.tolist())
                print(self.pre_avg)
                print(self.rec_avg)
            if self.params.eval in ('pairwise', 'all'):
                print(self.pair)
            dataset.write_embedding(vec)

    # Compute the pair-wise accuracy.
    def eval_pair(self, pred, truth):
        id_list = pred[0].numpy().tolist()
        emb_list = pred[1].detach()
        n_total = 0
        n_correct = 0
        for key in truth.keys():
            x_id = id_list.index(int(key))
            x_emb = emb_list[x_id, :]
            for pos in truth[key][0]:
                if int(pos) in id_list:
                    pos_id = id_list.index(int(pos))
                    pos_emb = emb_list[pos_id, :]
                else:
                    pos_emb = torch.zeros_like(emb_list[0, :])
                for neg in truth[key][1]:
                    if int(neg) in id_list:
                        neg_id = id_list.index(int(neg))
                        neg_emb = emb_list[neg_id, :]
                    else:
                        neg_emb = torch.zeros_like(emb_list[0, :])
                    n_total += 1
                    if (x_emb - pos_emb).pow(2).sum() < (x_emb - neg_emb).pow(2).sum():
                        n_correct += 1
        if self.params.verbose:
            print('{}: finished pairwise evaluation of {} pairs'.format(
                datetime.now(),
                pred[0].shape[0]
            ))
        self.pair = n_correct * 1.0 / n_total

    # Compute the precision and recall of exact knn.
    def eval_knn(self, pred, truth):
        emb_id = []
        emb_feat = []
        id_list = pred[0].numpy().tolist()
        emb_list = pred[1].detach().numpy().tolist()
        for key in truth.keys():
            id = id_list.index(int(key))
            emb_id.append(int(key))
            emb_feat.append(emb_list[id])
        emb_feat = torch.from_numpy(np.asarray(emb_feat))

        emb = emb_feat.to(self.params.device)
        x_norm = (emb**2).sum(1).view(-1, 1)
        y_norm = x_norm.view(1, -1)
        dist = x_norm + y_norm - 2.0 * torch.mm(emb, torch.transpose(emb, 0, 1))
        _, indices = torch.sort(dist)
        indices = indices.to(torch.device('cpu'))

        pre = torch.zeros((len(emb_id), self.params.k_max))
        rec = torch.zeros((len(emb_id), self.params.k_max))
        for i in range(len(emb_id)):
            correct = 0.0
            for k in range(self.params.k_max):
                if emb_id[indices[i, k + 1].item()] in truth[emb_id[i]]:
                    correct += 1.0
                pre[i, k] = correct / (k + 1)
                rec[i, k] = correct / len(truth[emb_id[i]])
        if self.params.verbose:
            print('{}: finished knn evaluation of {} places'.format(
                datetime.now(),
                len(emb_id)
            ))
        self.pre_list = pre.mean(dim=0).numpy()
        self.pre_avg = pre.mean().item()
        self.rec_list = rec.mean(dim=0).numpy()
        self.rec_avg = rec.mean().item()
