
from typing import List
import math
import pprint
import random
import matplotlib.pyplot as plt

from scratch.probability import inverse_normal_cdf


def random_normal() -> float:
    """Returns a random draw from a standard normal distribution"""
    return inverse_normal_cdf(random.random())


#===================================================

plt.close()		# Close the plot created in scratch.probability

# Just some random data to show off correlation scatterplots
num_points = 100

####################################################
# random_row(): return a list of 4 random numbers
####################################################
def random_row() -> List[float]:
    row = [0.0, 0, 0, 0]
    row[0] = random_normal()
    row[1] = -5 * row[0] + random_normal()
    row[2] = row[0] + row[1] + 5 * random_normal()
    row[3] = 6 if row[2] > -2 else 0
#   print("row = ", row)
    return row

random.seed(0)

####################################################
# each row has 4 points, but really we want the columns
####################################################
corr_rows = [random_row() for _ in range(num_points)]   # num_points = 100
pprint.pprint(corr_rows)

corr_data = [list(col) for col in zip(*corr_rows)]

# corr_data is a list of four 100-d vectors
num_vectors = len(corr_data)
fig, ax = plt.subplots(num_vectors, num_vectors)

for i in range(num_vectors):
    for j in range(num_vectors):

        # Scatter column_j on the x-axis vs column_i on the y-axis,
        if i != j: 
            ax[i][j].scatter(corr_data[j], corr_data[i])

        # unless i == j, in which case show the series name.
        else: 
            ax[i][j].annotate("series " + str(i), (0.5, 0.5),
                                xycoords='axes fraction',
                                ha="center", va="center")

        # Then hide axis labels except left and bottom charts
        if i < num_vectors - 1: 
            ax[i][j].xaxis.set_visible(False)
        if j > 0: 
            ax[i][j].yaxis.set_visible(False)

# Fix the bottom right and top left axis labels, which are wrong because
# their charts only have text in them
ax[-1][-1].set_xlim(ax[0][-1].get_xlim())
ax[0][0].set_ylim(ax[0][1].get_ylim())

plt.show()

