[functorch] Add complexities and references for NTK implementations. (pytorch/functorch#907)

* Add complexities and references for NTK implementations.

* Fix result names; rename "outer product" -> "matrix product".

* Fix names
This commit is contained in:
Roman Novak
2022-06-24 12:06:21 -07:00
committed by Jon Janzen
parent 1b5fbf872f
commit e642a34627

View File

@ -7,7 +7,7 @@
"source": [
"# Neural Tangent Kernels\n",
"\n",
"The neural tangent kernel (NTK) is a kernel that describes [how a neural network evolves during training](https://en.wikipedia.org/wiki/Neural_tangent_kernel). There has been a lot of research around it [in recent years](https://arxiv.org/abs/1806.07572). This tutorial, inspired by the implementation of [NTKs in JAX](https://github.com/google/neural-tangents), demonstrates how to easily compute this quantity using functorch."
"The neural tangent kernel (NTK) is a kernel that describes [how a neural network evolves during training](https://en.wikipedia.org/wiki/Neural_tangent_kernel). There has been a lot of research around it [in recent years](https://arxiv.org/abs/1806.07572). This tutorial, inspired by the implementation of [NTKs in JAX](https://github.com/google/neural-tangents) (see [Fast Finite Width Neural Tangent Kernel](https://arxiv.org/abs/2206.08720) for details), demonstrates how to easily compute this quantity using functorch."
]
},
{
@ -79,7 +79,7 @@
"\n",
"functorch transforms operate on functions. In particular, to compute the NTK, we will need a function that accepts the parameters of the model and a single input (as opposed to a batch of inputs!) and returns a single output.\n",
"\n",
"We'll use functorch's make_functional to accomplish the first step. If your module has buffers, you'll want to use make_functional_with_buffers instead."
"We'll use functorch's `make_functional` to accomplish the first step. If your module has buffers, you'll want to use `make_functional_with_buffers` instead."
]
},
{
@ -117,13 +117,15 @@
"id": "62bc6b5a-31fa-411e-8069-e6c1f6d05248",
"metadata": {},
"source": [
"## Compute the NTK: method 1\n",
"## Compute the NTK: method 1 (Jacobian contraction)\n",
"\n",
"We're ready to compute the empirical NTK. The empirical NTK for two data points `x1` and `x2` is defined as an inner product between the Jacobian of the model evaluated at `x1` and the Jacobian of the model evaluated at `x2`:\n",
"We're ready to compute the empirical NTK. The empirical NTK for two data points $x_1$ and $x_2$ is defined as the matrix product between the Jacobian of the model evaluated at $x_1$ and the Jacobian of the model evaluated at $x_2$:\n",
"\n",
"$$J_{net}(x1) \\cdot J_{net}^T(x2)$$\n",
"$$J_{net}(x_1) J_{net}^T(x_2)$$\n",
"\n",
"In the batched case where `x1` is a batch of data points and `x2` is a batch of data points, then we want the inner product between the Jacobians of all combinations of data points from `x1` and `x2`. Here's how to compute the NTK in the batched case:"
"In the batched case where $x_1$ is a batch of data points and $x_2$ is a batch of data points, then we want the matrix product between the Jacobians of all combinations of data points from $x_1$ and $x_2$.\n",
"\n",
"The first method consists of doing just that - computing the two Jacobians, and contracting them. Here's how to compute the NTK in the batched case:"
]
},
{
@ -133,7 +135,7 @@
"metadata": {},
"outputs": [],
"source": [
"def empirical_ntk(fnet_single, params, x1, x2):\n",
"def empirical_ntk_jacobian_contraction(fnet_single, params, x1, x2):\n",
" # Compute J(x1)\n",
" jac1 = vmap(jacrev(fnet_single), (None, 0))(params, x1)\n",
" jac1 = [j.flatten(2) for j in jac1]\n",
@ -163,7 +165,7 @@
}
],
"source": [
"result = empirical_ntk(fnet_single, params, x_train, x_test)\n",
"result = empirical_ntk_jacobian_contraction(fnet_single, params, x_train, x_test)\n",
"print(result.shape)"
]
},
@ -182,7 +184,7 @@
"metadata": {},
"outputs": [],
"source": [
"def empirical_ntk(fnet_single, params, x1, x2, compute='full'):\n",
"def empirical_ntk_jacobian_contraction(fnet_single, params, x1, x2, compute='full'):\n",
" # Compute J(x1)\n",
" jac1 = vmap(jacrev(fnet_single), (None, 0))(params, x1)\n",
" jac1 = [j.flatten(2) for j in jac1]\n",
@ -222,32 +224,39 @@
}
],
"source": [
"result = empirical_ntk(fnet_single, params, x_train, x_test, 'trace')\n",
"result = empirical_ntk_jacobian_contraction(fnet_single, params, x_train, x_test, 'trace')\n",
"print(result.shape)"
]
},
{
"cell_type": "markdown",
"id": "6c941e5d-51d7-47b2-80ee-edcd4aee6aaa",
"metadata": {},
"source": [
"The asymptotic time complexity of this method is $N O [FP]$ (time to compute the Jacobians) $ + N^2 O^2 P$ (time to contract the Jacobians), where $N$ is the batch size of $x_1$ and $x_2$, $O$ is the model's output size, $P$ is the total number of parameters, and $[FP]$ is the cost of a single forward pass through the model. See section section 3.2 in [Fast Finite Width Neural Tangent Kernel](https://arxiv.org/abs/2206.08720) for details."
]
},
{
"cell_type": "markdown",
"id": "6c931e5d-51d7-47b2-80ee-ddcd4aee6aaa",
"metadata": {},
"source": [
"## Compute the NTK: method 2\n",
"## Compute the NTK: method 2 (NTK-vector products)\n",
"\n",
"The next method we will discuss is a way to compute the NTK implicitly. This has different tradeoffs compared to the previous one and it is generally more efficient when your model has large parameters; we recommend trying out both methods to see which works better.\n",
"The next method we will discuss is a way to compute the NTK using NTK-vector products.\n",
"\n",
"Here's our definition of NTK:\n",
"This method reformulates NTK as a stack of NTK-vector products applied to columns of an identity matrix $I_O$ of size $O\\times O$ (where $O$ is the output size of the model):\n",
"\n",
"$$J_{net}(x1) \\cdot J_{net}^T(x2)$$\n",
"$$J_{net}(x_1) J_{net}^T(x_2) = J_{net}(x_1) J_{net}^T(x_2) I_{O} = \\left[J_{net}(x_1) \\left[J_{net}^T(x_2) e_o\\right]\\right]_{o=1}^{O},$$\n",
"where $e_o\\in \\mathbb{R}^O$ are column vectors of the identity matrix $I_O$.\n",
"\n",
"The implicit computation reformulates the problem by adding an identity matrix and rearranging the matrix-multiplies:\n",
"- Let $\\textrm{vjp}_o = J_{net}^T(x_2) e_o$. We can use a vector-Jacobian product to compute this.\n",
"- Now, consider $J_{net}(x_1) \\textrm{vjp}_o$. This is a Jacobian-vector product!\n",
"- Finally, we can run the above computation in parallel over all columns $e_o$ of $I_O$ using `vmap`.\n",
"\n",
"$$= J_{net}(x1) \\cdot J_{net}^T(x2) \\cdot I$$\n",
"$$= (J_{net}(x1) \\cdot (J_{net}^T(x2) \\cdot I))$$\n",
"This suggests that we can use a combination of reverse-mode AD (to compute the vector-Jacobian product) and forward-mode AD (to compute the Jacobian-vector product) to compute the NTK.\n",
"\n",
"- Let $vjps = (J_{net}^T(x2) \\cdot I)$. We can use a vector-Jacobian product to compute this.\n",
"- Now, consider $J_{net}(x1) \\cdot vjps$. This is a Jacobian-vector product!\n",
"\n",
"This suggests that we can use a combination of reverse-mode AD (to compute the vector-Jacobian product) and forward-mode AD (to compute the Jacobian-vector product) to compute the NTK. Let's code that up:"
"Let's code that up:"
]
},
{
@ -257,7 +266,7 @@
"metadata": {},
"outputs": [],
"source": [
"def empirical_ntk_implicit(func, params, x1, x2, compute='full'):\n",
"def empirical_ntk_ntk_vps(func, params, x1, x2, compute='full'):\n",
" def get_ntk(x1, x2):\n",
" def func_x1(params):\n",
" return func(params, x1)\n",
@ -280,7 +289,7 @@
" return vmap(get_ntk_slice)(basis)\n",
" \n",
" # get_ntk(x1, x2) computes the NTK for a single data point x1, x2\n",
" # Since the x1, x2 inputs to empirical_ntk_implicit are batched,\n",
" # Since the x1, x2 inputs to empirical_ntk_ntk_vps are batched,\n",
" # we actually wish to compute the NTK between every pair of data points\n",
" # between {x1} and {x2}. That's what the vmaps here do.\n",
" result = vmap(vmap(get_ntk, (None, 0)), (0, None))(x1, x2)\n",
@ -300,9 +309,9 @@
"metadata": {},
"outputs": [],
"source": [
"result_implicit = empirical_ntk_implicit(fnet_single, params, x_test, x_train)\n",
"result_explicit = empirical_ntk(fnet_single, params, x_test, x_train)\n",
"assert torch.allclose(result_implicit, result_explicit, atol=1e-5)"
"result_from_jacobian_contraction = empirical_ntk_jacobian_contraction(fnet_single, params, x_test, x_train)\n",
"result_from_ntk_vps = empirical_ntk_ntk_vps(fnet_single, params, x_test, x_train)\n",
"assert torch.allclose(result_from_jacobian_contraction, result_from_ntk_vps, atol=1e-5)"
]
},
{
@ -310,7 +319,9 @@
"id": "84253466-971d-4475-999c-fe3de6bd25b5",
"metadata": {},
"source": [
"Our code for `empirical_ntk_implicit` looks like a direct translation from the math above! This showcases the power of function transforms: good luck trying to write an efficient version of the above using stock PyTorch."
"Our code for `empirical_ntk_ntk_vps` looks like a direct translation from the math above! This showcases the power of function transforms: good luck trying to write an efficient version of the above using stock PyTorch.\n",
"\n",
"The asymptotic time complexity of this method is $N^2 O [FP]$, where $N$ is the batch size of $x_1$ and $x_2$, $O$ is the model's output size, and $[FP]$ is the cost of a single forward pass through the model. Hence this method performs more forward passes through the network than method 1, Jacobian contraction ($N^2 O$ instead of $N O$), but avoids the contraction cost altogether (no $N^2 O^2 P$ term, where $P$ is the total number of model's parameters). Therefore, this method is preferable when $O P$ is large relative to $[FP]$, such as fully-connected (not convolutional) models with many outputs $O$. Memory-wise, both methods should be comparable. See section 3.3 in [Fast Finite Width Neural Tangent Kernel](https://arxiv.org/abs/2206.08720) for details."
]
}
],