
# choosing the right step size is more of an art than a science. 
#
# Popular options include:
#
#   (1) Using a fixed step size
#   (2) Gradually shrinking the step size over time
#   (3) At each step, choosing the step size that minimizes 
#       the value of the objective function (very computation intensive)

step_sizes = [100, 10, 1, 0.1, 0.01, 0.001, 0.0001, 0.00001]

# It is possible that certain step sizes will result in invalid inputs 
# for our function. 
# So we’ll need to create a “safe apply” function that returns infinity 
# (which should never be the minimum of anything) for invalid inputs:

# safe(f):
#
#    Input: function f
#    Output: safe function that calls f

def safe(f):
    """return a new function that's the same as f,
       except that it outputs infinity whenever f produces an error
    """

    def safe_f(*args, **kwargs):
        try:
            return f(*args, **kwargs)
        except:
            return float('inf') # this means "infinity" in Python

    return safe_f

# NOTE:
#
# You call  safe(f) with:
#
#     f = some_function		# f = address of a function
#     f = safe(f)		# We PASS f to safe (some FIXES the address
#  				# and OBTAINS the address of a NEW function
#				# which is safe_f !!!
#				# (It seems like the Python can remember
#				# the address PASSED to a prior invocation...
#

# ###################################################
# step(v, dir, step_size):
#
#      Return v' by starting at v and go "step_size"
#      in direction "dir"
#
#      v is a vector
#      dir is the gradient

def step(v, direction, step_size):
    """move step_size in the direction from v"""
    return [v_i + step_size * direction_i
            for v_i, direction_i in zip(v, direction)]



# +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
# Stochastic gradient descent
# +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++

def in_random_order(data):
    """
    generator that returns the elements of **data** in random order
    """

    indexes = [i for i, _ in enumerate(data)] # create a list of indexes
					      # indexes = [0,1,2,...]

    random.shuffle(indexes) # shuffle them
                            # indexes = [4,2,1,9,... randomized]

    for i in indexes:       # return the data in that order
        yield data[i]       # returns data[4], then data[2]... etc


# ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
# Stochastic gradient descent alg:
#
# What the heck is x and y ?????????????????????????????????????????

def minimize_stochastic(target_fn, gradient_fn, x, y, theta_0, alpha_0=0.01):

    data = zip(x, y)

    theta = theta_0 # initial guess
    alpha = alpha_0 # initial step size

    min_theta, min_value = None, float("inf") 	# the minimum so far

    iterations_with_no_improvement = 0

    # if we ever go 100 iterations with no improvement, stop
    while iterations_with_no_improvement < 100:

        value = sum( target_fn(x_i, y_i, theta) for x_i, y_i in data )

        if value < min_value:
            # if we've found a new minimum, remember it
            min_theta, min_value = theta, value

            # and go back to the original step size
            alpha = alpha_0

            iterations_with_no_improvement = 0
        else:
            # otherwise we're not improving, so try shrinking the step size
            alpha *= 0.9

            iterations_with_no_improvement += 1

        # and take a gradient step for each of the data points
        for x_i, y_i in in_random_order(data):
            gradient_i = gradient_fn(x_i, y_i, theta)
            theta = vector_subtract(theta, scalar_multiply(alpha, gradient_i))
    return min_theta














# ###################################################################
# How to use the "safe" gradient descent with optimal step size:

def sum_of_squares(v):
    """computes the sum of squared elements in v"""
    return sum(v_i ** 2 for v_i in v)

def sum_of_squares_gradient(v):
    return [2 * v_i for v_i in v]


v = [2,6,9,8]
min_v = minimize_batch(sum_of_squares, sum_of_squares_gradient, v)

print(v, min_v)

