mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-31 12:15:03 +08:00 
			
		
		
		
	Compare commits
	
		
			284 Commits
		
	
	
		
			viable/str
			...
			revert-cpp
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 2eacbe792a | |||
| 8110ce02a2 | |||
| 43c30f607e | |||
| 5ebf74a655 | |||
| acd936cc1a | |||
| a4a0378e6b | |||
| ac841267a1 | |||
| 0eacd934bc | |||
| 5016e7b2eb | |||
| 544b443ea1 | |||
| 3041ede082 | |||
| 34d6ef7022 | |||
| 110efe4df4 | |||
| e137cd0a10 | |||
| be28329710 | |||
| 85a7c745aa | |||
| 32fe4f681e | |||
| ebb2b2e894 | |||
| 13413b3b07 | |||
| 5d0b3e28dc | |||
| 9139368b64 | |||
| 02095cc09d | |||
| 65868156c6 | |||
| f93ea7dab1 | |||
| a77f5d9a00 | |||
| ff46d5a79b | |||
| f452edd782 | |||
| ea698e8bfc | |||
| 7f7a28046b | |||
| d8283a317a | |||
| e0ca3049c0 | |||
| 8417981c96 | |||
| 06e71c8558 | |||
| a76b59cc45 | |||
| 74336f8c77 | |||
| 236ce736a1 | |||
| 17bdb232e1 | |||
| add37bacda | |||
| 1425b40f29 | |||
| 8af9ed0824 | |||
| 7045aab143 | |||
| 7ae8aaf4c0 | |||
| f2450798cd | |||
| 46d17e8871 | |||
| dc011d3203 | |||
| e95920e3e6 | |||
| 5e769ff867 | |||
| 0ae3e30621 | |||
| 47f50cfd45 | |||
| a51f877287 | |||
| b44423bbb4 | |||
| 8e1e4ee8e0 | |||
| 1e836bc769 | |||
| 9a91486e45 | |||
| 92381a5aa7 | |||
| 2a5f87decf | |||
| 840d63c12d | |||
| 2ce894bb1d | |||
| 47ec1e9990 | |||
| 904abfc2ca | |||
| 7d16fcf2df | |||
| 483845a9c4 | |||
| 60bcb4ee88 | |||
| ee7434be82 | |||
| d049ed2cb1 | |||
| 9901d44418 | |||
| 6096c0fc74 | |||
| f6951cb8ea | |||
| 8887a33ede | |||
| 36a48e7e6d | |||
| c6a02eae5b | |||
| 6ecd6b23b6 | |||
| 3f69b4d9b4 | |||
| a04edcb27a | |||
| eb2bad5bb5 | |||
| a076b4d7ac | |||
| a988510c33 | |||
| 99e07c39ec | |||
| 610c09f8f4 | |||
| 61bad3c1ea | |||
| f89a7e9fe8 | |||
| f2c81635c8 | |||
| e214af6ae8 | |||
| 7ce723d21c | |||
| 4295a9a158 | |||
| 90d7be35e9 | |||
| 8d4e48831e | |||
| 90b30ebf7e | |||
| 173bcda436 | |||
| 6530bc70fb | |||
| 4c38887346 | |||
| 81fa4a204c | |||
| 4e6afa8c07 | |||
| 79aa88cc5d | |||
| fa4cb91846 | |||
| c58d0ad85d | |||
| 000f49551b | |||
| 9940e894ea | |||
| 27302a4932 | |||
| 507614ba43 | |||
| 86f9f1d0ab | |||
| 154e4d36e9 | |||
| a2b6afeac5 | |||
| 262830d86c | |||
| e4c01011c2 | |||
| a60d9e1f6d | |||
| f863550192 | |||
| 84b14f3a10 | |||
| 5121499f6b | |||
| 8f80892359 | |||
| cdb60e44eb | |||
| 25909d2629 | |||
| c7eee49525 | |||
| 621ba05107 | |||
| 39a70cead1 | |||
| d97f6550a2 | |||
| 516e58965a | |||
| b55b779ad3 | |||
| 74e53d0761 | |||
| 798a6d2be1 | |||
| b0e9c86971 | |||
| 661a56002f | |||
| c9bc00f016 | |||
| ec51b139e1 | |||
| eb83c3ca23 | |||
| 7924e3aacf | |||
| 78bcfcf870 | |||
| 1e2e7cb18b | |||
| 003601a70d | |||
| 1d58d5fe25 | |||
| de7fdfe41a | |||
| b31bad1b8f | |||
| 2efcf3ca98 | |||
| 761f946043 | |||
| 8aa465f18e | |||
| 0a5d68d92d | |||
| 42bd210fff | |||
| 1d13c314b3 | |||
| 0c9763a5a0 | |||
| 79a4a9c02e | |||
| 9d0b77f4cd | |||
| d486eee234 | |||
| cddd5f74ab | |||
| dfdb68e51f | |||
| 98c818320a | |||
| cc20b7ad72 | |||
| bc11a42b3f | |||
| 4fc06f2e0a | |||
| 82473c3d59 | |||
| b6a4236e5d | |||
| b04173be9b | |||
| 32ac38f85d | |||
| c9b49e506e | |||
| 6038e476e8 | |||
| 2c851c16e5 | |||
| 31584f2d91 | |||
| 0442125362 | |||
| fdcf402d82 | |||
| 13cda9b89e | |||
| fa6d911dda | |||
| 0db6bcc015 | |||
| 60ac039998 | |||
| 380d440d1c | |||
| 9038a30cee | |||
| 690c8c13b9 | |||
| 28ee6b62ed | |||
| 81577bdb3f | |||
| e67e3d95f3 | |||
| 27af8480ea | |||
| 6494cdc40c | |||
| ac7074efa2 | |||
| 263901cec4 | |||
| c12293dcbe | |||
| 5a4997dcae | |||
| 47f638eae7 | |||
| 882b834082 | |||
| b146ea411e | |||
| 8625ffbd45 | |||
| 0977cc4474 | |||
| d9a55faccc | |||
| 75b8295868 | |||
| defb6a80d8 | |||
| f8fccb1e48 | |||
| 5aac4cfce4 | |||
| baf91bbbfc | |||
| cbcb4f7768 | |||
| 2b93d5b450 | |||
| 6b7cd48e7e | |||
| bf5aa9e42e | |||
| b1eb6dede5 | |||
| 673060beae | |||
| 2e8e9a59a8 | |||
| fb277a5916 | |||
| 73fa0d0c63 | |||
| 36c21cc84e | |||
| 0b68814b44 | |||
| e64a814ae7 | |||
| 0b58d87aec | |||
| 757975ad50 | |||
| 291712026b | |||
| 3e77a2b478 | |||
| 82ef1b5db3 | |||
| 5f370f5c42 | |||
| 05b2e02cb4 | |||
| 12f742941d | |||
| 35180fafee | |||
| c746feb86a | |||
| c5f26db5bf | |||
| 18e99b6d45 | |||
| ab9e466928 | |||
| af4ba78543 | |||
| 282f39a4bc | |||
| a479769488 | |||
| 26c7375477 | |||
| d01f15152c | |||
| 4fae6968b1 | |||
| f9953e0f61 | |||
| 34ed7a8f0d | |||
| 2fde10d914 | |||
| 0a93295da0 | |||
| 4b898b51b9 | |||
| 550e3e6efb | |||
| 715449ca76 | |||
| 84d8d06fc3 | |||
| 60992d98b2 | |||
| 59e015e3a1 | |||
| 8904a5a7c9 | |||
| f5df9ca03a | |||
| 2998abd777 | |||
| e13580e41c | |||
| f3b8e15f20 | |||
| 5211f4c108 | |||
| ad9027b80d | |||
| a1005427bf | |||
| 35153d0846 | |||
| 7773a22cdb | |||
| 7cb467a169 | |||
| 12aac12b8d | |||
| 2b748d0a56 | |||
| 16745a882a | |||
| 8daef35cf1 | |||
| 51319ca090 | |||
| d311a3d1dc | |||
| 04adfe5ba9 | |||
| 4be1e3bf92 | |||
| e7592f4005 | |||
| d334c3649d | |||
| 9f82535c5a | |||
| 5b35fc8777 | |||
| 2f38eece7c | |||
| 830e789a55 | |||
| ad4dc52bf6 | |||
| dac9ed9790 | |||
| 1c7fe8f861 | |||
| 4e643422f6 | |||
| 3c3b278872 | |||
| 0bd12c1168 | |||
| ce8a7764e2 | |||
| d1269a0434 | |||
| c87cf1be32 | |||
| 2fc5e45a41 | |||
| f9022ba93b | |||
| ff8be889ad | |||
| 292454942e | |||
| 6c4412f72b | |||
| 78bf6186f2 | |||
| c40048472c | |||
| 3dfd0c7584 | |||
| e6ba4d0725 | |||
| bdf7cb9d9c | |||
| 6aed378958 | |||
| 8b3dc0d1b0 | |||
| 06773663b5 | |||
| 0bff65503c | |||
| 21131a2444 | |||
| 1009790ad8 | |||
| 410e6a4321 | |||
| 23c55c5b66 | |||
| 1290b077f2 | |||
| 9f9ab881b2 | |||
| f2bb22ff84 | |||
| 03f3f7899c | |||
| 771170807b | |||
| ffa90d46e6 | 
| @ -150,7 +150,7 @@ function install_130 { | |||||||
|   CUDNN_VERSION=9.13.0.50 |   CUDNN_VERSION=9.13.0.50 | ||||||
|   echo "Installing CUDA 13.0 and cuDNN ${CUDNN_VERSION} and NVSHMEM and NCCL and cuSparseLt-0.7.1" |   echo "Installing CUDA 13.0 and cuDNN ${CUDNN_VERSION} and NVSHMEM and NCCL and cuSparseLt-0.7.1" | ||||||
|   # install CUDA 13.0 in the same container |   # install CUDA 13.0 in the same container | ||||||
|   install_cuda 13.0.0 cuda_13.0.0_580.65.06_linux |   install_cuda 13.0.2 cuda_13.0.2_580.95.05_linux | ||||||
|  |  | ||||||
|   # cuDNN license: https://developer.nvidia.com/cudnn/license_agreement |   # cuDNN license: https://developer.nvidia.com/cudnn/license_agreement | ||||||
|   install_cudnn 13 $CUDNN_VERSION |   install_cudnn 13 $CUDNN_VERSION | ||||||
|  | |||||||
| @ -19,7 +19,7 @@ pip_install \ | |||||||
|   transformers==4.36.2 |   transformers==4.36.2 | ||||||
|  |  | ||||||
| pip_install coloredlogs packaging | pip_install coloredlogs packaging | ||||||
| pip_install onnxruntime==1.23.0 | pip_install onnxruntime==1.23.1 | ||||||
| pip_install onnxscript==0.5.4 | pip_install onnxscript==0.5.4 | ||||||
|  |  | ||||||
| # Cache the transformers model to be used later by ONNX tests. We need to run the transformers | # Cache the transformers model to be used later by ONNX tests. We need to run the transformers | ||||||
|  | |||||||
| @ -334,12 +334,12 @@ sympy==1.13.3 | |||||||
| #Pinned versions: | #Pinned versions: | ||||||
| #test that import: | #test that import: | ||||||
|  |  | ||||||
| onnx==1.18.0 | onnx==1.19.1 | ||||||
| #Description: Required by onnx tests, and mypy and test_public_bindings.py when checking torch.onnx._internal | #Description: Required by onnx tests, and mypy and test_public_bindings.py when checking torch.onnx._internal | ||||||
| #Pinned versions: | #Pinned versions: | ||||||
| #test that import: | #test that import: | ||||||
|  |  | ||||||
| onnxscript==0.5.3 | onnxscript==0.5.4 | ||||||
| #Description: Required by mypy and test_public_bindings.py when checking torch.onnx._internal | #Description: Required by mypy and test_public_bindings.py when checking torch.onnx._internal | ||||||
| #Pinned versions: | #Pinned versions: | ||||||
| #test that import: | #test that import: | ||||||
|  | |||||||
| @ -6,7 +6,7 @@ dependencies = [ | |||||||
|     "GitPython==3.1.45", |     "GitPython==3.1.45", | ||||||
|     "docker==7.1.0", |     "docker==7.1.0", | ||||||
|     "pytest==7.3.2", |     "pytest==7.3.2", | ||||||
|     "uv==0.8.6" |     "uv==0.9.5" | ||||||
| ] | ] | ||||||
|  |  | ||||||
| [tool.setuptools] | [tool.setuptools] | ||||||
|  | |||||||
| @ -163,8 +163,13 @@ if [[ "$(uname)" != Darwin ]]; then | |||||||
|   MEMORY_LIMIT_MAX_JOBS=12 |   MEMORY_LIMIT_MAX_JOBS=12 | ||||||
|   NUM_CPUS=$(( $(nproc) - 2 )) |   NUM_CPUS=$(( $(nproc) - 2 )) | ||||||
|  |  | ||||||
|   # Defaults here for **binary** linux builds so they can be changed in one place |   if [[ "$(uname)" == Linux ]]; then | ||||||
|   export MAX_JOBS=${MAX_JOBS:-$(( ${NUM_CPUS} > ${MEMORY_LIMIT_MAX_JOBS} ? ${MEMORY_LIMIT_MAX_JOBS} : ${NUM_CPUS} ))} |     # Defaults here for **binary** linux builds so they can be changed in one place | ||||||
|  |     export MAX_JOBS=${MAX_JOBS:-$(( ${NUM_CPUS} > ${MEMORY_LIMIT_MAX_JOBS} ? ${MEMORY_LIMIT_MAX_JOBS} : ${NUM_CPUS} ))} | ||||||
|  |   else | ||||||
|  |     # For other builds | ||||||
|  |     export MAX_JOBS=${NUM_CPUS} | ||||||
|  |   fi | ||||||
|  |  | ||||||
|   cat >>"$envfile" <<EOL |   cat >>"$envfile" <<EOL | ||||||
|   export MAX_JOBS="${MAX_JOBS}" |   export MAX_JOBS="${MAX_JOBS}" | ||||||
|  | |||||||
							
								
								
									
										359
									
								
								.claude/skills/docstring/SKILL.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										359
									
								
								.claude/skills/docstring/SKILL.md
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,359 @@ | |||||||
|  | --- | ||||||
|  | name: docstring | ||||||
|  | description: Write docstrings for PyTorch functions and methods following PyTorch conventions. Use when writing or updating docstrings in PyTorch code. | ||||||
|  | --- | ||||||
|  |  | ||||||
|  | # PyTorch Docstring Writing Guide | ||||||
|  |  | ||||||
|  | This skill describes how to write docstrings for functions and methods in the PyTorch project, following the conventions in `torch/_tensor_docs.py` and `torch/nn/functional.py`. | ||||||
|  |  | ||||||
|  | ## General Principles | ||||||
|  |  | ||||||
|  | - Use **raw strings** (`r"""..."""`) for all docstrings to avoid issues with LaTeX/math backslashes | ||||||
|  | - Follow **Sphinx/reStructuredText** (reST) format for documentation | ||||||
|  | - Be **concise but complete** - include all essential information | ||||||
|  | - Always include **examples** when possible | ||||||
|  | - Use **cross-references** to related functions/classes | ||||||
|  |  | ||||||
|  | ## Docstring Structure | ||||||
|  |  | ||||||
|  | ### 1. Function Signature (First Line) | ||||||
|  |  | ||||||
|  | Start with the function signature showing all parameters: | ||||||
|  |  | ||||||
|  | ```python | ||||||
|  | r"""function_name(param1, param2, *, kwarg1=default1, kwarg2=default2) -> ReturnType | ||||||
|  | ``` | ||||||
|  |  | ||||||
|  | **Notes:** | ||||||
|  | - Include the function name | ||||||
|  | - Show positional and keyword-only arguments (use `*` separator) | ||||||
|  | - Include default values | ||||||
|  | - Show return type annotation | ||||||
|  | - This line should NOT end with a period | ||||||
|  |  | ||||||
|  | ### 2. Brief Description | ||||||
|  |  | ||||||
|  | Provide a one-line description of what the function does: | ||||||
|  |  | ||||||
|  | ```python | ||||||
|  | r"""conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1) -> Tensor | ||||||
|  |  | ||||||
|  | Applies a 2D convolution over an input image composed of several input | ||||||
|  | planes. | ||||||
|  | ``` | ||||||
|  |  | ||||||
|  | ### 3. Mathematical Formulas (if applicable) | ||||||
|  |  | ||||||
|  | Use Sphinx math directives for mathematical expressions: | ||||||
|  |  | ||||||
|  | ```python | ||||||
|  | .. math:: | ||||||
|  |     \text{Softmax}(x_{i}) = \frac{\exp(x_i)}{\sum_j \exp(x_j)} | ||||||
|  | ``` | ||||||
|  |  | ||||||
|  | Or inline math: `:math:\`x^2\`` | ||||||
|  |  | ||||||
|  | ### 4. Cross-References | ||||||
|  |  | ||||||
|  | Link to related classes and functions using Sphinx roles: | ||||||
|  |  | ||||||
|  | - `:class:\`~torch.nn.ModuleName\`` - Link to a class | ||||||
|  | - `:func:\`torch.function_name\`` - Link to a function | ||||||
|  | - `:meth:\`~Tensor.method_name\`` - Link to a method | ||||||
|  | - `:attr:\`attribute_name\`` - Reference an attribute | ||||||
|  | - The `~` prefix shows only the last component (e.g., `Conv2d` instead of `torch.nn.Conv2d`) | ||||||
|  |  | ||||||
|  | **Example:** | ||||||
|  | ```python | ||||||
|  | See :class:`~torch.nn.Conv2d` for details and output shape. | ||||||
|  | ``` | ||||||
|  |  | ||||||
|  | ### 5. Notes and Warnings | ||||||
|  |  | ||||||
|  | Use admonitions for important information: | ||||||
|  |  | ||||||
|  | ```python | ||||||
|  | .. note:: | ||||||
|  |     This function doesn't work directly with NLLLoss, | ||||||
|  |     which expects the Log to be computed between the Softmax and itself. | ||||||
|  |     Use log_softmax instead (it's faster and has better numerical properties). | ||||||
|  |  | ||||||
|  | .. warning:: | ||||||
|  |     :func:`new_tensor` always copies :attr:`data`. If you have a Tensor | ||||||
|  |     ``data`` and want to avoid a copy, use :func:`torch.Tensor.requires_grad_` | ||||||
|  |     or :func:`torch.Tensor.detach`. | ||||||
|  | ``` | ||||||
|  |  | ||||||
|  | ### 6. Args Section | ||||||
|  |  | ||||||
|  | Document all parameters with type annotations and descriptions: | ||||||
|  |  | ||||||
|  | ```python | ||||||
|  | Args: | ||||||
|  |     input (Tensor): input tensor of shape :math:`(\text{minibatch} , \text{in\_channels} , iH , iW)` | ||||||
|  |     weight (Tensor): filters of shape :math:`(\text{out\_channels} , kH , kW)` | ||||||
|  |     bias (Tensor, optional): optional bias tensor of shape :math:`(\text{out\_channels})`. Default: ``None`` | ||||||
|  |     stride (int or tuple): the stride of the convolving kernel. Can be a single number or a | ||||||
|  |       tuple `(sH, sW)`. Default: 1 | ||||||
|  | ``` | ||||||
|  |  | ||||||
|  | **Formatting rules:** | ||||||
|  | - Parameter name in **lowercase** | ||||||
|  | - Type in parentheses: `(Type)`, `(Type, optional)` for optional parameters | ||||||
|  | - Description follows the type | ||||||
|  | - For optional parameters, include "Default: ``value``" at the end | ||||||
|  | - Use double backticks for inline code: ``` ``None`` ``` | ||||||
|  | - Indent continuation lines by 2 spaces | ||||||
|  |  | ||||||
|  | ### 7. Keyword Args Section (if applicable) | ||||||
|  |  | ||||||
|  | Sometimes keyword arguments are documented separately: | ||||||
|  |  | ||||||
|  | ```python | ||||||
|  | Keyword args: | ||||||
|  |     dtype (:class:`torch.dtype`, optional): the desired type of returned tensor. | ||||||
|  |         Default: if None, same :class:`torch.dtype` as this tensor. | ||||||
|  |     device (:class:`torch.device`, optional): the desired device of returned tensor. | ||||||
|  |         Default: if None, same :class:`torch.device` as this tensor. | ||||||
|  |     requires_grad (bool, optional): If autograd should record operations on the | ||||||
|  |         returned tensor. Default: ``False``. | ||||||
|  | ``` | ||||||
|  |  | ||||||
|  | ### 8. Returns Section (if needed) | ||||||
|  |  | ||||||
|  | Document the return value: | ||||||
|  |  | ||||||
|  | ```python | ||||||
|  | Returns: | ||||||
|  |     Tensor: Sampled tensor of same shape as `logits` from the Gumbel-Softmax distribution. | ||||||
|  |         If ``hard=True``, the returned samples will be one-hot, otherwise they will | ||||||
|  |         be probability distributions that sum to 1 across `dim`. | ||||||
|  | ``` | ||||||
|  |  | ||||||
|  | Or simply include it in the function signature line if obvious from context. | ||||||
|  |  | ||||||
|  | ### 9. Examples Section | ||||||
|  |  | ||||||
|  | Always include examples when possible: | ||||||
|  |  | ||||||
|  | ```python | ||||||
|  | Examples:: | ||||||
|  |  | ||||||
|  |     >>> inputs = torch.randn(33, 16, 30) | ||||||
|  |     >>> filters = torch.randn(20, 16, 5) | ||||||
|  |     >>> F.conv1d(inputs, filters) | ||||||
|  |  | ||||||
|  |     >>> # With square kernels and equal stride | ||||||
|  |     >>> filters = torch.randn(8, 4, 3, 3) | ||||||
|  |     >>> inputs = torch.randn(1, 4, 5, 5) | ||||||
|  |     >>> F.conv2d(inputs, filters, padding=1) | ||||||
|  | ``` | ||||||
|  |  | ||||||
|  | **Formatting rules:** | ||||||
|  | - Use `Examples::` with double colon | ||||||
|  | - Use `>>>` prompt for Python code | ||||||
|  | - Include comments with `#` when helpful | ||||||
|  | - Show actual output when it helps understanding (indent without `>>>`) | ||||||
|  |  | ||||||
|  | ### 10. External References | ||||||
|  |  | ||||||
|  | Link to papers or external documentation: | ||||||
|  |  | ||||||
|  | ```python | ||||||
|  | .. _Link Name: | ||||||
|  |     https://arxiv.org/abs/1611.00712 | ||||||
|  | ``` | ||||||
|  |  | ||||||
|  | Reference them in text: ```See `Link Name`_``` | ||||||
|  |  | ||||||
|  | ## Method Types | ||||||
|  |  | ||||||
|  | ### Native Python Functions | ||||||
|  |  | ||||||
|  | For regular Python functions, use a standard docstring: | ||||||
|  |  | ||||||
|  | ```python | ||||||
|  | def relu(input: Tensor, inplace: bool = False) -> Tensor: | ||||||
|  |     r"""relu(input, inplace=False) -> Tensor | ||||||
|  |  | ||||||
|  |     Applies the rectified linear unit function element-wise. See | ||||||
|  |     :class:`~torch.nn.ReLU` for more details. | ||||||
|  |     """ | ||||||
|  |     # implementation | ||||||
|  | ``` | ||||||
|  |  | ||||||
|  | ### C-Bound Functions (using add_docstr) | ||||||
|  |  | ||||||
|  | For C-bound functions, use `_add_docstr`: | ||||||
|  |  | ||||||
|  | ```python | ||||||
|  | conv1d = _add_docstr( | ||||||
|  |     torch.conv1d, | ||||||
|  |     r""" | ||||||
|  | conv1d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1) -> Tensor | ||||||
|  |  | ||||||
|  | Applies a 1D convolution over an input signal composed of several input | ||||||
|  | planes. | ||||||
|  |  | ||||||
|  | See :class:`~torch.nn.Conv1d` for details and output shape. | ||||||
|  |  | ||||||
|  | Args: | ||||||
|  |     input: input tensor of shape :math:`(\text{minibatch} , \text{in\_channels} , iW)` | ||||||
|  |     weight: filters of shape :math:`(\text{out\_channels} , kW)` | ||||||
|  |     ... | ||||||
|  | """, | ||||||
|  | ) | ||||||
|  | ``` | ||||||
|  |  | ||||||
|  | ### In-Place Variants | ||||||
|  |  | ||||||
|  | For in-place operations (ending with `_`), reference the original: | ||||||
|  |  | ||||||
|  | ```python | ||||||
|  | add_docstr_all( | ||||||
|  |     "abs_", | ||||||
|  |     r""" | ||||||
|  | abs_() -> Tensor | ||||||
|  |  | ||||||
|  | In-place version of :meth:`~Tensor.abs` | ||||||
|  | """, | ||||||
|  | ) | ||||||
|  | ``` | ||||||
|  |  | ||||||
|  | ### Alias Functions | ||||||
|  |  | ||||||
|  | For aliases, simply reference the original: | ||||||
|  |  | ||||||
|  | ```python | ||||||
|  | add_docstr_all( | ||||||
|  |     "absolute", | ||||||
|  |     r""" | ||||||
|  | absolute() -> Tensor | ||||||
|  |  | ||||||
|  | Alias for :func:`abs` | ||||||
|  | """, | ||||||
|  | ) | ||||||
|  | ``` | ||||||
|  |  | ||||||
|  | ## Common Patterns | ||||||
|  |  | ||||||
|  | ### Shape Documentation | ||||||
|  |  | ||||||
|  | Use LaTeX math notation for tensor shapes: | ||||||
|  |  | ||||||
|  | ```python | ||||||
|  | :math:`(\text{minibatch} , \text{in\_channels} , iH , iW)` | ||||||
|  | ``` | ||||||
|  |  | ||||||
|  | ### Reusable Argument Definitions | ||||||
|  |  | ||||||
|  | For commonly used arguments, define them once and reuse: | ||||||
|  |  | ||||||
|  | ```python | ||||||
|  | common_args = parse_kwargs( | ||||||
|  |     """ | ||||||
|  |     dtype (:class:`torch.dtype`, optional): the desired type of returned tensor. | ||||||
|  |         Default: if None, same as this tensor. | ||||||
|  | """ | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | # Then use with .format(): | ||||||
|  | r""" | ||||||
|  | ... | ||||||
|  |  | ||||||
|  | Keyword args: | ||||||
|  |     {dtype} | ||||||
|  |     {device} | ||||||
|  | """.format(**common_args) | ||||||
|  | ``` | ||||||
|  |  | ||||||
|  | ### Template Insertion | ||||||
|  |  | ||||||
|  | Insert reproducibility notes or other common text: | ||||||
|  |  | ||||||
|  | ```python | ||||||
|  | r""" | ||||||
|  | {tf32_note} | ||||||
|  |  | ||||||
|  | {cudnn_reproducibility_note} | ||||||
|  | """.format(**reproducibility_notes, **tf32_notes) | ||||||
|  | ``` | ||||||
|  |  | ||||||
|  | ## Complete Example | ||||||
|  |  | ||||||
|  | Here's a complete example showing all elements: | ||||||
|  |  | ||||||
|  | ```python | ||||||
|  | def gumbel_softmax( | ||||||
|  |     logits: Tensor, | ||||||
|  |     tau: float = 1, | ||||||
|  |     hard: bool = False, | ||||||
|  |     eps: float = 1e-10, | ||||||
|  |     dim: int = -1, | ||||||
|  | ) -> Tensor: | ||||||
|  |     r""" | ||||||
|  |     Sample from the Gumbel-Softmax distribution and optionally discretize. | ||||||
|  |  | ||||||
|  |     Args: | ||||||
|  |         logits (Tensor): `[..., num_features]` unnormalized log probabilities | ||||||
|  |         tau (float): non-negative scalar temperature | ||||||
|  |         hard (bool): if ``True``, the returned samples will be discretized as one-hot vectors, | ||||||
|  |               but will be differentiated as if it is the soft sample in autograd. Default: ``False`` | ||||||
|  |         dim (int): A dimension along which softmax will be computed. Default: -1 | ||||||
|  |  | ||||||
|  |     Returns: | ||||||
|  |         Tensor: Sampled tensor of same shape as `logits` from the Gumbel-Softmax distribution. | ||||||
|  |             If ``hard=True``, the returned samples will be one-hot, otherwise they will | ||||||
|  |             be probability distributions that sum to 1 across `dim`. | ||||||
|  |  | ||||||
|  |     .. note:: | ||||||
|  |         This function is here for legacy reasons, may be removed from nn.Functional in the future. | ||||||
|  |  | ||||||
|  |     Examples:: | ||||||
|  |         >>> logits = torch.randn(20, 32) | ||||||
|  |         >>> # Sample soft categorical using reparametrization trick: | ||||||
|  |         >>> F.gumbel_softmax(logits, tau=1, hard=False) | ||||||
|  |         >>> # Sample hard categorical using "Straight-through" trick: | ||||||
|  |         >>> F.gumbel_softmax(logits, tau=1, hard=True) | ||||||
|  |  | ||||||
|  |     .. _Link 1: | ||||||
|  |         https://arxiv.org/abs/1611.00712 | ||||||
|  |     """ | ||||||
|  |     # implementation | ||||||
|  | ``` | ||||||
|  |  | ||||||
|  | ## Quick Checklist | ||||||
|  |  | ||||||
|  | When writing a PyTorch docstring, ensure: | ||||||
|  |  | ||||||
|  | - [ ] Use raw string (`r"""`) | ||||||
|  | - [ ] Include function signature on first line | ||||||
|  | - [ ] Provide brief description | ||||||
|  | - [ ] Document all parameters in Args section with types | ||||||
|  | - [ ] Include default values for optional parameters | ||||||
|  | - [ ] Use Sphinx cross-references (`:func:`, `:class:`, `:meth:`) | ||||||
|  | - [ ] Add mathematical formulas if applicable | ||||||
|  | - [ ] Include at least one example in Examples section | ||||||
|  | - [ ] Add warnings/notes for important caveats | ||||||
|  | - [ ] Link to related module class with `:class:` | ||||||
|  | - [ ] Use proper math notation for tensor shapes | ||||||
|  | - [ ] Follow consistent formatting and indentation | ||||||
|  |  | ||||||
|  | ## Common Sphinx Roles Reference | ||||||
|  |  | ||||||
|  | - `:class:\`~torch.nn.Module\`` - Class reference | ||||||
|  | - `:func:\`torch.function\`` - Function reference | ||||||
|  | - `:meth:\`~Tensor.method\`` - Method reference | ||||||
|  | - `:attr:\`attribute\`` - Attribute reference | ||||||
|  | - `:math:\`equation\`` - Inline math | ||||||
|  | - `:ref:\`label\`` - Internal reference | ||||||
|  | - ``` ``code`` ``` - Inline code (use double backticks) | ||||||
|  |  | ||||||
|  | ## Additional Notes | ||||||
|  |  | ||||||
|  | - **Indentation**: Use 4 spaces for code, 2 spaces for continuation of parameter descriptions | ||||||
|  | - **Line length**: Try to keep lines under 100 characters when possible | ||||||
|  | - **Periods**: End sentences with periods, but not the signature line | ||||||
|  | - **Backticks**: Use double backticks for code: ``` ``True`` ``None`` ``False`` ``` | ||||||
|  | - **Types**: Common types are `Tensor`, `int`, `float`, `bool`, `str`, `tuple`, `list`, etc. | ||||||
							
								
								
									
										385
									
								
								.claude/skills/skill-writer/SKILL.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										385
									
								
								.claude/skills/skill-writer/SKILL.md
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,385 @@ | |||||||
|  | --- | ||||||
|  | name: skill-writer | ||||||
|  | description: Guide users through creating Agent Skills for Claude Code. Use when the user wants to create, write, author, or design a new Skill, or needs help with SKILL.md files, frontmatter, or skill structure. | ||||||
|  | --- | ||||||
|  |  | ||||||
|  | # Skill Writer | ||||||
|  |  | ||||||
|  | This Skill helps you create well-structured Agent Skills for Claude Code that follow best practices and validation requirements. | ||||||
|  |  | ||||||
|  | ## When to use this Skill | ||||||
|  |  | ||||||
|  | Use this Skill when: | ||||||
|  | - Creating a new Agent Skill | ||||||
|  | - Writing or updating SKILL.md files | ||||||
|  | - Designing skill structure and frontmatter | ||||||
|  | - Troubleshooting skill discovery issues | ||||||
|  | - Converting existing prompts or workflows into Skills | ||||||
|  |  | ||||||
|  | ## Instructions | ||||||
|  |  | ||||||
|  | ### Step 1: Determine Skill scope | ||||||
|  |  | ||||||
|  | First, understand what the Skill should do: | ||||||
|  |  | ||||||
|  | 1. **Ask clarifying questions**: | ||||||
|  |    - What specific capability should this Skill provide? | ||||||
|  |    - When should Claude use this Skill? | ||||||
|  |    - What tools or resources does it need? | ||||||
|  |    - Is this for personal use or team sharing? | ||||||
|  |  | ||||||
|  | 2. **Keep it focused**: One Skill = one capability | ||||||
|  |    - Good: "PDF form filling", "Excel data analysis" | ||||||
|  |    - Too broad: "Document processing", "Data tools" | ||||||
|  |  | ||||||
|  | ### Step 2: Choose Skill location | ||||||
|  |  | ||||||
|  | Determine where to create the Skill: | ||||||
|  |  | ||||||
|  | **Personal Skills** (`~/.claude/skills/`): | ||||||
|  | - Individual workflows and preferences | ||||||
|  | - Experimental Skills | ||||||
|  | - Personal productivity tools | ||||||
|  |  | ||||||
|  | **Project Skills** (`.claude/skills/`): | ||||||
|  | - Team workflows and conventions | ||||||
|  | - Project-specific expertise | ||||||
|  | - Shared utilities (committed to git) | ||||||
|  |  | ||||||
|  | ### Step 3: Create Skill structure | ||||||
|  |  | ||||||
|  | Create the directory and files: | ||||||
|  |  | ||||||
|  | ```bash | ||||||
|  | # Personal | ||||||
|  | mkdir -p ~/.claude/skills/skill-name | ||||||
|  |  | ||||||
|  | # Project | ||||||
|  | mkdir -p .claude/skills/skill-name | ||||||
|  | ``` | ||||||
|  |  | ||||||
|  | For multi-file Skills: | ||||||
|  | ``` | ||||||
|  | skill-name/ | ||||||
|  | ├── SKILL.md (required) | ||||||
|  | ├── reference.md (optional) | ||||||
|  | ├── examples.md (optional) | ||||||
|  | ├── scripts/ | ||||||
|  | │   └── helper.py (optional) | ||||||
|  | └── templates/ | ||||||
|  |     └── template.txt (optional) | ||||||
|  | ``` | ||||||
|  |  | ||||||
|  | ### Step 4: Write SKILL.md frontmatter | ||||||
|  |  | ||||||
|  | Create YAML frontmatter with required fields: | ||||||
|  |  | ||||||
|  | ```yaml | ||||||
|  | --- | ||||||
|  | name: skill-name | ||||||
|  | description: Brief description of what this does and when to use it | ||||||
|  | --- | ||||||
|  | ``` | ||||||
|  |  | ||||||
|  | **Field requirements**: | ||||||
|  |  | ||||||
|  | - **name**: | ||||||
|  |   - Lowercase letters, numbers, hyphens only | ||||||
|  |   - Max 64 characters | ||||||
|  |   - Must match directory name | ||||||
|  |   - Good: `pdf-processor`, `git-commit-helper` | ||||||
|  |   - Bad: `PDF_Processor`, `Git Commits!` | ||||||
|  |  | ||||||
|  | - **description**: | ||||||
|  |   - Max 1024 characters | ||||||
|  |   - Include BOTH what it does AND when to use it | ||||||
|  |   - Use specific trigger words users would say | ||||||
|  |   - Mention file types, operations, and context | ||||||
|  |  | ||||||
|  | **Optional frontmatter fields**: | ||||||
|  |  | ||||||
|  | - **allowed-tools**: Restrict tool access (comma-separated list) | ||||||
|  |   ```yaml | ||||||
|  |   allowed-tools: Read, Grep, Glob | ||||||
|  |   ``` | ||||||
|  |   Use for: | ||||||
|  |   - Read-only Skills | ||||||
|  |   - Security-sensitive workflows | ||||||
|  |   - Limited-scope operations | ||||||
|  |  | ||||||
|  | ### Step 5: Write effective descriptions | ||||||
|  |  | ||||||
|  | The description is critical for Claude to discover your Skill. | ||||||
|  |  | ||||||
|  | **Formula**: `[What it does] + [When to use it] + [Key triggers]` | ||||||
|  |  | ||||||
|  | **Examples**: | ||||||
|  |  | ||||||
|  | ✅ **Good**: | ||||||
|  | ```yaml | ||||||
|  | description: Extract text and tables from PDF files, fill forms, merge documents. Use when working with PDF files or when the user mentions PDFs, forms, or document extraction. | ||||||
|  | ``` | ||||||
|  |  | ||||||
|  | ✅ **Good**: | ||||||
|  | ```yaml | ||||||
|  | description: Analyze Excel spreadsheets, create pivot tables, and generate charts. Use when working with Excel files, spreadsheets, or analyzing tabular data in .xlsx format. | ||||||
|  | ``` | ||||||
|  |  | ||||||
|  | ❌ **Too vague**: | ||||||
|  | ```yaml | ||||||
|  | description: Helps with documents | ||||||
|  | description: For data analysis | ||||||
|  | ``` | ||||||
|  |  | ||||||
|  | **Tips**: | ||||||
|  | - Include specific file extensions (.pdf, .xlsx, .json) | ||||||
|  | - Mention common user phrases ("analyze", "extract", "generate") | ||||||
|  | - List concrete operations (not generic verbs) | ||||||
|  | - Add context clues ("Use when...", "For...") | ||||||
|  |  | ||||||
|  | ### Step 6: Structure the Skill content | ||||||
|  |  | ||||||
|  | Use clear Markdown sections: | ||||||
|  |  | ||||||
|  | ```markdown | ||||||
|  | # Skill Name | ||||||
|  |  | ||||||
|  | Brief overview of what this Skill does. | ||||||
|  |  | ||||||
|  | ## Quick start | ||||||
|  |  | ||||||
|  | Provide a simple example to get started immediately. | ||||||
|  |  | ||||||
|  | ## Instructions | ||||||
|  |  | ||||||
|  | Step-by-step guidance for Claude: | ||||||
|  | 1. First step with clear action | ||||||
|  | 2. Second step with expected outcome | ||||||
|  | 3. Handle edge cases | ||||||
|  |  | ||||||
|  | ## Examples | ||||||
|  |  | ||||||
|  | Show concrete usage examples with code or commands. | ||||||
|  |  | ||||||
|  | ## Best practices | ||||||
|  |  | ||||||
|  | - Key conventions to follow | ||||||
|  | - Common pitfalls to avoid | ||||||
|  | - When to use vs. not use | ||||||
|  |  | ||||||
|  | ## Requirements | ||||||
|  |  | ||||||
|  | List any dependencies or prerequisites: | ||||||
|  | ```bash | ||||||
|  | pip install package-name | ||||||
|  | ``` | ||||||
|  |  | ||||||
|  | ## Advanced usage | ||||||
|  |  | ||||||
|  | For complex scenarios, see [reference.md](reference.md). | ||||||
|  | ``` | ||||||
|  |  | ||||||
|  | ### Step 7: Add supporting files (optional) | ||||||
|  |  | ||||||
|  | Create additional files for progressive disclosure: | ||||||
|  |  | ||||||
|  | **reference.md**: Detailed API docs, advanced options | ||||||
|  | **examples.md**: Extended examples and use cases | ||||||
|  | **scripts/**: Helper scripts and utilities | ||||||
|  | **templates/**: File templates or boilerplate | ||||||
|  |  | ||||||
|  | Reference them from SKILL.md: | ||||||
|  | ```markdown | ||||||
|  | For advanced usage, see [reference.md](reference.md). | ||||||
|  |  | ||||||
|  | Run the helper script: | ||||||
|  | \`\`\`bash | ||||||
|  | python scripts/helper.py input.txt | ||||||
|  | \`\`\` | ||||||
|  | ``` | ||||||
|  |  | ||||||
|  | ### Step 8: Validate the Skill | ||||||
|  |  | ||||||
|  | Check these requirements: | ||||||
|  |  | ||||||
|  | ✅ **File structure**: | ||||||
|  | - [ ] SKILL.md exists in correct location | ||||||
|  | - [ ] Directory name matches frontmatter `name` | ||||||
|  |  | ||||||
|  | ✅ **YAML frontmatter**: | ||||||
|  | - [ ] Opening `---` on line 1 | ||||||
|  | - [ ] Closing `---` before content | ||||||
|  | - [ ] Valid YAML (no tabs, correct indentation) | ||||||
|  | - [ ] `name` follows naming rules | ||||||
|  | - [ ] `description` is specific and < 1024 chars | ||||||
|  |  | ||||||
|  | ✅ **Content quality**: | ||||||
|  | - [ ] Clear instructions for Claude | ||||||
|  | - [ ] Concrete examples provided | ||||||
|  | - [ ] Edge cases handled | ||||||
|  | - [ ] Dependencies listed (if any) | ||||||
|  |  | ||||||
|  | ✅ **Testing**: | ||||||
|  | - [ ] Description matches user questions | ||||||
|  | - [ ] Skill activates on relevant queries | ||||||
|  | - [ ] Instructions are clear and actionable | ||||||
|  |  | ||||||
|  | ### Step 9: Test the Skill | ||||||
|  |  | ||||||
|  | 1. **Restart Claude Code** (if running) to load the Skill | ||||||
|  |  | ||||||
|  | 2. **Ask relevant questions** that match the description: | ||||||
|  |    ``` | ||||||
|  |    Can you help me extract text from this PDF? | ||||||
|  |    ``` | ||||||
|  |  | ||||||
|  | 3. **Verify activation**: Claude should use the Skill automatically | ||||||
|  |  | ||||||
|  | 4. **Check behavior**: Confirm Claude follows the instructions correctly | ||||||
|  |  | ||||||
|  | ### Step 10: Debug if needed | ||||||
|  |  | ||||||
|  | If Claude doesn't use the Skill: | ||||||
|  |  | ||||||
|  | 1. **Make description more specific**: | ||||||
|  |    - Add trigger words | ||||||
|  |    - Include file types | ||||||
|  |    - Mention common user phrases | ||||||
|  |  | ||||||
|  | 2. **Check file location**: | ||||||
|  |    ```bash | ||||||
|  |    ls ~/.claude/skills/skill-name/SKILL.md | ||||||
|  |    ls .claude/skills/skill-name/SKILL.md | ||||||
|  |    ``` | ||||||
|  |  | ||||||
|  | 3. **Validate YAML**: | ||||||
|  |    ```bash | ||||||
|  |    cat SKILL.md | head -n 10 | ||||||
|  |    ``` | ||||||
|  |  | ||||||
|  | 4. **Run debug mode**: | ||||||
|  |    ```bash | ||||||
|  |    claude --debug | ||||||
|  |    ``` | ||||||
|  |  | ||||||
|  | ## Common patterns | ||||||
|  |  | ||||||
|  | ### Read-only Skill | ||||||
|  |  | ||||||
|  | ```yaml | ||||||
|  | --- | ||||||
|  | name: code-reader | ||||||
|  | description: Read and analyze code without making changes. Use for code review, understanding codebases, or documentation. | ||||||
|  | allowed-tools: Read, Grep, Glob | ||||||
|  | --- | ||||||
|  | ``` | ||||||
|  |  | ||||||
|  | ### Script-based Skill | ||||||
|  |  | ||||||
|  | ```yaml | ||||||
|  | --- | ||||||
|  | name: data-processor | ||||||
|  | description: Process CSV and JSON data files with Python scripts. Use when analyzing data files or transforming datasets. | ||||||
|  | --- | ||||||
|  |  | ||||||
|  | # Data Processor | ||||||
|  |  | ||||||
|  | ## Instructions | ||||||
|  |  | ||||||
|  | 1. Use the processing script: | ||||||
|  | \`\`\`bash | ||||||
|  | python scripts/process.py input.csv --output results.json | ||||||
|  | \`\`\` | ||||||
|  |  | ||||||
|  | 2. Validate output with: | ||||||
|  | \`\`\`bash | ||||||
|  | python scripts/validate.py results.json | ||||||
|  | \`\`\` | ||||||
|  | ``` | ||||||
|  |  | ||||||
|  | ### Multi-file Skill with progressive disclosure | ||||||
|  |  | ||||||
|  | ```yaml | ||||||
|  | --- | ||||||
|  | name: api-designer | ||||||
|  | description: Design REST APIs following best practices. Use when creating API endpoints, designing routes, or planning API architecture. | ||||||
|  | --- | ||||||
|  |  | ||||||
|  | # API Designer | ||||||
|  |  | ||||||
|  | Quick start: See [examples.md](examples.md) | ||||||
|  |  | ||||||
|  | Detailed reference: See [reference.md](reference.md) | ||||||
|  |  | ||||||
|  | ## Instructions | ||||||
|  |  | ||||||
|  | 1. Gather requirements | ||||||
|  | 2. Design endpoints (see examples.md) | ||||||
|  | 3. Document with OpenAPI spec | ||||||
|  | 4. Review against best practices (see reference.md) | ||||||
|  | ``` | ||||||
|  |  | ||||||
|  | ## Best practices for Skill authors | ||||||
|  |  | ||||||
|  | 1. **One Skill, one purpose**: Don't create mega-Skills | ||||||
|  | 2. **Specific descriptions**: Include trigger words users will say | ||||||
|  | 3. **Clear instructions**: Write for Claude, not humans | ||||||
|  | 4. **Concrete examples**: Show real code, not pseudocode | ||||||
|  | 5. **List dependencies**: Mention required packages in description | ||||||
|  | 6. **Test with teammates**: Verify activation and clarity | ||||||
|  | 7. **Version your Skills**: Document changes in content | ||||||
|  | 8. **Use progressive disclosure**: Put advanced details in separate files | ||||||
|  |  | ||||||
|  | ## Validation checklist | ||||||
|  |  | ||||||
|  | Before finalizing a Skill, verify: | ||||||
|  |  | ||||||
|  | - [ ] Name is lowercase, hyphens only, max 64 chars | ||||||
|  | - [ ] Description is specific and < 1024 chars | ||||||
|  | - [ ] Description includes "what" and "when" | ||||||
|  | - [ ] YAML frontmatter is valid | ||||||
|  | - [ ] Instructions are step-by-step | ||||||
|  | - [ ] Examples are concrete and realistic | ||||||
|  | - [ ] Dependencies are documented | ||||||
|  | - [ ] File paths use forward slashes | ||||||
|  | - [ ] Skill activates on relevant queries | ||||||
|  | - [ ] Claude follows instructions correctly | ||||||
|  |  | ||||||
|  | ## Troubleshooting | ||||||
|  |  | ||||||
|  | **Skill doesn't activate**: | ||||||
|  | - Make description more specific with trigger words | ||||||
|  | - Include file types and operations in description | ||||||
|  | - Add "Use when..." clause with user phrases | ||||||
|  |  | ||||||
|  | **Multiple Skills conflict**: | ||||||
|  | - Make descriptions more distinct | ||||||
|  | - Use different trigger words | ||||||
|  | - Narrow the scope of each Skill | ||||||
|  |  | ||||||
|  | **Skill has errors**: | ||||||
|  | - Check YAML syntax (no tabs, proper indentation) | ||||||
|  | - Verify file paths (use forward slashes) | ||||||
|  | - Ensure scripts have execute permissions | ||||||
|  | - List all dependencies | ||||||
|  |  | ||||||
|  | ## Examples | ||||||
|  |  | ||||||
|  | See the documentation for complete examples: | ||||||
|  | - Simple single-file Skill (commit-helper) | ||||||
|  | - Skill with tool permissions (code-reviewer) | ||||||
|  | - Multi-file Skill (pdf-processing) | ||||||
|  |  | ||||||
|  | ## Output format | ||||||
|  |  | ||||||
|  | When creating a Skill, I will: | ||||||
|  |  | ||||||
|  | 1. Ask clarifying questions about scope and requirements | ||||||
|  | 2. Suggest a Skill name and location | ||||||
|  | 3. Create the SKILL.md file with proper frontmatter | ||||||
|  | 4. Include clear instructions and examples | ||||||
|  | 5. Add supporting files if needed | ||||||
|  | 6. Provide testing instructions | ||||||
|  | 7. Validate against all requirements | ||||||
|  |  | ||||||
|  | The result will be a complete, working Skill that follows all best practices and validation rules. | ||||||
							
								
								
									
										7
									
								
								.github/actions/setup-rocm/action.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										7
									
								
								.github/actions/setup-rocm/action.yml
									
									
									
									
										vendored
									
									
								
							| @ -124,3 +124,10 @@ runs: | |||||||
|       id: login-ecr |       id: login-ecr | ||||||
|       continue-on-error: true |       continue-on-error: true | ||||||
|       uses: aws-actions/amazon-ecr-login@062b18b96a7aff071d4dc91bc00c4c1a7945b076 # v2.0.1 |       uses: aws-actions/amazon-ecr-login@062b18b96a7aff071d4dc91bc00c4c1a7945b076 # v2.0.1 | ||||||
|  |  | ||||||
|  |     - name: Preserve github env variables for use in docker | ||||||
|  |       shell: bash | ||||||
|  |       run: | | ||||||
|  |         env | grep '^GITHUB' >> "${RUNNER_TEMP}/github_env_${GITHUB_RUN_ID}" | ||||||
|  |         env | grep '^CI' >> "${RUNNER_TEMP}/github_env_${GITHUB_RUN_ID}" | ||||||
|  |         env | grep '^RUNNER' >> "${RUNNER_TEMP}/github_env_${GITHUB_RUN_ID}" | ||||||
|  | |||||||
							
								
								
									
										2
									
								
								.github/ci_commit_pins/vision.txt
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.github/ci_commit_pins/vision.txt
									
									
									
									
										vendored
									
									
								
							| @ -1 +1 @@ | |||||||
| faffd5cf673615583da6517275e361cb3dbc77e6 | 1752fe6809b74921644866275ab80244b96e80bc | ||||||
|  | |||||||
							
								
								
									
										2
									
								
								.github/ci_commit_pins/xla.txt
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.github/ci_commit_pins/xla.txt
									
									
									
									
										vendored
									
									
								
							| @ -1 +1 @@ | |||||||
| 0fa6e3129e61143224663e1ec67980d12b7ec4eb | df6798dfb931ce7c7fe5bed2447cd1092a5981af | ||||||
|  | |||||||
							
								
								
									
										5
									
								
								.github/ci_configs/vllm/Dockerfile
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										5
									
								
								.github/ci_configs/vllm/Dockerfile
									
									
									
									
										vendored
									
									
								
							| @ -283,6 +283,9 @@ RUN --mount=type=bind,source=${TORCH_WHEELS_PATH},target=/dist \ | |||||||
|         uv pip install --system $(cat torch_build_versions.txt | xargs) --index-url https://download.pytorch.org/whl/nightly/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.'); \ |         uv pip install --system $(cat torch_build_versions.txt | xargs) --index-url https://download.pytorch.org/whl/nightly/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.'); \ | ||||||
|     fi |     fi | ||||||
|  |  | ||||||
|  | RUN --mount=type=cache,target=/root/.cache/uv \ | ||||||
|  |     uv pip install --system --pre apache-tvm-ffi==0.1.0b15 | ||||||
|  |  | ||||||
| # Install the vllm wheel from previous stage | # Install the vllm wheel from previous stage | ||||||
| RUN --mount=type=cache,target=/root/.cache/uv \ | RUN --mount=type=cache,target=/root/.cache/uv \ | ||||||
|     uv pip install --system /wheels/vllm/*.whl --verbose |     uv pip install --system /wheels/vllm/*.whl --verbose | ||||||
| @ -295,6 +298,8 @@ RUN --mount=type=cache,target=/root/.cache/uv \ | |||||||
| ARG torch_cuda_arch_list='8.0;8.9;9.0a;10.0a;12.0' | ARG torch_cuda_arch_list='8.0;8.9;9.0a;10.0a;12.0' | ||||||
| ENV TORCH_CUDA_ARCH_LIST=${torch_cuda_arch_list} | ENV TORCH_CUDA_ARCH_LIST=${torch_cuda_arch_list} | ||||||
|  |  | ||||||
|  | # TODO(elainewy): remove this once vllm commit is updated, and install flashinfer from pip | ||||||
|  | # see https://github.com/pytorch/pytorch/pull/165274#issuecomment-3408531784 | ||||||
| ARG FLASHINFER_GIT_REPO="https://github.com/flashinfer-ai/flashinfer.git" | ARG FLASHINFER_GIT_REPO="https://github.com/flashinfer-ai/flashinfer.git" | ||||||
| ARG FLASHINFER_GIT_REF="v0.2.14.post1" | ARG FLASHINFER_GIT_REF="v0.2.14.post1" | ||||||
|  |  | ||||||
|  | |||||||
							
								
								
									
										9
									
								
								.github/label_to_label.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										9
									
								
								.github/label_to_label.yml
									
									
									
									
										vendored
									
									
								
							| @ -15,6 +15,11 @@ | |||||||
|   - "module: reinplacing" |   - "module: reinplacing" | ||||||
|   then: |   then: | ||||||
|   - "module: pt2-dispatcher" |   - "module: pt2-dispatcher" | ||||||
|  | - any: | ||||||
|  |   - "vllm-compile" | ||||||
|  |   then: | ||||||
|  |   - "module: vllm" | ||||||
|  |   - "oncall: pt2" | ||||||
| - any: | - any: | ||||||
|   - "module: vmap" |   - "module: vmap" | ||||||
|   then: |   then: | ||||||
| @ -27,10 +32,6 @@ | |||||||
|   - "module: pt2 optimizer" |   - "module: pt2 optimizer" | ||||||
|   then: |   then: | ||||||
|   - "module: dynamo" |   - "module: dynamo" | ||||||
| - any: |  | ||||||
|   - "module: flex attention" |  | ||||||
|   then: |  | ||||||
|   - "module: higher order operators" |  | ||||||
| - any: | - any: | ||||||
|   - "module: aotinductor" |   - "module: aotinductor" | ||||||
|   then: |   then: | ||||||
|  | |||||||
							
								
								
									
										22
									
								
								.github/scripts/generate_binary_build_matrix.py
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										22
									
								
								.github/scripts/generate_binary_build_matrix.py
									
									
									
									
										vendored
									
									
								
							| @ -22,7 +22,7 @@ CUDA_ARCHES_FULL_VERSION = { | |||||||
|     "12.6": "12.6.3", |     "12.6": "12.6.3", | ||||||
|     "12.8": "12.8.1", |     "12.8": "12.8.1", | ||||||
|     "12.9": "12.9.1", |     "12.9": "12.9.1", | ||||||
|     "13.0": "13.0.0", |     "13.0": "13.0.2", | ||||||
| } | } | ||||||
| CUDA_ARCHES_CUDNN_VERSION = { | CUDA_ARCHES_CUDNN_VERSION = { | ||||||
|     "12.6": "9", |     "12.6": "9", | ||||||
| @ -96,21 +96,21 @@ PYTORCH_EXTRA_INSTALL_REQUIREMENTS = { | |||||||
|         "nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux'" |         "nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux'" | ||||||
|     ), |     ), | ||||||
|     "13.0": ( |     "13.0": ( | ||||||
|         "nvidia-cuda-nvrtc==13.0.48; platform_system == 'Linux' | " |         "nvidia-cuda-nvrtc==13.0.88; platform_system == 'Linux' | " | ||||||
|         "nvidia-cuda-runtime==13.0.48; platform_system == 'Linux' | " |         "nvidia-cuda-runtime==13.0.96; platform_system == 'Linux' | " | ||||||
|         "nvidia-cuda-cupti==13.0.48; platform_system == 'Linux' | " |         "nvidia-cuda-cupti==13.0.85; platform_system == 'Linux' | " | ||||||
|         "nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | " |         "nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | " | ||||||
|         "nvidia-cublas==13.0.0.19; platform_system == 'Linux' | " |         "nvidia-cublas==13.1.0.3; platform_system == 'Linux' | " | ||||||
|         "nvidia-cufft==12.0.0.15; platform_system == 'Linux' | " |         "nvidia-cufft==12.0.0.61; platform_system == 'Linux' | " | ||||||
|         "nvidia-curand==10.4.0.35; platform_system == 'Linux' | " |         "nvidia-curand==10.4.0.35; platform_system == 'Linux' | " | ||||||
|         "nvidia-cusolver==12.0.3.29; platform_system == 'Linux' | " |         "nvidia-cusolver==12.0.4.66; platform_system == 'Linux' | " | ||||||
|         "nvidia-cusparse==12.6.2.49; platform_system == 'Linux' | " |         "nvidia-cusparse==12.6.3.3; platform_system == 'Linux' | " | ||||||
|         "nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | " |         "nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | " | ||||||
|         "nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | " |         "nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | " | ||||||
|         "nvidia-nvshmem-cu13==3.3.24; platform_system == 'Linux' | " |         "nvidia-nvshmem-cu13==3.3.24; platform_system == 'Linux' | " | ||||||
|         "nvidia-nvtx==13.0.39; platform_system == 'Linux' | " |         "nvidia-nvtx==13.0.85; platform_system == 'Linux' | " | ||||||
|         "nvidia-nvjitlink==13.0.39; platform_system == 'Linux' | " |         "nvidia-nvjitlink==13.0.88; platform_system == 'Linux' | " | ||||||
|         "nvidia-cufile==1.15.0.42; platform_system == 'Linux'" |         "nvidia-cufile==1.15.1.6; platform_system == 'Linux'" | ||||||
|     ), |     ), | ||||||
|     "xpu": ( |     "xpu": ( | ||||||
|         "intel-cmplr-lib-rt==2025.2.1 | " |         "intel-cmplr-lib-rt==2025.2.1 | " | ||||||
|  | |||||||
| @ -79,9 +79,9 @@ jobs: | |||||||
|     runs-on: "windows-11-arm64-preview" |     runs-on: "windows-11-arm64-preview" | ||||||
|     {%- else %} |     {%- else %} | ||||||
|     {%- if branches == "nightly" %} |     {%- if branches == "nightly" %} | ||||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" |     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge" | ||||||
|     {%- else %} |     {%- else %} | ||||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" |     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge.nonephemeral" | ||||||
|     {%- endif %} |     {%- endif %} | ||||||
|     {%- endif %} |     {%- endif %} | ||||||
|     timeout-minutes: !{{ common.timeout_minutes_windows_binary }} |     timeout-minutes: !{{ common.timeout_minutes_windows_binary }} | ||||||
|  | |||||||
							
								
								
									
										14
									
								
								.github/workflows/generated-linux-aarch64-binary-manywheel-nightly.yml
									
									
									
										generated
									
									
										vendored
									
									
								
							
							
						
						
									
										14
									
								
								.github/workflows/generated-linux-aarch64-binary-manywheel-nightly.yml
									
									
									
										generated
									
									
										vendored
									
									
								
							| @ -270,7 +270,7 @@ jobs: | |||||||
|       ALPINE_IMAGE: "arm64v8/alpine" |       ALPINE_IMAGE: "arm64v8/alpine" | ||||||
|       build_name: manywheel-py3_10-cuda-aarch64-13_0 |       build_name: manywheel-py3_10-cuda-aarch64-13_0 | ||||||
|       build_environment: linux-aarch64-binary-manywheel |       build_environment: linux-aarch64-binary-manywheel | ||||||
|       PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc==13.0.48; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.48; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.48; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.0.0.19; platform_system == 'Linux' | nvidia-cufft==12.0.0.15; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.3.29; platform_system == 'Linux' | nvidia-cusparse==12.6.2.49; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.3.24; platform_system == 'Linux' | nvidia-nvtx==13.0.39; platform_system == 'Linux' | nvidia-nvjitlink==13.0.39; platform_system == 'Linux' | nvidia-cufile==1.15.0.42; platform_system == 'Linux' |       PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc==13.0.88; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.96; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.85; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.1.0.3; platform_system == 'Linux' | nvidia-cufft==12.0.0.61; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.4.66; platform_system == 'Linux' | nvidia-cusparse==12.6.3.3; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.3.24; platform_system == 'Linux' | nvidia-nvtx==13.0.85; platform_system == 'Linux' | nvidia-nvjitlink==13.0.88; platform_system == 'Linux' | nvidia-cufile==1.15.1.6; platform_system == 'Linux' | ||||||
|       timeout-minutes: 420 |       timeout-minutes: 420 | ||||||
|     secrets: |     secrets: | ||||||
|       github-token: ${{ secrets.GITHUB_TOKEN }} |       github-token: ${{ secrets.GITHUB_TOKEN }} | ||||||
| @ -519,7 +519,7 @@ jobs: | |||||||
|       ALPINE_IMAGE: "arm64v8/alpine" |       ALPINE_IMAGE: "arm64v8/alpine" | ||||||
|       build_name: manywheel-py3_11-cuda-aarch64-13_0 |       build_name: manywheel-py3_11-cuda-aarch64-13_0 | ||||||
|       build_environment: linux-aarch64-binary-manywheel |       build_environment: linux-aarch64-binary-manywheel | ||||||
|       PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc==13.0.48; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.48; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.48; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.0.0.19; platform_system == 'Linux' | nvidia-cufft==12.0.0.15; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.3.29; platform_system == 'Linux' | nvidia-cusparse==12.6.2.49; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.3.24; platform_system == 'Linux' | nvidia-nvtx==13.0.39; platform_system == 'Linux' | nvidia-nvjitlink==13.0.39; platform_system == 'Linux' | nvidia-cufile==1.15.0.42; platform_system == 'Linux' |       PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc==13.0.88; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.96; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.85; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.1.0.3; platform_system == 'Linux' | nvidia-cufft==12.0.0.61; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.4.66; platform_system == 'Linux' | nvidia-cusparse==12.6.3.3; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.3.24; platform_system == 'Linux' | nvidia-nvtx==13.0.85; platform_system == 'Linux' | nvidia-nvjitlink==13.0.88; platform_system == 'Linux' | nvidia-cufile==1.15.1.6; platform_system == 'Linux' | ||||||
|       timeout-minutes: 420 |       timeout-minutes: 420 | ||||||
|     secrets: |     secrets: | ||||||
|       github-token: ${{ secrets.GITHUB_TOKEN }} |       github-token: ${{ secrets.GITHUB_TOKEN }} | ||||||
| @ -768,7 +768,7 @@ jobs: | |||||||
|       ALPINE_IMAGE: "arm64v8/alpine" |       ALPINE_IMAGE: "arm64v8/alpine" | ||||||
|       build_name: manywheel-py3_12-cuda-aarch64-13_0 |       build_name: manywheel-py3_12-cuda-aarch64-13_0 | ||||||
|       build_environment: linux-aarch64-binary-manywheel |       build_environment: linux-aarch64-binary-manywheel | ||||||
|       PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc==13.0.48; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.48; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.48; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.0.0.19; platform_system == 'Linux' | nvidia-cufft==12.0.0.15; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.3.29; platform_system == 'Linux' | nvidia-cusparse==12.6.2.49; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.3.24; platform_system == 'Linux' | nvidia-nvtx==13.0.39; platform_system == 'Linux' | nvidia-nvjitlink==13.0.39; platform_system == 'Linux' | nvidia-cufile==1.15.0.42; platform_system == 'Linux' |       PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc==13.0.88; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.96; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.85; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.1.0.3; platform_system == 'Linux' | nvidia-cufft==12.0.0.61; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.4.66; platform_system == 'Linux' | nvidia-cusparse==12.6.3.3; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.3.24; platform_system == 'Linux' | nvidia-nvtx==13.0.85; platform_system == 'Linux' | nvidia-nvjitlink==13.0.88; platform_system == 'Linux' | nvidia-cufile==1.15.1.6; platform_system == 'Linux' | ||||||
|       timeout-minutes: 420 |       timeout-minutes: 420 | ||||||
|     secrets: |     secrets: | ||||||
|       github-token: ${{ secrets.GITHUB_TOKEN }} |       github-token: ${{ secrets.GITHUB_TOKEN }} | ||||||
| @ -1017,7 +1017,7 @@ jobs: | |||||||
|       ALPINE_IMAGE: "arm64v8/alpine" |       ALPINE_IMAGE: "arm64v8/alpine" | ||||||
|       build_name: manywheel-py3_13-cuda-aarch64-13_0 |       build_name: manywheel-py3_13-cuda-aarch64-13_0 | ||||||
|       build_environment: linux-aarch64-binary-manywheel |       build_environment: linux-aarch64-binary-manywheel | ||||||
|       PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc==13.0.48; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.48; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.48; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.0.0.19; platform_system == 'Linux' | nvidia-cufft==12.0.0.15; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.3.29; platform_system == 'Linux' | nvidia-cusparse==12.6.2.49; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.3.24; platform_system == 'Linux' | nvidia-nvtx==13.0.39; platform_system == 'Linux' | nvidia-nvjitlink==13.0.39; platform_system == 'Linux' | nvidia-cufile==1.15.0.42; platform_system == 'Linux' |       PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc==13.0.88; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.96; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.85; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.1.0.3; platform_system == 'Linux' | nvidia-cufft==12.0.0.61; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.4.66; platform_system == 'Linux' | nvidia-cusparse==12.6.3.3; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.3.24; platform_system == 'Linux' | nvidia-nvtx==13.0.85; platform_system == 'Linux' | nvidia-nvjitlink==13.0.88; platform_system == 'Linux' | nvidia-cufile==1.15.1.6; platform_system == 'Linux' | ||||||
|       timeout-minutes: 420 |       timeout-minutes: 420 | ||||||
|     secrets: |     secrets: | ||||||
|       github-token: ${{ secrets.GITHUB_TOKEN }} |       github-token: ${{ secrets.GITHUB_TOKEN }} | ||||||
| @ -1266,7 +1266,7 @@ jobs: | |||||||
|       ALPINE_IMAGE: "arm64v8/alpine" |       ALPINE_IMAGE: "arm64v8/alpine" | ||||||
|       build_name: manywheel-py3_13t-cuda-aarch64-13_0 |       build_name: manywheel-py3_13t-cuda-aarch64-13_0 | ||||||
|       build_environment: linux-aarch64-binary-manywheel |       build_environment: linux-aarch64-binary-manywheel | ||||||
|       PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc==13.0.48; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.48; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.48; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.0.0.19; platform_system == 'Linux' | nvidia-cufft==12.0.0.15; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.3.29; platform_system == 'Linux' | nvidia-cusparse==12.6.2.49; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.3.24; platform_system == 'Linux' | nvidia-nvtx==13.0.39; platform_system == 'Linux' | nvidia-nvjitlink==13.0.39; platform_system == 'Linux' | nvidia-cufile==1.15.0.42; platform_system == 'Linux' |       PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc==13.0.88; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.96; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.85; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.1.0.3; platform_system == 'Linux' | nvidia-cufft==12.0.0.61; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.4.66; platform_system == 'Linux' | nvidia-cusparse==12.6.3.3; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.3.24; platform_system == 'Linux' | nvidia-nvtx==13.0.85; platform_system == 'Linux' | nvidia-nvjitlink==13.0.88; platform_system == 'Linux' | nvidia-cufile==1.15.1.6; platform_system == 'Linux' | ||||||
|       timeout-minutes: 420 |       timeout-minutes: 420 | ||||||
|     secrets: |     secrets: | ||||||
|       github-token: ${{ secrets.GITHUB_TOKEN }} |       github-token: ${{ secrets.GITHUB_TOKEN }} | ||||||
| @ -1515,7 +1515,7 @@ jobs: | |||||||
|       ALPINE_IMAGE: "arm64v8/alpine" |       ALPINE_IMAGE: "arm64v8/alpine" | ||||||
|       build_name: manywheel-py3_14-cuda-aarch64-13_0 |       build_name: manywheel-py3_14-cuda-aarch64-13_0 | ||||||
|       build_environment: linux-aarch64-binary-manywheel |       build_environment: linux-aarch64-binary-manywheel | ||||||
|       PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc==13.0.48; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.48; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.48; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.0.0.19; platform_system == 'Linux' | nvidia-cufft==12.0.0.15; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.3.29; platform_system == 'Linux' | nvidia-cusparse==12.6.2.49; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.3.24; platform_system == 'Linux' | nvidia-nvtx==13.0.39; platform_system == 'Linux' | nvidia-nvjitlink==13.0.39; platform_system == 'Linux' | nvidia-cufile==1.15.0.42; platform_system == 'Linux' |       PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc==13.0.88; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.96; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.85; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.1.0.3; platform_system == 'Linux' | nvidia-cufft==12.0.0.61; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.4.66; platform_system == 'Linux' | nvidia-cusparse==12.6.3.3; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.3.24; platform_system == 'Linux' | nvidia-nvtx==13.0.85; platform_system == 'Linux' | nvidia-nvjitlink==13.0.88; platform_system == 'Linux' | nvidia-cufile==1.15.1.6; platform_system == 'Linux' | ||||||
|       timeout-minutes: 420 |       timeout-minutes: 420 | ||||||
|     secrets: |     secrets: | ||||||
|       github-token: ${{ secrets.GITHUB_TOKEN }} |       github-token: ${{ secrets.GITHUB_TOKEN }} | ||||||
| @ -1764,7 +1764,7 @@ jobs: | |||||||
|       ALPINE_IMAGE: "arm64v8/alpine" |       ALPINE_IMAGE: "arm64v8/alpine" | ||||||
|       build_name: manywheel-py3_14t-cuda-aarch64-13_0 |       build_name: manywheel-py3_14t-cuda-aarch64-13_0 | ||||||
|       build_environment: linux-aarch64-binary-manywheel |       build_environment: linux-aarch64-binary-manywheel | ||||||
|       PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc==13.0.48; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.48; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.48; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.0.0.19; platform_system == 'Linux' | nvidia-cufft==12.0.0.15; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.3.29; platform_system == 'Linux' | nvidia-cusparse==12.6.2.49; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.3.24; platform_system == 'Linux' | nvidia-nvtx==13.0.39; platform_system == 'Linux' | nvidia-nvjitlink==13.0.39; platform_system == 'Linux' | nvidia-cufile==1.15.0.42; platform_system == 'Linux' |       PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc==13.0.88; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.96; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.85; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.1.0.3; platform_system == 'Linux' | nvidia-cufft==12.0.0.61; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.4.66; platform_system == 'Linux' | nvidia-cusparse==12.6.3.3; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.3.24; platform_system == 'Linux' | nvidia-nvtx==13.0.85; platform_system == 'Linux' | nvidia-nvjitlink==13.0.88; platform_system == 'Linux' | nvidia-cufile==1.15.1.6; platform_system == 'Linux' | ||||||
|       timeout-minutes: 420 |       timeout-minutes: 420 | ||||||
|     secrets: |     secrets: | ||||||
|       github-token: ${{ secrets.GITHUB_TOKEN }} |       github-token: ${{ secrets.GITHUB_TOKEN }} | ||||||
|  | |||||||
							
								
								
									
										14
									
								
								.github/workflows/generated-linux-binary-manywheel-nightly.yml
									
									
									
										generated
									
									
										vendored
									
									
								
							
							
						
						
									
										14
									
								
								.github/workflows/generated-linux-binary-manywheel-nightly.yml
									
									
									
										generated
									
									
										vendored
									
									
								
							| @ -325,7 +325,7 @@ jobs: | |||||||
|       runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" |       runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" | ||||||
|       build_name: manywheel-py3_10-cuda13_0 |       build_name: manywheel-py3_10-cuda13_0 | ||||||
|       build_environment: linux-binary-manywheel |       build_environment: linux-binary-manywheel | ||||||
|       PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc==13.0.48; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.48; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.48; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.0.0.19; platform_system == 'Linux' | nvidia-cufft==12.0.0.15; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.3.29; platform_system == 'Linux' | nvidia-cusparse==12.6.2.49; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.3.24; platform_system == 'Linux' | nvidia-nvtx==13.0.39; platform_system == 'Linux' | nvidia-nvjitlink==13.0.39; platform_system == 'Linux' | nvidia-cufile==1.15.0.42; platform_system == 'Linux' |       PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc==13.0.88; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.96; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.85; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.1.0.3; platform_system == 'Linux' | nvidia-cufft==12.0.0.61; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.4.66; platform_system == 'Linux' | nvidia-cusparse==12.6.3.3; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.3.24; platform_system == 'Linux' | nvidia-nvtx==13.0.85; platform_system == 'Linux' | nvidia-nvjitlink==13.0.88; platform_system == 'Linux' | nvidia-cufile==1.15.1.6; platform_system == 'Linux' | ||||||
|     secrets: |     secrets: | ||||||
|       github-token: ${{ secrets.GITHUB_TOKEN }} |       github-token: ${{ secrets.GITHUB_TOKEN }} | ||||||
|   manywheel-py3_10-cuda13_0-test:  # Testing |   manywheel-py3_10-cuda13_0-test:  # Testing | ||||||
| @ -991,7 +991,7 @@ jobs: | |||||||
|       runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" |       runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" | ||||||
|       build_name: manywheel-py3_11-cuda13_0 |       build_name: manywheel-py3_11-cuda13_0 | ||||||
|       build_environment: linux-binary-manywheel |       build_environment: linux-binary-manywheel | ||||||
|       PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc==13.0.48; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.48; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.48; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.0.0.19; platform_system == 'Linux' | nvidia-cufft==12.0.0.15; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.3.29; platform_system == 'Linux' | nvidia-cusparse==12.6.2.49; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.3.24; platform_system == 'Linux' | nvidia-nvtx==13.0.39; platform_system == 'Linux' | nvidia-nvjitlink==13.0.39; platform_system == 'Linux' | nvidia-cufile==1.15.0.42; platform_system == 'Linux' |       PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc==13.0.88; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.96; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.85; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.1.0.3; platform_system == 'Linux' | nvidia-cufft==12.0.0.61; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.4.66; platform_system == 'Linux' | nvidia-cusparse==12.6.3.3; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.3.24; platform_system == 'Linux' | nvidia-nvtx==13.0.85; platform_system == 'Linux' | nvidia-nvjitlink==13.0.88; platform_system == 'Linux' | nvidia-cufile==1.15.1.6; platform_system == 'Linux' | ||||||
|     secrets: |     secrets: | ||||||
|       github-token: ${{ secrets.GITHUB_TOKEN }} |       github-token: ${{ secrets.GITHUB_TOKEN }} | ||||||
|   manywheel-py3_11-cuda13_0-test:  # Testing |   manywheel-py3_11-cuda13_0-test:  # Testing | ||||||
| @ -1657,7 +1657,7 @@ jobs: | |||||||
|       runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" |       runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" | ||||||
|       build_name: manywheel-py3_12-cuda13_0 |       build_name: manywheel-py3_12-cuda13_0 | ||||||
|       build_environment: linux-binary-manywheel |       build_environment: linux-binary-manywheel | ||||||
|       PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc==13.0.48; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.48; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.48; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.0.0.19; platform_system == 'Linux' | nvidia-cufft==12.0.0.15; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.3.29; platform_system == 'Linux' | nvidia-cusparse==12.6.2.49; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.3.24; platform_system == 'Linux' | nvidia-nvtx==13.0.39; platform_system == 'Linux' | nvidia-nvjitlink==13.0.39; platform_system == 'Linux' | nvidia-cufile==1.15.0.42; platform_system == 'Linux' |       PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc==13.0.88; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.96; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.85; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.1.0.3; platform_system == 'Linux' | nvidia-cufft==12.0.0.61; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.4.66; platform_system == 'Linux' | nvidia-cusparse==12.6.3.3; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.3.24; platform_system == 'Linux' | nvidia-nvtx==13.0.85; platform_system == 'Linux' | nvidia-nvjitlink==13.0.88; platform_system == 'Linux' | nvidia-cufile==1.15.1.6; platform_system == 'Linux' | ||||||
|     secrets: |     secrets: | ||||||
|       github-token: ${{ secrets.GITHUB_TOKEN }} |       github-token: ${{ secrets.GITHUB_TOKEN }} | ||||||
|   manywheel-py3_12-cuda13_0-test:  # Testing |   manywheel-py3_12-cuda13_0-test:  # Testing | ||||||
| @ -2323,7 +2323,7 @@ jobs: | |||||||
|       runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" |       runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" | ||||||
|       build_name: manywheel-py3_13-cuda13_0 |       build_name: manywheel-py3_13-cuda13_0 | ||||||
|       build_environment: linux-binary-manywheel |       build_environment: linux-binary-manywheel | ||||||
|       PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc==13.0.48; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.48; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.48; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.0.0.19; platform_system == 'Linux' | nvidia-cufft==12.0.0.15; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.3.29; platform_system == 'Linux' | nvidia-cusparse==12.6.2.49; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.3.24; platform_system == 'Linux' | nvidia-nvtx==13.0.39; platform_system == 'Linux' | nvidia-nvjitlink==13.0.39; platform_system == 'Linux' | nvidia-cufile==1.15.0.42; platform_system == 'Linux' |       PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc==13.0.88; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.96; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.85; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.1.0.3; platform_system == 'Linux' | nvidia-cufft==12.0.0.61; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.4.66; platform_system == 'Linux' | nvidia-cusparse==12.6.3.3; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.3.24; platform_system == 'Linux' | nvidia-nvtx==13.0.85; platform_system == 'Linux' | nvidia-nvjitlink==13.0.88; platform_system == 'Linux' | nvidia-cufile==1.15.1.6; platform_system == 'Linux' | ||||||
|     secrets: |     secrets: | ||||||
|       github-token: ${{ secrets.GITHUB_TOKEN }} |       github-token: ${{ secrets.GITHUB_TOKEN }} | ||||||
|   manywheel-py3_13-cuda13_0-test:  # Testing |   manywheel-py3_13-cuda13_0-test:  # Testing | ||||||
| @ -2989,7 +2989,7 @@ jobs: | |||||||
|       runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" |       runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" | ||||||
|       build_name: manywheel-py3_13t-cuda13_0 |       build_name: manywheel-py3_13t-cuda13_0 | ||||||
|       build_environment: linux-binary-manywheel |       build_environment: linux-binary-manywheel | ||||||
|       PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc==13.0.48; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.48; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.48; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.0.0.19; platform_system == 'Linux' | nvidia-cufft==12.0.0.15; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.3.29; platform_system == 'Linux' | nvidia-cusparse==12.6.2.49; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.3.24; platform_system == 'Linux' | nvidia-nvtx==13.0.39; platform_system == 'Linux' | nvidia-nvjitlink==13.0.39; platform_system == 'Linux' | nvidia-cufile==1.15.0.42; platform_system == 'Linux' |       PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc==13.0.88; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.96; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.85; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.1.0.3; platform_system == 'Linux' | nvidia-cufft==12.0.0.61; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.4.66; platform_system == 'Linux' | nvidia-cusparse==12.6.3.3; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.3.24; platform_system == 'Linux' | nvidia-nvtx==13.0.85; platform_system == 'Linux' | nvidia-nvjitlink==13.0.88; platform_system == 'Linux' | nvidia-cufile==1.15.1.6; platform_system == 'Linux' | ||||||
|     secrets: |     secrets: | ||||||
|       github-token: ${{ secrets.GITHUB_TOKEN }} |       github-token: ${{ secrets.GITHUB_TOKEN }} | ||||||
|   manywheel-py3_13t-cuda13_0-test:  # Testing |   manywheel-py3_13t-cuda13_0-test:  # Testing | ||||||
| @ -3655,7 +3655,7 @@ jobs: | |||||||
|       runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" |       runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" | ||||||
|       build_name: manywheel-py3_14-cuda13_0 |       build_name: manywheel-py3_14-cuda13_0 | ||||||
|       build_environment: linux-binary-manywheel |       build_environment: linux-binary-manywheel | ||||||
|       PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc==13.0.48; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.48; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.48; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.0.0.19; platform_system == 'Linux' | nvidia-cufft==12.0.0.15; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.3.29; platform_system == 'Linux' | nvidia-cusparse==12.6.2.49; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.3.24; platform_system == 'Linux' | nvidia-nvtx==13.0.39; platform_system == 'Linux' | nvidia-nvjitlink==13.0.39; platform_system == 'Linux' | nvidia-cufile==1.15.0.42; platform_system == 'Linux' |       PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc==13.0.88; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.96; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.85; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.1.0.3; platform_system == 'Linux' | nvidia-cufft==12.0.0.61; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.4.66; platform_system == 'Linux' | nvidia-cusparse==12.6.3.3; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.3.24; platform_system == 'Linux' | nvidia-nvtx==13.0.85; platform_system == 'Linux' | nvidia-nvjitlink==13.0.88; platform_system == 'Linux' | nvidia-cufile==1.15.1.6; platform_system == 'Linux' | ||||||
|     secrets: |     secrets: | ||||||
|       github-token: ${{ secrets.GITHUB_TOKEN }} |       github-token: ${{ secrets.GITHUB_TOKEN }} | ||||||
|   manywheel-py3_14-cuda13_0-test:  # Testing |   manywheel-py3_14-cuda13_0-test:  # Testing | ||||||
| @ -4321,7 +4321,7 @@ jobs: | |||||||
|       runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" |       runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" | ||||||
|       build_name: manywheel-py3_14t-cuda13_0 |       build_name: manywheel-py3_14t-cuda13_0 | ||||||
|       build_environment: linux-binary-manywheel |       build_environment: linux-binary-manywheel | ||||||
|       PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc==13.0.48; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.48; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.48; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.0.0.19; platform_system == 'Linux' | nvidia-cufft==12.0.0.15; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.3.29; platform_system == 'Linux' | nvidia-cusparse==12.6.2.49; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.3.24; platform_system == 'Linux' | nvidia-nvtx==13.0.39; platform_system == 'Linux' | nvidia-nvjitlink==13.0.39; platform_system == 'Linux' | nvidia-cufile==1.15.0.42; platform_system == 'Linux' |       PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc==13.0.88; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.96; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.85; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.1.0.3; platform_system == 'Linux' | nvidia-cufft==12.0.0.61; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.4.66; platform_system == 'Linux' | nvidia-cusparse==12.6.3.3; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.3.24; platform_system == 'Linux' | nvidia-nvtx==13.0.85; platform_system == 'Linux' | nvidia-nvjitlink==13.0.88; platform_system == 'Linux' | nvidia-cufile==1.15.1.6; platform_system == 'Linux' | ||||||
|     secrets: |     secrets: | ||||||
|       github-token: ${{ secrets.GITHUB_TOKEN }} |       github-token: ${{ secrets.GITHUB_TOKEN }} | ||||||
|   manywheel-py3_14t-cuda13_0-test:  # Testing |   manywheel-py3_14t-cuda13_0-test:  # Testing | ||||||
|  | |||||||
							
								
								
									
										8
									
								
								.github/workflows/generated-windows-binary-libtorch-debug-nightly.yml
									
									
									
										generated
									
									
										vendored
									
									
								
							
							
						
						
									
										8
									
								
								.github/workflows/generated-windows-binary-libtorch-debug-nightly.yml
									
									
									
										generated
									
									
										vendored
									
									
								
							| @ -44,7 +44,7 @@ jobs: | |||||||
|   libtorch-cpu-shared-with-deps-debug-build: |   libtorch-cpu-shared-with-deps-debug-build: | ||||||
|     if: ${{ github.repository_owner == 'pytorch' }} |     if: ${{ github.repository_owner == 'pytorch' }} | ||||||
|     needs: get-label-type |     needs: get-label-type | ||||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" |     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge" | ||||||
|     timeout-minutes: 360 |     timeout-minutes: 360 | ||||||
|     env: |     env: | ||||||
|       PYTORCH_ROOT: ${{ github.workspace }}/pytorch |       PYTORCH_ROOT: ${{ github.workspace }}/pytorch | ||||||
| @ -291,7 +291,7 @@ jobs: | |||||||
|   libtorch-cuda12_6-shared-with-deps-debug-build: |   libtorch-cuda12_6-shared-with-deps-debug-build: | ||||||
|     if: ${{ github.repository_owner == 'pytorch' }} |     if: ${{ github.repository_owner == 'pytorch' }} | ||||||
|     needs: get-label-type |     needs: get-label-type | ||||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" |     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge" | ||||||
|     timeout-minutes: 360 |     timeout-minutes: 360 | ||||||
|     env: |     env: | ||||||
|       PYTORCH_ROOT: ${{ github.workspace }}/pytorch |       PYTORCH_ROOT: ${{ github.workspace }}/pytorch | ||||||
| @ -541,7 +541,7 @@ jobs: | |||||||
|   libtorch-cuda12_8-shared-with-deps-debug-build: |   libtorch-cuda12_8-shared-with-deps-debug-build: | ||||||
|     if: ${{ github.repository_owner == 'pytorch' }} |     if: ${{ github.repository_owner == 'pytorch' }} | ||||||
|     needs: get-label-type |     needs: get-label-type | ||||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" |     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge" | ||||||
|     timeout-minutes: 360 |     timeout-minutes: 360 | ||||||
|     env: |     env: | ||||||
|       PYTORCH_ROOT: ${{ github.workspace }}/pytorch |       PYTORCH_ROOT: ${{ github.workspace }}/pytorch | ||||||
| @ -791,7 +791,7 @@ jobs: | |||||||
|   libtorch-cuda13_0-shared-with-deps-debug-build: |   libtorch-cuda13_0-shared-with-deps-debug-build: | ||||||
|     if: ${{ github.repository_owner == 'pytorch' }} |     if: ${{ github.repository_owner == 'pytorch' }} | ||||||
|     needs: get-label-type |     needs: get-label-type | ||||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" |     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge" | ||||||
|     timeout-minutes: 360 |     timeout-minutes: 360 | ||||||
|     env: |     env: | ||||||
|       PYTORCH_ROOT: ${{ github.workspace }}/pytorch |       PYTORCH_ROOT: ${{ github.workspace }}/pytorch | ||||||
|  | |||||||
							
								
								
									
										8
									
								
								.github/workflows/generated-windows-binary-libtorch-release-nightly.yml
									
									
									
										generated
									
									
										vendored
									
									
								
							
							
						
						
									
										8
									
								
								.github/workflows/generated-windows-binary-libtorch-release-nightly.yml
									
									
									
										generated
									
									
										vendored
									
									
								
							| @ -44,7 +44,7 @@ jobs: | |||||||
|   libtorch-cpu-shared-with-deps-release-build: |   libtorch-cpu-shared-with-deps-release-build: | ||||||
|     if: ${{ github.repository_owner == 'pytorch' }} |     if: ${{ github.repository_owner == 'pytorch' }} | ||||||
|     needs: get-label-type |     needs: get-label-type | ||||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" |     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge" | ||||||
|     timeout-minutes: 360 |     timeout-minutes: 360 | ||||||
|     env: |     env: | ||||||
|       PYTORCH_ROOT: ${{ github.workspace }}/pytorch |       PYTORCH_ROOT: ${{ github.workspace }}/pytorch | ||||||
| @ -291,7 +291,7 @@ jobs: | |||||||
|   libtorch-cuda12_6-shared-with-deps-release-build: |   libtorch-cuda12_6-shared-with-deps-release-build: | ||||||
|     if: ${{ github.repository_owner == 'pytorch' }} |     if: ${{ github.repository_owner == 'pytorch' }} | ||||||
|     needs: get-label-type |     needs: get-label-type | ||||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" |     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge" | ||||||
|     timeout-minutes: 360 |     timeout-minutes: 360 | ||||||
|     env: |     env: | ||||||
|       PYTORCH_ROOT: ${{ github.workspace }}/pytorch |       PYTORCH_ROOT: ${{ github.workspace }}/pytorch | ||||||
| @ -541,7 +541,7 @@ jobs: | |||||||
|   libtorch-cuda12_8-shared-with-deps-release-build: |   libtorch-cuda12_8-shared-with-deps-release-build: | ||||||
|     if: ${{ github.repository_owner == 'pytorch' }} |     if: ${{ github.repository_owner == 'pytorch' }} | ||||||
|     needs: get-label-type |     needs: get-label-type | ||||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" |     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge" | ||||||
|     timeout-minutes: 360 |     timeout-minutes: 360 | ||||||
|     env: |     env: | ||||||
|       PYTORCH_ROOT: ${{ github.workspace }}/pytorch |       PYTORCH_ROOT: ${{ github.workspace }}/pytorch | ||||||
| @ -791,7 +791,7 @@ jobs: | |||||||
|   libtorch-cuda13_0-shared-with-deps-release-build: |   libtorch-cuda13_0-shared-with-deps-release-build: | ||||||
|     if: ${{ github.repository_owner == 'pytorch' }} |     if: ${{ github.repository_owner == 'pytorch' }} | ||||||
|     needs: get-label-type |     needs: get-label-type | ||||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" |     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge" | ||||||
|     timeout-minutes: 360 |     timeout-minutes: 360 | ||||||
|     env: |     env: | ||||||
|       PYTORCH_ROOT: ${{ github.workspace }}/pytorch |       PYTORCH_ROOT: ${{ github.workspace }}/pytorch | ||||||
|  | |||||||
							
								
								
									
										70
									
								
								.github/workflows/generated-windows-binary-wheel-nightly.yml
									
									
									
										generated
									
									
										vendored
									
									
								
							
							
						
						
									
										70
									
								
								.github/workflows/generated-windows-binary-wheel-nightly.yml
									
									
									
										generated
									
									
										vendored
									
									
								
							| @ -44,7 +44,7 @@ jobs: | |||||||
|   wheel-py3_10-cpu-build: |   wheel-py3_10-cpu-build: | ||||||
|     if: ${{ github.repository_owner == 'pytorch' }} |     if: ${{ github.repository_owner == 'pytorch' }} | ||||||
|     needs: get-label-type |     needs: get-label-type | ||||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" |     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge" | ||||||
|     timeout-minutes: 360 |     timeout-minutes: 360 | ||||||
|     env: |     env: | ||||||
|       PYTORCH_ROOT: ${{ github.workspace }}/pytorch |       PYTORCH_ROOT: ${{ github.workspace }}/pytorch | ||||||
| @ -279,7 +279,7 @@ jobs: | |||||||
|   wheel-py3_10-cuda12_6-build: |   wheel-py3_10-cuda12_6-build: | ||||||
|     if: ${{ github.repository_owner == 'pytorch' }} |     if: ${{ github.repository_owner == 'pytorch' }} | ||||||
|     needs: get-label-type |     needs: get-label-type | ||||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" |     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge" | ||||||
|     timeout-minutes: 360 |     timeout-minutes: 360 | ||||||
|     env: |     env: | ||||||
|       PYTORCH_ROOT: ${{ github.workspace }}/pytorch |       PYTORCH_ROOT: ${{ github.workspace }}/pytorch | ||||||
| @ -517,7 +517,7 @@ jobs: | |||||||
|   wheel-py3_10-cuda12_8-build: |   wheel-py3_10-cuda12_8-build: | ||||||
|     if: ${{ github.repository_owner == 'pytorch' }} |     if: ${{ github.repository_owner == 'pytorch' }} | ||||||
|     needs: get-label-type |     needs: get-label-type | ||||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" |     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge" | ||||||
|     timeout-minutes: 360 |     timeout-minutes: 360 | ||||||
|     env: |     env: | ||||||
|       PYTORCH_ROOT: ${{ github.workspace }}/pytorch |       PYTORCH_ROOT: ${{ github.workspace }}/pytorch | ||||||
| @ -755,7 +755,7 @@ jobs: | |||||||
|   wheel-py3_10-cuda13_0-build: |   wheel-py3_10-cuda13_0-build: | ||||||
|     if: ${{ github.repository_owner == 'pytorch' }} |     if: ${{ github.repository_owner == 'pytorch' }} | ||||||
|     needs: get-label-type |     needs: get-label-type | ||||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" |     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge" | ||||||
|     timeout-minutes: 360 |     timeout-minutes: 360 | ||||||
|     env: |     env: | ||||||
|       PYTORCH_ROOT: ${{ github.workspace }}/pytorch |       PYTORCH_ROOT: ${{ github.workspace }}/pytorch | ||||||
| @ -993,7 +993,7 @@ jobs: | |||||||
|   wheel-py3_10-xpu-build: |   wheel-py3_10-xpu-build: | ||||||
|     if: ${{ github.repository_owner == 'pytorch' }} |     if: ${{ github.repository_owner == 'pytorch' }} | ||||||
|     needs: get-label-type |     needs: get-label-type | ||||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" |     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge" | ||||||
|     timeout-minutes: 360 |     timeout-minutes: 360 | ||||||
|     env: |     env: | ||||||
|       PYTORCH_ROOT: ${{ github.workspace }}/pytorch |       PYTORCH_ROOT: ${{ github.workspace }}/pytorch | ||||||
| @ -1229,7 +1229,7 @@ jobs: | |||||||
|   wheel-py3_11-cpu-build: |   wheel-py3_11-cpu-build: | ||||||
|     if: ${{ github.repository_owner == 'pytorch' }} |     if: ${{ github.repository_owner == 'pytorch' }} | ||||||
|     needs: get-label-type |     needs: get-label-type | ||||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" |     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge" | ||||||
|     timeout-minutes: 360 |     timeout-minutes: 360 | ||||||
|     env: |     env: | ||||||
|       PYTORCH_ROOT: ${{ github.workspace }}/pytorch |       PYTORCH_ROOT: ${{ github.workspace }}/pytorch | ||||||
| @ -1464,7 +1464,7 @@ jobs: | |||||||
|   wheel-py3_11-cuda12_6-build: |   wheel-py3_11-cuda12_6-build: | ||||||
|     if: ${{ github.repository_owner == 'pytorch' }} |     if: ${{ github.repository_owner == 'pytorch' }} | ||||||
|     needs: get-label-type |     needs: get-label-type | ||||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" |     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge" | ||||||
|     timeout-minutes: 360 |     timeout-minutes: 360 | ||||||
|     env: |     env: | ||||||
|       PYTORCH_ROOT: ${{ github.workspace }}/pytorch |       PYTORCH_ROOT: ${{ github.workspace }}/pytorch | ||||||
| @ -1702,7 +1702,7 @@ jobs: | |||||||
|   wheel-py3_11-cuda12_8-build: |   wheel-py3_11-cuda12_8-build: | ||||||
|     if: ${{ github.repository_owner == 'pytorch' }} |     if: ${{ github.repository_owner == 'pytorch' }} | ||||||
|     needs: get-label-type |     needs: get-label-type | ||||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" |     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge" | ||||||
|     timeout-minutes: 360 |     timeout-minutes: 360 | ||||||
|     env: |     env: | ||||||
|       PYTORCH_ROOT: ${{ github.workspace }}/pytorch |       PYTORCH_ROOT: ${{ github.workspace }}/pytorch | ||||||
| @ -1940,7 +1940,7 @@ jobs: | |||||||
|   wheel-py3_11-cuda13_0-build: |   wheel-py3_11-cuda13_0-build: | ||||||
|     if: ${{ github.repository_owner == 'pytorch' }} |     if: ${{ github.repository_owner == 'pytorch' }} | ||||||
|     needs: get-label-type |     needs: get-label-type | ||||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" |     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge" | ||||||
|     timeout-minutes: 360 |     timeout-minutes: 360 | ||||||
|     env: |     env: | ||||||
|       PYTORCH_ROOT: ${{ github.workspace }}/pytorch |       PYTORCH_ROOT: ${{ github.workspace }}/pytorch | ||||||
| @ -2178,7 +2178,7 @@ jobs: | |||||||
|   wheel-py3_11-xpu-build: |   wheel-py3_11-xpu-build: | ||||||
|     if: ${{ github.repository_owner == 'pytorch' }} |     if: ${{ github.repository_owner == 'pytorch' }} | ||||||
|     needs: get-label-type |     needs: get-label-type | ||||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" |     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge" | ||||||
|     timeout-minutes: 360 |     timeout-minutes: 360 | ||||||
|     env: |     env: | ||||||
|       PYTORCH_ROOT: ${{ github.workspace }}/pytorch |       PYTORCH_ROOT: ${{ github.workspace }}/pytorch | ||||||
| @ -2414,7 +2414,7 @@ jobs: | |||||||
|   wheel-py3_12-cpu-build: |   wheel-py3_12-cpu-build: | ||||||
|     if: ${{ github.repository_owner == 'pytorch' }} |     if: ${{ github.repository_owner == 'pytorch' }} | ||||||
|     needs: get-label-type |     needs: get-label-type | ||||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" |     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge" | ||||||
|     timeout-minutes: 360 |     timeout-minutes: 360 | ||||||
|     env: |     env: | ||||||
|       PYTORCH_ROOT: ${{ github.workspace }}/pytorch |       PYTORCH_ROOT: ${{ github.workspace }}/pytorch | ||||||
| @ -2649,7 +2649,7 @@ jobs: | |||||||
|   wheel-py3_12-cuda12_6-build: |   wheel-py3_12-cuda12_6-build: | ||||||
|     if: ${{ github.repository_owner == 'pytorch' }} |     if: ${{ github.repository_owner == 'pytorch' }} | ||||||
|     needs: get-label-type |     needs: get-label-type | ||||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" |     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge" | ||||||
|     timeout-minutes: 360 |     timeout-minutes: 360 | ||||||
|     env: |     env: | ||||||
|       PYTORCH_ROOT: ${{ github.workspace }}/pytorch |       PYTORCH_ROOT: ${{ github.workspace }}/pytorch | ||||||
| @ -2887,7 +2887,7 @@ jobs: | |||||||
|   wheel-py3_12-cuda12_8-build: |   wheel-py3_12-cuda12_8-build: | ||||||
|     if: ${{ github.repository_owner == 'pytorch' }} |     if: ${{ github.repository_owner == 'pytorch' }} | ||||||
|     needs: get-label-type |     needs: get-label-type | ||||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" |     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge" | ||||||
|     timeout-minutes: 360 |     timeout-minutes: 360 | ||||||
|     env: |     env: | ||||||
|       PYTORCH_ROOT: ${{ github.workspace }}/pytorch |       PYTORCH_ROOT: ${{ github.workspace }}/pytorch | ||||||
| @ -3125,7 +3125,7 @@ jobs: | |||||||
|   wheel-py3_12-cuda13_0-build: |   wheel-py3_12-cuda13_0-build: | ||||||
|     if: ${{ github.repository_owner == 'pytorch' }} |     if: ${{ github.repository_owner == 'pytorch' }} | ||||||
|     needs: get-label-type |     needs: get-label-type | ||||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" |     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge" | ||||||
|     timeout-minutes: 360 |     timeout-minutes: 360 | ||||||
|     env: |     env: | ||||||
|       PYTORCH_ROOT: ${{ github.workspace }}/pytorch |       PYTORCH_ROOT: ${{ github.workspace }}/pytorch | ||||||
| @ -3363,7 +3363,7 @@ jobs: | |||||||
|   wheel-py3_12-xpu-build: |   wheel-py3_12-xpu-build: | ||||||
|     if: ${{ github.repository_owner == 'pytorch' }} |     if: ${{ github.repository_owner == 'pytorch' }} | ||||||
|     needs: get-label-type |     needs: get-label-type | ||||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" |     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge" | ||||||
|     timeout-minutes: 360 |     timeout-minutes: 360 | ||||||
|     env: |     env: | ||||||
|       PYTORCH_ROOT: ${{ github.workspace }}/pytorch |       PYTORCH_ROOT: ${{ github.workspace }}/pytorch | ||||||
| @ -3599,7 +3599,7 @@ jobs: | |||||||
|   wheel-py3_13-cpu-build: |   wheel-py3_13-cpu-build: | ||||||
|     if: ${{ github.repository_owner == 'pytorch' }} |     if: ${{ github.repository_owner == 'pytorch' }} | ||||||
|     needs: get-label-type |     needs: get-label-type | ||||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" |     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge" | ||||||
|     timeout-minutes: 360 |     timeout-minutes: 360 | ||||||
|     env: |     env: | ||||||
|       PYTORCH_ROOT: ${{ github.workspace }}/pytorch |       PYTORCH_ROOT: ${{ github.workspace }}/pytorch | ||||||
| @ -3834,7 +3834,7 @@ jobs: | |||||||
|   wheel-py3_13-cuda12_6-build: |   wheel-py3_13-cuda12_6-build: | ||||||
|     if: ${{ github.repository_owner == 'pytorch' }} |     if: ${{ github.repository_owner == 'pytorch' }} | ||||||
|     needs: get-label-type |     needs: get-label-type | ||||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" |     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge" | ||||||
|     timeout-minutes: 360 |     timeout-minutes: 360 | ||||||
|     env: |     env: | ||||||
|       PYTORCH_ROOT: ${{ github.workspace }}/pytorch |       PYTORCH_ROOT: ${{ github.workspace }}/pytorch | ||||||
| @ -4072,7 +4072,7 @@ jobs: | |||||||
|   wheel-py3_13-cuda12_8-build: |   wheel-py3_13-cuda12_8-build: | ||||||
|     if: ${{ github.repository_owner == 'pytorch' }} |     if: ${{ github.repository_owner == 'pytorch' }} | ||||||
|     needs: get-label-type |     needs: get-label-type | ||||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" |     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge" | ||||||
|     timeout-minutes: 360 |     timeout-minutes: 360 | ||||||
|     env: |     env: | ||||||
|       PYTORCH_ROOT: ${{ github.workspace }}/pytorch |       PYTORCH_ROOT: ${{ github.workspace }}/pytorch | ||||||
| @ -4310,7 +4310,7 @@ jobs: | |||||||
|   wheel-py3_13-cuda13_0-build: |   wheel-py3_13-cuda13_0-build: | ||||||
|     if: ${{ github.repository_owner == 'pytorch' }} |     if: ${{ github.repository_owner == 'pytorch' }} | ||||||
|     needs: get-label-type |     needs: get-label-type | ||||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" |     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge" | ||||||
|     timeout-minutes: 360 |     timeout-minutes: 360 | ||||||
|     env: |     env: | ||||||
|       PYTORCH_ROOT: ${{ github.workspace }}/pytorch |       PYTORCH_ROOT: ${{ github.workspace }}/pytorch | ||||||
| @ -4548,7 +4548,7 @@ jobs: | |||||||
|   wheel-py3_13-xpu-build: |   wheel-py3_13-xpu-build: | ||||||
|     if: ${{ github.repository_owner == 'pytorch' }} |     if: ${{ github.repository_owner == 'pytorch' }} | ||||||
|     needs: get-label-type |     needs: get-label-type | ||||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" |     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge" | ||||||
|     timeout-minutes: 360 |     timeout-minutes: 360 | ||||||
|     env: |     env: | ||||||
|       PYTORCH_ROOT: ${{ github.workspace }}/pytorch |       PYTORCH_ROOT: ${{ github.workspace }}/pytorch | ||||||
| @ -4784,7 +4784,7 @@ jobs: | |||||||
|   wheel-py3_13t-cpu-build: |   wheel-py3_13t-cpu-build: | ||||||
|     if: ${{ github.repository_owner == 'pytorch' }} |     if: ${{ github.repository_owner == 'pytorch' }} | ||||||
|     needs: get-label-type |     needs: get-label-type | ||||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" |     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge" | ||||||
|     timeout-minutes: 360 |     timeout-minutes: 360 | ||||||
|     env: |     env: | ||||||
|       PYTORCH_ROOT: ${{ github.workspace }}/pytorch |       PYTORCH_ROOT: ${{ github.workspace }}/pytorch | ||||||
| @ -5019,7 +5019,7 @@ jobs: | |||||||
|   wheel-py3_13t-cuda12_6-build: |   wheel-py3_13t-cuda12_6-build: | ||||||
|     if: ${{ github.repository_owner == 'pytorch' }} |     if: ${{ github.repository_owner == 'pytorch' }} | ||||||
|     needs: get-label-type |     needs: get-label-type | ||||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" |     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge" | ||||||
|     timeout-minutes: 360 |     timeout-minutes: 360 | ||||||
|     env: |     env: | ||||||
|       PYTORCH_ROOT: ${{ github.workspace }}/pytorch |       PYTORCH_ROOT: ${{ github.workspace }}/pytorch | ||||||
| @ -5257,7 +5257,7 @@ jobs: | |||||||
|   wheel-py3_13t-cuda12_8-build: |   wheel-py3_13t-cuda12_8-build: | ||||||
|     if: ${{ github.repository_owner == 'pytorch' }} |     if: ${{ github.repository_owner == 'pytorch' }} | ||||||
|     needs: get-label-type |     needs: get-label-type | ||||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" |     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge" | ||||||
|     timeout-minutes: 360 |     timeout-minutes: 360 | ||||||
|     env: |     env: | ||||||
|       PYTORCH_ROOT: ${{ github.workspace }}/pytorch |       PYTORCH_ROOT: ${{ github.workspace }}/pytorch | ||||||
| @ -5495,7 +5495,7 @@ jobs: | |||||||
|   wheel-py3_13t-cuda13_0-build: |   wheel-py3_13t-cuda13_0-build: | ||||||
|     if: ${{ github.repository_owner == 'pytorch' }} |     if: ${{ github.repository_owner == 'pytorch' }} | ||||||
|     needs: get-label-type |     needs: get-label-type | ||||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" |     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge" | ||||||
|     timeout-minutes: 360 |     timeout-minutes: 360 | ||||||
|     env: |     env: | ||||||
|       PYTORCH_ROOT: ${{ github.workspace }}/pytorch |       PYTORCH_ROOT: ${{ github.workspace }}/pytorch | ||||||
| @ -5733,7 +5733,7 @@ jobs: | |||||||
|   wheel-py3_13t-xpu-build: |   wheel-py3_13t-xpu-build: | ||||||
|     if: ${{ github.repository_owner == 'pytorch' }} |     if: ${{ github.repository_owner == 'pytorch' }} | ||||||
|     needs: get-label-type |     needs: get-label-type | ||||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" |     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge" | ||||||
|     timeout-minutes: 360 |     timeout-minutes: 360 | ||||||
|     env: |     env: | ||||||
|       PYTORCH_ROOT: ${{ github.workspace }}/pytorch |       PYTORCH_ROOT: ${{ github.workspace }}/pytorch | ||||||
| @ -5969,7 +5969,7 @@ jobs: | |||||||
|   wheel-py3_14-cpu-build: |   wheel-py3_14-cpu-build: | ||||||
|     if: ${{ github.repository_owner == 'pytorch' }} |     if: ${{ github.repository_owner == 'pytorch' }} | ||||||
|     needs: get-label-type |     needs: get-label-type | ||||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" |     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge" | ||||||
|     timeout-minutes: 360 |     timeout-minutes: 360 | ||||||
|     env: |     env: | ||||||
|       PYTORCH_ROOT: ${{ github.workspace }}/pytorch |       PYTORCH_ROOT: ${{ github.workspace }}/pytorch | ||||||
| @ -6204,7 +6204,7 @@ jobs: | |||||||
|   wheel-py3_14-cuda12_6-build: |   wheel-py3_14-cuda12_6-build: | ||||||
|     if: ${{ github.repository_owner == 'pytorch' }} |     if: ${{ github.repository_owner == 'pytorch' }} | ||||||
|     needs: get-label-type |     needs: get-label-type | ||||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" |     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge" | ||||||
|     timeout-minutes: 360 |     timeout-minutes: 360 | ||||||
|     env: |     env: | ||||||
|       PYTORCH_ROOT: ${{ github.workspace }}/pytorch |       PYTORCH_ROOT: ${{ github.workspace }}/pytorch | ||||||
| @ -6442,7 +6442,7 @@ jobs: | |||||||
|   wheel-py3_14-cuda12_8-build: |   wheel-py3_14-cuda12_8-build: | ||||||
|     if: ${{ github.repository_owner == 'pytorch' }} |     if: ${{ github.repository_owner == 'pytorch' }} | ||||||
|     needs: get-label-type |     needs: get-label-type | ||||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" |     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge" | ||||||
|     timeout-minutes: 360 |     timeout-minutes: 360 | ||||||
|     env: |     env: | ||||||
|       PYTORCH_ROOT: ${{ github.workspace }}/pytorch |       PYTORCH_ROOT: ${{ github.workspace }}/pytorch | ||||||
| @ -6680,7 +6680,7 @@ jobs: | |||||||
|   wheel-py3_14-cuda13_0-build: |   wheel-py3_14-cuda13_0-build: | ||||||
|     if: ${{ github.repository_owner == 'pytorch' }} |     if: ${{ github.repository_owner == 'pytorch' }} | ||||||
|     needs: get-label-type |     needs: get-label-type | ||||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" |     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge" | ||||||
|     timeout-minutes: 360 |     timeout-minutes: 360 | ||||||
|     env: |     env: | ||||||
|       PYTORCH_ROOT: ${{ github.workspace }}/pytorch |       PYTORCH_ROOT: ${{ github.workspace }}/pytorch | ||||||
| @ -6918,7 +6918,7 @@ jobs: | |||||||
|   wheel-py3_14-xpu-build: |   wheel-py3_14-xpu-build: | ||||||
|     if: ${{ github.repository_owner == 'pytorch' }} |     if: ${{ github.repository_owner == 'pytorch' }} | ||||||
|     needs: get-label-type |     needs: get-label-type | ||||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" |     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge" | ||||||
|     timeout-minutes: 360 |     timeout-minutes: 360 | ||||||
|     env: |     env: | ||||||
|       PYTORCH_ROOT: ${{ github.workspace }}/pytorch |       PYTORCH_ROOT: ${{ github.workspace }}/pytorch | ||||||
| @ -7154,7 +7154,7 @@ jobs: | |||||||
|   wheel-py3_14t-cpu-build: |   wheel-py3_14t-cpu-build: | ||||||
|     if: ${{ github.repository_owner == 'pytorch' }} |     if: ${{ github.repository_owner == 'pytorch' }} | ||||||
|     needs: get-label-type |     needs: get-label-type | ||||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" |     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge" | ||||||
|     timeout-minutes: 360 |     timeout-minutes: 360 | ||||||
|     env: |     env: | ||||||
|       PYTORCH_ROOT: ${{ github.workspace }}/pytorch |       PYTORCH_ROOT: ${{ github.workspace }}/pytorch | ||||||
| @ -7389,7 +7389,7 @@ jobs: | |||||||
|   wheel-py3_14t-cuda12_6-build: |   wheel-py3_14t-cuda12_6-build: | ||||||
|     if: ${{ github.repository_owner == 'pytorch' }} |     if: ${{ github.repository_owner == 'pytorch' }} | ||||||
|     needs: get-label-type |     needs: get-label-type | ||||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" |     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge" | ||||||
|     timeout-minutes: 360 |     timeout-minutes: 360 | ||||||
|     env: |     env: | ||||||
|       PYTORCH_ROOT: ${{ github.workspace }}/pytorch |       PYTORCH_ROOT: ${{ github.workspace }}/pytorch | ||||||
| @ -7627,7 +7627,7 @@ jobs: | |||||||
|   wheel-py3_14t-cuda12_8-build: |   wheel-py3_14t-cuda12_8-build: | ||||||
|     if: ${{ github.repository_owner == 'pytorch' }} |     if: ${{ github.repository_owner == 'pytorch' }} | ||||||
|     needs: get-label-type |     needs: get-label-type | ||||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" |     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge" | ||||||
|     timeout-minutes: 360 |     timeout-minutes: 360 | ||||||
|     env: |     env: | ||||||
|       PYTORCH_ROOT: ${{ github.workspace }}/pytorch |       PYTORCH_ROOT: ${{ github.workspace }}/pytorch | ||||||
| @ -7865,7 +7865,7 @@ jobs: | |||||||
|   wheel-py3_14t-cuda13_0-build: |   wheel-py3_14t-cuda13_0-build: | ||||||
|     if: ${{ github.repository_owner == 'pytorch' }} |     if: ${{ github.repository_owner == 'pytorch' }} | ||||||
|     needs: get-label-type |     needs: get-label-type | ||||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" |     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge" | ||||||
|     timeout-minutes: 360 |     timeout-minutes: 360 | ||||||
|     env: |     env: | ||||||
|       PYTORCH_ROOT: ${{ github.workspace }}/pytorch |       PYTORCH_ROOT: ${{ github.workspace }}/pytorch | ||||||
| @ -8103,7 +8103,7 @@ jobs: | |||||||
|   wheel-py3_14t-xpu-build: |   wheel-py3_14t-xpu-build: | ||||||
|     if: ${{ github.repository_owner == 'pytorch' }} |     if: ${{ github.repository_owner == 'pytorch' }} | ||||||
|     needs: get-label-type |     needs: get-label-type | ||||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" |     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge" | ||||||
|     timeout-minutes: 360 |     timeout-minutes: 360 | ||||||
|     env: |     env: | ||||||
|       PYTORCH_ROOT: ${{ github.workspace }}/pytorch |       PYTORCH_ROOT: ${{ github.workspace }}/pytorch | ||||||
|  | |||||||
							
								
								
									
										1
									
								
								.github/workflows/inductor-periodic.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								.github/workflows/inductor-periodic.yml
									
									
									
									
										vendored
									
									
								
							| @ -88,7 +88,6 @@ jobs: | |||||||
|     with: |     with: | ||||||
|       build-environment: linux-jammy-rocm-py3_10 |       build-environment: linux-jammy-rocm-py3_10 | ||||||
|       docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3-benchmarks |       docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3-benchmarks | ||||||
|       sync-tag: rocm-build |  | ||||||
|       test-matrix: | |       test-matrix: | | ||||||
|         { include: [ |         { include: [ | ||||||
|           { config: "dynamo_eager_torchbench", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.gfx942.1" }, |           { config: "dynamo_eager_torchbench", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.gfx942.1" }, | ||||||
|  | |||||||
							
								
								
									
										15
									
								
								.github/workflows/periodic.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										15
									
								
								.github/workflows/periodic.yml
									
									
									
									
										vendored
									
									
								
							| @ -147,15 +147,16 @@ jobs: | |||||||
|       runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" |       runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" | ||||||
|       build-environment: linux-jammy-cuda12.8-py3.10-gcc9-debug |       build-environment: linux-jammy-cuda12.8-py3.10-gcc9-debug | ||||||
|       docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9 |       docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9 | ||||||
|  |       cuda-arch-list: 8.9 | ||||||
|       test-matrix: | |       test-matrix: | | ||||||
|         { include: [ |         { include: [ | ||||||
|           { config: "default", shard: 1, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu", owners: ["oncall:debug-build"] }, |           { config: "default", shard: 1, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", owners: ["oncall:debug-build"] }, | ||||||
|           { config: "default", shard: 2, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu", owners: ["oncall:debug-build"] }, |           { config: "default", shard: 2, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", owners: ["oncall:debug-build"] }, | ||||||
|           { config: "default", shard: 3, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu", owners: ["oncall:debug-build"] }, |           { config: "default", shard: 3, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", owners: ["oncall:debug-build"] }, | ||||||
|           { config: "default", shard: 4, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu", owners: ["oncall:debug-build"] }, |           { config: "default", shard: 4, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", owners: ["oncall:debug-build"] }, | ||||||
|           { config: "default", shard: 5, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu", owners: ["oncall:debug-build"] }, |           { config: "default", shard: 5, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", owners: ["oncall:debug-build"] }, | ||||||
|           { config: "default", shard: 6, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu", owners: ["oncall:debug-build"] }, |           { config: "default", shard: 6, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", owners: ["oncall:debug-build"] }, | ||||||
|           { config: "default", shard: 7, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu", owners: ["oncall:debug-build"] }, |           { config: "default", shard: 7, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", owners: ["oncall:debug-build"] }, | ||||||
|         ]} |         ]} | ||||||
|     secrets: inherit |     secrets: inherit | ||||||
|  |  | ||||||
|  | |||||||
							
								
								
									
										3
									
								
								.github/workflows/pull.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										3
									
								
								.github/workflows/pull.yml
									
									
									
									
										vendored
									
									
								
							| @ -347,7 +347,8 @@ jobs: | |||||||
|     uses: ./.github/workflows/_linux-build.yml |     uses: ./.github/workflows/_linux-build.yml | ||||||
|     needs: get-label-type |     needs: get-label-type | ||||||
|     with: |     with: | ||||||
|       sync-tag: linux-xpu-n-build |       # This should sync with the build in xpu.yml but xpu uses a larger runner | ||||||
|  |       # sync-tag: linux-xpu-n-build | ||||||
|       runner_prefix: ${{ needs.get-label-type.outputs.label-type }} |       runner_prefix: ${{ needs.get-label-type.outputs.label-type }} | ||||||
|       build-environment: linux-jammy-xpu-n-py3.10 |       build-environment: linux-jammy-xpu-n-py3.10 | ||||||
|       docker-image-name: ci-image:pytorch-linux-jammy-xpu-n-py3 |       docker-image-name: ci-image:pytorch-linux-jammy-xpu-n-py3 | ||||||
|  | |||||||
							
								
								
									
										1
									
								
								.github/workflows/rocm-mi300.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								.github/workflows/rocm-mi300.yml
									
									
									
									
										vendored
									
									
								
							| @ -45,7 +45,6 @@ jobs: | |||||||
|       runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" |       runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" | ||||||
|       build-environment: linux-noble-rocm-py3.12-mi300 |       build-environment: linux-noble-rocm-py3.12-mi300 | ||||||
|       docker-image-name: ci-image:pytorch-linux-noble-rocm-n-py3 |       docker-image-name: ci-image:pytorch-linux-noble-rocm-n-py3 | ||||||
|       sync-tag: rocm-build |  | ||||||
|       test-matrix: | |       test-matrix: | | ||||||
|         { include: [ |         { include: [ | ||||||
|           { config: "default", shard: 1, num_shards: 6, runner: "linux.rocm.gpu.gfx942.1" }, |           { config: "default", shard: 1, num_shards: 6, runner: "linux.rocm.gpu.gfx942.1" }, | ||||||
|  | |||||||
							
								
								
									
										1
									
								
								.github/workflows/rocm-mi355.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								.github/workflows/rocm-mi355.yml
									
									
									
									
										vendored
									
									
								
							| @ -42,7 +42,6 @@ jobs: | |||||||
|       runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" |       runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" | ||||||
|       build-environment: linux-noble-rocm-py3.12-mi355 |       build-environment: linux-noble-rocm-py3.12-mi355 | ||||||
|       docker-image-name: ci-image:pytorch-linux-noble-rocm-n-py3 |       docker-image-name: ci-image:pytorch-linux-noble-rocm-n-py3 | ||||||
|       sync-tag: rocm-build |  | ||||||
|       test-matrix: | |       test-matrix: | | ||||||
|         { include: [ |         { include: [ | ||||||
|           { config: "default", shard: 1, num_shards: 6, runner: "linux.rocm.gpu.mi355.1" }, |           { config: "default", shard: 1, num_shards: 6, runner: "linux.rocm.gpu.mi355.1" }, | ||||||
|  | |||||||
							
								
								
									
										12
									
								
								.github/workflows/rocm-navi31.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										12
									
								
								.github/workflows/rocm-navi31.yml
									
									
									
									
										vendored
									
									
								
							| @ -26,11 +26,23 @@ jobs: | |||||||
|       id-token: write |       id-token: write | ||||||
|       contents: read |       contents: read | ||||||
|  |  | ||||||
|  |   get-label-type: | ||||||
|  |     name: get-label-type | ||||||
|  |     uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main | ||||||
|  |     if: ${{ (github.event_name != 'schedule' || github.repository == 'pytorch/pytorch') && github.repository_owner == 'pytorch' }} | ||||||
|  |     with: | ||||||
|  |       triggering_actor: ${{ github.triggering_actor }} | ||||||
|  |       issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} | ||||||
|  |       curr_branch: ${{ github.head_ref || github.ref_name }} | ||||||
|  |       curr_ref_type: ${{ github.ref_type }} | ||||||
|  |  | ||||||
|   linux-jammy-rocm-py3_10-build: |   linux-jammy-rocm-py3_10-build: | ||||||
|     if: ${{ (github.event_name != 'schedule' || github.repository == 'pytorch/pytorch') && github.repository_owner == 'pytorch' }} |     if: ${{ (github.event_name != 'schedule' || github.repository == 'pytorch/pytorch') && github.repository_owner == 'pytorch' }} | ||||||
|     name: linux-jammy-rocm-py3.10 |     name: linux-jammy-rocm-py3.10 | ||||||
|     uses: ./.github/workflows/_linux-build.yml |     uses: ./.github/workflows/_linux-build.yml | ||||||
|  |     needs: get-label-type | ||||||
|     with: |     with: | ||||||
|  |       runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" | ||||||
|       build-environment: linux-jammy-rocm-py3.10 |       build-environment: linux-jammy-rocm-py3.10 | ||||||
|       docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3 |       docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3 | ||||||
|       sync-tag: rocm-build |       sync-tag: rocm-build | ||||||
|  | |||||||
							
								
								
									
										12
									
								
								.github/workflows/rocm.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										12
									
								
								.github/workflows/rocm.yml
									
									
									
									
										vendored
									
									
								
							| @ -26,11 +26,23 @@ jobs: | |||||||
|       id-token: write |       id-token: write | ||||||
|       contents: read |       contents: read | ||||||
|  |  | ||||||
|  |   get-label-type: | ||||||
|  |     name: get-label-type | ||||||
|  |     uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main | ||||||
|  |     if: ${{ (github.event_name != 'schedule' || github.repository == 'pytorch/pytorch') && github.repository_owner == 'pytorch' }} | ||||||
|  |     with: | ||||||
|  |       triggering_actor: ${{ github.triggering_actor }} | ||||||
|  |       issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} | ||||||
|  |       curr_branch: ${{ github.head_ref || github.ref_name }} | ||||||
|  |       curr_ref_type: ${{ github.ref_type }} | ||||||
|  |  | ||||||
|   linux-jammy-rocm-py3_10-build: |   linux-jammy-rocm-py3_10-build: | ||||||
|     if: ${{ (github.event_name != 'schedule' || github.repository == 'pytorch/pytorch') && github.repository_owner == 'pytorch' }} |     if: ${{ (github.event_name != 'schedule' || github.repository == 'pytorch/pytorch') && github.repository_owner == 'pytorch' }} | ||||||
|     name: linux-jammy-rocm-py3.10 |     name: linux-jammy-rocm-py3.10 | ||||||
|     uses: ./.github/workflows/_linux-build.yml |     uses: ./.github/workflows/_linux-build.yml | ||||||
|  |     needs: get-label-type | ||||||
|     with: |     with: | ||||||
|  |       runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" | ||||||
|       build-environment: linux-jammy-rocm-py3.10 |       build-environment: linux-jammy-rocm-py3.10 | ||||||
|       docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3 |       docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3 | ||||||
|       sync-tag: rocm-build |       sync-tag: rocm-build | ||||||
|  | |||||||
							
								
								
									
										147
									
								
								.github/workflows/trunk-tagging.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										147
									
								
								.github/workflows/trunk-tagging.yml
									
									
									
									
										vendored
									
									
								
							| @ -58,8 +58,10 @@ jobs: | |||||||
|           else |           else | ||||||
|             COMMIT_SHA="${{ github.sha }}" |             COMMIT_SHA="${{ github.sha }}" | ||||||
|           fi |           fi | ||||||
|           echo "sha=${COMMIT_SHA}" >> "${GITHUB_OUTPUT}" |           { | ||||||
|           echo "tag_name=trunk/${COMMIT_SHA}" >> "${GITHUB_OUTPUT}" |             echo "sha=${COMMIT_SHA}" | ||||||
|  |             echo "tag_name=trunk/${COMMIT_SHA}" | ||||||
|  |           } >> "${GITHUB_OUTPUT}" | ||||||
|  |  | ||||||
|       - name: Validate commit SHA |       - name: Validate commit SHA | ||||||
|         run: | |         run: | | ||||||
| @ -87,7 +89,7 @@ jobs: | |||||||
|             echo "✅ Commit ${COMMIT_SHA} is valid (automatic push trigger)" |             echo "✅ Commit ${COMMIT_SHA} is valid (automatic push trigger)" | ||||||
|           fi |           fi | ||||||
|  |  | ||||||
|       - name: Create and push tag with retry |       - name: Create and push tag(s) with retry | ||||||
|         id: check_tag |         id: check_tag | ||||||
|         env: |         env: | ||||||
|           TAG_NAME: ${{ steps.commit.outputs.tag_name }} |           TAG_NAME: ${{ steps.commit.outputs.tag_name }} | ||||||
| @ -112,14 +114,23 @@ jobs: | |||||||
|             return 1 |             return 1 | ||||||
|           } |           } | ||||||
|  |  | ||||||
|           # Exit early if tag already exists |           # Counters for summary reporting | ||||||
|           if check_tag_exists; then |           created_count=0 | ||||||
|             echo "✅ Tag already exists - no action needed" |           skipped_count=0 | ||||||
|             echo "exists=true" >> "${GITHUB_OUTPUT}" |           failed_count=0 | ||||||
|             exit 0 |  | ||||||
|           fi |  | ||||||
|  |  | ||||||
|           echo "Tag ${TAG_NAME} does not exist, proceeding with creation" |           # Always write outputs once on exit | ||||||
|  |           finish() { | ||||||
|  |             set +e | ||||||
|  |             if [ -n "${GITHUB_OUTPUT:-}" ]; then | ||||||
|  |               { | ||||||
|  |                 echo "created_count=${created_count}" | ||||||
|  |                 echo "skipped_count=${skipped_count}" | ||||||
|  |                 echo "failed_count=${failed_count}" | ||||||
|  |               } >> "${GITHUB_OUTPUT}" | ||||||
|  |             fi | ||||||
|  |           } | ||||||
|  |           trap finish EXIT | ||||||
|  |  | ||||||
|           # Retry configuration |           # Retry configuration | ||||||
|           MAX_RETRIES=5 |           MAX_RETRIES=5 | ||||||
| @ -194,31 +205,111 @@ jobs: | |||||||
|             } |             } | ||||||
|           } |           } | ||||||
|  |  | ||||||
|           # Execute with retry |           # New behavior for push events: enumerate commits in the push and tag each one. | ||||||
|           if retry_with_backoff "tag_with_retry" "Creating tag ${TAG_NAME} for commit ${COMMIT_SHA}"; then |           # For workflow_dispatch, retain existing single-SHA behavior. | ||||||
|             echo "exists=false" >> "${GITHUB_OUTPUT}" |  | ||||||
|  |           # Always fetch tags once up front to improve idempotency in loops | ||||||
|  |           git fetch origin --tags --quiet || true | ||||||
|  |  | ||||||
|  |           if [ "${{ github.event_name }}" = "push" ]; then | ||||||
|  |             BEFORE_SHA="${{ github.event.before }}" | ||||||
|  |             AFTER_SHA="${{ github.sha }}"  # same as event.after | ||||||
|  |  | ||||||
|  |             # List commits introduced by this push (old..new), oldest first for stable ordering | ||||||
|  |             commits_file="$(mktemp)" | ||||||
|  |             git rev-list --reverse "${BEFORE_SHA}..${AFTER_SHA}" > "${commits_file}" | ||||||
|  |  | ||||||
|  |             if [ ! -s "${commits_file}" ]; then | ||||||
|  |               echo "No new commits found between ${BEFORE_SHA}..${AFTER_SHA}; nothing to tag." | ||||||
|  |               rm -f "${commits_file}" | ||||||
|  |               exit 0 | ||||||
|  |             fi | ||||||
|  |  | ||||||
|  |             commit_count="$(wc -l < "${commits_file}" | tr -d ' ')" | ||||||
|  |             echo "Found ${commit_count} commit(s) to tag for push:" | ||||||
|  |             while IFS= read -r sha; do | ||||||
|  |               printf '  %s\n' "${sha}" | ||||||
|  |             done < "${commits_file}" | ||||||
|  |  | ||||||
|  |             while IFS= read -r sha; do | ||||||
|  |               TAG_NAME="trunk/${sha}" | ||||||
|  |               COMMIT_SHA="${sha}" | ||||||
|  |  | ||||||
|  |               # If tag already exists locally or remotely, skip (idempotent) | ||||||
|  |               if check_tag_exists; then | ||||||
|  |                 echo "✅ Tag ${TAG_NAME} already exists - skipping" | ||||||
|  |                 skipped_count=$((skipped_count + 1)) | ||||||
|  |                 continue | ||||||
|  |               fi | ||||||
|  |  | ||||||
|  |               echo "Tag ${TAG_NAME} does not exist, proceeding with creation" | ||||||
|  |  | ||||||
|  |               if retry_with_backoff "tag_with_retry" "Creating tag ${TAG_NAME} for commit ${COMMIT_SHA}"; then | ||||||
|  |                 created_count=$((created_count + 1)) | ||||||
|  |               else | ||||||
|  |                 echo "Tag creation failed after all retry attempts for ${TAG_NAME}" | ||||||
|  |                 failed_count=$((failed_count + 1)) | ||||||
|  |               fi | ||||||
|  |             done < "${commits_file}" | ||||||
|  |  | ||||||
|  |             rm -f "${commits_file}" | ||||||
|  |  | ||||||
|  |             if [ "${failed_count}" -gt 0 ]; then | ||||||
|  |               exit 1 | ||||||
|  |             fi | ||||||
|             exit 0 |             exit 0 | ||||||
|           else |           else | ||||||
|             echo "Tag creation failed after all retry attempts" |             # workflow_dispatch path (single SHA tagging preserved) | ||||||
|             exit 1 |  | ||||||
|  |             # Exit early if tag already exists | ||||||
|  |             if check_tag_exists; then | ||||||
|  |               echo "✅ Tag already exists - no action needed" | ||||||
|  |               skipped_count=1 | ||||||
|  |               exit 0 | ||||||
|  |             fi | ||||||
|  |  | ||||||
|  |             echo "Tag ${TAG_NAME} does not exist, proceeding with creation" | ||||||
|  |  | ||||||
|  |             if retry_with_backoff "tag_with_retry" "Creating tag ${TAG_NAME} for commit ${COMMIT_SHA}"; then | ||||||
|  |               created_count=1 | ||||||
|  |               exit 0 | ||||||
|  |             else | ||||||
|  |               echo "Tag creation failed after all retry attempts" | ||||||
|  |               failed_count=1 | ||||||
|  |               exit 1 | ||||||
|  |             fi | ||||||
|           fi |           fi | ||||||
|  |  | ||||||
|       - name: Tag creation summary |       - name: Tag creation summary | ||||||
|         if: always() |         if: always() | ||||||
|         run: | |         run: | | ||||||
|           if [ "${{ steps.check_tag.outputs.exists }}" = "true" ]; then |           if [ "${{ github.event_name }}" = "push" ]; then | ||||||
|             echo "✅ Tag ${{ steps.commit.outputs.tag_name }} already existed - no action needed" |             echo "Trigger: push on main" | ||||||
|           elif [ "${{ job.status }}" = "success" ]; then |             echo "Created: ${{ steps.check_tag.outputs.created_count }}" | ||||||
|             echo "✅ Successfully created tag ${{ steps.commit.outputs.tag_name }} for commit ${{ steps.commit.outputs.sha }}" |             echo "Skipped (already existed): ${{ steps.check_tag.outputs.skipped_count }}" | ||||||
|  |             echo "Failed: ${{ steps.check_tag.outputs.failed_count }}" | ||||||
|  |             if [ "${{ steps.check_tag.outputs.failed_count }}" = "0" ]; then | ||||||
|  |               echo "✅ Completed tagging for push range ${{ github.event.before }}..${{ github.sha }}" | ||||||
|  |             else | ||||||
|  |               echo "❌ Some tags failed to create for push range ${{ github.event.before }}..${{ github.sha }}" | ||||||
|  |             fi | ||||||
|           else |           else | ||||||
|             echo "❌ Failed to create tag ${{ steps.commit.outputs.tag_name }} for commit ${{ steps.commit.outputs.sha }}" |             if [ "${{ steps.check_tag.outputs.failed_count }}" = "0" ]; then | ||||||
|           fi |               if [ "${{ steps.check_tag.outputs.created_count }}" = "0" ]; then | ||||||
|  |                 echo "✅ Tag ${{ steps.commit.outputs.tag_name }} already existed - no action needed" | ||||||
|  |               else | ||||||
|  |                 echo "✅ Successfully created tag ${{ steps.commit.outputs.tag_name }} for commit ${{ steps.commit.outputs.sha }}" | ||||||
|  |               fi | ||||||
|  |             else | ||||||
|  |               echo "❌ Failed to create tag ${{ steps.commit.outputs.tag_name }} for commit ${{ steps.commit.outputs.sha }}" | ||||||
|  |             fi | ||||||
|  |  | ||||||
|           echo "" |             echo "" | ||||||
|           echo "Tag details:" |             echo "Tag details:" | ||||||
|           echo "  Name: ${{ steps.commit.outputs.tag_name }}" |             echo "  Name: ${{ steps.commit.outputs.tag_name }}" | ||||||
|           echo "  Commit: ${{ steps.commit.outputs.sha }}" |             echo "  Commit: ${{ steps.commit.outputs.sha }}" | ||||||
|           echo "  Trigger: ${{ github.event_name }}" |             echo "  Trigger: ${{ github.event_name }}" | ||||||
|           if [ -n "${{ github.event.inputs.commit_sha }}" ]; then |             if [ -n "${{ github.event.inputs.commit_sha }}" ]; then | ||||||
|             echo "  Manual commit: ${{ github.event.inputs.commit_sha }}" |               echo "  Manual commit: ${{ github.event.inputs.commit_sha }}" | ||||||
|  |             fi | ||||||
|           fi |           fi | ||||||
|  | |||||||
| @ -833,8 +833,7 @@ exclude_patterns = [ | |||||||
| command = [ | command = [ | ||||||
|     'python3', |     'python3', | ||||||
|     'tools/linter/adapters/grep_linter.py', |     'tools/linter/adapters/grep_linter.py', | ||||||
|     '--pattern=cudaSetDevice(', |     '--pattern=(cudaSetDevice|cudaGetDevice)\\(', | ||||||
|     '--pattern=cudaGetDevice(', |  | ||||||
|     '--linter-name=RAWCUDADEVICE', |     '--linter-name=RAWCUDADEVICE', | ||||||
|     '--error-name=raw CUDA API usage', |     '--error-name=raw CUDA API usage', | ||||||
|     """--error-description=\ |     """--error-description=\ | ||||||
| @ -1138,11 +1137,8 @@ command = [ | |||||||
| [[linter]] | [[linter]] | ||||||
| code = 'WORKFLOWSYNC' | code = 'WORKFLOWSYNC' | ||||||
| include_patterns = [ | include_patterns = [ | ||||||
|     '.github/workflows/pull.yml', |     '.github/workflows/*.yml', | ||||||
|     '.github/workflows/trunk.yml', |     '.github/workflows/*.yaml', | ||||||
|     '.github/workflows/periodic.yml', |  | ||||||
|     '.github/workflows/mac-mps.yml', |  | ||||||
|     '.github/workflows/slow.yml', |  | ||||||
| ] | ] | ||||||
| command = [ | command = [ | ||||||
|     'python3', |     'python3', | ||||||
|  | |||||||
| @ -31,9 +31,9 @@ Be careful when running untrusted models. This classification includes models cr | |||||||
|  |  | ||||||
| **Prefer to execute untrusted models within a secure, isolated environment such as a sandbox** (e.g., containers, virtual machines). This helps protect your system from potentially malicious code. You can find further details and instructions in [this page](https://developers.google.com/code-sandboxing). | **Prefer to execute untrusted models within a secure, isolated environment such as a sandbox** (e.g., containers, virtual machines). This helps protect your system from potentially malicious code. You can find further details and instructions in [this page](https://developers.google.com/code-sandboxing). | ||||||
|  |  | ||||||
| **Be mindful of risky model formats**. Give preference to share and load weights with the appropriate format for your use case. [safetensors](https://huggingface.co/docs/safetensors/en/index) gives the most safety but is the most restricted in what it supports. [`torch.load`](https://pytorch.org/docs/stable/generated/torch.load.html#torch.load) with `weights_only=True` is also secure to our knowledge even though it offers significantly larger surface of attack. Loading un-trusted checkpoint with `weights_only=False` MUST never be done. | **Be mindful of risky model formats**. Give preference to share and load weights with the appropriate format for your use case. [safetensors](https://huggingface.co/docs/safetensors/en/index) gives the most safety but is the most restricted in what it supports. [`torch.load`](https://pytorch.org/docs/stable/generated/torch.load.html#torch.load) has a significantly larger surface of attack but is more flexible in what it can serialize. See the documentation for more details. | ||||||
|  |  | ||||||
|  |  | ||||||
|  | Even for more secure serialization formats, unexpected inputs to the downstream system can cause diverse security threats (e.g. denial of service, out of bound reads/writes) and thus we recommend extensive validation of any untrusted inputs. | ||||||
|  |  | ||||||
| Important Note: The trustworthiness of a model is not binary. You must always determine the proper level of caution depending on the specific model and how it matches your use case and risk tolerance. | Important Note: The trustworthiness of a model is not binary. You must always determine the proper level of caution depending on the specific model and how it matches your use case and risk tolerance. | ||||||
|  |  | ||||||
|  | |||||||
| @ -38,7 +38,7 @@ set_bool(AT_HIPSPARSELT_ENABLED CAFFE2_USE_HIPSPARSELT) | |||||||
|  |  | ||||||
| configure_file(Config.h.in "${CMAKE_CURRENT_SOURCE_DIR}/Config.h") | configure_file(Config.h.in "${CMAKE_CURRENT_SOURCE_DIR}/Config.h") | ||||||
| # TODO: Do not generate CUDAConfig.h for ROCm BUILDS | # TODO: Do not generate CUDAConfig.h for ROCm BUILDS | ||||||
| # At the moment, `jit_macors.h` include CUDAConfig.h for both CUDA and HIP builds | # At the moment, `jit_macros.h` include CUDAConfig.h for both CUDA and HIP builds | ||||||
| if(USE_CUDA OR USE_ROCM) | if(USE_CUDA OR USE_ROCM) | ||||||
|   configure_file(cuda/CUDAConfig.h.in "${CMAKE_CURRENT_SOURCE_DIR}/cuda/CUDAConfig.h") |   configure_file(cuda/CUDAConfig.h.in "${CMAKE_CURRENT_SOURCE_DIR}/cuda/CUDAConfig.h") | ||||||
| endif() | endif() | ||||||
| @ -289,14 +289,15 @@ IF(USE_FBGEMM_GENAI) | |||||||
|  |  | ||||||
|     set_target_properties(fbgemm_genai PROPERTIES POSITION_INDEPENDENT_CODE ON) |     set_target_properties(fbgemm_genai PROPERTIES POSITION_INDEPENDENT_CODE ON) | ||||||
|  |  | ||||||
|     set(fbgemm_genai_mx8mx8bf16_grouped |     set(fbgemm_genai_cuh | ||||||
|       "${FBGEMM_GENAI_SRCS}/cutlass_extensions/mx8mx8bf16_grouped/" |       "${FBGEMM_GENAI_SRCS}/cutlass_extensions/mx8mx8bf16_grouped/" | ||||||
|  |       "${FBGEMM_GENAI_SRCS}/" | ||||||
|     ) |     ) | ||||||
|  |  | ||||||
|     target_include_directories(fbgemm_genai PRIVATE |     target_include_directories(fbgemm_genai PRIVATE | ||||||
|       ${FBGEMM_THIRD_PARTY}/cutlass/include |       ${FBGEMM_THIRD_PARTY}/cutlass/include | ||||||
|       ${FBGEMM_THIRD_PARTY}/cutlass/tools/util/include |       ${FBGEMM_THIRD_PARTY}/cutlass/tools/util/include | ||||||
|       ${fbgemm_genai_mx8mx8bf16_grouped} |       ${fbgemm_genai_cuh} | ||||||
|       ${FBGEMM_GENAI_SRCS}/common/include/   # includes fbgemm_gpu/quantize/utils.h, fbgemm_gpu/quantize/tuning_cache.hpp |       ${FBGEMM_GENAI_SRCS}/common/include/   # includes fbgemm_gpu/quantize/utils.h, fbgemm_gpu/quantize/tuning_cache.hpp | ||||||
|       ${FBGEMM_GENAI_SRCS}/include/          # includes fbgemm_gpu/torch_ops.h |       ${FBGEMM_GENAI_SRCS}/include/          # includes fbgemm_gpu/torch_ops.h | ||||||
|     ) |     ) | ||||||
|  | |||||||
| @ -19,6 +19,7 @@ | |||||||
| #include <ATen/detail/MPSHooksInterface.h> | #include <ATen/detail/MPSHooksInterface.h> | ||||||
| #include <ATen/detail/MTIAHooksInterface.h> | #include <ATen/detail/MTIAHooksInterface.h> | ||||||
| #include <ATen/detail/PrivateUse1HooksInterface.h> | #include <ATen/detail/PrivateUse1HooksInterface.h> | ||||||
|  | #include <ATen/detail/XLAHooksInterface.h> | ||||||
| #include <ATen/detail/XPUHooksInterface.h> | #include <ATen/detail/XPUHooksInterface.h> | ||||||
| #include <c10/core/QEngine.h> | #include <c10/core/QEngine.h> | ||||||
| #include <c10/core/impl/DeviceGuardImplInterface.h> | #include <c10/core/impl/DeviceGuardImplInterface.h> | ||||||
| @ -88,6 +89,8 @@ class TORCH_API Context { | |||||||
|       return at::detail::getHIPHooks(); |       return at::detail::getHIPHooks(); | ||||||
|     } else if (opt_device_type == at::kHPU) { |     } else if (opt_device_type == at::kHPU) { | ||||||
|       return at::detail::getHPUHooks(); |       return at::detail::getHPUHooks(); | ||||||
|  |     } else if (opt_device_type == at::kXLA) { | ||||||
|  |       return at::detail::getXLAHooks(); | ||||||
|     } else { |     } else { | ||||||
|       TORCH_CHECK( |       TORCH_CHECK( | ||||||
|           false, |           false, | ||||||
| @ -196,7 +199,7 @@ class TORCH_API Context { | |||||||
|     return c10::impl::hasDeviceGuardImpl(c10::DeviceType::IPU); |     return c10::impl::hasDeviceGuardImpl(c10::DeviceType::IPU); | ||||||
|   } |   } | ||||||
|   static bool hasXLA() { |   static bool hasXLA() { | ||||||
|     return c10::impl::hasDeviceGuardImpl(c10::DeviceType::XLA); |     return detail::getXLAHooks().hasXLA(); | ||||||
|   } |   } | ||||||
|   static bool hasXPU() { |   static bool hasXPU() { | ||||||
|     return detail::getXPUHooks().hasXPU(); |     return detail::getXPUHooks().hasXPU(); | ||||||
|  | |||||||
| @ -122,7 +122,7 @@ void FunctionalTensorWrapper::freeze_storage() const { | |||||||
| //          |   have their own storages, but backends like functorch      | | //          |   have their own storages, but backends like functorch      | | ||||||
| //         \/   are allowed to re-alias underneath the pass               \/ | //         \/   are allowed to re-alias underneath the pass               \/ | ||||||
| // . - - - - - - - - - - - - - .                             . - - - - - - - - - - - - - - - . | // . - - - - - - - - - - - - - .                             . - - - - - - - - - - - - - - - . | ||||||
| // |    underyling_storage     |                             |      underyling_storage       | | // |    underlying_storage     |                             |      underlying_storage       | | ||||||
| // . - - - - - - - - - - - - - .                             . - - - - - - - - - - - - - - - . | // . - - - - - - - - - - - - - .                             . - - - - - - - - - - - - - - - . | ||||||
| // | // | ||||||
| // This constructor is only used by view ops. | // This constructor is only used by view ops. | ||||||
|  | |||||||
| @ -1534,7 +1534,7 @@ void TensorIteratorBase::build(TensorIteratorConfig& config) { | |||||||
|  |  | ||||||
|   // XLA and lazy tensors don't have storage, so they don't have an underlying data pointer. |   // XLA and lazy tensors don't have storage, so they don't have an underlying data pointer. | ||||||
|   // Nothing beyond this point is important for meta functions, so it's fine to exit early here. |   // Nothing beyond this point is important for meta functions, so it's fine to exit early here. | ||||||
|   // Extend the condition to MAIA tesnors as MAIA tensors also don't have storage. |   // Extend the condition to MAIA tensors as MAIA tensors also don't have storage. | ||||||
|   if (privateuse1_without_storage  || |   if (privateuse1_without_storage  || | ||||||
|       common_device_.type() == DeviceType::XLA  || |       common_device_.type() == DeviceType::XLA  || | ||||||
|       common_device_.type() == DeviceType::IPU  || |       common_device_.type() == DeviceType::IPU  || | ||||||
|  | |||||||
| @ -39,7 +39,7 @@ struct HostBlock { | |||||||
| }; | }; | ||||||
|  |  | ||||||
| template <typename B> | template <typename B> | ||||||
| struct alignas(64) FreeBlockList { | struct alignas(hardware_destructive_interference_size) FreeBlockList { | ||||||
|   std::mutex mutex_; |   std::mutex mutex_; | ||||||
|   std::deque<B*> list_; |   std::deque<B*> list_; | ||||||
| }; | }; | ||||||
| @ -94,11 +94,11 @@ struct PinnedReserveSegment { | |||||||
| struct TORCH_API HostStats { | struct TORCH_API HostStats { | ||||||
|   // COUNT: total allocations (active) |   // COUNT: total allocations (active) | ||||||
|   Stat active_requests; |   Stat active_requests; | ||||||
|   // SUM: bytes allocated/reserved by this memory alocator. (active) |   // SUM: bytes allocated/reserved by this memory allocator. (active) | ||||||
|   Stat active_bytes; |   Stat active_bytes; | ||||||
|   // COUNT: total allocations (active + free) |   // COUNT: total allocations (active + free) | ||||||
|   Stat allocations; |   Stat allocations; | ||||||
|   // SUM: bytes allocated/reserved by this memory alocator. This accounts |   // SUM: bytes allocated/reserved by this memory allocator. This accounts | ||||||
|   // for both free and in-use blocks. |   // for both free and in-use blocks. | ||||||
|   Stat allocated_bytes; |   Stat allocated_bytes; | ||||||
|  |  | ||||||
| @ -122,12 +122,12 @@ struct TORCH_API HostStats { | |||||||
| // Struct containing memory allocator summary statistics for host, as they | // Struct containing memory allocator summary statistics for host, as they | ||||||
| // are staged for reporting. This is a temporary struct that is used to | // are staged for reporting. This is a temporary struct that is used to | ||||||
| // avoid locking the allocator while collecting stats. | // avoid locking the allocator while collecting stats. | ||||||
| struct alignas(64) HostStatsStaged { | struct alignas(hardware_destructive_interference_size) HostStatsStaged { | ||||||
|   std::mutex timing_mutex_; |   std::mutex timing_mutex_; | ||||||
|   // COUNT: total allocations (active + free) |   // COUNT: total allocations (active + free) | ||||||
|   // LOCK: access to this stat is protected by the allocator's blocks_mutex_ |   // LOCK: access to this stat is protected by the allocator's blocks_mutex_ | ||||||
|   Stat allocations; |   Stat allocations; | ||||||
|   // SUM: bytes allocated/reserved by this memory alocator. This accounts |   // SUM: bytes allocated/reserved by this memory allocator. This accounts | ||||||
|   // for both free and in-use blocks. |   // for both free and in-use blocks. | ||||||
|   Stat allocated_bytes; |   Stat allocated_bytes; | ||||||
|   // COUNT: number of allocations per bucket (active) |   // COUNT: number of allocations per bucket (active) | ||||||
| @ -455,7 +455,7 @@ struct CachingHostAllocatorImpl { | |||||||
|   } |   } | ||||||
|  |  | ||||||
|   void resetAccumulatedStats() { |   void resetAccumulatedStats() { | ||||||
|     // Reseting accumulated memory stats requires concurrently holding both the |     // Resetting accumulated memory stats requires concurrently holding both the | ||||||
|     // free list mutexes and the blocks mutex. Previously, this was only done in |     // free list mutexes and the blocks mutex. Previously, this was only done in | ||||||
|     // empty_cache function. |     // empty_cache function. | ||||||
|     for (size_t i = 0; i < free_list_.size(); ++i) { |     for (size_t i = 0; i < free_list_.size(); ++i) { | ||||||
| @ -482,7 +482,7 @@ struct CachingHostAllocatorImpl { | |||||||
|   } |   } | ||||||
|  |  | ||||||
|   void resetPeakStats() { |   void resetPeakStats() { | ||||||
|     // Reseting peak memory stats requires concurrently holding both the |     // Resetting peak memory stats requires concurrently holding both the | ||||||
|     // free list mutexes and the blocks mutex. Previously, this was only done in |     // free list mutexes and the blocks mutex. Previously, this was only done in | ||||||
|     // empty_cache function. |     // empty_cache function. | ||||||
|     for (size_t i = 0; i < free_list_.size(); ++i) { |     for (size_t i = 0; i < free_list_.size(); ++i) { | ||||||
| @ -669,7 +669,7 @@ struct CachingHostAllocatorImpl { | |||||||
|     TORCH_CHECK_NOT_IMPLEMENTED(false, "Not implemented for query_event"); |     TORCH_CHECK_NOT_IMPLEMENTED(false, "Not implemented for query_event"); | ||||||
|   } |   } | ||||||
|  |  | ||||||
|   alignas(64) std::mutex blocks_mutex_; |   alignas(hardware_destructive_interference_size) std::mutex blocks_mutex_; | ||||||
|   ska::flat_hash_set<B*> blocks_; // block list |   ska::flat_hash_set<B*> blocks_; // block list | ||||||
|   ska::flat_hash_map<void*, B*> ptr_to_block_; |   ska::flat_hash_map<void*, B*> ptr_to_block_; | ||||||
|  |  | ||||||
| @ -677,17 +677,17 @@ struct CachingHostAllocatorImpl { | |||||||
|   // size. This allows us to quickly find a free block of the right size. |   // size. This allows us to quickly find a free block of the right size. | ||||||
|   // We use deque to store per size free list and guard the list with its own |   // We use deque to store per size free list and guard the list with its own | ||||||
|   // mutex. |   // mutex. | ||||||
|   alignas(64) std::vector<FreeBlockList<B>> free_list_ = |   alignas(hardware_destructive_interference_size) std::vector<FreeBlockList<B>> free_list_ = | ||||||
|       std::vector<FreeBlockList<B>>(MAX_SIZE_INDEX); |       std::vector<FreeBlockList<B>>(MAX_SIZE_INDEX); | ||||||
|  |  | ||||||
|   alignas(64) std::mutex events_mutex_; |   alignas(hardware_destructive_interference_size) std::mutex events_mutex_; | ||||||
|   std::deque<std::pair<E, B*>> events_; // event queue paired with block |   std::deque<std::pair<E, B*>> events_; // event queue paired with block | ||||||
|  |  | ||||||
|   // Indicates whether the object is active. |   // Indicates whether the object is active. | ||||||
|   // Set to false in the destructor to signal background threads to stop. |   // Set to false in the destructor to signal background threads to stop. | ||||||
|   std::atomic<bool> active_{true}; |   std::atomic<bool> active_{true}; | ||||||
| protected: | protected: | ||||||
|   alignas(64) HostStatsStaged stats_; |   alignas(hardware_destructive_interference_size) HostStatsStaged stats_; | ||||||
| }; | }; | ||||||
|  |  | ||||||
| struct TORCH_API HostAllocator : public at::Allocator { | struct TORCH_API HostAllocator : public at::Allocator { | ||||||
|  | |||||||
| @ -59,9 +59,7 @@ struct TORCH_API Generator { | |||||||
|  |  | ||||||
|   explicit Generator(c10::intrusive_ptr<c10::GeneratorImpl> gen_impl) |   explicit Generator(c10::intrusive_ptr<c10::GeneratorImpl> gen_impl) | ||||||
|    : impl_(std::move(gen_impl)) { |    : impl_(std::move(gen_impl)) { | ||||||
|     if (impl_.get() == nullptr) { |     TORCH_CHECK(impl_.get(), "GeneratorImpl with nullptr is not supported"); | ||||||
|       throw std::runtime_error("GeneratorImpl with nullptr is not supported"); |  | ||||||
|     } |  | ||||||
|   } |   } | ||||||
|  |  | ||||||
|   bool operator==(const Generator& rhs) const { |   bool operator==(const Generator& rhs) const { | ||||||
|  | |||||||
| @ -111,9 +111,7 @@ class TORCH_API TensorBase { | |||||||
|   explicit TensorBase( |   explicit TensorBase( | ||||||
|       c10::intrusive_ptr<TensorImpl, UndefinedTensorImpl> tensor_impl) |       c10::intrusive_ptr<TensorImpl, UndefinedTensorImpl> tensor_impl) | ||||||
|       : impl_(std::move(tensor_impl)) { |       : impl_(std::move(tensor_impl)) { | ||||||
|     if (impl_.get() == nullptr) { |     TORCH_CHECK(impl_.get(), "TensorImpl with nullptr is not supported"); | ||||||
|       throw std::runtime_error("TensorImpl with nullptr is not supported"); |  | ||||||
|     } |  | ||||||
|   } |   } | ||||||
|   TensorBase(const TensorBase&) = default; |   TensorBase(const TensorBase&) = default; | ||||||
|   TensorBase(TensorBase&&) noexcept = default; |   TensorBase(TensorBase&&) noexcept = default; | ||||||
|  | |||||||
| @ -109,6 +109,10 @@ TORCH_LIBRARY_IMPL(_, AutogradHPU, m) { | |||||||
|   m.fallback(AUTOGRAD_FALLBACK); |   m.fallback(AUTOGRAD_FALLBACK); | ||||||
| } | } | ||||||
|  |  | ||||||
|  | TORCH_LIBRARY_IMPL(_, AutogradPrivateUse1, m) { | ||||||
|  |   m.fallback(AUTOGRAD_FALLBACK); | ||||||
|  | } | ||||||
|  |  | ||||||
| #undef AUTOGRAD_FALLBACK | #undef AUTOGRAD_FALLBACK | ||||||
|  |  | ||||||
| } // namespace | } // namespace | ||||||
|  | |||||||
| @ -148,7 +148,7 @@ struct TORCH_API ClassType : public NamedType { | |||||||
|  |  | ||||||
|   void checkNotExist(const std::string& name, const std::string& what) const; |   void checkNotExist(const std::string& name, const std::string& what) const; | ||||||
|  |  | ||||||
|   // Attributes are stored in a specific slot at runtime for effiency. |   // Attributes are stored in a specific slot at runtime for efficiency. | ||||||
|   // When emitting instructions we specify the slot so that attribute access is |   // When emitting instructions we specify the slot so that attribute access is | ||||||
|   // a constant lookup |   // a constant lookup | ||||||
|   std::optional<size_t> findAttributeSlot(const std::string& name) const { |   std::optional<size_t> findAttributeSlot(const std::string& name) const { | ||||||
| @ -412,7 +412,7 @@ struct TORCH_API ClassType : public NamedType { | |||||||
|   // Holds method attributes |   // Holds method attributes | ||||||
|   std::weak_ptr<CompilationUnit> compilation_unit_; |   std::weak_ptr<CompilationUnit> compilation_unit_; | ||||||
|  |  | ||||||
|   // Holds all atrributes, attribute details are found on ClassAttribute |   // Holds all attributes, attribute details are found on ClassAttribute | ||||||
|   std::vector<ClassAttribute> attributes_; |   std::vector<ClassAttribute> attributes_; | ||||||
|   // Construct mirroring attributes_, only around due to the fact that `containedTypes()` method returns an ArrayRef. |   // Construct mirroring attributes_, only around due to the fact that `containedTypes()` method returns an ArrayRef. | ||||||
|   // Never fill this without using the appropriate provideNewClassAttribute method |   // Never fill this without using the appropriate provideNewClassAttribute method | ||||||
|  | |||||||
| @ -442,11 +442,17 @@ RegistrationHandleRAII Dispatcher::registerFallback(DispatchKey dispatchKey, Ker | |||||||
|  |  | ||||||
|   auto idx = getDispatchTableIndexForDispatchKey(dispatchKey); |   auto idx = getDispatchTableIndexForDispatchKey(dispatchKey); | ||||||
|   TORCH_CHECK(idx >= 0 && static_cast<uint64_t>(idx) < backendFallbackKernels_.size(), "idx=", idx); |   TORCH_CHECK(idx >= 0 && static_cast<uint64_t>(idx) < backendFallbackKernels_.size(), "idx=", idx); | ||||||
|  |   // NB: Perserve BC for registering fallback for AutogradPrivateUse1 multiple time, | ||||||
|  |   // refer to https://github.com/pytorch/pytorch/issues/163979 for more informations. | ||||||
|   TORCH_CHECK( |   TORCH_CHECK( | ||||||
|     !backendFallbackKernels_[idx].kernel.isValid(), |       dispatchKey == DispatchKey::AutogradPrivateUse1 || | ||||||
|     "Tried to register multiple backend fallbacks for the same dispatch key ", dispatchKey, "; previous registration ", |           !backendFallbackKernels_[idx].kernel.isValid(), | ||||||
|     backendFallbackKernels_[idx].debug, ", new registration ", debug |       "Tried to register multiple backend fallbacks for the same dispatch key ", | ||||||
|   ); |       dispatchKey, | ||||||
|  |       "; previous registration ", | ||||||
|  |       backendFallbackKernels_[idx].debug, | ||||||
|  |       ", new registration ", | ||||||
|  |       debug); | ||||||
|   // NB: inferred function schema is always nullptr for fallbacks, as fallbacks |   // NB: inferred function schema is always nullptr for fallbacks, as fallbacks | ||||||
|   // cannot be unboxed |   // cannot be unboxed | ||||||
|   backendFallbackKernels_[idx] = impl::AnnotatedKernel(std::move(kernel), nullptr, std::move(debug)); |   backendFallbackKernels_[idx] = impl::AnnotatedKernel(std::move(kernel), nullptr, std::move(debug)); | ||||||
| @ -531,7 +537,7 @@ int64_t Dispatcher::sequenceNumberForRunningRecordFunction(DispatchKey dispatchK | |||||||
|  |  | ||||||
|   // Note: this records a sequence number for both Autograd keys, and for |   // Note: this records a sequence number for both Autograd keys, and for | ||||||
|   // non-Autograd keys where the dispatchKeySet still contains an autograd key. |   // non-Autograd keys where the dispatchKeySet still contains an autograd key. | ||||||
|   // This means that we might collect the same sequence nubmer two different |   // This means that we might collect the same sequence number two different | ||||||
|   // events if they all occurred above Autograd and still had the Autograd |   // events if they all occurred above Autograd and still had the Autograd | ||||||
|   // dispatch key in the dispatch key set. |   // dispatch key in the dispatch key set. | ||||||
|   // However, this usually doesn't happen: normally the first call will |   // However, this usually doesn't happen: normally the first call will | ||||||
|  | |||||||
| @ -585,7 +585,7 @@ class TORCH_API OperatorHandle { | |||||||
|  |  | ||||||
|   // We need to store this iterator in order to make |   // We need to store this iterator in order to make | ||||||
|   // Dispatcher::cleanup() fast -- it runs a lot on program |   // Dispatcher::cleanup() fast -- it runs a lot on program | ||||||
|   // termination (and presuambly library unloading). |   // termination (and presumably library unloading). | ||||||
|   std::list<Dispatcher::OperatorDef>::iterator operatorIterator_; |   std::list<Dispatcher::OperatorDef>::iterator operatorIterator_; | ||||||
| }; | }; | ||||||
|  |  | ||||||
|  | |||||||
| @ -365,7 +365,7 @@ std::pair<const AnnotatedKernel&, const char*> OperatorEntry::computeDispatchTab | |||||||
|   //          For autograd keys, we only use kernel from CompositeImplicitAutograd when there's no direct registration |   //          For autograd keys, we only use kernel from CompositeImplicitAutograd when there's no direct registration | ||||||
|   //          to its corresponding backend key or CompositeExplicitAutograd. See Note [CompositeExplicitAutograd and CompositeImplicitAutograd]. |   //          to its corresponding backend key or CompositeExplicitAutograd. See Note [CompositeExplicitAutograd and CompositeImplicitAutograd]. | ||||||
|   //          For AutogradOther, we eagerly return ambiguousAutogradOtherKernel() if there's registration to any of |   //          For AutogradOther, we eagerly return ambiguousAutogradOtherKernel() if there's registration to any of | ||||||
|   //          its backends and ask backend extender to request a decicated Autograd key for the backend. |   //          its backends and ask backend extender to request a dedicated Autograd key for the backend. | ||||||
|   //          See Note [Ambiguity in AutogradOther kernel] for more details. |   //          See Note [Ambiguity in AutogradOther kernel] for more details. | ||||||
|   //          A CompositeExplicitAutograd kernel prevents CompositeImplicitAutograd kernel being used for Autograd keys, but it doesn't |   //          A CompositeExplicitAutograd kernel prevents CompositeImplicitAutograd kernel being used for Autograd keys, but it doesn't | ||||||
|   //          cause confusion for AutogradOther. It's pretty straightforward to use Autograd (if available) |   //          cause confusion for AutogradOther. It's pretty straightforward to use Autograd (if available) | ||||||
|  | |||||||
| @ -261,7 +261,7 @@ std::ostream& operator<<(std::ostream& out, const FunctionSchema& schema) { | |||||||
|     // |     // | ||||||
|     // There are 2 cases |     // There are 2 cases | ||||||
|     // 1. something like 'aten::items.str(Dict(str, t) self) -> ((str, t)[])'. |     // 1. something like 'aten::items.str(Dict(str, t) self) -> ((str, t)[])'. | ||||||
|     // without the extra parenthesis, the c++ schem parser can not parse it. |     // without the extra parenthesis, the c++ scheme parser can not parse it. | ||||||
|     // 2. something like '-> ((str, str))'. Need extra parenthesis so the return |     // 2. something like '-> ((str, str))'. Need extra parenthesis so the return | ||||||
|     // type is a single tuple rather than two strings. |     // type is a single tuple rather than two strings. | ||||||
|     // PR (https://github.com/pytorch/pytorch/pull/23204) has more context about |     // PR (https://github.com/pytorch/pytorch/pull/23204) has more context about | ||||||
|  | |||||||
| @ -68,11 +68,7 @@ Symbol InternedStrings::_symbol(const std::string& s) { | |||||||
|     return it->second; |     return it->second; | ||||||
|  |  | ||||||
|   auto pos = s.find("::"); |   auto pos = s.find("::"); | ||||||
|   if (pos == std::string::npos) { |   TORCH_CHECK(pos != std::string::npos, "all symbols must have a namespace, <namespace>::<string>, but found: ", s); | ||||||
|     std::stringstream ss; |  | ||||||
|     ss << "all symbols must have a namespace, <namespace>::<string>, but found: " << s; |  | ||||||
|     throw std::runtime_error(ss.str()); |  | ||||||
|   } |  | ||||||
|   Symbol ns = _symbol("namespaces::" + s.substr(0, pos)); |   Symbol ns = _symbol("namespaces::" + s.substr(0, pos)); | ||||||
|  |  | ||||||
|   Symbol sym(sym_to_info_.size()); |   Symbol sym(sym_to_info_.size()); | ||||||
| @ -121,12 +117,7 @@ std::string Symbol::domainString() const { | |||||||
| } | } | ||||||
|  |  | ||||||
| Symbol Symbol::fromDomainAndUnqualString(const std::string & d, const std::string & s) { | Symbol Symbol::fromDomainAndUnqualString(const std::string & d, const std::string & s) { | ||||||
|   if (d.compare(0, domain_prefix().size(), domain_prefix()) != 0) { |   TORCH_CHECK(d.compare(0, domain_prefix().size(), domain_prefix()) == 0, "Symbol: domain string is expected to be prefixed with '", domain_prefix(), "', e.g. 'org.pytorch.aten'"); | ||||||
|     std::ostringstream ss; |  | ||||||
|     ss << "Symbol: domain string is expected to be prefixed with '" |  | ||||||
|        << domain_prefix() << "', e.g. 'org.pytorch.aten'"; |  | ||||||
|     throw std::runtime_error(ss.str()); |  | ||||||
|   } |  | ||||||
|   std::string qualString = d.substr(domain_prefix().size()) + "::" + s; |   std::string qualString = d.substr(domain_prefix().size()) + "::" + s; | ||||||
|   return fromQualString(qualString); |   return fromQualString(qualString); | ||||||
| } | } | ||||||
|  | |||||||
| @ -7,6 +7,7 @@ | |||||||
| #include <ATen/core/jit_type.h> | #include <ATen/core/jit_type.h> | ||||||
| #include <ATen/core/stack.h> | #include <ATen/core/stack.h> | ||||||
| #include <ATen/core/type_factory.h> | #include <ATen/core/type_factory.h> | ||||||
|  | #include <c10/util/Exception.h> | ||||||
| #include <c10/util/StringUtil.h> | #include <c10/util/StringUtil.h> | ||||||
| #include <c10/util/hash.h> | #include <c10/util/hash.h> | ||||||
| #include <c10/util/irange.h> | #include <c10/util/irange.h> | ||||||
| @ -412,7 +413,7 @@ size_t IValue::hash(const IValue& v) { | |||||||
|     case Tag::Enum: |     case Tag::Enum: | ||||||
|     case Tag::Stream: |     case Tag::Stream: | ||||||
|     case Tag::Uninitialized: |     case Tag::Uninitialized: | ||||||
|       throw std::runtime_error( |       TORCH_CHECK(false, | ||||||
|           "unhashable type: '" + v.type()->repr_str() + "'"); |           "unhashable type: '" + v.type()->repr_str() + "'"); | ||||||
|   } |   } | ||||||
|   // the above switch should be exhaustive |   // the above switch should be exhaustive | ||||||
|  | |||||||
| @ -1176,7 +1176,7 @@ struct TORCH_API IValue final { | |||||||
|   using HashIdentityIValueMap = |   using HashIdentityIValueMap = | ||||||
|       std::unordered_map<IValue, IValue, HashIdentityIValue, CompIdentityIValues>; |       std::unordered_map<IValue, IValue, HashIdentityIValue, CompIdentityIValues>; | ||||||
|  |  | ||||||
|   // Chechs if this and rhs has a subvalues in common. |   // Checks if this and rhs has a subvalues in common. | ||||||
|   // [t1,t2] and [t2, t3] returns true. |   // [t1,t2] and [t2, t3] returns true. | ||||||
|   bool overlaps(const IValue& rhs) const; |   bool overlaps(const IValue& rhs) const; | ||||||
|  |  | ||||||
|  | |||||||
| @ -1501,7 +1501,7 @@ struct C10_EXPORT ivalue::Object final : c10::intrusive_ptr_target { | |||||||
|   // However, the CompilationUnit holds ownership of the type's graphs, so |   // However, the CompilationUnit holds ownership of the type's graphs, so | ||||||
|   // inserting a constant object into a Graph would create a reference cycle if |   // inserting a constant object into a Graph would create a reference cycle if | ||||||
|   // that constant object held a shared_ptr to its CU. For these objects we |   // that constant object held a shared_ptr to its CU. For these objects we | ||||||
|   // instatiate them with non-owning references to its CU |   // instantiate them with non-owning references to its CU | ||||||
|   Object(WeakOrStrongTypePtr type, size_t numSlots) : type_(std::move(type)) { |   Object(WeakOrStrongTypePtr type, size_t numSlots) : type_(std::move(type)) { | ||||||
|     slots_.resize(numSlots); |     slots_.resize(numSlots); | ||||||
|   } |   } | ||||||
|  | |||||||
| @ -8,6 +8,7 @@ | |||||||
| #include <ATen/core/type_factory.h> | #include <ATen/core/type_factory.h> | ||||||
| #include <ATen/core/qualified_name.h> | #include <ATen/core/qualified_name.h> | ||||||
| #include <c10/util/TypeList.h> | #include <c10/util/TypeList.h> | ||||||
|  | #include <c10/util/Exception.h> | ||||||
| #include <optional> | #include <optional> | ||||||
| #include <c10/core/SymFloat.h> | #include <c10/core/SymFloat.h> | ||||||
| #include <c10/core/SymBool.h> | #include <c10/core/SymBool.h> | ||||||
| @ -116,10 +117,8 @@ struct SingleElementType : public SharedType { | |||||||
|  |  | ||||||
|  protected: |  protected: | ||||||
|   SingleElementType(TypePtr elem) : SharedType(Kind), elem(std::move(elem)) { |   SingleElementType(TypePtr elem) : SharedType(Kind), elem(std::move(elem)) { | ||||||
|     if (!this->elem) { |     TORCH_CHECK(this->elem, c10::str( | ||||||
|       throw std::runtime_error(c10::str( |  | ||||||
|             "Can not create ", typeKindToString(Kind), " with None type")); |             "Can not create ", typeKindToString(Kind), " with None type")); | ||||||
|     } |  | ||||||
|   } |   } | ||||||
|  |  | ||||||
|  private: |  private: | ||||||
| @ -374,7 +373,7 @@ struct TORCH_API SymbolicShape { | |||||||
|   // Unranked shape constructor. |   // Unranked shape constructor. | ||||||
|   SymbolicShape() : dims_(std::nullopt) {} |   SymbolicShape() : dims_(std::nullopt) {} | ||||||
|  |  | ||||||
|   // Known rank but unknown dimentions. |   // Known rank but unknown dimensions. | ||||||
|   SymbolicShape(std::optional<size_t> rank) : dims_(std::nullopt) { |   SymbolicShape(std::optional<size_t> rank) : dims_(std::nullopt) { | ||||||
|     if(!rank) { |     if(!rank) { | ||||||
|       return; |       return; | ||||||
| @ -416,16 +415,12 @@ struct TORCH_API SymbolicShape { | |||||||
|   } |   } | ||||||
|  |  | ||||||
|   ShapeSymbol operator[](size_t i) const { |   ShapeSymbol operator[](size_t i) const { | ||||||
|     if (!dims_) { |     TORCH_CHECK(dims_, "Rank isn't fixed"); | ||||||
|       throw std::runtime_error("Rank isn't fixed"); |  | ||||||
|     } |  | ||||||
|     return (*dims_).at(i); |     return (*dims_).at(i); | ||||||
|   } |   } | ||||||
|  |  | ||||||
|   ShapeSymbol at(size_t i) const { |   ShapeSymbol at(size_t i) const { | ||||||
|     if (!dims_) { |     TORCH_CHECK(dims_, "Rank isn't fixed"); | ||||||
|       throw std::runtime_error("Rank isn't fixed"); |  | ||||||
|     } |  | ||||||
|     return (*dims_).at(i); |     return (*dims_).at(i); | ||||||
|   } |   } | ||||||
|  |  | ||||||
| @ -520,9 +515,7 @@ struct VaryingShape { | |||||||
|   } |   } | ||||||
|  |  | ||||||
|   const std::optional<T> &operator[](size_t i) const { |   const std::optional<T> &operator[](size_t i) const { | ||||||
|     if (!dims_) { |     TORCH_CHECK(dims_, "Rank isn't fixed"); | ||||||
|       throw std::runtime_error("Rank isn't fixed"); |  | ||||||
|     } |  | ||||||
|     return (*dims_).at(i); |     return (*dims_).at(i); | ||||||
|   } |   } | ||||||
|  |  | ||||||
| @ -891,9 +884,9 @@ struct TORCH_API ListType | |||||||
|  |  | ||||||
|   // global singleton |   // global singleton | ||||||
|   // Given an inner type T and an identifier, |   // Given an inner type T and an identifier, | ||||||
|   // this function wil return the global singleton type pointer |   // this function will return the global singleton type pointer | ||||||
|   // the type List<T>. |   // the type List<T>. | ||||||
|   // The extra "identifier" argument is needed beccause we have multiple container types |   // The extra "identifier" argument is needed because we have multiple container types | ||||||
|   // that all re-use this function (List<T>, array<T, N>, etc.) |   // that all re-use this function (List<T>, array<T, N>, etc.) | ||||||
|   static TypePtr get(const std::string& identifier, TypePtr inner); |   static TypePtr get(const std::string& identifier, TypePtr inner); | ||||||
|  |  | ||||||
| @ -957,9 +950,7 @@ struct TORCH_API DictType : public SharedType { | |||||||
|  |  | ||||||
|   TypePtr createWithContained( |   TypePtr createWithContained( | ||||||
|       std::vector<TypePtr> contained_types) const override { |       std::vector<TypePtr> contained_types) const override { | ||||||
|     if (contained_types.size() != 2) { |     TORCH_CHECK(contained_types.size() == 2, "Expected 2 contained types"); | ||||||
|       throw std::runtime_error("Expected 2 contained types"); |  | ||||||
|     } |  | ||||||
|     return create(std::move(contained_types.at(0)), std::move(contained_types.at(1))); |     return create(std::move(contained_types.at(0)), std::move(contained_types.at(1))); | ||||||
|   } |   } | ||||||
|  |  | ||||||
|  | |||||||
| @ -185,11 +185,11 @@ struct TORCH_API Type { | |||||||
|         : repr_(nullptr) {} |         : repr_(nullptr) {} | ||||||
|  |  | ||||||
|     /* implicit */ SingletonOrSharedTypePtr(SingletonTypePtr<T> p) |     /* implicit */ SingletonOrSharedTypePtr(SingletonTypePtr<T> p) | ||||||
|         : repr_(p) {} |         : repr_(makeSingletonSharedPtr(p.get())) {} | ||||||
|  |  | ||||||
|     template <typename U, std::enable_if_t<std::is_convertible_v<U*, T*>, bool> = true> |     template <typename U, std::enable_if_t<std::is_convertible_v<U*, T*>, bool> = true> | ||||||
|     /* implicit */ SingletonOrSharedTypePtr(SingletonTypePtr<U> p) |     /* implicit */ SingletonOrSharedTypePtr(SingletonTypePtr<U> p) | ||||||
|         : repr_(SingletonTypePtr<T>(p.get())) {} |         : repr_(makeSingletonSharedPtr(static_cast<T*>(p.get()))) {} | ||||||
|  |  | ||||||
|  |  | ||||||
|     // We need to support construction from T* for pybind. The problem |     // We need to support construction from T* for pybind. The problem | ||||||
| @ -202,8 +202,8 @@ struct TORCH_API Type { | |||||||
|     // Case 2: if T is exactly Type, we need to do a dynamic_cast to |     // Case 2: if T is exactly Type, we need to do a dynamic_cast to | ||||||
|     // check if it's a SharedType and do the right thing. |     // check if it's a SharedType and do the right thing. | ||||||
|     // |     // | ||||||
|     // Case 3: Otherwise, T is not a SharedType. (debug-check this |     // Case 3: Otherwise, T is not a SharedType. Use a singleton | ||||||
|     // assumption!) Use a singleton pointer. |     // pointer. | ||||||
|  |  | ||||||
|     template <typename U = T, std::enable_if_t<std::is_base_of_v<SharedType, U>, bool> = true> |     template <typename U = T, std::enable_if_t<std::is_base_of_v<SharedType, U>, bool> = true> | ||||||
|     /* implicit */ SingletonOrSharedTypePtr(T* p) : SingletonOrSharedTypePtr(static_cast<typename detail::as_shared_type<U>::type>(p)->shared_from_this()) {} |     /* implicit */ SingletonOrSharedTypePtr(T* p) : SingletonOrSharedTypePtr(static_cast<typename detail::as_shared_type<U>::type>(p)->shared_from_this()) {} | ||||||
| @ -211,15 +211,15 @@ struct TORCH_API Type { | |||||||
|     template <typename U = T, std::enable_if_t<std::is_same_v<Type, U>, bool> = true> |     template <typename U = T, std::enable_if_t<std::is_same_v<Type, U>, bool> = true> | ||||||
|     /* implicit */ SingletonOrSharedTypePtr(T* p) { |     /* implicit */ SingletonOrSharedTypePtr(T* p) { | ||||||
|       if (auto* shared_p = dynamic_cast<typename detail::as_shared_type<U>::type>(p)) { |       if (auto* shared_p = dynamic_cast<typename detail::as_shared_type<U>::type>(p)) { | ||||||
|         repr_ = Repr(shared_p->shared_from_this()); |         repr_ = shared_p->shared_from_this(); | ||||||
|       } else { |       } else { | ||||||
|         repr_ = Repr(p); |         repr_ = makeSingletonSharedPtr(p); | ||||||
|       } |       } | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     template <typename U = T, std::enable_if_t<!std::is_same_v<Type, U> && !std::is_base_of_v<SharedType, U>, bool> = true> |     template <typename U = T, std::enable_if_t<!std::is_same_v<Type, U> && !std::is_base_of_v<SharedType, U>, bool> = true> | ||||||
|     /* implicit */ SingletonOrSharedTypePtr(T* p) |     /* implicit */ SingletonOrSharedTypePtr(T* p) | ||||||
|         : repr_(p) { |         : repr_(makeSingletonSharedPtr(p)) { | ||||||
|       TORCH_INTERNAL_ASSERT_DEBUG_ONLY(dynamic_cast<typename detail::as_shared_type<U>::type>(p) == nullptr); |       TORCH_INTERNAL_ASSERT_DEBUG_ONLY(dynamic_cast<typename detail::as_shared_type<U>::type>(p) == nullptr); | ||||||
|     } |     } | ||||||
|  |  | ||||||
| @ -230,19 +230,19 @@ struct TORCH_API Type { | |||||||
|     ~SingletonOrSharedTypePtr() = default; |     ~SingletonOrSharedTypePtr() = default; | ||||||
|  |  | ||||||
|     T* get() const { |     T* get() const { | ||||||
|       return repr_.isSharedAndNonNull() ? repr_.shared_.repr_.get() : static_cast<T*>(repr_.rawRepr().first); |       return repr_.get(); | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     operator bool() const { |     operator bool() const { | ||||||
|       return repr_.isNonNull(); |       return repr_ != nullptr; | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     bool operator==(std::nullptr_t) const { |     bool operator==(std::nullptr_t) const { | ||||||
|       return !repr_.isNonNull(); |       return repr_ == nullptr; | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     bool operator!=(std::nullptr_t) const { |     bool operator!=(std::nullptr_t) const { | ||||||
|       return repr_.isNonNull(); |       return repr_ != nullptr; | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     template <typename U = T, std::enable_if_t<!std::is_same_v<std::remove_const_t<U>, void>, bool> = true> |     template <typename U = T, std::enable_if_t<!std::is_same_v<std::remove_const_t<U>, void>, bool> = true> | ||||||
| @ -255,138 +255,14 @@ struct TORCH_API Type { | |||||||
|     } |     } | ||||||
|  |  | ||||||
|   private: |   private: | ||||||
|     // NOTE: SharedPtrWrapper exists to work around a baffling bug in |     // Use shared_ptr's aliasing constructor to create a non-owning pointer | ||||||
|     // nvcc; see comment in destroy() below. |     // to a singleton. The lifetime is tied to the null shared_ptr, so there's | ||||||
|     struct SharedPtrWrapper { |     // no reference counting overhead for the singleton itself. | ||||||
|       SharedPtrWrapper(std::shared_ptr<T> &&x) |     static std::shared_ptr<T> makeSingletonSharedPtr(T* ptr) { | ||||||
|           : repr_(std::move(x)) {} |       return std::shared_ptr<T>(std::shared_ptr<T>(), ptr); | ||||||
|       std::shared_ptr<T> repr_; |     } | ||||||
|     }; |  | ||||||
|     union Repr { |  | ||||||
|       Repr() : Repr(nullptr) {} |  | ||||||
|  |  | ||||||
|       explicit Repr(std::shared_ptr<T> x) |     std::shared_ptr<T> repr_; | ||||||
|           : shared_(std::move(x)) {} |  | ||||||
|  |  | ||||||
|       explicit Repr(std::nullptr_t) |  | ||||||
|           : singletonRepr_(nullptr) {} |  | ||||||
|  |  | ||||||
|       explicit Repr(SingletonTypePtr<T> p) |  | ||||||
|           : singletonRepr_(p.get()) {} |  | ||||||
|  |  | ||||||
|       ~Repr() { |  | ||||||
|         destroy(); |  | ||||||
|       } |  | ||||||
|  |  | ||||||
|       // NOTE: the only non-UB way to access our null state is through |  | ||||||
|       // rawRepr(), because our copy operation doesn't preserve which |  | ||||||
|       // union member is active for null pointers. |  | ||||||
|       Repr(const Repr& rhs) { |  | ||||||
|         if (rhs.isSharedAndNonNull()) { |  | ||||||
|           new (&shared_) SharedPtrWrapper(rhs.shared_); |  | ||||||
|         } else { |  | ||||||
|           singletonRepr_.singleton_ = static_cast<T*>(rhs.rawRepr().first); |  | ||||||
|           TORCH_INTERNAL_ASSERT_DEBUG_ONLY(rhs.singletonRepr_.unused_ == nullptr); |  | ||||||
|           singletonRepr_.unused_ = nullptr; |  | ||||||
|         } |  | ||||||
|       } |  | ||||||
|  |  | ||||||
|       Repr(Repr&& rhs) noexcept { |  | ||||||
|         if (rhs.isSharedAndNonNull()) { |  | ||||||
|           new (&shared_) SharedPtrWrapper(std::move(rhs.shared_)); |  | ||||||
|         } else { |  | ||||||
|           singletonRepr_.singleton_ = static_cast<T*>(rhs.rawRepr().first); |  | ||||||
|           TORCH_INTERNAL_ASSERT_DEBUG_ONLY(rhs.singletonRepr_.unused_ == nullptr); |  | ||||||
|           singletonRepr_.unused_ = nullptr; |  | ||||||
|         } |  | ||||||
|       } |  | ||||||
|  |  | ||||||
|       Repr& operator=(const Repr& rhs) { |  | ||||||
|         if (&rhs == this) { |  | ||||||
|           return *this; |  | ||||||
|         } |  | ||||||
|         if (rhs.isSharedAndNonNull()) { |  | ||||||
|           if (isSharedAndNonNull()) { |  | ||||||
|             shared_ = rhs.shared_; |  | ||||||
|           } else { |  | ||||||
|             new (&shared_) SharedPtrWrapper(rhs.shared_); |  | ||||||
|           } |  | ||||||
|         } else { |  | ||||||
|           if (isSharedAndNonNull()) { |  | ||||||
|             destroy(); |  | ||||||
|           } |  | ||||||
|           singletonRepr_.singleton_ = static_cast<T*>(rhs.rawRepr().first); |  | ||||||
|           TORCH_INTERNAL_ASSERT_DEBUG_ONLY(rhs.rawRepr().nullIfSingleton_ == nullptr); |  | ||||||
|           singletonRepr_.unused_ = nullptr; |  | ||||||
|         } |  | ||||||
|         return *this; |  | ||||||
|       } |  | ||||||
|  |  | ||||||
|       Repr& operator=(Repr&& rhs) noexcept { |  | ||||||
|         if (&rhs == this) { |  | ||||||
|           return *this; |  | ||||||
|         } |  | ||||||
|         if (rhs.isSharedAndNonNull()) { |  | ||||||
|           if (isSharedAndNonNull()) { |  | ||||||
|             shared_ = std::move(rhs.shared_); |  | ||||||
|           } else { |  | ||||||
|             new (&shared_) SharedPtrWrapper(std::move(rhs.shared_)); |  | ||||||
|           } |  | ||||||
|         } else { |  | ||||||
|           if (isSharedAndNonNull()) { |  | ||||||
|             destroy(); |  | ||||||
|           } |  | ||||||
|           singletonRepr_.singleton_ = static_cast<T*>(rhs.rawRepr().first); |  | ||||||
|           TORCH_INTERNAL_ASSERT_DEBUG_ONLY(rhs.rawRepr().nullIfSingleton_ == nullptr); |  | ||||||
|           singletonRepr_.unused_ = nullptr; |  | ||||||
|         } |  | ||||||
|         return *this; |  | ||||||
|       } |  | ||||||
|  |  | ||||||
|       SharedPtrWrapper shared_; |  | ||||||
|  |  | ||||||
|       struct SingletonRepr { |  | ||||||
|         explicit SingletonRepr(T* s) : singleton_(s) {} |  | ||||||
|         T* singleton_; |  | ||||||
|         void* unused_ = nullptr; |  | ||||||
|       } singletonRepr_; |  | ||||||
|       struct RawRepr { |  | ||||||
|         void* first; |  | ||||||
|         void* nullIfSingleton_; |  | ||||||
|       }; |  | ||||||
|  |  | ||||||
|       // It is UB to read the singleton part of Repr if it was |  | ||||||
|       // constructed as a shared_ptr and vice versa, but memcpying out |  | ||||||
|       // the representation is always OK, so here's an accessor to obey |  | ||||||
|       // the letter of the law. |  | ||||||
|       RawRepr rawRepr() const { |  | ||||||
|         RawRepr repr{}; |  | ||||||
|         memcpy(&repr, reinterpret_cast<const char *>(this), sizeof(RawRepr)); |  | ||||||
|         return repr; |  | ||||||
|       } |  | ||||||
|  |  | ||||||
|       bool isNonNull() const { |  | ||||||
|         auto repr = rawRepr(); |  | ||||||
|         TORCH_INTERNAL_ASSERT_DEBUG_ONLY(repr.nullIfSingleton_ == nullptr || repr.first != nullptr); |  | ||||||
|         return repr.first != nullptr; |  | ||||||
|       } |  | ||||||
|  |  | ||||||
|       bool isSharedAndNonNull() const { |  | ||||||
|         return rawRepr().nullIfSingleton_ != nullptr; |  | ||||||
|       } |  | ||||||
|  |  | ||||||
|      private: |  | ||||||
|       void destroy() { |  | ||||||
|         if (isSharedAndNonNull()) { |  | ||||||
|           // Without SharedPtrWrapper, this line would read |  | ||||||
|           // `shared_.~shared_ptr()` and nvcc would complain with |  | ||||||
|           // "error: expected primary-expression before '>' token" |  | ||||||
|           // referring to the "t" in "shared_ptr". SharedPtrWrapper |  | ||||||
|           // exists to work around this compiler bug. |  | ||||||
|           shared_.~SharedPtrWrapper(); |  | ||||||
|         } |  | ||||||
|       } |  | ||||||
|     } repr_; |  | ||||||
|   }; |   }; | ||||||
|  |  | ||||||
|   using TypePtr = SingletonOrSharedTypePtr<Type>; |   using TypePtr = SingletonOrSharedTypePtr<Type>; | ||||||
|  | |||||||
| @ -21,7 +21,7 @@ namespace c10 { | |||||||
|  |  | ||||||
| namespace detail { | namespace detail { | ||||||
| // The first argument of the schema might be of type DispatchKeySet, in which case we remove it. | // The first argument of the schema might be of type DispatchKeySet, in which case we remove it. | ||||||
| // We do this because every argument in a function schema is expected to be convertable | // We do this because every argument in a function schema is expected to be convertible | ||||||
| // to an ivalue, but DispatchKeySet is not a type we want the jit to be aware of. | // to an ivalue, but DispatchKeySet is not a type we want the jit to be aware of. | ||||||
| // See Note [Plumbing Keys Through The Dispatcher] | // See Note [Plumbing Keys Through The Dispatcher] | ||||||
| template<class KernelFunctor> | template<class KernelFunctor> | ||||||
|  | |||||||
| @ -251,7 +251,7 @@ TEST(OperatorRegistrationTest, whenRegisteringCPUTensorType_thenCanOnlyCallUnbox | |||||||
|   callOpUnboxedWithPrecomputedDispatchKeySet<void, Tensor>(*op, c10::DispatchKeySet(c10::DispatchKey::CPU), dummyTensor(c10::DispatchKey::CUDA)); |   callOpUnboxedWithPrecomputedDispatchKeySet<void, Tensor>(*op, c10::DispatchKeySet(c10::DispatchKey::CPU), dummyTensor(c10::DispatchKey::CUDA)); | ||||||
|   EXPECT_TRUE(called_kernel_cpu); |   EXPECT_TRUE(called_kernel_cpu); | ||||||
|  |  | ||||||
|   // Ensure that disptach key from tensor is not used here. |   // Ensure that dispatch key from tensor is not used here. | ||||||
|   called_kernel_cpu = false; |   called_kernel_cpu = false; | ||||||
|   expectThrows<c10::Error>([&] { |   expectThrows<c10::Error>([&] { | ||||||
|     callOpUnboxedWithPrecomputedDispatchKeySet<void, Tensor>(*op, c10::DispatchKeySet(c10::DispatchKey::CUDA), dummyTensor(c10::DispatchKey::CPU)); |     callOpUnboxedWithPrecomputedDispatchKeySet<void, Tensor>(*op, c10::DispatchKeySet(c10::DispatchKey::CUDA), dummyTensor(c10::DispatchKey::CPU)); | ||||||
|  | |||||||
| @ -172,7 +172,7 @@ VaryingShape<Stride> TensorType::computeStrideProps( | |||||||
|   // The logic below follows what TensorIterator uses in its logic: |   // The logic below follows what TensorIterator uses in its logic: | ||||||
|   //   1. Fast_set_up is the short-cut to identify a. channels_last and |   //   1. Fast_set_up is the short-cut to identify a. channels_last and | ||||||
|   //      b. contiguous format, which is what we have in the below logic. |   //      b. contiguous format, which is what we have in the below logic. | ||||||
|   //   2. In more generla cases, it does best effort to preserve permutatoin. |   //   2. In more general cases, it does best effort to preserve permutatoin. | ||||||
|   if (is_channels_last_strides_2d(sizes, strides) || is_channels_last_strides_3d(sizes, strides)) { |   if (is_channels_last_strides_2d(sizes, strides) || is_channels_last_strides_3d(sizes, strides)) { | ||||||
|     // case 1.a. short cut channels last |     // case 1.a. short cut channels last | ||||||
|     std::iota(stride_indices.rbegin() + 1, stride_indices.rend() - 1, 2); |     std::iota(stride_indices.rbegin() + 1, stride_indices.rend() - 1, 2); | ||||||
|  | |||||||
| @ -8,6 +8,7 @@ | |||||||
| #include <ATen/core/jit_type.h> | #include <ATen/core/jit_type.h> | ||||||
| #include <c10/macros/Macros.h> | #include <c10/macros/Macros.h> | ||||||
| #include <c10/util/env.h> | #include <c10/util/env.h> | ||||||
|  | #include <c10/util/Exception.h> | ||||||
| #include <c10/util/flat_hash_map.h> | #include <c10/util/flat_hash_map.h> | ||||||
| #include <c10/util/irange.h> | #include <c10/util/irange.h> | ||||||
| #include <array> | #include <array> | ||||||
| @ -826,9 +827,7 @@ TupleType::TupleType( | |||||||
|     : NamedType(TypeKind::TupleType, std::move(name)), |     : NamedType(TypeKind::TupleType, std::move(name)), | ||||||
|       elements_(std::move(elements)), |       elements_(std::move(elements)), | ||||||
|       has_free_variables_(std::any_of(elements_.begin(), elements_.end(), [](const TypePtr& v) { |       has_free_variables_(std::any_of(elements_.begin(), elements_.end(), [](const TypePtr& v) { | ||||||
|         if (!v) { |         TORCH_CHECK(v, "Can not create tuple with None type"); | ||||||
|           throw std::runtime_error("Can not create tuple with None type"); |  | ||||||
|         } |  | ||||||
|         return v->hasFreeVariables(); |         return v->hasFreeVariables(); | ||||||
|       })), schema_(std::move(schema)) { |       })), schema_(std::move(schema)) { | ||||||
|  |  | ||||||
|  | |||||||
| @ -104,71 +104,6 @@ class Vectorized<float> { | |||||||
|     } |     } | ||||||
|     return b; |     return b; | ||||||
|   } |   } | ||||||
|   // Implementation is picked from |  | ||||||
|   // https://github.com/ARM-software/ComputeLibrary/blob/v25.01/src/core/NEON/SVEMath.inl#L105 |  | ||||||
|   inline svfloat32_t svexp_f32_z(svbool_t pg, svfloat32_t x) const { |  | ||||||
|     const auto c1 = |  | ||||||
|         svreinterpret_f32_u32(svdup_n_u32(0x3f7ffff6)); // x^1: 0x1.ffffecp-1f |  | ||||||
|     const auto c2 = |  | ||||||
|         svreinterpret_f32_u32(svdup_n_u32(0x3efffedb)); // x^2: 0x1.fffdb6p-2f |  | ||||||
|     const auto c3 = |  | ||||||
|         svreinterpret_f32_u32(svdup_n_u32(0x3e2aaf33)); // x^3: 0x1.555e66p-3f |  | ||||||
|     const auto c4 = |  | ||||||
|         svreinterpret_f32_u32(svdup_n_u32(0x3d2b9f17)); // x^4: 0x1.573e2ep-5f |  | ||||||
|     const auto c5 = |  | ||||||
|         svreinterpret_f32_u32(svdup_n_u32(0x3c072010)); // x^5: 0x1.0e4020p-7f |  | ||||||
|     const auto shift = svreinterpret_f32_u32( |  | ||||||
|         svdup_n_u32(0x4b00007f)); // 2^23 + 127 = 0x1.0000fep23f |  | ||||||
|     const auto inv_ln2 = svreinterpret_f32_u32( |  | ||||||
|         svdup_n_u32(0x3fb8aa3b)); // 1 / ln(2) = 0x1.715476p+0f |  | ||||||
|     const auto neg_ln2_hi = svreinterpret_f32_u32(svdup_n_u32( |  | ||||||
|         0xbf317200)); // -ln(2) from bits  -1 to -19: -0x1.62e400p-1f |  | ||||||
|     const auto neg_ln2_lo = svreinterpret_f32_u32(svdup_n_u32( |  | ||||||
|         0xb5bfbe8e)); // -ln(2) from bits -20 to -42: -0x1.7f7d1cp-20f |  | ||||||
|     const auto inf = svdup_n_f32(std::numeric_limits<float>::infinity()); |  | ||||||
|     const auto max_input = svdup_n_f32(88.37f); // Approximately ln(2^127.5) |  | ||||||
|     const auto zero = svdup_n_f32(0.f); |  | ||||||
|     const auto min_input = svdup_n_f32(-86.64f); // Approximately ln(2^-125) |  | ||||||
|     // Range reduction: |  | ||||||
|     //   e^x = 2^n * e^r |  | ||||||
|     // where: |  | ||||||
|     //   n = floor(x / ln(2)) |  | ||||||
|     //   r = x - n * ln(2) |  | ||||||
|     // |  | ||||||
|     // By adding x / ln(2) with 2^23 + 127 (shift): |  | ||||||
|     //   * As FP32 fraction part only has 23-bits, the addition of 2^23 + 127 |  | ||||||
|     //   forces decimal part |  | ||||||
|     //     of x / ln(2) out of the result. The integer part of x / ln(2) (i.e. |  | ||||||
|     //     n) + 127 will occupy the whole fraction part of z in FP32 format. |  | ||||||
|     //     Subtracting 2^23 + 127 (shift) from z will result in the integer part |  | ||||||
|     //     of x / ln(2) (i.e. n) because the decimal part has been pushed out |  | ||||||
|     //     and lost. |  | ||||||
|     //   * The addition of 127 makes the FP32 fraction part of z ready to be |  | ||||||
|     //   used as the exponent |  | ||||||
|     //     in FP32 format. Left shifting z by 23 bits will result in 2^n. |  | ||||||
|     const auto z = svmla_f32_z(pg, shift, x, inv_ln2); |  | ||||||
|     const auto n = svsub_f32_z(pg, z, shift); |  | ||||||
|     const auto scale = svreinterpret_f32_u32( |  | ||||||
|         svlsl_n_u32_z(pg, svreinterpret_u32_f32(z), 23)); // 2^n |  | ||||||
|     // The calculation of n * ln(2) is done using 2 steps to achieve accuracy |  | ||||||
|     // beyond FP32. This outperforms longer Taylor series (3-4 tabs) both in |  | ||||||
|     // term of accuracy and performance. |  | ||||||
|     const auto r_hi = svmla_f32_z(pg, x, n, neg_ln2_hi); |  | ||||||
|     const auto r = svmla_f32_z(pg, r_hi, n, neg_ln2_lo); |  | ||||||
|     // Compute the truncated Taylor series of e^r. |  | ||||||
|     //   poly = scale * (1 + c1 * r + c2 * r^2 + c3 * r^3 + c4 * r^4 + c5 * r^5) |  | ||||||
|     const auto r2 = svmul_f32_z(pg, r, r); |  | ||||||
|     const auto p1 = svmul_f32_z(pg, c1, r); |  | ||||||
|     const auto p23 = svmla_f32_z(pg, c2, c3, r); |  | ||||||
|     const auto p45 = svmla_f32_z(pg, c4, c5, r); |  | ||||||
|     const auto p2345 = svmla_f32_z(pg, p23, p45, r2); |  | ||||||
|     const auto p12345 = svmla_f32_z(pg, p1, p2345, r2); |  | ||||||
|     auto poly = svmla_f32_z(pg, scale, p12345, scale); |  | ||||||
|     // Handle underflow and overflow. |  | ||||||
|     poly = svsel_f32(svcmplt_f32(pg, x, min_input), zero, poly); |  | ||||||
|     poly = svsel_f32(svcmpgt_f32(pg, x, max_input), inf, poly); |  | ||||||
|     return poly; |  | ||||||
|   } |  | ||||||
|   static Vectorized<float> loadu(const void* ptr, int64_t count = size()) { |   static Vectorized<float> loadu(const void* ptr, int64_t count = size()) { | ||||||
|     if (count == size()) |     if (count == size()) | ||||||
|       return svld1_f32(ptrue, reinterpret_cast<const float*>(ptr)); |       return svld1_f32(ptrue, reinterpret_cast<const float*>(ptr)); | ||||||
| @ -313,11 +248,41 @@ class Vectorized<float> { | |||||||
|     return USE_SLEEF( |     return USE_SLEEF( | ||||||
|         Vectorized<float>(Sleef_expm1fx_u10sve(values)), map(std::expm1)); |         Vectorized<float>(Sleef_expm1fx_u10sve(values)), map(std::expm1)); | ||||||
|   } |   } | ||||||
|  |   // Implementation copied from Arm Optimized Routines: | ||||||
|  |   // https://github.com/ARM-software/optimized-routines/blob/master/math/aarch64/sve/expf.c | ||||||
|   Vectorized<float> exp_u20() const { |   Vectorized<float> exp_u20() const { | ||||||
|     return exp(); |     // special case to handle special inputs that are too large or too small | ||||||
|  |     // i.e. where there's at least one element x, s.t. |x| >= 87.3... | ||||||
|  |     svbool_t is_special_case = svacgt(svptrue_b32(), values, 0x1.5d5e2ap+6f); | ||||||
|  |     if (svptest_any(svptrue_b32(), is_special_case)) { | ||||||
|  |       return exp(); | ||||||
|  |     } | ||||||
|  |     const svfloat32_t ln2_hi = svdup_n_f32(0x1.62e4p-1f); | ||||||
|  |     const svfloat32_t ln2_lo = svdup_n_f32(0x1.7f7d1cp-20f); | ||||||
|  |     const svfloat32_t c1 = svdup_n_f32(0.5f); | ||||||
|  |     const svfloat32_t inv_ln2 = svdup_n_f32(0x1.715476p+0f); | ||||||
|  |  | ||||||
|  |     const float shift = 0x1.803f8p17f; | ||||||
|  |  | ||||||
|  |     /* n = round(x/(ln2/N)).  */ | ||||||
|  |     svfloat32_t z = svmad_x(svptrue_b32(), inv_ln2, values, shift); | ||||||
|  |     svfloat32_t n = svsub_x(svptrue_b32(), z, shift); | ||||||
|  |  | ||||||
|  |     /* r = x - n*ln2/N.  */ | ||||||
|  |     svfloat32_t r = values; | ||||||
|  |     r = svmls_x(svptrue_b32(), r, n, ln2_hi); | ||||||
|  |     r = svmls_x(svptrue_b32(), r, n, ln2_lo); | ||||||
|  |  | ||||||
|  |     /* scale = 2^(n/N).  */ | ||||||
|  |     svfloat32_t scale = svexpa(svreinterpret_u32(z)); | ||||||
|  |  | ||||||
|  |     /* poly(r) = exp(r) - 1 ~= r + 0.5 r^2.  */ | ||||||
|  |     svfloat32_t r2 = svmul_x(svptrue_b32(), r, r); | ||||||
|  |     svfloat32_t poly = svmla_x(svptrue_b32(), r, r2, c1); | ||||||
|  |     return svmla_x(svptrue_b32(), scale, scale, poly); | ||||||
|   } |   } | ||||||
|   Vectorized<float> fexp_u20() const { |   Vectorized<float> fexp_u20() const { | ||||||
|     return exp(); |     return exp_u20(); | ||||||
|   } |   } | ||||||
|   Vectorized<float> fmod(const Vectorized<float>& q) const {USE_SLEEF( |   Vectorized<float> fmod(const Vectorized<float>& q) const {USE_SLEEF( | ||||||
|       { return Vectorized<float>(Sleef_fmodfx_sve(values, q)); }, |       { return Vectorized<float>(Sleef_fmodfx_sve(values, q)); }, | ||||||
| @ -453,9 +418,11 @@ class Vectorized<float> { | |||||||
|         ptrue, svmax_f32_z(ptrue, values, CONST_MIN_TANH), CONST_MAX_TANH); |         ptrue, svmax_f32_z(ptrue, values, CONST_MIN_TANH), CONST_MAX_TANH); | ||||||
|  |  | ||||||
|     // Step 2: Calculate exp(2 * x), where x is the clamped value. |     // Step 2: Calculate exp(2 * x), where x is the clamped value. | ||||||
|     // svmul_f32_z computes 2 * x, and svexp_f32_z computes the exponential of |     // svmul_f32_z computes 2 * x, and exp_u20() computes the exponential of | ||||||
|     // the result. |     // the result (via Vectorized<float>, then auto-converts back to | ||||||
|     svfloat32_t exp2x = svexp_f32_z(ptrue, svmul_f32_z(ptrue, CONST_2, x)); |     // svfloat32_t). | ||||||
|  |     svfloat32_t exp2x = | ||||||
|  |         Vectorized<float>(svmul_f32_z(ptrue, CONST_2, x)).exp_u20(); | ||||||
|  |  | ||||||
|     // Step 3: Calculate the numerator of the tanh function, which is exp(2x) |     // Step 3: Calculate the numerator of the tanh function, which is exp(2x) | ||||||
|     // - 1. |     // - 1. | ||||||
|  | |||||||
| @ -6,9 +6,11 @@ | |||||||
| #ifdef __aarch64__ | #ifdef __aarch64__ | ||||||
| #if !defined(CPU_CAPABILITY_SVE) | #if !defined(CPU_CAPABILITY_SVE) | ||||||
| #include <ATen/cpu/vec/vec128/vec128_bfloat16_neon.h> | #include <ATen/cpu/vec/vec128/vec128_bfloat16_neon.h> | ||||||
|  | #include <ATen/cpu/vec/vec128/vec128_double_neon.h> | ||||||
| #include <ATen/cpu/vec/vec128/vec128_float_neon.h> | #include <ATen/cpu/vec/vec128/vec128_float_neon.h> | ||||||
| #include <ATen/cpu/vec/vec128/vec128_half_neon.h> | #include <ATen/cpu/vec/vec128/vec128_half_neon.h> | ||||||
| #include <ATen/cpu/vec/vec128/vec128_int_aarch64.h> | #include <ATen/cpu/vec/vec128/vec128_int_aarch64.h> | ||||||
|  | #include <ATen/cpu/vec/vec128/vec128_uint_aarch64.h> | ||||||
| #endif | #endif | ||||||
|  |  | ||||||
| #include <ATen/cpu/vec/vec128/vec128_convert.h> | #include <ATen/cpu/vec/vec128/vec128_convert.h> | ||||||
|  | |||||||
| @ -354,9 +354,47 @@ class Vectorized<c10::BFloat16> : public Vectorized16< | |||||||
|  |  | ||||||
|   DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(abs) |   DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(abs) | ||||||
|   Vectorized frac() const; |   Vectorized frac() const; | ||||||
|   DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(neg) |  | ||||||
|   DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(trunc) |   DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(trunc) | ||||||
|   DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(sqrt) |   DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(sqrt) | ||||||
|  |  | ||||||
|  | #ifdef __ARM_FEATURE_BF16 | ||||||
|  |   Vectorized<c10::BFloat16> neg() const { | ||||||
|  |     return -values; | ||||||
|  |   } | ||||||
|  |   Vectorized<c10::BFloat16> reciprocal() const { | ||||||
|  |     return 1.0f / values; | ||||||
|  |   } | ||||||
|  |   Vectorized<c10::BFloat16> operator==( | ||||||
|  |       const Vectorized<c10::BFloat16>& other) const { | ||||||
|  |     return values == other.values; | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   Vectorized<c10::BFloat16> operator!=( | ||||||
|  |       const Vectorized<c10::BFloat16>& other) const { | ||||||
|  |     return values != other.values; | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   Vectorized<c10::BFloat16> operator<( | ||||||
|  |       const Vectorized<c10::BFloat16>& other) const { | ||||||
|  |     return values < other.values; | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   Vectorized<c10::BFloat16> operator<=( | ||||||
|  |       const Vectorized<c10::BFloat16>& other) const { | ||||||
|  |     return values <= other.values; | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   Vectorized<c10::BFloat16> operator>( | ||||||
|  |       const Vectorized<c10::BFloat16>& other) const { | ||||||
|  |     return values > other.values; | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   Vectorized<c10::BFloat16> operator>=( | ||||||
|  |       const Vectorized<c10::BFloat16>& other) const { | ||||||
|  |     return values >= other.values; | ||||||
|  |   } | ||||||
|  | #else | ||||||
|  |   DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(neg) | ||||||
|   DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(reciprocal) |   DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(reciprocal) | ||||||
|   DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator==) |   DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator==) | ||||||
|   DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator!=) |   DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator!=) | ||||||
| @ -364,6 +402,7 @@ class Vectorized<c10::BFloat16> : public Vectorized16< | |||||||
|   DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator<=) |   DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator<=) | ||||||
|   DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator>) |   DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator>) | ||||||
|   DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator>=) |   DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator>=) | ||||||
|  | #endif | ||||||
|  |  | ||||||
| #undef DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD | #undef DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD | ||||||
| #undef DEFINE_BINARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD | #undef DEFINE_BINARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD | ||||||
| @ -412,28 +451,52 @@ template <> | |||||||
| Vectorized<c10::BFloat16> inline operator+( | Vectorized<c10::BFloat16> inline operator+( | ||||||
|     const Vectorized<c10::BFloat16>& a, |     const Vectorized<c10::BFloat16>& a, | ||||||
|     const Vectorized<c10::BFloat16>& b) { |     const Vectorized<c10::BFloat16>& b) { | ||||||
|  | #ifdef __ARM_FEATURE_BF16 | ||||||
|  |   bfloat16x8_t x = a; | ||||||
|  |   bfloat16x8_t y = b; | ||||||
|  |   return x + y; | ||||||
|  | #else | ||||||
|   return binary_operator_via_float(std::plus<Vectorized<float>>(), a, b); |   return binary_operator_via_float(std::plus<Vectorized<float>>(), a, b); | ||||||
|  | #endif | ||||||
| } | } | ||||||
|  |  | ||||||
| template <> | template <> | ||||||
| Vectorized<c10::BFloat16> inline operator-( | Vectorized<c10::BFloat16> inline operator-( | ||||||
|     const Vectorized<c10::BFloat16>& a, |     const Vectorized<c10::BFloat16>& a, | ||||||
|     const Vectorized<c10::BFloat16>& b) { |     const Vectorized<c10::BFloat16>& b) { | ||||||
|  | #ifdef __ARM_FEATURE_BF16 | ||||||
|  |   bfloat16x8_t x = a; | ||||||
|  |   bfloat16x8_t y = b; | ||||||
|  |   return x - y; | ||||||
|  | #else | ||||||
|   return binary_operator_via_float(std::minus<Vectorized<float>>(), a, b); |   return binary_operator_via_float(std::minus<Vectorized<float>>(), a, b); | ||||||
|  | #endif | ||||||
| } | } | ||||||
|  |  | ||||||
| template <> | template <> | ||||||
| Vectorized<c10::BFloat16> inline operator*( | Vectorized<c10::BFloat16> inline operator*( | ||||||
|     const Vectorized<c10::BFloat16>& a, |     const Vectorized<c10::BFloat16>& a, | ||||||
|     const Vectorized<c10::BFloat16>& b) { |     const Vectorized<c10::BFloat16>& b) { | ||||||
|  | #ifdef __ARM_FEATURE_BF16 | ||||||
|  |   bfloat16x8_t x = a; | ||||||
|  |   bfloat16x8_t y = b; | ||||||
|  |   return x * y; | ||||||
|  | #else | ||||||
|   return binary_operator_via_float(std::multiplies<Vectorized<float>>(), a, b); |   return binary_operator_via_float(std::multiplies<Vectorized<float>>(), a, b); | ||||||
|  | #endif | ||||||
| } | } | ||||||
|  |  | ||||||
| template <> | template <> | ||||||
| Vectorized<c10::BFloat16> inline operator/( | Vectorized<c10::BFloat16> inline operator/( | ||||||
|     const Vectorized<c10::BFloat16>& a, |     const Vectorized<c10::BFloat16>& a, | ||||||
|     const Vectorized<c10::BFloat16>& b) { |     const Vectorized<c10::BFloat16>& b) { | ||||||
|  | #ifdef __ARM_FEATURE_BF16 | ||||||
|  |   bfloat16x8_t x = a; | ||||||
|  |   bfloat16x8_t y = b; | ||||||
|  |   return x / y; | ||||||
|  | #else | ||||||
|   return binary_operator_via_float(std::divides<Vectorized<float>>(), a, b); |   return binary_operator_via_float(std::divides<Vectorized<float>>(), a, b); | ||||||
|  | #endif | ||||||
| } | } | ||||||
|  |  | ||||||
| // frac. Implement this here so we can use subtraction | // frac. Implement this here so we can use subtraction | ||||||
| @ -544,12 +607,19 @@ Vectorized<c10::BFloat16> inline fmadd( | |||||||
|     const Vectorized<c10::BFloat16>& a, |     const Vectorized<c10::BFloat16>& a, | ||||||
|     const Vectorized<c10::BFloat16>& b, |     const Vectorized<c10::BFloat16>& b, | ||||||
|     const Vectorized<c10::BFloat16>& c) { |     const Vectorized<c10::BFloat16>& c) { | ||||||
|  | #ifdef __ARM_FEATURE_BF16 | ||||||
|  |   bfloat16x8_t x = a; | ||||||
|  |   bfloat16x8_t y = b; | ||||||
|  |   bfloat16x8_t z = c; | ||||||
|  |   return x * y + z; | ||||||
|  | #else | ||||||
|   // NOTE [BF16 FMA]: There isn't an FMA that accumulates into BF16!  Also, |   // NOTE [BF16 FMA]: There isn't an FMA that accumulates into BF16!  Also, | ||||||
|   // vbfmlalbq_f32 and vbfmlaltq_f32 take the even and odd-numbered |   // vbfmlalbq_f32 and vbfmlaltq_f32 take the even and odd-numbered | ||||||
|   // elements, not the bottom and top half, so they don't seem |   // elements, not the bottom and top half, so they don't seem | ||||||
|   // particularly useful here. Ideally we would include dot product in |   // particularly useful here. Ideally we would include dot product in | ||||||
|   // the Vectorized interface... |   // the Vectorized interface... | ||||||
|   return a * b + c; |   return a * b + c; | ||||||
|  | #endif | ||||||
| } | } | ||||||
|  |  | ||||||
| template <> | template <> | ||||||
| @ -557,8 +627,15 @@ Vectorized<c10::BFloat16> inline fnmadd( | |||||||
|     const Vectorized<c10::BFloat16>& a, |     const Vectorized<c10::BFloat16>& a, | ||||||
|     const Vectorized<c10::BFloat16>& b, |     const Vectorized<c10::BFloat16>& b, | ||||||
|     const Vectorized<c10::BFloat16>& c) { |     const Vectorized<c10::BFloat16>& c) { | ||||||
|  | #ifdef __ARM_FEATURE_BF16 | ||||||
|  |   bfloat16x8_t x = a; | ||||||
|  |   bfloat16x8_t y = b; | ||||||
|  |   bfloat16x8_t z = c; | ||||||
|  |   return (-x) * y + z; | ||||||
|  | #else | ||||||
|   // See NOTE [BF16 FMA] above. |   // See NOTE [BF16 FMA] above. | ||||||
|   return -a * b + c; |   return -a * b + c; | ||||||
|  | #endif | ||||||
| } | } | ||||||
|  |  | ||||||
| template <> | template <> | ||||||
| @ -566,8 +643,15 @@ Vectorized<c10::BFloat16> inline fmsub( | |||||||
|     const Vectorized<c10::BFloat16>& a, |     const Vectorized<c10::BFloat16>& a, | ||||||
|     const Vectorized<c10::BFloat16>& b, |     const Vectorized<c10::BFloat16>& b, | ||||||
|     const Vectorized<c10::BFloat16>& c) { |     const Vectorized<c10::BFloat16>& c) { | ||||||
|  | #ifdef __ARM_FEATURE_BF16 | ||||||
|  |   bfloat16x8_t x = a; | ||||||
|  |   bfloat16x8_t y = b; | ||||||
|  |   bfloat16x8_t z = c; | ||||||
|  |   return x * y - z; | ||||||
|  | #else | ||||||
|   // See NOTE [BF16 FMA] above. |   // See NOTE [BF16 FMA] above. | ||||||
|   return a * b - c; |   return a * b - c; | ||||||
|  | #endif | ||||||
| } | } | ||||||
|  |  | ||||||
| template <> | template <> | ||||||
| @ -575,8 +659,15 @@ Vectorized<c10::BFloat16> inline fnmsub( | |||||||
|     const Vectorized<c10::BFloat16>& a, |     const Vectorized<c10::BFloat16>& a, | ||||||
|     const Vectorized<c10::BFloat16>& b, |     const Vectorized<c10::BFloat16>& b, | ||||||
|     const Vectorized<c10::BFloat16>& c) { |     const Vectorized<c10::BFloat16>& c) { | ||||||
|  | #ifdef __ARM_FEATURE_BF16 | ||||||
|  |   bfloat16x8_t x = a; | ||||||
|  |   bfloat16x8_t y = b; | ||||||
|  |   bfloat16x8_t z = c; | ||||||
|  |   return (-x) * y - z; | ||||||
|  | #else | ||||||
|   // See NOTE [BF16 FMA] above. |   // See NOTE [BF16 FMA] above. | ||||||
|   return -a * b - c; |   return -a * b - c; | ||||||
|  | #endif | ||||||
| } | } | ||||||
|  |  | ||||||
| #endif // !defined(C10_MOBILE) && defined(__aarch64__) | #endif // !defined(C10_MOBILE) && defined(__aarch64__) | ||||||
|  | |||||||
| @ -5,6 +5,129 @@ | |||||||
| namespace at::vec { | namespace at::vec { | ||||||
| inline namespace CPU_CAPABILITY { | inline namespace CPU_CAPABILITY { | ||||||
| #if (defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE256)) | #if (defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE256)) | ||||||
|  |  | ||||||
|  | // Enable auto-vectorization for GCC-13+ and clang-17+ | ||||||
|  | // GCC-12 has a bug: gcc.gnu.org/bugzilla/show_bug.cgi?id=117001 | ||||||
|  | #if __GNUC__ > 12 || (defined(__clang__) && (__clang_major__ >= 17)) | ||||||
|  |  | ||||||
|  | template <typename from_type, typename to_type> | ||||||
|  | inline void convertImpl( | ||||||
|  |     const from_type* __restrict src, | ||||||
|  |     to_type* __restrict dst, | ||||||
|  |     int64_t n) { | ||||||
|  |   uint64_t len = static_cast<uint64_t>(n); | ||||||
|  |   for (uint64_t i = 0; i < len; i++) { | ||||||
|  |     dst[i] = static_cast<to_type>(src[i]); | ||||||
|  |   } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | #define CONVERT_TEMPLATE(from_type, to_type)                           \ | ||||||
|  |   template <>                                                          \ | ||||||
|  |   inline void convert(const from_type* src, to_type* dst, int64_t n) { \ | ||||||
|  |     return convertImpl<from_type, to_type>(src, dst, n);               \ | ||||||
|  |   } | ||||||
|  |  | ||||||
|  | CONVERT_TEMPLATE(uint8_t, uint8_t) | ||||||
|  | CONVERT_TEMPLATE(uint8_t, int8_t) | ||||||
|  | CONVERT_TEMPLATE(uint8_t, int16_t) | ||||||
|  | CONVERT_TEMPLATE(uint8_t, int32_t) | ||||||
|  | CONVERT_TEMPLATE(uint8_t, int64_t) | ||||||
|  | CONVERT_TEMPLATE(uint8_t, float) | ||||||
|  | CONVERT_TEMPLATE(uint8_t, double) | ||||||
|  | CONVERT_TEMPLATE(int8_t, uint8_t) | ||||||
|  | CONVERT_TEMPLATE(int8_t, int8_t) | ||||||
|  | CONVERT_TEMPLATE(int8_t, int16_t) | ||||||
|  | CONVERT_TEMPLATE(int8_t, int32_t) | ||||||
|  | CONVERT_TEMPLATE(int8_t, int64_t) | ||||||
|  | CONVERT_TEMPLATE(int8_t, float) | ||||||
|  | CONVERT_TEMPLATE(int8_t, double) | ||||||
|  | CONVERT_TEMPLATE(int16_t, uint8_t) | ||||||
|  | CONVERT_TEMPLATE(int16_t, int8_t) | ||||||
|  | CONVERT_TEMPLATE(int16_t, int16_t) | ||||||
|  | CONVERT_TEMPLATE(int16_t, int32_t) | ||||||
|  | CONVERT_TEMPLATE(int16_t, int64_t) | ||||||
|  | CONVERT_TEMPLATE(int16_t, float) | ||||||
|  | CONVERT_TEMPLATE(int16_t, double) | ||||||
|  | CONVERT_TEMPLATE(int32_t, uint8_t) | ||||||
|  | CONVERT_TEMPLATE(int32_t, int8_t) | ||||||
|  | CONVERT_TEMPLATE(int32_t, int16_t) | ||||||
|  | CONVERT_TEMPLATE(int32_t, int32_t) | ||||||
|  | CONVERT_TEMPLATE(int32_t, int64_t) | ||||||
|  | CONVERT_TEMPLATE(int32_t, float) | ||||||
|  | CONVERT_TEMPLATE(int32_t, double) | ||||||
|  | CONVERT_TEMPLATE(int64_t, uint8_t) | ||||||
|  | CONVERT_TEMPLATE(int64_t, int8_t) | ||||||
|  | CONVERT_TEMPLATE(int64_t, int16_t) | ||||||
|  | CONVERT_TEMPLATE(int64_t, int32_t) | ||||||
|  | CONVERT_TEMPLATE(int64_t, int64_t) | ||||||
|  | CONVERT_TEMPLATE(int64_t, float) | ||||||
|  | CONVERT_TEMPLATE(int64_t, double) | ||||||
|  | CONVERT_TEMPLATE(float, uint8_t) | ||||||
|  | CONVERT_TEMPLATE(float, int8_t) | ||||||
|  | CONVERT_TEMPLATE(float, int16_t) | ||||||
|  | CONVERT_TEMPLATE(float, int32_t) | ||||||
|  | CONVERT_TEMPLATE(float, int64_t) | ||||||
|  | CONVERT_TEMPLATE(float, float) | ||||||
|  | CONVERT_TEMPLATE(float, double) | ||||||
|  | CONVERT_TEMPLATE(double, uint8_t) | ||||||
|  | CONVERT_TEMPLATE(double, int8_t) | ||||||
|  | CONVERT_TEMPLATE(double, int16_t) | ||||||
|  | CONVERT_TEMPLATE(double, int32_t) | ||||||
|  | CONVERT_TEMPLATE(double, int64_t) | ||||||
|  | CONVERT_TEMPLATE(double, float) | ||||||
|  | CONVERT_TEMPLATE(double, double) | ||||||
|  | #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | ||||||
|  |  | ||||||
|  | #define CONVERT_FROM_FP16_TEMPLATE(to_type)                            \ | ||||||
|  |   template <>                                                          \ | ||||||
|  |   inline void convert(const at::Half* src, to_type* dst, int64_t n) {  \ | ||||||
|  |     const float16_t* srcPtr = reinterpret_cast<const float16_t*>(src); \ | ||||||
|  |     return convertImpl<float16_t, to_type>(srcPtr, dst, n);            \ | ||||||
|  |   } | ||||||
|  |  | ||||||
|  | #define CONVERT_TO_FP16_TEMPLATE(from_type)                             \ | ||||||
|  |   template <>                                                           \ | ||||||
|  |   inline void convert(const from_type* src, at::Half* dst, int64_t n) { \ | ||||||
|  |     float16_t* dstPtr = reinterpret_cast<float16_t*>(dst);              \ | ||||||
|  |     return convertImpl<from_type, float16_t>(src, dstPtr, n);           \ | ||||||
|  |   } | ||||||
|  |  | ||||||
|  | CONVERT_FROM_FP16_TEMPLATE(uint8_t) | ||||||
|  | CONVERT_FROM_FP16_TEMPLATE(int8_t) | ||||||
|  | CONVERT_FROM_FP16_TEMPLATE(int16_t) | ||||||
|  | CONVERT_FROM_FP16_TEMPLATE(int32_t) | ||||||
|  | CONVERT_FROM_FP16_TEMPLATE(int64_t) | ||||||
|  | CONVERT_FROM_FP16_TEMPLATE(float16_t) | ||||||
|  | CONVERT_FROM_FP16_TEMPLATE(float) | ||||||
|  | CONVERT_FROM_FP16_TEMPLATE(double) | ||||||
|  | CONVERT_TO_FP16_TEMPLATE(uint8_t) | ||||||
|  | CONVERT_TO_FP16_TEMPLATE(int8_t) | ||||||
|  | CONVERT_TO_FP16_TEMPLATE(int16_t) | ||||||
|  | CONVERT_TO_FP16_TEMPLATE(int32_t) | ||||||
|  | CONVERT_TO_FP16_TEMPLATE(int64_t) | ||||||
|  | CONVERT_TO_FP16_TEMPLATE(float) | ||||||
|  | CONVERT_TO_FP16_TEMPLATE(double) | ||||||
|  | #endif | ||||||
|  | #ifdef __ARM_FEATURE_BF16 | ||||||
|  | CONVERT_TEMPLATE(bfloat16_t, uint8_t) | ||||||
|  | CONVERT_TEMPLATE(bfloat16_t, int8_t) | ||||||
|  | CONVERT_TEMPLATE(bfloat16_t, int16_t) | ||||||
|  | CONVERT_TEMPLATE(bfloat16_t, int32_t) | ||||||
|  | CONVERT_TEMPLATE(bfloat16_t, int64_t) | ||||||
|  | CONVERT_TEMPLATE(bfloat16_t, bfloat16_t) | ||||||
|  | CONVERT_TEMPLATE(bfloat16_t, float) | ||||||
|  | CONVERT_TEMPLATE(bfloat16_t, double) | ||||||
|  | CONVERT_TEMPLATE(uint8_t, bfloat16_t) | ||||||
|  | CONVERT_TEMPLATE(int8_t, bfloat16_t) | ||||||
|  | CONVERT_TEMPLATE(int16_t, bfloat16_t) | ||||||
|  | CONVERT_TEMPLATE(int32_t, bfloat16_t) | ||||||
|  | CONVERT_TEMPLATE(int64_t, bfloat16_t) | ||||||
|  | CONVERT_TEMPLATE(float, bfloat16_t) | ||||||
|  | CONVERT_TEMPLATE(double, bfloat16_t) | ||||||
|  | #endif | ||||||
|  |  | ||||||
|  | #endif | ||||||
|  |  | ||||||
| template <typename src_t> | template <typename src_t> | ||||||
| struct VecConvert< | struct VecConvert< | ||||||
|     float, |     float, | ||||||
|  | |||||||
							
								
								
									
										586
									
								
								aten/src/ATen/cpu/vec/vec128/vec128_double_neon.h
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										586
									
								
								aten/src/ATen/cpu/vec/vec128/vec128_double_neon.h
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,586 @@ | |||||||
|  | #pragma once | ||||||
|  |  | ||||||
|  | #include <ATen/cpu/vec/intrinsics.h> | ||||||
|  | #include <ATen/cpu/vec/vec_base.h> | ||||||
|  | #include <c10/macros/Macros.h> | ||||||
|  | #include <c10/util/irange.h> | ||||||
|  | #include <cmath> | ||||||
|  |  | ||||||
|  | namespace at::vec { | ||||||
|  | // Note [CPU_CAPABILITY namespace] | ||||||
|  | // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | ||||||
|  | // This header, and all of its subheaders, will be compiled with | ||||||
|  | // different architecture flags for each supported set of vector | ||||||
|  | // intrinsics. So we need to make sure they aren't inadvertently | ||||||
|  | // linked together. We do this by declaring objects in an `inline | ||||||
|  | // namespace` which changes the name mangling, but can still be | ||||||
|  | // accessed as `at::vec`. | ||||||
|  | inline namespace CPU_CAPABILITY { | ||||||
|  |  | ||||||
|  | template <> | ||||||
|  | struct is_vec_specialized_for<double> : std::bool_constant<true> {}; | ||||||
|  |  | ||||||
|  | template <> | ||||||
|  | class Vectorized<double> { | ||||||
|  |  private: | ||||||
|  |   float64x2_t values; | ||||||
|  |  | ||||||
|  |  public: | ||||||
|  |   using value_type = double; | ||||||
|  |   using size_type = int; | ||||||
|  |   static constexpr size_type size() { | ||||||
|  |     return 2; | ||||||
|  |   } | ||||||
|  |   Vectorized() { | ||||||
|  |     values = vdupq_n_f64(0.0); | ||||||
|  |   } | ||||||
|  |   Vectorized(float64x2_t v) : values(v) {} | ||||||
|  |   Vectorized(double val) { | ||||||
|  |     values = vdupq_n_f64(val); | ||||||
|  |   } | ||||||
|  |   template < | ||||||
|  |       typename... Args, | ||||||
|  |       typename = std::enable_if_t<(sizeof...(Args) == size())>> | ||||||
|  |   Vectorized(Args... vals) { | ||||||
|  |     __at_align__ double buffer[size()] = {vals...}; | ||||||
|  |     values = vld1q_f64(buffer); | ||||||
|  |   } | ||||||
|  |   operator float64x2_t() const { | ||||||
|  |     return values; | ||||||
|  |   } | ||||||
|  |   template <int64_t mask> | ||||||
|  |   static Vectorized<double> blend( | ||||||
|  |       const Vectorized<double>& a, | ||||||
|  |       const Vectorized<double>& b) { | ||||||
|  |     // Build an array of flags: each bit of element is 1 if the corresponding | ||||||
|  |     // bit in 'mask' is set, 0 otherwise. | ||||||
|  |     uint64x2_t maskArray = { | ||||||
|  |         (mask & 1ULL) ? 0xFFFFFFFFFFFFFFFF : 0, | ||||||
|  |         (mask & 2ULL) ? 0xFFFFFFFFFFFFFFFF : 0}; | ||||||
|  |     // Use BSL to select elements from b where the mask is 1, else from a | ||||||
|  |     return vbslq_f64(maskArray, b.values, a.values); | ||||||
|  |   } | ||||||
|  |   static Vectorized<double> blendv( | ||||||
|  |       const Vectorized<double>& a, | ||||||
|  |       const Vectorized<double>& b, | ||||||
|  |       const Vectorized<double>& mask_) { | ||||||
|  |     return vbslq_f64(vreinterpretq_u64_f64(mask_.values), b.values, a.values); | ||||||
|  |   } | ||||||
|  |   template <typename step_t> | ||||||
|  |   static Vectorized<double> arange( | ||||||
|  |       double base = 0., | ||||||
|  |       step_t step = static_cast<step_t>(1)) { | ||||||
|  |     return {base, base + static_cast<double>(step)}; | ||||||
|  |   } | ||||||
|  |   static inline Vectorized<double> set( | ||||||
|  |       const Vectorized<double>& a, | ||||||
|  |       const Vectorized<double>& b, | ||||||
|  |       int64_t count = size()) { | ||||||
|  |     if (count == 0) { | ||||||
|  |       return a; | ||||||
|  |     } else if (count >= 2) { | ||||||
|  |       return b; | ||||||
|  |     } else { | ||||||
|  |       float64x2_t c = {b.values[0], a.values[1]}; | ||||||
|  |       return c; | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  |   static Vectorized<double> loadu(const void* ptr, int64_t count = size()) { | ||||||
|  |     if (count == size()) { | ||||||
|  |       return vld1q_f64(reinterpret_cast<const double*>(ptr)); | ||||||
|  |     } else if (count == 1) { | ||||||
|  |       float64x1_t x = vld1_f64(reinterpret_cast<const double*>(ptr)); | ||||||
|  |       float64x1_t z = {0.0}; | ||||||
|  |       return vcombine_f64(x, z); | ||||||
|  |     } else { | ||||||
|  |       return vdupq_n_f64(0.0); | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  |   void store(void* ptr, int64_t count = size()) const { | ||||||
|  |     if (count == size()) { | ||||||
|  |       vst1q_f64(reinterpret_cast<double*>(ptr), values); | ||||||
|  |     } else if (count == 1) { | ||||||
|  |       vst1_f64(reinterpret_cast<double*>(ptr), vget_low_f64(values)); | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  |   const double& operator[](int idx) const = delete; | ||||||
|  |   double& operator[](int idx) = delete; | ||||||
|  |   int64_t zero_mask() const { | ||||||
|  |     // returns an integer mask where all zero elements are translated to 1-bit | ||||||
|  |     // and others are translated to 0-bit | ||||||
|  |     uint64x2_t cmpReg = vceqzq_f64(values); | ||||||
|  |     uint64x2_t mask = {1, 2}; | ||||||
|  |     uint64x2_t res = vandq_u64(cmpReg, mask); | ||||||
|  |     return res[0] | res[1]; | ||||||
|  |   } | ||||||
|  |   Vectorized<double> isnan() const { | ||||||
|  |     // NaN check | ||||||
|  |     return vreinterpretq_f64_u32( | ||||||
|  |         vmvnq_u32(vreinterpretq_u32_u64(vceqq_f64(values, values)))); | ||||||
|  |   } | ||||||
|  |   bool has_inf_nan() const { | ||||||
|  |     Vectorized<double> x = vsubq_f64(values, values); | ||||||
|  |     float64x2_t r = x.isnan(); | ||||||
|  |     uint64x2_t u = vreinterpretq_u64_f64(r); | ||||||
|  |     return u[0] | u[1]; | ||||||
|  |   } | ||||||
|  |   Vectorized<double> map(double (*f)(double)) const { | ||||||
|  |     float64x2_t result; | ||||||
|  |     result[0] = f(values[0]); | ||||||
|  |     result[1] = f(values[1]); | ||||||
|  |     return result; | ||||||
|  |   } | ||||||
|  |   Vectorized<double> map2( | ||||||
|  |       const Vectorized<double>& second, | ||||||
|  |       double (*const f)(double, double)) const { | ||||||
|  |     float64x2_t result; | ||||||
|  |     result[0] = f(values[0], second.values[0]); | ||||||
|  |     result[1] = f(values[1], second.values[1]); | ||||||
|  |     return result; | ||||||
|  |   } | ||||||
|  |   Vectorized<double> abs() const { | ||||||
|  |     return vabsq_f64(values); | ||||||
|  |   } | ||||||
|  |   Vectorized<double> angle() const { | ||||||
|  |     auto zero = Vectorized<double>(0.0); | ||||||
|  |     auto pi = Vectorized<double>(c10::pi<double>); | ||||||
|  |     auto tmp = blendv(zero, pi, vreinterpretq_f64_u64(vcltzq_f64(values))); | ||||||
|  |     return blendv(tmp, *this, isnan()); | ||||||
|  |   } | ||||||
|  |   Vectorized<double> real() const { | ||||||
|  |     return *this; | ||||||
|  |   } | ||||||
|  |   Vectorized<double> imag() const { | ||||||
|  |     return Vectorized<double>(0.0); | ||||||
|  |   } | ||||||
|  |   Vectorized<double> conj() const { | ||||||
|  |     return *this; | ||||||
|  |   } | ||||||
|  |   Vectorized<double> acos() const { | ||||||
|  |     return USE_SLEEF( | ||||||
|  |         Vectorized<double>(Sleef_acosd2_u10(values)), map(std::acos)); | ||||||
|  |   } | ||||||
|  |   Vectorized<double> acosh() const { | ||||||
|  |     return USE_SLEEF( | ||||||
|  |         Vectorized<double>(Sleef_acoshd2_u10(values)), map(std::acosh)); | ||||||
|  |   } | ||||||
|  |   Vectorized<double> asin() const { | ||||||
|  |     return USE_SLEEF( | ||||||
|  |         Vectorized<double>(Sleef_asind2_u10(values)), map(std::asin)); | ||||||
|  |   } | ||||||
|  |   Vectorized<double> asinh() const { | ||||||
|  |     return USE_SLEEF( | ||||||
|  |         Vectorized<double>(Sleef_asinhd2_u10(values)), map(std::asinh)); | ||||||
|  |   } | ||||||
|  |   Vectorized<double> atan() const { | ||||||
|  |     return USE_SLEEF( | ||||||
|  |         Vectorized<double>(Sleef_atand2_u10(values)), map(std::atan)); | ||||||
|  |   } | ||||||
|  |   Vectorized<double> atanh() const { | ||||||
|  |     return USE_SLEEF( | ||||||
|  |         Vectorized<double>(Sleef_atanhd2_u10(values)), map(std::atanh)); | ||||||
|  |   } | ||||||
|  |   Vectorized<double> atan2(const Vectorized<double>& b) const {USE_SLEEF( | ||||||
|  |       { return Vectorized<double>(Sleef_atan2d2_u10(values, b)); }, | ||||||
|  |       { | ||||||
|  |         __at_align__ double tmp[size()]; | ||||||
|  |         __at_align__ double tmp_b[size()]; | ||||||
|  |         store(tmp); | ||||||
|  |         b.store(tmp_b); | ||||||
|  |         for (int64_t i = 0; i < size(); i++) { | ||||||
|  |           tmp[i] = std::atan2(tmp[i], tmp_b[i]); | ||||||
|  |         } | ||||||
|  |         return loadu(tmp); | ||||||
|  |       })} Vectorized<double> copysign(const Vectorized<double>& sign) const { | ||||||
|  |       USE_SLEEF( | ||||||
|  |           { return Vectorized<double>(Sleef_copysignd2(values, sign)); }, | ||||||
|  |           { | ||||||
|  |             __at_align__ double tmp[size()]; | ||||||
|  |             __at_align__ double tmp_sign[size()]; | ||||||
|  |             store(tmp); | ||||||
|  |             sign.store(tmp_sign); | ||||||
|  |             for (int64_t i = 0; i < size(); i++) { | ||||||
|  |               tmp[i] = std::copysign(tmp[i], tmp_sign[i]); | ||||||
|  |             } | ||||||
|  |             return loadu(tmp); | ||||||
|  |           })} Vectorized<double> erf() const { | ||||||
|  |     return USE_SLEEF( | ||||||
|  |         Vectorized<double>(Sleef_erfd2_u10(values)), map(std::erf)); | ||||||
|  |   } | ||||||
|  |   Vectorized<double> erfc() const { | ||||||
|  |     return USE_SLEEF( | ||||||
|  |         Vectorized<double>(Sleef_erfcd2_u15(values)), map(std::erfc)); | ||||||
|  |   } | ||||||
|  |   Vectorized<double> exp() const { | ||||||
|  |     return USE_SLEEF( | ||||||
|  |         Vectorized<double>(Sleef_expd2_u10(values)), map(std::exp)); | ||||||
|  |   } | ||||||
|  |   Vectorized<double> exp2() const { | ||||||
|  |     return USE_SLEEF( | ||||||
|  |         Vectorized<double>(Sleef_exp2d2_u10(values)), map(std::exp2)); | ||||||
|  |   } | ||||||
|  |   Vectorized<double> expm1() const { | ||||||
|  |     return USE_SLEEF( | ||||||
|  |         Vectorized<double>(Sleef_expm1d2_u10(values)), map(std::expm1)); | ||||||
|  |   } | ||||||
|  |   Vectorized<double> fmod(const Vectorized<double>& q) const {USE_SLEEF( | ||||||
|  |       { return Vectorized<double>(Sleef_fmodd2(values, q)); }, | ||||||
|  |       { | ||||||
|  |         __at_align__ double tmp[size()]; | ||||||
|  |         __at_align__ double tmp_q[size()]; | ||||||
|  |         store(tmp); | ||||||
|  |         q.store(tmp_q); | ||||||
|  |         for (int64_t i = 0; i < size(); i++) { | ||||||
|  |           tmp[i] = std::fmod(tmp[i], tmp_q[i]); | ||||||
|  |         } | ||||||
|  |         return loadu(tmp); | ||||||
|  |       })} Vectorized<double> hypot(const Vectorized<double>& b) const { | ||||||
|  |       USE_SLEEF( | ||||||
|  |           { return Vectorized<double>(Sleef_hypotd2_u05(values, b)); }, | ||||||
|  |           { | ||||||
|  |             __at_align__ double tmp[size()]; | ||||||
|  |             __at_align__ double tmp_b[size()]; | ||||||
|  |             store(tmp); | ||||||
|  |             b.store(tmp_b); | ||||||
|  |             for (int64_t i = 0; i < size(); i++) { | ||||||
|  |               tmp[i] = std::hypot(tmp[i], tmp_b[i]); | ||||||
|  |             } | ||||||
|  |             return loadu(tmp); | ||||||
|  |           })} Vectorized<double> i0() const { | ||||||
|  |     return map(calc_i0); | ||||||
|  |   } | ||||||
|  |   Vectorized<double> nextafter(const Vectorized<double>& b) const {USE_SLEEF( | ||||||
|  |       { return Vectorized<double>(Sleef_nextafterd2(values, b)); }, | ||||||
|  |       { | ||||||
|  |         __at_align__ double tmp[size()]; | ||||||
|  |         __at_align__ double tmp_b[size()]; | ||||||
|  |         store(tmp); | ||||||
|  |         b.store(tmp_b); | ||||||
|  |         for (int64_t i = 0; i < size(); ++i) { | ||||||
|  |           tmp[i] = std::nextafter(tmp[i], tmp_b[i]); | ||||||
|  |         } | ||||||
|  |         return loadu(tmp); | ||||||
|  |       })} Vectorized<double> log() const { | ||||||
|  |     return USE_SLEEF( | ||||||
|  |         Vectorized<double>(Sleef_logd2_u10(values)), map(std::log)); | ||||||
|  |   } | ||||||
|  |   Vectorized<double> log2() const { | ||||||
|  |     return USE_SLEEF( | ||||||
|  |         Vectorized<double>(Sleef_log2d2_u10(values)), map(std::log2)); | ||||||
|  |   } | ||||||
|  |   Vectorized<double> log10() const { | ||||||
|  |     return USE_SLEEF( | ||||||
|  |         Vectorized<double>(Sleef_log10d2_u10(values)), map(std::log10)); | ||||||
|  |   } | ||||||
|  |   Vectorized<double> log1p() const { | ||||||
|  |     return USE_SLEEF( | ||||||
|  |         Vectorized<double>(Sleef_log1pd2_u10(values)), map(std::log1p)); | ||||||
|  |   } | ||||||
|  |   Vectorized<double> frac() const; | ||||||
|  |   Vectorized<double> sin() const { | ||||||
|  |     return USE_SLEEF( | ||||||
|  |         Vectorized<double>(Sleef_sind2_u10(values)), map(std::sin)); | ||||||
|  |   } | ||||||
|  |   Vectorized<double> sinh() const { | ||||||
|  |     return USE_SLEEF( | ||||||
|  |         Vectorized<double>(Sleef_sinhd2_u10(values)), map(std::sinh)); | ||||||
|  |   } | ||||||
|  |   Vectorized<double> cos() const { | ||||||
|  |     return USE_SLEEF( | ||||||
|  |         Vectorized<double>(Sleef_cosd2_u10(values)), map(std::cos)); | ||||||
|  |   } | ||||||
|  |   Vectorized<double> cosh() const { | ||||||
|  |     return USE_SLEEF( | ||||||
|  |         Vectorized<double>(Sleef_coshd2_u10(values)), map(std::cosh)); | ||||||
|  |   } | ||||||
|  |   Vectorized<double> pow(const Vectorized<double>& b) const {USE_SLEEF( | ||||||
|  |       { return Vectorized<double>(Sleef_powd2_u10(values, b)); }, | ||||||
|  |       { | ||||||
|  |         __at_align__ double tmp[size()]; | ||||||
|  |         __at_align__ double tmp_b[size()]; | ||||||
|  |         store(tmp); | ||||||
|  |         b.store(tmp_b); | ||||||
|  |         for (int64_t i = 0; i < size(); i++) { | ||||||
|  |           tmp[i] = std::pow(tmp[i], tmp_b[i]); | ||||||
|  |         } | ||||||
|  |         return loadu(tmp); | ||||||
|  |       })} // Comparison using the _CMP_**_OQ predicate. | ||||||
|  |           //   `O`: get false if an operand is NaN | ||||||
|  |           //   `Q`: do not raise if an operand is NaN | ||||||
|  |   Vectorized<double> tan() const { | ||||||
|  |     return USE_SLEEF( | ||||||
|  |         Vectorized<double>(Sleef_tand2_u10(values)), map(std::tan)); | ||||||
|  |   } | ||||||
|  |   Vectorized<double> tanh() const { | ||||||
|  |     return USE_SLEEF( | ||||||
|  |         Vectorized<double>(Sleef_tanhd2_u10(values)), map(std::tanh)); | ||||||
|  |   } | ||||||
|  |   Vectorized<double> lgamma() const { | ||||||
|  |     return USE_SLEEF( | ||||||
|  |         Vectorized<double>(Sleef_lgammad2_u10(values)), map(std::lgamma)); | ||||||
|  |   } | ||||||
|  |   Vectorized<double> erfinv() const { | ||||||
|  |     return map(calc_erfinv); | ||||||
|  |   } | ||||||
|  |   Vectorized<double> exp_u20() const { | ||||||
|  |     return exp(); | ||||||
|  |   } | ||||||
|  |   Vectorized<double> fexp_u20() const { | ||||||
|  |     return exp(); | ||||||
|  |   } | ||||||
|  |   Vectorized<double> i0e() const { | ||||||
|  |     return map(calc_i0e); | ||||||
|  |   } | ||||||
|  |   Vectorized<double> digamma() const { | ||||||
|  |     return map(calc_digamma); | ||||||
|  |   } | ||||||
|  |   Vectorized<double> igamma(const Vectorized<double>& x) const { | ||||||
|  |     __at_align__ double tmp[size()]; | ||||||
|  |     __at_align__ double tmp_x[size()]; | ||||||
|  |     store(tmp); | ||||||
|  |     x.store(tmp_x); | ||||||
|  |     for (int64_t i = 0; i < size(); i++) { | ||||||
|  |       tmp[i] = calc_igamma(tmp[i], tmp_x[i]); | ||||||
|  |     } | ||||||
|  |     return loadu(tmp); | ||||||
|  |   } | ||||||
|  |   Vectorized<double> igammac(const Vectorized<double>& x) const { | ||||||
|  |     __at_align__ double tmp[size()]; | ||||||
|  |     __at_align__ double tmp_x[size()]; | ||||||
|  |     store(tmp); | ||||||
|  |     x.store(tmp_x); | ||||||
|  |     for (int64_t i = 0; i < size(); i++) { | ||||||
|  |       tmp[i] = calc_igammac(tmp[i], tmp_x[i]); | ||||||
|  |     } | ||||||
|  |     return loadu(tmp); | ||||||
|  |   } | ||||||
|  |   Vectorized<double> ceil() const { | ||||||
|  |     return vrndpq_f64(values); | ||||||
|  |   } | ||||||
|  |   Vectorized<double> floor() const { | ||||||
|  |     return vrndmq_f64(values); | ||||||
|  |   } | ||||||
|  |   Vectorized<double> neg() const { | ||||||
|  |     return vnegq_f64(values); | ||||||
|  |   } | ||||||
|  |   Vectorized<double> round() const { | ||||||
|  |     return vrndiq_f64(values); | ||||||
|  |   } | ||||||
|  |   Vectorized<double> trunc() const { | ||||||
|  |     return vrndq_f64(values); | ||||||
|  |   } | ||||||
|  |   Vectorized<double> sqrt() const { | ||||||
|  |     return vsqrtq_f64(values); | ||||||
|  |   } | ||||||
|  |   Vectorized<double> reciprocal() const { | ||||||
|  |     return vdivq_f64(vdupq_n_f64(1.0), values); | ||||||
|  |   } | ||||||
|  |   Vectorized<double> rsqrt() const { | ||||||
|  |     return vdivq_f64(vdupq_n_f64(1.0), vsqrtq_f64(values)); | ||||||
|  |   } | ||||||
|  |   double reduce_add() const { | ||||||
|  |     return vaddvq_f64(values); | ||||||
|  |   } | ||||||
|  |   double reduce_max() const { | ||||||
|  |     return vmaxvq_f64(values); | ||||||
|  |   } | ||||||
|  |   Vectorized<double> operator==(const Vectorized<double>& other) const { | ||||||
|  |     return Vectorized<double>( | ||||||
|  |         vreinterpretq_f64_u64(vceqq_f64(values, other.values))); | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   Vectorized<double> operator!=(const Vectorized<double>& other) const { | ||||||
|  |     float64x2_t r0 = vreinterpretq_f64_u32( | ||||||
|  |         vmvnq_u32(vreinterpretq_u32_u64(vceqq_f64(values, other.values)))); | ||||||
|  |     return Vectorized<double>(r0); | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   Vectorized<double> operator<(const Vectorized<double>& other) const { | ||||||
|  |     return Vectorized<double>( | ||||||
|  |         vreinterpretq_f64_u64(vcltq_f64(values, other.values))); | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   Vectorized<double> operator<=(const Vectorized<double>& other) const { | ||||||
|  |     return Vectorized<double>( | ||||||
|  |         vreinterpretq_f64_u64(vcleq_f64(values, other.values))); | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   Vectorized<double> operator>(const Vectorized<double>& other) const { | ||||||
|  |     return Vectorized<double>( | ||||||
|  |         vreinterpretq_f64_u64(vcgtq_f64(values, other.values))); | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   Vectorized<double> operator>=(const Vectorized<double>& other) const { | ||||||
|  |     return Vectorized<double>( | ||||||
|  |         vreinterpretq_f64_u64(vcgeq_f64(values, other.values))); | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   Vectorized<double> eq(const Vectorized<double>& other) const; | ||||||
|  |   Vectorized<double> ne(const Vectorized<double>& other) const; | ||||||
|  |   Vectorized<double> gt(const Vectorized<double>& other) const; | ||||||
|  |   Vectorized<double> ge(const Vectorized<double>& other) const; | ||||||
|  |   Vectorized<double> lt(const Vectorized<double>& other) const; | ||||||
|  |   Vectorized<double> le(const Vectorized<double>& other) const; | ||||||
|  | }; | ||||||
|  |  | ||||||
|  | template <> | ||||||
|  | Vectorized<double> inline operator+( | ||||||
|  |     const Vectorized<double>& a, | ||||||
|  |     const Vectorized<double>& b) { | ||||||
|  |   return vaddq_f64(a, b); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | template <> | ||||||
|  | Vectorized<double> inline operator-( | ||||||
|  |     const Vectorized<double>& a, | ||||||
|  |     const Vectorized<double>& b) { | ||||||
|  |   return vsubq_f64(a, b); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | template <> | ||||||
|  | Vectorized<double> inline operator*( | ||||||
|  |     const Vectorized<double>& a, | ||||||
|  |     const Vectorized<double>& b) { | ||||||
|  |   return vmulq_f64(a, b); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | template <> | ||||||
|  | Vectorized<double> inline operator/( | ||||||
|  |     const Vectorized<double>& a, | ||||||
|  |     const Vectorized<double>& b) { | ||||||
|  |   return vdivq_f64(a, b); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // frac. Implement this here so we can use subtraction | ||||||
|  | Vectorized<double> inline Vectorized<double>::frac() const { | ||||||
|  |   return *this - this->trunc(); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Implements the IEEE 754 201X `maximum` operation, which propagates NaN if | ||||||
|  | // either input is a NaN. | ||||||
|  | template <> | ||||||
|  | Vectorized<double> inline maximum( | ||||||
|  |     const Vectorized<double>& a, | ||||||
|  |     const Vectorized<double>& b) { | ||||||
|  |   return vmaxq_f64(a, b); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Implements the IEEE 754 201X `minimum` operation, which propagates NaN if | ||||||
|  | // either input is a NaN. | ||||||
|  | template <> | ||||||
|  | Vectorized<double> inline minimum( | ||||||
|  |     const Vectorized<double>& a, | ||||||
|  |     const Vectorized<double>& b) { | ||||||
|  |   return vminq_f64(a, b); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | template <> | ||||||
|  | Vectorized<double> inline clamp( | ||||||
|  |     const Vectorized<double>& a, | ||||||
|  |     const Vectorized<double>& min, | ||||||
|  |     const Vectorized<double>& max) { | ||||||
|  |   return vminq_f64(max, vmaxq_f64(min, a)); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | template <> | ||||||
|  | Vectorized<double> inline clamp_max( | ||||||
|  |     const Vectorized<double>& a, | ||||||
|  |     const Vectorized<double>& max) { | ||||||
|  |   return vminq_f64(max, a); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | template <> | ||||||
|  | Vectorized<double> inline clamp_min( | ||||||
|  |     const Vectorized<double>& a, | ||||||
|  |     const Vectorized<double>& min) { | ||||||
|  |   return vmaxq_f64(min, a); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | template <> | ||||||
|  | Vectorized<double> inline operator&( | ||||||
|  |     const Vectorized<double>& a, | ||||||
|  |     const Vectorized<double>& b) { | ||||||
|  |   return vreinterpretq_f64_u64( | ||||||
|  |       vandq_u64(vreinterpretq_u64_f64(a), vreinterpretq_u64_f64(b))); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | template <> | ||||||
|  | Vectorized<double> inline operator|( | ||||||
|  |     const Vectorized<double>& a, | ||||||
|  |     const Vectorized<double>& b) { | ||||||
|  |   return vreinterpretq_f64_u64( | ||||||
|  |       vorrq_u64(vreinterpretq_u64_f64(a), vreinterpretq_u64_f64(b))); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | template <> | ||||||
|  | Vectorized<double> inline operator^( | ||||||
|  |     const Vectorized<double>& a, | ||||||
|  |     const Vectorized<double>& b) { | ||||||
|  |   return vreinterpretq_f64_u64( | ||||||
|  |       veorq_u64(vreinterpretq_u64_f64(a), vreinterpretq_u64_f64(b))); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | inline Vectorized<double> Vectorized<double>::eq( | ||||||
|  |     const Vectorized<double>& other) const { | ||||||
|  |   return (*this == other) & Vectorized<double>(1.0); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | inline Vectorized<double> Vectorized<double>::ne( | ||||||
|  |     const Vectorized<double>& other) const { | ||||||
|  |   return (*this != other) & Vectorized<double>(1.0); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | inline Vectorized<double> Vectorized<double>::gt( | ||||||
|  |     const Vectorized<double>& other) const { | ||||||
|  |   return (*this > other) & Vectorized<double>(1.0); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | inline Vectorized<double> Vectorized<double>::ge( | ||||||
|  |     const Vectorized<double>& other) const { | ||||||
|  |   return (*this >= other) & Vectorized<double>(1.0); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | inline Vectorized<double> Vectorized<double>::lt( | ||||||
|  |     const Vectorized<double>& other) const { | ||||||
|  |   return (*this < other) & Vectorized<double>(1.0); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | inline Vectorized<double> Vectorized<double>::le( | ||||||
|  |     const Vectorized<double>& other) const { | ||||||
|  |   return (*this <= other) & Vectorized<double>(1.0); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | template <> | ||||||
|  | Vectorized<double> inline fmadd( | ||||||
|  |     const Vectorized<double>& a, | ||||||
|  |     const Vectorized<double>& b, | ||||||
|  |     const Vectorized<double>& c) { | ||||||
|  |   return vfmaq_f64(c, a, b); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | template <> | ||||||
|  | Vectorized<double> inline fnmadd( | ||||||
|  |     const Vectorized<double>& a, | ||||||
|  |     const Vectorized<double>& b, | ||||||
|  |     const Vectorized<double>& c) { | ||||||
|  |   return vfmsq_f64(c, a, b); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | template <> | ||||||
|  | Vectorized<double> inline fmsub( | ||||||
|  |     const Vectorized<double>& a, | ||||||
|  |     const Vectorized<double>& b, | ||||||
|  |     const Vectorized<double>& c) { | ||||||
|  |   return vfmaq_f64(vnegq_f64(c), a, b); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | template <> | ||||||
|  | Vectorized<double> inline fnmsub( | ||||||
|  |     const Vectorized<double>& a, | ||||||
|  |     const Vectorized<double>& b, | ||||||
|  |     const Vectorized<double>& c) { | ||||||
|  |   return vfmsq_f64(vnegq_f64(c), a, b); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | } // namespace CPU_CAPABILITY | ||||||
|  | } // namespace at::vec | ||||||
| @ -307,11 +307,49 @@ class Vectorized<float> { | |||||||
|   DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(exp) |   DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(exp) | ||||||
|   DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(exp2) |   DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(exp2) | ||||||
|   DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(expm1) |   DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(expm1) | ||||||
|  |   // Implementation copied from Arm Optimized Routine | ||||||
|  |   // https://github.com/ARM-software/optimized-routines/blob/master/math/aarch64/advsimd/expf.c | ||||||
|   Vectorized<float> exp_u20() const { |   Vectorized<float> exp_u20() const { | ||||||
|     return exp(); |     // bail out to sleef if it's a special case: | ||||||
|  |     // i.e. there's an input s.t. |input| > 87.3.... | ||||||
|  |     const float32x4_t special_bound = vdupq_n_f32(0x1.5d5e2ap+6f); | ||||||
|  |     uint32x4_t cmp = vcagtq_f32(values, special_bound); | ||||||
|  |     if (vpaddd_u64(vreinterpretq_u64_u32(cmp)) != 0) { | ||||||
|  |       return exp(); | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     const float32x4_t inv_ln2 = vdupq_n_f32(0x1.715476p+0f); | ||||||
|  |     const float ln2_hi = 0x1.62e4p-1f; | ||||||
|  |     const float ln2_lo = 0x1.7f7d1cp-20f; | ||||||
|  |     const float c0 = 0x1.0e4020p-7f; | ||||||
|  |     const float c2 = 0x1.555e66p-3f; | ||||||
|  |     const float32x4_t ln2_c02 = {ln2_hi, ln2_lo, c0, c2}; | ||||||
|  |  | ||||||
|  |     const uint32x4_t exponent_bias = vdupq_n_u32(0x3f800000); | ||||||
|  |     const float32x4_t c1 = vdupq_n_f32(0x1.573e2ep-5f); | ||||||
|  |     const float32x4_t c3 = vdupq_n_f32(0x1.fffdb6p-2f); | ||||||
|  |     const float32x4_t c4 = vdupq_n_f32(0x1.ffffecp-1f); | ||||||
|  |  | ||||||
|  |     /* exp(x) = 2^n (1 + poly(r)), with 1 + poly(r) in [1/sqrt(2),sqrt(2)] | ||||||
|  |       x = ln2*n + r, with r in [-ln2/2, ln2/2].  */ | ||||||
|  |  | ||||||
|  |     float32x4_t n = vrndaq_f32(vmulq_f32(values, inv_ln2)); | ||||||
|  |     float32x4_t r = vfmsq_laneq_f32(values, n, ln2_c02, 0); | ||||||
|  |     r = vfmsq_laneq_f32(r, n, ln2_c02, 1); | ||||||
|  |     uint32x4_t e = vshlq_n_u32(vreinterpretq_u32_s32(vcvtq_s32_f32(n)), 23); | ||||||
|  |     float32x4_t scale = vreinterpretq_f32_u32(vaddq_u32(e, exponent_bias)); | ||||||
|  |  | ||||||
|  |     float32x4_t r2 = vmulq_f32(r, r); | ||||||
|  |     float32x4_t p = vfmaq_laneq_f32(c1, r, ln2_c02, 2); | ||||||
|  |     float32x4_t q = vfmaq_laneq_f32(c3, r, ln2_c02, 3); | ||||||
|  |     q = vfmaq_f32(q, p, r2); | ||||||
|  |     p = vmulq_f32(c4, r); | ||||||
|  |     float32x4_t poly = vfmaq_f32(p, q, r2); | ||||||
|  |  | ||||||
|  |     return vfmaq_f32(scale, poly, scale); | ||||||
|   } |   } | ||||||
|   Vectorized<float> fexp_u20() const { |   Vectorized<float> fexp_u20() const { | ||||||
|     return exp(); |     return exp_u20(); | ||||||
|   } |   } | ||||||
|   DEFINE_SLEEF_COMPATIBLE_BINARY_ELEMENTWISE_FUNC_WITH_SLEEF_NAME( |   DEFINE_SLEEF_COMPATIBLE_BINARY_ELEMENTWISE_FUNC_WITH_SLEEF_NAME( | ||||||
|       fmod, |       fmod, | ||||||
| @ -540,42 +578,6 @@ inline Vectorized<float> Vectorized<float>::le( | |||||||
|   return (*this <= other) & Vectorized<float>(1.0f); |   return (*this <= other) & Vectorized<float>(1.0f); | ||||||
| } | } | ||||||
|  |  | ||||||
| template <> |  | ||||||
| inline void convert(const float* src, int32_t* dst, int64_t n) { |  | ||||||
|   int64_t i; |  | ||||||
| #ifndef __msvc_cl__ |  | ||||||
| #pragma unroll |  | ||||||
| #endif |  | ||||||
|   for (i = 0; i <= (n - Vectorized<float>::size()); |  | ||||||
|        i += Vectorized<float>::size()) { |  | ||||||
|     vst1q_s32(dst + i, vcvtq_s32_f32(vld1q_f32(src + i))); |  | ||||||
|   } |  | ||||||
| #ifndef __msvc_cl__ |  | ||||||
| #pragma unroll |  | ||||||
| #endif |  | ||||||
|   for (; i < n; i++) { |  | ||||||
|     dst[i] = static_cast<int32_t>(src[i]); |  | ||||||
|   } |  | ||||||
| } |  | ||||||
|  |  | ||||||
| template <> |  | ||||||
| inline void convert(const int32_t* src, float* dst, int64_t n) { |  | ||||||
|   int64_t i; |  | ||||||
| #ifndef __msvc_cl__ |  | ||||||
| #pragma unroll |  | ||||||
| #endif |  | ||||||
|   for (i = 0; i <= (n - Vectorized<float>::size()); |  | ||||||
|        i += Vectorized<float>::size()) { |  | ||||||
|     vst1q_f32(dst + i, vcvtq_f32_s32(vld1q_s32(src + i))); |  | ||||||
|   } |  | ||||||
| #ifndef __msvc_cl__ |  | ||||||
| #pragma unroll |  | ||||||
| #endif |  | ||||||
|   for (; i < n; i++) { |  | ||||||
|     dst[i] = static_cast<float>(src[i]); |  | ||||||
|   } |  | ||||||
| } |  | ||||||
|  |  | ||||||
| template <> | template <> | ||||||
| Vectorized<float> inline fmadd( | Vectorized<float> inline fmadd( | ||||||
|     const Vectorized<float>& a, |     const Vectorized<float>& a, | ||||||
| @ -632,8 +634,7 @@ inline Vectorized<float> Vectorized<float>::erf() const { | |||||||
|   // - exp(- x * x) |   // - exp(- x * x) | ||||||
|   auto pow_2 = (*this) * (*this); |   auto pow_2 = (*this) * (*this); | ||||||
|   auto neg_pow_2 = pow_2 ^ neg_zero_vec; |   auto neg_pow_2 = pow_2 ^ neg_zero_vec; | ||||||
|   auto tmp4 = neg_pow_2.map( |   auto tmp4 = neg_pow_2.exp(); | ||||||
|       std::exp); // This can be swapped for a faster implementation of exp. |  | ||||||
|   auto tmp5 = tmp4 ^ neg_zero_vec; |   auto tmp5 = tmp4 ^ neg_zero_vec; | ||||||
|   // erf(x) = sign(x) * (1 - r * t * exp(- x * x)) |   // erf(x) = sign(x) * (1 - r * t * exp(- x * x)) | ||||||
|   auto tmp6 = t * tmp5; |   auto tmp6 = t * tmp5; | ||||||
|  | |||||||
| @ -234,7 +234,7 @@ class Vectorized<c10::Half> : public Vectorized16< | |||||||
|         vshlq_u16(vandq_u16(is_zero_vec, vdupq_n_u16(1)), shift); |         vshlq_u16(vandq_u16(is_zero_vec, vdupq_n_u16(1)), shift); | ||||||
|     return vaddvq_u16(bits_vec); |     return vaddvq_u16(bits_vec); | ||||||
| #else // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | #else // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | ||||||
|     // use known working implmentation. |     // use known working implementation. | ||||||
|     __at_align__ value_type tmp[size()]; |     __at_align__ value_type tmp[size()]; | ||||||
|     store(tmp); |     store(tmp); | ||||||
|     int mask = 0; |     int mask = 0; | ||||||
| @ -569,46 +569,6 @@ inline Vectorized<c10::Half> Vectorized<c10::Half>::le( | |||||||
|   return (*this <= other) & Vectorized<c10::Half>(1); |   return (*this <= other) & Vectorized<c10::Half>(1); | ||||||
| } | } | ||||||
|  |  | ||||||
| // These are global functions, so the defaults in vec_base.h should |  | ||||||
| // work fine if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC is not available. |  | ||||||
| #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC |  | ||||||
| template <> |  | ||||||
| inline void convert(const float16_t* src, int16_t* dst, int64_t n) { |  | ||||||
|   int64_t i; |  | ||||||
| #ifndef __msvc_cl__ |  | ||||||
| #pragma unroll |  | ||||||
| #endif |  | ||||||
|   for (i = 0; i <= (n - Vectorized<c10::Half>::size()); |  | ||||||
|        i += Vectorized<c10::Half>::size()) { |  | ||||||
|     vst1q_s16(dst + i, vcvtq_s16_f16(vld1q_f16(src + i))); |  | ||||||
|   } |  | ||||||
| #ifndef __msvc_cl__ |  | ||||||
| #pragma unroll |  | ||||||
| #endif |  | ||||||
|   for (; i < n; i++) { |  | ||||||
|     dst[i] = static_cast<int16_t>(src[i]); |  | ||||||
|   } |  | ||||||
| } |  | ||||||
|  |  | ||||||
| template <> |  | ||||||
| inline void convert(const int16_t* src, float16_t* dst, int64_t n) { |  | ||||||
|   int64_t i; |  | ||||||
| #ifndef __msvc_cl__ |  | ||||||
| #pragma unroll |  | ||||||
| #endif |  | ||||||
|   for (i = 0; i <= (n - Vectorized<c10::Half>::size()); |  | ||||||
|        i += Vectorized<c10::Half>::size()) { |  | ||||||
|     vst1q_f16(dst + i, vcvtq_f16_s16(vld1q_s16(src + i))); |  | ||||||
|   } |  | ||||||
| #ifndef __msvc_cl__ |  | ||||||
| #pragma unroll |  | ||||||
| #endif |  | ||||||
|   for (; i < n; i++) { |  | ||||||
|     dst[i] = static_cast<float16_t>(src[i]); |  | ||||||
|   } |  | ||||||
| } |  | ||||||
| #endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC |  | ||||||
|  |  | ||||||
| template <> | template <> | ||||||
| Vectorized<c10::Half> inline fmadd( | Vectorized<c10::Half> inline fmadd( | ||||||
|     const Vectorized<c10::Half>& a, |     const Vectorized<c10::Half>& a, | ||||||
|  | |||||||
							
								
								
									
										378
									
								
								aten/src/ATen/cpu/vec/vec128/vec128_uint_aarch64.h
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										378
									
								
								aten/src/ATen/cpu/vec/vec128/vec128_uint_aarch64.h
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,378 @@ | |||||||
|  | #pragma once | ||||||
|  |  | ||||||
|  | #include <ATen/cpu/vec/intrinsics.h> | ||||||
|  | #include <ATen/cpu/vec/vec_base.h> | ||||||
|  | #include <c10/macros/Macros.h> | ||||||
|  | #include <c10/util/irange.h> | ||||||
|  |  | ||||||
|  | namespace at::vec { | ||||||
|  | // Note [CPU_CAPABILITY namespace] | ||||||
|  | // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | ||||||
|  | // This header, and all of its subheaders, will be compiled with | ||||||
|  | // different architecture flags for each supported set of vector | ||||||
|  | // intrinsics. So we need to make sure they aren't inadvertently | ||||||
|  | // linked together. We do this by declaring objects in an `inline | ||||||
|  | // namespace` which changes the name mangling, but can still be | ||||||
|  | // accessed as `at::vec`. | ||||||
|  | inline namespace CPU_CAPABILITY { | ||||||
|  |  | ||||||
|  | #define VEC_UINT_NEON_TEMPLATE(vl, bit)                                       \ | ||||||
|  |   template <>                                                                 \ | ||||||
|  |   struct is_vec_specialized_for<uint##bit##_t> : std::bool_constant<true> {}; \ | ||||||
|  |                                                                               \ | ||||||
|  |   template <>                                                                 \ | ||||||
|  |   class Vectorized<uint##bit##_t> {                                           \ | ||||||
|  |     using neon_type = uint##bit##x##vl##_t;                                   \ | ||||||
|  |                                                                               \ | ||||||
|  |    private:                                                                   \ | ||||||
|  |     neon_type values;                                                         \ | ||||||
|  |                                                                               \ | ||||||
|  |    public:                                                                    \ | ||||||
|  |     using value_type = uint##bit##_t;                                         \ | ||||||
|  |     using size_type = int;                                                    \ | ||||||
|  |     static constexpr size_type size() {                                       \ | ||||||
|  |       return vl;                                                              \ | ||||||
|  |     }                                                                         \ | ||||||
|  |     Vectorized() {                                                            \ | ||||||
|  |       values = vdupq_n_u##bit(0);                                             \ | ||||||
|  |     }                                                                         \ | ||||||
|  |     Vectorized(neon_type v) : values(v) {}                                    \ | ||||||
|  |     Vectorized(uint##bit##_t val);                                            \ | ||||||
|  |     template <                                                                \ | ||||||
|  |         typename... Args,                                                     \ | ||||||
|  |         typename = std::enable_if_t<(sizeof...(Args) == size())>>             \ | ||||||
|  |     Vectorized(Args... vals) {                                                \ | ||||||
|  |       __at_align__ uint##bit##_t buffer[size()] = {vals...};                  \ | ||||||
|  |       values = vld1q_u##bit(buffer);                                          \ | ||||||
|  |     }                                                                         \ | ||||||
|  |     operator neon_type() const {                                              \ | ||||||
|  |       return values;                                                          \ | ||||||
|  |     }                                                                         \ | ||||||
|  |     static Vectorized<uint##bit##_t> loadu(                                   \ | ||||||
|  |         const void* ptr,                                                      \ | ||||||
|  |         uint64_t count = size());                                             \ | ||||||
|  |     void store(void* ptr, uint64_t count = size()) const;                     \ | ||||||
|  |     template <uint64_t mask>                                                  \ | ||||||
|  |     static Vectorized<uint##bit##_t> blend(                                   \ | ||||||
|  |         const Vectorized<uint##bit##_t>& a,                                   \ | ||||||
|  |         const Vectorized<uint##bit##_t>& b);                                  \ | ||||||
|  |     static Vectorized<uint##bit##_t> blendv(                                  \ | ||||||
|  |         const Vectorized<uint##bit##_t>& a,                                   \ | ||||||
|  |         const Vectorized<uint##bit##_t>& b,                                   \ | ||||||
|  |         const Vectorized<uint##bit##_t>& mask_) {                             \ | ||||||
|  |       return vbslq_u##bit(mask_.values, b, a);                                \ | ||||||
|  |     }                                                                         \ | ||||||
|  |     template <typename step_t>                                                \ | ||||||
|  |     static Vectorized<uint##bit##_t> arange(                                  \ | ||||||
|  |         value_type base = 0,                                                  \ | ||||||
|  |         step_t step = static_cast<step_t>(1));                                \ | ||||||
|  |     static Vectorized<uint##bit##_t> set(                                     \ | ||||||
|  |         const Vectorized<uint##bit##_t>& a,                                   \ | ||||||
|  |         const Vectorized<uint##bit##_t>& b,                                   \ | ||||||
|  |         uint64_t count = size());                                             \ | ||||||
|  |     const uint##bit##_t& operator[](uint idx) const = delete;                 \ | ||||||
|  |     uint##bit##_t& operator[](uint idx) = delete;                             \ | ||||||
|  |     Vectorized<uint##bit##_t> abs() const {                                   \ | ||||||
|  |       return values;                                                          \ | ||||||
|  |     }                                                                         \ | ||||||
|  |     Vectorized<uint##bit##_t> real() const {                                  \ | ||||||
|  |       return values;                                                          \ | ||||||
|  |     }                                                                         \ | ||||||
|  |     Vectorized<uint##bit##_t> imag() const {                                  \ | ||||||
|  |       return vdupq_n_u##bit(0);                                               \ | ||||||
|  |     }                                                                         \ | ||||||
|  |     Vectorized<uint##bit##_t> conj() const {                                  \ | ||||||
|  |       return values;                                                          \ | ||||||
|  |     }                                                                         \ | ||||||
|  |     Vectorized<uint##bit##_t> neg() const {                                   \ | ||||||
|  |       return vreinterpretq_u##bit##_s##bit(                                   \ | ||||||
|  |           vnegq_s##bit(vreinterpretq_s##bit##_u##bit(values)));               \ | ||||||
|  |     }                                                                         \ | ||||||
|  |     uint##bit##_t reduce_add() const {                                        \ | ||||||
|  |       return vaddvq_u##bit(values);                                           \ | ||||||
|  |     }                                                                         \ | ||||||
|  |     uint##bit##_t reduce_max() const;                                         \ | ||||||
|  |     Vectorized<uint##bit##_t> operator==(                                     \ | ||||||
|  |         const Vectorized<uint##bit##_t>& other) const {                       \ | ||||||
|  |       return Vectorized<value_type>(vceqq_u##bit(values, other.values));      \ | ||||||
|  |     }                                                                         \ | ||||||
|  |     Vectorized<uint##bit##_t> operator!=(                                     \ | ||||||
|  |         const Vectorized<uint##bit##_t>& other) const;                        \ | ||||||
|  |     Vectorized<uint##bit##_t> operator<(                                      \ | ||||||
|  |         const Vectorized<uint##bit##_t>& other) const {                       \ | ||||||
|  |       return Vectorized<value_type>(vcltq_u##bit(values, other.values));      \ | ||||||
|  |     }                                                                         \ | ||||||
|  |     Vectorized<uint##bit##_t> operator<=(                                     \ | ||||||
|  |         const Vectorized<uint##bit##_t>& other) const {                       \ | ||||||
|  |       return Vectorized<value_type>(vcleq_u##bit(values, other.values));      \ | ||||||
|  |     }                                                                         \ | ||||||
|  |     Vectorized<uint##bit##_t> operator>(                                      \ | ||||||
|  |         const Vectorized<uint##bit##_t>& other) const {                       \ | ||||||
|  |       return Vectorized<value_type>(vcgtq_u##bit(values, other.values));      \ | ||||||
|  |     }                                                                         \ | ||||||
|  |     Vectorized<uint##bit##_t> operator>=(                                     \ | ||||||
|  |         const Vectorized<uint##bit##_t>& other) const {                       \ | ||||||
|  |       return Vectorized<value_type>(vcgeq_u##bit(values, other.values));      \ | ||||||
|  |     }                                                                         \ | ||||||
|  |     Vectorized<uint##bit##_t> eq(                                             \ | ||||||
|  |         const Vectorized<uint##bit##_t>& other) const;                        \ | ||||||
|  |     Vectorized<uint##bit##_t> ne(                                             \ | ||||||
|  |         const Vectorized<uint##bit##_t>& other) const;                        \ | ||||||
|  |     Vectorized<uint##bit##_t> gt(                                             \ | ||||||
|  |         const Vectorized<uint##bit##_t>& other) const;                        \ | ||||||
|  |     Vectorized<uint##bit##_t> ge(                                             \ | ||||||
|  |         const Vectorized<uint##bit##_t>& other) const;                        \ | ||||||
|  |     Vectorized<uint##bit##_t> lt(                                             \ | ||||||
|  |         const Vectorized<uint##bit##_t>& other) const;                        \ | ||||||
|  |     Vectorized<uint##bit##_t> le(                                             \ | ||||||
|  |         const Vectorized<uint##bit##_t>& other) const;                        \ | ||||||
|  |   };                                                                          \ | ||||||
|  |   template <>                                                                 \ | ||||||
|  |   Vectorized<uint##bit##_t> inline operator+(                                 \ | ||||||
|  |       const Vectorized<uint##bit##_t>& a,                                     \ | ||||||
|  |       const Vectorized<uint##bit##_t>& b) {                                   \ | ||||||
|  |     return vaddq_u##bit(a, b);                                                \ | ||||||
|  |   }                                                                           \ | ||||||
|  |   template <>                                                                 \ | ||||||
|  |   Vectorized<uint##bit##_t> inline operator-(                                 \ | ||||||
|  |       const Vectorized<uint##bit##_t>& a,                                     \ | ||||||
|  |       const Vectorized<uint##bit##_t>& b) {                                   \ | ||||||
|  |     return vsubq_u##bit(a, b);                                                \ | ||||||
|  |   }                                                                           \ | ||||||
|  |   template <>                                                                 \ | ||||||
|  |   Vectorized<uint##bit##_t> inline operator&(                                 \ | ||||||
|  |       const Vectorized<uint##bit##_t>& a,                                     \ | ||||||
|  |       const Vectorized<uint##bit##_t>& b) {                                   \ | ||||||
|  |     return vandq_u##bit(a, b);                                                \ | ||||||
|  |   }                                                                           \ | ||||||
|  |   template <>                                                                 \ | ||||||
|  |   Vectorized<uint##bit##_t> inline operator|(                                 \ | ||||||
|  |       const Vectorized<uint##bit##_t>& a,                                     \ | ||||||
|  |       const Vectorized<uint##bit##_t>& b) {                                   \ | ||||||
|  |     return vorrq_u##bit(a, b);                                                \ | ||||||
|  |   }                                                                           \ | ||||||
|  |   template <>                                                                 \ | ||||||
|  |   Vectorized<uint##bit##_t> inline operator^(                                 \ | ||||||
|  |       const Vectorized<uint##bit##_t>& a,                                     \ | ||||||
|  |       const Vectorized<uint##bit##_t>& b) {                                   \ | ||||||
|  |     return veorq_u##bit(a, b);                                                \ | ||||||
|  |   }                                                                           \ | ||||||
|  |   Vectorized<uint##bit##_t> inline Vectorized<uint##bit##_t>::eq(             \ | ||||||
|  |       const Vectorized<uint##bit##_t>& other) const {                         \ | ||||||
|  |     return (*this == other) & Vectorized<uint##bit##_t>(1);                   \ | ||||||
|  |   }                                                                           \ | ||||||
|  |   Vectorized<uint##bit##_t> inline Vectorized<uint##bit##_t>::ne(             \ | ||||||
|  |       const Vectorized<uint##bit##_t>& other) const {                         \ | ||||||
|  |     return (*this != other) & Vectorized<uint##bit##_t>(1);                   \ | ||||||
|  |   }                                                                           \ | ||||||
|  |   Vectorized<uint##bit##_t> inline Vectorized<uint##bit##_t>::gt(             \ | ||||||
|  |       const Vectorized<uint##bit##_t>& other) const {                         \ | ||||||
|  |     return (*this > other) & Vectorized<uint##bit##_t>(1);                    \ | ||||||
|  |   }                                                                           \ | ||||||
|  |   Vectorized<uint##bit##_t> inline Vectorized<uint##bit##_t>::ge(             \ | ||||||
|  |       const Vectorized<uint##bit##_t>& other) const {                         \ | ||||||
|  |     return (*this >= other) & Vectorized<uint##bit##_t>(1);                   \ | ||||||
|  |   }                                                                           \ | ||||||
|  |   Vectorized<uint##bit##_t> inline Vectorized<uint##bit##_t>::lt(             \ | ||||||
|  |       const Vectorized<uint##bit##_t>& other) const {                         \ | ||||||
|  |     return (*this < other) & Vectorized<uint##bit##_t>(1);                    \ | ||||||
|  |   }                                                                           \ | ||||||
|  |   Vectorized<uint##bit##_t> inline Vectorized<uint##bit##_t>::le(             \ | ||||||
|  |       const Vectorized<uint##bit##_t>& other) const {                         \ | ||||||
|  |     return (*this <= other) & Vectorized<uint##bit##_t>(1);                   \ | ||||||
|  |   } | ||||||
|  |  | ||||||
|  | VEC_UINT_NEON_TEMPLATE(16, 8) | ||||||
|  |  | ||||||
|  | inline uint8_t Vectorized<uint8_t>::reduce_max() const { | ||||||
|  |   return vmaxvq_u8(values); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | template <> | ||||||
|  | Vectorized<uint8_t> inline operator*( | ||||||
|  |     const Vectorized<uint8_t>& a, | ||||||
|  |     const Vectorized<uint8_t>& b) { | ||||||
|  |   return vmulq_u8(a, b); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | template <> | ||||||
|  | inline Vectorized<uint8_t> operator~(const Vectorized<uint8_t>& a) { | ||||||
|  |   return vmvnq_u8(a); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | inline Vectorized<uint8_t> Vectorized<uint8_t>::operator!=( | ||||||
|  |     const Vectorized<uint8_t>& other) const { | ||||||
|  |   return ~(*this == other); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | template <> | ||||||
|  | Vectorized<uint8_t> inline minimum( | ||||||
|  |     const Vectorized<uint8_t>& a, | ||||||
|  |     const Vectorized<uint8_t>& b) { | ||||||
|  |   return vminq_u8(a, b); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | template <> | ||||||
|  | Vectorized<uint8_t> inline maximum( | ||||||
|  |     const Vectorized<uint8_t>& a, | ||||||
|  |     const Vectorized<uint8_t>& b) { | ||||||
|  |   return vmaxq_u8(a, b); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | template <uint64_t mask> | ||||||
|  | Vectorized<uint8_t> Vectorized<uint8_t>::blend( | ||||||
|  |     const Vectorized<uint8_t>& a, | ||||||
|  |     const Vectorized<uint8_t>& b) { | ||||||
|  |   // Build an array of flags: each bit of element is 1 if the corresponding bit | ||||||
|  |   // in 'mask' is set, 0 otherwise. | ||||||
|  |   uint8x16_t maskArray = { | ||||||
|  |       (mask & 1LL) ? 0xFF : 0, | ||||||
|  |       (mask & 2LL) ? 0xFF : 0, | ||||||
|  |       (mask & 4LL) ? 0xFF : 0, | ||||||
|  |       (mask & 8LL) ? 0xFF : 0, | ||||||
|  |       (mask & 16LL) ? 0xFF : 0, | ||||||
|  |       (mask & 32LL) ? 0xFF : 0, | ||||||
|  |       (mask & 64LL) ? 0xFF : 0, | ||||||
|  |       (mask & 128LL) ? 0xFF : 0, | ||||||
|  |       (mask & 256LL) ? 0xFF : 0, | ||||||
|  |       (mask & 512LL) ? 0xFF : 0, | ||||||
|  |       (mask & 1024LL) ? 0xFF : 0, | ||||||
|  |       (mask & 2048LL) ? 0xFF : 0, | ||||||
|  |       (mask & 4096LL) ? 0xFF : 0, | ||||||
|  |       (mask & 8192LL) ? 0xFF : 0, | ||||||
|  |       (mask & 16384LL) ? 0xFF : 0, | ||||||
|  |       (mask & 32768LL) ? 0xFF : 0}; | ||||||
|  |   // Use BSL to select elements from b where the mask is 1, else from a | ||||||
|  |   return vbslq_u8(maskArray, b.values, a.values); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | #define VEC_UINT_NEON_OPS(vl, bit)                                             \ | ||||||
|  |   inline Vectorized<uint##bit##_t>::Vectorized(uint##bit##_t val) {            \ | ||||||
|  |     values = vdupq_n_u##bit(val);                                              \ | ||||||
|  |   }                                                                            \ | ||||||
|  |   inline Vectorized<uint##bit##_t> Vectorized<uint##bit##_t>::loadu(           \ | ||||||
|  |       const void* ptr, uint64_t count) {                                       \ | ||||||
|  |     if (count == size()) {                                                     \ | ||||||
|  |       return vld1q_u##bit(reinterpret_cast<const uint##bit##_t*>(ptr));        \ | ||||||
|  |     } else {                                                                   \ | ||||||
|  |       __at_align__ uint##bit##_t tmp_values[size()];                           \ | ||||||
|  |       for (const auto i : c10::irange(size())) {                               \ | ||||||
|  |         tmp_values[i] = 0;                                                     \ | ||||||
|  |       }                                                                        \ | ||||||
|  |       std::memcpy(                                                             \ | ||||||
|  |           tmp_values,                                                          \ | ||||||
|  |           reinterpret_cast<const uint##bit##_t*>(ptr),                         \ | ||||||
|  |           count * sizeof(uint##bit##_t));                                      \ | ||||||
|  |       return vld1q_u##bit(reinterpret_cast<const uint##bit##_t*>(tmp_values)); \ | ||||||
|  |     }                                                                          \ | ||||||
|  |   }                                                                            \ | ||||||
|  |   inline void Vectorized<uint##bit##_t>::store(void* ptr, uint64_t count)      \ | ||||||
|  |       const {                                                                  \ | ||||||
|  |     if (count == size()) {                                                     \ | ||||||
|  |       vst1q_u##bit(reinterpret_cast<uint##bit##_t*>(ptr), values);             \ | ||||||
|  |     } else {                                                                   \ | ||||||
|  |       uint##bit##_t tmp_values[size()];                                        \ | ||||||
|  |       vst1q_u##bit(reinterpret_cast<uint##bit##_t*>(tmp_values), values);      \ | ||||||
|  |       std::memcpy(ptr, tmp_values, count * sizeof(uint##bit##_t));             \ | ||||||
|  |     }                                                                          \ | ||||||
|  |   } | ||||||
|  |  | ||||||
|  | VEC_UINT_NEON_OPS(16, 8) | ||||||
|  |  | ||||||
|  | template <typename step_t> | ||||||
|  | inline Vectorized<uint8_t> Vectorized<uint8_t>::arange( | ||||||
|  |     uint8_t base, | ||||||
|  |     step_t step) { | ||||||
|  |   const Vectorized<uint8_t> base_vec(base); | ||||||
|  |   const Vectorized<uint8_t> step_vec(step); | ||||||
|  |   const uint8x16_t step_sizes = { | ||||||
|  |       0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; | ||||||
|  |   return vmlaq_u8(base_vec, step_sizes, step_vec); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | template <> | ||||||
|  | Vectorized<uint8_t> inline operator>>( | ||||||
|  |     const Vectorized<uint8_t>& a, | ||||||
|  |     const Vectorized<uint8_t>& b) { | ||||||
|  |   uint8x16_t x = a; | ||||||
|  |   uint8x16_t bound = vdupq_n_u8(8); | ||||||
|  |   uint8x16_t z = vminq_u8(b, bound); | ||||||
|  |   return x >> z; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | template <> | ||||||
|  | Vectorized<uint8_t> inline operator<<( | ||||||
|  |     const Vectorized<uint8_t>& a, | ||||||
|  |     const Vectorized<uint8_t>& b) { | ||||||
|  |   uint8x16_t bound = vdupq_n_u8(8); | ||||||
|  |   uint8x16_t z = vminq_u8(b, bound); | ||||||
|  |   return vshlq_u8(a, vreinterpretq_s8_u8(z)); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | inline Vectorized<uint8_t> Vectorized<uint8_t>::set( | ||||||
|  |     const Vectorized<uint8_t>& a, | ||||||
|  |     const Vectorized<uint8_t>& b, | ||||||
|  |     uint64_t count) { | ||||||
|  |   if (count == 0) { | ||||||
|  |     return a; | ||||||
|  |   } else if (count >= 16) { | ||||||
|  |     return b; | ||||||
|  |   } else { | ||||||
|  |     // Build an array of flags: each bit of element is 1 if the corresponding | ||||||
|  |     // bit in 'mask' is set, 0 otherwise. | ||||||
|  |     uint8x16_t maskArray = { | ||||||
|  |         static_cast<uint8_t>((count >= 1LL) ? 0xFF : 0), | ||||||
|  |         static_cast<uint8_t>((count >= 2LL) ? 0xFF : 0), | ||||||
|  |         static_cast<uint8_t>((count >= 3LL) ? 0xFF : 0), | ||||||
|  |         static_cast<uint8_t>((count >= 4LL) ? 0xFF : 0), | ||||||
|  |         static_cast<uint8_t>((count >= 5LL) ? 0xFF : 0), | ||||||
|  |         static_cast<uint8_t>((count >= 6LL) ? 0xFF : 0), | ||||||
|  |         static_cast<uint8_t>((count >= 7LL) ? 0xFF : 0), | ||||||
|  |         static_cast<uint8_t>((count >= 8LL) ? 0xFF : 0), | ||||||
|  |         static_cast<uint8_t>((count >= 9LL) ? 0xFF : 0), | ||||||
|  |         static_cast<uint8_t>((count >= 10LL) ? 0xFF : 0), | ||||||
|  |         static_cast<uint8_t>((count >= 11LL) ? 0xFF : 0), | ||||||
|  |         static_cast<uint8_t>((count >= 12LL) ? 0xFF : 0), | ||||||
|  |         static_cast<uint8_t>((count >= 13LL) ? 0xFF : 0), | ||||||
|  |         static_cast<uint8_t>((count >= 14LL) ? 0xFF : 0), | ||||||
|  |         static_cast<uint8_t>((count >= 15LL) ? 0xFF : 0), | ||||||
|  |         0}; | ||||||
|  |  | ||||||
|  |     // Use BSL to select elements from b where the mask is 1, else from a | ||||||
|  |     return vbslq_u8(maskArray, b.values, a.values); | ||||||
|  |   } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | template <> | ||||||
|  | Vectorized<uint8_t> inline operator/( | ||||||
|  |     const Vectorized<uint8_t>& a, | ||||||
|  |     const Vectorized<uint8_t>& b) { | ||||||
|  |   uint8x16_t x = a; | ||||||
|  |   uint8x16_t y = b; | ||||||
|  |   return x / y; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | template <> | ||||||
|  | Vectorized<uint8_t> inline clamp( | ||||||
|  |     const Vectorized<uint8_t>& a, | ||||||
|  |     const Vectorized<uint8_t>& min, | ||||||
|  |     const Vectorized<uint8_t>& max) { | ||||||
|  |   return minimum(max, maximum(min, a)); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | template <> | ||||||
|  | Vectorized<uint8_t> inline clamp_max( | ||||||
|  |     const Vectorized<uint8_t>& a, | ||||||
|  |     const Vectorized<uint8_t>& max) { | ||||||
|  |   return minimum(max, a); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | template <> | ||||||
|  | Vectorized<uint8_t> inline clamp_min( | ||||||
|  |     const Vectorized<uint8_t>& a, | ||||||
|  |     const Vectorized<uint8_t>& min) { | ||||||
|  |   return maximum(min, a); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | } // namespace CPU_CAPABILITY | ||||||
|  | } // namespace at::vec | ||||||
| @ -1740,7 +1740,7 @@ Vectorized<int16_t> inline shift_256_16( | |||||||
|  |  | ||||||
|   // Control masks for shuffle operation, treating 256 bits as an |   // Control masks for shuffle operation, treating 256 bits as an | ||||||
|   // array of 16-bit elements, and considering pairs of neighboring |   // array of 16-bit elements, and considering pairs of neighboring | ||||||
|   // elements.  Specifially, a mask named "ctl_M_N" (M,N in [0,1], and |   // elements.  Specifically, a mask named "ctl_M_N" (M,N in [0,1], and | ||||||
|   // M!=N) is set so that shuffle will move element with index M from |   // M!=N) is set so that shuffle will move element with index M from | ||||||
|   // input pair into element with index N in output pair, and element |   // input pair into element with index N in output pair, and element | ||||||
|   // with index M in output pair will be set to all 0s. |   // with index M in output pair will be set to all 0s. | ||||||
| @ -1875,7 +1875,7 @@ Vectorized<T> inline shift_256_8( | |||||||
|  |  | ||||||
|   // Control masks for shuffle operation, treating 256 bits as an |   // Control masks for shuffle operation, treating 256 bits as an | ||||||
|   // array of 8-bit elements, and considering quadruples of |   // array of 8-bit elements, and considering quadruples of | ||||||
|   // neighboring elements.  Specifially, a mask named "ctl_M_N" (M,N |   // neighboring elements.  Specifically, a mask named "ctl_M_N" (M,N | ||||||
|   // in [0,1,2,3], and M!=N) is set so that shuffle will move element |   // in [0,1,2,3], and M!=N) is set so that shuffle will move element | ||||||
|   // with index M from input quadruple into element with index N in |   // with index M from input quadruple into element with index N in | ||||||
|   // output quadruple, and other elements in output quadruple will be |   // output quadruple, and other elements in output quadruple will be | ||||||
|  | |||||||
| @ -1390,7 +1390,7 @@ std::pair<Vectorized<float>, Vectorized<float>> inline convert_int8_to_float( | |||||||
|  |  | ||||||
| std::pair<Vectorized<float>, Vectorized<float>> inline convert_int8_to_float( | std::pair<Vectorized<float>, Vectorized<float>> inline convert_int8_to_float( | ||||||
|     at::vec::Vectorized<uint8_t> src) { |     at::vec::Vectorized<uint8_t> src) { | ||||||
|   auto u8x8 = vld1_u8(src.operator const uint8_t*()); |   auto u8x8 = vget_low_u8(src); | ||||||
|   auto u16x8 = vmovl_u8(u8x8); |   auto u16x8 = vmovl_u8(u8x8); | ||||||
|   auto u32x4_hi = vmovl_u16(vget_high_u16(u16x8)); |   auto u32x4_hi = vmovl_u16(vget_high_u16(u16x8)); | ||||||
|   auto u32x4_lo = vmovl_u16(vget_low_u16(u16x8)); |   auto u32x4_lo = vmovl_u16(vget_low_u16(u16x8)); | ||||||
| @ -1412,7 +1412,7 @@ Vectorized<float> inline convert_int8_half_register_to_float( | |||||||
|  |  | ||||||
| Vectorized<float> inline convert_int8_half_register_to_float( | Vectorized<float> inline convert_int8_half_register_to_float( | ||||||
|     at::vec::Vectorized<uint8_t> src) { |     at::vec::Vectorized<uint8_t> src) { | ||||||
|   auto u8x8 = vld1_u8(src.operator const uint8_t*()); |   auto u8x8 = vget_low_u8(src); | ||||||
|   auto u16x8 = vmovl_u8(u8x8); |   auto u16x8 = vmovl_u8(u8x8); | ||||||
|   auto u32x4_lo = vmovl_u16(vget_low_u16(u16x8)); |   auto u32x4_lo = vmovl_u16(vget_low_u16(u16x8)); | ||||||
|  |  | ||||||
|  | |||||||
| @ -143,7 +143,7 @@ class Vectorized<double> { | |||||||
|       const Vectorized<double>& a, |       const Vectorized<double>& a, | ||||||
|       const Vectorized<double>& b, |       const Vectorized<double>& b, | ||||||
|       const Vectorized<double>& mask) { |       const Vectorized<double>& mask) { | ||||||
|     // the mask used here returned by comparision of vec256 |     // the mask used here returned by comparison of vec256 | ||||||
|  |  | ||||||
|     return { |     return { | ||||||
|         vec_sel(a._vec0, b._vec0, mask._vecb0), |         vec_sel(a._vec0, b._vec0, mask._vecb0), | ||||||
|  | |||||||
| @ -142,7 +142,7 @@ class Vectorized<float> { | |||||||
|       const Vectorized<float>& a, |       const Vectorized<float>& a, | ||||||
|       const Vectorized<float>& b, |       const Vectorized<float>& b, | ||||||
|       const Vectorized<float>& mask) { |       const Vectorized<float>& mask) { | ||||||
|     // the mask used here returned by comparision of vec256 |     // the mask used here returned by comparison of vec256 | ||||||
|     // assuming this we can use the same mask directly with vec_sel |     // assuming this we can use the same mask directly with vec_sel | ||||||
|     return { |     return { | ||||||
|         vec_sel(a._vec0, b._vec0, mask._vecb0), |         vec_sel(a._vec0, b._vec0, mask._vecb0), | ||||||
|  | |||||||
| @ -202,7 +202,7 @@ class Vectorized<int16_t> { | |||||||
|       const Vectorized<int16_t>& a, |       const Vectorized<int16_t>& a, | ||||||
|       const Vectorized<int16_t>& b, |       const Vectorized<int16_t>& b, | ||||||
|       const Vectorized<int16_t>& mask) { |       const Vectorized<int16_t>& mask) { | ||||||
|     // the mask used here returned by comparision of vec256 |     // the mask used here returned by comparison of vec256 | ||||||
|     // assuming this we can use the same mask directly with vec_sel |     // assuming this we can use the same mask directly with vec_sel | ||||||
|     // warning intel style mask will not work properly |     // warning intel style mask will not work properly | ||||||
|     return { |     return { | ||||||
|  | |||||||
| @ -155,7 +155,7 @@ class Vectorized<int32_t> { | |||||||
|       const Vectorized<int32_t>& a, |       const Vectorized<int32_t>& a, | ||||||
|       const Vectorized<int32_t>& b, |       const Vectorized<int32_t>& b, | ||||||
|       const Vectorized<int32_t>& mask) { |       const Vectorized<int32_t>& mask) { | ||||||
|     // the mask used here returned by comparision of vec256 |     // the mask used here returned by comparison of vec256 | ||||||
|     // assuming this we can use the same mask directly with vec_sel |     // assuming this we can use the same mask directly with vec_sel | ||||||
|     // warning intel style mask will not work properly |     // warning intel style mask will not work properly | ||||||
|     return { |     return { | ||||||
|  | |||||||
| @ -119,7 +119,7 @@ class Vectorized<int64_t> { | |||||||
|       const Vectorized<int64_t>& a, |       const Vectorized<int64_t>& a, | ||||||
|       const Vectorized<int64_t>& b, |       const Vectorized<int64_t>& b, | ||||||
|       const Vectorized<int64_t>& mask) { |       const Vectorized<int64_t>& mask) { | ||||||
|     // the mask used here returned by comparision of vec256 |     // the mask used here returned by comparison of vec256 | ||||||
|  |  | ||||||
|     return { |     return { | ||||||
|         vec_sel(a._vec0, b._vec0, mask._vecb0), |         vec_sel(a._vec0, b._vec0, mask._vecb0), | ||||||
|  | |||||||
| @ -397,7 +397,7 @@ inline Vectorized<bool> operator&&( | |||||||
|   const __m512i* other_ = reinterpret_cast<const __m512i*>(other.as_bytes()); |   const __m512i* other_ = reinterpret_cast<const __m512i*>(other.as_bytes()); | ||||||
|   __m512i out = _mm512_and_si512(*self_, *other_); |   __m512i out = _mm512_and_si512(*self_, *other_); | ||||||
|   Vectorized<bool> ret; |   Vectorized<bool> ret; | ||||||
|   // We do not have a constructer that takes __m512i, so we need to memcpy |   // We do not have a constructor that takes __m512i, so we need to memcpy | ||||||
|   std::memcpy(ret, &out, ret.size() * sizeof(bool)); |   std::memcpy(ret, &out, ret.size() * sizeof(bool)); | ||||||
|   return ret; |   return ret; | ||||||
| } | } | ||||||
|  | |||||||
| @ -1852,7 +1852,7 @@ Vectorized<T> inline shift_512_8( | |||||||
|  |  | ||||||
|   // Control masks for shuffle operation, treating 512 bits as an |   // Control masks for shuffle operation, treating 512 bits as an | ||||||
|   // array of 8-bit elements, and considering pairs of neighboring |   // array of 8-bit elements, and considering pairs of neighboring | ||||||
|   // elements.  Specifially, a mask named "ctl_M_N" (M,N in [0,1], and |   // elements.  Specifically, a mask named "ctl_M_N" (M,N in [0,1], and | ||||||
|   // M!=N) is set so that shuffle will move element with index M from |   // M!=N) is set so that shuffle will move element with index M from | ||||||
|   // input pair into element with index N in output pair, and element |   // input pair into element with index N in output pair, and element | ||||||
|   // with index M in output pair will be set to all 0s. |   // with index M in output pair will be set to all 0s. | ||||||
|  | |||||||
| @ -634,7 +634,7 @@ struct Vectorized { | |||||||
|   } |   } | ||||||
|   Vectorized<T> neg() const { |   Vectorized<T> neg() const { | ||||||
|     // NB: the trailing return type is needed because we need to coerce the |     // NB: the trailing return type is needed because we need to coerce the | ||||||
|     // return value back to T in the case of unary operator- incuring a |     // return value back to T in the case of unary operator- incurring a | ||||||
|     // promotion |     // promotion | ||||||
|     return map([](T x) -> T { return -x; }); |     return map([](T x) -> T { return -x; }); | ||||||
|   } |   } | ||||||
|  | |||||||
| @ -1958,7 +1958,7 @@ void scaled_gemm( | |||||||
|     ScalarType result_dtype, |     ScalarType result_dtype, | ||||||
|     bool use_fast_accum, |     bool use_fast_accum, | ||||||
|     const std::optional<Tensor>& alpha) { |     const std::optional<Tensor>& alpha) { | ||||||
|   // Note: see `cublasCommonArgs` for various non-intuitive manupulations |   // Note: see `cublasCommonArgs` for various non-intuitive manipulations | ||||||
|   // of input arguments to this function. |   // of input arguments to this function. | ||||||
|   const auto computeType = CUBLAS_COMPUTE_32F; |   const auto computeType = CUBLAS_COMPUTE_32F; | ||||||
|   const auto scaleType = CUDA_R_32F; |   const auto scaleType = CUDA_R_32F; | ||||||
|  | |||||||
| @ -2,10 +2,10 @@ | |||||||
|  |  | ||||||
| #include <ATen/cuda/ATenCUDAGeneral.h> | #include <ATen/cuda/ATenCUDAGeneral.h> | ||||||
| #include <ATen/cuda/CUDAContext.h> | #include <ATen/cuda/CUDAContext.h> | ||||||
| #include <c10/core/impl/GPUTrace.h> |  | ||||||
| #include <c10/cuda/CUDAStream.h> |  | ||||||
| #include <c10/cuda/CUDAGuard.h> |  | ||||||
| #include <ATen/cuda/Exceptions.h> | #include <ATen/cuda/Exceptions.h> | ||||||
|  | #include <c10/core/impl/GPUTrace.h> | ||||||
|  | #include <c10/cuda/CUDAGuard.h> | ||||||
|  | #include <c10/cuda/CUDAStream.h> | ||||||
| #include <c10/util/Exception.h> | #include <c10/util/Exception.h> | ||||||
|  |  | ||||||
| #include <cuda_runtime_api.h> | #include <cuda_runtime_api.h> | ||||||
| @ -246,4 +246,79 @@ private: | |||||||
|   } |   } | ||||||
| }; | }; | ||||||
|  |  | ||||||
|  | // EventPool - Thread-safe pool of CUDA events to avoid expensive cudaEventCreate | ||||||
|  | // calls. cudaEventCreate when concurrently invoked from multiple threads can be | ||||||
|  | // very expensive (especially on certain device/driver combinations). | ||||||
|  | using CUDAEventPtr = | ||||||
|  |     std::unique_ptr<CUDAEvent, std::function<void(CUDAEvent*)>>; | ||||||
|  |  | ||||||
|  | class EventPool { | ||||||
|  |  public: | ||||||
|  |   EventPool() : pools_(at::cuda::device_count()) {} | ||||||
|  |  | ||||||
|  |   CUDAEventPtr get(const DeviceIndex device) { | ||||||
|  |     // If the device is invalid, return a default event and no pooling | ||||||
|  |     if (device < 0 || device >= (DeviceIndex)pools_.size()) { | ||||||
|  |       auto deleter = [](CUDAEvent* event) { | ||||||
|  |         delete event; | ||||||
|  |       }; | ||||||
|  |       return CUDAEventPtr( | ||||||
|  |         std::make_unique<CUDAEvent>(cudaEventDisableTiming).release(), deleter); | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     auto& pool = pools_[device]; | ||||||
|  |  | ||||||
|  |     // Create a destructor that returns the event to the appropriate device pool | ||||||
|  |     auto destructor = [&pool](CUDAEvent* event) noexcept { | ||||||
|  |       if (event != nullptr) { | ||||||
|  |         std::lock_guard<std::mutex> lock(pool.mutex_); | ||||||
|  |         pool.event_pool_.emplace_back(event); | ||||||
|  |       } | ||||||
|  |     }; | ||||||
|  |  | ||||||
|  |     { | ||||||
|  |       std::lock_guard<std::mutex> lock(pool.mutex_); | ||||||
|  |       if (!pool.event_pool_.empty()) { | ||||||
|  |         auto event = std::move(pool.event_pool_.back()); | ||||||
|  |         pool.event_pool_.pop_back(); | ||||||
|  |         return CUDAEventPtr(event.release(), destructor); | ||||||
|  |       } | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     return CUDAEventPtr( | ||||||
|  |         std::make_unique<CUDAEvent>(cudaEventDisableTiming).release(), | ||||||
|  |         destructor); | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   void empty_cache() { | ||||||
|  |     for (auto& pool : pools_) { | ||||||
|  |       std::lock_guard<std::mutex> lock(pool.mutex_); | ||||||
|  |       pool.event_pool_.clear(); | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   void init_num_events(const size_t num_events) { | ||||||
|  |     for (DeviceIndex device_idx = 0; device_idx < at::cuda::device_count(); ++device_idx) { | ||||||
|  |         CUDAGuard device_guard(device_idx); | ||||||
|  |         std::vector<CUDAEventPtr> temp_events; | ||||||
|  |         temp_events.reserve(num_events); | ||||||
|  |         for (size_t i = 0; i < num_events; ++i) { | ||||||
|  |           auto event = get(device_idx); | ||||||
|  |           // Record the event to ensure it's properly initialized | ||||||
|  |           event->record(); | ||||||
|  |           temp_events.emplace_back(std::move(event)); | ||||||
|  |         } | ||||||
|  |         // Events will be returned to pool when temp_events is destroyed | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |  private: | ||||||
|  |   struct alignas(64) PerDevicePool { | ||||||
|  |     alignas(64) std::mutex mutex_; | ||||||
|  |     std::vector<std::unique_ptr<CUDAEvent>> event_pool_; | ||||||
|  |   }; | ||||||
|  |  | ||||||
|  |   std::vector<PerDevicePool> pools_; | ||||||
|  | }; | ||||||
|  |  | ||||||
| } // namespace at::cuda | } // namespace at::cuda | ||||||
|  | |||||||
| @ -168,11 +168,9 @@ void CUDAGraph::instantiate() { | |||||||
|   // https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__GRAPH.html#group__CUDART__GRAPH_1g1accfe1da0c605a577c22d9751a09597 |   // https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__GRAPH.html#group__CUDART__GRAPH_1g1accfe1da0c605a577c22d9751a09597 | ||||||
|   // cudaGraphInstantiateWithFlags |   // cudaGraphInstantiateWithFlags | ||||||
|   // https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__GRAPH.html#group__CUDART__GRAPH_1ga2c652a24ba93e52b99a47bec0888233 |   // https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__GRAPH.html#group__CUDART__GRAPH_1ga2c652a24ba93e52b99a47bec0888233 | ||||||
| #if !defined(USE_ROCM) || ROCM_VERSION >= 60200 |  | ||||||
|   int version = 0; |   int version = 0; | ||||||
|   AT_CUDA_CHECK(cudaDriverGetVersion(&version)); |   AT_CUDA_CHECK(cudaDriverGetVersion(&version)); | ||||||
|   if (version < 11040) { |   if (version < 11040) { | ||||||
| #endif |  | ||||||
|     // Trailing NULL, NULL, 0 arguments were recommended by Cuda driver people, |     // Trailing NULL, NULL, 0 arguments were recommended by Cuda driver people, | ||||||
|     // who prefer not to report error message through these arguments moving forward |     // who prefer not to report error message through these arguments moving forward | ||||||
|     // (they prefer return value, or errors on api calls internal to the capture) |     // (they prefer return value, or errors on api calls internal to the capture) | ||||||
| @ -183,13 +181,11 @@ void CUDAGraph::instantiate() { | |||||||
| #endif | #endif | ||||||
| //Since ROCm 6.2, we want to go down this path as hipGraphExecDestroy in the destructor will not immediately free the memory. | //Since ROCm 6.2, we want to go down this path as hipGraphExecDestroy in the destructor will not immediately free the memory. | ||||||
| //It will wait for the next sync operation. cudaGraphInstantiateFlagAutoFreeOnLaunch will add async frees after graph launch. | //It will wait for the next sync operation. cudaGraphInstantiateFlagAutoFreeOnLaunch will add async frees after graph launch. | ||||||
| #if !defined(USE_ROCM) || ROCM_VERSION >= 60200 |  | ||||||
|   } else { |   } else { | ||||||
|     AT_CUDA_CHECK(cudaGraphInstantiateWithFlags(&graph_exec_, |     AT_CUDA_CHECK(cudaGraphInstantiateWithFlags(&graph_exec_, | ||||||
|                                                 graph_, |                                                 graph_, | ||||||
|                                                 cudaGraphInstantiateFlagAutoFreeOnLaunch)); |                                                 cudaGraphInstantiateFlagAutoFreeOnLaunch)); | ||||||
|   } |   } | ||||||
| #endif |  | ||||||
|   has_graph_exec_ = true; |   has_graph_exec_ = true; | ||||||
| } | } | ||||||
|  |  | ||||||
| @ -311,7 +307,7 @@ CUDAGraph::~CUDAGraph() { | |||||||
| // There are recent HIP changes where hipGraphExecDestroy doesn't immediately free memory. | // There are recent HIP changes where hipGraphExecDestroy doesn't immediately free memory. | ||||||
| // They wait for next sync point in order to free the memory, this is to ensure that all | // They wait for next sync point in order to free the memory, this is to ensure that all | ||||||
| // hipGraphLaunch are finished before we release any memory. This feature was enabled in rocm6.2. | // hipGraphLaunch are finished before we release any memory. This feature was enabled in rocm6.2. | ||||||
| // We need to ensure all async opreations finish before deleting the object. | // We need to ensure all async operations finish before deleting the object. | ||||||
| #if (defined(USE_ROCM) && ROCM_VERSION >= 60200) | #if (defined(USE_ROCM) && ROCM_VERSION >= 60200) | ||||||
|   if (capture_dev_ != UNDEFINED_DEVICE) // check if capture_dev_ contains the real device id |   if (capture_dev_ != UNDEFINED_DEVICE) // check if capture_dev_ contains the real device id | ||||||
|   { |   { | ||||||
|  | |||||||
							
								
								
									
										192
									
								
								aten/src/ATen/cuda/CUDAGreenContext.cpp
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										192
									
								
								aten/src/ATen/cuda/CUDAGreenContext.cpp
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,192 @@ | |||||||
|  | #include <ATen/cuda/CUDAGreenContext.h> | ||||||
|  |  | ||||||
|  | namespace at::cuda { | ||||||
|  |   GreenContext::GreenContext(uint32_t device_id, uint32_t num_sms) { | ||||||
|  | #if CUDA_HAS_GREEN_CONTEXT | ||||||
|  |     int driver_version; | ||||||
|  |     C10_CUDA_CHECK(cudaDriverGetVersion(&driver_version)); | ||||||
|  |     TORCH_CHECK( | ||||||
|  |         driver_version >= 12080, "cuda driver too old to use green context!"); | ||||||
|  |     CUcontext pctx = nullptr; | ||||||
|  |     C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuCtxGetCurrent_(&pctx)); | ||||||
|  |     if (C10_UNLIKELY(!pctx)) { | ||||||
|  |       TORCH_WARN( | ||||||
|  |           "Attempted to create a green context but" | ||||||
|  |           " there was no primary context! Creating a primary context..."); | ||||||
|  |  | ||||||
|  |       cudaFree(0); | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     CUdevice device; | ||||||
|  |     device_id_ = device_id; | ||||||
|  |     C10_CUDA_DRIVER_CHECK( | ||||||
|  |         c10::cuda::DriverAPI::get()->cuDeviceGet_(&device, device_id)); | ||||||
|  |  | ||||||
|  |     // Get device resources | ||||||
|  |     CUdevResource device_resource; | ||||||
|  |     C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuDeviceGetDevResource_( | ||||||
|  |         device, &device_resource, CU_DEV_RESOURCE_TYPE_SM)); | ||||||
|  |  | ||||||
|  |     // Split resources | ||||||
|  |     std::vector<CUdevResource> result(1); | ||||||
|  |     auto result_data = result.data(); | ||||||
|  |     unsigned int nb_groups = 1; | ||||||
|  |     CUdevResource remaining; | ||||||
|  |  | ||||||
|  |     C10_CUDA_DRIVER_CHECK( | ||||||
|  |         c10::cuda::DriverAPI::get()->cuDevSmResourceSplitByCount_( | ||||||
|  |             result_data, | ||||||
|  |             &nb_groups, | ||||||
|  |             &device_resource, | ||||||
|  |             &remaining, | ||||||
|  |             0, // default flags | ||||||
|  |             num_sms)); | ||||||
|  |  | ||||||
|  |     TORCH_CHECK(nb_groups == 1, "Failed to create single resource group"); | ||||||
|  |  | ||||||
|  |     // Generate resource descriptor | ||||||
|  |     CUdevResourceDesc desc; | ||||||
|  |     C10_CUDA_DRIVER_CHECK( | ||||||
|  |         c10::cuda::DriverAPI::get()->cuDevResourceGenerateDesc_( | ||||||
|  |             &desc, result_data, 1)); | ||||||
|  |  | ||||||
|  |     // Create green context | ||||||
|  |     // CU_GREEN_CTX_DEFAULT_STREAM is required per docs: | ||||||
|  |     // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__GREEN__CONTEXTS.html | ||||||
|  |     C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuGreenCtxCreate_( | ||||||
|  |         &green_ctx_, desc, device, CU_GREEN_CTX_DEFAULT_STREAM)); | ||||||
|  |  | ||||||
|  |     // Convert to regular context | ||||||
|  |     C10_CUDA_DRIVER_CHECK( | ||||||
|  |         c10::cuda::DriverAPI::get()->cuCtxFromGreenCtx_(&context_, green_ctx_)); | ||||||
|  |     TORCH_CHECK(context_, "Green ctx conversion to regular ctx failed!"); | ||||||
|  | #else | ||||||
|  |     TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!"); | ||||||
|  | #endif | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   std::unique_ptr<GreenContext> GreenContext::create( | ||||||
|  |       uint32_t num_sms, | ||||||
|  |       std::optional<uint32_t> device_id) { | ||||||
|  | #if CUDA_HAS_GREEN_CONTEXT | ||||||
|  |     if (!device_id.has_value()) { | ||||||
|  |       device_id = at::cuda::current_device(); | ||||||
|  |     } | ||||||
|  |     return std::make_unique<GreenContext>(device_id.value(), num_sms); | ||||||
|  | #else | ||||||
|  |     TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!"); | ||||||
|  | #endif | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   // Implement move operations | ||||||
|  |   GreenContext::GreenContext(GreenContext&& other) noexcept{ | ||||||
|  | #if CUDA_HAS_GREEN_CONTEXT | ||||||
|  |     device_id_ = std::exchange(other.device_id_, -1); | ||||||
|  |     green_ctx_ = std::exchange(other.green_ctx_, nullptr); | ||||||
|  |     context_ = std::exchange(other.context_, nullptr); | ||||||
|  |     parent_stream_ = std::exchange(other.parent_stream_, nullptr); | ||||||
|  | #else | ||||||
|  |     TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!"); | ||||||
|  | #endif | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   GreenContext& GreenContext::operator=(GreenContext&& other) noexcept{ | ||||||
|  | #if CUDA_HAS_GREEN_CONTEXT | ||||||
|  |     if (this != &other) { | ||||||
|  |       // Clean up current resources | ||||||
|  |       if (green_ctx_) { | ||||||
|  |         CUcontext current = nullptr; | ||||||
|  |         C10_CUDA_DRIVER_CHECK( | ||||||
|  |             c10::cuda::DriverAPI::get()->cuCtxGetCurrent_(¤t)); | ||||||
|  |         if (current == context_) { | ||||||
|  |           TORCH_CHECK( | ||||||
|  |               false, | ||||||
|  |               "attempting to overwrite current green ctx " | ||||||
|  |               "when it is active!"); | ||||||
|  |         } | ||||||
|  |         C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuGreenCtxDestroy_(green_ctx_)); | ||||||
|  |       } | ||||||
|  |  | ||||||
|  |       // Take ownership of other's resources | ||||||
|  |       device_id_ = std::exchange(other.device_id_, -1); | ||||||
|  |       green_ctx_ = std::exchange(other.green_ctx_, nullptr); | ||||||
|  |       context_ = std::exchange(other.context_, nullptr); | ||||||
|  |       parent_stream_ = std::exchange(other.parent_stream_, nullptr); | ||||||
|  |     } | ||||||
|  |     return *this; | ||||||
|  | #else | ||||||
|  |     TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!"); | ||||||
|  | #endif | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   GreenContext::~GreenContext() noexcept{ | ||||||
|  | #if CUDA_HAS_GREEN_CONTEXT | ||||||
|  |     C10_CUDA_DRIVER_CHECK( | ||||||
|  |         c10::cuda::DriverAPI::get()->cuGreenCtxDestroy_(green_ctx_)); | ||||||
|  | #else | ||||||
|  |     TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!"); | ||||||
|  | #endif | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   // Get the underlying CUDA context | ||||||
|  |   CUcontext GreenContext::getContext() const { | ||||||
|  | #if CUDA_HAS_GREEN_CONTEXT | ||||||
|  |     return context_; | ||||||
|  | #else | ||||||
|  |     TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!"); | ||||||
|  | #endif | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   // Get the underlying green context | ||||||
|  | #if CUDA_HAS_GREEN_CONTEXT | ||||||
|  |   CUgreenCtx GreenContext::getGreenContext() const { | ||||||
|  |     return green_ctx_; | ||||||
|  |   } | ||||||
|  | #endif | ||||||
|  |  | ||||||
|  |   // Make this context current | ||||||
|  |   void GreenContext::setContext() { | ||||||
|  | #if CUDA_HAS_GREEN_CONTEXT | ||||||
|  |     auto current_stream = c10::cuda::getCurrentCUDAStream(); | ||||||
|  |     parent_stream_ = current_stream.stream(); | ||||||
|  |  | ||||||
|  |     at::cuda::CUDAEvent ev; | ||||||
|  |     ev.record(current_stream); | ||||||
|  |  | ||||||
|  |     CUcontext current = nullptr; | ||||||
|  |     C10_CUDA_DRIVER_CHECK( | ||||||
|  |         c10::cuda::DriverAPI::get()->cuCtxGetCurrent_(¤t)); | ||||||
|  |     if (!current) { | ||||||
|  |       C10_CUDA_DRIVER_CHECK( | ||||||
|  |           c10::cuda::DriverAPI::get()->cuCtxSetCurrent_(context_)); | ||||||
|  |     } else { | ||||||
|  |       C10_CUDA_DRIVER_CHECK( | ||||||
|  |           c10::cuda::DriverAPI::get()->cuCtxPushCurrent_(context_)); | ||||||
|  |     } | ||||||
|  |     // currently hardcodes the new green context to use the default stream | ||||||
|  |     // TODO(eqy): consider creating a new stream if e.g., it allows interop | ||||||
|  |     // with CUDA Graph captures etc. | ||||||
|  |     auto default_stream = c10::cuda::getDefaultCUDAStream(); | ||||||
|  |     ev.block(default_stream); | ||||||
|  |     c10::cuda::setCurrentCUDAStream(default_stream); | ||||||
|  | #else | ||||||
|  |     TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!"); | ||||||
|  | #endif | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   void GreenContext::popContext() { | ||||||
|  | #if CUDA_HAS_GREEN_CONTEXT | ||||||
|  |     // see above note about stream being hardcoded to the default stream | ||||||
|  |     at::cuda::CUDAEvent ev; | ||||||
|  |     ev.record(c10::cuda::getCurrentCUDAStream()); | ||||||
|  |     CUcontext popped; | ||||||
|  |     C10_CUDA_DRIVER_CHECK( | ||||||
|  |         c10::cuda::DriverAPI::get()->cuCtxPopCurrent_(&popped)); | ||||||
|  |     TORCH_INTERNAL_ASSERT( | ||||||
|  |         popped == context_, "expected popped context to be the current ctx"); | ||||||
|  |     ev.block(c10::cuda::getStreamFromExternal(parent_stream_, device_id_)); | ||||||
|  | #else | ||||||
|  |     TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!"); | ||||||
|  | #endif | ||||||
|  |   } | ||||||
|  | } // namespace at::cuda | ||||||
							
								
								
									
										53
									
								
								aten/src/ATen/cuda/CUDAGreenContext.h
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										53
									
								
								aten/src/ATen/cuda/CUDAGreenContext.h
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,53 @@ | |||||||
|  | #pragma once | ||||||
|  | #include <ATen/cuda/CUDAEvent.h> | ||||||
|  |  | ||||||
|  | #if defined(CUDA_VERSION) && !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED) | ||||||
|  | #include <c10/cuda/driver_api.h> | ||||||
|  | #include <cuda.h> | ||||||
|  | #include <memory> | ||||||
|  | #include <stdexcept> | ||||||
|  | #include <vector> | ||||||
|  | #define CUDA_HAS_GREEN_CONTEXT 1 | ||||||
|  | #else | ||||||
|  | #define CUDA_HAS_GREEN_CONTEXT 0 | ||||||
|  | #endif | ||||||
|  |  | ||||||
|  | namespace at::cuda { | ||||||
|  |  | ||||||
|  | class TORCH_CUDA_CPP_API GreenContext { | ||||||
|  |  public: | ||||||
|  |   GreenContext(uint32_t device_id, uint32_t num_sms); | ||||||
|  |  | ||||||
|  |   static std::unique_ptr<GreenContext> create(uint32_t num_sms, std::optional<uint32_t> device_id); | ||||||
|  |  | ||||||
|  |   // Delete copy constructor and assignment | ||||||
|  |   GreenContext(const GreenContext&) = delete; | ||||||
|  |   GreenContext& operator=(const GreenContext&) = delete; | ||||||
|  |  | ||||||
|  |   // Implement move operations | ||||||
|  |   GreenContext(GreenContext&& other) noexcept; | ||||||
|  |   GreenContext& operator=(GreenContext&& other) noexcept; | ||||||
|  |   ~GreenContext() noexcept; | ||||||
|  |  | ||||||
|  |   // Get the underlying CUDA context | ||||||
|  |   CUcontext getContext() const; | ||||||
|  |  | ||||||
|  |   // Get the underlying green context | ||||||
|  | #if CUDA_HAS_GREEN_CONTEXT | ||||||
|  |   CUgreenCtx getGreenContext() const; | ||||||
|  | #endif | ||||||
|  |  | ||||||
|  |   // Make this context current | ||||||
|  |   void setContext(); | ||||||
|  |  | ||||||
|  |   void popContext(); | ||||||
|  |  | ||||||
|  |  private: | ||||||
|  | #if CUDA_HAS_GREEN_CONTEXT | ||||||
|  |   int32_t device_id_ = -1; | ||||||
|  |   CUgreenCtx green_ctx_ = nullptr; | ||||||
|  |   CUcontext context_ = nullptr; | ||||||
|  |   cudaStream_t parent_stream_ = nullptr; | ||||||
|  | #endif | ||||||
|  | }; | ||||||
|  | } // namespace at::cuda | ||||||
							
								
								
									
										270
									
								
								aten/src/ATen/cuda/CUDAScaledBlas.cpp
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										270
									
								
								aten/src/ATen/cuda/CUDAScaledBlas.cpp
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,270 @@ | |||||||
|  | #include <cstdint> | ||||||
|  | #include <c10/util/typeid.h> | ||||||
|  | #include <c10/util/Exception.h> | ||||||
|  | #include <c10/util/SmallVector.h> | ||||||
|  | #include <c10/core/Scalar.h> | ||||||
|  | #include <c10/core/ScalarType.h> | ||||||
|  | #include <c10/util/Exception.h> | ||||||
|  | #define TORCH_ASSERT_ONLY_METHOD_OPERATORS | ||||||
|  | #include <ATen/core/Tensor.h> | ||||||
|  | #include <ATen/core/NamedTensor.h> | ||||||
|  | #include <ATen/Dispatch.h> | ||||||
|  | #include <ATen/ExpandUtils.h> | ||||||
|  | #include <ATen/OpMathType.h> | ||||||
|  | #include <ATen/TensorUtils.h> | ||||||
|  | #include <ATen/cuda/CUDABlas.h> | ||||||
|  | #include <ATen/cuda/tunable/Tunable.h> | ||||||
|  | #include <ATen/cuda/tunable/TunableGemm.h> | ||||||
|  | #include <ATen/native/Resize.h> | ||||||
|  | #include <c10/util/MaybeOwned.h> | ||||||
|  | #include <ATen/native/GroupedMMUtils.h> | ||||||
|  | #include <ATen/native/cuda/RowwiseScaledMM.h> | ||||||
|  | #include <ATen/native/cuda/ScaledGroupMM.h> | ||||||
|  | #include <ATen/native/cuda/GroupMM.h> | ||||||
|  | #include <ATen/ceil_div.h> | ||||||
|  |  | ||||||
|  | #ifdef USE_FBGEMM_GENAI | ||||||
|  | #include <fbgemm_gpu/torch_ops.h> | ||||||
|  | #endif | ||||||
|  |  | ||||||
|  | #ifndef AT_PER_OPERATOR_HEADERS | ||||||
|  | #include <ATen/Functions.h> | ||||||
|  | #include <ATen/NativeFunctions.h> | ||||||
|  | #else | ||||||
|  | #include <ATen/ops/_addmm_activation_native.h> | ||||||
|  | #include <ATen/ops/_efficientzerotensor.h> | ||||||
|  | #include <ATen/ops/_scaled_mm_native.h> | ||||||
|  | #include <ATen/ops/_unsafe_view_native.h> | ||||||
|  | #include <ATen/ops/abs.h> | ||||||
|  | #include <ATen/ops/addmm_native.h> | ||||||
|  | #include <ATen/ops/addmv_native.h> | ||||||
|  | #include <ATen/ops/baddbmm_native.h> | ||||||
|  | #include <ATen/ops/bmm_native.h> | ||||||
|  | #include <ATen/ops/copy_native.h> | ||||||
|  | #include <ATen/ops/dot_native.h> | ||||||
|  | #include <ATen/ops/empty.h> | ||||||
|  | #include <ATen/ops/empty_strided.h> | ||||||
|  | #include <ATen/ops/gelu.h> | ||||||
|  | #include <ATen/ops/max.h> | ||||||
|  | #include <ATen/ops/mm_native.h> | ||||||
|  | #include <ATen/ops/mul.h> | ||||||
|  | #include <ATen/ops/relu.h> | ||||||
|  | #include <ATen/ops/ones.h> | ||||||
|  | #include <ATen/ops/scalar_tensor_native.h> | ||||||
|  | #include <ATen/ops/vdot_native.h> | ||||||
|  | #endif | ||||||
|  |  | ||||||
|  | using at::blas::ScalingType; | ||||||
|  | using at::blas::SwizzleType; | ||||||
|  |  | ||||||
|  | namespace at::cuda::scaled { | ||||||
|  |  | ||||||
|  | /** | ||||||
|  |  * Both inputs must be fp8, | ||||||
|  |  * Each needs a single scale, {Tensorwise (float)} | ||||||
|  |  */ | ||||||
|  | bool check_tensorwise_recipe(c10::ScalarType type_a, | ||||||
|  |                              std::vector<ScalingType>& recipe_a, | ||||||
|  |                              ArrayRef<Tensor>& scales_a, | ||||||
|  |                              c10::ScalarType type_b, | ||||||
|  |                              std::vector<ScalingType>& recipe_b, | ||||||
|  |                              ArrayRef<Tensor>& scales_b) { | ||||||
|  |   // both types must be fp8 | ||||||
|  |   if (!isFloat8Type(type_a) || !isFloat8Type(type_b)) { | ||||||
|  |     return false; | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   // 1 scale each, {Tensorwise, float} | ||||||
|  |   if (scales_a.size() != 1 || recipe_a.size() != 1 || scales_b.size() != 1 || recipe_b.size() != 1) { | ||||||
|  |     return false; | ||||||
|  |   } | ||||||
|  |   // Need {Blockwise_1x32, e8m0} for A & B | ||||||
|  |   if (recipe_a[0] != ScalingType::TensorWise) return false; | ||||||
|  |   if (scales_a[0].scalar_type() != ScalarType::Float) return false; | ||||||
|  |   if (recipe_b[0] != ScalingType::TensorWise) return false; | ||||||
|  |   if (scales_b[0].scalar_type() != ScalarType::Float) return false; | ||||||
|  |  | ||||||
|  |   return true; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | /** | ||||||
|  |  * Both inputs must be fp8, | ||||||
|  |  * Each needs scales, {Rowwise (float)} | ||||||
|  |  */ | ||||||
|  | bool check_rowwise_recipe(c10::ScalarType type_a, | ||||||
|  |                              std::vector<ScalingType>& recipe_a, | ||||||
|  |                              ArrayRef<Tensor>& scales_a, | ||||||
|  |                              c10::ScalarType type_b, | ||||||
|  |                              std::vector<ScalingType>& recipe_b, | ||||||
|  |                              ArrayRef<Tensor>& scales_b) { | ||||||
|  |   // both types must be fp8 | ||||||
|  |   if (!isFloat8Type(type_a) || !isFloat8Type(type_b)) { | ||||||
|  |     return false; | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   // 1 scale each, {Tensorwise, float} | ||||||
|  |   if (scales_a.size() != 1 || recipe_a.size() != 1 || scales_b.size() != 1 || recipe_b.size() != 1) { | ||||||
|  |     return false; | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   // Need {RowWise, dp32} for A & B | ||||||
|  |   if (recipe_a[0] != ScalingType::RowWise) return false; | ||||||
|  |   if (scales_a[0].scalar_type() != ScalarType::Float) return false; | ||||||
|  |   if (recipe_b[0] != ScalingType::RowWise) return false; | ||||||
|  |   if (scales_b[0].scalar_type() != ScalarType::Float) return false; | ||||||
|  |  | ||||||
|  |   return true; | ||||||
|  | } | ||||||
|  |  | ||||||
|  |  | ||||||
|  | /** | ||||||
|  |  * Two-level scaling, canonical NVFP4 | ||||||
|  |  * Both inputs must be fp4 | ||||||
|  |  * A, B need 2 scales, {Blockwise_1x16 (e4m3), Tensorwise (fp32)} | ||||||
|  |  */ | ||||||
|  | bool check_nvfp4_recipe(c10::ScalarType type_a, | ||||||
|  |                         std::vector<ScalingType>& recipe_a, | ||||||
|  |                         ArrayRef<Tensor>& scales_a, | ||||||
|  |                         c10::ScalarType type_b, | ||||||
|  |                         std::vector<ScalingType>& recipe_b, | ||||||
|  |                         ArrayRef<Tensor>& scales_b) { | ||||||
|  |   // both types must be fp4 | ||||||
|  |   if (type_a != ScalarType::Float4_e2m1fn_x2 || type_b != ScalarType::Float4_e2m1fn_x2) { | ||||||
|  |     return false; | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   // 2 scales, 2 recipes for each input | ||||||
|  |   if (scales_a.size() != 2 || recipe_a.size() != 2 || scales_b.size() != 2 || recipe_b.size() != 2) { | ||||||
|  |     return false; | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   // Need {Blockwise_1x16, e4m3 for scale[0], Tensorwise, fp32 for scale[1]} | ||||||
|  |   if (recipe_a[0] != ScalingType::BlockWise1x16 || recipe_a[1] != ScalingType::TensorWise) return false; | ||||||
|  |   if (scales_a[0].scalar_type() != ScalarType::Float8_e4m3fn || scales_a[1].scalar_type() != ScalarType::Float) return false; | ||||||
|  |   if (recipe_b[0] != ScalingType::BlockWise1x16 || recipe_b[1] != ScalingType::TensorWise) return false; | ||||||
|  |   if (scales_b[0].scalar_type() != ScalarType::Float8_e4m3fn || scales_b[1].scalar_type() != ScalarType::Float) return false; | ||||||
|  |  | ||||||
|  |   return true; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | /** | ||||||
|  |  * Single-level scaling, what PyT currently understands | ||||||
|  |  * Both inputs must be fp4 | ||||||
|  |  * A, B need 1 scale, {Blockwise_1x16 (e4m3)} | ||||||
|  |  */ | ||||||
|  | bool check_nvfp4_recipe_single_scale | ||||||
|  |                        (c10::ScalarType type_a, | ||||||
|  |                         std::vector<ScalingType>& recipe_a, | ||||||
|  |                         ArrayRef<Tensor>& scales_a, | ||||||
|  |                         c10::ScalarType type_b, | ||||||
|  |                         std::vector<ScalingType>& recipe_b, | ||||||
|  |                         ArrayRef<Tensor>& scales_b) { | ||||||
|  |   // both types must be fp4 | ||||||
|  |   if (type_a != ScalarType::Float4_e2m1fn_x2 || type_b != ScalarType::Float4_e2m1fn_x2) { | ||||||
|  |     return false; | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   // 2 scales, 2 recipes for each input | ||||||
|  |   if (scales_a.size() != 1 || recipe_a.size() != 1 || scales_b.size() != 1 || recipe_b.size() != 1) { | ||||||
|  |     return false; | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   // Need {Blockwise_1x16, e4m3 for scale[0], Tensorwise, fp32 for scale[1]} | ||||||
|  |   if (recipe_a[0] != ScalingType::BlockWise1x16) return false; | ||||||
|  |   if (scales_a[0].scalar_type() != ScalarType::Float8_e4m3fn) return false; | ||||||
|  |   if (recipe_b[0] != ScalingType::BlockWise1x16) return false; | ||||||
|  |   if (scales_b[0].scalar_type() != ScalarType::Float8_e4m3fn) return false; | ||||||
|  |  | ||||||
|  |   return true; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | /** | ||||||
|  |  * Both inputs must be fp8 | ||||||
|  |  * A, B must only have 1 scale each, A: {Blockwise_1x128 (float), B: {Blockwise_128x128 (float) | ||||||
|  |  */ | ||||||
|  | bool check_deepseek_recipe(ScalingType expected_recipe_a, | ||||||
|  |                            ScalingType expected_recipe_b, | ||||||
|  |                            c10::ScalarType type_a, | ||||||
|  |                            std::vector<ScalingType>& recipe_a, | ||||||
|  |                            ArrayRef<Tensor>& scales_a, | ||||||
|  |                            c10::ScalarType type_b, | ||||||
|  |                            std::vector<ScalingType>& recipe_b, | ||||||
|  |                            ArrayRef<Tensor>& scales_b) { | ||||||
|  |   // both types must be fp8 | ||||||
|  |   if (type_a != ScalarType::Float8_e4m3fn || type_b != ScalarType::Float8_e4m3fn) { | ||||||
|  |     return false; | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   // 1 scales, 1 recipes for each input | ||||||
|  |   if (scales_a.size() != 1 || recipe_a.size() != 1 || scales_b.size() != 1 || recipe_b.size() != 1) { | ||||||
|  |     return false; | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   // Need {Blockwise_1x128, float} for A, {Blockwise_128x128, float} for B | ||||||
|  |   if (recipe_a[0] != expected_recipe_a) return false; | ||||||
|  |   if (scales_a[0].scalar_type() != ScalarType::Float) return false; | ||||||
|  |   if (recipe_b[0] != expected_recipe_b) return false; | ||||||
|  |   if (scales_b[0].scalar_type() != ScalarType::Float) return false; | ||||||
|  |  | ||||||
|  |   return true; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | /** | ||||||
|  |  * Both inputs must be fp8 | ||||||
|  |  * A, B must have 1 scale each, {Blockwise_1x32, e8m0} | ||||||
|  |  */ | ||||||
|  | bool check_mxfp8_recipe(c10::ScalarType type_a, | ||||||
|  |                         std::vector<ScalingType>& recipe_a, | ||||||
|  |                         ArrayRef<Tensor>& scales_a, | ||||||
|  |                         c10::ScalarType type_b, | ||||||
|  |                         std::vector<ScalingType>& recipe_b, | ||||||
|  |                         ArrayRef<Tensor>& scales_b) { | ||||||
|  |   // both types must be fp8 | ||||||
|  |   if (type_a != ScalarType::Float8_e4m3fn || type_b != ScalarType::Float8_e4m3fn) { | ||||||
|  |     return false; | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   // 1 scales, 1 recipes for each input | ||||||
|  |   if (scales_a.size() != 1 || recipe_a.size() != 1 || scales_b.size() != 1 || recipe_b.size() != 1) { | ||||||
|  |     return false; | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   // Need {Blockwise_1x32, e8m0} for A & B | ||||||
|  |   if (recipe_a[0] != ScalingType::BlockWise1x32) return false; | ||||||
|  |   if (scales_a[0].scalar_type() != ScalarType::Float8_e8m0fnu) return false; | ||||||
|  |   if (recipe_b[0] != ScalingType::BlockWise1x32) return false; | ||||||
|  |   if (scales_b[0].scalar_type() != ScalarType::Float8_e8m0fnu) return false; | ||||||
|  |  | ||||||
|  |   return true; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | /** | ||||||
|  |  * Both inputs must be fp4 | ||||||
|  |  * A, B must have 1 scale each, {Blockwise_1x32, e8m0} | ||||||
|  |  */ | ||||||
|  | bool check_mxfp4_recipe(c10::ScalarType type_a, | ||||||
|  |                         std::vector<ScalingType>& recipe_a, | ||||||
|  |                         ArrayRef<Tensor>& scales_a, | ||||||
|  |                         c10::ScalarType type_b, | ||||||
|  |                         std::vector<ScalingType>& recipe_b, | ||||||
|  |                         ArrayRef<Tensor>& scales_b) { | ||||||
|  |   // both types must be fp4 | ||||||
|  |   if (type_a != ScalarType::Float4_e2m1fn_x2 || type_b != ScalarType::Float4_e2m1fn_x2) { | ||||||
|  |     return false; | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   // 1 scales, 1 recipes for each input | ||||||
|  |   if (scales_a.size() != 1 || recipe_a.size() != 1 || scales_b.size() != 1 || recipe_b.size() != 1) { | ||||||
|  |     return false; | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   // Need {Blockwise_1x32, e8m0} for A & B | ||||||
|  |   if (recipe_a[0] != ScalingType::BlockWise1x32) return false; | ||||||
|  |   if (scales_a[0].scalar_type() != ScalarType::Float8_e8m0fnu) return false; | ||||||
|  |   if (recipe_b[0] != ScalingType::BlockWise1x32) return false; | ||||||
|  |   if (scales_b[0].scalar_type() != ScalarType::Float8_e8m0fnu) return false; | ||||||
|  |  | ||||||
|  |   return true; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | } // namespace at::native::cuda::blas::scaled | ||||||
							
								
								
									
										174
									
								
								aten/src/ATen/cuda/CUDAScaledBlas.h
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										174
									
								
								aten/src/ATen/cuda/CUDAScaledBlas.h
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,174 @@ | |||||||
|  | #include <cstdint> | ||||||
|  | #include <c10/util/typeid.h> | ||||||
|  | #include <c10/util/Exception.h> | ||||||
|  | #include <c10/util/SmallVector.h> | ||||||
|  | #include <c10/core/Scalar.h> | ||||||
|  | #include <c10/core/ScalarType.h> | ||||||
|  | #include <c10/util/Exception.h> | ||||||
|  | #define TORCH_ASSERT_ONLY_METHOD_OPERATORS | ||||||
|  | #include <ATen/core/Tensor.h> | ||||||
|  | #include <ATen/core/NamedTensor.h> | ||||||
|  | #include <ATen/Dispatch.h> | ||||||
|  | #include <ATen/ExpandUtils.h> | ||||||
|  | #include <ATen/OpMathType.h> | ||||||
|  | #include <ATen/TensorUtils.h> | ||||||
|  | #include <ATen/cuda/CUDABlas.h> | ||||||
|  | #include <ATen/cuda/tunable/Tunable.h> | ||||||
|  | #include <ATen/cuda/tunable/TunableGemm.h> | ||||||
|  | #include <ATen/native/Resize.h> | ||||||
|  | #include <c10/util/MaybeOwned.h> | ||||||
|  | #include <ATen/native/GroupedMMUtils.h> | ||||||
|  | #include <ATen/native/cuda/RowwiseScaledMM.h> | ||||||
|  | #include <ATen/native/cuda/ScaledGroupMM.h> | ||||||
|  | #include <ATen/native/cuda/GroupMM.h> | ||||||
|  | #include <ATen/ceil_div.h> | ||||||
|  |  | ||||||
|  | #ifdef USE_FBGEMM_GENAI | ||||||
|  | #include <fbgemm_gpu/torch_ops.h> | ||||||
|  | #endif | ||||||
|  |  | ||||||
|  | #ifndef AT_PER_OPERATOR_HEADERS | ||||||
|  | #include <ATen/Functions.h> | ||||||
|  | #include <ATen/NativeFunctions.h> | ||||||
|  | #else | ||||||
|  | #include <ATen/ops/_addmm_activation_native.h> | ||||||
|  | #include <ATen/ops/_efficientzerotensor.h> | ||||||
|  | #include <ATen/ops/_scaled_mm_native.h> | ||||||
|  | #include <ATen/ops/_unsafe_view_native.h> | ||||||
|  | #include <ATen/ops/abs.h> | ||||||
|  | #include <ATen/ops/addmm_native.h> | ||||||
|  | #include <ATen/ops/addmv_native.h> | ||||||
|  | #include <ATen/ops/baddbmm_native.h> | ||||||
|  | #include <ATen/ops/bmm_native.h> | ||||||
|  | #include <ATen/ops/copy_native.h> | ||||||
|  | #include <ATen/ops/dot_native.h> | ||||||
|  | #include <ATen/ops/empty.h> | ||||||
|  | #include <ATen/ops/empty_strided.h> | ||||||
|  | #include <ATen/ops/gelu.h> | ||||||
|  | #include <ATen/ops/max.h> | ||||||
|  | #include <ATen/ops/mm_native.h> | ||||||
|  | #include <ATen/ops/mul.h> | ||||||
|  | #include <ATen/ops/relu.h> | ||||||
|  | #include <ATen/ops/ones.h> | ||||||
|  | #include <ATen/ops/scalar_tensor_native.h> | ||||||
|  | #include <ATen/ops/vdot_native.h> | ||||||
|  | #endif | ||||||
|  |  | ||||||
|  | using at::blas::ScalingType; | ||||||
|  | using at::blas::SwizzleType; | ||||||
|  |  | ||||||
|  | namespace at::cuda::scaled { | ||||||
|  |  | ||||||
|  | static bool _scaled_mm_allowed_device(bool sm90_only=false, bool sm100_only=false) { | ||||||
|  | #ifdef USE_ROCM | ||||||
|  |     static const std::vector<std::string> archs = { | ||||||
|  |         "gfx942", | ||||||
|  | #if ROCM_VERSION >= 60300 | ||||||
|  |         "gfx1200", "gfx1201", | ||||||
|  | #endif | ||||||
|  | #if ROCM_VERSION >= 60500 | ||||||
|  |         "gfx950" | ||||||
|  | #endif | ||||||
|  |     }; | ||||||
|  |     return at::detail::getCUDAHooks().isGPUArch(archs); | ||||||
|  | #else | ||||||
|  |     auto dprops = at::cuda::getCurrentDeviceProperties(); | ||||||
|  |  | ||||||
|  |     if (sm90_only || sm100_only) { | ||||||
|  |       return (sm90_only && dprops->major == 9) || (sm100_only && dprops->major == 10); | ||||||
|  |     } else { | ||||||
|  |       return dprops->major >= 9 || (dprops->major == 8 && dprops->minor == 9); | ||||||
|  |     } | ||||||
|  | #endif | ||||||
|  | } | ||||||
|  |  | ||||||
|  | #ifdef USE_ROCM | ||||||
|  | static bool _scaled_mm_is_fnuz() { | ||||||
|  |     return at::detail::getCUDAHooks().isGPUArch({"gfx942"}); | ||||||
|  | } | ||||||
|  | #endif | ||||||
|  | /** | ||||||
|  |  * Track concrete implementations available | ||||||
|  |  */ | ||||||
|  | enum class ScaledGemmImplementation { | ||||||
|  |   NONE = 0, | ||||||
|  |   TENSORWISE_TENSORWISE = 1, | ||||||
|  |   ROWWISE_ROWWISE = 2, | ||||||
|  |   BLOCK_128x128_1x128 = 3, | ||||||
|  |   BLOCK_1x128_128x128 = 4, | ||||||
|  |   BLOCK_1x128_1x128 = 5, | ||||||
|  |   MXFP8_MXFP8 = 6, | ||||||
|  |   NVFP4_NVFP4 = 7, | ||||||
|  |   NVFP4_NVFP4_SINGLE_SCALE = 8, | ||||||
|  |   MXFP4_MXFP4 = 9, | ||||||
|  | }; | ||||||
|  |  | ||||||
|  | /** | ||||||
|  |  * Convert passed int (enum) from python back into a | ||||||
|  |  * strictly-typed enum | ||||||
|  |  */ | ||||||
|  | template <class EnumType, class ArrayType> | ||||||
|  | std::vector<EnumType> convert_int_to_enum(ArrayType& v) { | ||||||
|  |   std::vector<EnumType> converted; | ||||||
|  |   converted.reserve(v.size()); | ||||||
|  |  | ||||||
|  |   for (auto vi : v) { | ||||||
|  |     converted.push_back(static_cast<EnumType>(vi)); | ||||||
|  |   } | ||||||
|  |   return converted; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | bool check_tensorwise_recipe(c10::ScalarType, | ||||||
|  |                              std::vector<ScalingType>&, | ||||||
|  |                              ArrayRef<Tensor>&, | ||||||
|  |                              c10::ScalarType, | ||||||
|  |                              std::vector<ScalingType>&, | ||||||
|  |                              ArrayRef<Tensor>&); | ||||||
|  |  | ||||||
|  |  | ||||||
|  | bool check_rowwise_recipe(c10::ScalarType, | ||||||
|  |                              std::vector<ScalingType>&, | ||||||
|  |                              ArrayRef<Tensor>&, | ||||||
|  |                              c10::ScalarType, | ||||||
|  |                              std::vector<ScalingType>&, | ||||||
|  |                              ArrayRef<Tensor>&); | ||||||
|  |  | ||||||
|  | bool check_nvfp4_recipe(c10::ScalarType, | ||||||
|  |                         std::vector<ScalingType>&, | ||||||
|  |                         ArrayRef<Tensor>&, | ||||||
|  |                         c10::ScalarType, | ||||||
|  |                         std::vector<ScalingType>&, | ||||||
|  |                         ArrayRef<Tensor>&); | ||||||
|  |  | ||||||
|  | bool check_nvfp4_recipe_single_scale | ||||||
|  |                        (c10::ScalarType, | ||||||
|  |                         std::vector<ScalingType>&, | ||||||
|  |                         ArrayRef<Tensor>&, | ||||||
|  |                         c10::ScalarType, | ||||||
|  |                         std::vector<ScalingType>&, | ||||||
|  |                         ArrayRef<Tensor>&); | ||||||
|  |  | ||||||
|  | bool check_deepseek_recipe(ScalingType, | ||||||
|  |                            ScalingType, | ||||||
|  |                            c10::ScalarType, | ||||||
|  |                            std::vector<ScalingType>&, | ||||||
|  |                            ArrayRef<Tensor>&, | ||||||
|  |                            c10::ScalarType, | ||||||
|  |                            std::vector<ScalingType>&, | ||||||
|  |                            ArrayRef<Tensor>&); | ||||||
|  |  | ||||||
|  | bool check_mxfp8_recipe(c10::ScalarType, | ||||||
|  |                         std::vector<ScalingType>&, | ||||||
|  |                         ArrayRef<Tensor>&, | ||||||
|  |                         c10::ScalarType, | ||||||
|  |                         std::vector<ScalingType>&, | ||||||
|  |                         ArrayRef<Tensor>&); | ||||||
|  |  | ||||||
|  | bool check_mxfp4_recipe(c10::ScalarType, | ||||||
|  |                         std::vector<ScalingType>&, | ||||||
|  |                         ArrayRef<Tensor>&, | ||||||
|  |                         c10::ScalarType, | ||||||
|  |                         std::vector<ScalingType>&, | ||||||
|  |                         ArrayRef<Tensor>&); | ||||||
|  |  | ||||||
|  | } // namespace at::native::cuda::blas::scaled | ||||||
| @ -137,7 +137,7 @@ struct CUDACachingHostAllocatorImpl | |||||||
|   void free_block_slowpath(Block* block) { |   void free_block_slowpath(Block* block) { | ||||||
|     auto start = std::chrono::steady_clock::now(); |     auto start = std::chrono::steady_clock::now(); | ||||||
|     // Users may change the allocator config at will. torch unit tests do this. |     // Users may change the allocator config at will. torch unit tests do this. | ||||||
|     // However, allocations using cudaHostRegister should use corresonding |     // However, allocations using cudaHostRegister should use corresponding | ||||||
|     // cudaHostUnregister and similarly for cudaHostAlloc / cudaFreeHost. |     // cudaHostUnregister and similarly for cudaHostAlloc / cudaFreeHost. | ||||||
|     void* ptr = block->ptr_; |     void* ptr = block->ptr_; | ||||||
|     bool use_register = false; |     bool use_register = false; | ||||||
|  | |||||||
| @ -70,11 +70,7 @@ | |||||||
| #define ATEN_CUB_MAXIMUM() NO_ROCM(at_cuda_detail)ROCM_HIPCUB(::cub)::Max() | #define ATEN_CUB_MAXIMUM() NO_ROCM(at_cuda_detail)ROCM_HIPCUB(::cub)::Max() | ||||||
| #endif | #endif | ||||||
|  |  | ||||||
| #if (!defined(USE_ROCM) && !CUB_SUPPORTS_NV_BFLOAT16()) || defined(USE_ROCM) | #if defined(USE_ROCM) | ||||||
|  |  | ||||||
| #if !defined(USE_ROCM) |  | ||||||
| namespace at_cuda_detail { |  | ||||||
| #endif |  | ||||||
|  |  | ||||||
| // backport https://github.com/NVIDIA/cub/pull/306 for c10::BFloat16 | // backport https://github.com/NVIDIA/cub/pull/306 for c10::BFloat16 | ||||||
|  |  | ||||||
| @ -96,10 +92,6 @@ template <> | |||||||
| struct ROCM_HIPCUB(cub)::NumericTraits<c10::BFloat16>: | struct ROCM_HIPCUB(cub)::NumericTraits<c10::BFloat16>: | ||||||
|        ROCM_HIPCUB(cub)::BaseTraits<ROCM_HIPCUB(cub)::FLOATING_POINT, true, false, unsigned short, c10::BFloat16> {}; |        ROCM_HIPCUB(cub)::BaseTraits<ROCM_HIPCUB(cub)::FLOATING_POINT, true, false, unsigned short, c10::BFloat16> {}; | ||||||
|  |  | ||||||
| #if !defined(USE_ROCM) |  | ||||||
| } // namespace at_cuda_detail |  | ||||||
| #endif |  | ||||||
|  |  | ||||||
| #endif | #endif | ||||||
|  |  | ||||||
| #if !defined(USE_ROCM) | #if !defined(USE_ROCM) | ||||||
| @ -121,7 +113,7 @@ struct cuda_type<c10::Half> { | |||||||
|   using type = __half; |   using type = __half; | ||||||
| }; | }; | ||||||
|  |  | ||||||
| #if !defined(USE_ROCM) && CUB_SUPPORTS_NV_BFLOAT16() | #if !defined(USE_ROCM) | ||||||
|  |  | ||||||
| template<> | template<> | ||||||
| struct cuda_type<c10::BFloat16> { | struct cuda_type<c10::BFloat16> { | ||||||
| @ -203,36 +195,6 @@ __global__ void transform_vals(InputIteratorT1 a, InputIteratorT2 b, OutputItera | |||||||
|   *out = scan_op(static_cast<acc_t>(*a), static_cast<acc_t>(*b)); |   *out = scan_op(static_cast<acc_t>(*a), static_cast<acc_t>(*b)); | ||||||
| } | } | ||||||
|  |  | ||||||
| #if !CUB_SUPPORTS_FUTURE_VALUE() |  | ||||||
| template<typename ValueT, typename InputIteratorT> |  | ||||||
| struct chained_iterator { |  | ||||||
|   using iterator_category = std::random_access_iterator_tag; |  | ||||||
|   using difference_type   = std::ptrdiff_t; |  | ||||||
|   using value_type        = ValueT; |  | ||||||
|   using pointer           = ValueT*; |  | ||||||
|   using reference         = ValueT&; |  | ||||||
|  |  | ||||||
|   InputIteratorT iter; |  | ||||||
|   ValueT *first; |  | ||||||
|   difference_type offset = 0; |  | ||||||
|  |  | ||||||
|   __device__ ValueT operator[](difference_type i) { |  | ||||||
|     i +=  offset; |  | ||||||
|     if (i == 0) { |  | ||||||
|       return *first; |  | ||||||
|     } else { |  | ||||||
|       return ValueT(iter[i - 1]); |  | ||||||
|     } |  | ||||||
|   } |  | ||||||
|   __device__ chained_iterator operator+(difference_type i) { |  | ||||||
|     return chained_iterator{iter, first, i}; |  | ||||||
|   } |  | ||||||
|   __device__ ValueT operator*() { |  | ||||||
|     return (*this)[0]; |  | ||||||
|   } |  | ||||||
| }; |  | ||||||
| #endif |  | ||||||
|  |  | ||||||
| // even though cub is supposed to support tensors with int_max elements, in reality it doesn't, | // even though cub is supposed to support tensors with int_max elements, in reality it doesn't, | ||||||
| // so split at int_max/2 | // so split at int_max/2 | ||||||
| constexpr int max_cub_size = std::numeric_limits<int>::max() / 2 + 1; // 2**30 | constexpr int max_cub_size = std::numeric_limits<int>::max() / 2 + 1; // 2**30 | ||||||
| @ -277,25 +239,6 @@ inline void inclusive_scan(InputIteratorT input, OutputIteratorT output, ScanOpT | |||||||
|         first_elem_ptr, |         first_elem_ptr, | ||||||
|         scan_op); |         scan_op); | ||||||
|     C10_CUDA_KERNEL_LAUNCH_CHECK(); |     C10_CUDA_KERNEL_LAUNCH_CHECK(); | ||||||
| #if !CUB_SUPPORTS_FUTURE_VALUE() |  | ||||||
|     using ArgIndexInputIterator = NO_ROCM(at_cuda_detail)::cub::ArgIndexInputIterator<InputIteratorT>; |  | ||||||
|     using tuple = typename ArgIndexInputIterator::value_type; |  | ||||||
|     auto input_iter_transform = [=] __device__ (const tuple &x)->input_t  { |  | ||||||
|       if (x.key == 0) { |  | ||||||
|         return *first_elem_ptr; |  | ||||||
|       } else { |  | ||||||
|         return x.value; |  | ||||||
|       } |  | ||||||
|     }; |  | ||||||
|     auto input_ = ATEN_CUB_TRANSFORM_ITERATOR(input_t, decltype(input_iter_transform), ArgIndexInputIterator)( |  | ||||||
|       ArgIndexInputIterator(input + i), input_iter_transform); |  | ||||||
|     CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceScan::InclusiveScan, |  | ||||||
|         input_, |  | ||||||
|         output + i, |  | ||||||
|         scan_op, |  | ||||||
|         size_cub, |  | ||||||
|         at::cuda::getCurrentCUDAStream()); |  | ||||||
| #else |  | ||||||
|     CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceScan::ExclusiveScan, |     CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceScan::ExclusiveScan, | ||||||
|         input + i + 1, |         input + i + 1, | ||||||
|         output + i, |         output + i, | ||||||
| @ -303,7 +246,6 @@ inline void inclusive_scan(InputIteratorT input, OutputIteratorT output, ScanOpT | |||||||
|         ::at_cuda_detail::cub::FutureValue<input_t>(first_elem_ptr), |         ::at_cuda_detail::cub::FutureValue<input_t>(first_elem_ptr), | ||||||
|         size_cub, |         size_cub, | ||||||
|         at::cuda::getCurrentCUDAStream()); |         at::cuda::getCurrentCUDAStream()); | ||||||
| #endif |  | ||||||
|   } |   } | ||||||
| #endif | #endif | ||||||
| } | } | ||||||
| @ -555,16 +497,6 @@ inline void exclusive_scan(InputIteratorT input, OutputIteratorT output, ScanOpT | |||||||
|         first_elem_ptr, |         first_elem_ptr, | ||||||
|         scan_op); |         scan_op); | ||||||
|     C10_CUDA_KERNEL_LAUNCH_CHECK(); |     C10_CUDA_KERNEL_LAUNCH_CHECK(); | ||||||
| #if !CUB_SUPPORTS_FUTURE_VALUE() |  | ||||||
|     auto input_ = impl::chained_iterator<InitValueT, InputIteratorT>{ |  | ||||||
|       input + i, first_elem_ptr}; |  | ||||||
|     CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceScan::InclusiveScan, |  | ||||||
|         input_, |  | ||||||
|         output + i, |  | ||||||
|         scan_op, |  | ||||||
|         size_cub, |  | ||||||
|         at::cuda::getCurrentCUDAStream()); |  | ||||||
| #else |  | ||||||
|     CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceScan::ExclusiveScan, |     CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceScan::ExclusiveScan, | ||||||
|         input + i, |         input + i, | ||||||
|         output + i, |         output + i, | ||||||
| @ -572,7 +504,6 @@ inline void exclusive_scan(InputIteratorT input, OutputIteratorT output, ScanOpT | |||||||
|         ::at_cuda_detail::cub::FutureValue<InitValueT>(first_elem_ptr), |         ::at_cuda_detail::cub::FutureValue<InitValueT>(first_elem_ptr), | ||||||
|         size_cub, |         size_cub, | ||||||
|         at::cuda::getCurrentCUDAStream()); |         at::cuda::getCurrentCUDAStream()); | ||||||
| #endif |  | ||||||
|   } |   } | ||||||
| #endif | #endif | ||||||
| } | } | ||||||
|  | |||||||
| @ -4,7 +4,7 @@ | |||||||
| #include <ATen/cuda/CUDAConfig.h> | #include <ATen/cuda/CUDAConfig.h> | ||||||
|  |  | ||||||
| // NOTE: These templates are intentionally not defined in this header, | // NOTE: These templates are intentionally not defined in this header, | ||||||
| // which aviods re-compiling them for each translation unit. If you get | // which avoids re-compiling them for each translation unit. If you get | ||||||
| // a link error, you need to add an explicit instantiation for your | // a link error, you need to add an explicit instantiation for your | ||||||
| // types in cub.cu | // types in cub.cu | ||||||
|  |  | ||||||
|  | |||||||
| @ -10,14 +10,6 @@ | |||||||
| #define CUB_VERSION 200001 | #define CUB_VERSION 200001 | ||||||
| #endif | #endif | ||||||
|  |  | ||||||
| // cub sort support for __nv_bfloat16 is added to cub 1.13 in: |  | ||||||
| // https://github.com/NVIDIA/cub/pull/306 |  | ||||||
| #if CUB_VERSION >= 101300 |  | ||||||
| #define CUB_SUPPORTS_NV_BFLOAT16() true |  | ||||||
| #else |  | ||||||
| #define CUB_SUPPORTS_NV_BFLOAT16() false |  | ||||||
| #endif |  | ||||||
|  |  | ||||||
| // cub support for CUB_WRAPPED_NAMESPACE is added to cub 1.13.1 in: | // cub support for CUB_WRAPPED_NAMESPACE is added to cub 1.13.1 in: | ||||||
| // https://github.com/NVIDIA/cub/pull/326 | // https://github.com/NVIDIA/cub/pull/326 | ||||||
| // CUB_WRAPPED_NAMESPACE is defined globally in cmake/Dependencies.cmake | // CUB_WRAPPED_NAMESPACE is defined globally in cmake/Dependencies.cmake | ||||||
| @ -28,14 +20,6 @@ | |||||||
| #define USE_GLOBAL_CUB_WRAPPED_NAMESPACE() false | #define USE_GLOBAL_CUB_WRAPPED_NAMESPACE() false | ||||||
| #endif | #endif | ||||||
|  |  | ||||||
| // cub support for cub::FutureValue is added to cub 1.15 in: |  | ||||||
| // https://github.com/NVIDIA/cub/pull/305 |  | ||||||
| #if CUB_VERSION >= 101500 |  | ||||||
| #define CUB_SUPPORTS_FUTURE_VALUE() true |  | ||||||
| #else |  | ||||||
| #define CUB_SUPPORTS_FUTURE_VALUE() false |  | ||||||
| #endif |  | ||||||
|  |  | ||||||
| // There were many bc-breaking changes in major version release of CCCL v3.0.0 | // There were many bc-breaking changes in major version release of CCCL v3.0.0 | ||||||
| // Please see https://nvidia.github.io/cccl/cccl/3.0_migration_guide.html | // Please see https://nvidia.github.io/cccl/cccl/3.0_migration_guide.html | ||||||
| #if CUB_VERSION >= 200800 | #if CUB_VERSION >= 200800 | ||||||
|  | |||||||
| @ -38,7 +38,7 @@ GemmTunableOp_float_NT,nt_25088_4096_64,1219,1.262 | |||||||
| GemmTunableOp_float_NT,nt_4096_4096_64,1216,0.033 | GemmTunableOp_float_NT,nt_4096_4096_64,1216,0.033 | ||||||
| ``` | ``` | ||||||
|  |  | ||||||
| Note the "Validator" lines. If you change a library verison, or ROCm version, or PyTorch version, TunableOp will detect | Note the "Validator" lines. If you change a library version, or ROCm version, or PyTorch version, TunableOp will detect | ||||||
| this and reject the tunings file because the prior tunings are likely affected by other software changes. | this and reject the tunings file because the prior tunings are likely affected by other software changes. | ||||||
|  |  | ||||||
| The remaining lines are the tuned solutions for each TunableOp encountered during your execution. Each line consists of | The remaining lines are the tuned solutions for each TunableOp encountered during your execution. Each line consists of | ||||||
|  | |||||||
| @ -235,7 +235,7 @@ class TunableOp { | |||||||
|       // numeric check option is controlled by non-static env var, so check it once per tuned operator |       // numeric check option is controlled by non-static env var, so check it once per tuned operator | ||||||
|       bool do_numerics_check = ctx->IsNumericsCheckEnabled(); |       bool do_numerics_check = ctx->IsNumericsCheckEnabled(); | ||||||
|  |  | ||||||
|       // calcaulte a reference answer for numerical check |       // calculate a reference answer for numerical check | ||||||
|       if (do_numerics_check) { |       if (do_numerics_check) { | ||||||
|         reference_params = params->DeepCopy(false); |         reference_params = params->DeepCopy(false); | ||||||
|         TORCH_CHECK(ops_[ResultEntry::Default()]->Call(reference_params) == OK); |         TORCH_CHECK(ops_[ResultEntry::Default()]->Call(reference_params) == OK); | ||||||
|  | |||||||
| @ -12,7 +12,7 @@ namespace at { | |||||||
|  |  | ||||||
| // AcceleratorHooksInterface is a shared interface provided by all | // AcceleratorHooksInterface is a shared interface provided by all | ||||||
| // accelerators to allow generic code. | // accelerators to allow generic code. | ||||||
| // This inferface is hook-based as it corresponds to all the functions | // This interface is hook-based as it corresponds to all the functions | ||||||
| // that are going to be called in a generic way from the CPU code. | // that are going to be called in a generic way from the CPU code. | ||||||
|  |  | ||||||
| struct TORCH_API AcceleratorHooksInterface { | struct TORCH_API AcceleratorHooksInterface { | ||||||
|  | |||||||
| @ -38,7 +38,7 @@ struct TORCH_API PrivateUse1HooksInterface : AcceleratorHooksInterface { | |||||||
|  |  | ||||||
|   Generator getNewGenerator( |   Generator getNewGenerator( | ||||||
|       [[maybe_unused]] DeviceIndex device_index = -1) const override { |       [[maybe_unused]] DeviceIndex device_index = -1) const override { | ||||||
|     // TODO(FFFrog): Perserved for BC and will be removed in the future. |     // TODO(FFFrog): Preserved for BC and will be removed in the future. | ||||||
|     if (at::GetGeneratorPrivate().has_value()) |     if (at::GetGeneratorPrivate().has_value()) | ||||||
|       return at::GetGeneratorForPrivateuse1(device_index); |       return at::GetGeneratorForPrivateuse1(device_index); | ||||||
|  |  | ||||||
|  | |||||||
							
								
								
									
										23
									
								
								aten/src/ATen/detail/XLAHooksInterface.cpp
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										23
									
								
								aten/src/ATen/detail/XLAHooksInterface.cpp
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,23 @@ | |||||||
|  | #include <ATen/detail/XLAHooksInterface.h> | ||||||
|  |  | ||||||
|  | namespace at { | ||||||
|  | namespace detail { | ||||||
|  |  | ||||||
|  | const XLAHooksInterface& getXLAHooks() { | ||||||
|  |   auto create_impl = [] { | ||||||
|  |     // Create XLA hooks using the registry | ||||||
|  |     auto hooks = XLAHooksRegistry()->Create("torch_xla::detail::XLAHooks", XLAHooksArgs{}); | ||||||
|  |     if (hooks) { | ||||||
|  |       return hooks; | ||||||
|  |     } | ||||||
|  |     // If hooks creation fails, fall back to default implementation | ||||||
|  |     return std::make_unique<XLAHooksInterface>(); | ||||||
|  |   }; | ||||||
|  |   static auto hooks = create_impl(); | ||||||
|  |   return *hooks; | ||||||
|  | } | ||||||
|  | } // namespace detail | ||||||
|  |  | ||||||
|  | C10_DEFINE_REGISTRY(XLAHooksRegistry, XLAHooksInterface, XLAHooksArgs) | ||||||
|  |  | ||||||
|  | } // namespace at | ||||||
							
								
								
									
										79
									
								
								aten/src/ATen/detail/XLAHooksInterface.h
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										79
									
								
								aten/src/ATen/detail/XLAHooksInterface.h
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,79 @@ | |||||||
|  | #pragma once | ||||||
|  |  | ||||||
|  | #include <c10/core/Device.h> | ||||||
|  | #include <c10/util/Exception.h> | ||||||
|  | #include <c10/util/Registry.h> | ||||||
|  |  | ||||||
|  | #include <ATen/detail/AcceleratorHooksInterface.h> | ||||||
|  |  | ||||||
|  | C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-parameter") | ||||||
|  |  | ||||||
|  | namespace at { | ||||||
|  |  | ||||||
|  | constexpr const char* XLA_HELP = | ||||||
|  |   "This error has occurred because you are trying " | ||||||
|  |   "to use some XLA functionality, but the XLA library has not been " | ||||||
|  |   "loaded by the dynamic linker. You must load xla libraries by `import torch_xla`"; | ||||||
|  |  | ||||||
|  | struct TORCH_API XLAHooksInterface : AcceleratorHooksInterface { | ||||||
|  |   ~XLAHooksInterface() override = default; | ||||||
|  |  | ||||||
|  |   void init() const override { | ||||||
|  |     TORCH_CHECK(false, "Cannot initialize XLA without torch_xla library. ", XLA_HELP); | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   virtual bool hasXLA() const { | ||||||
|  |     return false; | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   virtual std::string showConfig() const { | ||||||
|  |     TORCH_CHECK( | ||||||
|  |         false, | ||||||
|  |         "Cannot query detailed XLA version without torch_xla library. ", | ||||||
|  |         XLA_HELP); | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   const Generator& getDefaultGenerator( | ||||||
|  |       [[maybe_unused]] DeviceIndex device_index = -1) const override { | ||||||
|  |     TORCH_CHECK( | ||||||
|  |         false, "Cannot get default XLA generator without torch_xla library. ", XLA_HELP); | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   Generator getNewGenerator( | ||||||
|  |       [[maybe_unused]] DeviceIndex device_index = -1) const override { | ||||||
|  |     TORCH_CHECK(false, "Cannot get XLA generator without torch_xla library. ", XLA_HELP); | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   virtual DeviceIndex getCurrentDevice() const override { | ||||||
|  |     TORCH_CHECK(false, "Cannot get current XLA device without torch_xla library. ", XLA_HELP); | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   Device getDeviceFromPtr(void* /*data*/) const override { | ||||||
|  |     TORCH_CHECK(false, "Cannot get device of pointer on XLA without torch_xla library. ", XLA_HELP); | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   Allocator* getPinnedMemoryAllocator() const override { | ||||||
|  |     TORCH_CHECK(false, "Cannot get XLA pinned memory allocator without torch_xla library. ", XLA_HELP); | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   bool isPinnedPtr(const void* data) const override { | ||||||
|  |     return false; | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   bool hasPrimaryContext(DeviceIndex device_index) const override { | ||||||
|  |     TORCH_CHECK(false, "Cannot query primary context without torch_xla library. ", XLA_HELP); | ||||||
|  |   } | ||||||
|  |  | ||||||
|  | }; | ||||||
|  |  | ||||||
|  | struct TORCH_API XLAHooksArgs {}; | ||||||
|  |  | ||||||
|  | TORCH_DECLARE_REGISTRY(XLAHooksRegistry, XLAHooksInterface, XLAHooksArgs); | ||||||
|  | #define REGISTER_XLA_HOOKS(clsname) \ | ||||||
|  |   C10_REGISTER_CLASS(XLAHooksRegistry, clsname, clsname) | ||||||
|  |  | ||||||
|  | namespace detail { | ||||||
|  | TORCH_API const XLAHooksInterface& getXLAHooks(); | ||||||
|  | } // namespace detail | ||||||
|  | } // namespace at | ||||||
|  | C10_DIAGNOSTIC_POP() | ||||||
| @ -283,7 +283,7 @@ inline void boxed_existing_bdim_all_batch_rule( | |||||||
| // Use when all tensors arguments accept one (normal) batch dim. | // Use when all tensors arguments accept one (normal) batch dim. | ||||||
| // This batching rule expands the batch dim on all Tensors, reshapes it into | // This batching rule expands the batch dim on all Tensors, reshapes it into | ||||||
| // dim 0, calls the op, and then reshapes the batch dim out of dim 0. | // dim 0, calls the op, and then reshapes the batch dim out of dim 0. | ||||||
| // This is not the most efficient thing; if there are alternatives, plese try | // This is not the most efficient thing; if there are alternatives, please try | ||||||
| // to use them. Use this only as a last resort. | // to use them. Use this only as a last resort. | ||||||
| #define EXISTING_BDIM_ALL_BOXED(op) \ | #define EXISTING_BDIM_ALL_BOXED(op) \ | ||||||
|   m.impl(#op, torch::CppFunction::makeFromBoxedFunction<boxed_existing_bdim_all_batch_rule>()); |   m.impl(#op, torch::CppFunction::makeFromBoxedFunction<boxed_existing_bdim_all_batch_rule>()); | ||||||
|  | |||||||
| @ -384,7 +384,7 @@ fourOutputs solve_ex_batch_rule( | |||||||
|  |  | ||||||
|   // NOTE [ solve_ex Batch Rule Contiguity ] |   // NOTE [ solve_ex Batch Rule Contiguity ] | ||||||
|   // A determines whether or not linalg_solve takes an optimized path. We need the check on A_ to match the one run on |   // A determines whether or not linalg_solve takes an optimized path. We need the check on A_ to match the one run on | ||||||
|   // A as BatchedTensor since it might have been saved by autograd (specifically by the jvp) and the autograd behvaior |   // A as BatchedTensor since it might have been saved by autograd (specifically by the jvp) and the autograd behavior | ||||||
|   // differs based on whether or not the optimized path was taken |   // differs based on whether or not the optimized path was taken | ||||||
|   const auto batched_A_was_contiguous = A_bdim.has_value() ? at::select(A, *A_bdim, 0).is_contiguous() : A.is_contiguous(); |   const auto batched_A_was_contiguous = A_bdim.has_value() ? at::select(A, *A_bdim, 0).is_contiguous() : A.is_contiguous(); | ||||||
|   if (batched_A_was_contiguous && !A.is_complex()) { |   if (batched_A_was_contiguous && !A.is_complex()) { | ||||||
|  | |||||||
| @ -282,7 +282,7 @@ static std::tuple<Tensor, std::optional<int64_t>> _softmax_backward_batch_rule( | |||||||
|  |  | ||||||
|   dim = getPhysicalDim(output_, /*has_batch_dim*/true, dim); |   dim = getPhysicalDim(output_, /*has_batch_dim*/true, dim); | ||||||
|  |  | ||||||
|   // Not sure why output_ needs to be marked as .contiguous(). Someting must |   // Not sure why output_ needs to be marked as .contiguous(). Something must | ||||||
|   // have changed in PyTorch (and output of softmax is probably always contiguous) |   // have changed in PyTorch (and output of softmax is probably always contiguous) | ||||||
|   return std::make_tuple(at::_softmax_backward_data(grad_output_, output_.contiguous(), dim, input_dtype), 0); |   return std::make_tuple(at::_softmax_backward_data(grad_output_, output_.contiguous(), dim, input_dtype), 0); | ||||||
| } | } | ||||||
|  | |||||||
| @ -224,7 +224,7 @@ static Tensor safeStack(TensorList tensors) { | |||||||
|   // is possible for the backward function to return an undefined grad for some |   // is possible for the backward function to return an undefined grad for some | ||||||
|   // grad_input for each example. In that case, we return an undefined grad. |   // grad_input for each example. In that case, we return an undefined grad. | ||||||
|   // |   // | ||||||
|   // It is theoretically posssible for *some* of the examples to produce an |   // It is theoretically possible for *some* of the examples to produce an | ||||||
|   // undefined grad (a kernel could peek at the gradient values and return an |   // undefined grad (a kernel could peek at the gradient values and return an | ||||||
|   // undefined tensor if it determines the gradient is full of zeros). We |   // undefined tensor if it determines the gradient is full of zeros). We | ||||||
|   // could handle this by treating the undefined grad as a zero-filled tensor |   // could handle this by treating the undefined grad as a zero-filled tensor | ||||||
|  | |||||||
| @ -113,7 +113,7 @@ SymIntArrayRef BatchedTensorImpl::sym_sizes_custom() const { | |||||||
|   return sym_sizes_default(); |   return sym_sizes_default(); | ||||||
| } | } | ||||||
|  |  | ||||||
| // The following are publically exposed as methods of Tensor | // The following are publicly exposed as methods of Tensor | ||||||
|  |  | ||||||
| IntArrayRef BatchedTensorImpl::strides_custom() const { | IntArrayRef BatchedTensorImpl::strides_custom() const { | ||||||
|   return strides_default(); |   return strides_default(); | ||||||
|  | |||||||
| @ -37,7 +37,7 @@ namespace at::functorch  { | |||||||
| // how to perform the transform. | // how to perform the transform. | ||||||
| // | // | ||||||
| // TODO: we can excise DynamicLayer in favor of Interpreter, | // TODO: we can excise DynamicLayer in favor of Interpreter, | ||||||
| // But I am going to leave it for now as a compatiblity shim to avoid | // But I am going to leave it for now as a compatibility shim to avoid | ||||||
| // needing to refactor a lot of callsites... | // needing to refactor a lot of callsites... | ||||||
| struct TORCH_API DynamicLayer { | struct TORCH_API DynamicLayer { | ||||||
|   explicit DynamicLayer( |   explicit DynamicLayer( | ||||||
|  | |||||||
| @ -88,7 +88,7 @@ std::ostream& operator<<(std::ostream& os, const TransformType& t); | |||||||
| // >>> VmapInterpreterPtr(&interpreter).batchSize() | // >>> VmapInterpreterPtr(&interpreter).batchSize() | ||||||
| // | // | ||||||
| // Finally, Interpreter::process switches on the type of the interpreter | // Finally, Interpreter::process switches on the type of the interpreter | ||||||
| // and calls one of {Transform}Intepreter::processImpl under the hood. | // and calls one of {Transform}Interpreter::processImpl under the hood. | ||||||
| // Same for Interpreter::sendToNextInterpreter :) | // Same for Interpreter::sendToNextInterpreter :) | ||||||
|  |  | ||||||
| struct VmapInterpreterMeta { | struct VmapInterpreterMeta { | ||||||
|  | |||||||
| @ -733,7 +733,7 @@ TORCH_LIBRARY_IMPL(_, FuncTorchBatched, m) { | |||||||
| } | } | ||||||
|  |  | ||||||
| TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) { | TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) { | ||||||
|   // still legacy b/c teturns multiple tensors |   // still legacy b/c returns multiple tensors | ||||||
|   m.impl("split.Tensor", split_batching_rule); |   m.impl("split.Tensor", split_batching_rule); | ||||||
|   m.impl("split_with_sizes", split_with_sizes_batching_rule); |   m.impl("split_with_sizes", split_with_sizes_batching_rule); | ||||||
|   m.impl("split_with_sizes_copy", split_with_sizes_copy_batching_rule); |   m.impl("split_with_sizes_copy", split_with_sizes_copy_batching_rule); | ||||||
|  | |||||||
| @ -158,7 +158,7 @@ void MPSStream::fill(id<MTLBuffer> buffer, uint8_t value, size_t length, size_t | |||||||
|       endKernelCoalescing(); |       endKernelCoalescing(); | ||||||
|       id<MTLBlitCommandEncoder> blitEncoder = [commandBuffer() blitCommandEncoder]; |       id<MTLBlitCommandEncoder> blitEncoder = [commandBuffer() blitCommandEncoder]; | ||||||
|  |  | ||||||
|       // For some reason fillBufferfor stopped working for lengh > 4Gb on MacOS 26 |       // For some reason fillBufferfor stopped working for length > 4Gb on MacOS 26 | ||||||
|       // See https://github.com/pytorch/pytorch/issues/163962 |       // See https://github.com/pytorch/pytorch/issues/163962 | ||||||
|       // Workaround by batching copy commands into 4Gb chunks |       // Workaround by batching copy commands into 4Gb chunks | ||||||
|       constexpr size_t max_copy_size = 0x100000000; // 4GB |       constexpr size_t max_copy_size = 0x100000000; // 4GB | ||||||
|  | |||||||
| @ -3620,7 +3620,7 @@ Tensor& _int_mm_out_cpu(const Tensor& self, const Tensor& mat2, Tensor& result) | |||||||
|     try { |     try { | ||||||
|       mkldnn_matmul_i8i8i32(self, mat2, result); |       mkldnn_matmul_i8i8i32(self, mat2, result); | ||||||
|       dispatched = true; |       dispatched = true; | ||||||
|     } catch (const std::exception& e) { |     } catch ([[maybe_unused]] const std::exception& e) { | ||||||
|       TORCH_WARN(func_name, " failed, switching to BLAS gemm: ", e.what()); |       TORCH_WARN(func_name, " failed, switching to BLAS gemm: ", e.what()); | ||||||
|     } |     } | ||||||
|   } |   } | ||||||
|  | |||||||
| @ -148,7 +148,7 @@ inline void checkInputsSolver(const Tensor& A, | |||||||
|  |  | ||||||
| inline bool is_row_or_column_contiguous(const Tensor& t) { | inline bool is_row_or_column_contiguous(const Tensor& t) { | ||||||
|   // This could be made more general, similar to how it's checked in matmul, which would allow to |   // This could be made more general, similar to how it's checked in matmul, which would allow to | ||||||
|   // ellide the copy with strides such as (6, 12, 1, 3) or (3, 1, 9), but this is quite tricky. |   // elide the copy with strides such as (6, 12, 1, 3) or (3, 1, 9), but this is quite tricky. | ||||||
|   // We choose to be conservative for simplicity |   // We choose to be conservative for simplicity | ||||||
|   return t.is_contiguous() || t.transpose(-2, -1).is_contiguous(); |   return t.is_contiguous() || t.transpose(-2, -1).is_contiguous(); | ||||||
| } | } | ||||||
|  | |||||||
| @ -11,6 +11,8 @@ inline void check_pixel_shuffle_shapes(const Tensor& self, int64_t upscale_facto | |||||||
|               "pixel_shuffle expects a positive upscale_factor, but got ", |               "pixel_shuffle expects a positive upscale_factor, but got ", | ||||||
|               upscale_factor); |               upscale_factor); | ||||||
|   int64_t c = self.size(-3); |   int64_t c = self.size(-3); | ||||||
|  |   TORCH_CHECK_VALUE(upscale_factor <= std::numeric_limits<decltype(upscale_factor)>::max() / upscale_factor, | ||||||
|  |         "upscale factor is too large, (upscale_factor)^2 overflowed: upscale_factor=", upscale_factor); | ||||||
|   int64_t upscale_factor_squared = upscale_factor * upscale_factor; |   int64_t upscale_factor_squared = upscale_factor * upscale_factor; | ||||||
|   TORCH_CHECK(c % upscale_factor_squared == 0, |   TORCH_CHECK(c % upscale_factor_squared == 0, | ||||||
|               "pixel_shuffle expects its input's 'channel' dimension to be divisible by the square of " |               "pixel_shuffle expects its input's 'channel' dimension to be divisible by the square of " | ||||||
|  | |||||||
| @ -21,7 +21,7 @@ enum class fft_norm_mode { | |||||||
| // NOTE [ Fourier Transform Conjugate Symmetry ] | // NOTE [ Fourier Transform Conjugate Symmetry ] | ||||||
| // | // | ||||||
| // Real-to-complex Fourier transform satisfies the conjugate symmetry. That is, | // Real-to-complex Fourier transform satisfies the conjugate symmetry. That is, | ||||||
| // assuming X is the transformed K-dimensionsal signal, we have | // assuming X is the transformed K-dimensional signal, we have | ||||||
| // | // | ||||||
| //     X[i_1, ..., i_K] = X[j_i, ..., j_K]*, | //     X[i_1, ..., i_K] = X[j_i, ..., j_K]*, | ||||||
| // | // | ||||||
|  | |||||||
Some files were not shown because too many files have changed in this diff Show More
		Reference in New Issue
	
	Block a user
	