{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "7c316db5",
   "metadata": {},
   "outputs": [],
   "source": [
    "from mxnet import autograd, np, npx\n",
    "\n",
    "npx.set_np()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "450bf7cf",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([0., 1., 2., 3.])"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x = np.arange(4.0)\n",
    "x\n",
    "\n",
    "# Computing derivative function of y = 2*x^T*x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "7f0a2fd3",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array(14.)"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.dot(x,x)\n",
    "# y = 2(x0^2 + x1^2 + x2^2 + x3^2)      (y is a SCALAR)\n",
    "# dy/dx0 = 4 x0\n",
    "# dy/dx1 = 4 x1\n",
    "# dy/dx2 = 4 x2\n",
    "# dy/dx3 = 4 x3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "f1143f75",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([0., 0., 0., 0.])"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# We allocate memory for a tensor's gradient by invoking `attach_grad`\n",
    "x.attach_grad()\n",
    "\n",
    "# After we calculate a gradient taken with respect to `x`, we will be able to\n",
    "# access it via the `grad` attribute, whose values are initialized with 0s\n",
    "x.grad"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "742c6b22",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array(28.)"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Place our code inside an `autograd.record` scope to build the computational\n",
    "# graph\n",
    "with autograd.record():\n",
    "    y = 2 * np.dot(x, x)\n",
    "y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "485ae83a",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([ 0.,  4.,  8., 12.])"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# we can automatically calculate the gradient of y with respect to \n",
    "# each component of x by calling the function for backpropagation:\n",
    "y.backward()\n",
    "\n",
    "# Then print the print the gradient\n",
    "x.grad\n",
    "\n",
    "# y = 2 * (x0^2 + x1^2 + x2^2 + x3^2)\n",
    "# dy/dx0 = 4 x0\n",
    "# dy/dx1 = 4 x1\n",
    "# dy/dx2 = 4 x2\n",
    "# dy/dx3 = 4 x3\n",
    "# x = [0,1,2,3]\n",
    "# So: grad(x) = [0,4,8,12]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "4ceae63d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([1., 1., 1., 1.])"
      ]
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Example 2: y = x0 + x1 + x2 + x3           (y is a SCALAR)\n",
    "# dy/dx0 = 1\n",
    "# dy/dx1 = 1\n",
    "# dy/dx2 = 1\n",
    "# dy/dx3 = 1\n",
    "# x = [0,1,2,3]\n",
    "\n",
    "with autograd.record():\n",
    "    y = x.sum()\n",
    "    \n",
    "y.backward()\n",
    "x.grad # Overwritten by the newly calculated gradient"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "1b39ee6f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([0., 2., 4., 6.])"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# What does \"backward()\" do if the function y=f(x) is a vector value ?\n",
    "# Technically, when y is not a scalar, the most natural interpretation of \n",
    "# the differentiation of a vector y with respect to a vector x is a matrix. \n",
    "# For higher-order and higher-dimensional y and x, the\n",
    "# differentiation result could be a high-order tensor.\n",
    "#\n",
    "# In practice: we are calling backward on a vector fuction, we are trying \n",
    "# to calculate the derivatives of the loss functions.\n",
    "# Our intent is not to calculate the differentiation matrix but rather:\n",
    "#       the sum of the partial derivatives\n",
    "# Therefore: if y=f() is a vector, mxnet treats it as:   f().sum()\n",
    "#\n",
    "# When we invoke `backward` on a vector-valued variable `y` (function of `x`),\n",
    "# a new scalar variable is created by summing the elements in `y`. Then the\n",
    "# gradient of that scalar variable with respect to `x` is computed\n",
    "\n",
    "with autograd.record():\n",
    "    y = x*x        # y = (x0^2, x1^2, x2^2, x3^2) ==> x0^2 + x1^2 + x2^2 + x3^2\n",
    "\n",
    "# y' = (dy/dx0, dy/dx1, dy/dx2, dy/dx3) = (2*x0, 2*x1, 2*x2, 2*x3)\n",
    "y.backward()\n",
    "x.grad           # Equals to y = sum(x * x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "id": "47e3a2e0",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([0., 1., 4., 9.])"
      ]
     },
     "execution_count": 30,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Detaching Computation\n",
    "# Suppose:  y = f(x)\n",
    "#           z = g(x,y)\n",
    "# Normally: dz/dx will use the chain rule...\n",
    "# Suppose we wanted to calculate the gradient of z with respect\n",
    "# to x, but wanted for some reason to treat y as a constant\n",
    "# The \"detached\" compuation can be used to compute the derivative:\n",
    "\n",
    "with autograd.record():\n",
    "    y = x * x         # x=[0,1,2,3], y = x*x = [0,1,4,9]\n",
    "    u = y.detach()    # u is now a constant\n",
    "    z = u * x         # dz/dx = u !!!\n",
    "z.backward()\n",
    "x.grad"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "id": "9f4f5e80",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([0., 2., 4., 6.])"
      ]
     },
     "execution_count": 31,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# When y is detached, the computation of y WILL be recorded ALSO\n",
    "# Therefore, we can compute the backpropagation on y:\n",
    "\n",
    "y.backward()\n",
    "x.grad"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "id": "08285092",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([ 0.,  3., 12., 27.])"
      ]
     },
     "execution_count": 32,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Without detaching:\n",
    "\n",
    "with autograd.record():\n",
    "    y = x * x         # x=[0,1,2,3], y = x*x = [0,1,4,9]\n",
    "    z = y * x         # z = x*x*x = (x0^3, x1^3, x2^3, x3^3) ===> x0^3 + x1^3 + x2^3 + x3^3\n",
    "z.backward()          # z' = (3*x0^2, 3*x1^2, 3*x2^2, 3^x3^2)\n",
    "x.grad                #    = (3*0 + 3*1 + 3*4 + 3*9)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "id": "6398ab70",
   "metadata": {},
   "outputs": [],
   "source": [
    "# *** SKIP THIS *** Useless diversion\n",
    "\n",
    "\n",
    "# Studying a weird functio\n",
    "# Here is a weird function:\n",
    "\n",
    "def f(a):\n",
    "    b = a * 2\n",
    "    \n",
    "    while np.linalg.norm(b) < 1000:\n",
    "        b = b * 2\n",
    "        \n",
    "    if b.sum() > 0:\n",
    "        c = b\n",
    "    else:\n",
    "        c = 100 * b\n",
    "    return c\n",
    "\n",
    "# This function is PIECEWISE linear !!!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "id": "6bd4c7ca",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([1029.12])"
      ]
     },
     "execution_count": 49,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x=np.array([2.01])\n",
    "f(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "id": "4f771964",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array(1585.1598)"
      ]
     },
     "execution_count": 42,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "a = np.random.normal()\n",
    "f(a)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "id": "3c099130",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array(2048.)"
      ]
     },
     "execution_count": 52,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Compute f'(a): the gradient at a\n",
    "\n",
    "a.attach_grad()\n",
    "with autograd.record():\n",
    "    d = f(a)\n",
    "d.backward()     # Because function is piecewise linear, d is the gradient !!\n",
    "d/a"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "id": "bbac1718",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array(2048.)"
      ]
     },
     "execution_count": 51,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "a.grad"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8647c84c",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
