{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 0
   },
   "source": [
    "# Custom Layers\n",
    "\n",
    "One factor behind deep learning's success\n",
    "is the availability of a wide range of layers\n",
    "that can be composed in creative ways\n",
    "to design architectures suitable\n",
    "for a wide variety of tasks.\n",
    "For instance, researchers have invented layers\n",
    "specifically for handling images, text,\n",
    "looping over sequential data,\n",
    "and\n",
    "performing dynamic programming.\n",
    "Sooner or later, you will encounter or invent\n",
    "a layer that does not exist yet in the deep learning framework.\n",
    "In these cases, you must build a custom layer.\n",
    "In this section, we show you how.\n",
    "\n",
    "## (**Layers without Parameters**)\n",
    "\n",
    "To start, we construct a custom layer\n",
    "that does not have any parameters of its own.\n",
    "This should look familiar if you recall our\n",
    "introduction to block in :numref:`sec_model_construction`.\n",
    "The following `CenteredLayer` class simply\n",
    "subtracts the mean from its input.\n",
    "To build it, we simply need to inherit\n",
    "from the base layer class and implement the forward propagation function.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "origin_pos": 1,
    "tab": [
     "mxnet"
    ]
   },
   "outputs": [],
   "source": [
    "from mxnet import np, npx\n",
    "from mxnet.gluon import nn\n",
    "\n",
    "npx.set_np()\n",
    "\n",
    "class CenteredLayer(nn.Block):\n",
    "    def __init__(self, **kwargs):\n",
    "        super().__init__(**kwargs)\n",
    "\n",
    "    def forward(self, X):\n",
    "        return X - X.mean()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 4
   },
   "source": [
    "Let us verify that our layer works as intended by feeding some data through it.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "origin_pos": 5,
    "tab": [
     "mxnet"
    ]
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([-2., -1.,  0.,  1.,  2.])"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "layer = CenteredLayer()\n",
    "layer(np.array([1, 2, 3, 4, 5]))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 8
   },
   "source": [
    "We can now [**incorporate our layer as a component\n",
    "in constructing more complex models.**]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "origin_pos": 9,
    "tab": [
     "mxnet"
    ]
   },
   "outputs": [],
   "source": [
    "net = nn.Sequential()\n",
    "net.add(nn.Dense(128), CenteredLayer())\n",
    "net.initialize()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 12
   },
   "source": [
    "As an extra sanity check, we can send random data\n",
    "through the network and check that the mean is in fact 0.\n",
    "Because we are dealing with floating point numbers,\n",
    "we may still see a very small nonzero number\n",
    "due to quantization.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "origin_pos": 13,
    "tab": [
     "mxnet"
    ]
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array(3.783498e-10)"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "Y = net(np.random.uniform(size=(4, 8)))\n",
    "Y.mean()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 16
   },
   "source": [
    "## [**Layers with Parameters**]\n",
    "\n",
    "Now that we know how to define simple layers,\n",
    "let us move on to defining layers with parameters\n",
    "that can be adjusted through training.\n",
    "We can use built-in functions to create parameters, which\n",
    "provide some basic housekeeping functionality.\n",
    "In particular, they govern access, initialization,\n",
    "sharing, saving, and loading model parameters.\n",
    "This way, among other benefits, we will not need to write\n",
    "custom serialization routines for every custom layer.\n",
    "\n",
    "Now let us implement our own version of the  fully-connected layer.\n",
    "Recall that this layer requires two parameters,\n",
    "one to represent the weight and the other for the bias.\n",
    "In this implementation, we bake in the ReLU activation as a default.\n",
    "This layer requires to input arguments: `in_units` and `units`, which\n",
    "denote the number of inputs and outputs, respectively.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "origin_pos": 17,
    "tab": [
     "mxnet"
    ]
   },
   "outputs": [],
   "source": [
    "class MyDense(nn.Block):\n",
    "    def __init__(self, units, in_units, **kwargs):\n",
    "        super().__init__(**kwargs)\n",
    "        self.weight = self.params.get('weight', shape=(in_units, units))\n",
    "        self.bias = self.params.get('bias', shape=(units,))\n",
    "\n",
    "    def forward(self, x):\n",
    "        linear = np.dot(x, self.weight.data(ctx=x.ctx)) + self.bias.data(\n",
    "            ctx=x.ctx)\n",
    "        return npx.relu(linear)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 20,
    "tab": [
     "mxnet"
    ]
   },
   "source": [
    "Next, we instantiate the `MyDense` class\n",
    "and access its model parameters.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "origin_pos": 22,
    "tab": [
     "mxnet"
    ]
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "mydense0_ (\n",
       "  Parameter mydense0_weight (shape=(5, 3), dtype=<class 'numpy.float32'>)\n",
       "  Parameter mydense0_bias (shape=(3,), dtype=<class 'numpy.float32'>)\n",
       ")"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dense = MyDense(units=3, in_units=5)\n",
    "dense.params"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 25
   },
   "source": [
    "We can [**directly carry out forward propagation calculations using custom layers.**]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "origin_pos": 26,
    "tab": [
     "mxnet"
    ]
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[0.        , 0.01633355, 0.        ],\n",
       "       [0.        , 0.01581812, 0.        ]])"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dense.initialize()\n",
    "dense(np.random.uniform(size=(2, 5)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 29
   },
   "source": [
    "We can also (**construct models using custom layers.**)\n",
    "Once we have that we can use it just like the built-in fully-connected layer.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "origin_pos": 30,
    "tab": [
     "mxnet"
    ]
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[0.06508517],\n",
       "       [0.0615553 ]])"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "net = nn.Sequential()\n",
    "net.add(MyDense(8, in_units=64),\n",
    "        MyDense(1, in_units=8))\n",
    "net.initialize()\n",
    "net(np.random.uniform(size=(2, 64)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 33
   },
   "source": [
    "## Summary\n",
    "\n",
    "* We can design custom layers via the basic layer class. This allows us to define flexible new layers that behave differently from any existing layers in the library.\n",
    "* Once defined, custom layers can be invoked in arbitrary contexts and architectures.\n",
    "* Layers can have local parameters, which can be created through built-in functions.\n",
    "\n",
    "\n",
    "## Exercises\n",
    "\n",
    "1. Design a layer that takes an input and computes a tensor reduction,\n",
    "   i.e., it returns $y_k = \\sum_{i, j} W_{ijk} x_i x_j$.\n",
    "1. Design a layer that returns the leading half of the Fourier coefficients of the data.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 34,
    "tab": [
     "mxnet"
    ]
   },
   "source": [
    "[Discussions](https://discuss.d2l.ai/t/58)\n"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}