mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-11 22:34:53 +08:00
[functorch] Add complexities and references for NTK implementations. (pytorch/functorch#907)
* Add complexities and references for NTK implementations. * Fix result names; rename "outer product" -> "matrix product". * Fix names
This commit is contained in:
@ -7,7 +7,7 @@
|
||||
"source": [
|
||||
"# Neural Tangent Kernels\n",
|
||||
"\n",
|
||||
"The neural tangent kernel (NTK) is a kernel that describes [how a neural network evolves during training](https://en.wikipedia.org/wiki/Neural_tangent_kernel). There has been a lot of research around it [in recent years](https://arxiv.org/abs/1806.07572). This tutorial, inspired by the implementation of [NTKs in JAX](https://github.com/google/neural-tangents), demonstrates how to easily compute this quantity using functorch."
|
||||
"The neural tangent kernel (NTK) is a kernel that describes [how a neural network evolves during training](https://en.wikipedia.org/wiki/Neural_tangent_kernel). There has been a lot of research around it [in recent years](https://arxiv.org/abs/1806.07572). This tutorial, inspired by the implementation of [NTKs in JAX](https://github.com/google/neural-tangents) (see [Fast Finite Width Neural Tangent Kernel](https://arxiv.org/abs/2206.08720) for details), demonstrates how to easily compute this quantity using functorch."
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -79,7 +79,7 @@
|
||||
"\n",
|
||||
"functorch transforms operate on functions. In particular, to compute the NTK, we will need a function that accepts the parameters of the model and a single input (as opposed to a batch of inputs!) and returns a single output.\n",
|
||||
"\n",
|
||||
"We'll use functorch's make_functional to accomplish the first step. If your module has buffers, you'll want to use make_functional_with_buffers instead."
|
||||
"We'll use functorch's `make_functional` to accomplish the first step. If your module has buffers, you'll want to use `make_functional_with_buffers` instead."
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -117,13 +117,15 @@
|
||||
"id": "62bc6b5a-31fa-411e-8069-e6c1f6d05248",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Compute the NTK: method 1\n",
|
||||
"## Compute the NTK: method 1 (Jacobian contraction)\n",
|
||||
"\n",
|
||||
"We're ready to compute the empirical NTK. The empirical NTK for two data points `x1` and `x2` is defined as an inner product between the Jacobian of the model evaluated at `x1` and the Jacobian of the model evaluated at `x2`:\n",
|
||||
"We're ready to compute the empirical NTK. The empirical NTK for two data points $x_1$ and $x_2$ is defined as the matrix product between the Jacobian of the model evaluated at $x_1$ and the Jacobian of the model evaluated at $x_2$:\n",
|
||||
"\n",
|
||||
"$$J_{net}(x1) \\cdot J_{net}^T(x2)$$\n",
|
||||
"$$J_{net}(x_1) J_{net}^T(x_2)$$\n",
|
||||
"\n",
|
||||
"In the batched case where `x1` is a batch of data points and `x2` is a batch of data points, then we want the inner product between the Jacobians of all combinations of data points from `x1` and `x2`. Here's how to compute the NTK in the batched case:"
|
||||
"In the batched case where $x_1$ is a batch of data points and $x_2$ is a batch of data points, then we want the matrix product between the Jacobians of all combinations of data points from $x_1$ and $x_2$.\n",
|
||||
"\n",
|
||||
"The first method consists of doing just that - computing the two Jacobians, and contracting them. Here's how to compute the NTK in the batched case:"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -133,7 +135,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def empirical_ntk(fnet_single, params, x1, x2):\n",
|
||||
"def empirical_ntk_jacobian_contraction(fnet_single, params, x1, x2):\n",
|
||||
" # Compute J(x1)\n",
|
||||
" jac1 = vmap(jacrev(fnet_single), (None, 0))(params, x1)\n",
|
||||
" jac1 = [j.flatten(2) for j in jac1]\n",
|
||||
@ -163,7 +165,7 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"result = empirical_ntk(fnet_single, params, x_train, x_test)\n",
|
||||
"result = empirical_ntk_jacobian_contraction(fnet_single, params, x_train, x_test)\n",
|
||||
"print(result.shape)"
|
||||
]
|
||||
},
|
||||
@ -182,7 +184,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def empirical_ntk(fnet_single, params, x1, x2, compute='full'):\n",
|
||||
"def empirical_ntk_jacobian_contraction(fnet_single, params, x1, x2, compute='full'):\n",
|
||||
" # Compute J(x1)\n",
|
||||
" jac1 = vmap(jacrev(fnet_single), (None, 0))(params, x1)\n",
|
||||
" jac1 = [j.flatten(2) for j in jac1]\n",
|
||||
@ -222,32 +224,39 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"result = empirical_ntk(fnet_single, params, x_train, x_test, 'trace')\n",
|
||||
"result = empirical_ntk_jacobian_contraction(fnet_single, params, x_train, x_test, 'trace')\n",
|
||||
"print(result.shape)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "6c941e5d-51d7-47b2-80ee-edcd4aee6aaa",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"The asymptotic time complexity of this method is $N O [FP]$ (time to compute the Jacobians) $ + N^2 O^2 P$ (time to contract the Jacobians), where $N$ is the batch size of $x_1$ and $x_2$, $O$ is the model's output size, $P$ is the total number of parameters, and $[FP]$ is the cost of a single forward pass through the model. See section section 3.2 in [Fast Finite Width Neural Tangent Kernel](https://arxiv.org/abs/2206.08720) for details."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "6c931e5d-51d7-47b2-80ee-ddcd4aee6aaa",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Compute the NTK: method 2\n",
|
||||
"## Compute the NTK: method 2 (NTK-vector products)\n",
|
||||
"\n",
|
||||
"The next method we will discuss is a way to compute the NTK implicitly. This has different tradeoffs compared to the previous one and it is generally more efficient when your model has large parameters; we recommend trying out both methods to see which works better.\n",
|
||||
"The next method we will discuss is a way to compute the NTK using NTK-vector products.\n",
|
||||
"\n",
|
||||
"Here's our definition of NTK:\n",
|
||||
"This method reformulates NTK as a stack of NTK-vector products applied to columns of an identity matrix $I_O$ of size $O\\times O$ (where $O$ is the output size of the model):\n",
|
||||
"\n",
|
||||
"$$J_{net}(x1) \\cdot J_{net}^T(x2)$$\n",
|
||||
"$$J_{net}(x_1) J_{net}^T(x_2) = J_{net}(x_1) J_{net}^T(x_2) I_{O} = \\left[J_{net}(x_1) \\left[J_{net}^T(x_2) e_o\\right]\\right]_{o=1}^{O},$$\n",
|
||||
"where $e_o\\in \\mathbb{R}^O$ are column vectors of the identity matrix $I_O$.\n",
|
||||
"\n",
|
||||
"The implicit computation reformulates the problem by adding an identity matrix and rearranging the matrix-multiplies:\n",
|
||||
"- Let $\\textrm{vjp}_o = J_{net}^T(x_2) e_o$. We can use a vector-Jacobian product to compute this.\n",
|
||||
"- Now, consider $J_{net}(x_1) \\textrm{vjp}_o$. This is a Jacobian-vector product!\n",
|
||||
"- Finally, we can run the above computation in parallel over all columns $e_o$ of $I_O$ using `vmap`.\n",
|
||||
"\n",
|
||||
"$$= J_{net}(x1) \\cdot J_{net}^T(x2) \\cdot I$$\n",
|
||||
"$$= (J_{net}(x1) \\cdot (J_{net}^T(x2) \\cdot I))$$\n",
|
||||
"This suggests that we can use a combination of reverse-mode AD (to compute the vector-Jacobian product) and forward-mode AD (to compute the Jacobian-vector product) to compute the NTK.\n",
|
||||
"\n",
|
||||
"- Let $vjps = (J_{net}^T(x2) \\cdot I)$. We can use a vector-Jacobian product to compute this.\n",
|
||||
"- Now, consider $J_{net}(x1) \\cdot vjps$. This is a Jacobian-vector product!\n",
|
||||
"\n",
|
||||
"This suggests that we can use a combination of reverse-mode AD (to compute the vector-Jacobian product) and forward-mode AD (to compute the Jacobian-vector product) to compute the NTK. Let's code that up:"
|
||||
"Let's code that up:"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -257,7 +266,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def empirical_ntk_implicit(func, params, x1, x2, compute='full'):\n",
|
||||
"def empirical_ntk_ntk_vps(func, params, x1, x2, compute='full'):\n",
|
||||
" def get_ntk(x1, x2):\n",
|
||||
" def func_x1(params):\n",
|
||||
" return func(params, x1)\n",
|
||||
@ -280,7 +289,7 @@
|
||||
" return vmap(get_ntk_slice)(basis)\n",
|
||||
" \n",
|
||||
" # get_ntk(x1, x2) computes the NTK for a single data point x1, x2\n",
|
||||
" # Since the x1, x2 inputs to empirical_ntk_implicit are batched,\n",
|
||||
" # Since the x1, x2 inputs to empirical_ntk_ntk_vps are batched,\n",
|
||||
" # we actually wish to compute the NTK between every pair of data points\n",
|
||||
" # between {x1} and {x2}. That's what the vmaps here do.\n",
|
||||
" result = vmap(vmap(get_ntk, (None, 0)), (0, None))(x1, x2)\n",
|
||||
@ -300,9 +309,9 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"result_implicit = empirical_ntk_implicit(fnet_single, params, x_test, x_train)\n",
|
||||
"result_explicit = empirical_ntk(fnet_single, params, x_test, x_train)\n",
|
||||
"assert torch.allclose(result_implicit, result_explicit, atol=1e-5)"
|
||||
"result_from_jacobian_contraction = empirical_ntk_jacobian_contraction(fnet_single, params, x_test, x_train)\n",
|
||||
"result_from_ntk_vps = empirical_ntk_ntk_vps(fnet_single, params, x_test, x_train)\n",
|
||||
"assert torch.allclose(result_from_jacobian_contraction, result_from_ntk_vps, atol=1e-5)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -310,7 +319,9 @@
|
||||
"id": "84253466-971d-4475-999c-fe3de6bd25b5",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Our code for `empirical_ntk_implicit` looks like a direct translation from the math above! This showcases the power of function transforms: good luck trying to write an efficient version of the above using stock PyTorch."
|
||||
"Our code for `empirical_ntk_ntk_vps` looks like a direct translation from the math above! This showcases the power of function transforms: good luck trying to write an efficient version of the above using stock PyTorch.\n",
|
||||
"\n",
|
||||
"The asymptotic time complexity of this method is $N^2 O [FP]$, where $N$ is the batch size of $x_1$ and $x_2$, $O$ is the model's output size, and $[FP]$ is the cost of a single forward pass through the model. Hence this method performs more forward passes through the network than method 1, Jacobian contraction ($N^2 O$ instead of $N O$), but avoids the contraction cost altogether (no $N^2 O^2 P$ term, where $P$ is the total number of model's parameters). Therefore, this method is preferable when $O P$ is large relative to $[FP]$, such as fully-connected (not convolutional) models with many outputs $O$. Memory-wise, both methods should be comparable. See section 3.3 in [Fast Finite Width Neural Tangent Kernel](https://arxiv.org/abs/2206.08720) for details."
|
||||
]
|
||||
}
|
||||
],
|
||||
|
||||
Reference in New Issue
Block a user