
import random
from typing import TypeVar, List, Tuple
X = TypeVar('X')  # generic type to represent a data point


# split_data(data, prob):
#
#    Split "data" into 2 set [1], [2] 
#    Put a fraction p of "data " in set [1] and 1-p in set [2]
#

def split_data(data: List[X], prob: float) -> Tuple[List[X], List[X]]:
    """Split data into fractions [prob, 1 - prob]"""
    data = data[:]                    # Make a copy
    random.shuffle(data)              # because shuffle modifies its input.
    cut = int(len(data) * prob)       # Use prob to find a cutoff
    return data[:cut], data[cut:]     # and split the shuffled list there.

# def split_data(data, prob):
#     """split data into fractions [prob, 1 - prob]"""
#     results = [], []
#     for row in data:
#         results[0 if random.random() < prob else 1].append(row)
# #               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
# #               This is an index(!!!): 0 with probab "prob"
#     return results

data = [n for n in range(100)]
print(data, "\n", len(data), "\n")

o1,o2 = split_data(data, 0.1)
print(o1, "\n", len(o1), "\n")
print(o2, "\n", len(o2), "\n")




# train_test_split(x, y, test_pct):
#
#    x = input Matrix of observations
#    y = output vector of (actual) outcomes
#    test_pct = percentage used to train
#
# Output:
#     x_train, x_test, y_train, y_test
#     Matrix   Matrix  vector   vector

Y = TypeVar('Y')  # generic type to represent output variables

def train_test_split(xs: List[X],
                     ys: List[Y],
                     test_pct: float) -> Tuple[List[X], List[X], List[Y], List[Y]]:

    # Generate the indices  [0,1,2,.... len(xs)]
    idxs = [i for i in range(len(xs))]

    # Split into 2 sets of randomized indexes: e.g.: [1,6,2,...] [4,9,...]
    train_idxs, test_idxs = split_data(idxs, 1 - test_pct)

    return ([xs[i] for i in train_idxs],  # x_train
            [xs[i] for i in test_idxs],   # x_test
            [ys[i] for i in train_idxs],  # y_train
            [ys[i] for i in test_idxs])   # y_test


xs = [x for x in range(100)]  # xs are 1 ... 100
ys = [2 * x for x in xs]      # each y_i is twice x_i
print(xs, "\n", len(xs), "\n")
print(ys, "\n", len(ys), "\n")

x_train, x_test, y_train, y_test = train_test_split(xs, ys, 0.1)

print(x_train, "\n", len(x_train), "\n")
print(x_test, "\n", len(x_test), "\n")
print(y_train, "\n", len(y_train), "\n")
print(y_test, "\n", len(y_test), "\n")


