import numpy as np
import tensorflow as tf
import cPickle as pickle
from sklearn.cluster import KMeans
import sys

feat_dim = 16
nsamples = 60
featname_dim = 42


def find(theset, key):
    idx_list = []
    for idx in range(0, len(theset)):
        if theset[idx] == key:
            idx_list.append(idx)
    return idx_list


def jaccard(s1, s2):
    return len(set(s1).intersection(set(s2))) * 1.0 / len(set(s1).union(set(s2)))


def F1(s1, s2):
    if len(set(s2)) == 0:
        precision = 0
    else:
        precision = len(set(s1).intersection(set(s2))) * 1.0 / len(set(s2))
    if len(set(s1)) == 0:
        recall = 0
    else:
        recall = len(set(s1).intersection(set(s2))) * 1.0 / len(set(s1))
    if precision == 0 and recall == 0:
        return 0
    return 2 * precision * recall / (precision + recall)


def cal_error(set1, set2, type):
    num = np.max(set1)
    s1 = []
    s2 = []

    for i in range(0, num + 1):
        s1.append(find(set1, i))
        s2.append(find(set2, i))
    #print s1
    #print s2
    max1 = []
    max2 = []
    for i in range(0, num + 1):
        maxval = 0
        for j in range(0, num + 1):
            if type == 'Jaccard':
                error = jaccard(s1[i], s2[j])
            else:
                error = F1(s1[i], s2[j])
            if error > maxval:
                maxval = error
        max1.append(maxval)

    for i in range(0, num + 1):
        maxval = 0
        for j in range(0, num + 1):
            if type == 'Jaccard':
                error = jaccard(s2[i], s1[j])
            else:
                error = F1(s2[i], s1[j])
            if error > maxval:
                maxval = error
        max2.append(maxval)
    #print max1
    #print max2
    return (sum(max1) + sum(max2)) / (2 * (num + 1))


def build_dataset(community, feat, graph, length):
    train_number = 18
    out = np.zeros([nsamples, len(community)])
    train_idx = []
    test_idx = []
    train_input = []
    train_output = []
    test_input = []
    test_output = []

    for i in range(0, length):
        for node in community[i][1:len(community[i])]:
            # out[int(node), i] = 1
            if not (int(node) in train_idx):
                train_idx.append(int(node))
                train_input.append(feat[int(node), :])
                # train_output.append(out[int(node), 0:length])
    for node in train_idx:
        for j in range(0, length):
            if node in community[j]:
                out[node, j] = 1
        train_output.append(out[int(node), 0:length])
    train_graph = graph[train_idx, :][:, train_idx].astype(np.float32)
    for i in range(0, len(train_idx)):
        number_of_relation = sum(train_graph[i, :]) + 1
        train_graph[i, i] = 1
        train_graph[i, :] = train_graph[i, :] / number_of_relation

    for i in range(length, len(community)):
        for node in community[i][1:len(community[i])]:
            # out[int(node), i] = 1
            if not (int(node) in test_idx):
                test_idx.append(int(node))
                test_input.append(feat[int(node), :])
                # test_output.append(out[int(node), length:len(community)])
    for node in test_idx:
        for j in range(length, len(community)):
            if node in community[j]:
                out[node, j] = 1
        test_output.append(out[node, length:len(community)])

    test_graph = graph[test_idx, :][:, test_idx].astype(np.float32)
    for i in range(0, len(test_idx)):
        number_of_relation = sum(test_graph[i, :]) + 1
        test_graph[i, i] = 1
        test_graph[i, :] = test_graph[i, :] / number_of_relation

    return np.array(train_input), \
           np.array(train_output), np.array(train_graph), \
           np.array(test_input), \
           np.array(test_output), np.array(test_graph)


def model(constant_b=0.1, alpha = 0.1, num_hidden=8, Epoch=10, train_number=18, path=1):

    def body(iteration, relation, temp):
        return iteration + 1, relation, tf.matmul(relation, temp)

    def condition(iteration, relation, temp):
        return iteration < path

    # read the data from pkl file
    f1 = open('community.pkl', 'r+')
    community = pickle.load(f1)
    f1.close()
    f1 = open('feat.pkl', 'r+')
    feat = pickle.load(f1)
    feat = feat.astype(np.float32)
    f1.close()

    f1 = open('Graph.pkl', 'r+')
    graph = pickle.load(f1)
    f1.close()

    train_input, train_output, train_graph, test_input, test_output, test_graph = build_dataset(community, feat, graph, train_number)

    # f1 = open('testdata/3980_truth.txt', 'w+')
    # for i in range(0, test_output.shape[1]):
    #     #f1.writelines("circle"+str(i)+" ")
    #     for j in range(0, test_output.shape[0]):
    #         if test_output[j, i] > 0:
    #             f1.writelines(str(j)+" ")
    #     f1.writelines("\n")
    # f1.close()
    #
    # f2 = open('traindata/3980.edges', 'w+')
    # for i in range(0, test_graph.shape[0]):
    #     for j in range(i, test_graph.shape[0]):
    #         if test_graph[i, j] > 0 and i != j:
    #             f2.writelines(str(i)+" "+str(j)+"\n")
    # f2.close()
    #
    # f3 = open('traindata/3980.feat', 'w+')
    # for i in range(0, test_input.shape[0]):
    #     temp = np.zeros(featname_dim, dtype=np.int64)
    #     for node in list(test_input[i, :]):
    #         if node > 0:
    #             # print int(node)
    #             temp[int(node)-2] = 1
    #     for node in list(temp):
    #         f3.writelines(str(node)+" ")
    #     f3.writelines("\n")
    # f3.close()

    # initialize the variable: input data for the neural network and the output.
    data = tf.placeholder(tf.float32, [None, feat_dim, 1])
    target = tf.placeholder(tf.float32, [None, train_number])
    relation = tf.placeholder(tf.float32, [None, None])
    # build the LSTM cell, num_hidden is the embedding size
    # num_hidden = 8
    cell = tf.nn.rnn_cell.LSTMCell(num_hidden, state_is_tuple=True)

    # build the RNN layer, which will output a [batch_size, max_time, cell.output_size]
    # since RNN need sequential input, it indicates that we input data max_time times.
    val, state = tf.nn.dynamic_rnn(cell, data, dtype=tf.float32)

    # transpose [1, 0, 2] means the original one is [0, 1, 2] and we swap the first 2 dimensions.
    # swap the batch_size with sequence size.
    val = tf.transpose(val, [1, 0, 2])
    # pooling.
    # last = tf.gather(val, int(val.get_shape()[0]) - 1) # last pooling
    last = tf.reduce_mean(val, 0) # mean pooling
    iteration = tf.Variable(tf.constant(0))
    temp = last

    # random-walk layer
    iteration_t, relation_t, walklayer = tf.while_loop(condition, body, [iteration, relation, temp])
    # walklayer = tf.matmul(relation, tf.matmul(relation, tf.matmul(relation, last)))

    # softmax layer
    weight = tf.Variable(tf.truncated_normal([num_hidden, int(target.get_shape()[1])]))
    bias = tf.Variable(tf.constant(constant_b, shape=[target.get_shape()[1]]))

    prediction = tf.nn.softmax(tf.matmul(walklayer, weight) + bias)
    # softlayer = tf.nn.softmax(tf.matmul(last, weight) + bias)
    # prediction = tf.matmul(relation, softlayer)
    # error_norm = tf.reduce_sum(tf.square(prediction - target))
    # frobenius_norm = tf.reduce_sum(tf.square(last - tf.matmul(relation, last)))

    # Loss function
    cross_entropy = -tf.reduce_sum(target * tf.log(prediction))
    # cross_entropy = error_norm + alpha * frobenius_norm
    # optimize the Loss function
    optimizer = tf.train.AdamOptimizer()
    minimize = optimizer.minimize(cross_entropy)

    # mistakes = tf.not_equal(tf.argmax(target, 1), tf.argmax(prediction, 1))
    # error = tf.reduce_mean(tf.cast(mistakes, tf.float32))

    # initialization
    init_op = tf.initialize_all_variables()
    sess = tf.Session()
    sess.run(init_op)

    # run phase
    # batch_size = 1000
    # no_of_batches = int(len(train_input)/batch_size)
    # epoch = 1000
    # for i in range(epoch):
    #     ptr = 0
    #     for j in range(no_of_batches):
    #         inp, out = train_input[ptr:ptr+batch_size], train_output[ptr:ptr+batch_size]
    #         ptr += batch_size
    #         sess.run(minimize, {data: inp, target: out})
    #     if (i % 100 == 0):
    #         print "Epoch - ", str(i)

    for i in range(Epoch):
        sess.run(minimize, {data: train_input, target: train_output, relation: train_graph})

    # incorrect = sess.run(error, {data: test_input, target: test_output})
    # print('error {:3.1f}%'.format(100 * incorrect))

    pred = sess.run(walklayer, {data: test_input, relation: test_graph})
    sess.close()
    #print pred
    clustering = KMeans(n_clusters=len(community) - train_number, init='k-means++')
    clustering.fit(pred)

    # print clustering.labels_
    # print np.argmax(test_output, 1)
    return (cal_error(clustering.labels_, np.argmax(test_output, 1), 'F1'),
            cal_error(clustering.labels_, np.argmax(test_output, 1), 'Jaccard'))

# b = float(sys.argv[1])
# alpha = float(sys.argv[2])
# num_hid = int(sys.argv[3])
# Epo = int(sys.argv[4])
# tn = int(sys.argv[5])
# f = open('data','a+')
# f.writelines('b = {:.2f}, alpha = {:.2f}, num_hidden = {:d}, Epoch = {:d}\n'.format(b, alpha, num_hid, Epo))
# (F1score, Jaccard) = model(constant_b=b, num_hidden=num_hid, Epoch=Epo, train_number=tn)
(F1score, Jaccard) = model(constant_b=0.1, num_hidden=28, Epoch=20, train_number=8, path=5)
print F1score
print Jaccard
# f.writelines('F1 = {:.3f}, Jaccard = {:.3f}\n'.format(F1score, Jaccard))
# f.close()
