Files
pytorch/functorch/docs/source/notebooks/neural_tangent_kernels.ipynb
Klaus Zimmermann f16053f0c9 Switch to standard pep517 sdist generation (#152098)
Generate source tarball with PEP 517 conform build tools instead of the custom routine in place right now.

Closes #150461.

The current procedure for generating the source tarball consists in creation of a source tree by manual copying and pruning of source files.

This PR replaces that with a call to the standard [build tool](https://build.pypa.io/en/stable/), which works with the build backend to produce an sdist. For that to work correctly, the build backend also needs to be configured. In the case of Pytorch, the backend currently is (the legacy version of) the setuptools backend, the source dist part of which is mostly configured via the `MANIFEST.in` file.

The resulting source distribution can be used to install directly from source with `pip install ./torch-{version}.tar.gz` or to build wheels directly from source with `pip wheel ./torch-{version}.tar.gz`; both should be considered experimental for now.

## Issues

### sdist name
According to PEP 517, the name of the source distribution file must coincide with the project name, or [more precisely](https://peps.python.org/pep-0517/#source-distributions), the source distribution of a project that generates `{NAME}-{...}.whl` wheels are required to be named `{NAME}-{...}.tar.gz`. Currently, the source tarball is called `pytorch-{...}.tar.gz`, but the generated wheels and python package are called `torch-{...}`.

### Symbolic Links
The source tree at the moment contains a small number of symbolic links. This [has been seen as problematic](https://github.com/pypa/pip/issues/5919) largely because of lack of support on Windows, but also because of [a problem in setuptools](https://github.com/pypa/setuptools/issues/4937). Particularly unfortunate is a circular symlink in the third party `ittapi` module, which can not be resolved by replacing it with a copy.

PEP 721 (now integrated in the [Source Distribution Format Specification](https://packaging.python.org/en/latest/specifications/source-distribution-format/#source-distribution-archive-features)) allows for symbolic links, but only if they don't point outside the destination directory and if they don't contain `../` in their target.

The list of symbolic links currently is as follows:

<details>

|source|target|problem|solution|
|-|-|-|-|
| `.dockerignore` | `.gitignore` |  ok (individual file) ||
| `docs/requirements.txt` | `../.ci/docker/requirements-docs.txt` |`..` in target|swap source and target[^1]|
| `functorch/docs/source/notebooks` | `../../notebooks/` |`..` in target|swap source and target[^1]|
| `.github/ci_commit_pins/triton.txt` | `../../.ci/docker/ci_commit_pins/triton.txt` |  ok (omitted from sdist)||
| `third_party/flatbuffers/docs/source/CONTRIBUTING.md` | `../../CONTRIBUTING.md` |`..` in target|omit from sdist[^2]|
| `third_party/flatbuffers/java/src/test/java/DictionaryLookup` | `../../../../tests/DictionaryLookup` |`..` in target|omit from sdist[^3]|
| `third_party/flatbuffers/java/src/test/java/MyGame` | `../../../../tests/MyGame` |`..` in target|omit from sdist[^3]|
| `third_party/flatbuffers/java/src/test/java/NamespaceA` | `../../../../tests/namespace_test/NamespaceA` |`..` in target|omit from sdist[^3]|
| `third_party/flatbuffers/java/src/test/java/NamespaceC` | `../../../../tests/namespace_test/NamespaceC` |`..` in target|omit from sdist[^3]|
| `third_party/flatbuffers/java/src/test/java/optional_scalars` | `../../../../tests/optional_scalars` |`..` in target|omit from sdist[^3]|
| `third_party/flatbuffers/java/src/test/java/union_vector` | `../../../../tests/union_vector` |`..` in target|omit from sdist[^3]|
| `third_party/flatbuffers/kotlin/benchmark/src/jvmMain/java` | `../../../../java/src/main/java` |`..` in target|omit from sdist[^3]|
| `third_party/ittapi/rust/ittapi-sys/c-library` | `../../` |`..` in target|omit from sdist[^4]|
| `third_party/ittapi/rust/ittapi-sys/LICENSES` | `../../LICENSES` |`..` in target|omit from sdist[^4]|
| `third_party/opentelemetry-cpp/buildscripts/pre-merge-commit` | `./pre-commit` | ok (individual file)||
| `third_party/opentelemetry-cpp/third_party/prometheus-cpp/cmake/project-import-cmake/sample_client.cc` | `../../push/tests/integration/sample_client.cc` |`..` in target|omit from sdist[^5]|
| `third_party/opentelemetry-cpp/third_party/prometheus-cpp/cmake/project-import-cmake/sample_server.cc` | `../../pull/tests/integration/sample_server.cc` |`..` in target|omit from sdist[^5]|
| `third_party/opentelemetry-cpp/third_party/prometheus-cpp/cmake/project-import-pkgconfig/sample_client.cc` | `../../push/tests/integration/sample_client.cc` |`..` in target|omit from sdist[^5]|
| `third_party/opentelemetry-cpp/third_party/prometheus-cpp/cmake/project-import-pkgconfig/sample_server.cc` | `../../pull/tests/integration/sample_server.cc` |`..` in target|omit from sdist[^5]|
| `third_party/XNNPACK/tools/xngen` | `xngen.py` |  ok (individual file)||

</details>

The introduction of symbolic links inside the `.ci/docker` folder creates a new problem, however, because Docker's `COPY` command does not allow symlinks in this way. We work around that by using `tar ch` to dereference the symlinks before handing them over to `docker build`.

[^1]: These resources can be naturally considered to be part of the docs, so moving the actual files into the place of the current symlinks and replacing them with (unproblematic) symlinks can be said to improve semantics as well.

[^2]: The flatbuffers docs already actually use the original file, not the symlink and in the most recent releases, starting from flatbuffers-25.1.21 the symlink is replaced by the actual file thanks to a documentation overhaul.

[^3]: These resources are flatbuffers tests for java and kotlin and can be omitted from our sdist.

[^4]: We don't need to ship the rust bindings for ittapi.

[^5]: These are demonstration examples for how to link to prometheus-cpp using cmake and can be omitted.

### Nccl
Nccl used to be included as a submodule. However, with #146073 (first released in v2.7.0-rc1), the submodule was removed and replaced with a build time checkout procedure in `tools/build_pytorch_libs.py`, which checks out the required version of nccl from the upstream repository based on a commit pin recorded in `.ci/docker/ci_commit_pins/nccl-cu{11,12}.txt`.
This means that a crucial third party dependency is missing from the source distribution and as the `.ci` folder is omitted from the source distribution, it is not possible to use the build time download.
However, it *is* possible to use a system provided Nccl using the `USE_SYSTEM_NCCL` environment variable, which now also is the default for the official Pytorch wheels.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152098
Approved by: https://github.com/atalman
2025-06-30 19:07:34 +00:00

354 lines
14 KiB
Plaintext

{
"cells": [
{
"cell_type": "markdown",
"id": "b687b169-ec83-493d-a7c5-f8c6cd402ea3",
"metadata": {},
"source": [
"# Neural Tangent Kernels\n",
"\n",
"<a href=\"https://colab.research.google.com/github/pytorch/pytorch/blob/master/functorch/notebooks/neural_tangent_kernels.ipynb\">\n",
" <img style=\"width: auto\" src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/>\n",
"</a>\n",
"\n",
"The neural tangent kernel (NTK) is a kernel that describes [how a neural network evolves during training](https://en.wikipedia.org/wiki/Neural_tangent_kernel). There has been a lot of research around it [in recent years](https://arxiv.org/abs/1806.07572). This tutorial, inspired by the implementation of [NTKs in JAX](https://github.com/google/neural-tangents) (see [Fast Finite Width Neural Tangent Kernel](https://arxiv.org/abs/2206.08720) for details), demonstrates how to easily compute this quantity using functorch."
]
},
{
"cell_type": "markdown",
"id": "77f41c65-f070-4b60-b3d0-1c8f56ed4f64",
"metadata": {},
"source": [
"## Setup\n",
"\n",
"First, some setup. Let's define a simple CNN that we wish to compute the NTK of."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "855fa70b-5b63-4973-94df-41be57ab6ecf",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"from functorch import make_functional, vmap, vjp, jvp, jacrev\n",
"device = 'cuda'\n",
"\n",
"class CNN(nn.Module):\n",
" def __init__(self):\n",
" super().__init__()\n",
" self.conv1 = nn.Conv2d(3, 32, (3, 3))\n",
" self.conv2 = nn.Conv2d(32, 32, (3, 3))\n",
" self.conv3 = nn.Conv2d(32, 32, (3, 3))\n",
" self.fc = nn.Linear(21632, 10)\n",
" \n",
" def forward(self, x):\n",
" x = self.conv1(x)\n",
" x = x.relu()\n",
" x = self.conv2(x)\n",
" x = x.relu()\n",
" x = self.conv3(x)\n",
" x = x.flatten(1)\n",
" x = self.fc(x)\n",
" return x"
]
},
{
"cell_type": "markdown",
"id": "52c600e9-207a-41ec-93b4-5d940827bda0",
"metadata": {},
"source": [
"And let's generate some random data"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "0001a907-f5c9-4532-9ee9-2e94b8487d08",
"metadata": {},
"outputs": [],
"source": [
"x_train = torch.randn(20, 3, 32, 32, device=device)\n",
"x_test = torch.randn(5, 3, 32, 32, device=device)"
]
},
{
"cell_type": "markdown",
"id": "8af210fe-9613-48ee-a96c-d0836458b0f1",
"metadata": {},
"source": [
"## Create a function version of the model\n",
"\n",
"functorch transforms operate on functions. In particular, to compute the NTK, we will need a function that accepts the parameters of the model and a single input (as opposed to a batch of inputs!) and returns a single output.\n",
"\n",
"We'll use functorch's `make_functional` to accomplish the first step. If your module has buffers, you'll want to use `make_functional_with_buffers` instead."
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "e6b4bb59-bdde-46cd-8a28-7fd00a37a387",
"metadata": {},
"outputs": [],
"source": [
"net = CNN().to(device)\n",
"fnet, params = make_functional(net)"
]
},
{
"cell_type": "markdown",
"id": "319276a4-da45-499a-af47-0677107559b6",
"metadata": {},
"source": [
"Keep in mind that the model was originally written to accept a batch of input data points. In our CNN example, there are no inter-batch operations. That is, each data point in the batch is independent of other data points. With this assumption in mind, we can easily generate a function that evaluates the model on a single data point:"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "0b8b4021-eb10-4a50-9d99-3817cb0ce4cc",
"metadata": {},
"outputs": [],
"source": [
"def fnet_single(params, x):\n",
" return fnet(params, x.unsqueeze(0)).squeeze(0)"
]
},
{
"cell_type": "markdown",
"id": "62bc6b5a-31fa-411e-8069-e6c1f6d05248",
"metadata": {},
"source": [
"## Compute the NTK: method 1 (Jacobian contraction)\n",
"\n",
"We're ready to compute the empirical NTK. The empirical NTK for two data points $x_1$ and $x_2$ is defined as the matrix product between the Jacobian of the model evaluated at $x_1$ and the Jacobian of the model evaluated at $x_2$:\n",
"\n",
"$$J_{net}(x_1) J_{net}^T(x_2)$$\n",
"\n",
"In the batched case where $x_1$ is a batch of data points and $x_2$ is a batch of data points, then we want the matrix product between the Jacobians of all combinations of data points from $x_1$ and $x_2$.\n",
"\n",
"The first method consists of doing just that - computing the two Jacobians, and contracting them. Here's how to compute the NTK in the batched case:"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "99a38a4b-64d3-4e13-bd63-2d71e8dd6840",
"metadata": {},
"outputs": [],
"source": [
"def empirical_ntk_jacobian_contraction(fnet_single, params, x1, x2):\n",
" # Compute J(x1)\n",
" jac1 = vmap(jacrev(fnet_single), (None, 0))(params, x1)\n",
" jac1 = [j.flatten(2) for j in jac1]\n",
" \n",
" # Compute J(x2)\n",
" jac2 = vmap(jacrev(fnet_single), (None, 0))(params, x2)\n",
" jac2 = [j.flatten(2) for j in jac2]\n",
" \n",
" # Compute J(x1) @ J(x2).T\n",
" result = torch.stack([torch.einsum('Naf,Mbf->NMab', j1, j2) for j1, j2 in zip(jac1, jac2)])\n",
" result = result.sum(0)\n",
" return result"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "cbf54d2b-c4bc-46bd-9e55-e1471d639a4e",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([20, 5, 10, 10])\n"
]
}
],
"source": [
"result = empirical_ntk_jacobian_contraction(fnet_single, params, x_train, x_test)\n",
"print(result.shape)"
]
},
{
"cell_type": "markdown",
"id": "ea844f45-98fb-4cba-8056-644292b968ab",
"metadata": {},
"source": [
"In some cases, you may only want the diagonal or the trace of this quantity, especially if you know beforehand that the network architecture results in an NTK where the non-diagonal elements can be approximated by zero. It's easy to adjust the above function to do that:"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "aae760c9-e906-4fda-b490-1126a86b7e96",
"metadata": {},
"outputs": [],
"source": [
"def empirical_ntk_jacobian_contraction(fnet_single, params, x1, x2, compute='full'):\n",
" # Compute J(x1)\n",
" jac1 = vmap(jacrev(fnet_single), (None, 0))(params, x1)\n",
" jac1 = [j.flatten(2) for j in jac1]\n",
" \n",
" # Compute J(x2)\n",
" jac2 = vmap(jacrev(fnet_single), (None, 0))(params, x2)\n",
" jac2 = [j.flatten(2) for j in jac2]\n",
" \n",
" # Compute J(x1) @ J(x2).T\n",
" einsum_expr = None\n",
" if compute == 'full':\n",
" einsum_expr = 'Naf,Mbf->NMab'\n",
" elif compute == 'trace':\n",
" einsum_expr = 'Naf,Maf->NM'\n",
" elif compute == 'diagonal':\n",
" einsum_expr = 'Naf,Maf->NMa'\n",
" else:\n",
" assert False\n",
" \n",
" result = torch.stack([torch.einsum(einsum_expr, j1, j2) for j1, j2 in zip(jac1, jac2)])\n",
" result = result.sum(0)\n",
" return result"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "42d974f3-1f9d-4953-8677-5ee22cfc67eb",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([20, 5])\n"
]
}
],
"source": [
"result = empirical_ntk_jacobian_contraction(fnet_single, params, x_train, x_test, 'trace')\n",
"print(result.shape)"
]
},
{
"cell_type": "markdown",
"id": "6c941e5d-51d7-47b2-80ee-edcd4aee6aaa",
"metadata": {},
"source": [
"The asymptotic time complexity of this method is $N O [FP]$ (time to compute the Jacobians) $ + N^2 O^2 P$ (time to contract the Jacobians), where $N$ is the batch size of $x_1$ and $x_2$, $O$ is the model's output size, $P$ is the total number of parameters, and $[FP]$ is the cost of a single forward pass through the model. See section section 3.2 in [Fast Finite Width Neural Tangent Kernel](https://arxiv.org/abs/2206.08720) for details."
]
},
{
"cell_type": "markdown",
"id": "6c931e5d-51d7-47b2-80ee-ddcd4aee6aaa",
"metadata": {},
"source": [
"## Compute the NTK: method 2 (NTK-vector products)\n",
"\n",
"The next method we will discuss is a way to compute the NTK using NTK-vector products.\n",
"\n",
"This method reformulates NTK as a stack of NTK-vector products applied to columns of an identity matrix $I_O$ of size $O\\times O$ (where $O$ is the output size of the model):\n",
"\n",
"$$J_{net}(x_1) J_{net}^T(x_2) = J_{net}(x_1) J_{net}^T(x_2) I_{O} = \\left[J_{net}(x_1) \\left[J_{net}^T(x_2) e_o\\right]\\right]_{o=1}^{O},$$\n",
"where $e_o\\in \\mathbb{R}^O$ are column vectors of the identity matrix $I_O$.\n",
"\n",
"- Let $\\textrm{vjp}_o = J_{net}^T(x_2) e_o$. We can use a vector-Jacobian product to compute this.\n",
"- Now, consider $J_{net}(x_1) \\textrm{vjp}_o$. This is a Jacobian-vector product!\n",
"- Finally, we can run the above computation in parallel over all columns $e_o$ of $I_O$ using `vmap`.\n",
"\n",
"This suggests that we can use a combination of reverse-mode AD (to compute the vector-Jacobian product) and forward-mode AD (to compute the Jacobian-vector product) to compute the NTK.\n",
"\n",
"Let's code that up:"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "dc4b49d7-3096-45d5-a7a1-7032309a2613",
"metadata": {},
"outputs": [],
"source": [
"def empirical_ntk_ntk_vps(func, params, x1, x2, compute='full'):\n",
" def get_ntk(x1, x2):\n",
" def func_x1(params):\n",
" return func(params, x1)\n",
"\n",
" def func_x2(params):\n",
" return func(params, x2)\n",
"\n",
" output, vjp_fn = vjp(func_x1, params)\n",
"\n",
" def get_ntk_slice(vec):\n",
" # This computes vec @ J(x2).T\n",
" # `vec` is some unit vector (a single slice of the Identity matrix)\n",
" vjps = vjp_fn(vec)\n",
" # This computes J(X1) @ vjps\n",
" _, jvps = jvp(func_x2, (params,), vjps)\n",
" return jvps\n",
"\n",
" # Here's our identity matrix\n",
" basis = torch.eye(output.numel(), dtype=output.dtype, device=output.device).view(output.numel(), -1)\n",
" return vmap(get_ntk_slice)(basis)\n",
" \n",
" # get_ntk(x1, x2) computes the NTK for a single data point x1, x2\n",
" # Since the x1, x2 inputs to empirical_ntk_ntk_vps are batched,\n",
" # we actually wish to compute the NTK between every pair of data points\n",
" # between {x1} and {x2}. That's what the vmaps here do.\n",
" result = vmap(vmap(get_ntk, (None, 0)), (0, None))(x1, x2)\n",
" \n",
" if compute == 'full':\n",
" return result\n",
" if compute == 'trace':\n",
" return torch.einsum('NMKK->NM', result)\n",
" if compute == 'diagonal':\n",
" return torch.einsum('NMKK->NMK', result)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "f750544f-9e48-47fe-9f9b-e1b8ae49b245",
"metadata": {},
"outputs": [],
"source": [
"result_from_jacobian_contraction = empirical_ntk_jacobian_contraction(fnet_single, params, x_test, x_train)\n",
"result_from_ntk_vps = empirical_ntk_ntk_vps(fnet_single, params, x_test, x_train)\n",
"assert torch.allclose(result_from_jacobian_contraction, result_from_ntk_vps, atol=1e-5)"
]
},
{
"cell_type": "markdown",
"id": "84253466-971d-4475-999c-fe3de6bd25b5",
"metadata": {},
"source": [
"Our code for `empirical_ntk_ntk_vps` looks like a direct translation from the math above! This showcases the power of function transforms: good luck trying to write an efficient version of the above using stock PyTorch.\n",
"\n",
"The asymptotic time complexity of this method is $N^2 O [FP]$, where $N$ is the batch size of $x_1$ and $x_2$, $O$ is the model's output size, and $[FP]$ is the cost of a single forward pass through the model. Hence this method performs more forward passes through the network than method 1, Jacobian contraction ($N^2 O$ instead of $N O$), but avoids the contraction cost altogether (no $N^2 O^2 P$ term, where $P$ is the total number of model's parameters). Therefore, this method is preferable when $O P$ is large relative to $[FP]$, such as fully-connected (not convolutional) models with many outputs $O$. Memory-wise, both methods should be comparable. See section 3.3 in [Fast Finite Width Neural Tangent Kernel](https://arxiv.org/abs/2206.08720) for details."
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.7"
}
},
"nbformat": 4,
"nbformat_minor": 5
}