mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-31 20:34:54 +08:00 
			
		
		
		
	Compare commits
	
		
			176 Commits
		
	
	
		
			cpp-docs-d
			...
			gh/slayton
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| ae37a754c4 | |||
| e9965f02da | |||
| e214af6ae8 | |||
| 7ce723d21c | |||
| 4295a9a158 | |||
| 90d7be35e9 | |||
| 8d4e48831e | |||
| 90b30ebf7e | |||
| 173bcda436 | |||
| 6530bc70fb | |||
| 4c38887346 | |||
| 81fa4a204c | |||
| 4e6afa8c07 | |||
| 79aa88cc5d | |||
| fa4cb91846 | |||
| c58d0ad85d | |||
| 000f49551b | |||
| 9940e894ea | |||
| 27302a4932 | |||
| 507614ba43 | |||
| 86f9f1d0ab | |||
| 154e4d36e9 | |||
| a2b6afeac5 | |||
| 262830d86c | |||
| e4c01011c2 | |||
| a60d9e1f6d | |||
| f863550192 | |||
| 84b14f3a10 | |||
| 5121499f6b | |||
| 8f80892359 | |||
| cdb60e44eb | |||
| 25909d2629 | |||
| c7eee49525 | |||
| 621ba05107 | |||
| 39a70cead1 | |||
| d97f6550a2 | |||
| 516e58965a | |||
| b55b779ad3 | |||
| 74e53d0761 | |||
| 798a6d2be1 | |||
| b0e9c86971 | |||
| 661a56002f | |||
| c9bc00f016 | |||
| ec51b139e1 | |||
| eb83c3ca23 | |||
| 7924e3aacf | |||
| 78bcfcf870 | |||
| 1e2e7cb18b | |||
| 003601a70d | |||
| 1d58d5fe25 | |||
| de7fdfe41a | |||
| b31bad1b8f | |||
| 2efcf3ca98 | |||
| 761f946043 | |||
| 8aa465f18e | |||
| 0a5d68d92d | |||
| 42bd210fff | |||
| 1d13c314b3 | |||
| 0c9763a5a0 | |||
| 79a4a9c02e | |||
| 9d0b77f4cd | |||
| d486eee234 | |||
| cddd5f74ab | |||
| dfdb68e51f | |||
| 98c818320a | |||
| cc20b7ad72 | |||
| bc11a42b3f | |||
| 4fc06f2e0a | |||
| 82473c3d59 | |||
| b6a4236e5d | |||
| b04173be9b | |||
| 32ac38f85d | |||
| c9b49e506e | |||
| 6038e476e8 | |||
| 2c851c16e5 | |||
| 31584f2d91 | |||
| 0442125362 | |||
| fdcf402d82 | |||
| 13cda9b89e | |||
| fa6d911dda | |||
| 0db6bcc015 | |||
| 60ac039998 | |||
| 380d440d1c | |||
| 9038a30cee | |||
| 690c8c13b9 | |||
| 28ee6b62ed | |||
| 81577bdb3f | |||
| e67e3d95f3 | |||
| 27af8480ea | |||
| 6494cdc40c | |||
| ac7074efa2 | |||
| 263901cec4 | |||
| c12293dcbe | |||
| 5a4997dcae | |||
| 47f638eae7 | |||
| 882b834082 | |||
| b146ea411e | |||
| 8625ffbd45 | |||
| 0977cc4474 | |||
| d9a55faccc | |||
| 75b8295868 | |||
| defb6a80d8 | |||
| f8fccb1e48 | |||
| 5aac4cfce4 | |||
| baf91bbbfc | |||
| cbcb4f7768 | |||
| 2b93d5b450 | |||
| 6b7cd48e7e | |||
| bf5aa9e42e | |||
| b1eb6dede5 | |||
| 673060beae | |||
| 2e8e9a59a8 | |||
| fb277a5916 | |||
| 73fa0d0c63 | |||
| 36c21cc84e | |||
| 0b68814b44 | |||
| e64a814ae7 | |||
| 0b58d87aec | |||
| 757975ad50 | |||
| 291712026b | |||
| 3e77a2b478 | |||
| 82ef1b5db3 | |||
| 5f370f5c42 | |||
| 05b2e02cb4 | |||
| 12f742941d | |||
| 35180fafee | |||
| c746feb86a | |||
| c5f26db5bf | |||
| 18e99b6d45 | |||
| ab9e466928 | |||
| af4ba78543 | |||
| 282f39a4bc | |||
| a479769488 | |||
| 26c7375477 | |||
| d01f15152c | |||
| 4fae6968b1 | |||
| f9953e0f61 | |||
| 34ed7a8f0d | |||
| 2fde10d914 | |||
| 0a93295da0 | |||
| 4b898b51b9 | |||
| 550e3e6efb | |||
| 715449ca76 | |||
| 84d8d06fc3 | |||
| 60992d98b2 | |||
| 59e015e3a1 | |||
| 8904a5a7c9 | |||
| f5df9ca03a | |||
| 2998abd777 | |||
| e13580e41c | |||
| f3b8e15f20 | |||
| 5211f4c108 | |||
| ad9027b80d | |||
| a1005427bf | |||
| 35153d0846 | |||
| 7773a22cdb | |||
| 7cb467a169 | |||
| 12aac12b8d | |||
| 2b748d0a56 | |||
| 16745a882a | |||
| 8daef35cf1 | |||
| 51319ca090 | |||
| d311a3d1dc | |||
| 04adfe5ba9 | |||
| 4be1e3bf92 | |||
| e7592f4005 | |||
| d334c3649d | |||
| 9f82535c5a | |||
| 5b35fc8777 | |||
| 2f38eece7c | |||
| 830e789a55 | |||
| ad4dc52bf6 | |||
| dac9ed9790 | |||
| 1c7fe8f861 | |||
| 4e643422f6 | |||
| 3c3b278872 | 
| @ -19,7 +19,7 @@ pip_install \ | |||||||
|   transformers==4.36.2 |   transformers==4.36.2 | ||||||
|  |  | ||||||
| pip_install coloredlogs packaging | pip_install coloredlogs packaging | ||||||
| pip_install onnxruntime==1.23.0 | pip_install onnxruntime==1.23.1 | ||||||
| pip_install onnxscript==0.5.4 | pip_install onnxscript==0.5.4 | ||||||
|  |  | ||||||
| # Cache the transformers model to be used later by ONNX tests. We need to run the transformers | # Cache the transformers model to be used later by ONNX tests. We need to run the transformers | ||||||
|  | |||||||
| @ -334,12 +334,12 @@ sympy==1.13.3 | |||||||
| #Pinned versions: | #Pinned versions: | ||||||
| #test that import: | #test that import: | ||||||
|  |  | ||||||
| onnx==1.18.0 | onnx==1.19.1 | ||||||
| #Description: Required by onnx tests, and mypy and test_public_bindings.py when checking torch.onnx._internal | #Description: Required by onnx tests, and mypy and test_public_bindings.py when checking torch.onnx._internal | ||||||
| #Pinned versions: | #Pinned versions: | ||||||
| #test that import: | #test that import: | ||||||
|  |  | ||||||
| onnxscript==0.5.3 | onnxscript==0.5.4 | ||||||
| #Description: Required by mypy and test_public_bindings.py when checking torch.onnx._internal | #Description: Required by mypy and test_public_bindings.py when checking torch.onnx._internal | ||||||
| #Pinned versions: | #Pinned versions: | ||||||
| #test that import: | #test that import: | ||||||
|  | |||||||
| @ -1,11 +1,15 @@ | |||||||
| sphinx==7.2.6 | sphinx==5.3.0 | ||||||
| #Description: This is used to generate PyTorch docs | #Description: This is used to generate PyTorch docs | ||||||
| #Pinned versions: 7.2.6 | #Pinned versions: 5.3.0 | ||||||
|  |  | ||||||
| pytorch_sphinx_theme2==0.1.0 | standard-imghdr==3.13.0; python_version >= "3.13" | ||||||
| #Description: This is needed to generate PyTorch docs | #Description: This is needed by Sphinx, so it needs to be added here. | ||||||
| #Pinned versions: 0.1.0 | # The reasons are as follows: | ||||||
|  | # 1) This module has been removed from the Python standard library since Python 3.13(https://peps.python.org/pep-0594/#imghdr); | ||||||
|  | # 2) The current version of Sphinx (5.3.0) is not compatible with Python 3.13. | ||||||
|  | # Once Sphinx is upgraded to a version compatible with Python 3.13 or later, we can remove this dependency. | ||||||
|  |  | ||||||
|  | -e git+https://github.com/pytorch/pytorch_sphinx_theme.git@71e55749be14ceb56e7f8211a9fb649866b87ad4#egg=pytorch_sphinx_theme2 | ||||||
| # TODO: sphinxcontrib.katex 0.9.0 adds a local KaTeX server to speed up pre-rendering | # TODO: sphinxcontrib.katex 0.9.0 adds a local KaTeX server to speed up pre-rendering | ||||||
| # but it doesn't seem to work and hangs around idly. The initial thought that it is probably | # but it doesn't seem to work and hangs around idly. The initial thought that it is probably | ||||||
| # something related to Docker setup. We can investigate this later. | # something related to Docker setup. We can investigate this later. | ||||||
| @ -32,17 +36,17 @@ tensorboard==2.18.0 ; python_version >= "3.13" | |||||||
| #Description: This is used to generate PyTorch docs | #Description: This is used to generate PyTorch docs | ||||||
| #Pinned versions: 2.13.0 | #Pinned versions: 2.13.0 | ||||||
|  |  | ||||||
| breathe==4.36.0 | breathe==4.34.0 | ||||||
| #Description: This is used to generate PyTorch C++ docs | #Description: This is used to generate PyTorch C++ docs | ||||||
| #Pinned versions: 4.36.0 | #Pinned versions: 4.34.0 | ||||||
|  |  | ||||||
| exhale==0.3.7 | exhale==0.2.3 | ||||||
| #Description: This is used to generate PyTorch C++ docs | #Description: This is used to generate PyTorch C++ docs | ||||||
| #Pinned versions: 0.3.7 | #Pinned versions: 0.2.3 | ||||||
|  |  | ||||||
| docutils==0.20 | docutils==0.16 | ||||||
| #Description: This is used to generate PyTorch C++ docs | #Description: This is used to generate PyTorch C++ docs | ||||||
| #Pinned versions: 0.20 | #Pinned versions: 0.16 | ||||||
|  |  | ||||||
| bs4==0.0.1 | bs4==0.0.1 | ||||||
| #Description: This is used to generate PyTorch C++ docs | #Description: This is used to generate PyTorch C++ docs | ||||||
| @ -52,13 +56,13 @@ IPython==8.12.0 | |||||||
| #Description: This is used to generate PyTorch functorch docs | #Description: This is used to generate PyTorch functorch docs | ||||||
| #Pinned versions: 8.12.0 | #Pinned versions: 8.12.0 | ||||||
|  |  | ||||||
| myst-nb==1.3.0 | myst-nb==0.17.2 | ||||||
| #Description: This is used to generate PyTorch functorch and torch.compile docs. | #Description: This is used to generate PyTorch functorch and torch.compile docs. | ||||||
| #Pinned versions: 1.3.0 | #Pinned versions: 0.17.2 | ||||||
|  |  | ||||||
| # The following are required to build torch.distributed.elastic.rendezvous.etcd* docs | # The following are required to build torch.distributed.elastic.rendezvous.etcd* docs | ||||||
| python-etcd==0.4.5 | python-etcd==0.4.5 | ||||||
| sphinx-copybutton==0.5.0 | sphinx-copybutton==0.5.0 | ||||||
| sphinx-design==0.6.1 | sphinx-design==0.4.0 | ||||||
| sphinxcontrib-mermaid==1.0.0 | sphinxcontrib-mermaid==1.0.0 | ||||||
| myst-parser==4.0.1 | myst-parser==0.18.1 | ||||||
|  | |||||||
| @ -6,7 +6,7 @@ dependencies = [ | |||||||
|     "GitPython==3.1.45", |     "GitPython==3.1.45", | ||||||
|     "docker==7.1.0", |     "docker==7.1.0", | ||||||
|     "pytest==7.3.2", |     "pytest==7.3.2", | ||||||
|     "uv==0.8.6" |     "uv==0.9.5" | ||||||
| ] | ] | ||||||
|  |  | ||||||
| [tool.setuptools] | [tool.setuptools] | ||||||
|  | |||||||
| @ -102,18 +102,8 @@ if [ "$is_main_doc" = true ]; then | |||||||
|     echo coverage output not found |     echo coverage output not found | ||||||
|     exit 1 |     exit 1 | ||||||
|   elif [ $undocumented -gt 0 ]; then |   elif [ $undocumented -gt 0 ]; then | ||||||
|     echo "======================================" |     echo undocumented objects found: | ||||||
|     echo "ERROR: $undocumented undocumented objects found!" |  | ||||||
|     echo "======================================" |  | ||||||
|     echo "" |  | ||||||
|     echo "Full coverage report:" |  | ||||||
|     cat build/coverage/python.txt |     cat build/coverage/python.txt | ||||||
|     echo "" |  | ||||||
|     echo "======================================" |  | ||||||
|     echo "Undocumented modules/objects (lines after TOTAL):" |  | ||||||
|     tail -n +$((lines - undocumented + 1)) build/coverage/python.txt |  | ||||||
|     echo "======================================" |  | ||||||
|     echo "" |  | ||||||
|     echo "Make sure you've updated relevant .rsts in docs/source!" |     echo "Make sure you've updated relevant .rsts in docs/source!" | ||||||
|     echo "You can reproduce locally by running 'cd docs && make coverage && cat build/coverage/python.txt'" |     echo "You can reproduce locally by running 'cd docs && make coverage && cat build/coverage/python.txt'" | ||||||
|     exit 1 |     exit 1 | ||||||
|  | |||||||
							
								
								
									
										354
									
								
								.claude/skills/pytorch-docstring.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										354
									
								
								.claude/skills/pytorch-docstring.md
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,354 @@ | |||||||
|  | # PyTorch Docstring Writing Guide | ||||||
|  |  | ||||||
|  | This skill describes how to write docstrings for functions and methods in the PyTorch project, following the conventions in `torch/_tensor_docs.py` and `torch/nn/functional.py`. | ||||||
|  |  | ||||||
|  | ## General Principles | ||||||
|  |  | ||||||
|  | - Use **raw strings** (`r"""..."""`) for all docstrings to avoid issues with LaTeX/math backslashes | ||||||
|  | - Follow **Sphinx/reStructuredText** (reST) format for documentation | ||||||
|  | - Be **concise but complete** - include all essential information | ||||||
|  | - Always include **examples** when possible | ||||||
|  | - Use **cross-references** to related functions/classes | ||||||
|  |  | ||||||
|  | ## Docstring Structure | ||||||
|  |  | ||||||
|  | ### 1. Function Signature (First Line) | ||||||
|  |  | ||||||
|  | Start with the function signature showing all parameters: | ||||||
|  |  | ||||||
|  | ```python | ||||||
|  | r"""function_name(param1, param2, *, kwarg1=default1, kwarg2=default2) -> ReturnType | ||||||
|  | ``` | ||||||
|  |  | ||||||
|  | **Notes:** | ||||||
|  | - Include the function name | ||||||
|  | - Show positional and keyword-only arguments (use `*` separator) | ||||||
|  | - Include default values | ||||||
|  | - Show return type annotation | ||||||
|  | - This line should NOT end with a period | ||||||
|  |  | ||||||
|  | ### 2. Brief Description | ||||||
|  |  | ||||||
|  | Provide a one-line description of what the function does: | ||||||
|  |  | ||||||
|  | ```python | ||||||
|  | r"""conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1) -> Tensor | ||||||
|  |  | ||||||
|  | Applies a 2D convolution over an input image composed of several input | ||||||
|  | planes. | ||||||
|  | ``` | ||||||
|  |  | ||||||
|  | ### 3. Mathematical Formulas (if applicable) | ||||||
|  |  | ||||||
|  | Use Sphinx math directives for mathematical expressions: | ||||||
|  |  | ||||||
|  | ```python | ||||||
|  | .. math:: | ||||||
|  |     \text{Softmax}(x_{i}) = \frac{\exp(x_i)}{\sum_j \exp(x_j)} | ||||||
|  | ``` | ||||||
|  |  | ||||||
|  | Or inline math: `:math:\`x^2\`` | ||||||
|  |  | ||||||
|  | ### 4. Cross-References | ||||||
|  |  | ||||||
|  | Link to related classes and functions using Sphinx roles: | ||||||
|  |  | ||||||
|  | - `:class:\`~torch.nn.ModuleName\`` - Link to a class | ||||||
|  | - `:func:\`torch.function_name\`` - Link to a function | ||||||
|  | - `:meth:\`~Tensor.method_name\`` - Link to a method | ||||||
|  | - `:attr:\`attribute_name\`` - Reference an attribute | ||||||
|  | - The `~` prefix shows only the last component (e.g., `Conv2d` instead of `torch.nn.Conv2d`) | ||||||
|  |  | ||||||
|  | **Example:** | ||||||
|  | ```python | ||||||
|  | See :class:`~torch.nn.Conv2d` for details and output shape. | ||||||
|  | ``` | ||||||
|  |  | ||||||
|  | ### 5. Notes and Warnings | ||||||
|  |  | ||||||
|  | Use admonitions for important information: | ||||||
|  |  | ||||||
|  | ```python | ||||||
|  | .. note:: | ||||||
|  |     This function doesn't work directly with NLLLoss, | ||||||
|  |     which expects the Log to be computed between the Softmax and itself. | ||||||
|  |     Use log_softmax instead (it's faster and has better numerical properties). | ||||||
|  |  | ||||||
|  | .. warning:: | ||||||
|  |     :func:`new_tensor` always copies :attr:`data`. If you have a Tensor | ||||||
|  |     ``data`` and want to avoid a copy, use :func:`torch.Tensor.requires_grad_` | ||||||
|  |     or :func:`torch.Tensor.detach`. | ||||||
|  | ``` | ||||||
|  |  | ||||||
|  | ### 6. Args Section | ||||||
|  |  | ||||||
|  | Document all parameters with type annotations and descriptions: | ||||||
|  |  | ||||||
|  | ```python | ||||||
|  | Args: | ||||||
|  |     input (Tensor): input tensor of shape :math:`(\text{minibatch} , \text{in\_channels} , iH , iW)` | ||||||
|  |     weight (Tensor): filters of shape :math:`(\text{out\_channels} , kH , kW)` | ||||||
|  |     bias (Tensor, optional): optional bias tensor of shape :math:`(\text{out\_channels})`. Default: ``None`` | ||||||
|  |     stride (int or tuple): the stride of the convolving kernel. Can be a single number or a | ||||||
|  |       tuple `(sH, sW)`. Default: 1 | ||||||
|  | ``` | ||||||
|  |  | ||||||
|  | **Formatting rules:** | ||||||
|  | - Parameter name in **lowercase** | ||||||
|  | - Type in parentheses: `(Type)`, `(Type, optional)` for optional parameters | ||||||
|  | - Description follows the type | ||||||
|  | - For optional parameters, include "Default: ``value``" at the end | ||||||
|  | - Use double backticks for inline code: ``` ``None`` ``` | ||||||
|  | - Indent continuation lines by 2 spaces | ||||||
|  |  | ||||||
|  | ### 7. Keyword Args Section (if applicable) | ||||||
|  |  | ||||||
|  | Sometimes keyword arguments are documented separately: | ||||||
|  |  | ||||||
|  | ```python | ||||||
|  | Keyword args: | ||||||
|  |     dtype (:class:`torch.dtype`, optional): the desired type of returned tensor. | ||||||
|  |         Default: if None, same :class:`torch.dtype` as this tensor. | ||||||
|  |     device (:class:`torch.device`, optional): the desired device of returned tensor. | ||||||
|  |         Default: if None, same :class:`torch.device` as this tensor. | ||||||
|  |     requires_grad (bool, optional): If autograd should record operations on the | ||||||
|  |         returned tensor. Default: ``False``. | ||||||
|  | ``` | ||||||
|  |  | ||||||
|  | ### 8. Returns Section (if needed) | ||||||
|  |  | ||||||
|  | Document the return value: | ||||||
|  |  | ||||||
|  | ```python | ||||||
|  | Returns: | ||||||
|  |     Tensor: Sampled tensor of same shape as `logits` from the Gumbel-Softmax distribution. | ||||||
|  |         If ``hard=True``, the returned samples will be one-hot, otherwise they will | ||||||
|  |         be probability distributions that sum to 1 across `dim`. | ||||||
|  | ``` | ||||||
|  |  | ||||||
|  | Or simply include it in the function signature line if obvious from context. | ||||||
|  |  | ||||||
|  | ### 9. Examples Section | ||||||
|  |  | ||||||
|  | Always include examples when possible: | ||||||
|  |  | ||||||
|  | ```python | ||||||
|  | Examples:: | ||||||
|  |  | ||||||
|  |     >>> inputs = torch.randn(33, 16, 30) | ||||||
|  |     >>> filters = torch.randn(20, 16, 5) | ||||||
|  |     >>> F.conv1d(inputs, filters) | ||||||
|  |  | ||||||
|  |     >>> # With square kernels and equal stride | ||||||
|  |     >>> filters = torch.randn(8, 4, 3, 3) | ||||||
|  |     >>> inputs = torch.randn(1, 4, 5, 5) | ||||||
|  |     >>> F.conv2d(inputs, filters, padding=1) | ||||||
|  | ``` | ||||||
|  |  | ||||||
|  | **Formatting rules:** | ||||||
|  | - Use `Examples::` with double colon | ||||||
|  | - Use `>>>` prompt for Python code | ||||||
|  | - Include comments with `#` when helpful | ||||||
|  | - Show actual output when it helps understanding (indent without `>>>`) | ||||||
|  |  | ||||||
|  | ### 10. External References | ||||||
|  |  | ||||||
|  | Link to papers or external documentation: | ||||||
|  |  | ||||||
|  | ```python | ||||||
|  | .. _Link Name: | ||||||
|  |     https://arxiv.org/abs/1611.00712 | ||||||
|  | ``` | ||||||
|  |  | ||||||
|  | Reference them in text: ```See `Link Name`_``` | ||||||
|  |  | ||||||
|  | ## Method Types | ||||||
|  |  | ||||||
|  | ### Native Python Functions | ||||||
|  |  | ||||||
|  | For regular Python functions, use a standard docstring: | ||||||
|  |  | ||||||
|  | ```python | ||||||
|  | def relu(input: Tensor, inplace: bool = False) -> Tensor: | ||||||
|  |     r"""relu(input, inplace=False) -> Tensor | ||||||
|  |  | ||||||
|  |     Applies the rectified linear unit function element-wise. See | ||||||
|  |     :class:`~torch.nn.ReLU` for more details. | ||||||
|  |     """ | ||||||
|  |     # implementation | ||||||
|  | ``` | ||||||
|  |  | ||||||
|  | ### C-Bound Functions (using add_docstr) | ||||||
|  |  | ||||||
|  | For C-bound functions, use `_add_docstr`: | ||||||
|  |  | ||||||
|  | ```python | ||||||
|  | conv1d = _add_docstr( | ||||||
|  |     torch.conv1d, | ||||||
|  |     r""" | ||||||
|  | conv1d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1) -> Tensor | ||||||
|  |  | ||||||
|  | Applies a 1D convolution over an input signal composed of several input | ||||||
|  | planes. | ||||||
|  |  | ||||||
|  | See :class:`~torch.nn.Conv1d` for details and output shape. | ||||||
|  |  | ||||||
|  | Args: | ||||||
|  |     input: input tensor of shape :math:`(\text{minibatch} , \text{in\_channels} , iW)` | ||||||
|  |     weight: filters of shape :math:`(\text{out\_channels} , kW)` | ||||||
|  |     ... | ||||||
|  | """, | ||||||
|  | ) | ||||||
|  | ``` | ||||||
|  |  | ||||||
|  | ### In-Place Variants | ||||||
|  |  | ||||||
|  | For in-place operations (ending with `_`), reference the original: | ||||||
|  |  | ||||||
|  | ```python | ||||||
|  | add_docstr_all( | ||||||
|  |     "abs_", | ||||||
|  |     r""" | ||||||
|  | abs_() -> Tensor | ||||||
|  |  | ||||||
|  | In-place version of :meth:`~Tensor.abs` | ||||||
|  | """, | ||||||
|  | ) | ||||||
|  | ``` | ||||||
|  |  | ||||||
|  | ### Alias Functions | ||||||
|  |  | ||||||
|  | For aliases, simply reference the original: | ||||||
|  |  | ||||||
|  | ```python | ||||||
|  | add_docstr_all( | ||||||
|  |     "absolute", | ||||||
|  |     r""" | ||||||
|  | absolute() -> Tensor | ||||||
|  |  | ||||||
|  | Alias for :func:`abs` | ||||||
|  | """, | ||||||
|  | ) | ||||||
|  | ``` | ||||||
|  |  | ||||||
|  | ## Common Patterns | ||||||
|  |  | ||||||
|  | ### Shape Documentation | ||||||
|  |  | ||||||
|  | Use LaTeX math notation for tensor shapes: | ||||||
|  |  | ||||||
|  | ```python | ||||||
|  | :math:`(\text{minibatch} , \text{in\_channels} , iH , iW)` | ||||||
|  | ``` | ||||||
|  |  | ||||||
|  | ### Reusable Argument Definitions | ||||||
|  |  | ||||||
|  | For commonly used arguments, define them once and reuse: | ||||||
|  |  | ||||||
|  | ```python | ||||||
|  | common_args = parse_kwargs( | ||||||
|  |     """ | ||||||
|  |     dtype (:class:`torch.dtype`, optional): the desired type of returned tensor. | ||||||
|  |         Default: if None, same as this tensor. | ||||||
|  | """ | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | # Then use with .format(): | ||||||
|  | r""" | ||||||
|  | ... | ||||||
|  |  | ||||||
|  | Keyword args: | ||||||
|  |     {dtype} | ||||||
|  |     {device} | ||||||
|  | """.format(**common_args) | ||||||
|  | ``` | ||||||
|  |  | ||||||
|  | ### Template Insertion | ||||||
|  |  | ||||||
|  | Insert reproducibility notes or other common text: | ||||||
|  |  | ||||||
|  | ```python | ||||||
|  | r""" | ||||||
|  | {tf32_note} | ||||||
|  |  | ||||||
|  | {cudnn_reproducibility_note} | ||||||
|  | """.format(**reproducibility_notes, **tf32_notes) | ||||||
|  | ``` | ||||||
|  |  | ||||||
|  | ## Complete Example | ||||||
|  |  | ||||||
|  | Here's a complete example showing all elements: | ||||||
|  |  | ||||||
|  | ```python | ||||||
|  | def gumbel_softmax( | ||||||
|  |     logits: Tensor, | ||||||
|  |     tau: float = 1, | ||||||
|  |     hard: bool = False, | ||||||
|  |     eps: float = 1e-10, | ||||||
|  |     dim: int = -1, | ||||||
|  | ) -> Tensor: | ||||||
|  |     r""" | ||||||
|  |     Sample from the Gumbel-Softmax distribution and optionally discretize. | ||||||
|  |  | ||||||
|  |     Args: | ||||||
|  |         logits (Tensor): `[..., num_features]` unnormalized log probabilities | ||||||
|  |         tau (float): non-negative scalar temperature | ||||||
|  |         hard (bool): if ``True``, the returned samples will be discretized as one-hot vectors, | ||||||
|  |               but will be differentiated as if it is the soft sample in autograd. Default: ``False`` | ||||||
|  |         dim (int): A dimension along which softmax will be computed. Default: -1 | ||||||
|  |  | ||||||
|  |     Returns: | ||||||
|  |         Tensor: Sampled tensor of same shape as `logits` from the Gumbel-Softmax distribution. | ||||||
|  |             If ``hard=True``, the returned samples will be one-hot, otherwise they will | ||||||
|  |             be probability distributions that sum to 1 across `dim`. | ||||||
|  |  | ||||||
|  |     .. note:: | ||||||
|  |         This function is here for legacy reasons, may be removed from nn.Functional in the future. | ||||||
|  |  | ||||||
|  |     Examples:: | ||||||
|  |         >>> logits = torch.randn(20, 32) | ||||||
|  |         >>> # Sample soft categorical using reparametrization trick: | ||||||
|  |         >>> F.gumbel_softmax(logits, tau=1, hard=False) | ||||||
|  |         >>> # Sample hard categorical using "Straight-through" trick: | ||||||
|  |         >>> F.gumbel_softmax(logits, tau=1, hard=True) | ||||||
|  |  | ||||||
|  |     .. _Link 1: | ||||||
|  |         https://arxiv.org/abs/1611.00712 | ||||||
|  |     """ | ||||||
|  |     # implementation | ||||||
|  | ``` | ||||||
|  |  | ||||||
|  | ## Quick Checklist | ||||||
|  |  | ||||||
|  | When writing a PyTorch docstring, ensure: | ||||||
|  |  | ||||||
|  | - [ ] Use raw string (`r"""`) | ||||||
|  | - [ ] Include function signature on first line | ||||||
|  | - [ ] Provide brief description | ||||||
|  | - [ ] Document all parameters in Args section with types | ||||||
|  | - [ ] Include default values for optional parameters | ||||||
|  | - [ ] Use Sphinx cross-references (`:func:`, `:class:`, `:meth:`) | ||||||
|  | - [ ] Add mathematical formulas if applicable | ||||||
|  | - [ ] Include at least one example in Examples section | ||||||
|  | - [ ] Add warnings/notes for important caveats | ||||||
|  | - [ ] Link to related module class with `:class:` | ||||||
|  | - [ ] Use proper math notation for tensor shapes | ||||||
|  | - [ ] Follow consistent formatting and indentation | ||||||
|  |  | ||||||
|  | ## Common Sphinx Roles Reference | ||||||
|  |  | ||||||
|  | - `:class:\`~torch.nn.Module\`` - Class reference | ||||||
|  | - `:func:\`torch.function\`` - Function reference | ||||||
|  | - `:meth:\`~Tensor.method\`` - Method reference | ||||||
|  | - `:attr:\`attribute\`` - Attribute reference | ||||||
|  | - `:math:\`equation\`` - Inline math | ||||||
|  | - `:ref:\`label\`` - Internal reference | ||||||
|  | - ``` ``code`` ``` - Inline code (use double backticks) | ||||||
|  |  | ||||||
|  | ## Additional Notes | ||||||
|  |  | ||||||
|  | - **Indentation**: Use 4 spaces for code, 2 spaces for continuation of parameter descriptions | ||||||
|  | - **Line length**: Try to keep lines under 100 characters when possible | ||||||
|  | - **Periods**: End sentences with periods, but not the signature line | ||||||
|  | - **Backticks**: Use double backticks for code: ``` ``True`` ``None`` ``False`` ``` | ||||||
|  | - **Types**: Common types are `Tensor`, `int`, `float`, `bool`, `str`, `tuple`, `list`, etc. | ||||||
							
								
								
									
										7
									
								
								.github/actions/setup-rocm/action.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										7
									
								
								.github/actions/setup-rocm/action.yml
									
									
									
									
										vendored
									
									
								
							| @ -124,3 +124,10 @@ runs: | |||||||
|       id: login-ecr |       id: login-ecr | ||||||
|       continue-on-error: true |       continue-on-error: true | ||||||
|       uses: aws-actions/amazon-ecr-login@062b18b96a7aff071d4dc91bc00c4c1a7945b076 # v2.0.1 |       uses: aws-actions/amazon-ecr-login@062b18b96a7aff071d4dc91bc00c4c1a7945b076 # v2.0.1 | ||||||
|  |  | ||||||
|  |     - name: Preserve github env variables for use in docker | ||||||
|  |       shell: bash | ||||||
|  |       run: | | ||||||
|  |         env | grep '^GITHUB' >> "${RUNNER_TEMP}/github_env_${GITHUB_RUN_ID}" | ||||||
|  |         env | grep '^CI' >> "${RUNNER_TEMP}/github_env_${GITHUB_RUN_ID}" | ||||||
|  |         env | grep '^RUNNER' >> "${RUNNER_TEMP}/github_env_${GITHUB_RUN_ID}" | ||||||
|  | |||||||
							
								
								
									
										2
									
								
								.github/ci_commit_pins/vision.txt
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.github/ci_commit_pins/vision.txt
									
									
									
									
										vendored
									
									
								
							| @ -1 +1 @@ | |||||||
| faffd5cf673615583da6517275e361cb3dbc77e6 | 1752fe6809b74921644866275ab80244b96e80bc | ||||||
|  | |||||||
							
								
								
									
										2
									
								
								.github/ci_commit_pins/xla.txt
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.github/ci_commit_pins/xla.txt
									
									
									
									
										vendored
									
									
								
							| @ -1 +1 @@ | |||||||
| 0fa6e3129e61143224663e1ec67980d12b7ec4eb | df6798dfb931ce7c7fe5bed2447cd1092a5981af | ||||||
|  | |||||||
							
								
								
									
										5
									
								
								.github/ci_configs/vllm/Dockerfile
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										5
									
								
								.github/ci_configs/vllm/Dockerfile
									
									
									
									
										vendored
									
									
								
							| @ -283,6 +283,9 @@ RUN --mount=type=bind,source=${TORCH_WHEELS_PATH},target=/dist \ | |||||||
|         uv pip install --system $(cat torch_build_versions.txt | xargs) --index-url https://download.pytorch.org/whl/nightly/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.'); \ |         uv pip install --system $(cat torch_build_versions.txt | xargs) --index-url https://download.pytorch.org/whl/nightly/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.'); \ | ||||||
|     fi |     fi | ||||||
|  |  | ||||||
|  | RUN --mount=type=cache,target=/root/.cache/uv \ | ||||||
|  |     uv pip install --system --pre apache-tvm-ffi==0.1.0b15 | ||||||
|  |  | ||||||
| # Install the vllm wheel from previous stage | # Install the vllm wheel from previous stage | ||||||
| RUN --mount=type=cache,target=/root/.cache/uv \ | RUN --mount=type=cache,target=/root/.cache/uv \ | ||||||
|     uv pip install --system /wheels/vllm/*.whl --verbose |     uv pip install --system /wheels/vllm/*.whl --verbose | ||||||
| @ -295,6 +298,8 @@ RUN --mount=type=cache,target=/root/.cache/uv \ | |||||||
| ARG torch_cuda_arch_list='8.0;8.9;9.0a;10.0a;12.0' | ARG torch_cuda_arch_list='8.0;8.9;9.0a;10.0a;12.0' | ||||||
| ENV TORCH_CUDA_ARCH_LIST=${torch_cuda_arch_list} | ENV TORCH_CUDA_ARCH_LIST=${torch_cuda_arch_list} | ||||||
|  |  | ||||||
|  | # TODO(elainewy): remove this once vllm commit is updated, and install flashinfer from pip | ||||||
|  | # see https://github.com/pytorch/pytorch/pull/165274#issuecomment-3408531784 | ||||||
| ARG FLASHINFER_GIT_REPO="https://github.com/flashinfer-ai/flashinfer.git" | ARG FLASHINFER_GIT_REPO="https://github.com/flashinfer-ai/flashinfer.git" | ||||||
| ARG FLASHINFER_GIT_REF="v0.2.14.post1" | ARG FLASHINFER_GIT_REF="v0.2.14.post1" | ||||||
|  |  | ||||||
|  | |||||||
							
								
								
									
										9
									
								
								.github/label_to_label.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										9
									
								
								.github/label_to_label.yml
									
									
									
									
										vendored
									
									
								
							| @ -15,6 +15,11 @@ | |||||||
|   - "module: reinplacing" |   - "module: reinplacing" | ||||||
|   then: |   then: | ||||||
|   - "module: pt2-dispatcher" |   - "module: pt2-dispatcher" | ||||||
|  | - any: | ||||||
|  |   - "vllm-compile" | ||||||
|  |   then: | ||||||
|  |   - "module: vllm" | ||||||
|  |   - "oncall: pt2" | ||||||
| - any: | - any: | ||||||
|   - "module: vmap" |   - "module: vmap" | ||||||
|   then: |   then: | ||||||
| @ -27,10 +32,6 @@ | |||||||
|   - "module: pt2 optimizer" |   - "module: pt2 optimizer" | ||||||
|   then: |   then: | ||||||
|   - "module: dynamo" |   - "module: dynamo" | ||||||
| - any: |  | ||||||
|   - "module: flex attention" |  | ||||||
|   then: |  | ||||||
|   - "module: higher order operators" |  | ||||||
| - any: | - any: | ||||||
|   - "module: aotinductor" |   - "module: aotinductor" | ||||||
|   then: |   then: | ||||||
|  | |||||||
							
								
								
									
										1
									
								
								.github/workflows/inductor-periodic.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								.github/workflows/inductor-periodic.yml
									
									
									
									
										vendored
									
									
								
							| @ -88,7 +88,6 @@ jobs: | |||||||
|     with: |     with: | ||||||
|       build-environment: linux-jammy-rocm-py3_10 |       build-environment: linux-jammy-rocm-py3_10 | ||||||
|       docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3-benchmarks |       docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3-benchmarks | ||||||
|       sync-tag: rocm-build |  | ||||||
|       test-matrix: | |       test-matrix: | | ||||||
|         { include: [ |         { include: [ | ||||||
|           { config: "dynamo_eager_torchbench", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.gfx942.1" }, |           { config: "dynamo_eager_torchbench", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.gfx942.1" }, | ||||||
|  | |||||||
							
								
								
									
										15
									
								
								.github/workflows/periodic.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										15
									
								
								.github/workflows/periodic.yml
									
									
									
									
										vendored
									
									
								
							| @ -147,15 +147,16 @@ jobs: | |||||||
|       runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" |       runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" | ||||||
|       build-environment: linux-jammy-cuda12.8-py3.10-gcc9-debug |       build-environment: linux-jammy-cuda12.8-py3.10-gcc9-debug | ||||||
|       docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9 |       docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9 | ||||||
|  |       cuda-arch-list: 8.9 | ||||||
|       test-matrix: | |       test-matrix: | | ||||||
|         { include: [ |         { include: [ | ||||||
|           { config: "default", shard: 1, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu", owners: ["oncall:debug-build"] }, |           { config: "default", shard: 1, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", owners: ["oncall:debug-build"] }, | ||||||
|           { config: "default", shard: 2, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu", owners: ["oncall:debug-build"] }, |           { config: "default", shard: 2, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", owners: ["oncall:debug-build"] }, | ||||||
|           { config: "default", shard: 3, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu", owners: ["oncall:debug-build"] }, |           { config: "default", shard: 3, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", owners: ["oncall:debug-build"] }, | ||||||
|           { config: "default", shard: 4, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu", owners: ["oncall:debug-build"] }, |           { config: "default", shard: 4, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", owners: ["oncall:debug-build"] }, | ||||||
|           { config: "default", shard: 5, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu", owners: ["oncall:debug-build"] }, |           { config: "default", shard: 5, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", owners: ["oncall:debug-build"] }, | ||||||
|           { config: "default", shard: 6, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu", owners: ["oncall:debug-build"] }, |           { config: "default", shard: 6, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", owners: ["oncall:debug-build"] }, | ||||||
|           { config: "default", shard: 7, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu", owners: ["oncall:debug-build"] }, |           { config: "default", shard: 7, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", owners: ["oncall:debug-build"] }, | ||||||
|         ]} |         ]} | ||||||
|     secrets: inherit |     secrets: inherit | ||||||
|  |  | ||||||
|  | |||||||
							
								
								
									
										3
									
								
								.github/workflows/pull.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										3
									
								
								.github/workflows/pull.yml
									
									
									
									
										vendored
									
									
								
							| @ -347,7 +347,8 @@ jobs: | |||||||
|     uses: ./.github/workflows/_linux-build.yml |     uses: ./.github/workflows/_linux-build.yml | ||||||
|     needs: get-label-type |     needs: get-label-type | ||||||
|     with: |     with: | ||||||
|       sync-tag: linux-xpu-n-build |       # This should sync with the build in xpu.yml but xpu uses a larger runner | ||||||
|  |       # sync-tag: linux-xpu-n-build | ||||||
|       runner_prefix: ${{ needs.get-label-type.outputs.label-type }} |       runner_prefix: ${{ needs.get-label-type.outputs.label-type }} | ||||||
|       build-environment: linux-jammy-xpu-n-py3.10 |       build-environment: linux-jammy-xpu-n-py3.10 | ||||||
|       docker-image-name: ci-image:pytorch-linux-jammy-xpu-n-py3 |       docker-image-name: ci-image:pytorch-linux-jammy-xpu-n-py3 | ||||||
|  | |||||||
							
								
								
									
										1
									
								
								.github/workflows/rocm-mi300.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								.github/workflows/rocm-mi300.yml
									
									
									
									
										vendored
									
									
								
							| @ -45,7 +45,6 @@ jobs: | |||||||
|       runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" |       runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" | ||||||
|       build-environment: linux-noble-rocm-py3.12-mi300 |       build-environment: linux-noble-rocm-py3.12-mi300 | ||||||
|       docker-image-name: ci-image:pytorch-linux-noble-rocm-n-py3 |       docker-image-name: ci-image:pytorch-linux-noble-rocm-n-py3 | ||||||
|       sync-tag: rocm-build |  | ||||||
|       test-matrix: | |       test-matrix: | | ||||||
|         { include: [ |         { include: [ | ||||||
|           { config: "default", shard: 1, num_shards: 6, runner: "linux.rocm.gpu.gfx942.1" }, |           { config: "default", shard: 1, num_shards: 6, runner: "linux.rocm.gpu.gfx942.1" }, | ||||||
|  | |||||||
							
								
								
									
										1
									
								
								.github/workflows/rocm-mi355.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								.github/workflows/rocm-mi355.yml
									
									
									
									
										vendored
									
									
								
							| @ -42,7 +42,6 @@ jobs: | |||||||
|       runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" |       runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" | ||||||
|       build-environment: linux-noble-rocm-py3.12-mi355 |       build-environment: linux-noble-rocm-py3.12-mi355 | ||||||
|       docker-image-name: ci-image:pytorch-linux-noble-rocm-n-py3 |       docker-image-name: ci-image:pytorch-linux-noble-rocm-n-py3 | ||||||
|       sync-tag: rocm-build |  | ||||||
|       test-matrix: | |       test-matrix: | | ||||||
|         { include: [ |         { include: [ | ||||||
|           { config: "default", shard: 1, num_shards: 6, runner: "linux.rocm.gpu.mi355.1" }, |           { config: "default", shard: 1, num_shards: 6, runner: "linux.rocm.gpu.mi355.1" }, | ||||||
|  | |||||||
							
								
								
									
										12
									
								
								.github/workflows/rocm-navi31.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										12
									
								
								.github/workflows/rocm-navi31.yml
									
									
									
									
										vendored
									
									
								
							| @ -26,11 +26,23 @@ jobs: | |||||||
|       id-token: write |       id-token: write | ||||||
|       contents: read |       contents: read | ||||||
|  |  | ||||||
|  |   get-label-type: | ||||||
|  |     name: get-label-type | ||||||
|  |     uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main | ||||||
|  |     if: ${{ (github.event_name != 'schedule' || github.repository == 'pytorch/pytorch') && github.repository_owner == 'pytorch' }} | ||||||
|  |     with: | ||||||
|  |       triggering_actor: ${{ github.triggering_actor }} | ||||||
|  |       issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} | ||||||
|  |       curr_branch: ${{ github.head_ref || github.ref_name }} | ||||||
|  |       curr_ref_type: ${{ github.ref_type }} | ||||||
|  |  | ||||||
|   linux-jammy-rocm-py3_10-build: |   linux-jammy-rocm-py3_10-build: | ||||||
|     if: ${{ (github.event_name != 'schedule' || github.repository == 'pytorch/pytorch') && github.repository_owner == 'pytorch' }} |     if: ${{ (github.event_name != 'schedule' || github.repository == 'pytorch/pytorch') && github.repository_owner == 'pytorch' }} | ||||||
|     name: linux-jammy-rocm-py3.10 |     name: linux-jammy-rocm-py3.10 | ||||||
|     uses: ./.github/workflows/_linux-build.yml |     uses: ./.github/workflows/_linux-build.yml | ||||||
|  |     needs: get-label-type | ||||||
|     with: |     with: | ||||||
|  |       runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" | ||||||
|       build-environment: linux-jammy-rocm-py3.10 |       build-environment: linux-jammy-rocm-py3.10 | ||||||
|       docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3 |       docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3 | ||||||
|       sync-tag: rocm-build |       sync-tag: rocm-build | ||||||
|  | |||||||
							
								
								
									
										12
									
								
								.github/workflows/rocm.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										12
									
								
								.github/workflows/rocm.yml
									
									
									
									
										vendored
									
									
								
							| @ -26,11 +26,23 @@ jobs: | |||||||
|       id-token: write |       id-token: write | ||||||
|       contents: read |       contents: read | ||||||
|  |  | ||||||
|  |   get-label-type: | ||||||
|  |     name: get-label-type | ||||||
|  |     uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main | ||||||
|  |     if: ${{ (github.event_name != 'schedule' || github.repository == 'pytorch/pytorch') && github.repository_owner == 'pytorch' }} | ||||||
|  |     with: | ||||||
|  |       triggering_actor: ${{ github.triggering_actor }} | ||||||
|  |       issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} | ||||||
|  |       curr_branch: ${{ github.head_ref || github.ref_name }} | ||||||
|  |       curr_ref_type: ${{ github.ref_type }} | ||||||
|  |  | ||||||
|   linux-jammy-rocm-py3_10-build: |   linux-jammy-rocm-py3_10-build: | ||||||
|     if: ${{ (github.event_name != 'schedule' || github.repository == 'pytorch/pytorch') && github.repository_owner == 'pytorch' }} |     if: ${{ (github.event_name != 'schedule' || github.repository == 'pytorch/pytorch') && github.repository_owner == 'pytorch' }} | ||||||
|     name: linux-jammy-rocm-py3.10 |     name: linux-jammy-rocm-py3.10 | ||||||
|     uses: ./.github/workflows/_linux-build.yml |     uses: ./.github/workflows/_linux-build.yml | ||||||
|  |     needs: get-label-type | ||||||
|     with: |     with: | ||||||
|  |       runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" | ||||||
|       build-environment: linux-jammy-rocm-py3.10 |       build-environment: linux-jammy-rocm-py3.10 | ||||||
|       docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3 |       docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3 | ||||||
|       sync-tag: rocm-build |       sync-tag: rocm-build | ||||||
|  | |||||||
							
								
								
									
										147
									
								
								.github/workflows/trunk-tagging.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										147
									
								
								.github/workflows/trunk-tagging.yml
									
									
									
									
										vendored
									
									
								
							| @ -58,8 +58,10 @@ jobs: | |||||||
|           else |           else | ||||||
|             COMMIT_SHA="${{ github.sha }}" |             COMMIT_SHA="${{ github.sha }}" | ||||||
|           fi |           fi | ||||||
|           echo "sha=${COMMIT_SHA}" >> "${GITHUB_OUTPUT}" |           { | ||||||
|           echo "tag_name=trunk/${COMMIT_SHA}" >> "${GITHUB_OUTPUT}" |             echo "sha=${COMMIT_SHA}" | ||||||
|  |             echo "tag_name=trunk/${COMMIT_SHA}" | ||||||
|  |           } >> "${GITHUB_OUTPUT}" | ||||||
|  |  | ||||||
|       - name: Validate commit SHA |       - name: Validate commit SHA | ||||||
|         run: | |         run: | | ||||||
| @ -87,7 +89,7 @@ jobs: | |||||||
|             echo "✅ Commit ${COMMIT_SHA} is valid (automatic push trigger)" |             echo "✅ Commit ${COMMIT_SHA} is valid (automatic push trigger)" | ||||||
|           fi |           fi | ||||||
|  |  | ||||||
|       - name: Create and push tag with retry |       - name: Create and push tag(s) with retry | ||||||
|         id: check_tag |         id: check_tag | ||||||
|         env: |         env: | ||||||
|           TAG_NAME: ${{ steps.commit.outputs.tag_name }} |           TAG_NAME: ${{ steps.commit.outputs.tag_name }} | ||||||
| @ -112,14 +114,23 @@ jobs: | |||||||
|             return 1 |             return 1 | ||||||
|           } |           } | ||||||
|  |  | ||||||
|           # Exit early if tag already exists |           # Counters for summary reporting | ||||||
|           if check_tag_exists; then |           created_count=0 | ||||||
|             echo "✅ Tag already exists - no action needed" |           skipped_count=0 | ||||||
|             echo "exists=true" >> "${GITHUB_OUTPUT}" |           failed_count=0 | ||||||
|             exit 0 |  | ||||||
|           fi |  | ||||||
|  |  | ||||||
|           echo "Tag ${TAG_NAME} does not exist, proceeding with creation" |           # Always write outputs once on exit | ||||||
|  |           finish() { | ||||||
|  |             set +e | ||||||
|  |             if [ -n "${GITHUB_OUTPUT:-}" ]; then | ||||||
|  |               { | ||||||
|  |                 echo "created_count=${created_count}" | ||||||
|  |                 echo "skipped_count=${skipped_count}" | ||||||
|  |                 echo "failed_count=${failed_count}" | ||||||
|  |               } >> "${GITHUB_OUTPUT}" | ||||||
|  |             fi | ||||||
|  |           } | ||||||
|  |           trap finish EXIT | ||||||
|  |  | ||||||
|           # Retry configuration |           # Retry configuration | ||||||
|           MAX_RETRIES=5 |           MAX_RETRIES=5 | ||||||
| @ -194,31 +205,111 @@ jobs: | |||||||
|             } |             } | ||||||
|           } |           } | ||||||
|  |  | ||||||
|           # Execute with retry |           # New behavior for push events: enumerate commits in the push and tag each one. | ||||||
|           if retry_with_backoff "tag_with_retry" "Creating tag ${TAG_NAME} for commit ${COMMIT_SHA}"; then |           # For workflow_dispatch, retain existing single-SHA behavior. | ||||||
|             echo "exists=false" >> "${GITHUB_OUTPUT}" |  | ||||||
|  |           # Always fetch tags once up front to improve idempotency in loops | ||||||
|  |           git fetch origin --tags --quiet || true | ||||||
|  |  | ||||||
|  |           if [ "${{ github.event_name }}" = "push" ]; then | ||||||
|  |             BEFORE_SHA="${{ github.event.before }}" | ||||||
|  |             AFTER_SHA="${{ github.sha }}"  # same as event.after | ||||||
|  |  | ||||||
|  |             # List commits introduced by this push (old..new), oldest first for stable ordering | ||||||
|  |             commits_file="$(mktemp)" | ||||||
|  |             git rev-list --reverse "${BEFORE_SHA}..${AFTER_SHA}" > "${commits_file}" | ||||||
|  |  | ||||||
|  |             if [ ! -s "${commits_file}" ]; then | ||||||
|  |               echo "No new commits found between ${BEFORE_SHA}..${AFTER_SHA}; nothing to tag." | ||||||
|  |               rm -f "${commits_file}" | ||||||
|  |               exit 0 | ||||||
|  |             fi | ||||||
|  |  | ||||||
|  |             commit_count="$(wc -l < "${commits_file}" | tr -d ' ')" | ||||||
|  |             echo "Found ${commit_count} commit(s) to tag for push:" | ||||||
|  |             while IFS= read -r sha; do | ||||||
|  |               printf '  %s\n' "${sha}" | ||||||
|  |             done < "${commits_file}" | ||||||
|  |  | ||||||
|  |             while IFS= read -r sha; do | ||||||
|  |               TAG_NAME="trunk/${sha}" | ||||||
|  |               COMMIT_SHA="${sha}" | ||||||
|  |  | ||||||
|  |               # If tag already exists locally or remotely, skip (idempotent) | ||||||
|  |               if check_tag_exists; then | ||||||
|  |                 echo "✅ Tag ${TAG_NAME} already exists - skipping" | ||||||
|  |                 skipped_count=$((skipped_count + 1)) | ||||||
|  |                 continue | ||||||
|  |               fi | ||||||
|  |  | ||||||
|  |               echo "Tag ${TAG_NAME} does not exist, proceeding with creation" | ||||||
|  |  | ||||||
|  |               if retry_with_backoff "tag_with_retry" "Creating tag ${TAG_NAME} for commit ${COMMIT_SHA}"; then | ||||||
|  |                 created_count=$((created_count + 1)) | ||||||
|  |               else | ||||||
|  |                 echo "Tag creation failed after all retry attempts for ${TAG_NAME}" | ||||||
|  |                 failed_count=$((failed_count + 1)) | ||||||
|  |               fi | ||||||
|  |             done < "${commits_file}" | ||||||
|  |  | ||||||
|  |             rm -f "${commits_file}" | ||||||
|  |  | ||||||
|  |             if [ "${failed_count}" -gt 0 ]; then | ||||||
|  |               exit 1 | ||||||
|  |             fi | ||||||
|             exit 0 |             exit 0 | ||||||
|           else |           else | ||||||
|             echo "Tag creation failed after all retry attempts" |             # workflow_dispatch path (single SHA tagging preserved) | ||||||
|             exit 1 |  | ||||||
|  |             # Exit early if tag already exists | ||||||
|  |             if check_tag_exists; then | ||||||
|  |               echo "✅ Tag already exists - no action needed" | ||||||
|  |               skipped_count=1 | ||||||
|  |               exit 0 | ||||||
|  |             fi | ||||||
|  |  | ||||||
|  |             echo "Tag ${TAG_NAME} does not exist, proceeding with creation" | ||||||
|  |  | ||||||
|  |             if retry_with_backoff "tag_with_retry" "Creating tag ${TAG_NAME} for commit ${COMMIT_SHA}"; then | ||||||
|  |               created_count=1 | ||||||
|  |               exit 0 | ||||||
|  |             else | ||||||
|  |               echo "Tag creation failed after all retry attempts" | ||||||
|  |               failed_count=1 | ||||||
|  |               exit 1 | ||||||
|  |             fi | ||||||
|           fi |           fi | ||||||
|  |  | ||||||
|       - name: Tag creation summary |       - name: Tag creation summary | ||||||
|         if: always() |         if: always() | ||||||
|         run: | |         run: | | ||||||
|           if [ "${{ steps.check_tag.outputs.exists }}" = "true" ]; then |           if [ "${{ github.event_name }}" = "push" ]; then | ||||||
|             echo "✅ Tag ${{ steps.commit.outputs.tag_name }} already existed - no action needed" |             echo "Trigger: push on main" | ||||||
|           elif [ "${{ job.status }}" = "success" ]; then |             echo "Created: ${{ steps.check_tag.outputs.created_count }}" | ||||||
|             echo "✅ Successfully created tag ${{ steps.commit.outputs.tag_name }} for commit ${{ steps.commit.outputs.sha }}" |             echo "Skipped (already existed): ${{ steps.check_tag.outputs.skipped_count }}" | ||||||
|  |             echo "Failed: ${{ steps.check_tag.outputs.failed_count }}" | ||||||
|  |             if [ "${{ steps.check_tag.outputs.failed_count }}" = "0" ]; then | ||||||
|  |               echo "✅ Completed tagging for push range ${{ github.event.before }}..${{ github.sha }}" | ||||||
|  |             else | ||||||
|  |               echo "❌ Some tags failed to create for push range ${{ github.event.before }}..${{ github.sha }}" | ||||||
|  |             fi | ||||||
|           else |           else | ||||||
|             echo "❌ Failed to create tag ${{ steps.commit.outputs.tag_name }} for commit ${{ steps.commit.outputs.sha }}" |             if [ "${{ steps.check_tag.outputs.failed_count }}" = "0" ]; then | ||||||
|           fi |               if [ "${{ steps.check_tag.outputs.created_count }}" = "0" ]; then | ||||||
|  |                 echo "✅ Tag ${{ steps.commit.outputs.tag_name }} already existed - no action needed" | ||||||
|  |               else | ||||||
|  |                 echo "✅ Successfully created tag ${{ steps.commit.outputs.tag_name }} for commit ${{ steps.commit.outputs.sha }}" | ||||||
|  |               fi | ||||||
|  |             else | ||||||
|  |               echo "❌ Failed to create tag ${{ steps.commit.outputs.tag_name }} for commit ${{ steps.commit.outputs.sha }}" | ||||||
|  |             fi | ||||||
|  |  | ||||||
|           echo "" |             echo "" | ||||||
|           echo "Tag details:" |             echo "Tag details:" | ||||||
|           echo "  Name: ${{ steps.commit.outputs.tag_name }}" |             echo "  Name: ${{ steps.commit.outputs.tag_name }}" | ||||||
|           echo "  Commit: ${{ steps.commit.outputs.sha }}" |             echo "  Commit: ${{ steps.commit.outputs.sha }}" | ||||||
|           echo "  Trigger: ${{ github.event_name }}" |             echo "  Trigger: ${{ github.event_name }}" | ||||||
|           if [ -n "${{ github.event.inputs.commit_sha }}" ]; then |             if [ -n "${{ github.event.inputs.commit_sha }}" ]; then | ||||||
|             echo "  Manual commit: ${{ github.event.inputs.commit_sha }}" |               echo "  Manual commit: ${{ github.event.inputs.commit_sha }}" | ||||||
|  |             fi | ||||||
|           fi |           fi | ||||||
|  | |||||||
| @ -833,8 +833,7 @@ exclude_patterns = [ | |||||||
| command = [ | command = [ | ||||||
|     'python3', |     'python3', | ||||||
|     'tools/linter/adapters/grep_linter.py', |     'tools/linter/adapters/grep_linter.py', | ||||||
|     '--pattern=cudaSetDevice(', |     '--pattern=(cudaSetDevice|cudaGetDevice)\\(', | ||||||
|     '--pattern=cudaGetDevice(', |  | ||||||
|     '--linter-name=RAWCUDADEVICE', |     '--linter-name=RAWCUDADEVICE', | ||||||
|     '--error-name=raw CUDA API usage', |     '--error-name=raw CUDA API usage', | ||||||
|     """--error-description=\ |     """--error-description=\ | ||||||
| @ -1138,11 +1137,8 @@ command = [ | |||||||
| [[linter]] | [[linter]] | ||||||
| code = 'WORKFLOWSYNC' | code = 'WORKFLOWSYNC' | ||||||
| include_patterns = [ | include_patterns = [ | ||||||
|     '.github/workflows/pull.yml', |     '.github/workflows/*.yml', | ||||||
|     '.github/workflows/trunk.yml', |     '.github/workflows/*.yaml', | ||||||
|     '.github/workflows/periodic.yml', |  | ||||||
|     '.github/workflows/mac-mps.yml', |  | ||||||
|     '.github/workflows/slow.yml', |  | ||||||
| ] | ] | ||||||
| command = [ | command = [ | ||||||
|     'python3', |     'python3', | ||||||
|  | |||||||
| @ -289,14 +289,15 @@ IF(USE_FBGEMM_GENAI) | |||||||
|  |  | ||||||
|     set_target_properties(fbgemm_genai PROPERTIES POSITION_INDEPENDENT_CODE ON) |     set_target_properties(fbgemm_genai PROPERTIES POSITION_INDEPENDENT_CODE ON) | ||||||
|  |  | ||||||
|     set(fbgemm_genai_mx8mx8bf16_grouped |     set(fbgemm_genai_cuh | ||||||
|       "${FBGEMM_GENAI_SRCS}/cutlass_extensions/mx8mx8bf16_grouped/" |       "${FBGEMM_GENAI_SRCS}/cutlass_extensions/mx8mx8bf16_grouped/" | ||||||
|  |       "${FBGEMM_GENAI_SRCS}/" | ||||||
|     ) |     ) | ||||||
|  |  | ||||||
|     target_include_directories(fbgemm_genai PRIVATE |     target_include_directories(fbgemm_genai PRIVATE | ||||||
|       ${FBGEMM_THIRD_PARTY}/cutlass/include |       ${FBGEMM_THIRD_PARTY}/cutlass/include | ||||||
|       ${FBGEMM_THIRD_PARTY}/cutlass/tools/util/include |       ${FBGEMM_THIRD_PARTY}/cutlass/tools/util/include | ||||||
|       ${fbgemm_genai_mx8mx8bf16_grouped} |       ${fbgemm_genai_cuh} | ||||||
|       ${FBGEMM_GENAI_SRCS}/common/include/   # includes fbgemm_gpu/quantize/utils.h, fbgemm_gpu/quantize/tuning_cache.hpp |       ${FBGEMM_GENAI_SRCS}/common/include/   # includes fbgemm_gpu/quantize/utils.h, fbgemm_gpu/quantize/tuning_cache.hpp | ||||||
|       ${FBGEMM_GENAI_SRCS}/include/          # includes fbgemm_gpu/torch_ops.h |       ${FBGEMM_GENAI_SRCS}/include/          # includes fbgemm_gpu/torch_ops.h | ||||||
|     ) |     ) | ||||||
|  | |||||||
| @ -19,6 +19,7 @@ | |||||||
| #include <ATen/detail/MPSHooksInterface.h> | #include <ATen/detail/MPSHooksInterface.h> | ||||||
| #include <ATen/detail/MTIAHooksInterface.h> | #include <ATen/detail/MTIAHooksInterface.h> | ||||||
| #include <ATen/detail/PrivateUse1HooksInterface.h> | #include <ATen/detail/PrivateUse1HooksInterface.h> | ||||||
|  | #include <ATen/detail/XLAHooksInterface.h> | ||||||
| #include <ATen/detail/XPUHooksInterface.h> | #include <ATen/detail/XPUHooksInterface.h> | ||||||
| #include <c10/core/QEngine.h> | #include <c10/core/QEngine.h> | ||||||
| #include <c10/core/impl/DeviceGuardImplInterface.h> | #include <c10/core/impl/DeviceGuardImplInterface.h> | ||||||
| @ -88,6 +89,8 @@ class TORCH_API Context { | |||||||
|       return at::detail::getHIPHooks(); |       return at::detail::getHIPHooks(); | ||||||
|     } else if (opt_device_type == at::kHPU) { |     } else if (opt_device_type == at::kHPU) { | ||||||
|       return at::detail::getHPUHooks(); |       return at::detail::getHPUHooks(); | ||||||
|  |     } else if (opt_device_type == at::kXLA) { | ||||||
|  |       return at::detail::getXLAHooks(); | ||||||
|     } else { |     } else { | ||||||
|       TORCH_CHECK( |       TORCH_CHECK( | ||||||
|           false, |           false, | ||||||
| @ -196,7 +199,7 @@ class TORCH_API Context { | |||||||
|     return c10::impl::hasDeviceGuardImpl(c10::DeviceType::IPU); |     return c10::impl::hasDeviceGuardImpl(c10::DeviceType::IPU); | ||||||
|   } |   } | ||||||
|   static bool hasXLA() { |   static bool hasXLA() { | ||||||
|     return c10::impl::hasDeviceGuardImpl(c10::DeviceType::XLA); |     return detail::getXLAHooks().hasXLA(); | ||||||
|   } |   } | ||||||
|   static bool hasXPU() { |   static bool hasXPU() { | ||||||
|     return detail::getXPUHooks().hasXPU(); |     return detail::getXPUHooks().hasXPU(); | ||||||
|  | |||||||
| @ -59,9 +59,7 @@ struct TORCH_API Generator { | |||||||
|  |  | ||||||
|   explicit Generator(c10::intrusive_ptr<c10::GeneratorImpl> gen_impl) |   explicit Generator(c10::intrusive_ptr<c10::GeneratorImpl> gen_impl) | ||||||
|    : impl_(std::move(gen_impl)) { |    : impl_(std::move(gen_impl)) { | ||||||
|     if (impl_.get() == nullptr) { |     TORCH_CHECK(impl_.get(), "GeneratorImpl with nullptr is not supported"); | ||||||
|       throw std::runtime_error("GeneratorImpl with nullptr is not supported"); |  | ||||||
|     } |  | ||||||
|   } |   } | ||||||
|  |  | ||||||
|   bool operator==(const Generator& rhs) const { |   bool operator==(const Generator& rhs) const { | ||||||
|  | |||||||
| @ -111,9 +111,7 @@ class TORCH_API TensorBase { | |||||||
|   explicit TensorBase( |   explicit TensorBase( | ||||||
|       c10::intrusive_ptr<TensorImpl, UndefinedTensorImpl> tensor_impl) |       c10::intrusive_ptr<TensorImpl, UndefinedTensorImpl> tensor_impl) | ||||||
|       : impl_(std::move(tensor_impl)) { |       : impl_(std::move(tensor_impl)) { | ||||||
|     if (impl_.get() == nullptr) { |     TORCH_CHECK(impl_.get(), "TensorImpl with nullptr is not supported"); | ||||||
|       throw std::runtime_error("TensorImpl with nullptr is not supported"); |  | ||||||
|     } |  | ||||||
|   } |   } | ||||||
|   TensorBase(const TensorBase&) = default; |   TensorBase(const TensorBase&) = default; | ||||||
|   TensorBase(TensorBase&&) noexcept = default; |   TensorBase(TensorBase&&) noexcept = default; | ||||||
|  | |||||||
| @ -109,6 +109,10 @@ TORCH_LIBRARY_IMPL(_, AutogradHPU, m) { | |||||||
|   m.fallback(AUTOGRAD_FALLBACK); |   m.fallback(AUTOGRAD_FALLBACK); | ||||||
| } | } | ||||||
|  |  | ||||||
|  | TORCH_LIBRARY_IMPL(_, AutogradPrivateUse1, m) { | ||||||
|  |   m.fallback(AUTOGRAD_FALLBACK); | ||||||
|  | } | ||||||
|  |  | ||||||
| #undef AUTOGRAD_FALLBACK | #undef AUTOGRAD_FALLBACK | ||||||
|  |  | ||||||
| } // namespace | } // namespace | ||||||
|  | |||||||
| @ -442,11 +442,17 @@ RegistrationHandleRAII Dispatcher::registerFallback(DispatchKey dispatchKey, Ker | |||||||
|  |  | ||||||
|   auto idx = getDispatchTableIndexForDispatchKey(dispatchKey); |   auto idx = getDispatchTableIndexForDispatchKey(dispatchKey); | ||||||
|   TORCH_CHECK(idx >= 0 && static_cast<uint64_t>(idx) < backendFallbackKernels_.size(), "idx=", idx); |   TORCH_CHECK(idx >= 0 && static_cast<uint64_t>(idx) < backendFallbackKernels_.size(), "idx=", idx); | ||||||
|  |   // NB: Perserve BC for registering fallback for AutogradPrivateUse1 multiple time, | ||||||
|  |   // refer to https://github.com/pytorch/pytorch/issues/163979 for more informations. | ||||||
|   TORCH_CHECK( |   TORCH_CHECK( | ||||||
|     !backendFallbackKernels_[idx].kernel.isValid(), |       dispatchKey == DispatchKey::AutogradPrivateUse1 || | ||||||
|     "Tried to register multiple backend fallbacks for the same dispatch key ", dispatchKey, "; previous registration ", |           !backendFallbackKernels_[idx].kernel.isValid(), | ||||||
|     backendFallbackKernels_[idx].debug, ", new registration ", debug |       "Tried to register multiple backend fallbacks for the same dispatch key ", | ||||||
|   ); |       dispatchKey, | ||||||
|  |       "; previous registration ", | ||||||
|  |       backendFallbackKernels_[idx].debug, | ||||||
|  |       ", new registration ", | ||||||
|  |       debug); | ||||||
|   // NB: inferred function schema is always nullptr for fallbacks, as fallbacks |   // NB: inferred function schema is always nullptr for fallbacks, as fallbacks | ||||||
|   // cannot be unboxed |   // cannot be unboxed | ||||||
|   backendFallbackKernels_[idx] = impl::AnnotatedKernel(std::move(kernel), nullptr, std::move(debug)); |   backendFallbackKernels_[idx] = impl::AnnotatedKernel(std::move(kernel), nullptr, std::move(debug)); | ||||||
|  | |||||||
| @ -68,11 +68,7 @@ Symbol InternedStrings::_symbol(const std::string& s) { | |||||||
|     return it->second; |     return it->second; | ||||||
|  |  | ||||||
|   auto pos = s.find("::"); |   auto pos = s.find("::"); | ||||||
|   if (pos == std::string::npos) { |   TORCH_CHECK(pos != std::string::npos, "all symbols must have a namespace, <namespace>::<string>, but found: ", s); | ||||||
|     std::stringstream ss; |  | ||||||
|     ss << "all symbols must have a namespace, <namespace>::<string>, but found: " << s; |  | ||||||
|     throw std::runtime_error(ss.str()); |  | ||||||
|   } |  | ||||||
|   Symbol ns = _symbol("namespaces::" + s.substr(0, pos)); |   Symbol ns = _symbol("namespaces::" + s.substr(0, pos)); | ||||||
|  |  | ||||||
|   Symbol sym(sym_to_info_.size()); |   Symbol sym(sym_to_info_.size()); | ||||||
| @ -121,12 +117,7 @@ std::string Symbol::domainString() const { | |||||||
| } | } | ||||||
|  |  | ||||||
| Symbol Symbol::fromDomainAndUnqualString(const std::string & d, const std::string & s) { | Symbol Symbol::fromDomainAndUnqualString(const std::string & d, const std::string & s) { | ||||||
|   if (d.compare(0, domain_prefix().size(), domain_prefix()) != 0) { |   TORCH_CHECK(d.compare(0, domain_prefix().size(), domain_prefix()) == 0, "Symbol: domain string is expected to be prefixed with '", domain_prefix(), "', e.g. 'org.pytorch.aten'"); | ||||||
|     std::ostringstream ss; |  | ||||||
|     ss << "Symbol: domain string is expected to be prefixed with '" |  | ||||||
|        << domain_prefix() << "', e.g. 'org.pytorch.aten'"; |  | ||||||
|     throw std::runtime_error(ss.str()); |  | ||||||
|   } |  | ||||||
|   std::string qualString = d.substr(domain_prefix().size()) + "::" + s; |   std::string qualString = d.substr(domain_prefix().size()) + "::" + s; | ||||||
|   return fromQualString(qualString); |   return fromQualString(qualString); | ||||||
| } | } | ||||||
|  | |||||||
| @ -7,6 +7,7 @@ | |||||||
| #include <ATen/core/jit_type.h> | #include <ATen/core/jit_type.h> | ||||||
| #include <ATen/core/stack.h> | #include <ATen/core/stack.h> | ||||||
| #include <ATen/core/type_factory.h> | #include <ATen/core/type_factory.h> | ||||||
|  | #include <c10/util/Exception.h> | ||||||
| #include <c10/util/StringUtil.h> | #include <c10/util/StringUtil.h> | ||||||
| #include <c10/util/hash.h> | #include <c10/util/hash.h> | ||||||
| #include <c10/util/irange.h> | #include <c10/util/irange.h> | ||||||
| @ -412,7 +413,7 @@ size_t IValue::hash(const IValue& v) { | |||||||
|     case Tag::Enum: |     case Tag::Enum: | ||||||
|     case Tag::Stream: |     case Tag::Stream: | ||||||
|     case Tag::Uninitialized: |     case Tag::Uninitialized: | ||||||
|       throw std::runtime_error( |       TORCH_CHECK(false, | ||||||
|           "unhashable type: '" + v.type()->repr_str() + "'"); |           "unhashable type: '" + v.type()->repr_str() + "'"); | ||||||
|   } |   } | ||||||
|   // the above switch should be exhaustive |   // the above switch should be exhaustive | ||||||
|  | |||||||
| @ -8,6 +8,7 @@ | |||||||
| #include <ATen/core/type_factory.h> | #include <ATen/core/type_factory.h> | ||||||
| #include <ATen/core/qualified_name.h> | #include <ATen/core/qualified_name.h> | ||||||
| #include <c10/util/TypeList.h> | #include <c10/util/TypeList.h> | ||||||
|  | #include <c10/util/Exception.h> | ||||||
| #include <optional> | #include <optional> | ||||||
| #include <c10/core/SymFloat.h> | #include <c10/core/SymFloat.h> | ||||||
| #include <c10/core/SymBool.h> | #include <c10/core/SymBool.h> | ||||||
| @ -116,10 +117,8 @@ struct SingleElementType : public SharedType { | |||||||
|  |  | ||||||
|  protected: |  protected: | ||||||
|   SingleElementType(TypePtr elem) : SharedType(Kind), elem(std::move(elem)) { |   SingleElementType(TypePtr elem) : SharedType(Kind), elem(std::move(elem)) { | ||||||
|     if (!this->elem) { |     TORCH_CHECK(this->elem, c10::str( | ||||||
|       throw std::runtime_error(c10::str( |  | ||||||
|             "Can not create ", typeKindToString(Kind), " with None type")); |             "Can not create ", typeKindToString(Kind), " with None type")); | ||||||
|     } |  | ||||||
|   } |   } | ||||||
|  |  | ||||||
|  private: |  private: | ||||||
| @ -416,16 +415,12 @@ struct TORCH_API SymbolicShape { | |||||||
|   } |   } | ||||||
|  |  | ||||||
|   ShapeSymbol operator[](size_t i) const { |   ShapeSymbol operator[](size_t i) const { | ||||||
|     if (!dims_) { |     TORCH_CHECK(dims_, "Rank isn't fixed"); | ||||||
|       throw std::runtime_error("Rank isn't fixed"); |  | ||||||
|     } |  | ||||||
|     return (*dims_).at(i); |     return (*dims_).at(i); | ||||||
|   } |   } | ||||||
|  |  | ||||||
|   ShapeSymbol at(size_t i) const { |   ShapeSymbol at(size_t i) const { | ||||||
|     if (!dims_) { |     TORCH_CHECK(dims_, "Rank isn't fixed"); | ||||||
|       throw std::runtime_error("Rank isn't fixed"); |  | ||||||
|     } |  | ||||||
|     return (*dims_).at(i); |     return (*dims_).at(i); | ||||||
|   } |   } | ||||||
|  |  | ||||||
| @ -520,9 +515,7 @@ struct VaryingShape { | |||||||
|   } |   } | ||||||
|  |  | ||||||
|   const std::optional<T> &operator[](size_t i) const { |   const std::optional<T> &operator[](size_t i) const { | ||||||
|     if (!dims_) { |     TORCH_CHECK(dims_, "Rank isn't fixed"); | ||||||
|       throw std::runtime_error("Rank isn't fixed"); |  | ||||||
|     } |  | ||||||
|     return (*dims_).at(i); |     return (*dims_).at(i); | ||||||
|   } |   } | ||||||
|  |  | ||||||
| @ -957,9 +950,7 @@ struct TORCH_API DictType : public SharedType { | |||||||
|  |  | ||||||
|   TypePtr createWithContained( |   TypePtr createWithContained( | ||||||
|       std::vector<TypePtr> contained_types) const override { |       std::vector<TypePtr> contained_types) const override { | ||||||
|     if (contained_types.size() != 2) { |     TORCH_CHECK(contained_types.size() == 2, "Expected 2 contained types"); | ||||||
|       throw std::runtime_error("Expected 2 contained types"); |  | ||||||
|     } |  | ||||||
|     return create(std::move(contained_types.at(0)), std::move(contained_types.at(1))); |     return create(std::move(contained_types.at(0)), std::move(contained_types.at(1))); | ||||||
|   } |   } | ||||||
|  |  | ||||||
|  | |||||||
| @ -185,11 +185,11 @@ struct TORCH_API Type { | |||||||
|         : repr_(nullptr) {} |         : repr_(nullptr) {} | ||||||
|  |  | ||||||
|     /* implicit */ SingletonOrSharedTypePtr(SingletonTypePtr<T> p) |     /* implicit */ SingletonOrSharedTypePtr(SingletonTypePtr<T> p) | ||||||
|         : repr_(p) {} |         : repr_(makeSingletonSharedPtr(p.get())) {} | ||||||
|  |  | ||||||
|     template <typename U, std::enable_if_t<std::is_convertible_v<U*, T*>, bool> = true> |     template <typename U, std::enable_if_t<std::is_convertible_v<U*, T*>, bool> = true> | ||||||
|     /* implicit */ SingletonOrSharedTypePtr(SingletonTypePtr<U> p) |     /* implicit */ SingletonOrSharedTypePtr(SingletonTypePtr<U> p) | ||||||
|         : repr_(SingletonTypePtr<T>(p.get())) {} |         : repr_(makeSingletonSharedPtr(static_cast<T*>(p.get()))) {} | ||||||
|  |  | ||||||
|  |  | ||||||
|     // We need to support construction from T* for pybind. The problem |     // We need to support construction from T* for pybind. The problem | ||||||
| @ -202,8 +202,8 @@ struct TORCH_API Type { | |||||||
|     // Case 2: if T is exactly Type, we need to do a dynamic_cast to |     // Case 2: if T is exactly Type, we need to do a dynamic_cast to | ||||||
|     // check if it's a SharedType and do the right thing. |     // check if it's a SharedType and do the right thing. | ||||||
|     // |     // | ||||||
|     // Case 3: Otherwise, T is not a SharedType. (debug-check this |     // Case 3: Otherwise, T is not a SharedType. Use a singleton | ||||||
|     // assumption!) Use a singleton pointer. |     // pointer. | ||||||
|  |  | ||||||
|     template <typename U = T, std::enable_if_t<std::is_base_of_v<SharedType, U>, bool> = true> |     template <typename U = T, std::enable_if_t<std::is_base_of_v<SharedType, U>, bool> = true> | ||||||
|     /* implicit */ SingletonOrSharedTypePtr(T* p) : SingletonOrSharedTypePtr(static_cast<typename detail::as_shared_type<U>::type>(p)->shared_from_this()) {} |     /* implicit */ SingletonOrSharedTypePtr(T* p) : SingletonOrSharedTypePtr(static_cast<typename detail::as_shared_type<U>::type>(p)->shared_from_this()) {} | ||||||
| @ -211,15 +211,15 @@ struct TORCH_API Type { | |||||||
|     template <typename U = T, std::enable_if_t<std::is_same_v<Type, U>, bool> = true> |     template <typename U = T, std::enable_if_t<std::is_same_v<Type, U>, bool> = true> | ||||||
|     /* implicit */ SingletonOrSharedTypePtr(T* p) { |     /* implicit */ SingletonOrSharedTypePtr(T* p) { | ||||||
|       if (auto* shared_p = dynamic_cast<typename detail::as_shared_type<U>::type>(p)) { |       if (auto* shared_p = dynamic_cast<typename detail::as_shared_type<U>::type>(p)) { | ||||||
|         repr_ = Repr(shared_p->shared_from_this()); |         repr_ = shared_p->shared_from_this(); | ||||||
|       } else { |       } else { | ||||||
|         repr_ = Repr(p); |         repr_ = makeSingletonSharedPtr(p); | ||||||
|       } |       } | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     template <typename U = T, std::enable_if_t<!std::is_same_v<Type, U> && !std::is_base_of_v<SharedType, U>, bool> = true> |     template <typename U = T, std::enable_if_t<!std::is_same_v<Type, U> && !std::is_base_of_v<SharedType, U>, bool> = true> | ||||||
|     /* implicit */ SingletonOrSharedTypePtr(T* p) |     /* implicit */ SingletonOrSharedTypePtr(T* p) | ||||||
|         : repr_(p) { |         : repr_(makeSingletonSharedPtr(p)) { | ||||||
|       TORCH_INTERNAL_ASSERT_DEBUG_ONLY(dynamic_cast<typename detail::as_shared_type<U>::type>(p) == nullptr); |       TORCH_INTERNAL_ASSERT_DEBUG_ONLY(dynamic_cast<typename detail::as_shared_type<U>::type>(p) == nullptr); | ||||||
|     } |     } | ||||||
|  |  | ||||||
| @ -230,19 +230,19 @@ struct TORCH_API Type { | |||||||
|     ~SingletonOrSharedTypePtr() = default; |     ~SingletonOrSharedTypePtr() = default; | ||||||
|  |  | ||||||
|     T* get() const { |     T* get() const { | ||||||
|       return repr_.isSharedAndNonNull() ? repr_.shared_.repr_.get() : static_cast<T*>(repr_.rawRepr().first); |       return repr_.get(); | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     operator bool() const { |     operator bool() const { | ||||||
|       return repr_.isNonNull(); |       return repr_ != nullptr; | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     bool operator==(std::nullptr_t) const { |     bool operator==(std::nullptr_t) const { | ||||||
|       return !repr_.isNonNull(); |       return repr_ == nullptr; | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     bool operator!=(std::nullptr_t) const { |     bool operator!=(std::nullptr_t) const { | ||||||
|       return repr_.isNonNull(); |       return repr_ != nullptr; | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     template <typename U = T, std::enable_if_t<!std::is_same_v<std::remove_const_t<U>, void>, bool> = true> |     template <typename U = T, std::enable_if_t<!std::is_same_v<std::remove_const_t<U>, void>, bool> = true> | ||||||
| @ -255,138 +255,14 @@ struct TORCH_API Type { | |||||||
|     } |     } | ||||||
|  |  | ||||||
|   private: |   private: | ||||||
|     // NOTE: SharedPtrWrapper exists to work around a baffling bug in |     // Use shared_ptr's aliasing constructor to create a non-owning pointer | ||||||
|     // nvcc; see comment in destroy() below. |     // to a singleton. The lifetime is tied to the null shared_ptr, so there's | ||||||
|     struct SharedPtrWrapper { |     // no reference counting overhead for the singleton itself. | ||||||
|       SharedPtrWrapper(std::shared_ptr<T> &&x) |     static std::shared_ptr<T> makeSingletonSharedPtr(T* ptr) { | ||||||
|           : repr_(std::move(x)) {} |       return std::shared_ptr<T>(std::shared_ptr<T>(), ptr); | ||||||
|       std::shared_ptr<T> repr_; |     } | ||||||
|     }; |  | ||||||
|     union Repr { |  | ||||||
|       Repr() : Repr(nullptr) {} |  | ||||||
|  |  | ||||||
|       explicit Repr(std::shared_ptr<T> x) |     std::shared_ptr<T> repr_; | ||||||
|           : shared_(std::move(x)) {} |  | ||||||
|  |  | ||||||
|       explicit Repr(std::nullptr_t) |  | ||||||
|           : singletonRepr_(nullptr) {} |  | ||||||
|  |  | ||||||
|       explicit Repr(SingletonTypePtr<T> p) |  | ||||||
|           : singletonRepr_(p.get()) {} |  | ||||||
|  |  | ||||||
|       ~Repr() { |  | ||||||
|         destroy(); |  | ||||||
|       } |  | ||||||
|  |  | ||||||
|       // NOTE: the only non-UB way to access our null state is through |  | ||||||
|       // rawRepr(), because our copy operation doesn't preserve which |  | ||||||
|       // union member is active for null pointers. |  | ||||||
|       Repr(const Repr& rhs) { |  | ||||||
|         if (rhs.isSharedAndNonNull()) { |  | ||||||
|           new (&shared_) SharedPtrWrapper(rhs.shared_); |  | ||||||
|         } else { |  | ||||||
|           singletonRepr_.singleton_ = static_cast<T*>(rhs.rawRepr().first); |  | ||||||
|           TORCH_INTERNAL_ASSERT_DEBUG_ONLY(rhs.singletonRepr_.unused_ == nullptr); |  | ||||||
|           singletonRepr_.unused_ = nullptr; |  | ||||||
|         } |  | ||||||
|       } |  | ||||||
|  |  | ||||||
|       Repr(Repr&& rhs) noexcept { |  | ||||||
|         if (rhs.isSharedAndNonNull()) { |  | ||||||
|           new (&shared_) SharedPtrWrapper(std::move(rhs.shared_)); |  | ||||||
|         } else { |  | ||||||
|           singletonRepr_.singleton_ = static_cast<T*>(rhs.rawRepr().first); |  | ||||||
|           TORCH_INTERNAL_ASSERT_DEBUG_ONLY(rhs.singletonRepr_.unused_ == nullptr); |  | ||||||
|           singletonRepr_.unused_ = nullptr; |  | ||||||
|         } |  | ||||||
|       } |  | ||||||
|  |  | ||||||
|       Repr& operator=(const Repr& rhs) { |  | ||||||
|         if (&rhs == this) { |  | ||||||
|           return *this; |  | ||||||
|         } |  | ||||||
|         if (rhs.isSharedAndNonNull()) { |  | ||||||
|           if (isSharedAndNonNull()) { |  | ||||||
|             shared_ = rhs.shared_; |  | ||||||
|           } else { |  | ||||||
|             new (&shared_) SharedPtrWrapper(rhs.shared_); |  | ||||||
|           } |  | ||||||
|         } else { |  | ||||||
|           if (isSharedAndNonNull()) { |  | ||||||
|             destroy(); |  | ||||||
|           } |  | ||||||
|           singletonRepr_.singleton_ = static_cast<T*>(rhs.rawRepr().first); |  | ||||||
|           TORCH_INTERNAL_ASSERT_DEBUG_ONLY(rhs.rawRepr().nullIfSingleton_ == nullptr); |  | ||||||
|           singletonRepr_.unused_ = nullptr; |  | ||||||
|         } |  | ||||||
|         return *this; |  | ||||||
|       } |  | ||||||
|  |  | ||||||
|       Repr& operator=(Repr&& rhs) noexcept { |  | ||||||
|         if (&rhs == this) { |  | ||||||
|           return *this; |  | ||||||
|         } |  | ||||||
|         if (rhs.isSharedAndNonNull()) { |  | ||||||
|           if (isSharedAndNonNull()) { |  | ||||||
|             shared_ = std::move(rhs.shared_); |  | ||||||
|           } else { |  | ||||||
|             new (&shared_) SharedPtrWrapper(std::move(rhs.shared_)); |  | ||||||
|           } |  | ||||||
|         } else { |  | ||||||
|           if (isSharedAndNonNull()) { |  | ||||||
|             destroy(); |  | ||||||
|           } |  | ||||||
|           singletonRepr_.singleton_ = static_cast<T*>(rhs.rawRepr().first); |  | ||||||
|           TORCH_INTERNAL_ASSERT_DEBUG_ONLY(rhs.rawRepr().nullIfSingleton_ == nullptr); |  | ||||||
|           singletonRepr_.unused_ = nullptr; |  | ||||||
|         } |  | ||||||
|         return *this; |  | ||||||
|       } |  | ||||||
|  |  | ||||||
|       SharedPtrWrapper shared_; |  | ||||||
|  |  | ||||||
|       struct SingletonRepr { |  | ||||||
|         explicit SingletonRepr(T* s) : singleton_(s) {} |  | ||||||
|         T* singleton_; |  | ||||||
|         void* unused_ = nullptr; |  | ||||||
|       } singletonRepr_; |  | ||||||
|       struct RawRepr { |  | ||||||
|         void* first; |  | ||||||
|         void* nullIfSingleton_; |  | ||||||
|       }; |  | ||||||
|  |  | ||||||
|       // It is UB to read the singleton part of Repr if it was |  | ||||||
|       // constructed as a shared_ptr and vice versa, but memcpying out |  | ||||||
|       // the representation is always OK, so here's an accessor to obey |  | ||||||
|       // the letter of the law. |  | ||||||
|       RawRepr rawRepr() const { |  | ||||||
|         RawRepr repr{}; |  | ||||||
|         memcpy(&repr, reinterpret_cast<const char *>(this), sizeof(RawRepr)); |  | ||||||
|         return repr; |  | ||||||
|       } |  | ||||||
|  |  | ||||||
|       bool isNonNull() const { |  | ||||||
|         auto repr = rawRepr(); |  | ||||||
|         TORCH_INTERNAL_ASSERT_DEBUG_ONLY(repr.nullIfSingleton_ == nullptr || repr.first != nullptr); |  | ||||||
|         return repr.first != nullptr; |  | ||||||
|       } |  | ||||||
|  |  | ||||||
|       bool isSharedAndNonNull() const { |  | ||||||
|         return rawRepr().nullIfSingleton_ != nullptr; |  | ||||||
|       } |  | ||||||
|  |  | ||||||
|      private: |  | ||||||
|       void destroy() { |  | ||||||
|         if (isSharedAndNonNull()) { |  | ||||||
|           // Without SharedPtrWrapper, this line would read |  | ||||||
|           // `shared_.~shared_ptr()` and nvcc would complain with |  | ||||||
|           // "error: expected primary-expression before '>' token" |  | ||||||
|           // referring to the "t" in "shared_ptr". SharedPtrWrapper |  | ||||||
|           // exists to work around this compiler bug. |  | ||||||
|           shared_.~SharedPtrWrapper(); |  | ||||||
|         } |  | ||||||
|       } |  | ||||||
|     } repr_; |  | ||||||
|   }; |   }; | ||||||
|  |  | ||||||
|   using TypePtr = SingletonOrSharedTypePtr<Type>; |   using TypePtr = SingletonOrSharedTypePtr<Type>; | ||||||
|  | |||||||
| @ -8,6 +8,7 @@ | |||||||
| #include <ATen/core/jit_type.h> | #include <ATen/core/jit_type.h> | ||||||
| #include <c10/macros/Macros.h> | #include <c10/macros/Macros.h> | ||||||
| #include <c10/util/env.h> | #include <c10/util/env.h> | ||||||
|  | #include <c10/util/Exception.h> | ||||||
| #include <c10/util/flat_hash_map.h> | #include <c10/util/flat_hash_map.h> | ||||||
| #include <c10/util/irange.h> | #include <c10/util/irange.h> | ||||||
| #include <array> | #include <array> | ||||||
| @ -826,9 +827,7 @@ TupleType::TupleType( | |||||||
|     : NamedType(TypeKind::TupleType, std::move(name)), |     : NamedType(TypeKind::TupleType, std::move(name)), | ||||||
|       elements_(std::move(elements)), |       elements_(std::move(elements)), | ||||||
|       has_free_variables_(std::any_of(elements_.begin(), elements_.end(), [](const TypePtr& v) { |       has_free_variables_(std::any_of(elements_.begin(), elements_.end(), [](const TypePtr& v) { | ||||||
|         if (!v) { |         TORCH_CHECK(v, "Can not create tuple with None type"); | ||||||
|           throw std::runtime_error("Can not create tuple with None type"); |  | ||||||
|         } |  | ||||||
|         return v->hasFreeVariables(); |         return v->hasFreeVariables(); | ||||||
|       })), schema_(std::move(schema)) { |       })), schema_(std::move(schema)) { | ||||||
|  |  | ||||||
|  | |||||||
| @ -104,71 +104,6 @@ class Vectorized<float> { | |||||||
|     } |     } | ||||||
|     return b; |     return b; | ||||||
|   } |   } | ||||||
|   // Implementation is picked from |  | ||||||
|   // https://github.com/ARM-software/ComputeLibrary/blob/v25.01/src/core/NEON/SVEMath.inl#L105 |  | ||||||
|   inline svfloat32_t svexp_f32_z(svbool_t pg, svfloat32_t x) const { |  | ||||||
|     const auto c1 = |  | ||||||
|         svreinterpret_f32_u32(svdup_n_u32(0x3f7ffff6)); // x^1: 0x1.ffffecp-1f |  | ||||||
|     const auto c2 = |  | ||||||
|         svreinterpret_f32_u32(svdup_n_u32(0x3efffedb)); // x^2: 0x1.fffdb6p-2f |  | ||||||
|     const auto c3 = |  | ||||||
|         svreinterpret_f32_u32(svdup_n_u32(0x3e2aaf33)); // x^3: 0x1.555e66p-3f |  | ||||||
|     const auto c4 = |  | ||||||
|         svreinterpret_f32_u32(svdup_n_u32(0x3d2b9f17)); // x^4: 0x1.573e2ep-5f |  | ||||||
|     const auto c5 = |  | ||||||
|         svreinterpret_f32_u32(svdup_n_u32(0x3c072010)); // x^5: 0x1.0e4020p-7f |  | ||||||
|     const auto shift = svreinterpret_f32_u32( |  | ||||||
|         svdup_n_u32(0x4b00007f)); // 2^23 + 127 = 0x1.0000fep23f |  | ||||||
|     const auto inv_ln2 = svreinterpret_f32_u32( |  | ||||||
|         svdup_n_u32(0x3fb8aa3b)); // 1 / ln(2) = 0x1.715476p+0f |  | ||||||
|     const auto neg_ln2_hi = svreinterpret_f32_u32(svdup_n_u32( |  | ||||||
|         0xbf317200)); // -ln(2) from bits  -1 to -19: -0x1.62e400p-1f |  | ||||||
|     const auto neg_ln2_lo = svreinterpret_f32_u32(svdup_n_u32( |  | ||||||
|         0xb5bfbe8e)); // -ln(2) from bits -20 to -42: -0x1.7f7d1cp-20f |  | ||||||
|     const auto inf = svdup_n_f32(std::numeric_limits<float>::infinity()); |  | ||||||
|     const auto max_input = svdup_n_f32(88.37f); // Approximately ln(2^127.5) |  | ||||||
|     const auto zero = svdup_n_f32(0.f); |  | ||||||
|     const auto min_input = svdup_n_f32(-86.64f); // Approximately ln(2^-125) |  | ||||||
|     // Range reduction: |  | ||||||
|     //   e^x = 2^n * e^r |  | ||||||
|     // where: |  | ||||||
|     //   n = floor(x / ln(2)) |  | ||||||
|     //   r = x - n * ln(2) |  | ||||||
|     // |  | ||||||
|     // By adding x / ln(2) with 2^23 + 127 (shift): |  | ||||||
|     //   * As FP32 fraction part only has 23-bits, the addition of 2^23 + 127 |  | ||||||
|     //   forces decimal part |  | ||||||
|     //     of x / ln(2) out of the result. The integer part of x / ln(2) (i.e. |  | ||||||
|     //     n) + 127 will occupy the whole fraction part of z in FP32 format. |  | ||||||
|     //     Subtracting 2^23 + 127 (shift) from z will result in the integer part |  | ||||||
|     //     of x / ln(2) (i.e. n) because the decimal part has been pushed out |  | ||||||
|     //     and lost. |  | ||||||
|     //   * The addition of 127 makes the FP32 fraction part of z ready to be |  | ||||||
|     //   used as the exponent |  | ||||||
|     //     in FP32 format. Left shifting z by 23 bits will result in 2^n. |  | ||||||
|     const auto z = svmla_f32_z(pg, shift, x, inv_ln2); |  | ||||||
|     const auto n = svsub_f32_z(pg, z, shift); |  | ||||||
|     const auto scale = svreinterpret_f32_u32( |  | ||||||
|         svlsl_n_u32_z(pg, svreinterpret_u32_f32(z), 23)); // 2^n |  | ||||||
|     // The calculation of n * ln(2) is done using 2 steps to achieve accuracy |  | ||||||
|     // beyond FP32. This outperforms longer Taylor series (3-4 tabs) both in |  | ||||||
|     // term of accuracy and performance. |  | ||||||
|     const auto r_hi = svmla_f32_z(pg, x, n, neg_ln2_hi); |  | ||||||
|     const auto r = svmla_f32_z(pg, r_hi, n, neg_ln2_lo); |  | ||||||
|     // Compute the truncated Taylor series of e^r. |  | ||||||
|     //   poly = scale * (1 + c1 * r + c2 * r^2 + c3 * r^3 + c4 * r^4 + c5 * r^5) |  | ||||||
|     const auto r2 = svmul_f32_z(pg, r, r); |  | ||||||
|     const auto p1 = svmul_f32_z(pg, c1, r); |  | ||||||
|     const auto p23 = svmla_f32_z(pg, c2, c3, r); |  | ||||||
|     const auto p45 = svmla_f32_z(pg, c4, c5, r); |  | ||||||
|     const auto p2345 = svmla_f32_z(pg, p23, p45, r2); |  | ||||||
|     const auto p12345 = svmla_f32_z(pg, p1, p2345, r2); |  | ||||||
|     auto poly = svmla_f32_z(pg, scale, p12345, scale); |  | ||||||
|     // Handle underflow and overflow. |  | ||||||
|     poly = svsel_f32(svcmplt_f32(pg, x, min_input), zero, poly); |  | ||||||
|     poly = svsel_f32(svcmpgt_f32(pg, x, max_input), inf, poly); |  | ||||||
|     return poly; |  | ||||||
|   } |  | ||||||
|   static Vectorized<float> loadu(const void* ptr, int64_t count = size()) { |   static Vectorized<float> loadu(const void* ptr, int64_t count = size()) { | ||||||
|     if (count == size()) |     if (count == size()) | ||||||
|       return svld1_f32(ptrue, reinterpret_cast<const float*>(ptr)); |       return svld1_f32(ptrue, reinterpret_cast<const float*>(ptr)); | ||||||
| @ -313,11 +248,41 @@ class Vectorized<float> { | |||||||
|     return USE_SLEEF( |     return USE_SLEEF( | ||||||
|         Vectorized<float>(Sleef_expm1fx_u10sve(values)), map(std::expm1)); |         Vectorized<float>(Sleef_expm1fx_u10sve(values)), map(std::expm1)); | ||||||
|   } |   } | ||||||
|  |   // Implementation copied from Arm Optimized Routines: | ||||||
|  |   // https://github.com/ARM-software/optimized-routines/blob/master/math/aarch64/sve/expf.c | ||||||
|   Vectorized<float> exp_u20() const { |   Vectorized<float> exp_u20() const { | ||||||
|     return exp(); |     // special case to handle special inputs that are too large or too small | ||||||
|  |     // i.e. where there's at least one element x, s.t. |x| >= 87.3... | ||||||
|  |     svbool_t is_special_case = svacgt(svptrue_b32(), values, 0x1.5d5e2ap+6f); | ||||||
|  |     if (svptest_any(svptrue_b32(), is_special_case)) { | ||||||
|  |       return exp(); | ||||||
|  |     } | ||||||
|  |     const svfloat32_t ln2_hi = svdup_n_f32(0x1.62e4p-1f); | ||||||
|  |     const svfloat32_t ln2_lo = svdup_n_f32(0x1.7f7d1cp-20f); | ||||||
|  |     const svfloat32_t c1 = svdup_n_f32(0.5f); | ||||||
|  |     const svfloat32_t inv_ln2 = svdup_n_f32(0x1.715476p+0f); | ||||||
|  |  | ||||||
|  |     const float shift = 0x1.803f8p17f; | ||||||
|  |  | ||||||
|  |     /* n = round(x/(ln2/N)).  */ | ||||||
|  |     svfloat32_t z = svmad_x(svptrue_b32(), inv_ln2, values, shift); | ||||||
|  |     svfloat32_t n = svsub_x(svptrue_b32(), z, shift); | ||||||
|  |  | ||||||
|  |     /* r = x - n*ln2/N.  */ | ||||||
|  |     svfloat32_t r = values; | ||||||
|  |     r = svmls_x(svptrue_b32(), r, n, ln2_hi); | ||||||
|  |     r = svmls_x(svptrue_b32(), r, n, ln2_lo); | ||||||
|  |  | ||||||
|  |     /* scale = 2^(n/N).  */ | ||||||
|  |     svfloat32_t scale = svexpa(svreinterpret_u32(z)); | ||||||
|  |  | ||||||
|  |     /* poly(r) = exp(r) - 1 ~= r + 0.5 r^2.  */ | ||||||
|  |     svfloat32_t r2 = svmul_x(svptrue_b32(), r, r); | ||||||
|  |     svfloat32_t poly = svmla_x(svptrue_b32(), r, r2, c1); | ||||||
|  |     return svmla_x(svptrue_b32(), scale, scale, poly); | ||||||
|   } |   } | ||||||
|   Vectorized<float> fexp_u20() const { |   Vectorized<float> fexp_u20() const { | ||||||
|     return exp(); |     return exp_u20(); | ||||||
|   } |   } | ||||||
|   Vectorized<float> fmod(const Vectorized<float>& q) const {USE_SLEEF( |   Vectorized<float> fmod(const Vectorized<float>& q) const {USE_SLEEF( | ||||||
|       { return Vectorized<float>(Sleef_fmodfx_sve(values, q)); }, |       { return Vectorized<float>(Sleef_fmodfx_sve(values, q)); }, | ||||||
| @ -453,9 +418,11 @@ class Vectorized<float> { | |||||||
|         ptrue, svmax_f32_z(ptrue, values, CONST_MIN_TANH), CONST_MAX_TANH); |         ptrue, svmax_f32_z(ptrue, values, CONST_MIN_TANH), CONST_MAX_TANH); | ||||||
|  |  | ||||||
|     // Step 2: Calculate exp(2 * x), where x is the clamped value. |     // Step 2: Calculate exp(2 * x), where x is the clamped value. | ||||||
|     // svmul_f32_z computes 2 * x, and svexp_f32_z computes the exponential of |     // svmul_f32_z computes 2 * x, and exp_u20() computes the exponential of | ||||||
|     // the result. |     // the result (via Vectorized<float>, then auto-converts back to | ||||||
|     svfloat32_t exp2x = svexp_f32_z(ptrue, svmul_f32_z(ptrue, CONST_2, x)); |     // svfloat32_t). | ||||||
|  |     svfloat32_t exp2x = | ||||||
|  |         Vectorized<float>(svmul_f32_z(ptrue, CONST_2, x)).exp_u20(); | ||||||
|  |  | ||||||
|     // Step 3: Calculate the numerator of the tanh function, which is exp(2x) |     // Step 3: Calculate the numerator of the tanh function, which is exp(2x) | ||||||
|     // - 1. |     // - 1. | ||||||
|  | |||||||
| @ -6,9 +6,11 @@ | |||||||
| #ifdef __aarch64__ | #ifdef __aarch64__ | ||||||
| #if !defined(CPU_CAPABILITY_SVE) | #if !defined(CPU_CAPABILITY_SVE) | ||||||
| #include <ATen/cpu/vec/vec128/vec128_bfloat16_neon.h> | #include <ATen/cpu/vec/vec128/vec128_bfloat16_neon.h> | ||||||
|  | #include <ATen/cpu/vec/vec128/vec128_double_neon.h> | ||||||
| #include <ATen/cpu/vec/vec128/vec128_float_neon.h> | #include <ATen/cpu/vec/vec128/vec128_float_neon.h> | ||||||
| #include <ATen/cpu/vec/vec128/vec128_half_neon.h> | #include <ATen/cpu/vec/vec128/vec128_half_neon.h> | ||||||
| #include <ATen/cpu/vec/vec128/vec128_int_aarch64.h> | #include <ATen/cpu/vec/vec128/vec128_int_aarch64.h> | ||||||
|  | #include <ATen/cpu/vec/vec128/vec128_uint_aarch64.h> | ||||||
| #endif | #endif | ||||||
|  |  | ||||||
| #include <ATen/cpu/vec/vec128/vec128_convert.h> | #include <ATen/cpu/vec/vec128/vec128_convert.h> | ||||||
|  | |||||||
| @ -354,9 +354,47 @@ class Vectorized<c10::BFloat16> : public Vectorized16< | |||||||
|  |  | ||||||
|   DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(abs) |   DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(abs) | ||||||
|   Vectorized frac() const; |   Vectorized frac() const; | ||||||
|   DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(neg) |  | ||||||
|   DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(trunc) |   DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(trunc) | ||||||
|   DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(sqrt) |   DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(sqrt) | ||||||
|  |  | ||||||
|  | #ifdef __ARM_FEATURE_BF16 | ||||||
|  |   Vectorized<c10::BFloat16> neg() const { | ||||||
|  |     return -values; | ||||||
|  |   } | ||||||
|  |   Vectorized<c10::BFloat16> reciprocal() const { | ||||||
|  |     return 1.0f / values; | ||||||
|  |   } | ||||||
|  |   Vectorized<c10::BFloat16> operator==( | ||||||
|  |       const Vectorized<c10::BFloat16>& other) const { | ||||||
|  |     return values == other.values; | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   Vectorized<c10::BFloat16> operator!=( | ||||||
|  |       const Vectorized<c10::BFloat16>& other) const { | ||||||
|  |     return values != other.values; | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   Vectorized<c10::BFloat16> operator<( | ||||||
|  |       const Vectorized<c10::BFloat16>& other) const { | ||||||
|  |     return values < other.values; | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   Vectorized<c10::BFloat16> operator<=( | ||||||
|  |       const Vectorized<c10::BFloat16>& other) const { | ||||||
|  |     return values <= other.values; | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   Vectorized<c10::BFloat16> operator>( | ||||||
|  |       const Vectorized<c10::BFloat16>& other) const { | ||||||
|  |     return values > other.values; | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   Vectorized<c10::BFloat16> operator>=( | ||||||
|  |       const Vectorized<c10::BFloat16>& other) const { | ||||||
|  |     return values >= other.values; | ||||||
|  |   } | ||||||
|  | #else | ||||||
|  |   DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(neg) | ||||||
|   DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(reciprocal) |   DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(reciprocal) | ||||||
|   DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator==) |   DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator==) | ||||||
|   DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator!=) |   DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator!=) | ||||||
| @ -364,6 +402,7 @@ class Vectorized<c10::BFloat16> : public Vectorized16< | |||||||
|   DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator<=) |   DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator<=) | ||||||
|   DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator>) |   DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator>) | ||||||
|   DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator>=) |   DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator>=) | ||||||
|  | #endif | ||||||
|  |  | ||||||
| #undef DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD | #undef DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD | ||||||
| #undef DEFINE_BINARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD | #undef DEFINE_BINARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD | ||||||
| @ -412,28 +451,52 @@ template <> | |||||||
| Vectorized<c10::BFloat16> inline operator+( | Vectorized<c10::BFloat16> inline operator+( | ||||||
|     const Vectorized<c10::BFloat16>& a, |     const Vectorized<c10::BFloat16>& a, | ||||||
|     const Vectorized<c10::BFloat16>& b) { |     const Vectorized<c10::BFloat16>& b) { | ||||||
|  | #ifdef __ARM_FEATURE_BF16 | ||||||
|  |   bfloat16x8_t x = a; | ||||||
|  |   bfloat16x8_t y = b; | ||||||
|  |   return x + y; | ||||||
|  | #else | ||||||
|   return binary_operator_via_float(std::plus<Vectorized<float>>(), a, b); |   return binary_operator_via_float(std::plus<Vectorized<float>>(), a, b); | ||||||
|  | #endif | ||||||
| } | } | ||||||
|  |  | ||||||
| template <> | template <> | ||||||
| Vectorized<c10::BFloat16> inline operator-( | Vectorized<c10::BFloat16> inline operator-( | ||||||
|     const Vectorized<c10::BFloat16>& a, |     const Vectorized<c10::BFloat16>& a, | ||||||
|     const Vectorized<c10::BFloat16>& b) { |     const Vectorized<c10::BFloat16>& b) { | ||||||
|  | #ifdef __ARM_FEATURE_BF16 | ||||||
|  |   bfloat16x8_t x = a; | ||||||
|  |   bfloat16x8_t y = b; | ||||||
|  |   return x - y; | ||||||
|  | #else | ||||||
|   return binary_operator_via_float(std::minus<Vectorized<float>>(), a, b); |   return binary_operator_via_float(std::minus<Vectorized<float>>(), a, b); | ||||||
|  | #endif | ||||||
| } | } | ||||||
|  |  | ||||||
| template <> | template <> | ||||||
| Vectorized<c10::BFloat16> inline operator*( | Vectorized<c10::BFloat16> inline operator*( | ||||||
|     const Vectorized<c10::BFloat16>& a, |     const Vectorized<c10::BFloat16>& a, | ||||||
|     const Vectorized<c10::BFloat16>& b) { |     const Vectorized<c10::BFloat16>& b) { | ||||||
|  | #ifdef __ARM_FEATURE_BF16 | ||||||
|  |   bfloat16x8_t x = a; | ||||||
|  |   bfloat16x8_t y = b; | ||||||
|  |   return x * y; | ||||||
|  | #else | ||||||
|   return binary_operator_via_float(std::multiplies<Vectorized<float>>(), a, b); |   return binary_operator_via_float(std::multiplies<Vectorized<float>>(), a, b); | ||||||
|  | #endif | ||||||
| } | } | ||||||
|  |  | ||||||
| template <> | template <> | ||||||
| Vectorized<c10::BFloat16> inline operator/( | Vectorized<c10::BFloat16> inline operator/( | ||||||
|     const Vectorized<c10::BFloat16>& a, |     const Vectorized<c10::BFloat16>& a, | ||||||
|     const Vectorized<c10::BFloat16>& b) { |     const Vectorized<c10::BFloat16>& b) { | ||||||
|  | #ifdef __ARM_FEATURE_BF16 | ||||||
|  |   bfloat16x8_t x = a; | ||||||
|  |   bfloat16x8_t y = b; | ||||||
|  |   return x / y; | ||||||
|  | #else | ||||||
|   return binary_operator_via_float(std::divides<Vectorized<float>>(), a, b); |   return binary_operator_via_float(std::divides<Vectorized<float>>(), a, b); | ||||||
|  | #endif | ||||||
| } | } | ||||||
|  |  | ||||||
| // frac. Implement this here so we can use subtraction | // frac. Implement this here so we can use subtraction | ||||||
| @ -544,12 +607,19 @@ Vectorized<c10::BFloat16> inline fmadd( | |||||||
|     const Vectorized<c10::BFloat16>& a, |     const Vectorized<c10::BFloat16>& a, | ||||||
|     const Vectorized<c10::BFloat16>& b, |     const Vectorized<c10::BFloat16>& b, | ||||||
|     const Vectorized<c10::BFloat16>& c) { |     const Vectorized<c10::BFloat16>& c) { | ||||||
|  | #ifdef __ARM_FEATURE_BF16 | ||||||
|  |   bfloat16x8_t x = a; | ||||||
|  |   bfloat16x8_t y = b; | ||||||
|  |   bfloat16x8_t z = c; | ||||||
|  |   return x * y + z; | ||||||
|  | #else | ||||||
|   // NOTE [BF16 FMA]: There isn't an FMA that accumulates into BF16!  Also, |   // NOTE [BF16 FMA]: There isn't an FMA that accumulates into BF16!  Also, | ||||||
|   // vbfmlalbq_f32 and vbfmlaltq_f32 take the even and odd-numbered |   // vbfmlalbq_f32 and vbfmlaltq_f32 take the even and odd-numbered | ||||||
|   // elements, not the bottom and top half, so they don't seem |   // elements, not the bottom and top half, so they don't seem | ||||||
|   // particularly useful here. Ideally we would include dot product in |   // particularly useful here. Ideally we would include dot product in | ||||||
|   // the Vectorized interface... |   // the Vectorized interface... | ||||||
|   return a * b + c; |   return a * b + c; | ||||||
|  | #endif | ||||||
| } | } | ||||||
|  |  | ||||||
| template <> | template <> | ||||||
| @ -557,8 +627,15 @@ Vectorized<c10::BFloat16> inline fnmadd( | |||||||
|     const Vectorized<c10::BFloat16>& a, |     const Vectorized<c10::BFloat16>& a, | ||||||
|     const Vectorized<c10::BFloat16>& b, |     const Vectorized<c10::BFloat16>& b, | ||||||
|     const Vectorized<c10::BFloat16>& c) { |     const Vectorized<c10::BFloat16>& c) { | ||||||
|  | #ifdef __ARM_FEATURE_BF16 | ||||||
|  |   bfloat16x8_t x = a; | ||||||
|  |   bfloat16x8_t y = b; | ||||||
|  |   bfloat16x8_t z = c; | ||||||
|  |   return (-x) * y + z; | ||||||
|  | #else | ||||||
|   // See NOTE [BF16 FMA] above. |   // See NOTE [BF16 FMA] above. | ||||||
|   return -a * b + c; |   return -a * b + c; | ||||||
|  | #endif | ||||||
| } | } | ||||||
|  |  | ||||||
| template <> | template <> | ||||||
| @ -566,8 +643,15 @@ Vectorized<c10::BFloat16> inline fmsub( | |||||||
|     const Vectorized<c10::BFloat16>& a, |     const Vectorized<c10::BFloat16>& a, | ||||||
|     const Vectorized<c10::BFloat16>& b, |     const Vectorized<c10::BFloat16>& b, | ||||||
|     const Vectorized<c10::BFloat16>& c) { |     const Vectorized<c10::BFloat16>& c) { | ||||||
|  | #ifdef __ARM_FEATURE_BF16 | ||||||
|  |   bfloat16x8_t x = a; | ||||||
|  |   bfloat16x8_t y = b; | ||||||
|  |   bfloat16x8_t z = c; | ||||||
|  |   return x * y - z; | ||||||
|  | #else | ||||||
|   // See NOTE [BF16 FMA] above. |   // See NOTE [BF16 FMA] above. | ||||||
|   return a * b - c; |   return a * b - c; | ||||||
|  | #endif | ||||||
| } | } | ||||||
|  |  | ||||||
| template <> | template <> | ||||||
| @ -575,8 +659,15 @@ Vectorized<c10::BFloat16> inline fnmsub( | |||||||
|     const Vectorized<c10::BFloat16>& a, |     const Vectorized<c10::BFloat16>& a, | ||||||
|     const Vectorized<c10::BFloat16>& b, |     const Vectorized<c10::BFloat16>& b, | ||||||
|     const Vectorized<c10::BFloat16>& c) { |     const Vectorized<c10::BFloat16>& c) { | ||||||
|  | #ifdef __ARM_FEATURE_BF16 | ||||||
|  |   bfloat16x8_t x = a; | ||||||
|  |   bfloat16x8_t y = b; | ||||||
|  |   bfloat16x8_t z = c; | ||||||
|  |   return (-x) * y - z; | ||||||
|  | #else | ||||||
|   // See NOTE [BF16 FMA] above. |   // See NOTE [BF16 FMA] above. | ||||||
|   return -a * b - c; |   return -a * b - c; | ||||||
|  | #endif | ||||||
| } | } | ||||||
|  |  | ||||||
| #endif // !defined(C10_MOBILE) && defined(__aarch64__) | #endif // !defined(C10_MOBILE) && defined(__aarch64__) | ||||||
|  | |||||||
| @ -5,6 +5,114 @@ | |||||||
| namespace at::vec { | namespace at::vec { | ||||||
| inline namespace CPU_CAPABILITY { | inline namespace CPU_CAPABILITY { | ||||||
| #if (defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE256)) | #if (defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE256)) | ||||||
|  |  | ||||||
|  | // Enable auto-vectorization for GCC-13+ and clang-17+ | ||||||
|  | // GCC-12 has a bug: gcc.gnu.org/bugzilla/show_bug.cgi?id=117001 | ||||||
|  | #if __GNUC__ > 12 || (defined(__clang__) && (__clang_major__ >= 17)) | ||||||
|  |  | ||||||
|  | template <typename from_type, typename to_type> | ||||||
|  | inline void convertImpl( | ||||||
|  |     const from_type* __restrict src, | ||||||
|  |     to_type* __restrict dst, | ||||||
|  |     int64_t n) { | ||||||
|  |   uint64_t len = static_cast<uint64_t>(n); | ||||||
|  |   for (uint64_t i = 0; i < len; i++) { | ||||||
|  |     dst[i] = static_cast<to_type>(src[i]); | ||||||
|  |   } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | #define CONVERT_TEMPLATE(from_type, to_type)                           \ | ||||||
|  |   template <>                                                          \ | ||||||
|  |   inline void convert(const from_type* src, to_type* dst, int64_t n) { \ | ||||||
|  |     return convertImpl<from_type, to_type>(src, dst, n);               \ | ||||||
|  |   } | ||||||
|  |  | ||||||
|  | CONVERT_TEMPLATE(uint8_t, uint8_t) | ||||||
|  | CONVERT_TEMPLATE(uint8_t, int8_t) | ||||||
|  | CONVERT_TEMPLATE(uint8_t, int16_t) | ||||||
|  | CONVERT_TEMPLATE(uint8_t, int32_t) | ||||||
|  | CONVERT_TEMPLATE(uint8_t, int64_t) | ||||||
|  | CONVERT_TEMPLATE(uint8_t, float) | ||||||
|  | CONVERT_TEMPLATE(uint8_t, double) | ||||||
|  | CONVERT_TEMPLATE(int8_t, uint8_t) | ||||||
|  | CONVERT_TEMPLATE(int8_t, int8_t) | ||||||
|  | CONVERT_TEMPLATE(int8_t, int16_t) | ||||||
|  | CONVERT_TEMPLATE(int8_t, int32_t) | ||||||
|  | CONVERT_TEMPLATE(int8_t, int64_t) | ||||||
|  | CONVERT_TEMPLATE(int8_t, float) | ||||||
|  | CONVERT_TEMPLATE(int8_t, double) | ||||||
|  | CONVERT_TEMPLATE(int16_t, uint8_t) | ||||||
|  | CONVERT_TEMPLATE(int16_t, int8_t) | ||||||
|  | CONVERT_TEMPLATE(int16_t, int16_t) | ||||||
|  | CONVERT_TEMPLATE(int16_t, int32_t) | ||||||
|  | CONVERT_TEMPLATE(int16_t, int64_t) | ||||||
|  | CONVERT_TEMPLATE(int16_t, float) | ||||||
|  | CONVERT_TEMPLATE(int16_t, double) | ||||||
|  | CONVERT_TEMPLATE(int32_t, uint8_t) | ||||||
|  | CONVERT_TEMPLATE(int32_t, int8_t) | ||||||
|  | CONVERT_TEMPLATE(int32_t, int16_t) | ||||||
|  | CONVERT_TEMPLATE(int32_t, int32_t) | ||||||
|  | CONVERT_TEMPLATE(int32_t, int64_t) | ||||||
|  | CONVERT_TEMPLATE(int32_t, float) | ||||||
|  | CONVERT_TEMPLATE(int32_t, double) | ||||||
|  | CONVERT_TEMPLATE(int64_t, uint8_t) | ||||||
|  | CONVERT_TEMPLATE(int64_t, int8_t) | ||||||
|  | CONVERT_TEMPLATE(int64_t, int16_t) | ||||||
|  | CONVERT_TEMPLATE(int64_t, int32_t) | ||||||
|  | CONVERT_TEMPLATE(int64_t, int64_t) | ||||||
|  | CONVERT_TEMPLATE(int64_t, float) | ||||||
|  | CONVERT_TEMPLATE(int64_t, double) | ||||||
|  | CONVERT_TEMPLATE(float, uint8_t) | ||||||
|  | CONVERT_TEMPLATE(float, int8_t) | ||||||
|  | CONVERT_TEMPLATE(float, int16_t) | ||||||
|  | CONVERT_TEMPLATE(float, int32_t) | ||||||
|  | CONVERT_TEMPLATE(float, int64_t) | ||||||
|  | CONVERT_TEMPLATE(float, float) | ||||||
|  | CONVERT_TEMPLATE(float, double) | ||||||
|  | CONVERT_TEMPLATE(double, uint8_t) | ||||||
|  | CONVERT_TEMPLATE(double, int8_t) | ||||||
|  | CONVERT_TEMPLATE(double, int16_t) | ||||||
|  | CONVERT_TEMPLATE(double, int32_t) | ||||||
|  | CONVERT_TEMPLATE(double, int64_t) | ||||||
|  | CONVERT_TEMPLATE(double, float) | ||||||
|  | CONVERT_TEMPLATE(double, double) | ||||||
|  | #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | ||||||
|  | CONVERT_TEMPLATE(float16_t, uint8_t) | ||||||
|  | CONVERT_TEMPLATE(float16_t, int8_t) | ||||||
|  | CONVERT_TEMPLATE(float16_t, int16_t) | ||||||
|  | CONVERT_TEMPLATE(float16_t, int32_t) | ||||||
|  | CONVERT_TEMPLATE(float16_t, int64_t) | ||||||
|  | CONVERT_TEMPLATE(float16_t, float16_t) | ||||||
|  | CONVERT_TEMPLATE(float16_t, float) | ||||||
|  | CONVERT_TEMPLATE(float16_t, double) | ||||||
|  | CONVERT_TEMPLATE(uint8_t, float16_t) | ||||||
|  | CONVERT_TEMPLATE(int8_t, float16_t) | ||||||
|  | CONVERT_TEMPLATE(int16_t, float16_t) | ||||||
|  | CONVERT_TEMPLATE(int32_t, float16_t) | ||||||
|  | CONVERT_TEMPLATE(int64_t, float16_t) | ||||||
|  | CONVERT_TEMPLATE(float, float16_t) | ||||||
|  | CONVERT_TEMPLATE(double, float16_t) | ||||||
|  | #endif | ||||||
|  | #ifdef __ARM_FEATURE_BF16 | ||||||
|  | CONVERT_TEMPLATE(bfloat16_t, uint8_t) | ||||||
|  | CONVERT_TEMPLATE(bfloat16_t, int8_t) | ||||||
|  | CONVERT_TEMPLATE(bfloat16_t, int16_t) | ||||||
|  | CONVERT_TEMPLATE(bfloat16_t, int32_t) | ||||||
|  | CONVERT_TEMPLATE(bfloat16_t, int64_t) | ||||||
|  | CONVERT_TEMPLATE(bfloat16_t, bfloat16_t) | ||||||
|  | CONVERT_TEMPLATE(bfloat16_t, float) | ||||||
|  | CONVERT_TEMPLATE(bfloat16_t, double) | ||||||
|  | CONVERT_TEMPLATE(uint8_t, bfloat16_t) | ||||||
|  | CONVERT_TEMPLATE(int8_t, bfloat16_t) | ||||||
|  | CONVERT_TEMPLATE(int16_t, bfloat16_t) | ||||||
|  | CONVERT_TEMPLATE(int32_t, bfloat16_t) | ||||||
|  | CONVERT_TEMPLATE(int64_t, bfloat16_t) | ||||||
|  | CONVERT_TEMPLATE(float, bfloat16_t) | ||||||
|  | CONVERT_TEMPLATE(double, bfloat16_t) | ||||||
|  | #endif | ||||||
|  |  | ||||||
|  | #endif | ||||||
|  |  | ||||||
| template <typename src_t> | template <typename src_t> | ||||||
| struct VecConvert< | struct VecConvert< | ||||||
|     float, |     float, | ||||||
|  | |||||||
							
								
								
									
										586
									
								
								aten/src/ATen/cpu/vec/vec128/vec128_double_neon.h
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										586
									
								
								aten/src/ATen/cpu/vec/vec128/vec128_double_neon.h
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,586 @@ | |||||||
|  | #pragma once | ||||||
|  |  | ||||||
|  | #include <ATen/cpu/vec/intrinsics.h> | ||||||
|  | #include <ATen/cpu/vec/vec_base.h> | ||||||
|  | #include <c10/macros/Macros.h> | ||||||
|  | #include <c10/util/irange.h> | ||||||
|  | #include <cmath> | ||||||
|  |  | ||||||
|  | namespace at::vec { | ||||||
|  | // Note [CPU_CAPABILITY namespace] | ||||||
|  | // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | ||||||
|  | // This header, and all of its subheaders, will be compiled with | ||||||
|  | // different architecture flags for each supported set of vector | ||||||
|  | // intrinsics. So we need to make sure they aren't inadvertently | ||||||
|  | // linked together. We do this by declaring objects in an `inline | ||||||
|  | // namespace` which changes the name mangling, but can still be | ||||||
|  | // accessed as `at::vec`. | ||||||
|  | inline namespace CPU_CAPABILITY { | ||||||
|  |  | ||||||
|  | template <> | ||||||
|  | struct is_vec_specialized_for<double> : std::bool_constant<true> {}; | ||||||
|  |  | ||||||
|  | template <> | ||||||
|  | class Vectorized<double> { | ||||||
|  |  private: | ||||||
|  |   float64x2_t values; | ||||||
|  |  | ||||||
|  |  public: | ||||||
|  |   using value_type = double; | ||||||
|  |   using size_type = int; | ||||||
|  |   static constexpr size_type size() { | ||||||
|  |     return 2; | ||||||
|  |   } | ||||||
|  |   Vectorized() { | ||||||
|  |     values = vdupq_n_f64(0.0); | ||||||
|  |   } | ||||||
|  |   Vectorized(float64x2_t v) : values(v) {} | ||||||
|  |   Vectorized(double val) { | ||||||
|  |     values = vdupq_n_f64(val); | ||||||
|  |   } | ||||||
|  |   template < | ||||||
|  |       typename... Args, | ||||||
|  |       typename = std::enable_if_t<(sizeof...(Args) == size())>> | ||||||
|  |   Vectorized(Args... vals) { | ||||||
|  |     __at_align__ double buffer[size()] = {vals...}; | ||||||
|  |     values = vld1q_f64(buffer); | ||||||
|  |   } | ||||||
|  |   operator float64x2_t() const { | ||||||
|  |     return values; | ||||||
|  |   } | ||||||
|  |   template <int64_t mask> | ||||||
|  |   static Vectorized<double> blend( | ||||||
|  |       const Vectorized<double>& a, | ||||||
|  |       const Vectorized<double>& b) { | ||||||
|  |     // Build an array of flags: each bit of element is 1 if the corresponding | ||||||
|  |     // bit in 'mask' is set, 0 otherwise. | ||||||
|  |     uint64x2_t maskArray = { | ||||||
|  |         (mask & 1ULL) ? 0xFFFFFFFFFFFFFFFF : 0, | ||||||
|  |         (mask & 2ULL) ? 0xFFFFFFFFFFFFFFFF : 0}; | ||||||
|  |     // Use BSL to select elements from b where the mask is 1, else from a | ||||||
|  |     return vbslq_f64(maskArray, b.values, a.values); | ||||||
|  |   } | ||||||
|  |   static Vectorized<double> blendv( | ||||||
|  |       const Vectorized<double>& a, | ||||||
|  |       const Vectorized<double>& b, | ||||||
|  |       const Vectorized<double>& mask_) { | ||||||
|  |     return vbslq_f64(vreinterpretq_u64_f64(mask_.values), b.values, a.values); | ||||||
|  |   } | ||||||
|  |   template <typename step_t> | ||||||
|  |   static Vectorized<double> arange( | ||||||
|  |       double base = 0., | ||||||
|  |       step_t step = static_cast<step_t>(1)) { | ||||||
|  |     return {base, base + static_cast<double>(step)}; | ||||||
|  |   } | ||||||
|  |   static inline Vectorized<double> set( | ||||||
|  |       const Vectorized<double>& a, | ||||||
|  |       const Vectorized<double>& b, | ||||||
|  |       int64_t count = size()) { | ||||||
|  |     if (count == 0) { | ||||||
|  |       return a; | ||||||
|  |     } else if (count >= 2) { | ||||||
|  |       return b; | ||||||
|  |     } else { | ||||||
|  |       float64x2_t c = {b.values[0], a.values[1]}; | ||||||
|  |       return c; | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  |   static Vectorized<double> loadu(const void* ptr, int64_t count = size()) { | ||||||
|  |     if (count == size()) { | ||||||
|  |       return vld1q_f64(reinterpret_cast<const double*>(ptr)); | ||||||
|  |     } else if (count == 1) { | ||||||
|  |       float64x1_t x = vld1_f64(reinterpret_cast<const double*>(ptr)); | ||||||
|  |       float64x1_t z = {0.0}; | ||||||
|  |       return vcombine_f64(x, z); | ||||||
|  |     } else { | ||||||
|  |       return vdupq_n_f64(0.0); | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  |   void store(void* ptr, int64_t count = size()) const { | ||||||
|  |     if (count == size()) { | ||||||
|  |       vst1q_f64(reinterpret_cast<double*>(ptr), values); | ||||||
|  |     } else if (count == 1) { | ||||||
|  |       vst1_f64(reinterpret_cast<double*>(ptr), vget_low_f64(values)); | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  |   const double& operator[](int idx) const = delete; | ||||||
|  |   double& operator[](int idx) = delete; | ||||||
|  |   int64_t zero_mask() const { | ||||||
|  |     // returns an integer mask where all zero elements are translated to 1-bit | ||||||
|  |     // and others are translated to 0-bit | ||||||
|  |     uint64x2_t cmpReg = vceqzq_f64(values); | ||||||
|  |     uint64x2_t mask = {1, 2}; | ||||||
|  |     uint64x2_t res = vandq_u64(cmpReg, mask); | ||||||
|  |     return res[0] | res[1]; | ||||||
|  |   } | ||||||
|  |   Vectorized<double> isnan() const { | ||||||
|  |     // NaN check | ||||||
|  |     return vreinterpretq_f64_u32( | ||||||
|  |         vmvnq_u32(vreinterpretq_u32_u64(vceqq_f64(values, values)))); | ||||||
|  |   } | ||||||
|  |   bool has_inf_nan() const { | ||||||
|  |     Vectorized<double> x = vsubq_f64(values, values); | ||||||
|  |     float64x2_t r = x.isnan(); | ||||||
|  |     uint64x2_t u = vreinterpretq_u64_f64(r); | ||||||
|  |     return u[0] | u[1]; | ||||||
|  |   } | ||||||
|  |   Vectorized<double> map(double (*f)(double)) const { | ||||||
|  |     float64x2_t result; | ||||||
|  |     result[0] = f(values[0]); | ||||||
|  |     result[1] = f(values[1]); | ||||||
|  |     return result; | ||||||
|  |   } | ||||||
|  |   Vectorized<double> map2( | ||||||
|  |       const Vectorized<double>& second, | ||||||
|  |       double (*const f)(double, double)) const { | ||||||
|  |     float64x2_t result; | ||||||
|  |     result[0] = f(values[0], second.values[0]); | ||||||
|  |     result[1] = f(values[1], second.values[1]); | ||||||
|  |     return result; | ||||||
|  |   } | ||||||
|  |   Vectorized<double> abs() const { | ||||||
|  |     return vabsq_f64(values); | ||||||
|  |   } | ||||||
|  |   Vectorized<double> angle() const { | ||||||
|  |     auto zero = Vectorized<double>(0.0); | ||||||
|  |     auto pi = Vectorized<double>(c10::pi<double>); | ||||||
|  |     auto tmp = blendv(zero, pi, vreinterpretq_f64_u64(vcltzq_f64(values))); | ||||||
|  |     return blendv(tmp, *this, isnan()); | ||||||
|  |   } | ||||||
|  |   Vectorized<double> real() const { | ||||||
|  |     return *this; | ||||||
|  |   } | ||||||
|  |   Vectorized<double> imag() const { | ||||||
|  |     return Vectorized<double>(0.0); | ||||||
|  |   } | ||||||
|  |   Vectorized<double> conj() const { | ||||||
|  |     return *this; | ||||||
|  |   } | ||||||
|  |   Vectorized<double> acos() const { | ||||||
|  |     return USE_SLEEF( | ||||||
|  |         Vectorized<double>(Sleef_acosd2_u10(values)), map(std::acos)); | ||||||
|  |   } | ||||||
|  |   Vectorized<double> acosh() const { | ||||||
|  |     return USE_SLEEF( | ||||||
|  |         Vectorized<double>(Sleef_acoshd2_u10(values)), map(std::acosh)); | ||||||
|  |   } | ||||||
|  |   Vectorized<double> asin() const { | ||||||
|  |     return USE_SLEEF( | ||||||
|  |         Vectorized<double>(Sleef_asind2_u10(values)), map(std::asin)); | ||||||
|  |   } | ||||||
|  |   Vectorized<double> asinh() const { | ||||||
|  |     return USE_SLEEF( | ||||||
|  |         Vectorized<double>(Sleef_asinhd2_u10(values)), map(std::asinh)); | ||||||
|  |   } | ||||||
|  |   Vectorized<double> atan() const { | ||||||
|  |     return USE_SLEEF( | ||||||
|  |         Vectorized<double>(Sleef_atand2_u10(values)), map(std::atan)); | ||||||
|  |   } | ||||||
|  |   Vectorized<double> atanh() const { | ||||||
|  |     return USE_SLEEF( | ||||||
|  |         Vectorized<double>(Sleef_atanhd2_u10(values)), map(std::atanh)); | ||||||
|  |   } | ||||||
|  |   Vectorized<double> atan2(const Vectorized<double>& b) const {USE_SLEEF( | ||||||
|  |       { return Vectorized<double>(Sleef_atan2d2_u10(values, b)); }, | ||||||
|  |       { | ||||||
|  |         __at_align__ double tmp[size()]; | ||||||
|  |         __at_align__ double tmp_b[size()]; | ||||||
|  |         store(tmp); | ||||||
|  |         b.store(tmp_b); | ||||||
|  |         for (int64_t i = 0; i < size(); i++) { | ||||||
|  |           tmp[i] = std::atan2(tmp[i], tmp_b[i]); | ||||||
|  |         } | ||||||
|  |         return loadu(tmp); | ||||||
|  |       })} Vectorized<double> copysign(const Vectorized<double>& sign) const { | ||||||
|  |       USE_SLEEF( | ||||||
|  |           { return Vectorized<double>(Sleef_copysignd2(values, sign)); }, | ||||||
|  |           { | ||||||
|  |             __at_align__ double tmp[size()]; | ||||||
|  |             __at_align__ double tmp_sign[size()]; | ||||||
|  |             store(tmp); | ||||||
|  |             sign.store(tmp_sign); | ||||||
|  |             for (int64_t i = 0; i < size(); i++) { | ||||||
|  |               tmp[i] = std::copysign(tmp[i], tmp_sign[i]); | ||||||
|  |             } | ||||||
|  |             return loadu(tmp); | ||||||
|  |           })} Vectorized<double> erf() const { | ||||||
|  |     return USE_SLEEF( | ||||||
|  |         Vectorized<double>(Sleef_erfd2_u10(values)), map(std::erf)); | ||||||
|  |   } | ||||||
|  |   Vectorized<double> erfc() const { | ||||||
|  |     return USE_SLEEF( | ||||||
|  |         Vectorized<double>(Sleef_erfcd2_u15(values)), map(std::erfc)); | ||||||
|  |   } | ||||||
|  |   Vectorized<double> exp() const { | ||||||
|  |     return USE_SLEEF( | ||||||
|  |         Vectorized<double>(Sleef_expd2_u10(values)), map(std::exp)); | ||||||
|  |   } | ||||||
|  |   Vectorized<double> exp2() const { | ||||||
|  |     return USE_SLEEF( | ||||||
|  |         Vectorized<double>(Sleef_exp2d2_u10(values)), map(std::exp2)); | ||||||
|  |   } | ||||||
|  |   Vectorized<double> expm1() const { | ||||||
|  |     return USE_SLEEF( | ||||||
|  |         Vectorized<double>(Sleef_expm1d2_u10(values)), map(std::expm1)); | ||||||
|  |   } | ||||||
|  |   Vectorized<double> fmod(const Vectorized<double>& q) const {USE_SLEEF( | ||||||
|  |       { return Vectorized<double>(Sleef_fmodd2(values, q)); }, | ||||||
|  |       { | ||||||
|  |         __at_align__ double tmp[size()]; | ||||||
|  |         __at_align__ double tmp_q[size()]; | ||||||
|  |         store(tmp); | ||||||
|  |         q.store(tmp_q); | ||||||
|  |         for (int64_t i = 0; i < size(); i++) { | ||||||
|  |           tmp[i] = std::fmod(tmp[i], tmp_q[i]); | ||||||
|  |         } | ||||||
|  |         return loadu(tmp); | ||||||
|  |       })} Vectorized<double> hypot(const Vectorized<double>& b) const { | ||||||
|  |       USE_SLEEF( | ||||||
|  |           { return Vectorized<double>(Sleef_hypotd2_u05(values, b)); }, | ||||||
|  |           { | ||||||
|  |             __at_align__ double tmp[size()]; | ||||||
|  |             __at_align__ double tmp_b[size()]; | ||||||
|  |             store(tmp); | ||||||
|  |             b.store(tmp_b); | ||||||
|  |             for (int64_t i = 0; i < size(); i++) { | ||||||
|  |               tmp[i] = std::hypot(tmp[i], tmp_b[i]); | ||||||
|  |             } | ||||||
|  |             return loadu(tmp); | ||||||
|  |           })} Vectorized<double> i0() const { | ||||||
|  |     return map(calc_i0); | ||||||
|  |   } | ||||||
|  |   Vectorized<double> nextafter(const Vectorized<double>& b) const {USE_SLEEF( | ||||||
|  |       { return Vectorized<double>(Sleef_nextafterd2(values, b)); }, | ||||||
|  |       { | ||||||
|  |         __at_align__ double tmp[size()]; | ||||||
|  |         __at_align__ double tmp_b[size()]; | ||||||
|  |         store(tmp); | ||||||
|  |         b.store(tmp_b); | ||||||
|  |         for (int64_t i = 0; i < size(); ++i) { | ||||||
|  |           tmp[i] = std::nextafter(tmp[i], tmp_b[i]); | ||||||
|  |         } | ||||||
|  |         return loadu(tmp); | ||||||
|  |       })} Vectorized<double> log() const { | ||||||
|  |     return USE_SLEEF( | ||||||
|  |         Vectorized<double>(Sleef_logd2_u10(values)), map(std::log)); | ||||||
|  |   } | ||||||
|  |   Vectorized<double> log2() const { | ||||||
|  |     return USE_SLEEF( | ||||||
|  |         Vectorized<double>(Sleef_log2d2_u10(values)), map(std::log2)); | ||||||
|  |   } | ||||||
|  |   Vectorized<double> log10() const { | ||||||
|  |     return USE_SLEEF( | ||||||
|  |         Vectorized<double>(Sleef_log10d2_u10(values)), map(std::log10)); | ||||||
|  |   } | ||||||
|  |   Vectorized<double> log1p() const { | ||||||
|  |     return USE_SLEEF( | ||||||
|  |         Vectorized<double>(Sleef_log1pd2_u10(values)), map(std::log1p)); | ||||||
|  |   } | ||||||
|  |   Vectorized<double> frac() const; | ||||||
|  |   Vectorized<double> sin() const { | ||||||
|  |     return USE_SLEEF( | ||||||
|  |         Vectorized<double>(Sleef_sind2_u10(values)), map(std::sin)); | ||||||
|  |   } | ||||||
|  |   Vectorized<double> sinh() const { | ||||||
|  |     return USE_SLEEF( | ||||||
|  |         Vectorized<double>(Sleef_sinhd2_u10(values)), map(std::sinh)); | ||||||
|  |   } | ||||||
|  |   Vectorized<double> cos() const { | ||||||
|  |     return USE_SLEEF( | ||||||
|  |         Vectorized<double>(Sleef_cosd2_u10(values)), map(std::cos)); | ||||||
|  |   } | ||||||
|  |   Vectorized<double> cosh() const { | ||||||
|  |     return USE_SLEEF( | ||||||
|  |         Vectorized<double>(Sleef_coshd2_u10(values)), map(std::cosh)); | ||||||
|  |   } | ||||||
|  |   Vectorized<double> pow(const Vectorized<double>& b) const {USE_SLEEF( | ||||||
|  |       { return Vectorized<double>(Sleef_powd2_u10(values, b)); }, | ||||||
|  |       { | ||||||
|  |         __at_align__ double tmp[size()]; | ||||||
|  |         __at_align__ double tmp_b[size()]; | ||||||
|  |         store(tmp); | ||||||
|  |         b.store(tmp_b); | ||||||
|  |         for (int64_t i = 0; i < size(); i++) { | ||||||
|  |           tmp[i] = std::pow(tmp[i], tmp_b[i]); | ||||||
|  |         } | ||||||
|  |         return loadu(tmp); | ||||||
|  |       })} // Comparison using the _CMP_**_OQ predicate. | ||||||
|  |           //   `O`: get false if an operand is NaN | ||||||
|  |           //   `Q`: do not raise if an operand is NaN | ||||||
|  |   Vectorized<double> tan() const { | ||||||
|  |     return USE_SLEEF( | ||||||
|  |         Vectorized<double>(Sleef_tand2_u10(values)), map(std::tan)); | ||||||
|  |   } | ||||||
|  |   Vectorized<double> tanh() const { | ||||||
|  |     return USE_SLEEF( | ||||||
|  |         Vectorized<double>(Sleef_tanhd2_u10(values)), map(std::tanh)); | ||||||
|  |   } | ||||||
|  |   Vectorized<double> lgamma() const { | ||||||
|  |     return USE_SLEEF( | ||||||
|  |         Vectorized<double>(Sleef_lgammad2_u10(values)), map(std::lgamma)); | ||||||
|  |   } | ||||||
|  |   Vectorized<double> erfinv() const { | ||||||
|  |     return map(calc_erfinv); | ||||||
|  |   } | ||||||
|  |   Vectorized<double> exp_u20() const { | ||||||
|  |     return exp(); | ||||||
|  |   } | ||||||
|  |   Vectorized<double> fexp_u20() const { | ||||||
|  |     return exp(); | ||||||
|  |   } | ||||||
|  |   Vectorized<double> i0e() const { | ||||||
|  |     return map(calc_i0e); | ||||||
|  |   } | ||||||
|  |   Vectorized<double> digamma() const { | ||||||
|  |     return map(calc_digamma); | ||||||
|  |   } | ||||||
|  |   Vectorized<double> igamma(const Vectorized<double>& x) const { | ||||||
|  |     __at_align__ double tmp[size()]; | ||||||
|  |     __at_align__ double tmp_x[size()]; | ||||||
|  |     store(tmp); | ||||||
|  |     x.store(tmp_x); | ||||||
|  |     for (int64_t i = 0; i < size(); i++) { | ||||||
|  |       tmp[i] = calc_igamma(tmp[i], tmp_x[i]); | ||||||
|  |     } | ||||||
|  |     return loadu(tmp); | ||||||
|  |   } | ||||||
|  |   Vectorized<double> igammac(const Vectorized<double>& x) const { | ||||||
|  |     __at_align__ double tmp[size()]; | ||||||
|  |     __at_align__ double tmp_x[size()]; | ||||||
|  |     store(tmp); | ||||||
|  |     x.store(tmp_x); | ||||||
|  |     for (int64_t i = 0; i < size(); i++) { | ||||||
|  |       tmp[i] = calc_igammac(tmp[i], tmp_x[i]); | ||||||
|  |     } | ||||||
|  |     return loadu(tmp); | ||||||
|  |   } | ||||||
|  |   Vectorized<double> ceil() const { | ||||||
|  |     return vrndpq_f64(values); | ||||||
|  |   } | ||||||
|  |   Vectorized<double> floor() const { | ||||||
|  |     return vrndmq_f64(values); | ||||||
|  |   } | ||||||
|  |   Vectorized<double> neg() const { | ||||||
|  |     return vnegq_f64(values); | ||||||
|  |   } | ||||||
|  |   Vectorized<double> round() const { | ||||||
|  |     return vrndiq_f64(values); | ||||||
|  |   } | ||||||
|  |   Vectorized<double> trunc() const { | ||||||
|  |     return vrndq_f64(values); | ||||||
|  |   } | ||||||
|  |   Vectorized<double> sqrt() const { | ||||||
|  |     return vsqrtq_f64(values); | ||||||
|  |   } | ||||||
|  |   Vectorized<double> reciprocal() const { | ||||||
|  |     return vdivq_f64(vdupq_n_f64(1.0), values); | ||||||
|  |   } | ||||||
|  |   Vectorized<double> rsqrt() const { | ||||||
|  |     return vdivq_f64(vdupq_n_f64(1.0), vsqrtq_f64(values)); | ||||||
|  |   } | ||||||
|  |   double reduce_add() const { | ||||||
|  |     return vaddvq_f64(values); | ||||||
|  |   } | ||||||
|  |   double reduce_max() const { | ||||||
|  |     return vmaxvq_f64(values); | ||||||
|  |   } | ||||||
|  |   Vectorized<double> operator==(const Vectorized<double>& other) const { | ||||||
|  |     return Vectorized<double>( | ||||||
|  |         vreinterpretq_f64_u64(vceqq_f64(values, other.values))); | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   Vectorized<double> operator!=(const Vectorized<double>& other) const { | ||||||
|  |     float64x2_t r0 = vreinterpretq_f64_u32( | ||||||
|  |         vmvnq_u32(vreinterpretq_u32_u64(vceqq_f64(values, other.values)))); | ||||||
|  |     return Vectorized<double>(r0); | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   Vectorized<double> operator<(const Vectorized<double>& other) const { | ||||||
|  |     return Vectorized<double>( | ||||||
|  |         vreinterpretq_f64_u64(vcltq_f64(values, other.values))); | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   Vectorized<double> operator<=(const Vectorized<double>& other) const { | ||||||
|  |     return Vectorized<double>( | ||||||
|  |         vreinterpretq_f64_u64(vcleq_f64(values, other.values))); | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   Vectorized<double> operator>(const Vectorized<double>& other) const { | ||||||
|  |     return Vectorized<double>( | ||||||
|  |         vreinterpretq_f64_u64(vcgtq_f64(values, other.values))); | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   Vectorized<double> operator>=(const Vectorized<double>& other) const { | ||||||
|  |     return Vectorized<double>( | ||||||
|  |         vreinterpretq_f64_u64(vcgeq_f64(values, other.values))); | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   Vectorized<double> eq(const Vectorized<double>& other) const; | ||||||
|  |   Vectorized<double> ne(const Vectorized<double>& other) const; | ||||||
|  |   Vectorized<double> gt(const Vectorized<double>& other) const; | ||||||
|  |   Vectorized<double> ge(const Vectorized<double>& other) const; | ||||||
|  |   Vectorized<double> lt(const Vectorized<double>& other) const; | ||||||
|  |   Vectorized<double> le(const Vectorized<double>& other) const; | ||||||
|  | }; | ||||||
|  |  | ||||||
|  | template <> | ||||||
|  | Vectorized<double> inline operator+( | ||||||
|  |     const Vectorized<double>& a, | ||||||
|  |     const Vectorized<double>& b) { | ||||||
|  |   return vaddq_f64(a, b); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | template <> | ||||||
|  | Vectorized<double> inline operator-( | ||||||
|  |     const Vectorized<double>& a, | ||||||
|  |     const Vectorized<double>& b) { | ||||||
|  |   return vsubq_f64(a, b); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | template <> | ||||||
|  | Vectorized<double> inline operator*( | ||||||
|  |     const Vectorized<double>& a, | ||||||
|  |     const Vectorized<double>& b) { | ||||||
|  |   return vmulq_f64(a, b); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | template <> | ||||||
|  | Vectorized<double> inline operator/( | ||||||
|  |     const Vectorized<double>& a, | ||||||
|  |     const Vectorized<double>& b) { | ||||||
|  |   return vdivq_f64(a, b); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // frac. Implement this here so we can use subtraction | ||||||
|  | Vectorized<double> inline Vectorized<double>::frac() const { | ||||||
|  |   return *this - this->trunc(); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Implements the IEEE 754 201X `maximum` operation, which propagates NaN if | ||||||
|  | // either input is a NaN. | ||||||
|  | template <> | ||||||
|  | Vectorized<double> inline maximum( | ||||||
|  |     const Vectorized<double>& a, | ||||||
|  |     const Vectorized<double>& b) { | ||||||
|  |   return vmaxq_f64(a, b); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Implements the IEEE 754 201X `minimum` operation, which propagates NaN if | ||||||
|  | // either input is a NaN. | ||||||
|  | template <> | ||||||
|  | Vectorized<double> inline minimum( | ||||||
|  |     const Vectorized<double>& a, | ||||||
|  |     const Vectorized<double>& b) { | ||||||
|  |   return vminq_f64(a, b); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | template <> | ||||||
|  | Vectorized<double> inline clamp( | ||||||
|  |     const Vectorized<double>& a, | ||||||
|  |     const Vectorized<double>& min, | ||||||
|  |     const Vectorized<double>& max) { | ||||||
|  |   return vminq_f64(max, vmaxq_f64(min, a)); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | template <> | ||||||
|  | Vectorized<double> inline clamp_max( | ||||||
|  |     const Vectorized<double>& a, | ||||||
|  |     const Vectorized<double>& max) { | ||||||
|  |   return vminq_f64(max, a); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | template <> | ||||||
|  | Vectorized<double> inline clamp_min( | ||||||
|  |     const Vectorized<double>& a, | ||||||
|  |     const Vectorized<double>& min) { | ||||||
|  |   return vmaxq_f64(min, a); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | template <> | ||||||
|  | Vectorized<double> inline operator&( | ||||||
|  |     const Vectorized<double>& a, | ||||||
|  |     const Vectorized<double>& b) { | ||||||
|  |   return vreinterpretq_f64_u64( | ||||||
|  |       vandq_u64(vreinterpretq_u64_f64(a), vreinterpretq_u64_f64(b))); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | template <> | ||||||
|  | Vectorized<double> inline operator|( | ||||||
|  |     const Vectorized<double>& a, | ||||||
|  |     const Vectorized<double>& b) { | ||||||
|  |   return vreinterpretq_f64_u64( | ||||||
|  |       vorrq_u64(vreinterpretq_u64_f64(a), vreinterpretq_u64_f64(b))); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | template <> | ||||||
|  | Vectorized<double> inline operator^( | ||||||
|  |     const Vectorized<double>& a, | ||||||
|  |     const Vectorized<double>& b) { | ||||||
|  |   return vreinterpretq_f64_u64( | ||||||
|  |       veorq_u64(vreinterpretq_u64_f64(a), vreinterpretq_u64_f64(b))); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | inline Vectorized<double> Vectorized<double>::eq( | ||||||
|  |     const Vectorized<double>& other) const { | ||||||
|  |   return (*this == other) & Vectorized<double>(1.0); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | inline Vectorized<double> Vectorized<double>::ne( | ||||||
|  |     const Vectorized<double>& other) const { | ||||||
|  |   return (*this != other) & Vectorized<double>(1.0); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | inline Vectorized<double> Vectorized<double>::gt( | ||||||
|  |     const Vectorized<double>& other) const { | ||||||
|  |   return (*this > other) & Vectorized<double>(1.0); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | inline Vectorized<double> Vectorized<double>::ge( | ||||||
|  |     const Vectorized<double>& other) const { | ||||||
|  |   return (*this >= other) & Vectorized<double>(1.0); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | inline Vectorized<double> Vectorized<double>::lt( | ||||||
|  |     const Vectorized<double>& other) const { | ||||||
|  |   return (*this < other) & Vectorized<double>(1.0); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | inline Vectorized<double> Vectorized<double>::le( | ||||||
|  |     const Vectorized<double>& other) const { | ||||||
|  |   return (*this <= other) & Vectorized<double>(1.0); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | template <> | ||||||
|  | Vectorized<double> inline fmadd( | ||||||
|  |     const Vectorized<double>& a, | ||||||
|  |     const Vectorized<double>& b, | ||||||
|  |     const Vectorized<double>& c) { | ||||||
|  |   return vfmaq_f64(c, a, b); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | template <> | ||||||
|  | Vectorized<double> inline fnmadd( | ||||||
|  |     const Vectorized<double>& a, | ||||||
|  |     const Vectorized<double>& b, | ||||||
|  |     const Vectorized<double>& c) { | ||||||
|  |   return vfmsq_f64(c, a, b); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | template <> | ||||||
|  | Vectorized<double> inline fmsub( | ||||||
|  |     const Vectorized<double>& a, | ||||||
|  |     const Vectorized<double>& b, | ||||||
|  |     const Vectorized<double>& c) { | ||||||
|  |   return vfmaq_f64(vnegq_f64(c), a, b); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | template <> | ||||||
|  | Vectorized<double> inline fnmsub( | ||||||
|  |     const Vectorized<double>& a, | ||||||
|  |     const Vectorized<double>& b, | ||||||
|  |     const Vectorized<double>& c) { | ||||||
|  |   return vfmsq_f64(vnegq_f64(c), a, b); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | } // namespace CPU_CAPABILITY | ||||||
|  | } // namespace at::vec | ||||||
| @ -307,11 +307,49 @@ class Vectorized<float> { | |||||||
|   DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(exp) |   DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(exp) | ||||||
|   DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(exp2) |   DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(exp2) | ||||||
|   DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(expm1) |   DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(expm1) | ||||||
|  |   // Implementation copied from Arm Optimized Routine | ||||||
|  |   // https://github.com/ARM-software/optimized-routines/blob/master/math/aarch64/advsimd/expf.c | ||||||
|   Vectorized<float> exp_u20() const { |   Vectorized<float> exp_u20() const { | ||||||
|     return exp(); |     // bail out to sleef if it's a special case: | ||||||
|  |     // i.e. there's an input s.t. |input| > 87.3.... | ||||||
|  |     const float32x4_t special_bound = vdupq_n_f32(0x1.5d5e2ap+6f); | ||||||
|  |     uint32x4_t cmp = vcagtq_f32(values, special_bound); | ||||||
|  |     if (vpaddd_u64(vreinterpretq_u64_u32(cmp)) != 0) { | ||||||
|  |       return exp(); | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     const float32x4_t inv_ln2 = vdupq_n_f32(0x1.715476p+0f); | ||||||
|  |     const float ln2_hi = 0x1.62e4p-1f; | ||||||
|  |     const float ln2_lo = 0x1.7f7d1cp-20f; | ||||||
|  |     const float c0 = 0x1.0e4020p-7f; | ||||||
|  |     const float c2 = 0x1.555e66p-3f; | ||||||
|  |     const float32x4_t ln2_c02 = {ln2_hi, ln2_lo, c0, c2}; | ||||||
|  |  | ||||||
|  |     const uint32x4_t exponent_bias = vdupq_n_u32(0x3f800000); | ||||||
|  |     const float32x4_t c1 = vdupq_n_f32(0x1.573e2ep-5f); | ||||||
|  |     const float32x4_t c3 = vdupq_n_f32(0x1.fffdb6p-2f); | ||||||
|  |     const float32x4_t c4 = vdupq_n_f32(0x1.ffffecp-1f); | ||||||
|  |  | ||||||
|  |     /* exp(x) = 2^n (1 + poly(r)), with 1 + poly(r) in [1/sqrt(2),sqrt(2)] | ||||||
|  |       x = ln2*n + r, with r in [-ln2/2, ln2/2].  */ | ||||||
|  |  | ||||||
|  |     float32x4_t n = vrndaq_f32(vmulq_f32(values, inv_ln2)); | ||||||
|  |     float32x4_t r = vfmsq_laneq_f32(values, n, ln2_c02, 0); | ||||||
|  |     r = vfmsq_laneq_f32(r, n, ln2_c02, 1); | ||||||
|  |     uint32x4_t e = vshlq_n_u32(vreinterpretq_u32_s32(vcvtq_s32_f32(n)), 23); | ||||||
|  |     float32x4_t scale = vreinterpretq_f32_u32(vaddq_u32(e, exponent_bias)); | ||||||
|  |  | ||||||
|  |     float32x4_t r2 = vmulq_f32(r, r); | ||||||
|  |     float32x4_t p = vfmaq_laneq_f32(c1, r, ln2_c02, 2); | ||||||
|  |     float32x4_t q = vfmaq_laneq_f32(c3, r, ln2_c02, 3); | ||||||
|  |     q = vfmaq_f32(q, p, r2); | ||||||
|  |     p = vmulq_f32(c4, r); | ||||||
|  |     float32x4_t poly = vfmaq_f32(p, q, r2); | ||||||
|  |  | ||||||
|  |     return vfmaq_f32(scale, poly, scale); | ||||||
|   } |   } | ||||||
|   Vectorized<float> fexp_u20() const { |   Vectorized<float> fexp_u20() const { | ||||||
|     return exp(); |     return exp_u20(); | ||||||
|   } |   } | ||||||
|   DEFINE_SLEEF_COMPATIBLE_BINARY_ELEMENTWISE_FUNC_WITH_SLEEF_NAME( |   DEFINE_SLEEF_COMPATIBLE_BINARY_ELEMENTWISE_FUNC_WITH_SLEEF_NAME( | ||||||
|       fmod, |       fmod, | ||||||
| @ -540,42 +578,6 @@ inline Vectorized<float> Vectorized<float>::le( | |||||||
|   return (*this <= other) & Vectorized<float>(1.0f); |   return (*this <= other) & Vectorized<float>(1.0f); | ||||||
| } | } | ||||||
|  |  | ||||||
| template <> |  | ||||||
| inline void convert(const float* src, int32_t* dst, int64_t n) { |  | ||||||
|   int64_t i; |  | ||||||
| #ifndef __msvc_cl__ |  | ||||||
| #pragma unroll |  | ||||||
| #endif |  | ||||||
|   for (i = 0; i <= (n - Vectorized<float>::size()); |  | ||||||
|        i += Vectorized<float>::size()) { |  | ||||||
|     vst1q_s32(dst + i, vcvtq_s32_f32(vld1q_f32(src + i))); |  | ||||||
|   } |  | ||||||
| #ifndef __msvc_cl__ |  | ||||||
| #pragma unroll |  | ||||||
| #endif |  | ||||||
|   for (; i < n; i++) { |  | ||||||
|     dst[i] = static_cast<int32_t>(src[i]); |  | ||||||
|   } |  | ||||||
| } |  | ||||||
|  |  | ||||||
| template <> |  | ||||||
| inline void convert(const int32_t* src, float* dst, int64_t n) { |  | ||||||
|   int64_t i; |  | ||||||
| #ifndef __msvc_cl__ |  | ||||||
| #pragma unroll |  | ||||||
| #endif |  | ||||||
|   for (i = 0; i <= (n - Vectorized<float>::size()); |  | ||||||
|        i += Vectorized<float>::size()) { |  | ||||||
|     vst1q_f32(dst + i, vcvtq_f32_s32(vld1q_s32(src + i))); |  | ||||||
|   } |  | ||||||
| #ifndef __msvc_cl__ |  | ||||||
| #pragma unroll |  | ||||||
| #endif |  | ||||||
|   for (; i < n; i++) { |  | ||||||
|     dst[i] = static_cast<float>(src[i]); |  | ||||||
|   } |  | ||||||
| } |  | ||||||
|  |  | ||||||
| template <> | template <> | ||||||
| Vectorized<float> inline fmadd( | Vectorized<float> inline fmadd( | ||||||
|     const Vectorized<float>& a, |     const Vectorized<float>& a, | ||||||
| @ -632,8 +634,7 @@ inline Vectorized<float> Vectorized<float>::erf() const { | |||||||
|   // - exp(- x * x) |   // - exp(- x * x) | ||||||
|   auto pow_2 = (*this) * (*this); |   auto pow_2 = (*this) * (*this); | ||||||
|   auto neg_pow_2 = pow_2 ^ neg_zero_vec; |   auto neg_pow_2 = pow_2 ^ neg_zero_vec; | ||||||
|   auto tmp4 = neg_pow_2.map( |   auto tmp4 = neg_pow_2.exp(); | ||||||
|       std::exp); // This can be swapped for a faster implementation of exp. |  | ||||||
|   auto tmp5 = tmp4 ^ neg_zero_vec; |   auto tmp5 = tmp4 ^ neg_zero_vec; | ||||||
|   // erf(x) = sign(x) * (1 - r * t * exp(- x * x)) |   // erf(x) = sign(x) * (1 - r * t * exp(- x * x)) | ||||||
|   auto tmp6 = t * tmp5; |   auto tmp6 = t * tmp5; | ||||||
|  | |||||||
| @ -569,46 +569,6 @@ inline Vectorized<c10::Half> Vectorized<c10::Half>::le( | |||||||
|   return (*this <= other) & Vectorized<c10::Half>(1); |   return (*this <= other) & Vectorized<c10::Half>(1); | ||||||
| } | } | ||||||
|  |  | ||||||
| // These are global functions, so the defaults in vec_base.h should |  | ||||||
| // work fine if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC is not available. |  | ||||||
| #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC |  | ||||||
| template <> |  | ||||||
| inline void convert(const float16_t* src, int16_t* dst, int64_t n) { |  | ||||||
|   int64_t i; |  | ||||||
| #ifndef __msvc_cl__ |  | ||||||
| #pragma unroll |  | ||||||
| #endif |  | ||||||
|   for (i = 0; i <= (n - Vectorized<c10::Half>::size()); |  | ||||||
|        i += Vectorized<c10::Half>::size()) { |  | ||||||
|     vst1q_s16(dst + i, vcvtq_s16_f16(vld1q_f16(src + i))); |  | ||||||
|   } |  | ||||||
| #ifndef __msvc_cl__ |  | ||||||
| #pragma unroll |  | ||||||
| #endif |  | ||||||
|   for (; i < n; i++) { |  | ||||||
|     dst[i] = static_cast<int16_t>(src[i]); |  | ||||||
|   } |  | ||||||
| } |  | ||||||
|  |  | ||||||
| template <> |  | ||||||
| inline void convert(const int16_t* src, float16_t* dst, int64_t n) { |  | ||||||
|   int64_t i; |  | ||||||
| #ifndef __msvc_cl__ |  | ||||||
| #pragma unroll |  | ||||||
| #endif |  | ||||||
|   for (i = 0; i <= (n - Vectorized<c10::Half>::size()); |  | ||||||
|        i += Vectorized<c10::Half>::size()) { |  | ||||||
|     vst1q_f16(dst + i, vcvtq_f16_s16(vld1q_s16(src + i))); |  | ||||||
|   } |  | ||||||
| #ifndef __msvc_cl__ |  | ||||||
| #pragma unroll |  | ||||||
| #endif |  | ||||||
|   for (; i < n; i++) { |  | ||||||
|     dst[i] = static_cast<float16_t>(src[i]); |  | ||||||
|   } |  | ||||||
| } |  | ||||||
| #endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC |  | ||||||
|  |  | ||||||
| template <> | template <> | ||||||
| Vectorized<c10::Half> inline fmadd( | Vectorized<c10::Half> inline fmadd( | ||||||
|     const Vectorized<c10::Half>& a, |     const Vectorized<c10::Half>& a, | ||||||
|  | |||||||
							
								
								
									
										378
									
								
								aten/src/ATen/cpu/vec/vec128/vec128_uint_aarch64.h
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										378
									
								
								aten/src/ATen/cpu/vec/vec128/vec128_uint_aarch64.h
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,378 @@ | |||||||
|  | #pragma once | ||||||
|  |  | ||||||
|  | #include <ATen/cpu/vec/intrinsics.h> | ||||||
|  | #include <ATen/cpu/vec/vec_base.h> | ||||||
|  | #include <c10/macros/Macros.h> | ||||||
|  | #include <c10/util/irange.h> | ||||||
|  |  | ||||||
|  | namespace at::vec { | ||||||
|  | // Note [CPU_CAPABILITY namespace] | ||||||
|  | // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | ||||||
|  | // This header, and all of its subheaders, will be compiled with | ||||||
|  | // different architecture flags for each supported set of vector | ||||||
|  | // intrinsics. So we need to make sure they aren't inadvertently | ||||||
|  | // linked together. We do this by declaring objects in an `inline | ||||||
|  | // namespace` which changes the name mangling, but can still be | ||||||
|  | // accessed as `at::vec`. | ||||||
|  | inline namespace CPU_CAPABILITY { | ||||||
|  |  | ||||||
|  | #define VEC_UINT_NEON_TEMPLATE(vl, bit)                                       \ | ||||||
|  |   template <>                                                                 \ | ||||||
|  |   struct is_vec_specialized_for<uint##bit##_t> : std::bool_constant<true> {}; \ | ||||||
|  |                                                                               \ | ||||||
|  |   template <>                                                                 \ | ||||||
|  |   class Vectorized<uint##bit##_t> {                                           \ | ||||||
|  |     using neon_type = uint##bit##x##vl##_t;                                   \ | ||||||
|  |                                                                               \ | ||||||
|  |    private:                                                                   \ | ||||||
|  |     neon_type values;                                                         \ | ||||||
|  |                                                                               \ | ||||||
|  |    public:                                                                    \ | ||||||
|  |     using value_type = uint##bit##_t;                                         \ | ||||||
|  |     using size_type = int;                                                    \ | ||||||
|  |     static constexpr size_type size() {                                       \ | ||||||
|  |       return vl;                                                              \ | ||||||
|  |     }                                                                         \ | ||||||
|  |     Vectorized() {                                                            \ | ||||||
|  |       values = vdupq_n_u##bit(0);                                             \ | ||||||
|  |     }                                                                         \ | ||||||
|  |     Vectorized(neon_type v) : values(v) {}                                    \ | ||||||
|  |     Vectorized(uint##bit##_t val);                                            \ | ||||||
|  |     template <                                                                \ | ||||||
|  |         typename... Args,                                                     \ | ||||||
|  |         typename = std::enable_if_t<(sizeof...(Args) == size())>>             \ | ||||||
|  |     Vectorized(Args... vals) {                                                \ | ||||||
|  |       __at_align__ uint##bit##_t buffer[size()] = {vals...};                  \ | ||||||
|  |       values = vld1q_u##bit(buffer);                                          \ | ||||||
|  |     }                                                                         \ | ||||||
|  |     operator neon_type() const {                                              \ | ||||||
|  |       return values;                                                          \ | ||||||
|  |     }                                                                         \ | ||||||
|  |     static Vectorized<uint##bit##_t> loadu(                                   \ | ||||||
|  |         const void* ptr,                                                      \ | ||||||
|  |         uint64_t count = size());                                             \ | ||||||
|  |     void store(void* ptr, uint64_t count = size()) const;                     \ | ||||||
|  |     template <uint64_t mask>                                                  \ | ||||||
|  |     static Vectorized<uint##bit##_t> blend(                                   \ | ||||||
|  |         const Vectorized<uint##bit##_t>& a,                                   \ | ||||||
|  |         const Vectorized<uint##bit##_t>& b);                                  \ | ||||||
|  |     static Vectorized<uint##bit##_t> blendv(                                  \ | ||||||
|  |         const Vectorized<uint##bit##_t>& a,                                   \ | ||||||
|  |         const Vectorized<uint##bit##_t>& b,                                   \ | ||||||
|  |         const Vectorized<uint##bit##_t>& mask_) {                             \ | ||||||
|  |       return vbslq_u##bit(mask_.values, b, a);                                \ | ||||||
|  |     }                                                                         \ | ||||||
|  |     template <typename step_t>                                                \ | ||||||
|  |     static Vectorized<uint##bit##_t> arange(                                  \ | ||||||
|  |         value_type base = 0,                                                  \ | ||||||
|  |         step_t step = static_cast<step_t>(1));                                \ | ||||||
|  |     static Vectorized<uint##bit##_t> set(                                     \ | ||||||
|  |         const Vectorized<uint##bit##_t>& a,                                   \ | ||||||
|  |         const Vectorized<uint##bit##_t>& b,                                   \ | ||||||
|  |         uint64_t count = size());                                             \ | ||||||
|  |     const uint##bit##_t& operator[](uint idx) const = delete;                 \ | ||||||
|  |     uint##bit##_t& operator[](uint idx) = delete;                             \ | ||||||
|  |     Vectorized<uint##bit##_t> abs() const {                                   \ | ||||||
|  |       return values;                                                          \ | ||||||
|  |     }                                                                         \ | ||||||
|  |     Vectorized<uint##bit##_t> real() const {                                  \ | ||||||
|  |       return values;                                                          \ | ||||||
|  |     }                                                                         \ | ||||||
|  |     Vectorized<uint##bit##_t> imag() const {                                  \ | ||||||
|  |       return vdupq_n_u##bit(0);                                               \ | ||||||
|  |     }                                                                         \ | ||||||
|  |     Vectorized<uint##bit##_t> conj() const {                                  \ | ||||||
|  |       return values;                                                          \ | ||||||
|  |     }                                                                         \ | ||||||
|  |     Vectorized<uint##bit##_t> neg() const {                                   \ | ||||||
|  |       return vreinterpretq_u##bit##_s##bit(                                   \ | ||||||
|  |           vnegq_s##bit(vreinterpretq_s##bit##_u##bit(values)));               \ | ||||||
|  |     }                                                                         \ | ||||||
|  |     uint##bit##_t reduce_add() const {                                        \ | ||||||
|  |       return vaddvq_u##bit(values);                                           \ | ||||||
|  |     }                                                                         \ | ||||||
|  |     uint##bit##_t reduce_max() const;                                         \ | ||||||
|  |     Vectorized<uint##bit##_t> operator==(                                     \ | ||||||
|  |         const Vectorized<uint##bit##_t>& other) const {                       \ | ||||||
|  |       return Vectorized<value_type>(vceqq_u##bit(values, other.values));      \ | ||||||
|  |     }                                                                         \ | ||||||
|  |     Vectorized<uint##bit##_t> operator!=(                                     \ | ||||||
|  |         const Vectorized<uint##bit##_t>& other) const;                        \ | ||||||
|  |     Vectorized<uint##bit##_t> operator<(                                      \ | ||||||
|  |         const Vectorized<uint##bit##_t>& other) const {                       \ | ||||||
|  |       return Vectorized<value_type>(vcltq_u##bit(values, other.values));      \ | ||||||
|  |     }                                                                         \ | ||||||
|  |     Vectorized<uint##bit##_t> operator<=(                                     \ | ||||||
|  |         const Vectorized<uint##bit##_t>& other) const {                       \ | ||||||
|  |       return Vectorized<value_type>(vcleq_u##bit(values, other.values));      \ | ||||||
|  |     }                                                                         \ | ||||||
|  |     Vectorized<uint##bit##_t> operator>(                                      \ | ||||||
|  |         const Vectorized<uint##bit##_t>& other) const {                       \ | ||||||
|  |       return Vectorized<value_type>(vcgtq_u##bit(values, other.values));      \ | ||||||
|  |     }                                                                         \ | ||||||
|  |     Vectorized<uint##bit##_t> operator>=(                                     \ | ||||||
|  |         const Vectorized<uint##bit##_t>& other) const {                       \ | ||||||
|  |       return Vectorized<value_type>(vcgeq_u##bit(values, other.values));      \ | ||||||
|  |     }                                                                         \ | ||||||
|  |     Vectorized<uint##bit##_t> eq(                                             \ | ||||||
|  |         const Vectorized<uint##bit##_t>& other) const;                        \ | ||||||
|  |     Vectorized<uint##bit##_t> ne(                                             \ | ||||||
|  |         const Vectorized<uint##bit##_t>& other) const;                        \ | ||||||
|  |     Vectorized<uint##bit##_t> gt(                                             \ | ||||||
|  |         const Vectorized<uint##bit##_t>& other) const;                        \ | ||||||
|  |     Vectorized<uint##bit##_t> ge(                                             \ | ||||||
|  |         const Vectorized<uint##bit##_t>& other) const;                        \ | ||||||
|  |     Vectorized<uint##bit##_t> lt(                                             \ | ||||||
|  |         const Vectorized<uint##bit##_t>& other) const;                        \ | ||||||
|  |     Vectorized<uint##bit##_t> le(                                             \ | ||||||
|  |         const Vectorized<uint##bit##_t>& other) const;                        \ | ||||||
|  |   };                                                                          \ | ||||||
|  |   template <>                                                                 \ | ||||||
|  |   Vectorized<uint##bit##_t> inline operator+(                                 \ | ||||||
|  |       const Vectorized<uint##bit##_t>& a,                                     \ | ||||||
|  |       const Vectorized<uint##bit##_t>& b) {                                   \ | ||||||
|  |     return vaddq_u##bit(a, b);                                                \ | ||||||
|  |   }                                                                           \ | ||||||
|  |   template <>                                                                 \ | ||||||
|  |   Vectorized<uint##bit##_t> inline operator-(                                 \ | ||||||
|  |       const Vectorized<uint##bit##_t>& a,                                     \ | ||||||
|  |       const Vectorized<uint##bit##_t>& b) {                                   \ | ||||||
|  |     return vsubq_u##bit(a, b);                                                \ | ||||||
|  |   }                                                                           \ | ||||||
|  |   template <>                                                                 \ | ||||||
|  |   Vectorized<uint##bit##_t> inline operator&(                                 \ | ||||||
|  |       const Vectorized<uint##bit##_t>& a,                                     \ | ||||||
|  |       const Vectorized<uint##bit##_t>& b) {                                   \ | ||||||
|  |     return vandq_u##bit(a, b);                                                \ | ||||||
|  |   }                                                                           \ | ||||||
|  |   template <>                                                                 \ | ||||||
|  |   Vectorized<uint##bit##_t> inline operator|(                                 \ | ||||||
|  |       const Vectorized<uint##bit##_t>& a,                                     \ | ||||||
|  |       const Vectorized<uint##bit##_t>& b) {                                   \ | ||||||
|  |     return vorrq_u##bit(a, b);                                                \ | ||||||
|  |   }                                                                           \ | ||||||
|  |   template <>                                                                 \ | ||||||
|  |   Vectorized<uint##bit##_t> inline operator^(                                 \ | ||||||
|  |       const Vectorized<uint##bit##_t>& a,                                     \ | ||||||
|  |       const Vectorized<uint##bit##_t>& b) {                                   \ | ||||||
|  |     return veorq_u##bit(a, b);                                                \ | ||||||
|  |   }                                                                           \ | ||||||
|  |   Vectorized<uint##bit##_t> inline Vectorized<uint##bit##_t>::eq(             \ | ||||||
|  |       const Vectorized<uint##bit##_t>& other) const {                         \ | ||||||
|  |     return (*this == other) & Vectorized<uint##bit##_t>(1);                   \ | ||||||
|  |   }                                                                           \ | ||||||
|  |   Vectorized<uint##bit##_t> inline Vectorized<uint##bit##_t>::ne(             \ | ||||||
|  |       const Vectorized<uint##bit##_t>& other) const {                         \ | ||||||
|  |     return (*this != other) & Vectorized<uint##bit##_t>(1);                   \ | ||||||
|  |   }                                                                           \ | ||||||
|  |   Vectorized<uint##bit##_t> inline Vectorized<uint##bit##_t>::gt(             \ | ||||||
|  |       const Vectorized<uint##bit##_t>& other) const {                         \ | ||||||
|  |     return (*this > other) & Vectorized<uint##bit##_t>(1);                    \ | ||||||
|  |   }                                                                           \ | ||||||
|  |   Vectorized<uint##bit##_t> inline Vectorized<uint##bit##_t>::ge(             \ | ||||||
|  |       const Vectorized<uint##bit##_t>& other) const {                         \ | ||||||
|  |     return (*this >= other) & Vectorized<uint##bit##_t>(1);                   \ | ||||||
|  |   }                                                                           \ | ||||||
|  |   Vectorized<uint##bit##_t> inline Vectorized<uint##bit##_t>::lt(             \ | ||||||
|  |       const Vectorized<uint##bit##_t>& other) const {                         \ | ||||||
|  |     return (*this < other) & Vectorized<uint##bit##_t>(1);                    \ | ||||||
|  |   }                                                                           \ | ||||||
|  |   Vectorized<uint##bit##_t> inline Vectorized<uint##bit##_t>::le(             \ | ||||||
|  |       const Vectorized<uint##bit##_t>& other) const {                         \ | ||||||
|  |     return (*this <= other) & Vectorized<uint##bit##_t>(1);                   \ | ||||||
|  |   } | ||||||
|  |  | ||||||
|  | VEC_UINT_NEON_TEMPLATE(16, 8) | ||||||
|  |  | ||||||
|  | inline uint8_t Vectorized<uint8_t>::reduce_max() const { | ||||||
|  |   return vmaxvq_u8(values); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | template <> | ||||||
|  | Vectorized<uint8_t> inline operator*( | ||||||
|  |     const Vectorized<uint8_t>& a, | ||||||
|  |     const Vectorized<uint8_t>& b) { | ||||||
|  |   return vmulq_u8(a, b); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | template <> | ||||||
|  | inline Vectorized<uint8_t> operator~(const Vectorized<uint8_t>& a) { | ||||||
|  |   return vmvnq_u8(a); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | inline Vectorized<uint8_t> Vectorized<uint8_t>::operator!=( | ||||||
|  |     const Vectorized<uint8_t>& other) const { | ||||||
|  |   return ~(*this == other); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | template <> | ||||||
|  | Vectorized<uint8_t> inline minimum( | ||||||
|  |     const Vectorized<uint8_t>& a, | ||||||
|  |     const Vectorized<uint8_t>& b) { | ||||||
|  |   return vminq_u8(a, b); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | template <> | ||||||
|  | Vectorized<uint8_t> inline maximum( | ||||||
|  |     const Vectorized<uint8_t>& a, | ||||||
|  |     const Vectorized<uint8_t>& b) { | ||||||
|  |   return vmaxq_u8(a, b); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | template <uint64_t mask> | ||||||
|  | Vectorized<uint8_t> Vectorized<uint8_t>::blend( | ||||||
|  |     const Vectorized<uint8_t>& a, | ||||||
|  |     const Vectorized<uint8_t>& b) { | ||||||
|  |   // Build an array of flags: each bit of element is 1 if the corresponding bit | ||||||
|  |   // in 'mask' is set, 0 otherwise. | ||||||
|  |   uint8x16_t maskArray = { | ||||||
|  |       (mask & 1LL) ? 0xFF : 0, | ||||||
|  |       (mask & 2LL) ? 0xFF : 0, | ||||||
|  |       (mask & 4LL) ? 0xFF : 0, | ||||||
|  |       (mask & 8LL) ? 0xFF : 0, | ||||||
|  |       (mask & 16LL) ? 0xFF : 0, | ||||||
|  |       (mask & 32LL) ? 0xFF : 0, | ||||||
|  |       (mask & 64LL) ? 0xFF : 0, | ||||||
|  |       (mask & 128LL) ? 0xFF : 0, | ||||||
|  |       (mask & 256LL) ? 0xFF : 0, | ||||||
|  |       (mask & 512LL) ? 0xFF : 0, | ||||||
|  |       (mask & 1024LL) ? 0xFF : 0, | ||||||
|  |       (mask & 2048LL) ? 0xFF : 0, | ||||||
|  |       (mask & 4096LL) ? 0xFF : 0, | ||||||
|  |       (mask & 8192LL) ? 0xFF : 0, | ||||||
|  |       (mask & 16384LL) ? 0xFF : 0, | ||||||
|  |       (mask & 32768LL) ? 0xFF : 0}; | ||||||
|  |   // Use BSL to select elements from b where the mask is 1, else from a | ||||||
|  |   return vbslq_u8(maskArray, b.values, a.values); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | #define VEC_UINT_NEON_OPS(vl, bit)                                             \ | ||||||
|  |   inline Vectorized<uint##bit##_t>::Vectorized(uint##bit##_t val) {            \ | ||||||
|  |     values = vdupq_n_u##bit(val);                                              \ | ||||||
|  |   }                                                                            \ | ||||||
|  |   inline Vectorized<uint##bit##_t> Vectorized<uint##bit##_t>::loadu(           \ | ||||||
|  |       const void* ptr, uint64_t count) {                                       \ | ||||||
|  |     if (count == size()) {                                                     \ | ||||||
|  |       return vld1q_u##bit(reinterpret_cast<const uint##bit##_t*>(ptr));        \ | ||||||
|  |     } else {                                                                   \ | ||||||
|  |       __at_align__ uint##bit##_t tmp_values[size()];                           \ | ||||||
|  |       for (const auto i : c10::irange(size())) {                               \ | ||||||
|  |         tmp_values[i] = 0;                                                     \ | ||||||
|  |       }                                                                        \ | ||||||
|  |       std::memcpy(                                                             \ | ||||||
|  |           tmp_values,                                                          \ | ||||||
|  |           reinterpret_cast<const uint##bit##_t*>(ptr),                         \ | ||||||
|  |           count * sizeof(uint##bit##_t));                                      \ | ||||||
|  |       return vld1q_u##bit(reinterpret_cast<const uint##bit##_t*>(tmp_values)); \ | ||||||
|  |     }                                                                          \ | ||||||
|  |   }                                                                            \ | ||||||
|  |   inline void Vectorized<uint##bit##_t>::store(void* ptr, uint64_t count)      \ | ||||||
|  |       const {                                                                  \ | ||||||
|  |     if (count == size()) {                                                     \ | ||||||
|  |       vst1q_u##bit(reinterpret_cast<uint##bit##_t*>(ptr), values);             \ | ||||||
|  |     } else {                                                                   \ | ||||||
|  |       uint##bit##_t tmp_values[size()];                                        \ | ||||||
|  |       vst1q_u##bit(reinterpret_cast<uint##bit##_t*>(tmp_values), values);      \ | ||||||
|  |       std::memcpy(ptr, tmp_values, count * sizeof(uint##bit##_t));             \ | ||||||
|  |     }                                                                          \ | ||||||
|  |   } | ||||||
|  |  | ||||||
|  | VEC_UINT_NEON_OPS(16, 8) | ||||||
|  |  | ||||||
|  | template <typename step_t> | ||||||
|  | inline Vectorized<uint8_t> Vectorized<uint8_t>::arange( | ||||||
|  |     uint8_t base, | ||||||
|  |     step_t step) { | ||||||
|  |   const Vectorized<uint8_t> base_vec(base); | ||||||
|  |   const Vectorized<uint8_t> step_vec(step); | ||||||
|  |   const uint8x16_t step_sizes = { | ||||||
|  |       0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; | ||||||
|  |   return vmlaq_u8(base_vec, step_sizes, step_vec); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | template <> | ||||||
|  | Vectorized<uint8_t> inline operator>>( | ||||||
|  |     const Vectorized<uint8_t>& a, | ||||||
|  |     const Vectorized<uint8_t>& b) { | ||||||
|  |   uint8x16_t x = a; | ||||||
|  |   uint8x16_t bound = vdupq_n_u8(8); | ||||||
|  |   uint8x16_t z = vminq_u8(b, bound); | ||||||
|  |   return x >> z; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | template <> | ||||||
|  | Vectorized<uint8_t> inline operator<<( | ||||||
|  |     const Vectorized<uint8_t>& a, | ||||||
|  |     const Vectorized<uint8_t>& b) { | ||||||
|  |   uint8x16_t bound = vdupq_n_u8(8); | ||||||
|  |   uint8x16_t z = vminq_u8(b, bound); | ||||||
|  |   return vshlq_u8(a, vreinterpretq_s8_u8(z)); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | inline Vectorized<uint8_t> Vectorized<uint8_t>::set( | ||||||
|  |     const Vectorized<uint8_t>& a, | ||||||
|  |     const Vectorized<uint8_t>& b, | ||||||
|  |     uint64_t count) { | ||||||
|  |   if (count == 0) { | ||||||
|  |     return a; | ||||||
|  |   } else if (count >= 16) { | ||||||
|  |     return b; | ||||||
|  |   } else { | ||||||
|  |     // Build an array of flags: each bit of element is 1 if the corresponding | ||||||
|  |     // bit in 'mask' is set, 0 otherwise. | ||||||
|  |     uint8x16_t maskArray = { | ||||||
|  |         static_cast<uint8_t>((count >= 1LL) ? 0xFF : 0), | ||||||
|  |         static_cast<uint8_t>((count >= 2LL) ? 0xFF : 0), | ||||||
|  |         static_cast<uint8_t>((count >= 3LL) ? 0xFF : 0), | ||||||
|  |         static_cast<uint8_t>((count >= 4LL) ? 0xFF : 0), | ||||||
|  |         static_cast<uint8_t>((count >= 5LL) ? 0xFF : 0), | ||||||
|  |         static_cast<uint8_t>((count >= 6LL) ? 0xFF : 0), | ||||||
|  |         static_cast<uint8_t>((count >= 7LL) ? 0xFF : 0), | ||||||
|  |         static_cast<uint8_t>((count >= 8LL) ? 0xFF : 0), | ||||||
|  |         static_cast<uint8_t>((count >= 9LL) ? 0xFF : 0), | ||||||
|  |         static_cast<uint8_t>((count >= 10LL) ? 0xFF : 0), | ||||||
|  |         static_cast<uint8_t>((count >= 11LL) ? 0xFF : 0), | ||||||
|  |         static_cast<uint8_t>((count >= 12LL) ? 0xFF : 0), | ||||||
|  |         static_cast<uint8_t>((count >= 13LL) ? 0xFF : 0), | ||||||
|  |         static_cast<uint8_t>((count >= 14LL) ? 0xFF : 0), | ||||||
|  |         static_cast<uint8_t>((count >= 15LL) ? 0xFF : 0), | ||||||
|  |         0}; | ||||||
|  |  | ||||||
|  |     // Use BSL to select elements from b where the mask is 1, else from a | ||||||
|  |     return vbslq_u8(maskArray, b.values, a.values); | ||||||
|  |   } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | template <> | ||||||
|  | Vectorized<uint8_t> inline operator/( | ||||||
|  |     const Vectorized<uint8_t>& a, | ||||||
|  |     const Vectorized<uint8_t>& b) { | ||||||
|  |   uint8x16_t x = a; | ||||||
|  |   uint8x16_t y = b; | ||||||
|  |   return x / y; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | template <> | ||||||
|  | Vectorized<uint8_t> inline clamp( | ||||||
|  |     const Vectorized<uint8_t>& a, | ||||||
|  |     const Vectorized<uint8_t>& min, | ||||||
|  |     const Vectorized<uint8_t>& max) { | ||||||
|  |   return minimum(max, maximum(min, a)); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | template <> | ||||||
|  | Vectorized<uint8_t> inline clamp_max( | ||||||
|  |     const Vectorized<uint8_t>& a, | ||||||
|  |     const Vectorized<uint8_t>& max) { | ||||||
|  |   return minimum(max, a); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | template <> | ||||||
|  | Vectorized<uint8_t> inline clamp_min( | ||||||
|  |     const Vectorized<uint8_t>& a, | ||||||
|  |     const Vectorized<uint8_t>& min) { | ||||||
|  |   return maximum(min, a); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | } // namespace CPU_CAPABILITY | ||||||
|  | } // namespace at::vec | ||||||
| @ -1390,7 +1390,7 @@ std::pair<Vectorized<float>, Vectorized<float>> inline convert_int8_to_float( | |||||||
|  |  | ||||||
| std::pair<Vectorized<float>, Vectorized<float>> inline convert_int8_to_float( | std::pair<Vectorized<float>, Vectorized<float>> inline convert_int8_to_float( | ||||||
|     at::vec::Vectorized<uint8_t> src) { |     at::vec::Vectorized<uint8_t> src) { | ||||||
|   auto u8x8 = vld1_u8(src.operator const uint8_t*()); |   auto u8x8 = vget_low_u8(src); | ||||||
|   auto u16x8 = vmovl_u8(u8x8); |   auto u16x8 = vmovl_u8(u8x8); | ||||||
|   auto u32x4_hi = vmovl_u16(vget_high_u16(u16x8)); |   auto u32x4_hi = vmovl_u16(vget_high_u16(u16x8)); | ||||||
|   auto u32x4_lo = vmovl_u16(vget_low_u16(u16x8)); |   auto u32x4_lo = vmovl_u16(vget_low_u16(u16x8)); | ||||||
| @ -1412,7 +1412,7 @@ Vectorized<float> inline convert_int8_half_register_to_float( | |||||||
|  |  | ||||||
| Vectorized<float> inline convert_int8_half_register_to_float( | Vectorized<float> inline convert_int8_half_register_to_float( | ||||||
|     at::vec::Vectorized<uint8_t> src) { |     at::vec::Vectorized<uint8_t> src) { | ||||||
|   auto u8x8 = vld1_u8(src.operator const uint8_t*()); |   auto u8x8 = vget_low_u8(src); | ||||||
|   auto u16x8 = vmovl_u8(u8x8); |   auto u16x8 = vmovl_u8(u8x8); | ||||||
|   auto u32x4_lo = vmovl_u16(vget_low_u16(u16x8)); |   auto u32x4_lo = vmovl_u16(vget_low_u16(u16x8)); | ||||||
|  |  | ||||||
|  | |||||||
| @ -168,11 +168,9 @@ void CUDAGraph::instantiate() { | |||||||
|   // https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__GRAPH.html#group__CUDART__GRAPH_1g1accfe1da0c605a577c22d9751a09597 |   // https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__GRAPH.html#group__CUDART__GRAPH_1g1accfe1da0c605a577c22d9751a09597 | ||||||
|   // cudaGraphInstantiateWithFlags |   // cudaGraphInstantiateWithFlags | ||||||
|   // https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__GRAPH.html#group__CUDART__GRAPH_1ga2c652a24ba93e52b99a47bec0888233 |   // https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__GRAPH.html#group__CUDART__GRAPH_1ga2c652a24ba93e52b99a47bec0888233 | ||||||
| #if !defined(USE_ROCM) || ROCM_VERSION >= 60200 |  | ||||||
|   int version = 0; |   int version = 0; | ||||||
|   AT_CUDA_CHECK(cudaDriverGetVersion(&version)); |   AT_CUDA_CHECK(cudaDriverGetVersion(&version)); | ||||||
|   if (version < 11040) { |   if (version < 11040) { | ||||||
| #endif |  | ||||||
|     // Trailing NULL, NULL, 0 arguments were recommended by Cuda driver people, |     // Trailing NULL, NULL, 0 arguments were recommended by Cuda driver people, | ||||||
|     // who prefer not to report error message through these arguments moving forward |     // who prefer not to report error message through these arguments moving forward | ||||||
|     // (they prefer return value, or errors on api calls internal to the capture) |     // (they prefer return value, or errors on api calls internal to the capture) | ||||||
| @ -183,13 +181,11 @@ void CUDAGraph::instantiate() { | |||||||
| #endif | #endif | ||||||
| //Since ROCm 6.2, we want to go down this path as hipGraphExecDestroy in the destructor will not immediately free the memory. | //Since ROCm 6.2, we want to go down this path as hipGraphExecDestroy in the destructor will not immediately free the memory. | ||||||
| //It will wait for the next sync operation. cudaGraphInstantiateFlagAutoFreeOnLaunch will add async frees after graph launch. | //It will wait for the next sync operation. cudaGraphInstantiateFlagAutoFreeOnLaunch will add async frees after graph launch. | ||||||
| #if !defined(USE_ROCM) || ROCM_VERSION >= 60200 |  | ||||||
|   } else { |   } else { | ||||||
|     AT_CUDA_CHECK(cudaGraphInstantiateWithFlags(&graph_exec_, |     AT_CUDA_CHECK(cudaGraphInstantiateWithFlags(&graph_exec_, | ||||||
|                                                 graph_, |                                                 graph_, | ||||||
|                                                 cudaGraphInstantiateFlagAutoFreeOnLaunch)); |                                                 cudaGraphInstantiateFlagAutoFreeOnLaunch)); | ||||||
|   } |   } | ||||||
| #endif |  | ||||||
|   has_graph_exec_ = true; |   has_graph_exec_ = true; | ||||||
| } | } | ||||||
|  |  | ||||||
|  | |||||||
							
								
								
									
										192
									
								
								aten/src/ATen/cuda/CUDAGreenContext.cpp
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										192
									
								
								aten/src/ATen/cuda/CUDAGreenContext.cpp
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,192 @@ | |||||||
|  | #include <ATen/cuda/CUDAGreenContext.h> | ||||||
|  |  | ||||||
|  | namespace at::cuda { | ||||||
|  |   GreenContext::GreenContext(uint32_t device_id, uint32_t num_sms) { | ||||||
|  | #if CUDA_HAS_GREEN_CONTEXT | ||||||
|  |     int driver_version; | ||||||
|  |     C10_CUDA_CHECK(cudaDriverGetVersion(&driver_version)); | ||||||
|  |     TORCH_CHECK( | ||||||
|  |         driver_version >= 12080, "cuda driver too old to use green context!"); | ||||||
|  |     CUcontext pctx = nullptr; | ||||||
|  |     C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuCtxGetCurrent_(&pctx)); | ||||||
|  |     if (C10_UNLIKELY(!pctx)) { | ||||||
|  |       TORCH_WARN( | ||||||
|  |           "Attempted to create a green context but" | ||||||
|  |           " there was no primary context! Creating a primary context..."); | ||||||
|  |  | ||||||
|  |       cudaFree(0); | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     CUdevice device; | ||||||
|  |     device_id_ = device_id; | ||||||
|  |     C10_CUDA_DRIVER_CHECK( | ||||||
|  |         c10::cuda::DriverAPI::get()->cuDeviceGet_(&device, device_id)); | ||||||
|  |  | ||||||
|  |     // Get device resources | ||||||
|  |     CUdevResource device_resource; | ||||||
|  |     C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuDeviceGetDevResource_( | ||||||
|  |         device, &device_resource, CU_DEV_RESOURCE_TYPE_SM)); | ||||||
|  |  | ||||||
|  |     // Split resources | ||||||
|  |     std::vector<CUdevResource> result(1); | ||||||
|  |     auto result_data = result.data(); | ||||||
|  |     unsigned int nb_groups = 1; | ||||||
|  |     CUdevResource remaining; | ||||||
|  |  | ||||||
|  |     C10_CUDA_DRIVER_CHECK( | ||||||
|  |         c10::cuda::DriverAPI::get()->cuDevSmResourceSplitByCount_( | ||||||
|  |             result_data, | ||||||
|  |             &nb_groups, | ||||||
|  |             &device_resource, | ||||||
|  |             &remaining, | ||||||
|  |             0, // default flags | ||||||
|  |             num_sms)); | ||||||
|  |  | ||||||
|  |     TORCH_CHECK(nb_groups == 1, "Failed to create single resource group"); | ||||||
|  |  | ||||||
|  |     // Generate resource descriptor | ||||||
|  |     CUdevResourceDesc desc; | ||||||
|  |     C10_CUDA_DRIVER_CHECK( | ||||||
|  |         c10::cuda::DriverAPI::get()->cuDevResourceGenerateDesc_( | ||||||
|  |             &desc, result_data, 1)); | ||||||
|  |  | ||||||
|  |     // Create green context | ||||||
|  |     // CU_GREEN_CTX_DEFAULT_STREAM is required per docs: | ||||||
|  |     // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__GREEN__CONTEXTS.html | ||||||
|  |     C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuGreenCtxCreate_( | ||||||
|  |         &green_ctx_, desc, device, CU_GREEN_CTX_DEFAULT_STREAM)); | ||||||
|  |  | ||||||
|  |     // Convert to regular context | ||||||
|  |     C10_CUDA_DRIVER_CHECK( | ||||||
|  |         c10::cuda::DriverAPI::get()->cuCtxFromGreenCtx_(&context_, green_ctx_)); | ||||||
|  |     TORCH_CHECK(context_, "Green ctx conversion to regular ctx failed!"); | ||||||
|  | #else | ||||||
|  |     TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!"); | ||||||
|  | #endif | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   std::unique_ptr<GreenContext> GreenContext::create( | ||||||
|  |       uint32_t num_sms, | ||||||
|  |       std::optional<uint32_t> device_id) { | ||||||
|  | #if CUDA_HAS_GREEN_CONTEXT | ||||||
|  |     if (!device_id.has_value()) { | ||||||
|  |       device_id = at::cuda::current_device(); | ||||||
|  |     } | ||||||
|  |     return std::make_unique<GreenContext>(device_id.value(), num_sms); | ||||||
|  | #else | ||||||
|  |     TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!"); | ||||||
|  | #endif | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   // Implement move operations | ||||||
|  |   GreenContext::GreenContext(GreenContext&& other) noexcept{ | ||||||
|  | #if CUDA_HAS_GREEN_CONTEXT | ||||||
|  |     device_id_ = std::exchange(other.device_id_, -1); | ||||||
|  |     green_ctx_ = std::exchange(other.green_ctx_, nullptr); | ||||||
|  |     context_ = std::exchange(other.context_, nullptr); | ||||||
|  |     parent_stream_ = std::exchange(other.parent_stream_, nullptr); | ||||||
|  | #else | ||||||
|  |     TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!"); | ||||||
|  | #endif | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   GreenContext& GreenContext::operator=(GreenContext&& other) noexcept{ | ||||||
|  | #if CUDA_HAS_GREEN_CONTEXT | ||||||
|  |     if (this != &other) { | ||||||
|  |       // Clean up current resources | ||||||
|  |       if (green_ctx_) { | ||||||
|  |         CUcontext current = nullptr; | ||||||
|  |         C10_CUDA_DRIVER_CHECK( | ||||||
|  |             c10::cuda::DriverAPI::get()->cuCtxGetCurrent_(¤t)); | ||||||
|  |         if (current == context_) { | ||||||
|  |           TORCH_CHECK( | ||||||
|  |               false, | ||||||
|  |               "attempting to overwrite current green ctx " | ||||||
|  |               "when it is active!"); | ||||||
|  |         } | ||||||
|  |         C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuGreenCtxDestroy_(green_ctx_)); | ||||||
|  |       } | ||||||
|  |  | ||||||
|  |       // Take ownership of other's resources | ||||||
|  |       device_id_ = std::exchange(other.device_id_, -1); | ||||||
|  |       green_ctx_ = std::exchange(other.green_ctx_, nullptr); | ||||||
|  |       context_ = std::exchange(other.context_, nullptr); | ||||||
|  |       parent_stream_ = std::exchange(other.parent_stream_, nullptr); | ||||||
|  |     } | ||||||
|  |     return *this; | ||||||
|  | #else | ||||||
|  |     TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!"); | ||||||
|  | #endif | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   GreenContext::~GreenContext() noexcept{ | ||||||
|  | #if CUDA_HAS_GREEN_CONTEXT | ||||||
|  |     C10_CUDA_DRIVER_CHECK( | ||||||
|  |         c10::cuda::DriverAPI::get()->cuGreenCtxDestroy_(green_ctx_)); | ||||||
|  | #else | ||||||
|  |     TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!"); | ||||||
|  | #endif | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   // Get the underlying CUDA context | ||||||
|  |   CUcontext GreenContext::getContext() const { | ||||||
|  | #if CUDA_HAS_GREEN_CONTEXT | ||||||
|  |     return context_; | ||||||
|  | #else | ||||||
|  |     TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!"); | ||||||
|  | #endif | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   // Get the underlying green context | ||||||
|  | #if CUDA_HAS_GREEN_CONTEXT | ||||||
|  |   CUgreenCtx GreenContext::getGreenContext() const { | ||||||
|  |     return green_ctx_; | ||||||
|  |   } | ||||||
|  | #endif | ||||||
|  |  | ||||||
|  |   // Make this context current | ||||||
|  |   void GreenContext::setContext() { | ||||||
|  | #if CUDA_HAS_GREEN_CONTEXT | ||||||
|  |     auto current_stream = c10::cuda::getCurrentCUDAStream(); | ||||||
|  |     parent_stream_ = current_stream.stream(); | ||||||
|  |  | ||||||
|  |     at::cuda::CUDAEvent ev; | ||||||
|  |     ev.record(current_stream); | ||||||
|  |  | ||||||
|  |     CUcontext current = nullptr; | ||||||
|  |     C10_CUDA_DRIVER_CHECK( | ||||||
|  |         c10::cuda::DriverAPI::get()->cuCtxGetCurrent_(¤t)); | ||||||
|  |     if (!current) { | ||||||
|  |       C10_CUDA_DRIVER_CHECK( | ||||||
|  |           c10::cuda::DriverAPI::get()->cuCtxSetCurrent_(context_)); | ||||||
|  |     } else { | ||||||
|  |       C10_CUDA_DRIVER_CHECK( | ||||||
|  |           c10::cuda::DriverAPI::get()->cuCtxPushCurrent_(context_)); | ||||||
|  |     } | ||||||
|  |     // currently hardcodes the new green context to use the default stream | ||||||
|  |     // TODO(eqy): consider creating a new stream if e.g., it allows interop | ||||||
|  |     // with CUDA Graph captures etc. | ||||||
|  |     auto default_stream = c10::cuda::getDefaultCUDAStream(); | ||||||
|  |     ev.block(default_stream); | ||||||
|  |     c10::cuda::setCurrentCUDAStream(default_stream); | ||||||
|  | #else | ||||||
|  |     TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!"); | ||||||
|  | #endif | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   void GreenContext::popContext() { | ||||||
|  | #if CUDA_HAS_GREEN_CONTEXT | ||||||
|  |     // see above note about stream being hardcoded to the default stream | ||||||
|  |     at::cuda::CUDAEvent ev; | ||||||
|  |     ev.record(c10::cuda::getCurrentCUDAStream()); | ||||||
|  |     CUcontext popped; | ||||||
|  |     C10_CUDA_DRIVER_CHECK( | ||||||
|  |         c10::cuda::DriverAPI::get()->cuCtxPopCurrent_(&popped)); | ||||||
|  |     TORCH_INTERNAL_ASSERT( | ||||||
|  |         popped == context_, "expected popped context to be the current ctx"); | ||||||
|  |     ev.block(c10::cuda::getStreamFromExternal(parent_stream_, device_id_)); | ||||||
|  | #else | ||||||
|  |     TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!"); | ||||||
|  | #endif | ||||||
|  |   } | ||||||
|  | } // namespace at::cuda | ||||||
							
								
								
									
										53
									
								
								aten/src/ATen/cuda/CUDAGreenContext.h
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										53
									
								
								aten/src/ATen/cuda/CUDAGreenContext.h
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,53 @@ | |||||||
|  | #pragma once | ||||||
|  | #include <ATen/cuda/CUDAEvent.h> | ||||||
|  |  | ||||||
|  | #if defined(CUDA_VERSION) && !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED) | ||||||
|  | #include <c10/cuda/driver_api.h> | ||||||
|  | #include <cuda.h> | ||||||
|  | #include <memory> | ||||||
|  | #include <stdexcept> | ||||||
|  | #include <vector> | ||||||
|  | #define CUDA_HAS_GREEN_CONTEXT 1 | ||||||
|  | #else | ||||||
|  | #define CUDA_HAS_GREEN_CONTEXT 0 | ||||||
|  | #endif | ||||||
|  |  | ||||||
|  | namespace at::cuda { | ||||||
|  |  | ||||||
|  | class TORCH_CUDA_CPP_API GreenContext { | ||||||
|  |  public: | ||||||
|  |   GreenContext(uint32_t device_id, uint32_t num_sms); | ||||||
|  |  | ||||||
|  |   static std::unique_ptr<GreenContext> create(uint32_t num_sms, std::optional<uint32_t> device_id); | ||||||
|  |  | ||||||
|  |   // Delete copy constructor and assignment | ||||||
|  |   GreenContext(const GreenContext&) = delete; | ||||||
|  |   GreenContext& operator=(const GreenContext&) = delete; | ||||||
|  |  | ||||||
|  |   // Implement move operations | ||||||
|  |   GreenContext(GreenContext&& other) noexcept; | ||||||
|  |   GreenContext& operator=(GreenContext&& other) noexcept; | ||||||
|  |   ~GreenContext() noexcept; | ||||||
|  |  | ||||||
|  |   // Get the underlying CUDA context | ||||||
|  |   CUcontext getContext() const; | ||||||
|  |  | ||||||
|  |   // Get the underlying green context | ||||||
|  | #if CUDA_HAS_GREEN_CONTEXT | ||||||
|  |   CUgreenCtx getGreenContext() const; | ||||||
|  | #endif | ||||||
|  |  | ||||||
|  |   // Make this context current | ||||||
|  |   void setContext(); | ||||||
|  |  | ||||||
|  |   void popContext(); | ||||||
|  |  | ||||||
|  |  private: | ||||||
|  | #if CUDA_HAS_GREEN_CONTEXT | ||||||
|  |   int32_t device_id_ = -1; | ||||||
|  |   CUgreenCtx green_ctx_ = nullptr; | ||||||
|  |   CUcontext context_ = nullptr; | ||||||
|  |   cudaStream_t parent_stream_ = nullptr; | ||||||
|  | #endif | ||||||
|  | }; | ||||||
|  | } // namespace at::cuda | ||||||
							
								
								
									
										270
									
								
								aten/src/ATen/cuda/CUDAScaledBlas.cpp
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										270
									
								
								aten/src/ATen/cuda/CUDAScaledBlas.cpp
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,270 @@ | |||||||
|  | #include <cstdint> | ||||||
|  | #include <c10/util/typeid.h> | ||||||
|  | #include <c10/util/Exception.h> | ||||||
|  | #include <c10/util/SmallVector.h> | ||||||
|  | #include <c10/core/Scalar.h> | ||||||
|  | #include <c10/core/ScalarType.h> | ||||||
|  | #include <c10/util/Exception.h> | ||||||
|  | #define TORCH_ASSERT_ONLY_METHOD_OPERATORS | ||||||
|  | #include <ATen/core/Tensor.h> | ||||||
|  | #include <ATen/core/NamedTensor.h> | ||||||
|  | #include <ATen/Dispatch.h> | ||||||
|  | #include <ATen/ExpandUtils.h> | ||||||
|  | #include <ATen/OpMathType.h> | ||||||
|  | #include <ATen/TensorUtils.h> | ||||||
|  | #include <ATen/cuda/CUDABlas.h> | ||||||
|  | #include <ATen/cuda/tunable/Tunable.h> | ||||||
|  | #include <ATen/cuda/tunable/TunableGemm.h> | ||||||
|  | #include <ATen/native/Resize.h> | ||||||
|  | #include <c10/util/MaybeOwned.h> | ||||||
|  | #include <ATen/native/GroupedMMUtils.h> | ||||||
|  | #include <ATen/native/cuda/RowwiseScaledMM.h> | ||||||
|  | #include <ATen/native/cuda/ScaledGroupMM.h> | ||||||
|  | #include <ATen/native/cuda/GroupMM.h> | ||||||
|  | #include <ATen/ceil_div.h> | ||||||
|  |  | ||||||
|  | #ifdef USE_FBGEMM_GENAI | ||||||
|  | #include <fbgemm_gpu/torch_ops.h> | ||||||
|  | #endif | ||||||
|  |  | ||||||
|  | #ifndef AT_PER_OPERATOR_HEADERS | ||||||
|  | #include <ATen/Functions.h> | ||||||
|  | #include <ATen/NativeFunctions.h> | ||||||
|  | #else | ||||||
|  | #include <ATen/ops/_addmm_activation_native.h> | ||||||
|  | #include <ATen/ops/_efficientzerotensor.h> | ||||||
|  | #include <ATen/ops/_scaled_mm_native.h> | ||||||
|  | #include <ATen/ops/_unsafe_view_native.h> | ||||||
|  | #include <ATen/ops/abs.h> | ||||||
|  | #include <ATen/ops/addmm_native.h> | ||||||
|  | #include <ATen/ops/addmv_native.h> | ||||||
|  | #include <ATen/ops/baddbmm_native.h> | ||||||
|  | #include <ATen/ops/bmm_native.h> | ||||||
|  | #include <ATen/ops/copy_native.h> | ||||||
|  | #include <ATen/ops/dot_native.h> | ||||||
|  | #include <ATen/ops/empty.h> | ||||||
|  | #include <ATen/ops/empty_strided.h> | ||||||
|  | #include <ATen/ops/gelu.h> | ||||||
|  | #include <ATen/ops/max.h> | ||||||
|  | #include <ATen/ops/mm_native.h> | ||||||
|  | #include <ATen/ops/mul.h> | ||||||
|  | #include <ATen/ops/relu.h> | ||||||
|  | #include <ATen/ops/ones.h> | ||||||
|  | #include <ATen/ops/scalar_tensor_native.h> | ||||||
|  | #include <ATen/ops/vdot_native.h> | ||||||
|  | #endif | ||||||
|  |  | ||||||
|  | using at::blas::ScalingType; | ||||||
|  | using at::blas::SwizzleType; | ||||||
|  |  | ||||||
|  | namespace at::cuda::scaled { | ||||||
|  |  | ||||||
|  | /** | ||||||
|  |  * Both inputs must be fp8, | ||||||
|  |  * Each needs a single scale, {Tensorwise (float)} | ||||||
|  |  */ | ||||||
|  | bool check_tensorwise_recipe(c10::ScalarType type_a, | ||||||
|  |                              std::vector<ScalingType>& recipe_a, | ||||||
|  |                              ArrayRef<Tensor>& scales_a, | ||||||
|  |                              c10::ScalarType type_b, | ||||||
|  |                              std::vector<ScalingType>& recipe_b, | ||||||
|  |                              ArrayRef<Tensor>& scales_b) { | ||||||
|  |   // both types must be fp8 | ||||||
|  |   if (!isFloat8Type(type_a) || !isFloat8Type(type_b)) { | ||||||
|  |     return false; | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   // 1 scale each, {Tensorwise, float} | ||||||
|  |   if (scales_a.size() != 1 || recipe_a.size() != 1 || scales_b.size() != 1 || recipe_b.size() != 1) { | ||||||
|  |     return false; | ||||||
|  |   } | ||||||
|  |   // Need {Blockwise_1x32, e8m0} for A & B | ||||||
|  |   if (recipe_a[0] != ScalingType::TensorWise) return false; | ||||||
|  |   if (scales_a[0].scalar_type() != ScalarType::Float) return false; | ||||||
|  |   if (recipe_b[0] != ScalingType::TensorWise) return false; | ||||||
|  |   if (scales_b[0].scalar_type() != ScalarType::Float) return false; | ||||||
|  |  | ||||||
|  |   return true; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | /** | ||||||
|  |  * Both inputs must be fp8, | ||||||
|  |  * Each needs scales, {Rowwise (float)} | ||||||
|  |  */ | ||||||
|  | bool check_rowwise_recipe(c10::ScalarType type_a, | ||||||
|  |                              std::vector<ScalingType>& recipe_a, | ||||||
|  |                              ArrayRef<Tensor>& scales_a, | ||||||
|  |                              c10::ScalarType type_b, | ||||||
|  |                              std::vector<ScalingType>& recipe_b, | ||||||
|  |                              ArrayRef<Tensor>& scales_b) { | ||||||
|  |   // both types must be fp8 | ||||||
|  |   if (!isFloat8Type(type_a) || !isFloat8Type(type_b)) { | ||||||
|  |     return false; | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   // 1 scale each, {Tensorwise, float} | ||||||
|  |   if (scales_a.size() != 1 || recipe_a.size() != 1 || scales_b.size() != 1 || recipe_b.size() != 1) { | ||||||
|  |     return false; | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   // Need {RowWise, dp32} for A & B | ||||||
|  |   if (recipe_a[0] != ScalingType::RowWise) return false; | ||||||
|  |   if (scales_a[0].scalar_type() != ScalarType::Float) return false; | ||||||
|  |   if (recipe_b[0] != ScalingType::RowWise) return false; | ||||||
|  |   if (scales_b[0].scalar_type() != ScalarType::Float) return false; | ||||||
|  |  | ||||||
|  |   return true; | ||||||
|  | } | ||||||
|  |  | ||||||
|  |  | ||||||
|  | /** | ||||||
|  |  * Two-level scaling, canonical NVFP4 | ||||||
|  |  * Both inputs must be fp4 | ||||||
|  |  * A, B need 2 scales, {Blockwise_1x16 (e4m3), Tensorwise (fp32)} | ||||||
|  |  */ | ||||||
|  | bool check_nvfp4_recipe(c10::ScalarType type_a, | ||||||
|  |                         std::vector<ScalingType>& recipe_a, | ||||||
|  |                         ArrayRef<Tensor>& scales_a, | ||||||
|  |                         c10::ScalarType type_b, | ||||||
|  |                         std::vector<ScalingType>& recipe_b, | ||||||
|  |                         ArrayRef<Tensor>& scales_b) { | ||||||
|  |   // both types must be fp4 | ||||||
|  |   if (type_a != ScalarType::Float4_e2m1fn_x2 || type_b != ScalarType::Float4_e2m1fn_x2) { | ||||||
|  |     return false; | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   // 2 scales, 2 recipes for each input | ||||||
|  |   if (scales_a.size() != 2 || recipe_a.size() != 2 || scales_b.size() != 2 || recipe_b.size() != 2) { | ||||||
|  |     return false; | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   // Need {Blockwise_1x16, e4m3 for scale[0], Tensorwise, fp32 for scale[1]} | ||||||
|  |   if (recipe_a[0] != ScalingType::BlockWise1x16 || recipe_a[1] != ScalingType::TensorWise) return false; | ||||||
|  |   if (scales_a[0].scalar_type() != ScalarType::Float8_e4m3fn || scales_a[1].scalar_type() != ScalarType::Float) return false; | ||||||
|  |   if (recipe_b[0] != ScalingType::BlockWise1x16 || recipe_b[1] != ScalingType::TensorWise) return false; | ||||||
|  |   if (scales_b[0].scalar_type() != ScalarType::Float8_e4m3fn || scales_b[1].scalar_type() != ScalarType::Float) return false; | ||||||
|  |  | ||||||
|  |   return true; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | /** | ||||||
|  |  * Single-level scaling, what PyT currently understands | ||||||
|  |  * Both inputs must be fp4 | ||||||
|  |  * A, B need 1 scale, {Blockwise_1x16 (e4m3)} | ||||||
|  |  */ | ||||||
|  | bool check_nvfp4_recipe_single_scale | ||||||
|  |                        (c10::ScalarType type_a, | ||||||
|  |                         std::vector<ScalingType>& recipe_a, | ||||||
|  |                         ArrayRef<Tensor>& scales_a, | ||||||
|  |                         c10::ScalarType type_b, | ||||||
|  |                         std::vector<ScalingType>& recipe_b, | ||||||
|  |                         ArrayRef<Tensor>& scales_b) { | ||||||
|  |   // both types must be fp4 | ||||||
|  |   if (type_a != ScalarType::Float4_e2m1fn_x2 || type_b != ScalarType::Float4_e2m1fn_x2) { | ||||||
|  |     return false; | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   // 2 scales, 2 recipes for each input | ||||||
|  |   if (scales_a.size() != 1 || recipe_a.size() != 1 || scales_b.size() != 1 || recipe_b.size() != 1) { | ||||||
|  |     return false; | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   // Need {Blockwise_1x16, e4m3 for scale[0], Tensorwise, fp32 for scale[1]} | ||||||
|  |   if (recipe_a[0] != ScalingType::BlockWise1x16) return false; | ||||||
|  |   if (scales_a[0].scalar_type() != ScalarType::Float8_e4m3fn) return false; | ||||||
|  |   if (recipe_b[0] != ScalingType::BlockWise1x16) return false; | ||||||
|  |   if (scales_b[0].scalar_type() != ScalarType::Float8_e4m3fn) return false; | ||||||
|  |  | ||||||
|  |   return true; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | /** | ||||||
|  |  * Both inputs must be fp8 | ||||||
|  |  * A, B must only have 1 scale each, A: {Blockwise_1x128 (float), B: {Blockwise_128x128 (float) | ||||||
|  |  */ | ||||||
|  | bool check_deepseek_recipe(ScalingType expected_recipe_a, | ||||||
|  |                            ScalingType expected_recipe_b, | ||||||
|  |                            c10::ScalarType type_a, | ||||||
|  |                            std::vector<ScalingType>& recipe_a, | ||||||
|  |                            ArrayRef<Tensor>& scales_a, | ||||||
|  |                            c10::ScalarType type_b, | ||||||
|  |                            std::vector<ScalingType>& recipe_b, | ||||||
|  |                            ArrayRef<Tensor>& scales_b) { | ||||||
|  |   // both types must be fp8 | ||||||
|  |   if (type_a != ScalarType::Float8_e4m3fn || type_b != ScalarType::Float8_e4m3fn) { | ||||||
|  |     return false; | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   // 1 scales, 1 recipes for each input | ||||||
|  |   if (scales_a.size() != 1 || recipe_a.size() != 1 || scales_b.size() != 1 || recipe_b.size() != 1) { | ||||||
|  |     return false; | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   // Need {Blockwise_1x128, float} for A, {Blockwise_128x128, float} for B | ||||||
|  |   if (recipe_a[0] != expected_recipe_a) return false; | ||||||
|  |   if (scales_a[0].scalar_type() != ScalarType::Float) return false; | ||||||
|  |   if (recipe_b[0] != expected_recipe_b) return false; | ||||||
|  |   if (scales_b[0].scalar_type() != ScalarType::Float) return false; | ||||||
|  |  | ||||||
|  |   return true; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | /** | ||||||
|  |  * Both inputs must be fp8 | ||||||
|  |  * A, B must have 1 scale each, {Blockwise_1x32, e8m0} | ||||||
|  |  */ | ||||||
|  | bool check_mxfp8_recipe(c10::ScalarType type_a, | ||||||
|  |                         std::vector<ScalingType>& recipe_a, | ||||||
|  |                         ArrayRef<Tensor>& scales_a, | ||||||
|  |                         c10::ScalarType type_b, | ||||||
|  |                         std::vector<ScalingType>& recipe_b, | ||||||
|  |                         ArrayRef<Tensor>& scales_b) { | ||||||
|  |   // both types must be fp8 | ||||||
|  |   if (type_a != ScalarType::Float8_e4m3fn || type_b != ScalarType::Float8_e4m3fn) { | ||||||
|  |     return false; | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   // 1 scales, 1 recipes for each input | ||||||
|  |   if (scales_a.size() != 1 || recipe_a.size() != 1 || scales_b.size() != 1 || recipe_b.size() != 1) { | ||||||
|  |     return false; | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   // Need {Blockwise_1x32, e8m0} for A & B | ||||||
|  |   if (recipe_a[0] != ScalingType::BlockWise1x32) return false; | ||||||
|  |   if (scales_a[0].scalar_type() != ScalarType::Float8_e8m0fnu) return false; | ||||||
|  |   if (recipe_b[0] != ScalingType::BlockWise1x32) return false; | ||||||
|  |   if (scales_b[0].scalar_type() != ScalarType::Float8_e8m0fnu) return false; | ||||||
|  |  | ||||||
|  |   return true; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | /** | ||||||
|  |  * Both inputs must be fp4 | ||||||
|  |  * A, B must have 1 scale each, {Blockwise_1x32, e8m0} | ||||||
|  |  */ | ||||||
|  | bool check_mxfp4_recipe(c10::ScalarType type_a, | ||||||
|  |                         std::vector<ScalingType>& recipe_a, | ||||||
|  |                         ArrayRef<Tensor>& scales_a, | ||||||
|  |                         c10::ScalarType type_b, | ||||||
|  |                         std::vector<ScalingType>& recipe_b, | ||||||
|  |                         ArrayRef<Tensor>& scales_b) { | ||||||
|  |   // both types must be fp4 | ||||||
|  |   if (type_a != ScalarType::Float4_e2m1fn_x2 || type_b != ScalarType::Float4_e2m1fn_x2) { | ||||||
|  |     return false; | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   // 1 scales, 1 recipes for each input | ||||||
|  |   if (scales_a.size() != 1 || recipe_a.size() != 1 || scales_b.size() != 1 || recipe_b.size() != 1) { | ||||||
|  |     return false; | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   // Need {Blockwise_1x32, e8m0} for A & B | ||||||
|  |   if (recipe_a[0] != ScalingType::BlockWise1x32) return false; | ||||||
|  |   if (scales_a[0].scalar_type() != ScalarType::Float8_e8m0fnu) return false; | ||||||
|  |   if (recipe_b[0] != ScalingType::BlockWise1x32) return false; | ||||||
|  |   if (scales_b[0].scalar_type() != ScalarType::Float8_e8m0fnu) return false; | ||||||
|  |  | ||||||
|  |   return true; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | } // namespace at::native::cuda::blas::scaled | ||||||
							
								
								
									
										174
									
								
								aten/src/ATen/cuda/CUDAScaledBlas.h
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										174
									
								
								aten/src/ATen/cuda/CUDAScaledBlas.h
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,174 @@ | |||||||
|  | #include <cstdint> | ||||||
|  | #include <c10/util/typeid.h> | ||||||
|  | #include <c10/util/Exception.h> | ||||||
|  | #include <c10/util/SmallVector.h> | ||||||
|  | #include <c10/core/Scalar.h> | ||||||
|  | #include <c10/core/ScalarType.h> | ||||||
|  | #include <c10/util/Exception.h> | ||||||
|  | #define TORCH_ASSERT_ONLY_METHOD_OPERATORS | ||||||
|  | #include <ATen/core/Tensor.h> | ||||||
|  | #include <ATen/core/NamedTensor.h> | ||||||
|  | #include <ATen/Dispatch.h> | ||||||
|  | #include <ATen/ExpandUtils.h> | ||||||
|  | #include <ATen/OpMathType.h> | ||||||
|  | #include <ATen/TensorUtils.h> | ||||||
|  | #include <ATen/cuda/CUDABlas.h> | ||||||
|  | #include <ATen/cuda/tunable/Tunable.h> | ||||||
|  | #include <ATen/cuda/tunable/TunableGemm.h> | ||||||
|  | #include <ATen/native/Resize.h> | ||||||
|  | #include <c10/util/MaybeOwned.h> | ||||||
|  | #include <ATen/native/GroupedMMUtils.h> | ||||||
|  | #include <ATen/native/cuda/RowwiseScaledMM.h> | ||||||
|  | #include <ATen/native/cuda/ScaledGroupMM.h> | ||||||
|  | #include <ATen/native/cuda/GroupMM.h> | ||||||
|  | #include <ATen/ceil_div.h> | ||||||
|  |  | ||||||
|  | #ifdef USE_FBGEMM_GENAI | ||||||
|  | #include <fbgemm_gpu/torch_ops.h> | ||||||
|  | #endif | ||||||
|  |  | ||||||
|  | #ifndef AT_PER_OPERATOR_HEADERS | ||||||
|  | #include <ATen/Functions.h> | ||||||
|  | #include <ATen/NativeFunctions.h> | ||||||
|  | #else | ||||||
|  | #include <ATen/ops/_addmm_activation_native.h> | ||||||
|  | #include <ATen/ops/_efficientzerotensor.h> | ||||||
|  | #include <ATen/ops/_scaled_mm_native.h> | ||||||
|  | #include <ATen/ops/_unsafe_view_native.h> | ||||||
|  | #include <ATen/ops/abs.h> | ||||||
|  | #include <ATen/ops/addmm_native.h> | ||||||
|  | #include <ATen/ops/addmv_native.h> | ||||||
|  | #include <ATen/ops/baddbmm_native.h> | ||||||
|  | #include <ATen/ops/bmm_native.h> | ||||||
|  | #include <ATen/ops/copy_native.h> | ||||||
|  | #include <ATen/ops/dot_native.h> | ||||||
|  | #include <ATen/ops/empty.h> | ||||||
|  | #include <ATen/ops/empty_strided.h> | ||||||
|  | #include <ATen/ops/gelu.h> | ||||||
|  | #include <ATen/ops/max.h> | ||||||
|  | #include <ATen/ops/mm_native.h> | ||||||
|  | #include <ATen/ops/mul.h> | ||||||
|  | #include <ATen/ops/relu.h> | ||||||
|  | #include <ATen/ops/ones.h> | ||||||
|  | #include <ATen/ops/scalar_tensor_native.h> | ||||||
|  | #include <ATen/ops/vdot_native.h> | ||||||
|  | #endif | ||||||
|  |  | ||||||
|  | using at::blas::ScalingType; | ||||||
|  | using at::blas::SwizzleType; | ||||||
|  |  | ||||||
|  | namespace at::cuda::scaled { | ||||||
|  |  | ||||||
|  | static bool _scaled_mm_allowed_device(bool sm90_only=false, bool sm100_only=false) { | ||||||
|  | #ifdef USE_ROCM | ||||||
|  |     static const std::vector<std::string> archs = { | ||||||
|  |         "gfx942", | ||||||
|  | #if ROCM_VERSION >= 60300 | ||||||
|  |         "gfx1200", "gfx1201", | ||||||
|  | #endif | ||||||
|  | #if ROCM_VERSION >= 60500 | ||||||
|  |         "gfx950" | ||||||
|  | #endif | ||||||
|  |     }; | ||||||
|  |     return at::detail::getCUDAHooks().isGPUArch(archs); | ||||||
|  | #else | ||||||
|  |     auto dprops = at::cuda::getCurrentDeviceProperties(); | ||||||
|  |  | ||||||
|  |     if (sm90_only || sm100_only) { | ||||||
|  |       return (sm90_only && dprops->major == 9) || (sm100_only && dprops->major == 10); | ||||||
|  |     } else { | ||||||
|  |       return dprops->major >= 9 || (dprops->major == 8 && dprops->minor == 9); | ||||||
|  |     } | ||||||
|  | #endif | ||||||
|  | } | ||||||
|  |  | ||||||
|  | #ifdef USE_ROCM | ||||||
|  | static bool _scaled_mm_is_fnuz() { | ||||||
|  |     return at::detail::getCUDAHooks().isGPUArch({"gfx942"}); | ||||||
|  | } | ||||||
|  | #endif | ||||||
|  | /** | ||||||
|  |  * Track concrete implementations available | ||||||
|  |  */ | ||||||
|  | enum class ScaledGemmImplementation { | ||||||
|  |   NONE = 0, | ||||||
|  |   TENSORWISE_TENSORWISE = 1, | ||||||
|  |   ROWWISE_ROWWISE = 2, | ||||||
|  |   BLOCK_128x128_1x128 = 3, | ||||||
|  |   BLOCK_1x128_128x128 = 4, | ||||||
|  |   BLOCK_1x128_1x128 = 5, | ||||||
|  |   MXFP8_MXFP8 = 6, | ||||||
|  |   NVFP4_NVFP4 = 7, | ||||||
|  |   NVFP4_NVFP4_SINGLE_SCALE = 8, | ||||||
|  |   MXFP4_MXFP4 = 9, | ||||||
|  | }; | ||||||
|  |  | ||||||
|  | /** | ||||||
|  |  * Convert passed int (enum) from python back into a | ||||||
|  |  * strictly-typed enum | ||||||
|  |  */ | ||||||
|  | template <class EnumType, class ArrayType> | ||||||
|  | std::vector<EnumType> convert_int_to_enum(ArrayType& v) { | ||||||
|  |   std::vector<EnumType> converted; | ||||||
|  |   converted.reserve(v.size()); | ||||||
|  |  | ||||||
|  |   for (auto vi : v) { | ||||||
|  |     converted.push_back(static_cast<EnumType>(vi)); | ||||||
|  |   } | ||||||
|  |   return converted; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | bool check_tensorwise_recipe(c10::ScalarType, | ||||||
|  |                              std::vector<ScalingType>&, | ||||||
|  |                              ArrayRef<Tensor>&, | ||||||
|  |                              c10::ScalarType, | ||||||
|  |                              std::vector<ScalingType>&, | ||||||
|  |                              ArrayRef<Tensor>&); | ||||||
|  |  | ||||||
|  |  | ||||||
|  | bool check_rowwise_recipe(c10::ScalarType, | ||||||
|  |                              std::vector<ScalingType>&, | ||||||
|  |                              ArrayRef<Tensor>&, | ||||||
|  |                              c10::ScalarType, | ||||||
|  |                              std::vector<ScalingType>&, | ||||||
|  |                              ArrayRef<Tensor>&); | ||||||
|  |  | ||||||
|  | bool check_nvfp4_recipe(c10::ScalarType, | ||||||
|  |                         std::vector<ScalingType>&, | ||||||
|  |                         ArrayRef<Tensor>&, | ||||||
|  |                         c10::ScalarType, | ||||||
|  |                         std::vector<ScalingType>&, | ||||||
|  |                         ArrayRef<Tensor>&); | ||||||
|  |  | ||||||
|  | bool check_nvfp4_recipe_single_scale | ||||||
|  |                        (c10::ScalarType, | ||||||
|  |                         std::vector<ScalingType>&, | ||||||
|  |                         ArrayRef<Tensor>&, | ||||||
|  |                         c10::ScalarType, | ||||||
|  |                         std::vector<ScalingType>&, | ||||||
|  |                         ArrayRef<Tensor>&); | ||||||
|  |  | ||||||
|  | bool check_deepseek_recipe(ScalingType, | ||||||
|  |                            ScalingType, | ||||||
|  |                            c10::ScalarType, | ||||||
|  |                            std::vector<ScalingType>&, | ||||||
|  |                            ArrayRef<Tensor>&, | ||||||
|  |                            c10::ScalarType, | ||||||
|  |                            std::vector<ScalingType>&, | ||||||
|  |                            ArrayRef<Tensor>&); | ||||||
|  |  | ||||||
|  | bool check_mxfp8_recipe(c10::ScalarType, | ||||||
|  |                         std::vector<ScalingType>&, | ||||||
|  |                         ArrayRef<Tensor>&, | ||||||
|  |                         c10::ScalarType, | ||||||
|  |                         std::vector<ScalingType>&, | ||||||
|  |                         ArrayRef<Tensor>&); | ||||||
|  |  | ||||||
|  | bool check_mxfp4_recipe(c10::ScalarType, | ||||||
|  |                         std::vector<ScalingType>&, | ||||||
|  |                         ArrayRef<Tensor>&, | ||||||
|  |                         c10::ScalarType, | ||||||
|  |                         std::vector<ScalingType>&, | ||||||
|  |                         ArrayRef<Tensor>&); | ||||||
|  |  | ||||||
|  | } // namespace at::native::cuda::blas::scaled | ||||||
| @ -70,11 +70,7 @@ | |||||||
| #define ATEN_CUB_MAXIMUM() NO_ROCM(at_cuda_detail)ROCM_HIPCUB(::cub)::Max() | #define ATEN_CUB_MAXIMUM() NO_ROCM(at_cuda_detail)ROCM_HIPCUB(::cub)::Max() | ||||||
| #endif | #endif | ||||||
|  |  | ||||||
| #if (!defined(USE_ROCM) && !CUB_SUPPORTS_NV_BFLOAT16()) || defined(USE_ROCM) | #if defined(USE_ROCM) | ||||||
|  |  | ||||||
| #if !defined(USE_ROCM) |  | ||||||
| namespace at_cuda_detail { |  | ||||||
| #endif |  | ||||||
|  |  | ||||||
| // backport https://github.com/NVIDIA/cub/pull/306 for c10::BFloat16 | // backport https://github.com/NVIDIA/cub/pull/306 for c10::BFloat16 | ||||||
|  |  | ||||||
| @ -96,10 +92,6 @@ template <> | |||||||
| struct ROCM_HIPCUB(cub)::NumericTraits<c10::BFloat16>: | struct ROCM_HIPCUB(cub)::NumericTraits<c10::BFloat16>: | ||||||
|        ROCM_HIPCUB(cub)::BaseTraits<ROCM_HIPCUB(cub)::FLOATING_POINT, true, false, unsigned short, c10::BFloat16> {}; |        ROCM_HIPCUB(cub)::BaseTraits<ROCM_HIPCUB(cub)::FLOATING_POINT, true, false, unsigned short, c10::BFloat16> {}; | ||||||
|  |  | ||||||
| #if !defined(USE_ROCM) |  | ||||||
| } // namespace at_cuda_detail |  | ||||||
| #endif |  | ||||||
|  |  | ||||||
| #endif | #endif | ||||||
|  |  | ||||||
| #if !defined(USE_ROCM) | #if !defined(USE_ROCM) | ||||||
| @ -121,7 +113,7 @@ struct cuda_type<c10::Half> { | |||||||
|   using type = __half; |   using type = __half; | ||||||
| }; | }; | ||||||
|  |  | ||||||
| #if !defined(USE_ROCM) && CUB_SUPPORTS_NV_BFLOAT16() | #if !defined(USE_ROCM) | ||||||
|  |  | ||||||
| template<> | template<> | ||||||
| struct cuda_type<c10::BFloat16> { | struct cuda_type<c10::BFloat16> { | ||||||
| @ -203,36 +195,6 @@ __global__ void transform_vals(InputIteratorT1 a, InputIteratorT2 b, OutputItera | |||||||
|   *out = scan_op(static_cast<acc_t>(*a), static_cast<acc_t>(*b)); |   *out = scan_op(static_cast<acc_t>(*a), static_cast<acc_t>(*b)); | ||||||
| } | } | ||||||
|  |  | ||||||
| #if !CUB_SUPPORTS_FUTURE_VALUE() |  | ||||||
| template<typename ValueT, typename InputIteratorT> |  | ||||||
| struct chained_iterator { |  | ||||||
|   using iterator_category = std::random_access_iterator_tag; |  | ||||||
|   using difference_type   = std::ptrdiff_t; |  | ||||||
|   using value_type        = ValueT; |  | ||||||
|   using pointer           = ValueT*; |  | ||||||
|   using reference         = ValueT&; |  | ||||||
|  |  | ||||||
|   InputIteratorT iter; |  | ||||||
|   ValueT *first; |  | ||||||
|   difference_type offset = 0; |  | ||||||
|  |  | ||||||
|   __device__ ValueT operator[](difference_type i) { |  | ||||||
|     i +=  offset; |  | ||||||
|     if (i == 0) { |  | ||||||
|       return *first; |  | ||||||
|     } else { |  | ||||||
|       return ValueT(iter[i - 1]); |  | ||||||
|     } |  | ||||||
|   } |  | ||||||
|   __device__ chained_iterator operator+(difference_type i) { |  | ||||||
|     return chained_iterator{iter, first, i}; |  | ||||||
|   } |  | ||||||
|   __device__ ValueT operator*() { |  | ||||||
|     return (*this)[0]; |  | ||||||
|   } |  | ||||||
| }; |  | ||||||
| #endif |  | ||||||
|  |  | ||||||
| // even though cub is supposed to support tensors with int_max elements, in reality it doesn't, | // even though cub is supposed to support tensors with int_max elements, in reality it doesn't, | ||||||
| // so split at int_max/2 | // so split at int_max/2 | ||||||
| constexpr int max_cub_size = std::numeric_limits<int>::max() / 2 + 1; // 2**30 | constexpr int max_cub_size = std::numeric_limits<int>::max() / 2 + 1; // 2**30 | ||||||
| @ -277,25 +239,6 @@ inline void inclusive_scan(InputIteratorT input, OutputIteratorT output, ScanOpT | |||||||
|         first_elem_ptr, |         first_elem_ptr, | ||||||
|         scan_op); |         scan_op); | ||||||
|     C10_CUDA_KERNEL_LAUNCH_CHECK(); |     C10_CUDA_KERNEL_LAUNCH_CHECK(); | ||||||
| #if !CUB_SUPPORTS_FUTURE_VALUE() |  | ||||||
|     using ArgIndexInputIterator = NO_ROCM(at_cuda_detail)::cub::ArgIndexInputIterator<InputIteratorT>; |  | ||||||
|     using tuple = typename ArgIndexInputIterator::value_type; |  | ||||||
|     auto input_iter_transform = [=] __device__ (const tuple &x)->input_t  { |  | ||||||
|       if (x.key == 0) { |  | ||||||
|         return *first_elem_ptr; |  | ||||||
|       } else { |  | ||||||
|         return x.value; |  | ||||||
|       } |  | ||||||
|     }; |  | ||||||
|     auto input_ = ATEN_CUB_TRANSFORM_ITERATOR(input_t, decltype(input_iter_transform), ArgIndexInputIterator)( |  | ||||||
|       ArgIndexInputIterator(input + i), input_iter_transform); |  | ||||||
|     CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceScan::InclusiveScan, |  | ||||||
|         input_, |  | ||||||
|         output + i, |  | ||||||
|         scan_op, |  | ||||||
|         size_cub, |  | ||||||
|         at::cuda::getCurrentCUDAStream()); |  | ||||||
| #else |  | ||||||
|     CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceScan::ExclusiveScan, |     CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceScan::ExclusiveScan, | ||||||
|         input + i + 1, |         input + i + 1, | ||||||
|         output + i, |         output + i, | ||||||
| @ -303,7 +246,6 @@ inline void inclusive_scan(InputIteratorT input, OutputIteratorT output, ScanOpT | |||||||
|         ::at_cuda_detail::cub::FutureValue<input_t>(first_elem_ptr), |         ::at_cuda_detail::cub::FutureValue<input_t>(first_elem_ptr), | ||||||
|         size_cub, |         size_cub, | ||||||
|         at::cuda::getCurrentCUDAStream()); |         at::cuda::getCurrentCUDAStream()); | ||||||
| #endif |  | ||||||
|   } |   } | ||||||
| #endif | #endif | ||||||
| } | } | ||||||
| @ -555,16 +497,6 @@ inline void exclusive_scan(InputIteratorT input, OutputIteratorT output, ScanOpT | |||||||
|         first_elem_ptr, |         first_elem_ptr, | ||||||
|         scan_op); |         scan_op); | ||||||
|     C10_CUDA_KERNEL_LAUNCH_CHECK(); |     C10_CUDA_KERNEL_LAUNCH_CHECK(); | ||||||
| #if !CUB_SUPPORTS_FUTURE_VALUE() |  | ||||||
|     auto input_ = impl::chained_iterator<InitValueT, InputIteratorT>{ |  | ||||||
|       input + i, first_elem_ptr}; |  | ||||||
|     CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceScan::InclusiveScan, |  | ||||||
|         input_, |  | ||||||
|         output + i, |  | ||||||
|         scan_op, |  | ||||||
|         size_cub, |  | ||||||
|         at::cuda::getCurrentCUDAStream()); |  | ||||||
| #else |  | ||||||
|     CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceScan::ExclusiveScan, |     CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceScan::ExclusiveScan, | ||||||
|         input + i, |         input + i, | ||||||
|         output + i, |         output + i, | ||||||
| @ -572,7 +504,6 @@ inline void exclusive_scan(InputIteratorT input, OutputIteratorT output, ScanOpT | |||||||
|         ::at_cuda_detail::cub::FutureValue<InitValueT>(first_elem_ptr), |         ::at_cuda_detail::cub::FutureValue<InitValueT>(first_elem_ptr), | ||||||
|         size_cub, |         size_cub, | ||||||
|         at::cuda::getCurrentCUDAStream()); |         at::cuda::getCurrentCUDAStream()); | ||||||
| #endif |  | ||||||
|   } |   } | ||||||
| #endif | #endif | ||||||
| } | } | ||||||
|  | |||||||
| @ -10,14 +10,6 @@ | |||||||
| #define CUB_VERSION 200001 | #define CUB_VERSION 200001 | ||||||
| #endif | #endif | ||||||
|  |  | ||||||
| // cub sort support for __nv_bfloat16 is added to cub 1.13 in: |  | ||||||
| // https://github.com/NVIDIA/cub/pull/306 |  | ||||||
| #if CUB_VERSION >= 101300 |  | ||||||
| #define CUB_SUPPORTS_NV_BFLOAT16() true |  | ||||||
| #else |  | ||||||
| #define CUB_SUPPORTS_NV_BFLOAT16() false |  | ||||||
| #endif |  | ||||||
|  |  | ||||||
| // cub support for CUB_WRAPPED_NAMESPACE is added to cub 1.13.1 in: | // cub support for CUB_WRAPPED_NAMESPACE is added to cub 1.13.1 in: | ||||||
| // https://github.com/NVIDIA/cub/pull/326 | // https://github.com/NVIDIA/cub/pull/326 | ||||||
| // CUB_WRAPPED_NAMESPACE is defined globally in cmake/Dependencies.cmake | // CUB_WRAPPED_NAMESPACE is defined globally in cmake/Dependencies.cmake | ||||||
| @ -28,14 +20,6 @@ | |||||||
| #define USE_GLOBAL_CUB_WRAPPED_NAMESPACE() false | #define USE_GLOBAL_CUB_WRAPPED_NAMESPACE() false | ||||||
| #endif | #endif | ||||||
|  |  | ||||||
| // cub support for cub::FutureValue is added to cub 1.15 in: |  | ||||||
| // https://github.com/NVIDIA/cub/pull/305 |  | ||||||
| #if CUB_VERSION >= 101500 |  | ||||||
| #define CUB_SUPPORTS_FUTURE_VALUE() true |  | ||||||
| #else |  | ||||||
| #define CUB_SUPPORTS_FUTURE_VALUE() false |  | ||||||
| #endif |  | ||||||
|  |  | ||||||
| // There were many bc-breaking changes in major version release of CCCL v3.0.0 | // There were many bc-breaking changes in major version release of CCCL v3.0.0 | ||||||
| // Please see https://nvidia.github.io/cccl/cccl/3.0_migration_guide.html | // Please see https://nvidia.github.io/cccl/cccl/3.0_migration_guide.html | ||||||
| #if CUB_VERSION >= 200800 | #if CUB_VERSION >= 200800 | ||||||
|  | |||||||
							
								
								
									
										23
									
								
								aten/src/ATen/detail/XLAHooksInterface.cpp
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										23
									
								
								aten/src/ATen/detail/XLAHooksInterface.cpp
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,23 @@ | |||||||
|  | #include <ATen/detail/XLAHooksInterface.h> | ||||||
|  |  | ||||||
|  | namespace at { | ||||||
|  | namespace detail { | ||||||
|  |  | ||||||
|  | const XLAHooksInterface& getXLAHooks() { | ||||||
|  |   auto create_impl = [] { | ||||||
|  |     // Create XLA hooks using the registry | ||||||
|  |     auto hooks = XLAHooksRegistry()->Create("torch_xla::detail::XLAHooks", XLAHooksArgs{}); | ||||||
|  |     if (hooks) { | ||||||
|  |       return hooks; | ||||||
|  |     } | ||||||
|  |     // If hooks creation fails, fall back to default implementation | ||||||
|  |     return std::make_unique<XLAHooksInterface>(); | ||||||
|  |   }; | ||||||
|  |   static auto hooks = create_impl(); | ||||||
|  |   return *hooks; | ||||||
|  | } | ||||||
|  | } // namespace detail | ||||||
|  |  | ||||||
|  | C10_DEFINE_REGISTRY(XLAHooksRegistry, XLAHooksInterface, XLAHooksArgs) | ||||||
|  |  | ||||||
|  | } // namespace at | ||||||
							
								
								
									
										79
									
								
								aten/src/ATen/detail/XLAHooksInterface.h
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										79
									
								
								aten/src/ATen/detail/XLAHooksInterface.h
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,79 @@ | |||||||
|  | #pragma once | ||||||
|  |  | ||||||
|  | #include <c10/core/Device.h> | ||||||
|  | #include <c10/util/Exception.h> | ||||||
|  | #include <c10/util/Registry.h> | ||||||
|  |  | ||||||
|  | #include <ATen/detail/AcceleratorHooksInterface.h> | ||||||
|  |  | ||||||
|  | C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-parameter") | ||||||
|  |  | ||||||
|  | namespace at { | ||||||
|  |  | ||||||
|  | constexpr const char* XLA_HELP = | ||||||
|  |   "This error has occurred because you are trying " | ||||||
|  |   "to use some XLA functionality, but the XLA library has not been " | ||||||
|  |   "loaded by the dynamic linker. You must load xla libraries by `import torch_xla`"; | ||||||
|  |  | ||||||
|  | struct TORCH_API XLAHooksInterface : AcceleratorHooksInterface { | ||||||
|  |   ~XLAHooksInterface() override = default; | ||||||
|  |  | ||||||
|  |   void init() const override { | ||||||
|  |     TORCH_CHECK(false, "Cannot initialize XLA without torch_xla library. ", XLA_HELP); | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   virtual bool hasXLA() const { | ||||||
|  |     return false; | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   virtual std::string showConfig() const { | ||||||
|  |     TORCH_CHECK( | ||||||
|  |         false, | ||||||
|  |         "Cannot query detailed XLA version without torch_xla library. ", | ||||||
|  |         XLA_HELP); | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   const Generator& getDefaultGenerator( | ||||||
|  |       [[maybe_unused]] DeviceIndex device_index = -1) const override { | ||||||
|  |     TORCH_CHECK( | ||||||
|  |         false, "Cannot get default XLA generator without torch_xla library. ", XLA_HELP); | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   Generator getNewGenerator( | ||||||
|  |       [[maybe_unused]] DeviceIndex device_index = -1) const override { | ||||||
|  |     TORCH_CHECK(false, "Cannot get XLA generator without torch_xla library. ", XLA_HELP); | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   virtual DeviceIndex getCurrentDevice() const override { | ||||||
|  |     TORCH_CHECK(false, "Cannot get current XLA device without torch_xla library. ", XLA_HELP); | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   Device getDeviceFromPtr(void* /*data*/) const override { | ||||||
|  |     TORCH_CHECK(false, "Cannot get device of pointer on XLA without torch_xla library. ", XLA_HELP); | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   Allocator* getPinnedMemoryAllocator() const override { | ||||||
|  |     TORCH_CHECK(false, "Cannot get XLA pinned memory allocator without torch_xla library. ", XLA_HELP); | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   bool isPinnedPtr(const void* data) const override { | ||||||
|  |     return false; | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   bool hasPrimaryContext(DeviceIndex device_index) const override { | ||||||
|  |     TORCH_CHECK(false, "Cannot query primary context without torch_xla library. ", XLA_HELP); | ||||||
|  |   } | ||||||
|  |  | ||||||
|  | }; | ||||||
|  |  | ||||||
|  | struct TORCH_API XLAHooksArgs {}; | ||||||
|  |  | ||||||
|  | TORCH_DECLARE_REGISTRY(XLAHooksRegistry, XLAHooksInterface, XLAHooksArgs); | ||||||
|  | #define REGISTER_XLA_HOOKS(clsname) \ | ||||||
|  |   C10_REGISTER_CLASS(XLAHooksRegistry, clsname, clsname) | ||||||
|  |  | ||||||
|  | namespace detail { | ||||||
|  | TORCH_API const XLAHooksInterface& getXLAHooks(); | ||||||
|  | } // namespace detail | ||||||
|  | } // namespace at | ||||||
|  | C10_DIAGNOSTIC_POP() | ||||||
| @ -11,6 +11,8 @@ inline void check_pixel_shuffle_shapes(const Tensor& self, int64_t upscale_facto | |||||||
|               "pixel_shuffle expects a positive upscale_factor, but got ", |               "pixel_shuffle expects a positive upscale_factor, but got ", | ||||||
|               upscale_factor); |               upscale_factor); | ||||||
|   int64_t c = self.size(-3); |   int64_t c = self.size(-3); | ||||||
|  |   TORCH_CHECK_VALUE(upscale_factor <= std::numeric_limits<decltype(upscale_factor)>::max() / upscale_factor, | ||||||
|  |         "upscale factor is too large, (upscale_factor)^2 overflowed: upscale_factor=", upscale_factor); | ||||||
|   int64_t upscale_factor_squared = upscale_factor * upscale_factor; |   int64_t upscale_factor_squared = upscale_factor * upscale_factor; | ||||||
|   TORCH_CHECK(c % upscale_factor_squared == 0, |   TORCH_CHECK(c % upscale_factor_squared == 0, | ||||||
|               "pixel_shuffle expects its input's 'channel' dimension to be divisible by the square of " |               "pixel_shuffle expects its input's 'channel' dimension to be divisible by the square of " | ||||||
|  | |||||||
| @ -259,11 +259,20 @@ inline void winograd_f2k3_input_transform_inplace__rvv( | |||||||
|   const vfloat32m1_t wd1 = __riscv_vfadd_vv_f32m1(d1, d2, 4); |   const vfloat32m1_t wd1 = __riscv_vfadd_vv_f32m1(d1, d2, 4); | ||||||
|   const vfloat32m1_t wd2 = __riscv_vfsub_vv_f32m1(d2, d1, 4); |   const vfloat32m1_t wd2 = __riscv_vfsub_vv_f32m1(d2, d1, 4); | ||||||
|   const vfloat32m1_t wd3 = __riscv_vfsub_vv_f32m1(d1, d3, 4); |   const vfloat32m1_t wd3 = __riscv_vfsub_vv_f32m1(d1, d3, 4); | ||||||
|  |   /* GCC 14.2 (RISC-V RVV) ICE workaround: | ||||||
|   *input_tile_val = __riscv_vset_v_f32m1_f32m1x4(*input_tile_val, 0, wd0); |    * Avoid single-statement read-modify-write on MEM_REF like: | ||||||
|   *input_tile_val = __riscv_vset_v_f32m1_f32m1x4(*input_tile_val, 1, wd1); |    *   *input_tile_val = | ||||||
|   *input_tile_val = __riscv_vset_v_f32m1_f32m1x4(*input_tile_val, 2, wd2); |    *     __riscv_vset_v_f32m1_f32m1x4(*input_tile_val, idx, val); | ||||||
|   *input_tile_val = __riscv_vset_v_f32m1_f32m1x4(*input_tile_val, 3, wd3); |    * This triggers an ICE during GIMPLE lower (gsi_replace / riscv_gimple_fold_builtin) | ||||||
|  |    * with -march=rv64gcv. Use a temporary then write back. | ||||||
|  |    * Do NOT refactor into the single-statement form. Clang is unaffected. | ||||||
|  |    */ | ||||||
|  |   vfloat32m1x4_t tmp_input_tile_val = *input_tile_val; | ||||||
|  |   tmp_input_tile_val = __riscv_vset_v_f32m1_f32m1x4(tmp_input_tile_val, 0, wd0); | ||||||
|  |   tmp_input_tile_val = __riscv_vset_v_f32m1_f32m1x4(tmp_input_tile_val, 1, wd1); | ||||||
|  |   tmp_input_tile_val = __riscv_vset_v_f32m1_f32m1x4(tmp_input_tile_val, 2, wd2); | ||||||
|  |   tmp_input_tile_val = __riscv_vset_v_f32m1_f32m1x4(tmp_input_tile_val, 3, wd3); | ||||||
|  |   *input_tile_val = tmp_input_tile_val; | ||||||
| } | } | ||||||
|  |  | ||||||
| inline void winograd_f2k3_output_transform_inplace__rvv( | inline void winograd_f2k3_output_transform_inplace__rvv( | ||||||
| @ -277,9 +286,15 @@ inline void winograd_f2k3_output_transform_inplace__rvv( | |||||||
|   const vfloat32m1_t wm0 = __riscv_vfadd_vv_f32m1(m0_plus_m1, m2, 4); |   const vfloat32m1_t wm0 = __riscv_vfadd_vv_f32m1(m0_plus_m1, m2, 4); | ||||||
|   const vfloat32m1_t m1_sub_m2 = __riscv_vfsub_vv_f32m1(m1, m2, 4); |   const vfloat32m1_t m1_sub_m2 = __riscv_vfsub_vv_f32m1(m1, m2, 4); | ||||||
|   const vfloat32m1_t wm1 = __riscv_vfsub_vv_f32m1(m1_sub_m2, m3, 4); |   const vfloat32m1_t wm1 = __riscv_vfsub_vv_f32m1(m1_sub_m2, m3, 4); | ||||||
|  |   /* GCC 14.2 (RISC-V RVV) ICE workaround — see note above. | ||||||
|   *input_tile_val = __riscv_vset_v_f32m1_f32m1x4(*input_tile_val, 0, wm0); |    * Keep the temporary + write-back pattern to avoid ICE. | ||||||
|   *input_tile_val = __riscv_vset_v_f32m1_f32m1x4(*input_tile_val, 1, wm1); |    * Do NOT rewrite into: | ||||||
|  |    *   *input_tile_val = __riscv_vset_v_f32m1_f32m1x4(*input_tile_val, idx, val); | ||||||
|  |    */ | ||||||
|  |   vfloat32m1x4_t tmp_output_tile_val = *input_tile_val; | ||||||
|  |   tmp_output_tile_val = __riscv_vset_v_f32m1_f32m1x4(tmp_output_tile_val, 0, wm0); | ||||||
|  |   tmp_output_tile_val = __riscv_vset_v_f32m1_f32m1x4(tmp_output_tile_val, 1, wm1); | ||||||
|  |   *input_tile_val = tmp_output_tile_val; | ||||||
| } | } | ||||||
|  |  | ||||||
| inline vfloat32m1_t | inline vfloat32m1_t | ||||||
| @ -300,11 +315,17 @@ inline void winograd_f2k3_kernel_transform__rvv( | |||||||
|   const vfloat32m1_t const_half = __riscv_vfmv_v_f_f32m1(0.5f, 4); |   const vfloat32m1_t const_half = __riscv_vfmv_v_f_f32m1(0.5f, 4); | ||||||
|   const vfloat32m1_t g0_plus_g2 = __riscv_vfadd_vv_f32m1(g0, g2, 4); |   const vfloat32m1_t g0_plus_g2 = __riscv_vfadd_vv_f32m1(g0, g2, 4); | ||||||
|   vfloat32m1_t half_g0_plus_g2 =  __riscv_vfmul_vv_f32m1(const_half, g0_plus_g2, 4); |   vfloat32m1_t half_g0_plus_g2 =  __riscv_vfmul_vv_f32m1(const_half, g0_plus_g2, 4); | ||||||
|  |   /* GCC 14.2 (RISC-V RVV) ICE workaround — see note above. | ||||||
|   *transform = __riscv_vset_v_f32m1_f32m1x4(*transform, 0, g0); |    * Keep the temporary + write-back pattern to avoid ICE. | ||||||
|   *transform = __riscv_vset_v_f32m1_f32m1x4(*transform, 1, vmuladdq_f32(half_g0_plus_g2, const_half, g1)); |    * Do NOT rewrite into: | ||||||
|   *transform = __riscv_vset_v_f32m1_f32m1x4(*transform, 2, vmulsubq_f32(half_g0_plus_g2, const_half, g1)); |    *   *transform = __riscv_vset_v_f32m1_f32m1x4(*transform, idx, val); | ||||||
|   *transform = __riscv_vset_v_f32m1_f32m1x4(*transform, 3, g2); |    */ | ||||||
|  |   vfloat32m1x4_t tmp_transform = *transform; | ||||||
|  |   tmp_transform = __riscv_vset_v_f32m1_f32m1x4(tmp_transform, 0, g0); | ||||||
|  |   tmp_transform = __riscv_vset_v_f32m1_f32m1x4(tmp_transform, 1, vmuladdq_f32(half_g0_plus_g2, const_half, g1)); | ||||||
|  |   tmp_transform = __riscv_vset_v_f32m1_f32m1x4(tmp_transform, 2, vmulsubq_f32(half_g0_plus_g2, const_half, g1)); | ||||||
|  |   tmp_transform = __riscv_vset_v_f32m1_f32m1x4(tmp_transform, 3, g2); | ||||||
|  |   *transform = tmp_transform; | ||||||
| } | } | ||||||
|  |  | ||||||
| inline vfloat32m1x4_t v4f_transpose4x4__rvv(const vfloat32m1x4_t m) { | inline vfloat32m1x4_t v4f_transpose4x4__rvv(const vfloat32m1x4_t m) { | ||||||
|  | |||||||
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							
							
								
								
									
										574
									
								
								aten/src/ATen/native/cuda/GroupedBlas.cpp
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										574
									
								
								aten/src/ATen/native/cuda/GroupedBlas.cpp
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,574 @@ | |||||||
|  | #include <cstdint> | ||||||
|  | #include <c10/util/typeid.h> | ||||||
|  | #include <c10/util/Exception.h> | ||||||
|  | #include <c10/util/SmallVector.h> | ||||||
|  | #include <c10/core/Scalar.h> | ||||||
|  | #include <c10/core/ScalarType.h> | ||||||
|  | #include <c10/util/Exception.h> | ||||||
|  | #define TORCH_ASSERT_ONLY_METHOD_OPERATORS | ||||||
|  | #include <ATen/core/Tensor.h> | ||||||
|  | #include <ATen/core/NamedTensor.h> | ||||||
|  | #include <ATen/Dispatch.h> | ||||||
|  | #include <ATen/ExpandUtils.h> | ||||||
|  | #include <ATen/OpMathType.h> | ||||||
|  | #include <ATen/TensorUtils.h> | ||||||
|  | #include <ATen/cuda/CUDABlas.h> | ||||||
|  | #include <ATen/cuda/CUDAScaledBlas.h> | ||||||
|  | #include <ATen/cuda/tunable/Tunable.h> | ||||||
|  | #include <ATen/cuda/tunable/TunableGemm.h> | ||||||
|  | #include <ATen/native/Resize.h> | ||||||
|  | #include <c10/util/MaybeOwned.h> | ||||||
|  | #include <ATen/native/GroupedMMUtils.h> | ||||||
|  | #include <ATen/native/cuda/RowwiseScaledMM.h> | ||||||
|  | #include <ATen/native/cuda/ScaledGroupMM.h> | ||||||
|  | #include <ATen/native/cuda/GroupMM.h> | ||||||
|  | #include <ATen/ceil_div.h> | ||||||
|  |  | ||||||
|  | #ifdef USE_FBGEMM_GENAI | ||||||
|  | #include <fbgemm_gpu/torch_ops.h> | ||||||
|  | #endif | ||||||
|  |  | ||||||
|  | #ifndef AT_PER_OPERATOR_HEADERS | ||||||
|  | #include <ATen/Functions.h> | ||||||
|  | #include <ATen/NativeFunctions.h> | ||||||
|  | #else | ||||||
|  | #include <ATen/ops/_addmm_activation_native.h> | ||||||
|  | #include <ATen/ops/_efficientzerotensor.h> | ||||||
|  | #include <ATen/ops/_scaled_mm_native.h> | ||||||
|  | #include <ATen/ops/_unsafe_view_native.h> | ||||||
|  | #include <ATen/ops/abs.h> | ||||||
|  | #include <ATen/ops/addmm_native.h> | ||||||
|  | #include <ATen/ops/addmv_native.h> | ||||||
|  | #include <ATen/ops/baddbmm_native.h> | ||||||
|  | #include <ATen/ops/bmm_native.h> | ||||||
|  | #include <ATen/ops/copy_native.h> | ||||||
|  | #include <ATen/ops/dot_native.h> | ||||||
|  | #include <ATen/ops/empty.h> | ||||||
|  | #include <ATen/ops/empty_strided.h> | ||||||
|  | #include <ATen/ops/gelu.h> | ||||||
|  | #include <ATen/ops/max.h> | ||||||
|  | #include <ATen/ops/mm_native.h> | ||||||
|  | #include <ATen/ops/mul.h> | ||||||
|  | #include <ATen/ops/relu.h> | ||||||
|  | #include <ATen/ops/ones.h> | ||||||
|  | #include <ATen/ops/scalar_tensor_native.h> | ||||||
|  | #include <ATen/ops/vdot_native.h> | ||||||
|  | #endif | ||||||
|  |  | ||||||
|  | using at::blas::ScalingType; | ||||||
|  | using at::blas::SwizzleType; | ||||||
|  |  | ||||||
|  | namespace scaled_blas = at::cuda::scaled; | ||||||
|  | using scaled_blas::ScaledGemmImplementation; | ||||||
|  | using scaled_blas::convert_int_to_enum; | ||||||
|  | using scaled_blas::_scaled_mm_allowed_device; | ||||||
|  |  | ||||||
|  | namespace at::native { | ||||||
|  |  | ||||||
|  | namespace { | ||||||
|  |  | ||||||
|  | // 2d-2d and 2d-3d | ||||||
|  | // scaling=MXFP8 | ||||||
|  | // CUDA-only | ||||||
|  | Tensor& | ||||||
|  | _mx8_mx8_bf16_grouped_mm_fbgemm( | ||||||
|  |         const Tensor& mat_a, | ||||||
|  |         const Tensor& mat_b, | ||||||
|  |         const Tensor& scale_a, | ||||||
|  |         const SwizzleType& swizzle_a, | ||||||
|  |         const Tensor& scale_b, | ||||||
|  |         const SwizzleType& swizzle_b, | ||||||
|  |         const std::optional<at::Tensor>& offs, | ||||||
|  |         Tensor& out) { | ||||||
|  |     const bool a_is_2d = mat_a.dim() == 2; | ||||||
|  |     const bool b_is_2d = mat_b.dim() == 2; | ||||||
|  |     bool b_is_3d = mat_b.dim() == 3; | ||||||
|  |     bool is_2d_2d = a_is_2d && b_is_2d; | ||||||
|  |     bool is_2d_3d = a_is_2d && b_is_3d; | ||||||
|  |     TORCH_CHECK_VALUE(is_2d_2d || is_2d_3d, "MXFP8 grouped GEMM currently only supports 2d-2d and 2d-3d cases"); | ||||||
|  |     TORCH_CHECK_VALUE(offs.has_value(), "MXFP8 2d-2d and 2d-3d grouped GEMMs requires offsets"); | ||||||
|  |     TORCH_CHECK_VALUE(out.scalar_type() == at::kBFloat16, "Only bf16 out_dtype is supported for MXFP8 grouped gemm"); | ||||||
|  |     // MXFP8 expects float8_e8m0fnu scales. | ||||||
|  |     TORCH_CHECK_VALUE(scale_a.scalar_type() == at::kFloat8_e8m0fnu && scale_b.scalar_type() == at::kFloat8_e8m0fnu, | ||||||
|  |         "For MXFP8 grouped gemm, both scales must be float8_e8m0fnu tensors."); | ||||||
|  | #ifdef USE_ROCM | ||||||
|  |     TORCH_CHECK_VALUE(swizzle_a == SwizzleType::NO_SWIZZLE && swizzle_b == SwizzleType::NO_SWIZZLE, | ||||||
|  |         "For ROCM MXFP8 grouped gemm, both scale swizzle types must be SWIZZLE_NONE"); | ||||||
|  | #else | ||||||
|  |     TORCH_CHECK_VALUE(swizzle_a == SwizzleType::SWIZZLE_32_4_4 && swizzle_b == SwizzleType::SWIZZLE_32_4_4, | ||||||
|  |         "For CUDA MXFP8 grouped gemm, both scale swizzle types must be SWIZZLE_32_4_4"); | ||||||
|  | #endif | ||||||
|  |  | ||||||
|  | #if defined(USE_FBGEMM_GENAI) and !defined(USE_ROCM) | ||||||
|  |     fbgemm_gpu::mx8mx8bf16_grouped_mm( | ||||||
|  |         mat_a, | ||||||
|  |         mat_b, | ||||||
|  |         scale_a, | ||||||
|  |         scale_b, | ||||||
|  |         offs.value(), | ||||||
|  |         out); | ||||||
|  | #else | ||||||
|  |     TORCH_CHECK_NOT_IMPLEMENTED(false, "mxfp8_mxfp8 grouped gemm requires compile with USE_FBGEMM_GENAI"); | ||||||
|  | #endif | ||||||
|  |     return out; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // 2d-2d and 2d-3d cases | ||||||
|  | // scaling=rowwise | ||||||
|  | // CUDA-only | ||||||
|  | Tensor& | ||||||
|  | _f8_f8_bf16_rowwise_grouped_mm_cuda( | ||||||
|  |           const Tensor& mat_a, | ||||||
|  |           const Tensor& mat_b, | ||||||
|  |           const Tensor& scale_a, | ||||||
|  |           const Tensor& scale_b, | ||||||
|  |           const std::optional<Tensor>& offs, | ||||||
|  |           const std::optional<Tensor>& bias, | ||||||
|  |           const bool use_fast_accum, | ||||||
|  |           Tensor& out) { | ||||||
|  |   TORCH_CHECK_VALUE(mat_a.dtype() == at::kFloat8_e4m3fn, "Expected mat_a to be Float8_e4m3 matrix got ", mat_a.scalar_type()); | ||||||
|  |   TORCH_CHECK_VALUE(mat_b.dtype() == at::kFloat8_e4m3fn, "Expected mat_a to be Float8_e4m3 matrix got ", mat_b.scalar_type()); | ||||||
|  |  | ||||||
|  |   at::cuda::detail::f8f8bf16_grouped_mm( | ||||||
|  |       mat_a, | ||||||
|  |       mat_b, | ||||||
|  |       scale_a, | ||||||
|  |       scale_b, | ||||||
|  |       offs, | ||||||
|  |       bias, | ||||||
|  |       use_fast_accum, | ||||||
|  |       out); | ||||||
|  |     return out; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // 2d-2d and 2d-3d cases | ||||||
|  | // scaling=rowwise | ||||||
|  | // only being called for rocm | ||||||
|  | Tensor& | ||||||
|  | _f8_f8_bf16_rowwise_grouped_mm_rocm( | ||||||
|  |       const Tensor& mat_a, | ||||||
|  |       const Tensor& mat_b, | ||||||
|  |       const Tensor& scale_a, | ||||||
|  |       const Tensor& scale_b, | ||||||
|  |       const std::optional<Tensor>& offs, | ||||||
|  |       Tensor& out) { | ||||||
|  |   TORCH_CHECK_VALUE(mat_a.dtype() == at::kFloat8_e4m3fnuz, "Expected mat_a to be Float8_e4m3fnuz matrix got ", mat_a.scalar_type()); | ||||||
|  |   TORCH_CHECK_VALUE(mat_b.dtype() == at::kFloat8_e4m3fnuz, "Expected mat_a to be Float8_e4m3fnuz matrix got ", mat_b.scalar_type()); | ||||||
|  |  | ||||||
|  | #if defined(USE_FBGEMM_GENAI) && defined(USE_ROCM) | ||||||
|  |   fbgemm_gpu::f8f8bf16_rowwise_grouped_mm( | ||||||
|  |       mat_a, | ||||||
|  |       // FBGEMM expects B matrix shape to be (.., N, K) | ||||||
|  |       mat_b.transpose(-2, -1), | ||||||
|  |       scale_a, | ||||||
|  |       scale_b, | ||||||
|  |       offs, | ||||||
|  |       out); | ||||||
|  | #else | ||||||
|  |   TORCH_CHECK_NOT_IMPLEMENTED(false, "grouped gemm is not supported without USE_FBGEMM_GENAI on ROCM") | ||||||
|  | #endif | ||||||
|  |   return out; | ||||||
|  |  | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Dispatch f8 x f8 -> bf16 row-wise scaled to rocm/cuda | ||||||
|  | Tensor& | ||||||
|  | _f8_f8_bf16_rowwise_grouped_mm( | ||||||
|  |       const Tensor& mat_a, | ||||||
|  |       const Tensor& mat_b, | ||||||
|  |       const Tensor& scale_a, | ||||||
|  |       const Tensor& scale_b, | ||||||
|  |       const std::optional<Tensor>& offs, | ||||||
|  |       const std::optional<Tensor>& bias, | ||||||
|  |       bool use_fast_accum, | ||||||
|  |       Tensor& out) { | ||||||
|  |   // FP8 per-tensor and per-row scaling expect fp32 scales. | ||||||
|  |   TORCH_CHECK_VALUE(scale_a.scalar_type() == kFloat && scale_b.scalar_type() == kFloat, | ||||||
|  |       "For grouped FP8 rowwise, both scales must be float32 tensors"); | ||||||
|  | #ifndef USE_ROCM | ||||||
|  |   return _f8_f8_bf16_rowwise_grouped_mm_cuda( | ||||||
|  |       mat_a, | ||||||
|  |       mat_b, | ||||||
|  |       scale_a, | ||||||
|  |       scale_b, | ||||||
|  |       offs, | ||||||
|  |       bias, | ||||||
|  |       use_fast_accum, | ||||||
|  |       out); | ||||||
|  | #else | ||||||
|  |   // NOTE: ignore use_fast_accum | ||||||
|  |   TORCH_CHECK_VALUE(!bias.has_value(), "ROCM grouped gemm does not support bias") | ||||||
|  |   return _f8_f8_bf16_rowwise_grouped_mm_rocm( | ||||||
|  |       mat_a, | ||||||
|  |       mat_b, | ||||||
|  |       scale_a, | ||||||
|  |       scale_b, | ||||||
|  |       offs, | ||||||
|  |       out); | ||||||
|  | #endif | ||||||
|  | } | ||||||
|  |  | ||||||
|  | void _check_scales_fp8_rowwise(const Tensor& mat, const Tensor& scale, const int dim, const int arg_idx, const int scale_multiplier=1) { | ||||||
|  |   // Checks scales for 2d or 3d target tensors (`mat`). | ||||||
|  |   if (mat.dim() == 2) { | ||||||
|  |     TORCH_CHECK( | ||||||
|  |         scale.dim() == 1, | ||||||
|  |         "scale must be a 1D tensor, but got ", | ||||||
|  |         scale.dim(), | ||||||
|  |         "D, arg ", | ||||||
|  |         arg_idx); | ||||||
|  |     TORCH_CHECK( | ||||||
|  |         scale.is_contiguous(), "scale must be contiguous for arg ", arg_idx); | ||||||
|  |     TORCH_CHECK( | ||||||
|  |         scale.size(0) == mat.size(dim) * scale_multiplier, | ||||||
|  |         "scale must have the same length as mat for arg ", | ||||||
|  |         arg_idx); | ||||||
|  |   } else { | ||||||
|  |     TORCH_CHECK( | ||||||
|  |         scale.dim() == 2, | ||||||
|  |         "scale must be a 2D tensor, but got ", | ||||||
|  |         scale.dim(), | ||||||
|  |         "D for arg ", | ||||||
|  |         arg_idx); | ||||||
|  |     TORCH_CHECK( | ||||||
|  |         scale.stride(1) == 1, | ||||||
|  |         "scale must be contiguous in the last dimension for arg ", | ||||||
|  |         arg_idx); | ||||||
|  |     TORCH_CHECK( | ||||||
|  |         scale.size(0) == mat.size(0), | ||||||
|  |         "scale must have the same batch dimension as mat for arg ", | ||||||
|  |         arg_idx); | ||||||
|  |     TORCH_CHECK( | ||||||
|  |         scale.size(1) == mat.size(1 + dim), | ||||||
|  |         "scale must have the same first dimension as mat for arg ", | ||||||
|  |         arg_idx); | ||||||
|  |   } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | void _check_scales_mxfp8(const Tensor& mat, const Tensor& scale, const int dim, const int arg_idx) { | ||||||
|  |   // Checks scales for 2d or 3d target tensors (`mat`). | ||||||
|  |   if (mat.dim() == 2) { | ||||||
|  |     // For MXFP8, 2d tensors have variable size groups represented as subtensors, | ||||||
|  |     // that are converted to blocked padded format individually, | ||||||
|  |     // so we can't check the scale sizes without doing a d2h sync to get the group sizes here. | ||||||
|  |     TORCH_CHECK( | ||||||
|  |       scale.dim() == mat.dim(), | ||||||
|  |       "for mxfp8, scale must have same number of dimensions as parent tensor, but got mat.dim() = ", mat.dim(), " and scale.dim() = ", scale.dim(), " for arg ", arg_idx); | ||||||
|  |  | ||||||
|  |     // LHS mat shape (M, total_K) -> scale shape (rounded_up(M, 128), rounded_up_per_group(K/32, 4)) | ||||||
|  |     // RHS mat shape (total_K, N) -> scale shape (rounded_up(N, 128), rounded_up_per_group(K/32, 4)) | ||||||
|  |     //   * weight is transposed prior to the call, scale stays non-transposed. | ||||||
|  |     bool LHS = arg_idx == 0; | ||||||
|  |     int scale_dim_to_check = 0; | ||||||
|  |     int mat_dim_to_check = LHS ? 0 : 1; | ||||||
|  |     TORCH_CHECK( | ||||||
|  |         scale.size(scale_dim_to_check) >= mat.size(mat_dim_to_check), | ||||||
|  |         "for mxfp8, arg ", arg_idx, " tensor shape (", mat.size(0), ", ", mat.size(1), ") ", | ||||||
|  |         "must have scale.shape[", scale_dim_to_check, "] >= ", mat.size(mat_dim_to_check), " but got scale.shape=(", scale.size(0), ", ", scale.size(1), ")"); | ||||||
|  |   } else { | ||||||
|  |     // For MXFP8, 3d tensors have static group sizes (stack of 2d tensors), | ||||||
|  |     // so we can check the exact expected scale sizes here without a d2h sync. | ||||||
|  |     auto round_up = [](auto x, auto y) { | ||||||
|  |         return ((x + y - 1) / y) * y; | ||||||
|  |     }; | ||||||
|  |  | ||||||
|  |     // TODO: this is for 3d tensor in 2d-3d case specifically. | ||||||
|  |     // We'll need to support 3d-3d and 3d-2d cases once mxfp8 grouped gemm supports them. | ||||||
|  |     int64_t G = mat.size(0); | ||||||
|  |     int64_t K = mat.size(1); | ||||||
|  |     int64_t N = mat.size(2); | ||||||
|  |     int64_t blocked_scale_K = round_up(K/32, 4); | ||||||
|  |     int64_t blocked_scale_N = round_up(N, 128); | ||||||
|  |  | ||||||
|  |     // fbgemm expects stack of flattened blocked scales for 3d tensor, shape (G, blocked_scale_K * blocked_scale_N). | ||||||
|  |     TORCH_CHECK( | ||||||
|  |       scale.dim() == mat.dim() - 1, | ||||||
|  |       "for mxfp8 2d-3d grouped GEMM, the 3d tensor of shape (G,K,N) must have a 2d scale of shape (G, blocked_scale_K * blocked_scale_N), but scale is ", scale.dim(), "D for arg ", arg_idx | ||||||
|  |     ); | ||||||
|  |     TORCH_CHECK( | ||||||
|  |       scale.size(0) == G && scale.size(1) == blocked_scale_K * blocked_scale_N, | ||||||
|  |       "for mxfp8, the tensor shape (", G, ", ", K, ", ", N, ") must have scale shape (", G, ",", blocked_scale_K, ",", blocked_scale_N, ") for arg ", arg_idx | ||||||
|  |     ); | ||||||
|  |   } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | void check_scale(const Tensor& mat, const Tensor& scale, const int dim, const int arg_idx, const int scale_multiplier=1) { | ||||||
|  |   bool using_fp8_rowwise = scale.scalar_type() == kFloat; | ||||||
|  |   bool using_mxfp8 = scale.scalar_type() == at::kFloat8_e8m0fnu; | ||||||
|  |   if (using_fp8_rowwise) { | ||||||
|  |     _check_scales_fp8_rowwise(mat, scale, dim, arg_idx, scale_multiplier); | ||||||
|  |   } else if (using_mxfp8) { | ||||||
|  |     _check_scales_mxfp8(mat, scale, dim, arg_idx); | ||||||
|  |   } else { | ||||||
|  |     TORCH_CHECK(false, "scale must be float32 or float8_e8m0fnu, but got ", scale.dtype()); | ||||||
|  |   } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | } // namespace | ||||||
|  |  | ||||||
|  | Tensor | ||||||
|  | _scaled_grouped_mm_cuda( | ||||||
|  |         const Tensor& mat_a, | ||||||
|  |         const Tensor& mat_b, | ||||||
|  |         const Tensor& scale_a, | ||||||
|  |         const Tensor& scale_b, | ||||||
|  |         const std::optional<at::Tensor>& offs, | ||||||
|  |         const std::optional<at::Tensor>& bias, | ||||||
|  |         const std::optional<at::Tensor>& scale_result, | ||||||
|  |         std::optional<c10::ScalarType> out_dtype, | ||||||
|  |         bool use_fast_accum) { | ||||||
|  |   bool allowed_device = _scaled_mm_allowed_device(/*sm90_only*/true, /*sm100_only*/true); | ||||||
|  |   TORCH_CHECK_VALUE(allowed_device, "torch._scaled_grouped_mm is only supported on CUDA devices with compute capability = [9.0, 10.0], or ROCm MI300+"); | ||||||
|  |  | ||||||
|  |   TORCH_CHECK_VALUE(!check_valid_strides_and_return_transposed(mat_a), "Expected mat1 to not be transposed"); | ||||||
|  |   TORCH_CHECK_VALUE(check_valid_strides_and_return_transposed(mat_b), "Expected mat2 to be transposed"); | ||||||
|  |   TORCH_CHECK_VALUE(mat_a.dim() == 2 || mat_a.dim() == 3, "mat_a has to be 2 or 3d"); | ||||||
|  |   TORCH_CHECK_VALUE(mat_b.dim() == 2 || mat_b.dim() == 3, "mat_b has to be 2 or 3d"); | ||||||
|  |   const bool a_is_2d = mat_a.dim() == 2; | ||||||
|  |   const bool b_is_2d = mat_b.dim() == 2; | ||||||
|  |  | ||||||
|  |   // NOTE(slayton): For sub-1B formats want contraction_dim argument? | ||||||
|  |   if (!a_is_2d || !b_is_2d) { | ||||||
|  |     TORCH_CHECK_VALUE(mat_a.size(-1) == mat_b.size(-2), "contraction dimension of mat_a and mat_b must match"); | ||||||
|  |   } | ||||||
|  |   TORCH_CHECK_VALUE( | ||||||
|  |     mat_a.size(-1) % 16 == 0, | ||||||
|  |     "Expected trailing dimension of mat_a to be divisible by 16 ", | ||||||
|  |     "but got mat1 shape: (", | ||||||
|  |     mat_a.sizes(), | ||||||
|  |     ")."); | ||||||
|  |   TORCH_CHECK_VALUE(mat_b.size(-2) % 16 == 0 && mat_b.size(-1) % 16 == 0, | ||||||
|  |     "Expected mat_b shape to be divisible by 16 ", | ||||||
|  |     "but got mat_b shape: (", | ||||||
|  |     mat_b.sizes(), | ||||||
|  |     ")."); | ||||||
|  |  | ||||||
|  |  | ||||||
|  |   TORCH_CHECK_VALUE(!bias.has_value(), "Bias not supported yet"); | ||||||
|  |   TORCH_CHECK_VALUE(!scale_result.has_value(), "Scale result not supported yet"); | ||||||
|  |   TORCH_CHECK_VALUE(offs.has_value() ==  (a_is_2d || b_is_2d), "Have to provide offsets if there is a 2d matrix"); | ||||||
|  |  | ||||||
|  |   // NOTE: mxfp8 x mxfp8 requires (and asserts later) that offsets is present. | ||||||
|  |   //       for rowwise, no offsets implies 3d-3d and is handled by lower-level | ||||||
|  |   //       routines | ||||||
|  |   if (offs.has_value()) { | ||||||
|  |     TORCH_CHECK_VALUE(offs->dim() == 1, "offs has to be 1D"); | ||||||
|  |     TORCH_CHECK_VALUE(offs->dtype() == at::kInt, "Offsets have to be int32"); | ||||||
|  |   } | ||||||
|  |   // FP8 per-tensor and per-row scaling expect fp32 scales. | ||||||
|  |   // MXFP8 expects float8_e8m0fnu scales. | ||||||
|  |   TORCH_CHECK_VALUE( | ||||||
|  |       (scale_a.scalar_type() == kFloat && scale_b.scalar_type() == kFloat) || | ||||||
|  |       (scale_a.scalar_type() == at::kFloat8_e8m0fnu && scale_b.scalar_type() == at::kFloat8_e8m0fnu), | ||||||
|  |       "For FP8 tensorwise and rowwise, both scales must both be float32 tensors. For MXFP8, scales must both be float8_e8m0fnu tensors."); | ||||||
|  |  | ||||||
|  |   const int scale_multiplier = (mat_a.dim() == 2 && mat_b.dim() == 2) ? offs->size(0) : 1; | ||||||
|  |   check_scale(mat_a, scale_a, 0 ,0, scale_multiplier); | ||||||
|  |   check_scale(mat_b, scale_b, 1, 1, scale_multiplier); | ||||||
|  |  | ||||||
|  |   const auto out_dtype_ = out_dtype.value_or(kBFloat16); | ||||||
|  |   TORCH_CHECK_VALUE(out_dtype_ == kBFloat16, "Only bf16 high precision output types are supported for grouped gemm"); | ||||||
|  |  | ||||||
|  |   Tensor out = create_grouped_gemm_output_tensor(mat_a, mat_b, offs, out_dtype_); | ||||||
|  |  | ||||||
|  | #if defined(USE_FBGEMM_GENAI) && defined(USE_CUDA) && !defined(USE_ROCM) | ||||||
|  |   // MXFP8 grouped GEMM dispatching | ||||||
|  |   bool is_mx8mx8bf16 = ( | ||||||
|  |     mat_a.scalar_type() == at::kFloat8_e4m3fn && mat_b.scalar_type() == at::kFloat8_e4m3fn && | ||||||
|  |     scale_a.scalar_type() == at::kFloat8_e8m0fnu && scale_b.scalar_type() == at::kFloat8_e8m0fnu | ||||||
|  |   ); | ||||||
|  | #else | ||||||
|  |   bool is_mx8mx8bf16 = false; | ||||||
|  | #endif | ||||||
|  |  | ||||||
|  |   if (is_mx8mx8bf16) { | ||||||
|  |     // Note: Passing implied SwizzleType here, correctness of scale previously checked | ||||||
|  |     //       in `check_scale` call | ||||||
|  |     return _mx8_mx8_bf16_grouped_mm_fbgemm( | ||||||
|  |         mat_a, | ||||||
|  |         mat_b, | ||||||
|  |         scale_a, | ||||||
|  |         SwizzleType::SWIZZLE_32_4_4, | ||||||
|  |         scale_b, | ||||||
|  |         SwizzleType::SWIZZLE_32_4_4, | ||||||
|  |         offs.value(), | ||||||
|  |         out); | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   // If we're not MXFP8, then we're row-wise scaling. | ||||||
|  |   return _f8_f8_bf16_rowwise_grouped_mm( | ||||||
|  |       mat_a, | ||||||
|  |       mat_b, | ||||||
|  |       scale_a, | ||||||
|  |       scale_b, | ||||||
|  |       offs, | ||||||
|  |       bias, | ||||||
|  |       use_fast_accum, | ||||||
|  |       out); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | namespace { | ||||||
|  |  | ||||||
|  | using acceptance_fn = std::function<bool(c10::ScalarType, std::vector<ScalingType>&, ArrayRef<Tensor>&, c10::ScalarType, std::vector<ScalingType>&, ArrayRef<Tensor>&)>; | ||||||
|  |  | ||||||
|  | std::array<std::tuple<std::string, acceptance_fn, ScaledGemmImplementation>, 2> scale_grouped_kernel_dispatch = {{ | ||||||
|  |   { "rowwise_rowwise", scaled_blas::check_rowwise_recipe, ScaledGemmImplementation::ROWWISE_ROWWISE}, | ||||||
|  |   { "mxfp8_mxfp8", scaled_blas::check_mxfp8_recipe, ScaledGemmImplementation::MXFP8_MXFP8}}}; | ||||||
|  |  | ||||||
|  | } // anonymous namespace | ||||||
|  |  | ||||||
|  | Tensor | ||||||
|  | _scaled_grouped_mm_cuda_v2( | ||||||
|  |           const Tensor& mat_a, const Tensor& mat_b, | ||||||
|  |           ArrayRef<Tensor> scale_a, | ||||||
|  |           IntArrayRef scale_recipe_a, | ||||||
|  |           IntArrayRef swizzle_a, | ||||||
|  |           ArrayRef<Tensor> scale_b, | ||||||
|  |           IntArrayRef scale_recipe_b, | ||||||
|  |           IntArrayRef swizzle_b, | ||||||
|  |           const std::optional<Tensor>& offs, | ||||||
|  |           const std::optional<Tensor>& bias, | ||||||
|  |           const std::optional<c10::ScalarType> out_dtype, | ||||||
|  |           IntArrayRef contraction_dim, | ||||||
|  |           bool use_fast_accum) { | ||||||
|  |   bool allowed_device = _scaled_mm_allowed_device(/*sm90_only*/true, /*sm100_only*/true); | ||||||
|  |   TORCH_CHECK_VALUE(allowed_device, "torch._scaled_grouped_mm is only supported on CUDA devices with compute capability = [9.0, 10.0], or ROCm MI300+"); | ||||||
|  |  | ||||||
|  |   TORCH_CHECK_VALUE(!check_valid_strides_and_return_transposed(mat_a), "Expected mat1 to not be transposed"); | ||||||
|  |   TORCH_CHECK_VALUE(check_valid_strides_and_return_transposed(mat_b), "Expected mat2 to be transposed"); | ||||||
|  |   TORCH_CHECK_VALUE(mat_a.dim() == 2 || mat_a.dim() == 3, "mat_a has to be 2 or 3d"); | ||||||
|  |   TORCH_CHECK_VALUE(mat_b.dim() == 2 || mat_b.dim() == 3, "mat_b has to be 2 or 3d"); | ||||||
|  |   const bool a_is_2d = mat_a.dim() == 2; | ||||||
|  |   const bool b_is_2d = mat_b.dim() == 2; | ||||||
|  |  | ||||||
|  |   // NOTE(slayton): For sub-1B formats want contraction_dim argument? | ||||||
|  |   if (!a_is_2d || !b_is_2d) { | ||||||
|  |     if (contraction_dim.size() > 0) { | ||||||
|  |       const int dim_a = contraction_dim[0], dim_b = mat_b.size(contraction_dim[1]); | ||||||
|  |       TORCH_CHECK_VALUE(mat_a.size(dim_a) == mat_b.size(dim_b), | ||||||
|  |           "Contraction dimensions (", dim_a, ",", dim_b, ") of mat_a and mat_b must match, got: ", mat_a.size(dim_a), " and ", | ||||||
|  |           mat_b.size(dim_b)); | ||||||
|  |       // Note: only (-1, -2) is currently supported | ||||||
|  |       TORCH_CHECK_VALUE(dim_a == -1 && dim_b == -2, "Curently contraction dims must be (-1, -2) only"); | ||||||
|  |     } else { | ||||||
|  |       TORCH_CHECK_VALUE(mat_a.size(-1) == mat_b.size(-2), "contraction dimension of mat_a and mat_b must match"); | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  |   TORCH_CHECK_VALUE( | ||||||
|  |     mat_a.size(-1) % 16 == 0, | ||||||
|  |     "Expected trailing dimension of mat_a to be divisible by 16 ", | ||||||
|  |     "but got mat1 shape: (", | ||||||
|  |     mat_a.sizes(), | ||||||
|  |     ")."); | ||||||
|  |   TORCH_CHECK_VALUE(mat_b.size(-2) % 16 == 0 && mat_b.size(-1) % 16 == 0, | ||||||
|  |     "Expected mat_b shape to be divisible by 16 ", | ||||||
|  |     "but got mat_b shape: (", | ||||||
|  |     mat_b.sizes(), | ||||||
|  |     ")."); | ||||||
|  |  | ||||||
|  |   TORCH_CHECK_VALUE(!bias.has_value(), "Bias not supported yet"); | ||||||
|  |   TORCH_CHECK_VALUE(offs.has_value() ==  (a_is_2d || b_is_2d), "Have to provide offsets if there is a 2d matrix"); | ||||||
|  |  | ||||||
|  |   // NOTE: mxfp8 x mxfp8 requires (and asserts later) that offsets is present. | ||||||
|  |   //       for rowwise, no offsets implies 3d-3d and is handled by lower-level | ||||||
|  |   //       routines | ||||||
|  |   if (offs.has_value()) { | ||||||
|  |     TORCH_CHECK_VALUE(offs->dim() == 1, "offs has to be 1D"); | ||||||
|  |     TORCH_CHECK_VALUE(offs->dtype() == at::kInt, "Offsets have to be int32"); | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   const auto out_dtype_ = out_dtype.value_or(kBFloat16); | ||||||
|  |   TORCH_CHECK_VALUE(out_dtype_ == kBFloat16, "Only bf16 high precision output types are supported for grouped gemm"); | ||||||
|  |  | ||||||
|  |   Tensor out = create_grouped_gemm_output_tensor(mat_a, mat_b, offs, out_dtype_); | ||||||
|  |  | ||||||
|  |   // Conversion of implicitly-defined enums to explicit | ||||||
|  |   auto scale_recipe_a_enum = convert_int_to_enum<ScalingType>(scale_recipe_a); | ||||||
|  |   auto swizzle_a_enum = convert_int_to_enum<SwizzleType>(swizzle_a); | ||||||
|  |   auto scale_recipe_b_enum = convert_int_to_enum<ScalingType>(scale_recipe_b); | ||||||
|  |   auto swizzle_b_enum = convert_int_to_enum<SwizzleType>(swizzle_b); | ||||||
|  |  | ||||||
|  |   // at this point we can start working out what we want to be doing | ||||||
|  |   // Try to do as few steps as possible. | ||||||
|  |   // NOTE: support is deliberately sparse, can explicitly enumerate all combinations allowed. | ||||||
|  |   // Do this via a list of defined (name, acceptance, concrete_impl) tuples. | ||||||
|  |   ScaledGemmImplementation gemm_impl = ScaledGemmImplementation::NONE; | ||||||
|  |   for (const auto& fn_entry : scale_grouped_kernel_dispatch) { | ||||||
|  |     const auto [name, accept_fn, scaled_gemm_impl] = fn_entry; | ||||||
|  |     bool ok = accept_fn(mat_a.scalar_type(), | ||||||
|  |                         scale_recipe_a_enum, | ||||||
|  |                         scale_a, | ||||||
|  |                         mat_b.scalar_type(), | ||||||
|  |                         scale_recipe_b_enum, | ||||||
|  |                         scale_b); | ||||||
|  |     if (ok) { | ||||||
|  |       gemm_impl = scaled_gemm_impl; | ||||||
|  |       break; | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  |   TORCH_CHECK_VALUE(gemm_impl != ScaledGemmImplementation::NONE, | ||||||
|  |       "No gemm implementation was found"); | ||||||
|  |  | ||||||
|  |   switch (gemm_impl) { | ||||||
|  |     case ScaledGemmImplementation::ROWWISE_ROWWISE: { | ||||||
|  |       const int scale_multiplier = (mat_a.dim() == 2 && mat_b.dim() == 2) ? offs->size(0) : 1; | ||||||
|  |       _check_scales_fp8_rowwise(mat_a, scale_a[0], 0 /* dim */ , 0 /* arg_idx */, scale_multiplier); | ||||||
|  |       _check_scales_fp8_rowwise(mat_b, scale_b[0], 1 /* dim */ , 1 /* arg_idx */, scale_multiplier); | ||||||
|  |       return _f8_f8_bf16_rowwise_grouped_mm( | ||||||
|  |           mat_a, | ||||||
|  |           mat_b, | ||||||
|  |           scale_a[0], | ||||||
|  |           scale_b[0], | ||||||
|  |           offs, | ||||||
|  |           bias, | ||||||
|  |           use_fast_accum, | ||||||
|  |           out); | ||||||
|  |     } | ||||||
|  |     case ScaledGemmImplementation::MXFP8_MXFP8: { | ||||||
|  |       _check_scales_mxfp8(mat_a, scale_a[0], 0 /* dim */, 0 /* arg_idx */); | ||||||
|  |       _check_scales_mxfp8(mat_b, scale_b[0], 1 /* dim */, 1 /* arg_idx */); | ||||||
|  |       return _mx8_mx8_bf16_grouped_mm_fbgemm( | ||||||
|  |           mat_a, | ||||||
|  |           mat_b, | ||||||
|  |           scale_a[0], | ||||||
|  |           swizzle_a_enum[0], | ||||||
|  |           scale_b[0], | ||||||
|  |           swizzle_b_enum[0], | ||||||
|  |           offs.value(), | ||||||
|  |           out); | ||||||
|  |     } | ||||||
|  |     default: | ||||||
|  |       TORCH_CHECK_NOT_IMPLEMENTED(false, | ||||||
|  |           "_scaled_grouped_mm_cuda_v2 is in an inconsistent state - should never reach here"); | ||||||
|  |   } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | Tensor _grouped_mm_cuda(const Tensor& mat_a, const Tensor& mat_b, | ||||||
|  | const std::optional<at::Tensor>& offs, | ||||||
|  | const std::optional<at::Tensor>& bias, | ||||||
|  | std::optional<c10::ScalarType> out_dtype) { | ||||||
|  |   _grouped_mm_validate_inputs(mat_a, mat_b, offs, bias, out_dtype); | ||||||
|  |   bool a_b_and_out_are_bf16 = ( | ||||||
|  |     mat_a.dtype() == at::kBFloat16 && | ||||||
|  |     mat_b.dtype() == at::kBFloat16 && | ||||||
|  |     out_dtype.value_or(at::kBFloat16) == at::kBFloat16 | ||||||
|  |   ); | ||||||
|  | #ifndef USE_ROCM | ||||||
|  |   bool use_fast_path = _scaled_mm_allowed_device(/*sm90_only*/true, /*sm100_only*/true) && a_b_and_out_are_bf16; | ||||||
|  | #else | ||||||
|  |   // _scaled_mm_allowed_device is used here within _grouped_mm_cuda which seems incorrect since scale is not used. | ||||||
|  |   // the _grouped_mm_fallback should be safe for any ROCm GPU since it's just calling typical mm/bmm | ||||||
|  |   bool use_fast_path = false; | ||||||
|  | #endif | ||||||
|  |   const auto out_dtype_ = _resolve_grouped_mm_out_dtype(mat_a, mat_b, out_dtype); | ||||||
|  |   Tensor out = create_grouped_gemm_output_tensor(mat_a, mat_b, offs, out_dtype_); | ||||||
|  |   if (use_fast_path) { | ||||||
|  |     // fast path, no d2h sync needed | ||||||
|  |     at::cuda::detail::bf16bf16_grouped_mm(mat_a, mat_b, offs, bias, out); | ||||||
|  |   } else { | ||||||
|  |     _grouped_mm_fallback(mat_a, mat_b, offs, bias, out_dtype, out); | ||||||
|  |   } | ||||||
|  |   return out; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | } // namespace at::native | ||||||
| @ -6,7 +6,7 @@ | |||||||
| #endif | #endif | ||||||
|  |  | ||||||
| // ROCm 6.3 is planned to have these functions, but until then here they are. | // ROCm 6.3 is planned to have these functions, but until then here they are. | ||||||
| #if defined(USE_ROCM) && ROCM_VERSION >= 60201 | #if defined(USE_ROCM) | ||||||
| #include <device_functions.h> | #include <device_functions.h> | ||||||
| #include <hip/hip_fp16.h> | #include <hip/hip_fp16.h> | ||||||
| #include <hip/hip_bf16.h> | #include <hip/hip_bf16.h> | ||||||
| @ -115,9 +115,7 @@ __device__ __forceinline__ void fastSpecializedAtomicAdd( | |||||||
|     index_t index, |     index_t index, | ||||||
|     const index_t numel, |     const index_t numel, | ||||||
|     scalar_t value) { |     scalar_t value) { | ||||||
| #if (                      \ | #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 700)) | ||||||
|     (defined(USE_ROCM) && ROCM_VERSION < 60201) || \ |  | ||||||
|     (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 700))) |  | ||||||
|   gpuAtomicAddNoReturn( |   gpuAtomicAddNoReturn( | ||||||
|       reinterpret_cast<at::Half*>(tensor) + index, |       reinterpret_cast<at::Half*>(tensor) + index, | ||||||
|       static_cast<at::Half>(value)); |       static_cast<at::Half>(value)); | ||||||
| @ -160,9 +158,7 @@ __device__ __forceinline__ void fastSpecializedAtomicAdd( | |||||||
|     index_t index, |     index_t index, | ||||||
|     const index_t numel, |     const index_t numel, | ||||||
|     scalar_t value) { |     scalar_t value) { | ||||||
| #if (                      \ | #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)) | ||||||
|     (defined(USE_ROCM) && ROCM_VERSION < 60201) || \ |  | ||||||
|     (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))) |  | ||||||
|   gpuAtomicAddNoReturn( |   gpuAtomicAddNoReturn( | ||||||
|       reinterpret_cast<at::BFloat16*>(tensor) + index, |       reinterpret_cast<at::BFloat16*>(tensor) + index, | ||||||
|       static_cast<at::BFloat16>(value)); |       static_cast<at::BFloat16>(value)); | ||||||
|  | |||||||
| @ -1,18 +1,17 @@ | |||||||
| #pragma once | #pragma once | ||||||
|  |  | ||||||
|  | #include <ATen/OpMathType.h> | ||||||
|  | #include <ATen/cuda/detail/OffsetCalculator.cuh> | ||||||
| #include <ATen/detail/FunctionTraits.h> | #include <ATen/detail/FunctionTraits.h> | ||||||
| #include <ATen/native/TensorIterator.h> | #include <ATen/native/TensorIterator.h> | ||||||
| #include <ATen/native/TensorIteratorDynamicCasting.h> | #include <ATen/native/TensorIteratorDynamicCasting.h> | ||||||
| #include <ATen/cuda/detail/OffsetCalculator.cuh> |  | ||||||
| #include <ATen/OpMathType.h> |  | ||||||
| #include <ATen/native/cuda/thread_constants.h> | #include <ATen/native/cuda/thread_constants.h> | ||||||
|  |  | ||||||
| #include <thrust/tuple.h> |  | ||||||
|  |  | ||||||
| #include <ATen/native/cuda/MemoryAccess.cuh> | #include <ATen/native/cuda/MemoryAccess.cuh> | ||||||
|  |  | ||||||
| #include <tuple> | #include <tuple> | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| namespace at::native { | namespace at::native { | ||||||
|  |  | ||||||
| template<int N> | template<int N> | ||||||
| @ -62,7 +61,11 @@ __device__ inline void elementwise_kernel_helper(func_t f, policy_t policy) { | |||||||
|   #pragma unroll |   #pragma unroll | ||||||
|   for (int i = 0; i < elems_per_thread; i++) { |   for (int i = 0; i < elems_per_thread; i++) { | ||||||
|     if (policy.check_inbounds(i)) { |     if (policy.check_inbounds(i)) { | ||||||
|  | #if defined(__HIP__) | ||||||
|       results[i] = c10::guts::apply(f, args[i]); |       results[i] = c10::guts::apply(f, args[i]); | ||||||
|  | #else | ||||||
|  |       results[i] = std::apply(f, args[i]); | ||||||
|  | #endif | ||||||
|     } |     } | ||||||
|   } |   } | ||||||
|  |  | ||||||
|  | |||||||
| @ -23,7 +23,7 @@ namespace at::native { | |||||||
|  |  | ||||||
| // The maximum number of threads in a block | // The maximum number of threads in a block | ||||||
| #if defined(USE_ROCM) | #if defined(USE_ROCM) | ||||||
| constexpr int MAX_BLOCK_SIZE = 256; | constexpr int MAX_BLOCK_SIZE = 1024; | ||||||
| #else | #else | ||||||
| constexpr int MAX_BLOCK_SIZE = 512; | constexpr int MAX_BLOCK_SIZE = 512; | ||||||
| #endif | #endif | ||||||
| @ -33,7 +33,7 @@ constexpr unsigned MAX_GRID_SIZE = 65535u; | |||||||
| // Number of threads in a block given an input size up to MAX_BLOCK_SIZE | // Number of threads in a block given an input size up to MAX_BLOCK_SIZE | ||||||
| static int getNumThreads(int nElem) { | static int getNumThreads(int nElem) { | ||||||
| #if defined(USE_ROCM) | #if defined(USE_ROCM) | ||||||
|   int threadSizes[5] = { 16, 32, 64, 128, MAX_BLOCK_SIZE }; |   int threadSizes[5] = { 64, 128, 256, 512, MAX_BLOCK_SIZE }; | ||||||
| #else | #else | ||||||
|   int threadSizes[5] = { 32, 64, 128, 256, MAX_BLOCK_SIZE }; |   int threadSizes[5] = { 32, 64, 128, 256, MAX_BLOCK_SIZE }; | ||||||
| #endif | #endif | ||||||
| @ -115,9 +115,23 @@ __device__ scalar_t reduce(Op op, PTA tensor, int plane) { | |||||||
|   // first the reductions each thread does separately |   // first the reductions each thread does separately | ||||||
|   scalar_t sum = static_cast<scalar_t>(0); |   scalar_t sum = static_cast<scalar_t>(0); | ||||||
|   for (int batch = threadIdx.y; batch < tensor.size(0); batch += blockDim.y) { |   for (int batch = threadIdx.y; batch < tensor.size(0); batch += blockDim.y) { | ||||||
|  | #if defined(USE_ROCM) | ||||||
|  |     constexpr int UNRL = 4; // load deserilize factor | ||||||
|  |     scalar_t tmp[UNRL]; | ||||||
|  |     for (int x = threadIdx.x; x < tensor.size(2); x += blockDim.x*UNRL) { | ||||||
|  | #pragma unroll | ||||||
|  |       for (int u = 0; u < UNRL; u++) | ||||||
|  |         tmp[u] = op(batch, plane, std::min((int)tensor.size(2)-1, (int)(x+u*blockDim.x))); | ||||||
|  | #pragma unroll | ||||||
|  |       for (int u = 0; u < UNRL; u++) | ||||||
|  |         if (x+u*blockDim.x < tensor.size(2)) | ||||||
|  |           sum += tmp[u]; | ||||||
|  |     } | ||||||
|  | #else | ||||||
|     for (int x = threadIdx.x; x < tensor.size(2); x += blockDim.x) { |     for (int x = threadIdx.x; x < tensor.size(2); x += blockDim.x) { | ||||||
|       sum += op(batch, plane, x); |       sum += op(batch, plane, x); | ||||||
|     } |     } | ||||||
|  | #endif | ||||||
|   } |   } | ||||||
|   __shared__ scalar_t shared[C10_WARP_SIZE]; |   __shared__ scalar_t shared[C10_WARP_SIZE]; | ||||||
|   SumReduceOp<scalar_t> reduce_op; |   SumReduceOp<scalar_t> reduce_op; | ||||||
| @ -292,6 +306,22 @@ __global__ void batch_norm_collect_statistics_kernel( | |||||||
|   stat_accscalar_t var_n = 0; |   stat_accscalar_t var_n = 0; | ||||||
|   int n = 0; |   int n = 0; | ||||||
|   for (int batch = threadIdx.y; batch < input.size(0); batch += blockDim.y) { |   for (int batch = threadIdx.y; batch < input.size(0); batch += blockDim.y) { | ||||||
|  | #if defined(USE_ROCM) | ||||||
|  |     constexpr int UNRL = 4; | ||||||
|  |     stat_accscalar_t v_[UNRL]; | ||||||
|  |     for (int x = threadIdx.x; x < input.size(2); x += blockDim.x*UNRL) { | ||||||
|  |       for (int u = 0; u < UNRL; u++) | ||||||
|  |         v_[u] = input[batch][plane][min(x+u*blockDim.x, input.size(2)-1)]; | ||||||
|  |       for (int u = 0; u < UNRL; u++) { | ||||||
|  |         if (x+u*blockDim.x < input.size(2)) { | ||||||
|  |           stat_accscalar_t d1 = v_[u] - avg; | ||||||
|  |           n++; | ||||||
|  |           avg += d1 / n; | ||||||
|  |           var_n += d1 * (v_[u] - avg); | ||||||
|  |         } | ||||||
|  |       } | ||||||
|  |     } | ||||||
|  | #else | ||||||
|     for (int x = threadIdx.x; x < input.size(2); x += blockDim.x) { |     for (int x = threadIdx.x; x < input.size(2); x += blockDim.x) { | ||||||
|       stat_accscalar_t v = input[batch][plane][x]; |       stat_accscalar_t v = input[batch][plane][x]; | ||||||
|       stat_accscalar_t d1 = v - avg; |       stat_accscalar_t d1 = v - avg; | ||||||
| @ -299,6 +329,7 @@ __global__ void batch_norm_collect_statistics_kernel( | |||||||
|       avg += d1 / n; |       avg += d1 / n; | ||||||
|       var_n += d1 * (v - avg); |       var_n += d1 * (v - avg); | ||||||
|     } |     } | ||||||
|  | #endif | ||||||
|   } |   } | ||||||
|  |  | ||||||
|   // first warpSum to get one value per thread to |   // first warpSum to get one value per thread to | ||||||
|  | |||||||
| @ -92,6 +92,16 @@ inline thrust::pair<int64_t, int64_t>  get_index_mapping2d( | |||||||
|     output_offset + output_y * output_dim_x + output_x); |     output_offset + output_y * output_dim_x + output_x); | ||||||
| } | } | ||||||
|  |  | ||||||
|  | __device__ __forceinline__ int64_t reflect_index(int64_t x, int64_t len) { | ||||||
|  |   const int64_t two = (len - 1) * 2; | ||||||
|  |   if (two <= 0) { | ||||||
|  |     return 0; | ||||||
|  |   } | ||||||
|  |   int64_t m = x % two; | ||||||
|  |   if (m < 0) m += two; | ||||||
|  |   return (m < len) ? m : (two - m); | ||||||
|  | } | ||||||
|  |  | ||||||
| template<typename scalar_t> | template<typename scalar_t> | ||||||
| __global__ void reflection_pad1d_out_kernel( | __global__ void reflection_pad1d_out_kernel( | ||||||
|     const scalar_t * input, scalar_t * output, |     const scalar_t * input, scalar_t * output, | ||||||
| @ -106,6 +116,28 @@ __global__ void reflection_pad1d_out_kernel( | |||||||
|   } |   } | ||||||
| } | } | ||||||
|  |  | ||||||
|  | template <typename scalar_t> | ||||||
|  | __global__ void reflection_pad1d_flat( | ||||||
|  |     const scalar_t* __restrict__ input, | ||||||
|  |     scalar_t* __restrict__ output, | ||||||
|  |     int64_t input_w, int64_t pad_l, int64_t pad_r, | ||||||
|  |     int64_t out_w, int64_t plane_count) { | ||||||
|  |  | ||||||
|  |   const int64_t bx = blockDim.x; | ||||||
|  |   const int64_t tx = threadIdx.x; | ||||||
|  |  | ||||||
|  |   const int64_t total = plane_count * out_w; | ||||||
|  |   const int64_t grid_stride = static_cast<int64_t>(bx) * gridDim.x; | ||||||
|  |   int64_t linear = static_cast<int64_t>(blockIdx.x) * bx + tx; | ||||||
|  |  | ||||||
|  |   for (; linear < total; linear += grid_stride) { | ||||||
|  |     const int64_t plane = linear / out_w; | ||||||
|  |     const int64_t x = linear - plane * out_w; | ||||||
|  |     const int64_t j = reflect_index(x - pad_l, input_w); | ||||||
|  |     output[plane * out_w + x] = input[plane * input_w + j]; | ||||||
|  |   } | ||||||
|  | } | ||||||
|  |  | ||||||
| template <typename scalar_t> | template <typename scalar_t> | ||||||
| __global__ void reflection_pad1d_backward_out_kernel( | __global__ void reflection_pad1d_backward_out_kernel( | ||||||
|     scalar_t * grad_input, const scalar_t * grad_output, |     scalar_t * grad_input, const scalar_t * grad_output, | ||||||
| @ -710,25 +742,44 @@ TORCH_IMPL_FUNC(reflection_pad1d_out_cuda) | |||||||
|   int64_t input_w = input_.size(dim_w); |   int64_t input_w = input_.size(dim_w); | ||||||
|   int64_t output_w = input_w + pad_l + pad_r; |   int64_t output_w = input_w + pad_l + pad_r; | ||||||
|  |  | ||||||
|   dim3 block_size(output_w > 256 ? 256 : output_w); |  | ||||||
|   dim3 grid_size((int)::ceil(output_w / 256.0), nplane, nbatch); |  | ||||||
|  |  | ||||||
|   Tensor input = input_.contiguous(); |   Tensor input = input_.contiguous(); | ||||||
|  |  | ||||||
|   AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2( |   const int block_x = static_cast<int>(std::min<int64_t>(256, std::max<int64_t>(1, output_w))); | ||||||
|       kHalf, kBFloat16, input.scalar_type(), "reflection_pad1d_out_template", [&] { |   const cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties(); | ||||||
|         reflection_pad1d_out_kernel<<< |   const int max_x = prop->maxGridSize[0]; | ||||||
|             grid_size, |   const int max_y = prop->maxGridSize[1]; | ||||||
|             block_size, |   const int max_z = prop->maxGridSize[2]; | ||||||
|             0, |  | ||||||
|             at::cuda::getCurrentCUDAStream()>>>( |   AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kHalf, kBFloat16, input.scalar_type(), "reflection_pad1d_out", [&] { | ||||||
|             input.const_data_ptr<scalar_t>(), |     auto stream = at::cuda::getCurrentCUDAStream(); | ||||||
|             output.mutable_data_ptr<scalar_t>(), |  | ||||||
|             input_w, |     const int64_t gx = at::ceil_div(output_w, static_cast<int64_t>(block_x)); | ||||||
|             pad_l, |  | ||||||
|             pad_r); |     const bool fits3d = (nplane <= max_y) && (nbatch <= max_z) && (gx <= max_x); | ||||||
|         C10_CUDA_KERNEL_LAUNCH_CHECK(); |  | ||||||
|       }); |     if (fits3d) { | ||||||
|  |       dim3 block(block_x, 1, 1); | ||||||
|  |       dim3 grid(gx, static_cast<unsigned>(nplane), static_cast<unsigned>(nbatch)); | ||||||
|  |       reflection_pad1d_out_kernel<scalar_t><<<grid, block, 0, stream>>>( | ||||||
|  |           input.const_data_ptr<scalar_t>(), | ||||||
|  |           output.mutable_data_ptr<scalar_t>(), | ||||||
|  |           input_w, pad_l, pad_r); | ||||||
|  |     } else { | ||||||
|  |       dim3 block(block_x, 1, 1); | ||||||
|  |       const int64_t plane_count = nplane * nbatch; | ||||||
|  |       const int64_t total_blocks = at::ceil_div(plane_count * output_w, static_cast<int64_t>(block_x)); | ||||||
|  |       const int grid_x = static_cast<int>(std::min<int64_t>(max_x, std::max<int64_t>(1, total_blocks))); | ||||||
|  |       dim3 grid(grid_x, 1, 1); | ||||||
|  |  | ||||||
|  |       reflection_pad1d_flat<scalar_t><<<grid, block, 0, stream>>>( | ||||||
|  |           input.const_data_ptr<scalar_t>(), | ||||||
|  |           output.mutable_data_ptr<scalar_t>(), | ||||||
|  |           input_w, pad_l, pad_r, output_w, plane_count); | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     C10_CUDA_KERNEL_LAUNCH_CHECK(); | ||||||
|  |   }); | ||||||
| } | } | ||||||
|  |  | ||||||
| TORCH_IMPL_FUNC(reflection_pad1d_backward_out_cuda)(const Tensor& grad_output_, | TORCH_IMPL_FUNC(reflection_pad1d_backward_out_cuda)(const Tensor& grad_output_, | ||||||
|  | |||||||
							
								
								
									
										1284
									
								
								aten/src/ATen/native/cuda/ScaledBlas.cpp
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1284
									
								
								aten/src/ATen/native/cuda/ScaledBlas.cpp
									
									
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							| @ -43,6 +43,12 @@ std::tuple<Tensor&, Tensor&> kthvalue_out_impl_cuda( | |||||||
|   TORCH_CHECK(k >= 1 && k <= slicesize, |   TORCH_CHECK(k >= 1 && k <= slicesize, | ||||||
|               "kthvalue(): selected number k out of range for dimension ", dim); |               "kthvalue(): selected number k out of range for dimension ", dim); | ||||||
|  |  | ||||||
|  |   TORCH_CHECK( | ||||||
|  |       slicesize <= std::numeric_limits<int32_t>::max(), | ||||||
|  |       "kthvalue(): dimension ", dim, " is too large (", slicesize, | ||||||
|  |       "). The current CUDA implementation supports dimension sizes up to ", | ||||||
|  |       std::numeric_limits<int32_t>::max()); | ||||||
|  |  | ||||||
|   at::assert_no_overlap(self, values); |   at::assert_no_overlap(self, values); | ||||||
|  |  | ||||||
|   _reduction_with_indices_allocate_or_resize_output( |   _reduction_with_indices_allocate_or_resize_output( | ||||||
| @ -163,10 +169,6 @@ std::tuple<Tensor&, Tensor&> kthvalue_out_cuda( | |||||||
|     bool keepdim, |     bool keepdim, | ||||||
|     Tensor& values, |     Tensor& values, | ||||||
|     Tensor& indices) { |     Tensor& indices) { | ||||||
|   // See note [Writing Nondeterministic Operations] |  | ||||||
|   // If there are duplicate elements of the kth value, the procedure for choosing which |  | ||||||
|   // of the duplicates to use for the indices output is nondeterministic. |  | ||||||
|   at::globalContext().alertNotDeterministic("kthvalue CUDA"); |  | ||||||
|   auto result = [&]() { |   auto result = [&]() { | ||||||
|     NoNamesGuard guard; |     NoNamesGuard guard; | ||||||
|     // `kthvalue_out_impl_cuda` expects contiguous in input `self`. |     // `kthvalue_out_impl_cuda` expects contiguous in input `self`. | ||||||
|  | |||||||
| @ -65,25 +65,34 @@ __global__ void gatherKthValue( | |||||||
|       &kValue); |       &kValue); | ||||||
|  |  | ||||||
|   // Find the index of the k-th highest element |   // Find the index of the k-th highest element | ||||||
|   index_t kValueIndex = 0; |   __shared__ int32_t minIndexFound; | ||||||
|   bool foundKValue = false; |  | ||||||
|  |   if (threadIdx.x == 0) { | ||||||
|  |       minIndexFound = static_cast<int32_t>(inputSliceSize); | ||||||
|  |   } | ||||||
|  |   __syncthreads(); | ||||||
|  |  | ||||||
|   for (index_t i = threadIdx.x; i < inputSliceSize; i += blockDim.x) { |   for (index_t i = threadIdx.x; i < inputSliceSize; i += blockDim.x) { | ||||||
|     bool inRange = (i < inputSliceSize); |       // Early exit based on best-so-far | ||||||
|     scalar_t v = inRange ? doLdg(&inputSliceStart[i * inputWithinSliceStride]) |       if (i >= minIndexFound) { | ||||||
|                          : static_cast<scalar_t>(0); |           break; | ||||||
|     bool isKValue = inRange && |       } | ||||||
|         ((v == kValue) || (at::_isnan(v) && at::_isnan(kValue))); |  | ||||||
|     if (isKValue) { |       scalar_t v = doLdg(&inputSliceStart[i * inputWithinSliceStride]); | ||||||
|       kValueIndex = i; |       bool isKValue = | ||||||
|       foundKValue = true; |           ((v == kValue) || (at::_isnan(v) && at::_isnan(kValue))); | ||||||
|       break; |  | ||||||
|     } |       if (isKValue) { | ||||||
|  |           atomicMin(&minIndexFound, static_cast<int32_t>(i)); | ||||||
|  |           break; | ||||||
|  |       } | ||||||
|   } |   } | ||||||
|  |  | ||||||
|   if (foundKValue) { |   __syncthreads(); | ||||||
|     kthValueSliceStart[0] = kValue; |  | ||||||
|     indicesSliceStart[0] = kValueIndex; |   if (threadIdx.x == 0) { | ||||||
|  |       indicesSliceStart[0] = static_cast<index_t>(minIndexFound); | ||||||
|  |       kthValueSliceStart[0] = kValue; | ||||||
|   } |   } | ||||||
| } | } | ||||||
|  |  | ||||||
|  | |||||||
| @ -127,6 +127,29 @@ __global__ void upsample_bilinear2d_nhwc_out_frame( | |||||||
|   } |   } | ||||||
| } | } | ||||||
|  |  | ||||||
|  | #ifdef USE_ROCM | ||||||
|  | // Helper function to compute output pixel range that can contribute to input pixel | ||||||
|  | template <typename accscalar_t> | ||||||
|  | __device__ __forceinline__ void compute_output_range( | ||||||
|  |     int input_pos, | ||||||
|  |     accscalar_t scale, | ||||||
|  |     int output_size, | ||||||
|  |     bool align_corners, | ||||||
|  |     int& min_output, | ||||||
|  |     int& max_output) { | ||||||
|  |   accscalar_t lo, hi; | ||||||
|  |   if (align_corners) { | ||||||
|  |       lo = static_cast<accscalar_t>(input_pos - 1) / scale; | ||||||
|  |       hi = static_cast<accscalar_t>(input_pos + 1) / scale; | ||||||
|  |   } else { | ||||||
|  |       lo = (input_pos - static_cast<accscalar_t>(0.5)) / scale - static_cast<accscalar_t>(0.5); | ||||||
|  |       hi = (input_pos + static_cast<accscalar_t>(1.5)) / scale - static_cast<accscalar_t>(0.5); | ||||||
|  |   } | ||||||
|  |   min_output = max(0, static_cast<int>(std::ceil(lo))); | ||||||
|  |   max_output = min(output_size - 1, static_cast<int>(std::floor(hi))); | ||||||
|  | } | ||||||
|  | #endif | ||||||
|  |  | ||||||
| // Backward (adjoint) operation 1 <- 2 (accumulates) | // Backward (adjoint) operation 1 <- 2 (accumulates) | ||||||
| template <typename scalar_t, typename accscalar_t> | template <typename scalar_t, typename accscalar_t> | ||||||
| C10_LAUNCH_BOUNDS_1(1024) | C10_LAUNCH_BOUNDS_1(1024) | ||||||
| @ -141,8 +164,74 @@ __global__ void upsample_bilinear2d_backward_out_frame( | |||||||
|     const bool align_corners, |     const bool align_corners, | ||||||
|     scalar_t* __restrict__ idata, |     scalar_t* __restrict__ idata, | ||||||
|     const scalar_t* __restrict__ odata) { |     const scalar_t* __restrict__ odata) { | ||||||
|   const size_t o_numel = nc * width2 * height2; |   // In C++, integer multiplication, like in standard arithmetic, is generally commutative. | ||||||
|   const size_t i_numel = nc * width1 * height1; |   const size_t i_numel = nc * width1 * height1; | ||||||
|  | #ifdef USE_ROCM | ||||||
|  |   for (size_t index = blockDim.x * blockIdx.x + threadIdx.x; index < i_numel; | ||||||
|  |        index += blockDim.x * gridDim.x) { | ||||||
|  |     // Decode input pixel coordinates | ||||||
|  |     size_t index_temp = index; | ||||||
|  |     const int w1 = index_temp % width1; | ||||||
|  |     index_temp /= width1; | ||||||
|  |     const int h1 = index_temp % height1; | ||||||
|  |     const size_t nc_idx = index_temp / height1; | ||||||
|  |  | ||||||
|  |     accscalar_t grad_sum = 0; | ||||||
|  |  | ||||||
|  |     // Find range of output pixels that could interpolate from this input pixel | ||||||
|  |     int h2_min, h2_max, w2_min, w2_max; | ||||||
|  |     compute_output_range<accscalar_t>(h1, rheight, height2, align_corners, h2_min, h2_max); | ||||||
|  |     compute_output_range<accscalar_t>(w1, rwidth, width2, align_corners, w2_min, w2_max); | ||||||
|  |  | ||||||
|  |     // Iterate over potential output pixels | ||||||
|  |     for (int h2 = h2_min; h2 <= h2_max; h2++) { | ||||||
|  |       for (int w2 = w2_min; w2 <= w2_max; w2++) { | ||||||
|  |         // Compute source coordinates for this output pixel | ||||||
|  |         const accscalar_t h1r = area_pixel_compute_source_index<accscalar_t>( | ||||||
|  |             rheight, h2, align_corners, /*cubic=*/false); | ||||||
|  |         const int h1_base = (int)h1r; | ||||||
|  |         const int h1p = (h1_base < height1 - 1) ? 1 : 0; | ||||||
|  |         const accscalar_t h1lambda = h1r - h1_base; | ||||||
|  |         const accscalar_t h0lambda = static_cast<accscalar_t>(1) - h1lambda; | ||||||
|  |  | ||||||
|  |         const accscalar_t w1r = area_pixel_compute_source_index<accscalar_t>( | ||||||
|  |             rwidth, w2, align_corners, /*cubic=*/false); | ||||||
|  |         const int w1_base = (int)w1r; | ||||||
|  |         const int w1p = (w1_base < width1 - 1) ? 1 : 0; | ||||||
|  |         const accscalar_t w1lambda = w1r - w1_base; | ||||||
|  |         const accscalar_t w0lambda = static_cast<accscalar_t>(1) - w1lambda; | ||||||
|  |  | ||||||
|  |         // Check if our input pixel participates in this interpolation and accumulate all weights | ||||||
|  |         // At boundaries, h1p=0 or w1p=0 causes some sampling positions to collapse | ||||||
|  |         // to the same pixel, so we need to accumulate weights from all matching positions | ||||||
|  |         accscalar_t weight = 0; | ||||||
|  |  | ||||||
|  |         // Check all four interpolation positions and accumulate weights | ||||||
|  |         if (h1 == h1_base && w1 == w1_base) { | ||||||
|  |           weight += h0lambda * w0lambda;  // top-left | ||||||
|  |         } | ||||||
|  |         if (h1 == h1_base && w1 == w1_base + w1p) { | ||||||
|  |           weight += h0lambda * w1lambda;  // top-right (may be same as top-left if w1p=0) | ||||||
|  |         } | ||||||
|  |         if (h1 == h1_base + h1p && w1 == w1_base) { | ||||||
|  |           weight += h1lambda * w0lambda;  // bottom-left (may be same as top-left if h1p=0) | ||||||
|  |         } | ||||||
|  |         if (h1 == h1_base + h1p && w1 == w1_base + w1p) { | ||||||
|  |           weight += h1lambda * w1lambda;  // bottom-right (may collapse to other positions) | ||||||
|  |         } | ||||||
|  |  | ||||||
|  |         if (weight > 0) { | ||||||
|  |           const size_t output_idx = nc_idx * height2 * width2 + h2 * width2 + w2; | ||||||
|  |           grad_sum += weight * static_cast<accscalar_t>(odata[output_idx]); | ||||||
|  |         } | ||||||
|  |       } | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     // Write accumulated gradient (no atomics needed) | ||||||
|  |     idata[index] = static_cast<scalar_t>(grad_sum); | ||||||
|  |   } | ||||||
|  | #else | ||||||
|  |   const size_t o_numel = nc * width2 * height2; | ||||||
|   for (size_t index = blockDim.x * blockIdx.x + threadIdx.x; index < o_numel; |   for (size_t index = blockDim.x * blockIdx.x + threadIdx.x; index < o_numel; | ||||||
|        index += blockDim.x * gridDim.x) { |        index += blockDim.x * gridDim.x) { | ||||||
|     size_t index_temp = index; |     size_t index_temp = index; | ||||||
| @ -191,6 +280,7 @@ __global__ void upsample_bilinear2d_backward_out_frame( | |||||||
|         static_cast<scalar_t>(h1lambda * w1lambda * d2val), |         static_cast<scalar_t>(h1lambda * w1lambda * d2val), | ||||||
|         true); |         true); | ||||||
|   } |   } | ||||||
|  | #endif | ||||||
| } | } | ||||||
|  |  | ||||||
| template <typename scalar_t, typename accscalar_t> | template <typename scalar_t, typename accscalar_t> | ||||||
| @ -387,7 +477,6 @@ static void upsample_bilinear2d_backward_out_cuda_template( | |||||||
|   // threads are not covering the whole input tensor. |   // threads are not covering the whole input tensor. | ||||||
|   grad_input.zero_(); |   grad_input.zero_(); | ||||||
|  |  | ||||||
|   const size_t num_kernels = nbatch * channels * output_height * output_width; |  | ||||||
|   const int num_threads = std::min( |   const int num_threads = std::min( | ||||||
|       at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock, 1024); |       at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock, 1024); | ||||||
|   cudaStream_t stream = at::cuda::getCurrentCUDAStream(); |   cudaStream_t stream = at::cuda::getCurrentCUDAStream(); | ||||||
| @ -397,6 +486,12 @@ static void upsample_bilinear2d_backward_out_cuda_template( | |||||||
|     return; |     return; | ||||||
|   } |   } | ||||||
|  |  | ||||||
|  | #ifdef USE_ROCM | ||||||
|  |   constexpr bool use_input = true; | ||||||
|  | #else | ||||||
|  |   constexpr bool use_input = false; | ||||||
|  | #endif | ||||||
|  |  | ||||||
|   AT_DISPATCH_FLOATING_TYPES_AND2( |   AT_DISPATCH_FLOATING_TYPES_AND2( | ||||||
|       at::ScalarType::Half, at::ScalarType::BFloat16, |       at::ScalarType::Half, at::ScalarType::BFloat16, | ||||||
|       grad_output_.scalar_type(), "upsample_bilinear2d_backward_out_frame", [&] { |       grad_output_.scalar_type(), "upsample_bilinear2d_backward_out_frame", [&] { | ||||||
| @ -414,6 +509,8 @@ static void upsample_bilinear2d_backward_out_cuda_template( | |||||||
|       const accscalar_t rwidth = area_pixel_compute_scale<accscalar_t>( |       const accscalar_t rwidth = area_pixel_compute_scale<accscalar_t>( | ||||||
|           input_width, output_width, align_corners, scales_w); |           input_width, output_width, align_corners, scales_w); | ||||||
|  |  | ||||||
|  |       const size_t num_kernels = nbatch * channels * output_height * output_width; | ||||||
|  |  | ||||||
|       upsample_bilinear2d_backward_nhwc_out_frame<scalar_t, accscalar_t> |       upsample_bilinear2d_backward_nhwc_out_frame<scalar_t, accscalar_t> | ||||||
|           <<<ceil_div(num_kernels, static_cast<size_t>(num_threads)), num_threads, 0, stream>>>( |           <<<ceil_div(num_kernels, static_cast<size_t>(num_threads)), num_threads, 0, stream>>>( | ||||||
|               input_height, |               input_height, | ||||||
| @ -444,6 +541,8 @@ static void upsample_bilinear2d_backward_out_cuda_template( | |||||||
|       const accscalar_t rwidth = area_pixel_compute_scale<accscalar_t>( |       const accscalar_t rwidth = area_pixel_compute_scale<accscalar_t>( | ||||||
|           input_width, output_width, align_corners, scales_w); |           input_width, output_width, align_corners, scales_w); | ||||||
|  |  | ||||||
|  |       const size_t num_kernels = nbatch * channels * (use_input ? input_height * input_width : output_height * output_width); | ||||||
|  |  | ||||||
|       upsample_bilinear2d_backward_out_frame<scalar_t, accscalar_t> |       upsample_bilinear2d_backward_out_frame<scalar_t, accscalar_t> | ||||||
|           <<<ceil_div(num_kernels, static_cast<size_t>(num_threads)), |           <<<ceil_div(num_kernels, static_cast<size_t>(num_threads)), | ||||||
|              num_threads, |              num_threads, | ||||||
|  | |||||||
							
								
								
									
										171
									
								
								aten/src/ATen/native/cuda/cuBlasCommonArgs.h
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										171
									
								
								aten/src/ATen/native/cuda/cuBlasCommonArgs.h
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,171 @@ | |||||||
|  | #pragma once | ||||||
|  |  | ||||||
|  | #include <ATen/core/Tensor.h> | ||||||
|  |  | ||||||
|  | namespace at::native { | ||||||
|  |  | ||||||
|  | using at::blas::ScalingType; | ||||||
|  | using at::blas::SwizzleType; | ||||||
|  |  | ||||||
|  | namespace { | ||||||
|  |  | ||||||
|  | // TODO: https://github.com/pytorch/pytorch/pull/59380#pullrequestreview-725310492 | ||||||
|  | c10::MaybeOwned<Tensor> inline resolve_conj_if_indicated(const Tensor& tensor, bool resolve_conj) { | ||||||
|  |   if (resolve_conj && tensor.is_conj()) { | ||||||
|  |     return c10::MaybeOwned<Tensor>::owned(tensor.resolve_conj()); | ||||||
|  |   } else { | ||||||
|  |     return c10::MaybeOwned<Tensor>::borrowed(tensor); | ||||||
|  |   } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | c10::MaybeOwned<Tensor> inline prepare_matrix_for_cublas(const Tensor& tensor, bool& transpose_tensor, bool transpose_result) { | ||||||
|  |   if (tensor.is_non_overlapping_and_dense()) { // common case | ||||||
|  |       transpose_tensor = tensor.is_contiguous(); | ||||||
|  |       return resolve_conj_if_indicated(tensor, transpose_result ? transpose_tensor : !transpose_tensor); | ||||||
|  |   } | ||||||
|  |   IntArrayRef tensor_strides = tensor.strides(); | ||||||
|  |   IntArrayRef tensor_sizes = tensor.sizes(); | ||||||
|  |   if ((tensor_strides[0] == 1) && (tensor_strides[1] >= std::max<int64_t>(1, tensor_sizes[0]))) { | ||||||
|  |     transpose_tensor = false; | ||||||
|  |     return resolve_conj_if_indicated(tensor, !transpose_result); | ||||||
|  |   } else if ((tensor_strides[1] == 1) && (tensor_strides[0] >= std::max<int64_t>(1, tensor_sizes[1]))) { | ||||||
|  |     transpose_tensor = true; | ||||||
|  |     return resolve_conj_if_indicated(tensor, transpose_result); | ||||||
|  |   } else { | ||||||
|  |     transpose_tensor = true; | ||||||
|  |     return c10::MaybeOwned<Tensor>::owned(tensor.clone(at::MemoryFormat::Contiguous)); | ||||||
|  |   } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | c10::MaybeOwned<Tensor> inline prepare_matrix_for_cublas(const Tensor& tensor, bool& transpose_tensor) { | ||||||
|  |   if (tensor.is_non_overlapping_and_dense()) { // common case | ||||||
|  |       transpose_tensor = tensor.is_contiguous(); | ||||||
|  |       return resolve_conj_if_indicated(tensor, true); | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   IntArrayRef tensor_strides = tensor.strides(); | ||||||
|  |   IntArrayRef tensor_sizes = tensor.sizes(); | ||||||
|  |   if ((tensor_strides[0] == 1) && (tensor_strides[1] >= std::max<int64_t>(1, tensor_sizes[0]))) { | ||||||
|  |     transpose_tensor = false; | ||||||
|  |     return resolve_conj_if_indicated(tensor, true); | ||||||
|  |   } else if ((tensor_strides[1] == 1) && (tensor_strides[0] >= std::max<int64_t>(1, tensor_sizes[1]))) { | ||||||
|  |     transpose_tensor = true; | ||||||
|  |     return resolve_conj_if_indicated(tensor, true); | ||||||
|  |   } else { | ||||||
|  |     transpose_tensor = true; | ||||||
|  |     return c10::MaybeOwned<Tensor>::owned(tensor.clone(at::MemoryFormat::Contiguous)); | ||||||
|  |   } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | } // namespace | ||||||
|  |  | ||||||
|  | /** | ||||||
|  |  * @brief Prepares matrices for CUBLAS operation | ||||||
|  |  * | ||||||
|  |  * This constructor prepares tensors for CUBLAS | ||||||
|  |  * The main difference is that PyTorch uses row-major as the default and | ||||||
|  |  * CUBLAS expects column-major. | ||||||
|  |  * | ||||||
|  |  * @details | ||||||
|  |  * To enable row-major output while using CUBLAS, | ||||||
|  |  * we use the mathematical identity that (A × B)^T = B^T × A^T. | ||||||
|  |  * | ||||||
|  |  * Transpose in this context refers to Cublas's(Fortran) definition of transpose (row-major) | ||||||
|  |  * T = row-major, N = col-major | ||||||
|  |  * | ||||||
|  |  * Example: | ||||||
|  |  * For matrices A (M×K)(row-major) and B (K×N)(row-major): | ||||||
|  |  *   - Standard multiplication: A × B = (M×K) × (K×N) = M×N result (row-major) | ||||||
|  |  *   - Using our transpose trick: (B^T × A^T) = (N×K)(T) × (K×M)(T) = N×M(N) | ||||||
|  |  *   - However, since the output form cublas is column-major this is | ||||||
|  |  *   - equivalent to an output of size MxN row-major as expected | ||||||
|  |  * | ||||||
|  |  * The transpose flags are derived from the layouts of the passed in tensors | ||||||
|  |  * | ||||||
|  |  * If the operands are in packed float4 format, `k`, `lda` and `ldb` are adjusted | ||||||
|  |  * to their unpacked values to match what cuBLAS expects. | ||||||
|  |  * | ||||||
|  |  * @param mat1 First input matrix | ||||||
|  |  * @param mat2 Second input matrix | ||||||
|  |  * @param c Output matrix (result) | ||||||
|  |  * @param scale_a Optional scaling factor for first matrix | ||||||
|  |  * @param scale_b Optional scaling factor for second matrix | ||||||
|  |  * @param scale_result Optional scaling factor for result | ||||||
|  |  */ | ||||||
|  | struct cublasCommonArgs { | ||||||
|  |   cublasCommonArgs( | ||||||
|  |       const Tensor& mat1, | ||||||
|  |       const Tensor& mat2, | ||||||
|  |       Tensor& c, | ||||||
|  |       const std::optional<Tensor>& scale_a = std::nullopt, | ||||||
|  |       const std::optional<Tensor>& scale_b = std::nullopt, | ||||||
|  |       const std::optional<Tensor>& scale_result = std::nullopt, | ||||||
|  |       const std::optional<ScalingType>& scaling_choice_a = std::nullopt, | ||||||
|  |       const std::optional<ScalingType>& scaling_choice_b = std::nullopt) { | ||||||
|  |     bool transpose_result = false, transpose_a = false, transpose_b = false; | ||||||
|  |     result = prepare_matrix_for_cublas(c, transpose_result); | ||||||
|  |     mata = prepare_matrix_for_cublas(transpose_result ? mat2 : mat1, transpose_a, transpose_result); | ||||||
|  |     matb = prepare_matrix_for_cublas(transpose_result ? mat1 : mat2, transpose_b, transpose_result); | ||||||
|  |  | ||||||
|  |     // Handle scale tensors if provided | ||||||
|  |     if (scale_a && scale_b) { | ||||||
|  |       // By default since we return in row-major we run the gemm | ||||||
|  |       // as B.T @ A.T, check transpose_result to determine if we flip the scales | ||||||
|  |       scale_mata_ptr = transpose_result ? scale_b->data_ptr() : scale_a->data_ptr(); | ||||||
|  |       scale_mata_dtype = transpose_result ? scale_b->scalar_type() : scale_a->scalar_type(); | ||||||
|  |       scaling_mata_type = transpose_result ? scaling_choice_b : scaling_choice_a; | ||||||
|  |       scale_matb_ptr = transpose_result ? scale_a->data_ptr() : scale_b->data_ptr(); | ||||||
|  |       scale_matb_dtype = transpose_result ? scale_a->scalar_type() : scale_b->scalar_type(); | ||||||
|  |       scaling_matb_type = transpose_result ? scaling_choice_a : scaling_choice_b; | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     if (scale_result) { | ||||||
|  |       scale_result_ptr = scale_result->data_ptr(); | ||||||
|  |       scale_result_dtype = scale_result->scalar_type(); | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     // Update transpose flags | ||||||
|  |     if (transpose_result) { | ||||||
|  |       transpose_a = !transpose_a; | ||||||
|  |       transpose_b = !transpose_b; | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     auto sizes_a = mata->sizes(); | ||||||
|  |     auto sizes_b = matb->sizes(); | ||||||
|  |  | ||||||
|  |     m = sizes_a[transpose_result ? 1 : 0]; | ||||||
|  |     k = sizes_a[transpose_result ? 0 : 1]; | ||||||
|  |     n = sizes_b[transpose_result ? 0 : 1]; | ||||||
|  |     lda = mata->stride((transpose_a == transpose_result) ? 1 : 0); | ||||||
|  |     ldb = matb->stride((transpose_b == transpose_result) ? 1 : 0); | ||||||
|  |     result_ld = result->stride(transpose_result ? 0 : 1); | ||||||
|  |     transa = transpose_a ? mata->is_conj() ? 'c' : 't' : 'n'; | ||||||
|  |     transb = transpose_b ? matb->is_conj() ? 'c' : 't' : 'n'; | ||||||
|  |  | ||||||
|  |     // cuBLAS expects unpacked values of `k`, `lda` and `ldb`, adjust for 4x2 packing | ||||||
|  |     // if the gemm operands are in packed float4 | ||||||
|  |     if (mat1.dtype() == at::kFloat4_e2m1fn_x2 && mat2.dtype() == at::kFloat4_e2m1fn_x2) { | ||||||
|  |       k = k * 2; | ||||||
|  |       lda = lda * 2; | ||||||
|  |       ldb = ldb * 2; | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   // Matrix members | ||||||
|  |   char transa, transb; | ||||||
|  |   int64_t m, n, k; | ||||||
|  |   int64_t lda, ldb, result_ld; | ||||||
|  |   c10::MaybeOwned<Tensor> mata, matb, result; | ||||||
|  |  | ||||||
|  |   // Scale members | ||||||
|  |   void* scale_mata_ptr = nullptr; | ||||||
|  |   void* scale_matb_ptr = nullptr; | ||||||
|  |   void* scale_result_ptr = nullptr; | ||||||
|  |   std::optional<c10::ScalarType> scale_mata_dtype; | ||||||
|  |   std::optional<ScalingType> scaling_mata_type; | ||||||
|  |   std::optional<c10::ScalarType> scale_matb_dtype; | ||||||
|  |   std::optional<ScalingType> scaling_matb_type; | ||||||
|  |   std::optional<c10::ScalarType> scale_result_dtype; | ||||||
|  | }; | ||||||
|  |  | ||||||
|  | } // namespace at::native | ||||||
| @ -52,7 +52,7 @@ struct FusedAdagradMathFunctor { | |||||||
|   using opmath_t = at::opmath_type<scalar_t>; |   using opmath_t = at::opmath_type<scalar_t>; | ||||||
|  |  | ||||||
|   C10_DEVICE __forceinline__ void operator()( |   C10_DEVICE __forceinline__ void operator()( | ||||||
|       int chunk_size, |       int64_t chunk_size, | ||||||
|       FusedOptimizerTensorListMetadata<3>& tl, |       FusedOptimizerTensorListMetadata<3>& tl, | ||||||
|       const float* lr_ptr, |       const float* lr_ptr, | ||||||
|       const double& lr, |       const double& lr, | ||||||
| @ -133,4 +133,4 @@ struct FusedAdagradMathFunctor { | |||||||
|  |  | ||||||
| } // namespace | } // namespace | ||||||
|  |  | ||||||
| } // namespace at::native | } // namespace at::native | ||||||
|  | |||||||
| @ -1,4 +1,4 @@ | |||||||
| #if (defined(USE_ROCM) && ROCM_VERSION >= 50700) || ((defined(CUDA_VERSION) && CUDA_VERSION >= 12000) && (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 800))) | #if defined(USE_ROCM) || ((defined(CUDA_VERSION) && CUDA_VERSION >= 12000) && (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 800))) | ||||||
| #include <cuda_bf16.h> | #include <cuda_bf16.h> | ||||||
| #include <cuda_fp16.h> | #include <cuda_fp16.h> | ||||||
| #include <cuda_runtime.h> | #include <cuda_runtime.h> | ||||||
| @ -133,7 +133,7 @@ inline __host__ __device__ uint32_t getAlignmentRoundUp(const void* p) { | |||||||
| #define CDNA2_OR_LATER 0 | #define CDNA2_OR_LATER 0 | ||||||
| #endif | #endif | ||||||
|  |  | ||||||
| #if (defined(USE_ROCM) && ROCM_VERSION >= 50700) || ((defined(CUDA_VERSION) && CUDA_VERSION >= 12000) && (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 800))) | #if defined(USE_ROCM) || ((defined(CUDA_VERSION) && CUDA_VERSION >= 12000) && (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 800))) | ||||||
|  |  | ||||||
| #if defined(USE_ROCM) | #if defined(USE_ROCM) | ||||||
| // TODO: Support RDNA | // TODO: Support RDNA | ||||||
| @ -1161,7 +1161,7 @@ at::Tensor _weight_int4pack_mm_cuda( | |||||||
|   auto C_final = at::empty( |   auto C_final = at::empty( | ||||||
|       {m, n}, at::TensorOptions().dtype(at::kBFloat16).device(A.device())); |       {m, n}, at::TensorOptions().dtype(at::kBFloat16).device(A.device())); | ||||||
|  |  | ||||||
| #if (defined(USE_ROCM) && ROCM_VERSION >= 50700) || ((defined(CUDA_VERSION) && CUDA_VERSION >= 12000) && (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 800))) | #if defined(USE_ROCM) || ((defined(CUDA_VERSION) && CUDA_VERSION >= 12000) && (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 800))) | ||||||
|   auto stream = at::cuda::getCurrentCUDAStream(); |   auto stream = at::cuda::getCurrentCUDAStream(); | ||||||
| #define RUN_GEMM(WARPS, K_TILES_PER_WARP, Q_GROUP_SIZE, REDUCE_TYPE) \ | #define RUN_GEMM(WARPS, K_TILES_PER_WARP, Q_GROUP_SIZE, REDUCE_TYPE) \ | ||||||
|   do {                                                               \ |   do {                                                               \ | ||||||
| @ -1327,7 +1327,7 @@ at::Tensor _convert_weight_to_int4pack_cuda( | |||||||
|       {nTilesTensor, kSuperTiles, 32, innerKTiles / 2}, |       {nTilesTensor, kSuperTiles, 32, innerKTiles / 2}, | ||||||
|       at::TensorOptions().dtype(at::kInt).device(in.device())); |       at::TensorOptions().dtype(at::kInt).device(in.device())); | ||||||
|  |  | ||||||
| #if (defined(USE_ROCM) && ROCM_VERSION >= 50700) || ((defined(CUDA_VERSION) && CUDA_VERSION >= 12000) && (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 800))) | #if defined(USE_ROCM) || ((defined(CUDA_VERSION) && CUDA_VERSION >= 12000) && (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 800))) | ||||||
|   auto stream = at::cuda::getCurrentCUDAStream(); |   auto stream = at::cuda::getCurrentCUDAStream(); | ||||||
|   dim3 grid(kSuperTiles, nTiles); |   dim3 grid(kSuperTiles, nTiles); | ||||||
|  |  | ||||||
|  | |||||||
| @ -2,6 +2,7 @@ | |||||||
| #include <ATen/WrapDimUtilsMulti.h> | #include <ATen/WrapDimUtilsMulti.h> | ||||||
| #include <ATen/native/Resize.h> | #include <ATen/native/Resize.h> | ||||||
| #include <ATen/native/mkldnn/xpu/detail/oneDNN.h> | #include <ATen/native/mkldnn/xpu/detail/oneDNN.h> | ||||||
|  | #include <ATen/native/xpu/Blas.h> | ||||||
| #include <torch/library.h> | #include <torch/library.h> | ||||||
| #ifndef AT_PER_OPERATOR_HEADERS | #ifndef AT_PER_OPERATOR_HEADERS | ||||||
|  |  | ||||||
| @ -50,9 +51,13 @@ Tensor& addmm_out( | |||||||
|       mat1.dtype(), |       mat1.dtype(), | ||||||
|       " != ", |       " != ", | ||||||
|       mat2.dtype()) |       mat2.dtype()) | ||||||
|  |  | ||||||
|   // complex case |   // complex case | ||||||
|   TORCH_CHECK( |   if (self.is_complex()) { | ||||||
|       !mat1.is_complex(), "Complex datatype matmul is not supported in oneDNN"); |     at::native::addmm_complex_out_xpu(self, mat1, mat2, beta, alpha, result); | ||||||
|  |  | ||||||
|  |     return result; | ||||||
|  |   } | ||||||
|  |  | ||||||
|   std::vector<int64_t> result_shape = {mat1.size(0), mat2.size(1)}; |   std::vector<int64_t> result_shape = {mat1.size(0), mat2.size(1)}; | ||||||
|   result.resize_(result_shape); |   result.resize_(result_shape); | ||||||
| @ -167,8 +172,11 @@ Tensor& mm_out(const Tensor& self, const Tensor& mat2, Tensor& result) { | |||||||
|     return result; |     return result; | ||||||
|   } |   } | ||||||
|  |  | ||||||
|   TORCH_CHECK( |   if (self.is_complex()) { | ||||||
|       !self.is_complex(), "Complex datatype matmul is not supported in oneDNN"); |     at::native::mm_complex_out_xpu(self, mat2, result); | ||||||
|  |  | ||||||
|  |     return result; | ||||||
|  |   } | ||||||
|  |  | ||||||
|   onednn::matmul(result, self, mat2, Tensor(), true, onednn::Attr()); |   onednn::matmul(result, self, mat2, Tensor(), true, onednn::Attr()); | ||||||
|   return result; |   return result; | ||||||
| @ -208,9 +216,12 @@ Tensor& baddbmm_out( | |||||||
|       input.sizes()); |       input.sizes()); | ||||||
|  |  | ||||||
|   // complex case |   // complex case | ||||||
|   TORCH_CHECK( |   if (input.is_complex()) { | ||||||
|       !batch1.is_complex(), |     at::native::baddbmm_complex_out_xpu( | ||||||
|       "Complex datatype matmul is not supported in oneDNN"); |         input, batch1, batch2, beta, alpha, result); | ||||||
|  |  | ||||||
|  |     return result; | ||||||
|  |   } | ||||||
|  |  | ||||||
|   // general case |   // general case | ||||||
|   onednn::Attr attr; |   onednn::Attr attr; | ||||||
| @ -257,8 +268,13 @@ Tensor& bmm_out(const Tensor& self, const Tensor& batch2, Tensor& result) { | |||||||
|     return result; |     return result; | ||||||
|   } |   } | ||||||
|  |  | ||||||
|   TORCH_CHECK( |   // complex case | ||||||
|       !self.is_complex(), "Complex datatype matmul is not supported in oneDNN"); |   if (self.is_complex()) { | ||||||
|  |     at::native::bmm_complex_out_xpu(self, batch2, result); | ||||||
|  |  | ||||||
|  |     return result; | ||||||
|  |   } | ||||||
|  |  | ||||||
|   onednn::matmul(result, self, batch2, at::Tensor(), true, onednn::Attr()); |   onednn::matmul(result, self, batch2, at::Tensor(), true, onednn::Attr()); | ||||||
|   return result; |   return result; | ||||||
| } | } | ||||||
|  | |||||||
| @ -222,6 +222,13 @@ struct nextafter_functor { | |||||||
|   } |   } | ||||||
| }; | }; | ||||||
|  |  | ||||||
|  | struct hypot_functor { | ||||||
|  |   template <typename T> | ||||||
|  |   inline T operator()(const T a, const T b) { | ||||||
|  |     return static_cast<T>(precise::sqrt(float(a) * a + float(b) * b)); | ||||||
|  |   } | ||||||
|  | }; | ||||||
|  |  | ||||||
| // Complex binary functors | // Complex binary functors | ||||||
| struct polar_functor { | struct polar_functor { | ||||||
|   template <typename U> |   template <typename U> | ||||||
| @ -362,6 +369,7 @@ struct igammac_functor { | |||||||
|   REGISTER_OPMATH_BINARY_OP(NAME, half, half);   \ |   REGISTER_OPMATH_BINARY_OP(NAME, half, half);   \ | ||||||
|   REGISTER_OPMATH_BINARY_OP(NAME, bfloat, bfloat) |   REGISTER_OPMATH_BINARY_OP(NAME, bfloat, bfloat) | ||||||
|  |  | ||||||
|  | REGISTER_FLOAT_BINARY_OP(hypot); | ||||||
| REGISTER_FLOAT_BINARY_OP(copysign); | REGISTER_FLOAT_BINARY_OP(copysign); | ||||||
| REGISTER_INT2FLOAT_BINARY_OP(copysign); | REGISTER_INT2FLOAT_BINARY_OP(copysign); | ||||||
| REGISTER_FLOAT_BINARY_OP(fmax); | REGISTER_FLOAT_BINARY_OP(fmax); | ||||||
|  | |||||||
							
								
								
									
										16
									
								
								aten/src/ATen/native/mps/kernels/LinearAlgebra.h
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										16
									
								
								aten/src/ATen/native/mps/kernels/LinearAlgebra.h
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,16 @@ | |||||||
|  | #pragma onces | ||||||
|  | #include <c10/metal/common.h> | ||||||
|  |  | ||||||
|  | template <unsigned N = c10::metal::max_ndim> | ||||||
|  | struct OrgqrParams { | ||||||
|  |   int32_t num_batch_dims; | ||||||
|  |  | ||||||
|  |   uint32_t m; | ||||||
|  |   uint32_t n; | ||||||
|  |   uint32_t k; | ||||||
|  |  | ||||||
|  |   ::c10::metal::array<uint32_t, N> A_strides; | ||||||
|  |   ::c10::metal::array<uint32_t, N> tau_strides; | ||||||
|  |   ::c10::metal::array<uint32_t, N> H_strides; | ||||||
|  |   ::c10::metal::array<uint32_t, N> H_sizes; | ||||||
|  | }; | ||||||
| @ -1,3 +1,4 @@ | |||||||
|  | #include <ATen/native/mps/kernels/LinearAlgebra.h> | ||||||
| #include <c10/metal/utils.h> | #include <c10/metal/utils.h> | ||||||
| #include <metal_array> | #include <metal_array> | ||||||
| #include <metal_simdgroup> | #include <metal_simdgroup> | ||||||
| @ -640,6 +641,164 @@ kernel void applyPivots( | |||||||
|   } |   } | ||||||
| } | } | ||||||
|  |  | ||||||
|  | template <typename T> | ||||||
|  | static T bool_to_float(bool b) { | ||||||
|  |   return static_cast<T>(b); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | template <> | ||||||
|  | half2 bool_to_float(bool b) { | ||||||
|  |   return half2(b ? 1 : 0, 0); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | template <> | ||||||
|  | float2 bool_to_float(bool b) { | ||||||
|  |   return float2(b ? 1 : 0, 0); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | template <typename T> | ||||||
|  | static T calc_H_irc( | ||||||
|  |     device T* A, | ||||||
|  |     uint32_t A_stride_r, | ||||||
|  |     uint32_t A_stride_c, | ||||||
|  |     constant T* tau, | ||||||
|  |     uint32_t tau_stride, | ||||||
|  |     uint32_t r, | ||||||
|  |     uint32_t c, | ||||||
|  |     uint32_t i) { | ||||||
|  |   T I_val = bool_to_float<T>(r == c); | ||||||
|  |   T tau_val = tau[i * tau_stride]; | ||||||
|  |  | ||||||
|  |   T A_ci = c10::metal::conj(A[c * A_stride_r + i * A_stride_c]); | ||||||
|  |   T A_ri = A[r * A_stride_r + i * A_stride_c]; | ||||||
|  |  | ||||||
|  |   T c_eq_i = bool_to_float<T>(c == i); | ||||||
|  |   T r_eq_i = bool_to_float<T>(r == i); | ||||||
|  |  | ||||||
|  |   T A_ci_ = (c > i) ? A_ci : c_eq_i; | ||||||
|  |   T A_ri_ = (r > i) ? A_ri : r_eq_i; | ||||||
|  |  | ||||||
|  |   return I_val - c10::metal::mul(tau_val, c10::metal::mul(A_ci_, A_ri_)); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Calculate (A @ B)[r, c], the element in the r-th row and c-th column of the | ||||||
|  | // result of matrix multiplying A and B together. A and B must be size m-by-m | ||||||
|  | // and have the same strides. The formula for this operation, written in Python | ||||||
|  | // syntax, is: | ||||||
|  | //   (A @ B)[r, c] = A[r, :].dot(B[:, c]) | ||||||
|  | template <typename T> | ||||||
|  | static T calc_matmul_rc( | ||||||
|  |     device T* A, | ||||||
|  |     device T* B, | ||||||
|  |     uint32_t stride_r, | ||||||
|  |     uint32_t stride_c, | ||||||
|  |     uint32_t m, | ||||||
|  |     uint32_t r, | ||||||
|  |     uint32_t c) { | ||||||
|  |   T AB_rc = 0; | ||||||
|  |   auto A_row_offset = r * stride_r; | ||||||
|  |   auto B_col_offset = c * stride_c; | ||||||
|  |  | ||||||
|  |   uint32_t A_col_offset = 0; | ||||||
|  |   uint32_t B_row_offset = 0; | ||||||
|  |  | ||||||
|  |   for (uint32_t j = 0; j < m; | ||||||
|  |        j++, A_col_offset += stride_c, B_row_offset += stride_r) { | ||||||
|  |     AB_rc += c10::metal::mul( | ||||||
|  |         A[A_row_offset + A_col_offset], B[B_row_offset + B_col_offset]); | ||||||
|  |   } | ||||||
|  |   return AB_rc; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | template <typename T> | ||||||
|  | kernel void orgqr( | ||||||
|  |     device T* A [[buffer(0)]], | ||||||
|  |     constant T* tau [[buffer(1)]], | ||||||
|  |     device T* H [[buffer(2)]], | ||||||
|  |     device T* H_prod [[buffer(3)]], | ||||||
|  |     constant OrgqrParams<>& params [[buffer(4)]], | ||||||
|  |     uint tid [[thread_position_in_grid]]) { | ||||||
|  |   constant auto& A_strides = params.A_strides; | ||||||
|  |   constant auto& tau_strides = params.tau_strides; | ||||||
|  |   constant auto& H_strides = params.H_strides; | ||||||
|  |   constant auto& H_sizes = params.H_sizes; | ||||||
|  |  | ||||||
|  |   auto num_batch_dims = params.num_batch_dims; | ||||||
|  |   auto m = params.m; | ||||||
|  |   auto n = params.n; | ||||||
|  |   auto k = params.k; | ||||||
|  |  | ||||||
|  |   auto m2 = m * m; | ||||||
|  |   auto batch_idx = tid / m2; | ||||||
|  |  | ||||||
|  |   // Find the matrices for this thread's batch index | ||||||
|  |   uint32_t A_offset = 0; | ||||||
|  |   uint32_t tau_offset = 0; | ||||||
|  |   uint32_t H_offset = 0; | ||||||
|  |  | ||||||
|  |   for (auto dim = num_batch_dims - 1; dim >= 0; dim--) { | ||||||
|  |     auto dim_size = H_sizes[dim]; | ||||||
|  |     auto dim_idx = batch_idx % dim_size; | ||||||
|  |  | ||||||
|  |     A_offset += dim_idx * A_strides[dim]; | ||||||
|  |     tau_offset += dim_idx * tau_strides[dim]; | ||||||
|  |     H_offset += dim_idx * H_strides[dim]; | ||||||
|  |  | ||||||
|  |     batch_idx /= dim_size; | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   A += A_offset; | ||||||
|  |   tau += tau_offset; | ||||||
|  |   H += H_offset; | ||||||
|  |   H_prod += H_offset; | ||||||
|  |  | ||||||
|  |   auto matrix_idx = tid % m2; | ||||||
|  |   auto r = matrix_idx / m; | ||||||
|  |   auto c = matrix_idx % m; | ||||||
|  |   auto A_stride_r = A_strides[num_batch_dims]; | ||||||
|  |   auto A_stride_c = A_strides[num_batch_dims + 1]; | ||||||
|  |   auto tau_stride = tau_strides[num_batch_dims]; | ||||||
|  |   auto H_stride_r = H_strides[num_batch_dims]; | ||||||
|  |   auto H_stride_c = H_strides[num_batch_dims + 1]; | ||||||
|  |  | ||||||
|  |   // Find the element of H and H_prod that this thread will calculate | ||||||
|  |   device T* H_elem_ptr = H + (r * H_stride_r + c * H_stride_c); | ||||||
|  |   device T* H_prod_elem_ptr = H_prod + (r * H_stride_r + c * H_stride_c); | ||||||
|  |  | ||||||
|  |   for (uint32_t i = 0; i < k; i++) { | ||||||
|  |     // Calculate and write H_i | ||||||
|  |  | ||||||
|  |     T H_irc = calc_H_irc(A, A_stride_r, A_stride_c, tau, tau_stride, r, c, i); | ||||||
|  |  | ||||||
|  |     // Calculate element [r, c] of prod(H_0, ..., H_i) | ||||||
|  |     if (i == 0) { | ||||||
|  |       *H_prod_elem_ptr = H_irc; | ||||||
|  |     } else { | ||||||
|  |       *H_elem_ptr = H_irc; | ||||||
|  |  | ||||||
|  |       // Need this sync because the below matmul requires all threads to finish | ||||||
|  |       // writing their entries to `H_prod` and `H`. | ||||||
|  |       threadgroup_barrier(mem_flags::mem_threadgroup); | ||||||
|  |  | ||||||
|  |       T H_prod_0_to_i_rc = | ||||||
|  |           calc_matmul_rc(H_prod, H, H_stride_r, H_stride_c, m, r, c); | ||||||
|  |  | ||||||
|  |       // Need this sync because the above matmul uses the current values in | ||||||
|  |       // `H_prod`, and we don't want to overwrite those until all threads are | ||||||
|  |       // finished using them. | ||||||
|  |       threadgroup_barrier(mem_flags::mem_threadgroup); | ||||||
|  |  | ||||||
|  |       *H_prod_elem_ptr = H_prod_0_to_i_rc; | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   device T* A_elem_ptr = A + (r * A_stride_r + c * A_stride_c); | ||||||
|  |  | ||||||
|  |   if (c < n) { | ||||||
|  |     *A_elem_ptr = *H_prod_elem_ptr; | ||||||
|  |   } | ||||||
|  | } | ||||||
|  |  | ||||||
| #define INSTANTIATE_MM_OPS(DTYPE)                                           \ | #define INSTANTIATE_MM_OPS(DTYPE)                                           \ | ||||||
|   template [[host_name("matmul_" #DTYPE)]] kernel void matmul<DTYPE>(       \ |   template [[host_name("matmul_" #DTYPE)]] kernel void matmul<DTYPE>(       \ | ||||||
|       constant DTYPE * mat1Data [[buffer(0)]],                              \ |       constant DTYPE * mat1Data [[buffer(0)]],                              \ | ||||||
| @ -679,3 +838,19 @@ INSTANTIATE_MM_OPS(int); | |||||||
| INSTANTIATE_MM_OPS(short); | INSTANTIATE_MM_OPS(short); | ||||||
| INSTANTIATE_MM_OPS(char); | INSTANTIATE_MM_OPS(char); | ||||||
| INSTANTIATE_MM_OPS(uchar); | INSTANTIATE_MM_OPS(uchar); | ||||||
|  |  | ||||||
|  | #define REGISTER_ORGQR(T)                            \ | ||||||
|  |   template [[host_name("orgqr_" #T)]]                \ | ||||||
|  |   kernel void orgqr<T>(                              \ | ||||||
|  |       device T * A [[buffer(0)]],                    \ | ||||||
|  |       constant T * tau [[buffer(1)]],                \ | ||||||
|  |       device T * H [[buffer(2)]],                    \ | ||||||
|  |       device T * H_prod [[buffer(3)]],               \ | ||||||
|  |       constant OrgqrParams<> & params [[buffer(4)]], \ | ||||||
|  |       uint tid [[thread_position_in_grid]]); | ||||||
|  |  | ||||||
|  | REGISTER_ORGQR(float); | ||||||
|  | REGISTER_ORGQR(half); | ||||||
|  | REGISTER_ORGQR(bfloat); | ||||||
|  | REGISTER_ORGQR(float2); | ||||||
|  | REGISTER_ORGQR(half2); | ||||||
|  | |||||||
| @ -5,6 +5,21 @@ | |||||||
| using namespace metal; | using namespace metal; | ||||||
| using namespace c10::metal; | using namespace c10::metal; | ||||||
|  |  | ||||||
|  | struct angle_functor { | ||||||
|  |   template <typename T, enable_if_t<is_complex_v<T>, bool> = true> | ||||||
|  |   inline T operator()(const T x) { | ||||||
|  |     return T(atan2(x.y, x.x), 0); | ||||||
|  |   } | ||||||
|  |   template <typename T, enable_if_t<is_scalar_floating_point_v<T>, bool> = true> | ||||||
|  |   inline T operator()(const T x) { | ||||||
|  |     return T(isnan(x) ? x : x < 0 ? M_PI_F : 0.0); | ||||||
|  |   } | ||||||
|  |   template <typename T, enable_if_t<is_scalar_integral_v<T>, bool> = true> | ||||||
|  |   inline float operator()(const T x) { | ||||||
|  |     return x < 0 ? M_PI_F : 0.0; | ||||||
|  |   } | ||||||
|  | }; | ||||||
|  |  | ||||||
| // Implement exp wrapper for both real and complex types | // Implement exp wrapper for both real and complex types | ||||||
| template <typename T, enable_if_t<is_scalar_floating_point_v<T>, bool> = true> | template <typename T, enable_if_t<is_scalar_floating_point_v<T>, bool> = true> | ||||||
| inline T exp_(const T x) { | inline T exp_(const T x) { | ||||||
| @ -545,6 +560,7 @@ REGISTER_UNARY_OP(abs, float, float); | |||||||
| REGISTER_UNARY_OP(abs, half, half); | REGISTER_UNARY_OP(abs, half, half); | ||||||
|  |  | ||||||
| #define INSTANTIATE_UNARY_KERNELS2(DTYPE0, DTYPE1) \ | #define INSTANTIATE_UNARY_KERNELS2(DTYPE0, DTYPE1) \ | ||||||
|  |   REGISTER_UNARY_OP(angle, DTYPE1, DTYPE0);        \ | ||||||
|   REGISTER_UNARY_OP(erf, DTYPE1, DTYPE0);          \ |   REGISTER_UNARY_OP(erf, DTYPE1, DTYPE0);          \ | ||||||
|   REGISTER_UNARY_OP(erfc, DTYPE1, DTYPE0);         \ |   REGISTER_UNARY_OP(erfc, DTYPE1, DTYPE0);         \ | ||||||
|   REGISTER_UNARY_OP(erfinv, DTYPE1, DTYPE0);       \ |   REGISTER_UNARY_OP(erfinv, DTYPE1, DTYPE0);       \ | ||||||
| @ -583,6 +599,7 @@ INSTANTIATE_UNARY_KERNELS2(float, int); | |||||||
| INSTANTIATE_UNARY_KERNELS2(float, long); | INSTANTIATE_UNARY_KERNELS2(float, long); | ||||||
|  |  | ||||||
| #define INSTANTIATE_UNARY_KERNELS_VEC2(DTYPE)     \ | #define INSTANTIATE_UNARY_KERNELS_VEC2(DTYPE)     \ | ||||||
|  |   REGISTER_UNARY_OP(angle, DTYPE##2, DTYPE##2);   \ | ||||||
|   REGISTER_UNARY_OP(neg, DTYPE##2, DTYPE##2);     \ |   REGISTER_UNARY_OP(neg, DTYPE##2, DTYPE##2);     \ | ||||||
|   REGISTER_UNARY_OP(exp, DTYPE##2, DTYPE##2);     \ |   REGISTER_UNARY_OP(exp, DTYPE##2, DTYPE##2);     \ | ||||||
|   REGISTER_UNARY_OP(expm1, DTYPE##2, DTYPE##2);   \ |   REGISTER_UNARY_OP(expm1, DTYPE##2, DTYPE##2);   \ | ||||||
|  | |||||||
| @ -92,13 +92,8 @@ static std::tuple<Tensor, Tensor> sdpa_general_mps(const Tensor& query, | |||||||
|           } |           } | ||||||
|  |  | ||||||
|           // upcasting to float32 if needed to improve precision when multiplying by the scale factor |           // upcasting to float32 if needed to improve precision when multiplying by the scale factor | ||||||
|           if ([maskedMM dataType] != MPSDataTypeFloat32) { |           maskedMM = castMPSTensor(mpsGraph, maskedMM, MPSDataTypeFloat32); | ||||||
|             maskedMM = [mpsGraph castTensor:maskedMM toType:MPSDataTypeFloat32 name:nil]; |  | ||||||
|           } |  | ||||||
|           maskedMM = [mpsGraph multiplicationWithPrimaryTensor:maskedMM secondaryTensor:scaleTensor name:nil]; |           maskedMM = [mpsGraph multiplicationWithPrimaryTensor:maskedMM secondaryTensor:scaleTensor name:nil]; | ||||||
|           if ([maskedMM dataType] != qTensor.dataType) { |  | ||||||
|             maskedMM = [mpsGraph castTensor:maskedMM toType:qTensor.dataType name:nil]; |  | ||||||
|           } |  | ||||||
|  |  | ||||||
|           if (is_causal) { |           if (is_causal) { | ||||||
|             auto causalMask = [mpsGraph constantWithScalar:1.0f |             auto causalMask = [mpsGraph constantWithScalar:1.0f | ||||||
| @ -112,7 +107,9 @@ static std::tuple<Tensor, Tensor> sdpa_general_mps(const Tensor& query, | |||||||
|                                                       name:nil]; |                                                       name:nil]; | ||||||
|           } else if (attn_mask) { |           } else if (attn_mask) { | ||||||
|             graph->maskTensor = mpsGraphRankedPlaceHolder(mpsGraph, *attn_mask); |             graph->maskTensor = mpsGraphRankedPlaceHolder(mpsGraph, *attn_mask); | ||||||
|             maskedMM = [mpsGraph additionWithPrimaryTensor:maskedMM secondaryTensor:graph->maskTensor name:nil]; |             maskedMM = [mpsGraph additionWithPrimaryTensor:maskedMM | ||||||
|  |                                            secondaryTensor:castMPSTensor(mpsGraph, graph->maskTensor, maskedMM.dataType) | ||||||
|  |                                                       name:nil]; | ||||||
|           } |           } | ||||||
|  |  | ||||||
|           // Account for case where all values were masked causing division by 0 in softmax (issue:#156707) |           // Account for case where all values were masked causing division by 0 in softmax (issue:#156707) | ||||||
| @ -133,8 +130,8 @@ static std::tuple<Tensor, Tensor> sdpa_general_mps(const Tensor& query, | |||||||
|           graph->qTensor = qTensor; |           graph->qTensor = qTensor; | ||||||
|           graph->kTensor = kTensor; |           graph->kTensor = kTensor; | ||||||
|           graph->vTensor = vTensor; |           graph->vTensor = vTensor; | ||||||
|           graph->outputTensor = output; |           graph->outputTensor = castMPSTensor(mpsGraph, output, qTensor.dataType); | ||||||
|           graph->attnTensor = sm; |           graph->attnTensor = castMPSTensor(mpsGraph, sm, qTensor.dataType); | ||||||
|         }); |         }); | ||||||
|     auto qPlaceholder = Placeholder(cachedGraph->qTensor, query); |     auto qPlaceholder = Placeholder(cachedGraph->qTensor, query); | ||||||
|     auto kPlaceholder = Placeholder(cachedGraph->kTensor, key); |     auto kPlaceholder = Placeholder(cachedGraph->kTensor, key); | ||||||
|  | |||||||
| @ -202,6 +202,10 @@ static void igammac_mps_kernel(TensorIteratorBase& iter) { | |||||||
|   lib.exec_binary_kernel(iter, "igammac"); |   lib.exec_binary_kernel(iter, "igammac"); | ||||||
| } | } | ||||||
|  |  | ||||||
|  | static void hypot_mps_kernel(TensorIteratorBase& iter) { | ||||||
|  |   lib.exec_binary_kernel(iter, "hypot"); | ||||||
|  | } | ||||||
|  |  | ||||||
| REGISTER_DISPATCH(fmax_stub, &fmax_mps_kernel) | REGISTER_DISPATCH(fmax_stub, &fmax_mps_kernel) | ||||||
| REGISTER_DISPATCH(fmin_stub, &fmin_mps_kernel) | REGISTER_DISPATCH(fmin_stub, &fmin_mps_kernel) | ||||||
| REGISTER_DISPATCH(copysign_stub, ©sign_mps_kernel) | REGISTER_DISPATCH(copysign_stub, ©sign_mps_kernel) | ||||||
| @ -229,4 +233,5 @@ REGISTER_DISPATCH(fmod_stub, &fmod_mps_kernel) | |||||||
| REGISTER_DISPATCH(remainder_stub, &remainder_mps_kernel) | REGISTER_DISPATCH(remainder_stub, &remainder_mps_kernel) | ||||||
| REGISTER_DISPATCH(igamma_stub, &igamma_mps_kernel) | REGISTER_DISPATCH(igamma_stub, &igamma_mps_kernel) | ||||||
| REGISTER_DISPATCH(igammac_stub, &igammac_mps_kernel) | REGISTER_DISPATCH(igammac_stub, &igammac_mps_kernel) | ||||||
|  | REGISTER_DISPATCH(hypot_stub, &hypot_mps_kernel) | ||||||
| } // namespace at::native | } // namespace at::native | ||||||
|  | |||||||
| @ -16,7 +16,6 @@ | |||||||
| #include <ATen/ops/eq_native.h> | #include <ATen/ops/eq_native.h> | ||||||
| #include <ATen/ops/ge_native.h> | #include <ATen/ops/ge_native.h> | ||||||
| #include <ATen/ops/gt_native.h> | #include <ATen/ops/gt_native.h> | ||||||
| #include <ATen/ops/hypot_native.h> |  | ||||||
| #include <ATen/ops/le_native.h> | #include <ATen/ops/le_native.h> | ||||||
| #include <ATen/ops/logaddexp2_native.h> | #include <ATen/ops/logaddexp2_native.h> | ||||||
| #include <ATen/ops/logaddexp_native.h> | #include <ATen/ops/logaddexp_native.h> | ||||||
| @ -278,22 +277,6 @@ TORCH_IMPL_FUNC(pow_Scalar_out_mps)(const Scalar& base, const Tensor& exp, const | |||||||
|   } |   } | ||||||
| } | } | ||||||
|  |  | ||||||
| TORCH_IMPL_FUNC(hypot_out_mps)(const Tensor& self, const Tensor& other, const Tensor& output) { |  | ||||||
|   mps::BinaryOpBlock hypot_op_block = ^BinaryOpFn(cachedGraph, primaryCastTensor, secondaryCastTensor) { |  | ||||||
|     MPSGraph* mpsGraph = cachedGraph->graph(); |  | ||||||
|     MPSGraphTensor* twoTensor = [mpsGraph constantWithScalar:2.0 shape:@[ @1 ] dataType:primaryCastTensor.dataType]; |  | ||||||
|     MPSGraphTensor* sumTensor = [mpsGraph additionWithPrimaryTensor:[mpsGraph powerWithPrimaryTensor:primaryCastTensor |  | ||||||
|                                                                                      secondaryTensor:twoTensor |  | ||||||
|                                                                                                 name:nil] |  | ||||||
|                                                     secondaryTensor:[mpsGraph powerWithPrimaryTensor:secondaryCastTensor |  | ||||||
|                                                                                      secondaryTensor:twoTensor |  | ||||||
|                                                                                                 name:nil] |  | ||||||
|                                                                name:nil]; |  | ||||||
|     return [mpsGraph squareRootWithTensor:sumTensor name:nil]; |  | ||||||
|   }; |  | ||||||
|   mps::binaryOpTensor(self, other, output, "hypot_out_mps", hypot_op_block); |  | ||||||
| } |  | ||||||
|  |  | ||||||
| TORCH_IMPL_FUNC(logaddexp_out_mps)(const Tensor& self, const Tensor& other, const Tensor& output) { | TORCH_IMPL_FUNC(logaddexp_out_mps)(const Tensor& self, const Tensor& other, const Tensor& output) { | ||||||
|   mps::BinaryOpBlock logaddexp_op_block = ^BinaryOpFn(cachedGraph, primaryCastTensor, secondaryCastTensor) { |   mps::BinaryOpBlock logaddexp_op_block = ^BinaryOpFn(cachedGraph, primaryCastTensor, secondaryCastTensor) { | ||||||
|     MPSGraph* mpsGraph = cachedGraph->graph(); |     MPSGraph* mpsGraph = cachedGraph->graph(); | ||||||
|  | |||||||
| @ -8,6 +8,9 @@ | |||||||
| #include <ATen/native/Resize.h> | #include <ATen/native/Resize.h> | ||||||
| #include <ATen/native/mps/MPSGraphSequoiaOps.h> | #include <ATen/native/mps/MPSGraphSequoiaOps.h> | ||||||
| #include <ATen/native/mps/OperationUtils.h> | #include <ATen/native/mps/OperationUtils.h> | ||||||
|  | #include <ATen/native/mps/kernels/LinearAlgebra.h> | ||||||
|  |  | ||||||
|  | #include <fmt/format.h> | ||||||
|  |  | ||||||
| #ifndef AT_PER_OPERATOR_HEADERS | #ifndef AT_PER_OPERATOR_HEADERS | ||||||
| #include <ATen/Functions.h> | #include <ATen/Functions.h> | ||||||
| @ -28,6 +31,7 @@ | |||||||
| #include <ATen/ops/linalg_solve_triangular_native.h> | #include <ATen/ops/linalg_solve_triangular_native.h> | ||||||
| #include <ATen/ops/lu_unpack_native.h> | #include <ATen/ops/lu_unpack_native.h> | ||||||
| #include <ATen/ops/mm_native.h> | #include <ATen/ops/mm_native.h> | ||||||
|  | #include <ATen/ops/orgqr_native.h> | ||||||
| #include <ATen/ops/slice.h> | #include <ATen/ops/slice.h> | ||||||
| #include <ATen/ops/stack.h> | #include <ATen/ops/stack.h> | ||||||
| #include <ATen/ops/triangular_solve_native.h> | #include <ATen/ops/triangular_solve_native.h> | ||||||
| @ -338,6 +342,8 @@ static void linalg_lu_factor_ex_out_mps_impl(const Tensor& A, | |||||||
|           ". See https://developer.apple.com/documentation/metalperformanceshaders/mpsmatrixdecompositionstatus for details."); |           ". See https://developer.apple.com/documentation/metalperformanceshaders/mpsmatrixdecompositionstatus for details."); | ||||||
|     } |     } | ||||||
|   } |   } | ||||||
|  |  | ||||||
|  |   map_mps_decomposition_error_code_to_blas(info); | ||||||
| } | } | ||||||
|  |  | ||||||
| static void linalg_solve_out_mps_impl(const Tensor& A, | static void linalg_solve_out_mps_impl(const Tensor& A, | ||||||
| @ -1233,6 +1239,69 @@ static void cholesky_stub_impl(const Tensor& out, const Tensor& info, bool upper | |||||||
|   } |   } | ||||||
| } | } | ||||||
|  |  | ||||||
|  | static Tensor& orgqr_stub_impl(Tensor& self, const Tensor& tau) { | ||||||
|  |   if (self.numel() == 0) { | ||||||
|  |     return self; | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   auto m = self.size(-2); | ||||||
|  |   auto n = self.size(-1); | ||||||
|  |   auto k = tau.size(-1); | ||||||
|  |  | ||||||
|  |   if (tau.numel() == 0) { | ||||||
|  |     auto I = eye(m, self.scalar_type(), std::nullopt, self.device()); | ||||||
|  |     return self.copy_(I.slice(-1, 0, n)); | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   auto num_batch_dims = self.dim() - 2; | ||||||
|  |   auto batch_sizes = self.sizes().slice(0, num_batch_dims); | ||||||
|  |  | ||||||
|  |   std::vector<int64_t> H_sizes(num_batch_dims + 2); | ||||||
|  |   for (auto dim : c10::irange(num_batch_dims)) { | ||||||
|  |     H_sizes[dim] = self.size(dim); | ||||||
|  |   } | ||||||
|  |   H_sizes[num_batch_dims] = m; | ||||||
|  |   H_sizes[num_batch_dims + 1] = m; | ||||||
|  |  | ||||||
|  |   auto H = at::empty(H_sizes, self.options().memory_format(MemoryFormat::Contiguous)); | ||||||
|  |   auto H_prod = at::empty_like(H); | ||||||
|  |  | ||||||
|  |   OrgqrParams params; | ||||||
|  |  | ||||||
|  |   params.num_batch_dims = num_batch_dims; | ||||||
|  |   params.m = m; | ||||||
|  |   params.n = n; | ||||||
|  |   params.k = k; | ||||||
|  |  | ||||||
|  |   for (const auto dim : c10::irange(self.dim())) { | ||||||
|  |     params.A_strides[dim] = self.stride(dim); | ||||||
|  |  | ||||||
|  |     if (dim < tau.dim()) { | ||||||
|  |       params.tau_strides[dim] = tau.stride(dim); | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     params.H_strides[dim] = H.stride(dim); | ||||||
|  |     params.H_sizes[dim] = H.size(dim); | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   auto num_threads = H.numel(); | ||||||
|  |   MPSStream* stream = getCurrentMPSStream(); | ||||||
|  |  | ||||||
|  |   dispatch_sync_with_rethrow(stream->queue(), ^() { | ||||||
|  |     @autoreleasepool { | ||||||
|  |       id<MTLComputeCommandEncoder> compute_encoder = stream->commandEncoder(); | ||||||
|  |       auto pipeline_state = lib.getPipelineStateForFunc(fmt::format("orgqr_{}", scalarToMetalTypeString(self))); | ||||||
|  |       getMPSProfiler().beginProfileKernel(pipeline_state, "orgqr", {self, tau}); | ||||||
|  |       [compute_encoder setComputePipelineState:pipeline_state]; | ||||||
|  |       mtl_setArgs(compute_encoder, self, tau, H, H_prod, params); | ||||||
|  |       mtl_dispatch1DJob(compute_encoder, pipeline_state, num_threads); | ||||||
|  |       getMPSProfiler().endProfileKernel(pipeline_state); | ||||||
|  |     } | ||||||
|  |   }); | ||||||
|  |  | ||||||
|  |   return self; | ||||||
|  | } | ||||||
|  |  | ||||||
| } // namespace mps | } // namespace mps | ||||||
|  |  | ||||||
| Tensor addr_mps(const Tensor& self, const Tensor& vec1, const Tensor& vec2, const Scalar& beta, const Scalar& alpha) { | Tensor addr_mps(const Tensor& self, const Tensor& vec1, const Tensor& vec2, const Scalar& beta, const Scalar& alpha) { | ||||||
| @ -1448,20 +1517,6 @@ TORCH_IMPL_FUNC(_linalg_solve_ex_out_mps) | |||||||
|   mps::linalg_solve_out_mps_impl(A, B, left, check_errors, result, LU, pivots, info); |   mps::linalg_solve_out_mps_impl(A, B, left, check_errors, result, LU, pivots, info); | ||||||
| } | } | ||||||
|  |  | ||||||
| std::tuple<Tensor&, Tensor&> linalg_lu_factor_out_mps(const Tensor& A, bool pivot, Tensor& LU, Tensor& pivots) { |  | ||||||
|   Tensor info = at::empty({}, A.options().dtype(kInt)); |  | ||||||
|   mps::linalg_lu_factor_ex_out_mps_impl(A, pivot, LU, pivots, info, false); |  | ||||||
|   return std::tie(LU, pivots); |  | ||||||
| } |  | ||||||
|  |  | ||||||
| std::tuple<Tensor, Tensor> linalg_lu_factor_mps(const Tensor& A, bool pivot) { |  | ||||||
|   Tensor LU = at::empty({0}, A.options()); |  | ||||||
|   Tensor pivots = at::empty({0}, A.options().dtype(kInt)); |  | ||||||
|   Tensor info = at::empty({}, A.options().dtype(kInt)); |  | ||||||
|   mps::linalg_lu_factor_ex_out_mps_impl(A, pivot, LU, pivots, info, false); |  | ||||||
|   return std::make_tuple(std::move(LU), std::move(pivots)); |  | ||||||
| } |  | ||||||
|  |  | ||||||
| TORCH_IMPL_FUNC(lu_unpack_out_mps) | TORCH_IMPL_FUNC(lu_unpack_out_mps) | ||||||
| (const Tensor& LU_data, | (const Tensor& LU_data, | ||||||
|  const Tensor& LU_pivots, |  const Tensor& LU_pivots, | ||||||
| @ -1483,4 +1538,6 @@ TORCH_IMPL_FUNC(linalg_inv_ex_out_mps)(const Tensor& A, bool check_errors, const | |||||||
| } | } | ||||||
|  |  | ||||||
| REGISTER_DISPATCH(cholesky_stub, mps::cholesky_stub_impl) | REGISTER_DISPATCH(cholesky_stub, mps::cholesky_stub_impl) | ||||||
|  | REGISTER_DISPATCH(orgqr_stub, mps::orgqr_stub_impl); | ||||||
|  |  | ||||||
| } // namespace at::native | } // namespace at::native | ||||||
|  | |||||||
| @ -34,6 +34,7 @@ REGISTER_UNARY_TI_DISPATCH(sinc); | |||||||
| REGISTER_UNARY_TI_DISPATCH(sinh); | REGISTER_UNARY_TI_DISPATCH(sinh); | ||||||
| REGISTER_UNARY_TI_DISPATCH(cosh); | REGISTER_UNARY_TI_DISPATCH(cosh); | ||||||
| REGISTER_UNARY_TI_DISPATCH(tanh); | REGISTER_UNARY_TI_DISPATCH(tanh); | ||||||
|  | REGISTER_UNARY_TI_DISPATCH(angle); | ||||||
| REGISTER_UNARY_TI_DISPATCH(abs); | REGISTER_UNARY_TI_DISPATCH(abs); | ||||||
| REGISTER_UNARY_TI_DISPATCH(sin); | REGISTER_UNARY_TI_DISPATCH(sin); | ||||||
| REGISTER_UNARY_TI_DISPATCH(cos); | REGISTER_UNARY_TI_DISPATCH(cos); | ||||||
|  | |||||||
| @ -12,7 +12,6 @@ | |||||||
| #include <ATen/ops/_copy_from_and_resize.h> | #include <ATen/ops/_copy_from_and_resize.h> | ||||||
| #include <ATen/ops/acos_native.h> | #include <ATen/ops/acos_native.h> | ||||||
| #include <ATen/ops/acosh_native.h> | #include <ATen/ops/acosh_native.h> | ||||||
| #include <ATen/ops/angle_native.h> |  | ||||||
| #include <ATen/ops/asin_native.h> | #include <ATen/ops/asin_native.h> | ||||||
| #include <ATen/ops/asinh_native.h> | #include <ATen/ops/asinh_native.h> | ||||||
| #include <ATen/ops/atan_native.h> | #include <ATen/ops/atan_native.h> | ||||||
| @ -204,23 +203,6 @@ Tensor& logical_not_out_mps(const Tensor& self, Tensor& output) { | |||||||
|   return output; |   return output; | ||||||
| } | } | ||||||
|  |  | ||||||
| Tensor& angle_out_mps(const Tensor& self, Tensor& output) { |  | ||||||
|   mps::unary_op(self, output, "angle_out_mps", ^MPSGraphTensor*(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) { |  | ||||||
|     auto realPart = [mpsGraph realPartOfTensor:inputTensor name:nil]; |  | ||||||
|     auto imagPart = [mpsGraph imaginaryPartOfTensor:inputTensor name:nil]; |  | ||||||
|     return [mpsGraph atan2WithPrimaryTensor:imagPart secondaryTensor:realPart name:nil]; |  | ||||||
|   }); |  | ||||||
|   return output; |  | ||||||
| } |  | ||||||
|  |  | ||||||
| Tensor angle_mps(const Tensor& self) { |  | ||||||
|   const auto float_type = c10::isIntegralType(self.scalar_type(), /*includeBool=*/true) |  | ||||||
|       ? c10::typeMetaToScalarType(c10::get_default_dtype()) |  | ||||||
|       : c10::toRealValueType(self.scalar_type()); |  | ||||||
|   Tensor result = at::empty({0}, self.options().dtype(float_type)); |  | ||||||
|   return angle_out_mps(self, result); |  | ||||||
| } |  | ||||||
|  |  | ||||||
| TORCH_IMPL_FUNC(frac_out_mps)(const Tensor& self, const Tensor& output) { | TORCH_IMPL_FUNC(frac_out_mps)(const Tensor& self, const Tensor& output) { | ||||||
|   TORCH_CHECK(isFloatingType(self.scalar_type()), "frac_out_mps is only implemented for floating types"); |   TORCH_CHECK(isFloatingType(self.scalar_type()), "frac_out_mps is only implemented for floating types"); | ||||||
|   mps::unary_op(self, output, "frac_out_mps", ^MPSGraphTensor*(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) { |   mps::unary_op(self, output, "frac_out_mps", ^MPSGraphTensor*(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) { | ||||||
|  | |||||||
| @ -403,16 +403,14 @@ | |||||||
|   device_check: NoCheck   # TensorIterator |   device_check: NoCheck   # TensorIterator | ||||||
|   variants: function, method |   variants: function, method | ||||||
|   dispatch: |   dispatch: | ||||||
|     CPU, CUDA: angle |     CPU, CUDA, MPS: angle | ||||||
|     MPS: angle_mps |  | ||||||
|     SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: angle_sparse_csr |     SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: angle_sparse_csr | ||||||
|   tags: pointwise |   tags: pointwise | ||||||
|  |  | ||||||
| - func: angle.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) | - func: angle.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) | ||||||
|   device_check: NoCheck   # TensorIterator |   device_check: NoCheck   # TensorIterator | ||||||
|   dispatch: |   dispatch: | ||||||
|     CPU, CUDA: angle_out |     CPU, CUDA, MPS: angle_out | ||||||
|     MPS: angle_out_mps |  | ||||||
|     SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: angle_sparse_csr_out |     SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: angle_sparse_csr_out | ||||||
|   tags: pointwise |   tags: pointwise | ||||||
|  |  | ||||||
| @ -10042,8 +10040,7 @@ | |||||||
|   structured: True |   structured: True | ||||||
|   structured_inherits: TensorIteratorBase |   structured_inherits: TensorIteratorBase | ||||||
|   dispatch: |   dispatch: | ||||||
|     CPU, CUDA: hypot_out |     CPU, CUDA, MPS: hypot_out | ||||||
|     MPS: hypot_out_mps |  | ||||||
|   tags: pointwise |   tags: pointwise | ||||||
|  |  | ||||||
| - func: hypot(Tensor self, Tensor other) -> Tensor | - func: hypot(Tensor self, Tensor other) -> Tensor | ||||||
| @ -14157,16 +14154,10 @@ | |||||||
| - func: linalg_lu_factor(Tensor A, *, bool pivot=True) -> (Tensor LU, Tensor pivots) | - func: linalg_lu_factor(Tensor A, *, bool pivot=True) -> (Tensor LU, Tensor pivots) | ||||||
|   python_module: linalg |   python_module: linalg | ||||||
|   variants: function |   variants: function | ||||||
|   dispatch: |  | ||||||
|     CompositeImplicitAutograd: linalg_lu_factor |  | ||||||
|     MPS: linalg_lu_factor_mps |  | ||||||
|  |  | ||||||
| - func: linalg_lu_factor.out(Tensor A, *, bool pivot=True, Tensor(a!) LU, Tensor(b!) pivots) -> (Tensor(a!) LU, Tensor(b!) pivots) | - func: linalg_lu_factor.out(Tensor A, *, bool pivot=True, Tensor(a!) LU, Tensor(b!) pivots) -> (Tensor(a!) LU, Tensor(b!) pivots) | ||||||
|   python_module: linalg |   python_module: linalg | ||||||
|   variants: function |   variants: function | ||||||
|   dispatch: |  | ||||||
|     CompositeImplicitAutograd: linalg_lu_factor_out |  | ||||||
|     MPS: linalg_lu_factor_out_mps |  | ||||||
|  |  | ||||||
| - func: linalg_lu_factor_ex(Tensor A, *, bool pivot=True, bool check_errors=False) -> (Tensor LU, Tensor pivots, Tensor info) | - func: linalg_lu_factor_ex(Tensor A, *, bool pivot=True, bool check_errors=False) -> (Tensor LU, Tensor pivots, Tensor info) | ||||||
|   python_module: linalg |   python_module: linalg | ||||||
| @ -14368,12 +14359,12 @@ | |||||||
|   python_module: linalg |   python_module: linalg | ||||||
|   variants: function |   variants: function | ||||||
|   dispatch: |   dispatch: | ||||||
|     CPU, CUDA: linalg_householder_product |     CPU, CUDA, MPS: linalg_householder_product | ||||||
|  |  | ||||||
| - func: linalg_householder_product.out(Tensor input, Tensor tau, *, Tensor(a!) out) -> Tensor(a!) | - func: linalg_householder_product.out(Tensor input, Tensor tau, *, Tensor(a!) out) -> Tensor(a!) | ||||||
|   python_module: linalg |   python_module: linalg | ||||||
|   dispatch: |   dispatch: | ||||||
|     CPU, CUDA: linalg_householder_product_out |     CPU, CUDA, MPS: linalg_householder_product_out | ||||||
|  |  | ||||||
| - func: linalg_inv_ex(Tensor A, *, bool check_errors=False) -> (Tensor inverse, Tensor info) | - func: linalg_inv_ex(Tensor A, *, bool check_errors=False) -> (Tensor inverse, Tensor info) | ||||||
|   python_module: linalg |   python_module: linalg | ||||||
|  | |||||||
| @ -575,24 +575,9 @@ void spmm( | |||||||
|   cusparseOperation_t opB = transpose_B ? CUSPARSE_OPERATION_TRANSPOSE |   cusparseOperation_t opB = transpose_B ? CUSPARSE_OPERATION_TRANSPOSE | ||||||
|                                         : CUSPARSE_OPERATION_NON_TRANSPOSE; |                                         : CUSPARSE_OPERATION_NON_TRANSPOSE; | ||||||
|  |  | ||||||
|   // CUDA < 11.0 doesn't support 64-bit indices and doesn't raise an error about this |  | ||||||
|   // silently returning incorrect results |  | ||||||
| #if defined(USE_ROCM) && (ROCM_VERSION < 60300) |  | ||||||
|   auto mat1_32 = at::native::_sparse_csr_tensor_unsafe( |  | ||||||
|       mat1.crow_indices().to(kInt), |  | ||||||
|       mat1.col_indices().to(kInt), |  | ||||||
|       mat1.values(), |  | ||||||
|       mat1.sizes(), |  | ||||||
|       mat1.scalar_type(), |  | ||||||
|       mat1.layout(), |  | ||||||
|       mat1.device()); |  | ||||||
|   auto descA = at::cuda::sparse::CuSparseSpMatCsrDescriptor(mat1_32); |  | ||||||
|   auto algorithm = CUSPARSE_MM_ALG_DEFAULT; |  | ||||||
| #else // defined(USE_ROCM) && (ROCM_VERSION < 60300) |  | ||||||
|   // TODO: update this to support COO sparse layout |   // TODO: update this to support COO sparse layout | ||||||
|   auto descA = at::cuda::sparse::CuSparseSpMatCsrDescriptor(mat1); |   auto descA = at::cuda::sparse::CuSparseSpMatCsrDescriptor(mat1); | ||||||
|   auto algorithm = CUSPARSE_SPMM_CSR_ALG2; |   auto algorithm = CUSPARSE_SPMM_CSR_ALG2; | ||||||
| #endif // defined(USE_ROCM) && (ROCM_VERSION < 60300) |  | ||||||
|  |  | ||||||
|   auto descB = at::cuda::sparse::CuSparseConstDnMatDescriptor( |   auto descB = at::cuda::sparse::CuSparseConstDnMatDescriptor( | ||||||
|       transpose_B ? mat2_->mT() : *mat2_); |       transpose_B ? mat2_->mT() : *mat2_); | ||||||
|  | |||||||
| @ -40,15 +40,7 @@ | |||||||
| #include <thrust/iterator/discard_iterator.h> | #include <thrust/iterator/discard_iterator.h> | ||||||
|  |  | ||||||
|  |  | ||||||
| #if defined(__CUDACC__) && (defined(CUSPARSE_VERSION) || (defined(USE_ROCM) && ROCM_VERSION >= 60300)) |  | ||||||
| #define IS_CUSPARSE11_AVAILABLE() 1 |  | ||||||
| #else |  | ||||||
| #define IS_CUSPARSE11_AVAILABLE() 0 |  | ||||||
| #endif |  | ||||||
|  |  | ||||||
| #if IS_CUSPARSE11_AVAILABLE() |  | ||||||
| #include <library_types.h> | #include <library_types.h> | ||||||
| #endif |  | ||||||
|  |  | ||||||
| namespace at::native { | namespace at::native { | ||||||
|  |  | ||||||
| @ -103,17 +95,9 @@ struct csrMatrixRef { | |||||||
|   int nnz_{0}; |   int nnz_{0}; | ||||||
|   std::vector<int> size_{}; |   std::vector<int> size_{}; | ||||||
|  |  | ||||||
|   #if IS_CUSPARSE11_AVAILABLE() |   cusparseSpMatDescr_t description_{0}; | ||||||
|     cusparseSpMatDescr_t description_{0}; |  | ||||||
|   #else |  | ||||||
|     cusparseMatDescr_t description_{0}; |  | ||||||
|   #endif |  | ||||||
|  |  | ||||||
|   csrMatrixRef() { |   csrMatrixRef() = default; | ||||||
|     #if !IS_CUSPARSE11_AVAILABLE() |  | ||||||
|       create_general_description_(description_); |  | ||||||
|     #endif |  | ||||||
|   } |  | ||||||
|  |  | ||||||
|   csrMatrixRef( |   csrMatrixRef( | ||||||
|       int* csr_indices, |       int* csr_indices, | ||||||
| @ -126,7 +110,6 @@ struct csrMatrixRef { | |||||||
|         csr_values_{csr_values}, |         csr_values_{csr_values}, | ||||||
|         nnz_{nnz}, |         nnz_{nnz}, | ||||||
|         size_{size} { |         size_{size} { | ||||||
|     #if IS_CUSPARSE11_AVAILABLE() |  | ||||||
|       cudaDataType cuda_data_type = at::cuda::getCudaDataType<scalar_t>(); |       cudaDataType cuda_data_type = at::cuda::getCudaDataType<scalar_t>(); | ||||||
|       TORCH_CUDASPARSE_CHECK(cusparseCreateCsr( |       TORCH_CUDASPARSE_CHECK(cusparseCreateCsr( | ||||||
|         &description_, |         &description_, | ||||||
| @ -140,17 +123,10 @@ struct csrMatrixRef { | |||||||
|         CUSPARSE_INDEX_32I, |         CUSPARSE_INDEX_32I, | ||||||
|         CUSPARSE_INDEX_BASE_ZERO, |         CUSPARSE_INDEX_BASE_ZERO, | ||||||
|         cuda_data_type)); |         cuda_data_type)); | ||||||
|     #else |  | ||||||
|       create_general_description_(description_); |  | ||||||
|     #endif |  | ||||||
|   } |   } | ||||||
|  |  | ||||||
|   ~csrMatrixRef() { |   ~csrMatrixRef() { | ||||||
|     #if IS_CUSPARSE11_AVAILABLE() |     cusparseDestroySpMat(description_); | ||||||
|       cusparseDestroySpMat(description_); |  | ||||||
|     #else |  | ||||||
|       cusparseDestroyMatDescr(description_); |  | ||||||
|     #endif |  | ||||||
|   } |   } | ||||||
|  |  | ||||||
|   int size(int index) const { |   int size(int index) const { | ||||||
| @ -196,8 +172,6 @@ struct csrOutput { | |||||||
|   } |   } | ||||||
| }; | }; | ||||||
|  |  | ||||||
| #if IS_CUSPARSE11_AVAILABLE() |  | ||||||
|  |  | ||||||
| // RAII guard helps to support cuSparse 11 API for `A @ B` operation | // RAII guard helps to support cuSparse 11 API for `A @ B` operation | ||||||
| // This generic template exists because with cuSparse the `scalar_t` type could be a double or float | // This generic template exists because with cuSparse the `scalar_t` type could be a double or float | ||||||
| template <class scalar_t> | template <class scalar_t> | ||||||
| @ -396,284 +370,6 @@ template struct CusparseMatrixMultiplyOp<float>; | |||||||
|  |  | ||||||
| template struct CusparseMatrixMultiplyOp<double>; | template struct CusparseMatrixMultiplyOp<double>; | ||||||
|  |  | ||||||
| #else // if not IS_CUSPARSE11_AVAILABLE() |  | ||||||
|  |  | ||||||
| using DcsrMatrixRef = csrMatrixRef<double>; |  | ||||||
| using ScsrMatrixRef = csrMatrixRef<float>; |  | ||||||
|  |  | ||||||
| // RAII guard helps to support cuSparse 10 API for `A @ B` operation |  | ||||||
| // This generic template exists because with cuSparse the `scalar_t` type could be a double or float |  | ||||||
| template <class scalar_t> |  | ||||||
| struct CusparseMatrixMultiplyOp { |  | ||||||
|   csrOutput operator()( |  | ||||||
|       const csrMatrixRef<scalar_t>& lhs, |  | ||||||
|       const csrMatrixRef<scalar_t>& rhs, |  | ||||||
|       Tensor &output_values, |  | ||||||
|       Tensor &output_indices) |  | ||||||
|   { |  | ||||||
|     static_assert(false&&sizeof(scalar_t), "cusparse csr sparse-sparse MM only supports data type of float and double."); |  | ||||||
|   } |  | ||||||
| }; |  | ||||||
|  |  | ||||||
| // Specializacion for `A @ B` operation for double values with cuSparse |  | ||||||
| template<> struct CusparseMatrixMultiplyOp<double> { |  | ||||||
|   csrgemm2Info_t gemm2Info_; |  | ||||||
|  |  | ||||||
|   CusparseMatrixMultiplyOp() { |  | ||||||
|     TORCH_CUDASPARSE_CHECK(cusparseCreateCsrgemm2Info(&gemm2Info_)); |  | ||||||
|   } |  | ||||||
|   ~CusparseMatrixMultiplyOp() { |  | ||||||
|     cusparseDestroyCsrgemm2Info(gemm2Info_); |  | ||||||
|   } |  | ||||||
|  |  | ||||||
|   csrOutput operator ()( |  | ||||||
|       const DcsrMatrixRef& lhs, |  | ||||||
|       const DcsrMatrixRef& rhs, |  | ||||||
|       Tensor &output_values, |  | ||||||
|       Tensor &output_indices) { |  | ||||||
|     double alpha = 1.0; |  | ||||||
|     DcsrMatrixRef empty; |  | ||||||
|     return Dgemm2(lhs, rhs, empty, &alpha, nullptr, output_values, output_indices); |  | ||||||
|   } |  | ||||||
|  |  | ||||||
|   csrOutput Dgemm2( |  | ||||||
|       const DcsrMatrixRef& A, |  | ||||||
|       const DcsrMatrixRef& B, |  | ||||||
|       const DcsrMatrixRef& C, |  | ||||||
|       const double* alpha, |  | ||||||
|       const double* beta, |  | ||||||
|       Tensor &output_values, |  | ||||||
|       Tensor &output_indices) { |  | ||||||
|     void* buffer_{nullptr}; |  | ||||||
|     cusparseHandle_t cusparseHandle_ = at::cuda::getCurrentCUDASparseHandle(); |  | ||||||
|     TORCH_CUDASPARSE_CHECK(cusparseSetPointerMode(cusparseHandle_, CUSPARSE_POINTER_MODE_HOST)); |  | ||||||
|  |  | ||||||
|     csrOutput out({A.size(0), B.size(1)}); |  | ||||||
|     int innerSize = confirm_mult_size(A.size_, B.size_); |  | ||||||
|     out.csr_pointers_ = at::empty({out.size(0) + 1}, output_indices.options().dtype(kInt)); |  | ||||||
|  |  | ||||||
|     // Compute needed buffer size |  | ||||||
|     size_t new_bubber_sz; |  | ||||||
|     TORCH_CUDASPARSE_CHECK(cusparseDcsrgemm2_bufferSizeExt( |  | ||||||
|         cusparseHandle_, |  | ||||||
|         out.size(0), |  | ||||||
|         out.size(1), |  | ||||||
|         innerSize, |  | ||||||
|         alpha, |  | ||||||
|         A.description_, |  | ||||||
|         A.nnz_, |  | ||||||
|         A.csr_pointers_, |  | ||||||
|         A.csr_indices_, |  | ||||||
|         B.description_, |  | ||||||
|         B.nnz_, |  | ||||||
|         B.csr_pointers_, |  | ||||||
|         B.csr_indices_, |  | ||||||
|         beta, |  | ||||||
|         C.description_, |  | ||||||
|         C.nnz_, |  | ||||||
|         C.csr_pointers_, |  | ||||||
|         C.csr_indices_, |  | ||||||
|         gemm2Info_, |  | ||||||
|         &new_bubber_sz)); |  | ||||||
|  |  | ||||||
|     // (Re)allocate buffer if needed |  | ||||||
|     auto& allocator = *::c10::cuda::CUDACachingAllocator::get(); |  | ||||||
|     at::DataPtr data_ptr = allocator.allocate(new_bubber_sz); |  | ||||||
|     buffer_ = data_ptr.get(); |  | ||||||
|  |  | ||||||
|     // Find the resulting non-zero pattern. |  | ||||||
|     TORCH_CUDASPARSE_CHECK(cusparseXcsrgemm2Nnz( |  | ||||||
|         cusparseHandle_, |  | ||||||
|         out.size(0), |  | ||||||
|         out.size(1), |  | ||||||
|         innerSize, |  | ||||||
|         A.description_, |  | ||||||
|         A.nnz_, |  | ||||||
|         A.csr_pointers_, |  | ||||||
|         A.csr_indices_, |  | ||||||
|         B.description_, |  | ||||||
|         B.nnz_, |  | ||||||
|         B.csr_pointers_, |  | ||||||
|         B.csr_indices_, |  | ||||||
|         C.description_, |  | ||||||
|         C.nnz_, |  | ||||||
|         C.csr_pointers_, |  | ||||||
|         C.csr_indices_, |  | ||||||
|         out.description_, |  | ||||||
|         out.csr_pointers_.data_ptr<int>(), |  | ||||||
|         &out.nnz_, |  | ||||||
|         gemm2Info_, |  | ||||||
|         buffer_)); |  | ||||||
|  |  | ||||||
|     out.csr_indices_ = at::empty({out.nnz_}, output_indices.options().dtype(kInt)); |  | ||||||
|     out.csr_values_ = at::empty({out.nnz_}, output_values.options()); |  | ||||||
|  |  | ||||||
|     // Perform the gemm2 operation for doubles |  | ||||||
|     // out = alpha ∗ A ∗ B + beta ∗ C |  | ||||||
|     TORCH_CUDASPARSE_CHECK(cusparseDcsrgemm2( |  | ||||||
|         cusparseHandle_, |  | ||||||
|         out.size(0), |  | ||||||
|         out.size(1), |  | ||||||
|         innerSize, |  | ||||||
|         alpha, |  | ||||||
|         A.description_, |  | ||||||
|         A.nnz_, |  | ||||||
|         A.csr_values_, |  | ||||||
|         A.csr_pointers_, |  | ||||||
|         A.csr_indices_, |  | ||||||
|         B.description_, |  | ||||||
|         B.nnz_, |  | ||||||
|         B.csr_values_, |  | ||||||
|         B.csr_pointers_, |  | ||||||
|         B.csr_indices_, |  | ||||||
|         beta, |  | ||||||
|         C.description_, |  | ||||||
|         C.nnz_, |  | ||||||
|         C.csr_values_, |  | ||||||
|         C.csr_pointers_, |  | ||||||
|         C.csr_indices_, |  | ||||||
|         out.description_, |  | ||||||
|         out.csr_values_.data_ptr<double>(), |  | ||||||
|         out.csr_pointers_.data_ptr<int>(), |  | ||||||
|         out.csr_indices_.data_ptr<int>(), |  | ||||||
|         gemm2Info_, |  | ||||||
|         buffer_)); |  | ||||||
|     return out; |  | ||||||
|   } |  | ||||||
| }; |  | ||||||
|  |  | ||||||
| // Specializacion for `A @ B` operation for float values with cuSparse |  | ||||||
| template<> struct CusparseMatrixMultiplyOp<float> { |  | ||||||
|   csrgemm2Info_t gemm2Info_; |  | ||||||
|  |  | ||||||
|   CusparseMatrixMultiplyOp() { |  | ||||||
|     TORCH_CUDASPARSE_CHECK(cusparseCreateCsrgemm2Info(&gemm2Info_)); |  | ||||||
|  |  | ||||||
|   } |  | ||||||
|   ~CusparseMatrixMultiplyOp() { |  | ||||||
|     cusparseDestroyCsrgemm2Info(gemm2Info_); |  | ||||||
|   } |  | ||||||
|   csrOutput operator()( |  | ||||||
|       const ScsrMatrixRef& lhs, |  | ||||||
|       const ScsrMatrixRef& rhs, |  | ||||||
|       Tensor &output_values, |  | ||||||
|       Tensor &output_indices) { |  | ||||||
|     float alpha = 1.0; |  | ||||||
|     ScsrMatrixRef empty; |  | ||||||
|     return Sgemm2(lhs, rhs, empty, &alpha, nullptr, output_values, output_indices); |  | ||||||
|   } |  | ||||||
|  |  | ||||||
|   csrOutput Sgemm2( |  | ||||||
|       const ScsrMatrixRef& A, |  | ||||||
|       const ScsrMatrixRef& B, |  | ||||||
|       const ScsrMatrixRef& C, |  | ||||||
|       const float* alpha, |  | ||||||
|       const float* beta, |  | ||||||
|       Tensor &output_values, |  | ||||||
|       Tensor &output_indices) { |  | ||||||
|     void* buffer_{nullptr}; |  | ||||||
|     cusparseHandle_t cusparseHandle_ = at::cuda::getCurrentCUDASparseHandle(); |  | ||||||
|     TORCH_CUDASPARSE_CHECK(cusparseSetPointerMode(cusparseHandle_, CUSPARSE_POINTER_MODE_HOST)); |  | ||||||
|  |  | ||||||
|     csrOutput out({A.size(0), B.size(1)}); |  | ||||||
|  |  | ||||||
|     int innerSize = confirm_mult_size(A.size_, B.size_); |  | ||||||
|  |  | ||||||
|     out.csr_pointers_ = at::empty({out.size(0) + 1}, output_indices.options().dtype(kInt)); |  | ||||||
|  |  | ||||||
|     // Compute needed buffer size |  | ||||||
|     size_t new_bubber_sz; |  | ||||||
|     TORCH_CUDASPARSE_CHECK(cusparseScsrgemm2_bufferSizeExt( |  | ||||||
|         cusparseHandle_, |  | ||||||
|         out.size(0), |  | ||||||
|         out.size(1), |  | ||||||
|         innerSize, |  | ||||||
|         alpha, |  | ||||||
|         A.description_, |  | ||||||
|         A.nnz_, |  | ||||||
|         A.csr_pointers_, |  | ||||||
|         A.csr_indices_, |  | ||||||
|         B.description_, |  | ||||||
|         B.nnz_, |  | ||||||
|         B.csr_pointers_, |  | ||||||
|         B.csr_indices_, |  | ||||||
|         beta, |  | ||||||
|         C.description_, |  | ||||||
|         C.nnz_, |  | ||||||
|         C.csr_pointers_, |  | ||||||
|         C.csr_indices_, |  | ||||||
|         gemm2Info_, |  | ||||||
|         &new_bubber_sz)); |  | ||||||
|  |  | ||||||
|     auto& allocator = *::c10::cuda::CUDACachingAllocator::get(); |  | ||||||
|     at::DataPtr data_ptr = allocator.allocate(new_bubber_sz); |  | ||||||
|     buffer_ = data_ptr.get(); |  | ||||||
|  |  | ||||||
|     // Find the resulting non-zero pattern. |  | ||||||
|     TORCH_CUDASPARSE_CHECK(cusparseXcsrgemm2Nnz( |  | ||||||
|         cusparseHandle_, |  | ||||||
|         out.size(0), |  | ||||||
|         out.size(1), |  | ||||||
|         innerSize, |  | ||||||
|         A.description_, |  | ||||||
|         A.nnz_, |  | ||||||
|         A.csr_pointers_, |  | ||||||
|         A.csr_indices_, |  | ||||||
|         B.description_, |  | ||||||
|         B.nnz_, |  | ||||||
|         B.csr_pointers_, |  | ||||||
|         B.csr_indices_, |  | ||||||
|         C.description_, |  | ||||||
|         C.nnz_, |  | ||||||
|         C.csr_pointers_, |  | ||||||
|         C.csr_indices_, |  | ||||||
|         out.description_, |  | ||||||
|         out.csr_pointers_.data_ptr<int>(), |  | ||||||
|         &out.nnz_, |  | ||||||
|         gemm2Info_, |  | ||||||
|         buffer_)); |  | ||||||
|  |  | ||||||
|     out.csr_indices_ = at::empty({out.nnz_}, output_indices.options().dtype(kInt)); |  | ||||||
|     out.csr_values_ = at::empty({out.nnz_}, output_values.options()); |  | ||||||
|  |  | ||||||
|     // Perform the gemm2 operation for doubles |  | ||||||
|     // out = alpha ∗ A ∗ B + beta ∗ C |  | ||||||
|     TORCH_CUDASPARSE_CHECK(cusparseScsrgemm2( |  | ||||||
|         cusparseHandle_, |  | ||||||
|         out.size(0), |  | ||||||
|         out.size(1), |  | ||||||
|         innerSize, |  | ||||||
|         alpha, |  | ||||||
|         A.description_, |  | ||||||
|         A.nnz_, |  | ||||||
|         A.csr_values_, |  | ||||||
|         A.csr_pointers_, |  | ||||||
|         A.csr_indices_, |  | ||||||
|         B.description_, |  | ||||||
|         B.nnz_, |  | ||||||
|         B.csr_values_, |  | ||||||
|         B.csr_pointers_, |  | ||||||
|         B.csr_indices_, |  | ||||||
|         beta, |  | ||||||
|         C.description_, |  | ||||||
|         C.nnz_, |  | ||||||
|         C.csr_values_, |  | ||||||
|         C.csr_pointers_, |  | ||||||
|         C.csr_indices_, |  | ||||||
|         out.description_, |  | ||||||
|         out.csr_values_.data_ptr<float>(), |  | ||||||
|         out.csr_pointers_.data_ptr<int>(), |  | ||||||
|         out.csr_indices_.data_ptr<int>(), |  | ||||||
|         gemm2Info_, |  | ||||||
|         buffer_)); |  | ||||||
|     return out; |  | ||||||
|   } |  | ||||||
| }; |  | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| #endif // IS_CUSPARSE11_AVAILABLE() |  | ||||||
|  |  | ||||||
| template <typename scalar_t> | template <typename scalar_t> | ||||||
| void sparse_sparse_matmul_cuda_kernel( | void sparse_sparse_matmul_cuda_kernel( | ||||||
|     Tensor& result, |     Tensor& result, | ||||||
| @ -815,19 +511,15 @@ Tensor sparse_sparse_matmul_cuda(const Tensor& mat1_, const Tensor& mat2_) { | |||||||
|   auto output = at::native::empty_like(mat1_); |   auto output = at::native::empty_like(mat1_); | ||||||
|   output.sparse_resize_and_clear_({mat1_.size(0), mat2_.size(1)}, mat1_.sparse_dim(), 0); |   output.sparse_resize_and_clear_({mat1_.size(0), mat2_.size(1)}, mat1_.sparse_dim(), 0); | ||||||
|  |  | ||||||
| #if IS_CUSPARSE11_AVAILABLE() && !defined(USE_ROCM) | #if !defined(USE_ROCM) | ||||||
|   AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kHalf, kBFloat16, mat1_.scalar_type(), "sparse_matmul", [&] { |   AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kHalf, kBFloat16, mat1_.scalar_type(), "sparse_matmul", [&] { | ||||||
|       sparse_sparse_matmul_cuda_kernel<scalar_t>(output, mat1_.coalesce(), mat2_.coalesce()); |       sparse_sparse_matmul_cuda_kernel<scalar_t>(output, mat1_.coalesce(), mat2_.coalesce()); | ||||||
|   }); |   }); | ||||||
| #elif IS_CUSPARSE11_AVAILABLE() && defined(USE_ROCM) | #else | ||||||
|   // ROCm does not support half and bfloat16 types for sparse_matmul |   // ROCm does not support half and bfloat16 types for sparse_matmul | ||||||
|   AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(mat1_.scalar_type(), "sparse_matmul", [&] { |   AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(mat1_.scalar_type(), "sparse_matmul", [&] { | ||||||
|       sparse_sparse_matmul_cuda_kernel<scalar_t>(output, mat1_.coalesce(), mat2_.coalesce()); |       sparse_sparse_matmul_cuda_kernel<scalar_t>(output, mat1_.coalesce(), mat2_.coalesce()); | ||||||
|   }); |   }); | ||||||
| #else |  | ||||||
|   AT_DISPATCH_FLOATING_TYPES(mat1_.scalar_type(), "sparse_matmul", [&] { |  | ||||||
|     sparse_sparse_matmul_cuda_kernel<scalar_t>(output, mat1_.coalesce(), mat2_.coalesce()); |  | ||||||
|   }); |  | ||||||
| #endif | #endif | ||||||
|   return output; |   return output; | ||||||
| } | } | ||||||
|  | |||||||
| @ -33,7 +33,7 @@ using namespace mps; | |||||||
| #ifndef PYTORCH_JIT_COMPILE_SHADERS | #ifndef PYTORCH_JIT_COMPILE_SHADERS | ||||||
| static auto& lib = MetalShaderLibrary::getBundledLibrary(); | static auto& lib = MetalShaderLibrary::getBundledLibrary(); | ||||||
| #else | #else | ||||||
| #include <ATen/native/mps/Mul_metallib.h> | #include <ATen/native/mps/SparseTensorMath_metallib.h> | ||||||
| #endif | #endif | ||||||
|  |  | ||||||
| static Tensor& s_addmm_out_sparse_dense_mps( | static Tensor& s_addmm_out_sparse_dense_mps( | ||||||
| @ -369,12 +369,7 @@ static SparseTensor& mul_out_dense_sparse_mps( | |||||||
|   } |   } | ||||||
|  |  | ||||||
|   if (scalar_like) { |   if (scalar_like) { | ||||||
|     auto scalar = dense; |     auto out_vals = values.mul(dense.to(values.options())); | ||||||
|     if (dense.numel() == 1 && dense.dim() > 0) { |  | ||||||
|       scalar = dense.view({}); |  | ||||||
|     } |  | ||||||
|     scalar = scalar.to(values.options()); |  | ||||||
|     auto out_vals = values.mul(scalar); |  | ||||||
|     if (out.scalar_type() != commonDtype) { |     if (out.scalar_type() != commonDtype) { | ||||||
|       out_vals = out_vals.to(out.scalar_type()); |       out_vals = out_vals.to(out.scalar_type()); | ||||||
|     } |     } | ||||||
| @ -508,14 +503,14 @@ SparseTensor& mul_out_sparse_mps(const Tensor& t_, const Tensor& src_, SparseTen | |||||||
|   const auto device = r_.device(); |   const auto device = r_.device(); | ||||||
|   auto stream = getCurrentMPSStream(); |   auto stream = getCurrentMPSStream(); | ||||||
|  |  | ||||||
|   auto lhs_indices = lhs._indices(); |   auto lhs_indices = lhs._indices().contiguous(); | ||||||
|   auto rhs_indices = rhs._indices(); |   auto rhs_indices = rhs._indices().contiguous(); | ||||||
|   auto lhs_values  = lhs._values().to(commonDtype); |   auto lhs_values  = lhs._values().to(commonDtype).contiguous(); | ||||||
|   auto rhs_values  = rhs._values().to(commonDtype); |   auto rhs_values  = rhs._values().to(commonDtype).contiguous(); | ||||||
|  |  | ||||||
|   // Flatten sparse indices to keys |   // Flatten sparse indices to keys | ||||||
|   auto lhs_keys = flatten_indices(lhs_indices, lhs.sizes()); |   auto lhs_keys = flatten_indices(lhs_indices, lhs.sizes().slice(0, ndim_i)); | ||||||
|   auto rhs_keys = flatten_indices(rhs_indices, rhs.sizes()); |   auto rhs_keys = flatten_indices(rhs_indices, rhs.sizes().slice(0, ndim_i)); | ||||||
|  |  | ||||||
|   // Intersect sorted keys (search the shorter in the longer) |   // Intersect sorted keys (search the shorter in the longer) | ||||||
|   const bool A_is_lhs = (lhs_nnz <= rhs_nnz); |   const bool A_is_lhs = (lhs_nnz <= rhs_nnz); | ||||||
| @ -546,35 +541,54 @@ SparseTensor& mul_out_sparse_mps(const Tensor& t_, const Tensor& src_, SparseTen | |||||||
|   auto out_indices = at::empty({ndim_i, static_cast<int64_t>(M)}, at::device(device).dtype(at::kLong)); |   auto out_indices = at::empty({ndim_i, static_cast<int64_t>(M)}, at::device(device).dtype(at::kLong)); | ||||||
|   auto lhs_match = outA_idx.narrow(0, 0, M); |   auto lhs_match = outA_idx.narrow(0, 0, M); | ||||||
|   auto rhs_match = outB_idx.narrow(0, 0, M); |   auto rhs_match = outB_idx.narrow(0, 0, M); | ||||||
|   auto out_val_sizes = lhs_values.sizes().vec(); |   auto dense_sizes_vec = lhs.sizes().slice(ndim_i).vec(); | ||||||
|   out_val_sizes[0] = static_cast<int64_t>(M); |   int64_t cols64 = 1; | ||||||
|  |   for (auto s : dense_sizes_vec) cols64 *= s; | ||||||
|  |   const uint32_t cols = static_cast<uint32_t>(std::max<int64_t>(cols64, 1)); | ||||||
|  |  | ||||||
|  |   auto to2d = [&](Tensor t, int64_t nnz) -> Tensor { | ||||||
|  |     const int64_t t_cols = t.numel() / nnz; | ||||||
|  |     if (t_cols == cols64) { | ||||||
|  |       return t.view({nnz, cols64}); | ||||||
|  |     } | ||||||
|  |     return t.view({nnz, 1}).expand({nnz, cols64}).contiguous(); | ||||||
|  |   }; | ||||||
|  |  | ||||||
|  |   // make both sides 2d [nnz, cols] buffers so the kernel can index it | ||||||
|  |   auto lhs_vals2d = to2d(lhs_values, lhs_nnz); | ||||||
|  |   auto rhs_vals2d = to2d(rhs_values, rhs_nnz); | ||||||
|  |  | ||||||
|  |   std::vector<int64_t> out_val_sizes; | ||||||
|  |   out_val_sizes.reserve(1 + dense_sizes_vec.size()); | ||||||
|  |   out_val_sizes.push_back(static_cast<int64_t>(M)); | ||||||
|  |   out_val_sizes.insert(out_val_sizes.end(), dense_sizes_vec.begin(), dense_sizes_vec.end()); | ||||||
|   auto out_values = at::empty(out_val_sizes, lhs_values.options()); |   auto out_values = at::empty(out_val_sizes, lhs_values.options()); | ||||||
|  |  | ||||||
|   const uint32_t cols = static_cast<uint32_t>( |   if (M > 0) { | ||||||
|       lhs_values.numel() / std::max<int64_t>(1, lhs_nnz)); |     dispatch_sync_with_rethrow(stream->queue(), ^() { | ||||||
|  |       @autoreleasepool { | ||||||
|  |         auto pso = lib.getPipelineStateForFunc( | ||||||
|  |             "fused_gather_mul_kernel_" + mps::scalarToMetalTypeString(lhs_values)); | ||||||
|  |         auto enc = stream->commandEncoder(); | ||||||
|  |         [enc setComputePipelineState:pso]; | ||||||
|  |  | ||||||
|   dispatch_sync_with_rethrow(stream->queue(), ^() { |         const uint32_t tew = pso.threadExecutionWidth; | ||||||
|     @autoreleasepool { |         const uint32_t gridW = std::max<uint32_t>(cols, 1u); | ||||||
|       auto pso = lib.getPipelineStateForFunc( |         const uint32_t tgW = std::min(gridW, tew); | ||||||
|           "fused_gather_mul_kernel_" + mps::scalarToMetalTypeString(lhs_values)); |         MTLSize grid = MTLSizeMake(gridW, 1, M); | ||||||
|       auto enc = stream->commandEncoder(); |         MTLSize tgs  = MTLSizeMake(tgW, 1, 1); | ||||||
|       [enc setComputePipelineState:pso]; |  | ||||||
|  |  | ||||||
|       const uint32_t tew  = pso.threadExecutionWidth; |         mtl_setArgs(enc, | ||||||
|       uint32_t tgW = std::min(cols, tew); |                     lhs_vals2d, rhs_vals2d, | ||||||
|       MTLSize grid = MTLSizeMake(cols, 1, M); |                     lhs_match, rhs_match, | ||||||
|       MTLSize tgs  = MTLSizeMake(tgW, 1, 1); |                     lhs_indices, out_indices, | ||||||
|  |                     out_values, | ||||||
|       mtl_setArgs(enc, |                     std::array<uint32_t, 2>{static_cast<uint32_t>(ndim_i), static_cast<uint32_t>(lhs_nnz)}, | ||||||
|                   lhs_values, rhs_values, |                     std::array<uint32_t, 2>{M, cols}); | ||||||
|                   lhs_match, rhs_match, |         [enc dispatchThreads:grid threadsPerThreadgroup:tgs]; | ||||||
|                   lhs_indices, out_indices, |       } | ||||||
|                   out_values, |     }); | ||||||
|                   std::array<uint32_t, 2>{static_cast<uint32_t>(ndim_i), static_cast<uint32_t>(lhs_nnz)}, |   } | ||||||
|                   std::array<uint32_t, 2>{M, cols}); |  | ||||||
|       [enc dispatchThreads:grid threadsPerThreadgroup:tgs]; |  | ||||||
|     } |  | ||||||
|   }); |  | ||||||
|  |  | ||||||
|   if (r_.scalar_type() != commonDtype) { |   if (r_.scalar_type() != commonDtype) { | ||||||
|     out_values = out_values.to(r_.scalar_type()); |     out_values = out_values.to(r_.scalar_type()); | ||||||
|  | |||||||
| @ -62,7 +62,6 @@ kernel void build_row_ptr_from_sorted_rows_by_batch( | |||||||
| 
 | 
 | ||||||
| template <typename T> | template <typename T> | ||||||
| kernel void spmm_bmm_coo_rows_grouped( | kernel void spmm_bmm_coo_rows_grouped( | ||||||
|     device const long*   rows      [[buffer(0)]], |  | ||||||
|     device const long*   cols      [[buffer(1)]], |     device const long*   cols      [[buffer(1)]], | ||||||
|     device const T*      vals      [[buffer(2)]], |     device const T*      vals      [[buffer(2)]], | ||||||
|     device const T*      dense     [[buffer(3)]], |     device const T*      dense     [[buffer(3)]], | ||||||
| @ -73,7 +72,6 @@ kernel void spmm_bmm_coo_rows_grouped( | |||||||
|     uint3                ltid      [[thread_position_in_threadgroup]], |     uint3                ltid      [[thread_position_in_threadgroup]], | ||||||
|     uint3                tptg      [[threads_per_threadgroup]]) |     uint3                tptg      [[threads_per_threadgroup]]) | ||||||
| { | { | ||||||
|   const uint B = dims.x; |  | ||||||
|   const uint I = dims.y; |   const uint I = dims.y; | ||||||
|   const uint J = dims.z; |   const uint J = dims.z; | ||||||
|   const uint K = dims.w; |   const uint K = dims.w; | ||||||
| @ -197,9 +195,9 @@ kernel void fused_gather_mul_kernel( | |||||||
|     const ulong offR = (ulong)iR * (ulong)view_cols + (ulong)col; |     const ulong offR = (ulong)iR * (ulong)view_cols + (ulong)col; | ||||||
|     const ulong offO = (ulong)k  * (ulong)view_cols + (ulong)col; |     const ulong offO = (ulong)k  * (ulong)view_cols + (ulong)col; | ||||||
| 
 | 
 | ||||||
|     const float a = (float)lhs_vals[offL]; |     const auto a = static_cast<accum_t<T>>(lhs_vals[offL]); | ||||||
|     const float b = (float)rhs_vals[offR]; |     const auto b = static_cast<accum_t<T>>(rhs_vals[offR]); | ||||||
|     out_vals[offO] = (T)(a * b); |     out_vals[offO] = static_cast<T>(mul(a, b)); | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   // One thread per match copies the indices column
 |   // One thread per match copies the indices column
 | ||||||
| @ -321,7 +319,6 @@ INSTANTIATE_FOR_FLOAT_TYPES(INSTANTIATE_FUSED_GATHER_MUL); | |||||||
| #define INSTANTIATE_SPMM_BMM_COO_ROWS_GROUPED(DTYPE)                         \ | #define INSTANTIATE_SPMM_BMM_COO_ROWS_GROUPED(DTYPE)                         \ | ||||||
|   template [[host_name("spmm_bmm_coo_rows_grouped_" #DTYPE)]] kernel void    \ |   template [[host_name("spmm_bmm_coo_rows_grouped_" #DTYPE)]] kernel void    \ | ||||||
|   spmm_bmm_coo_rows_grouped<DTYPE>(                                          \ |   spmm_bmm_coo_rows_grouped<DTYPE>(                                          \ | ||||||
|       device const long*   rows      [[buffer(0)]],                          \ |  | ||||||
|       device const long*   cols      [[buffer(1)]],                          \ |       device const long*   cols      [[buffer(1)]],                          \ | ||||||
|       device const DTYPE*  vals      [[buffer(2)]],                          \ |       device const DTYPE*  vals      [[buffer(2)]],                          \ | ||||||
|       device const DTYPE*  dense     [[buffer(3)]],                          \ |       device const DTYPE*  dense     [[buffer(3)]],                          \ | ||||||
| @ -76,14 +76,21 @@ bool priority_order_init_ = false; | |||||||
| // TODO(eqy): more benchmarking to determine whether this should include sm86/89 | // TODO(eqy): more benchmarking to determine whether this should include sm86/89 | ||||||
| // Needs to be kept in-sync with test_fused_chocie in test_transformers.py | // Needs to be kept in-sync with test_fused_chocie in test_transformers.py | ||||||
| bool check_prefer_cudnn_attention() { | bool check_prefer_cudnn_attention() { | ||||||
|   static const bool prefer_cudnn = c10::utils::check_env("TORCH_CUDNN_SDPA_PREFERRED") != false; |   static const bool prefer_cudnn = c10::utils::check_env("TORCH_CUDNN_SDPA_DEPRIORITIZED") != true; | ||||||
|   if (!prefer_cudnn) { |   if (!prefer_cudnn) { | ||||||
|     return false; |     return false; | ||||||
|   } |   } | ||||||
| #if (defined(CUDNN_VERSION) && (CUDNN_VERSION >= 90900)) | #if (defined(CUDNN_VERSION) && (CUDNN_VERSION >= 90900)) | ||||||
|   auto dprops = at::cuda::getCurrentDeviceProperties(); |   try { | ||||||
|   auto major = dprops->major; |     auto dprops = at::cuda::getCurrentDeviceProperties(); | ||||||
|   return (major == 9 || major == 10) && !dprops->minor; |     auto major = dprops->major; | ||||||
|  |     return (major == 9 || major == 10) && !dprops->minor; | ||||||
|  |   } catch (c10::Error const& e) { | ||||||
|  | #ifdef DEBUG | ||||||
|  |     TORCH_WARN("check_prefer_cudnn_attention() caught exception ", e.what()); | ||||||
|  | #endif | ||||||
|  |     return false; | ||||||
|  |   } | ||||||
| #else | #else | ||||||
|   return false; |   return false; | ||||||
| #endif | #endif | ||||||
|  | |||||||
| @ -37,6 +37,10 @@ TEST(SingletonOrSharedTypePtr, Comparison) { | |||||||
|  |  | ||||||
|   EXPECT_NE(empty, p); |   EXPECT_NE(empty, p); | ||||||
|   EXPECT_NE(p, p2); |   EXPECT_NE(p, p2); | ||||||
|  |  | ||||||
|  |   EXPECT_EQ(empty, empty); | ||||||
|  |   EXPECT_EQ(p, p); | ||||||
|  |   EXPECT_EQ(p2, p2); | ||||||
| } | } | ||||||
|  |  | ||||||
| TEST(SingletonOrSharedTypePtr, SingletonComparison) { | TEST(SingletonOrSharedTypePtr, SingletonComparison) { | ||||||
| @ -47,6 +51,8 @@ TEST(SingletonOrSharedTypePtr, SingletonComparison) { | |||||||
|   c10::TypePtr type = c10::NoneType::get(); |   c10::TypePtr type = c10::NoneType::get(); | ||||||
|   EXPECT_NE(type, c10::StringType::get()); |   EXPECT_NE(type, c10::StringType::get()); | ||||||
|   EXPECT_NE(type, c10::DeviceObjType::get()); |   EXPECT_NE(type, c10::DeviceObjType::get()); | ||||||
|  |   EXPECT_EQ(type, type); | ||||||
|  |   EXPECT_EQ(type, c10::NoneType::get()); | ||||||
| } | } | ||||||
|  |  | ||||||
|  |  | ||||||
|  | |||||||
| @ -526,6 +526,41 @@ namespace { | |||||||
|             [](const vec& v) { return v.expm1(); }, |             [](const vec& v) { return v.expm1(); }, | ||||||
|             createDefaultUnaryTestCase<vec>(TestSeed(), false, true)); |             createDefaultUnaryTestCase<vec>(TestSeed(), false, true)); | ||||||
|     } |     } | ||||||
|  |     TYPED_TEST(Exponents, ExpU20) { | ||||||
|  |         using vec = TypeParam; | ||||||
|  |         using VT = ValueType<TypeParam>; | ||||||
|  |         using UVT = UvalueType<TypeParam>; | ||||||
|  |  | ||||||
|  |         // Explicit edge values | ||||||
|  |         VT v_too_small = VT(-100.0); // much less than -87.3 | ||||||
|  |         VT exp_too_small = std::exp(v_too_small); | ||||||
|  |         VT v_neg_edge = VT(-0x1.5d5e2ap+6f);   // just at the edge | ||||||
|  |         VT exp_neg_edge = std::exp(v_neg_edge); | ||||||
|  |         VT v_zero = VT(0.0);         // middle, normal case | ||||||
|  |         VT exp_zero = std::exp(v_zero); | ||||||
|  |         VT v_pos_edge = VT(0x1.5d5e2ap+6f);    // just at the edge | ||||||
|  |         VT exp_pos_edge = std::exp(v_pos_edge); | ||||||
|  |         VT v_too_large = VT(100.0);  // much more than 87.3 | ||||||
|  |         VT exp_too_large = std::exp(v_too_large); | ||||||
|  |  | ||||||
|  |         auto test_case = TestingCase<vec>::getBuilder() | ||||||
|  |             // Randoms in normal range, but the .addCustom() below guarantees we hit the special/fallback cases | ||||||
|  |             .addDomain(CheckWithinDomains<UVT>{{{-100, 100}}, false, getDefaultTolerance<UVT>()}) | ||||||
|  |             .addCustom({ {v_too_small}, exp_too_small }) | ||||||
|  |             .addCustom({ {v_neg_edge}, exp_neg_edge }) | ||||||
|  |             .addCustom({ {v_zero}, exp_zero }) | ||||||
|  |             .addCustom({ {v_pos_edge}, exp_pos_edge }) | ||||||
|  |             .addCustom({ {v_too_large}, exp_too_large }) | ||||||
|  |             .setTrialCount(65536) | ||||||
|  |             .setTestSeed(TestSeed()); | ||||||
|  |  | ||||||
|  |         test_unary<vec>( | ||||||
|  |             NAME_INFO(exp_u20_edge_cases), | ||||||
|  |             RESOLVE_OVERLOAD(std::exp), | ||||||
|  |             [](const vec& v) { return v.exp_u20(); }, | ||||||
|  |             test_case | ||||||
|  |         ); | ||||||
|  |     } | ||||||
|     TYPED_TEST(ErrorFunctions, Erf) { |     TYPED_TEST(ErrorFunctions, Erf) { | ||||||
|         using vec = TypeParam; |         using vec = TypeParam; | ||||||
|         test_unary<vec>( |         test_unary<vec>( | ||||||
|  | |||||||
| @ -58,8 +58,7 @@ def list_benchmarks(): | |||||||
|  |  | ||||||
| def run_benchmark( | def run_benchmark( | ||||||
|     benchmark_name: str, |     benchmark_name: str, | ||||||
|     should_visualize: bool = False, |     script_args, | ||||||
|     compile_mode: str = "max-autotune-no-cudagraphs", |  | ||||||
| ): | ): | ||||||
|     """Run a specific benchmark.""" |     """Run a specific benchmark.""" | ||||||
|     if benchmark_name not in BENCHMARK_REGISTRY: |     if benchmark_name not in BENCHMARK_REGISTRY: | ||||||
| @ -68,29 +67,29 @@ def run_benchmark( | |||||||
|         return False |         return False | ||||||
|  |  | ||||||
|     print(f"Running benchmark: {benchmark_name}") |     print(f"Running benchmark: {benchmark_name}") | ||||||
|     print(f"Torch compile mode: {compile_mode}") |     print(f"Torch compile mode: {script_args.compile_mode}") | ||||||
|     print("=" * 60) |     print("=" * 60) | ||||||
|  |  | ||||||
|     benchmark_class = BENCHMARK_REGISTRY[benchmark_name] |     benchmark_class = BENCHMARK_REGISTRY[benchmark_name] | ||||||
|     benchmark = benchmark_class(compile_mode) |     benchmark = benchmark_class(script_args) | ||||||
|     benchmark.benchmark() |     benchmark.benchmark() | ||||||
|     if should_visualize: |     if script_args.visualize: | ||||||
|         benchmark.visualize() |         benchmark.visualize() | ||||||
|  |  | ||||||
|     return True |     return True | ||||||
|  |  | ||||||
|  |  | ||||||
| def run_all_benchmarks(should_visualize: bool = False, compile_mode: str = "default"): | def run_all_benchmarks(script_args): | ||||||
|     """Run all available benchmarks.""" |     """Run all available benchmarks.""" | ||||||
|     print("Running all benchmarks...") |     print("Running all benchmarks...") | ||||||
|     print(f"Torch compile mode: {compile_mode}") |     print(f"Torch compile mode: {script_args.compile_mode}") | ||||||
|     print("=" * 60) |     print("=" * 60) | ||||||
|  |  | ||||||
|     for name, cls in BENCHMARK_REGISTRY.items(): |     for name, cls in BENCHMARK_REGISTRY.items(): | ||||||
|         print(f"\n{'=' * 20} {name.upper()} {'=' * 20}") |         print(f"\n{'=' * 20} {name.upper()} {'=' * 20}") | ||||||
|         benchmark = cls(compile_mode) |         benchmark = cls(script_args) | ||||||
|         benchmark.benchmark() |         benchmark.benchmark() | ||||||
|         if should_visualize: |         if script_args.visualize: | ||||||
|             benchmark.visualize() |             benchmark.visualize() | ||||||
|         print() |         print() | ||||||
|  |  | ||||||
| @ -137,6 +136,19 @@ Examples: | |||||||
|         help="Torch compile mode to use (default: default)", |         help="Torch compile mode to use (default: default)", | ||||||
|     ) |     ) | ||||||
|  |  | ||||||
|  |     parser.add_argument( | ||||||
|  |         "--tolerance", | ||||||
|  |         type=float, | ||||||
|  |         default=None, | ||||||
|  |         help="Tolerance for the accuracy check", | ||||||
|  |     ) | ||||||
|  |  | ||||||
|  |     parser.add_argument( | ||||||
|  |         "--exit-on-accuracy-failure", | ||||||
|  |         action="store_true", | ||||||
|  |         help="Whether to exit with an error message for accuracy failure", | ||||||
|  |     ) | ||||||
|  |  | ||||||
|     args = parser.parse_args() |     args = parser.parse_args() | ||||||
|  |  | ||||||
|     # Handle list option |     # Handle list option | ||||||
| @ -146,7 +158,7 @@ Examples: | |||||||
|  |  | ||||||
|     # Handle all option |     # Handle all option | ||||||
|     if args.all: |     if args.all: | ||||||
|         run_all_benchmarks(args.visualize, args.compile_mode) |         run_all_benchmarks(args) | ||||||
|         return |         return | ||||||
|  |  | ||||||
|     # Handle specific benchmarks |     # Handle specific benchmarks | ||||||
| @ -157,7 +169,7 @@ Examples: | |||||||
|         sys.exit(1) |         sys.exit(1) | ||||||
|  |  | ||||||
|     for benchmark_name in args.benchmarks: |     for benchmark_name in args.benchmarks: | ||||||
|         run_benchmark(benchmark_name, args.visualize, args.compile_mode) |         run_benchmark(benchmark_name, args) | ||||||
|         print()  # Add spacing between benchmarks |         print()  # Add spacing between benchmarks | ||||||
|  |  | ||||||
|  |  | ||||||
|  | |||||||
| @ -9,8 +9,8 @@ import torch.nn.functional as F | |||||||
|  |  | ||||||
|  |  | ||||||
| class CrossEntropyForward(BenchmarkKernel): | class CrossEntropyForward(BenchmarkKernel): | ||||||
|     def __init__(self, compile_mode: str = "max-autotune-no-cudagraphs"): |     def __init__(self, script_args): | ||||||
|         super().__init__(compile_mode) |         super().__init__(script_args) | ||||||
|         self.available_backends = ["eager", "compiled", "quack", "liger"] |         self.available_backends = ["eager", "compiled", "quack", "liger"] | ||||||
|  |  | ||||||
|     def get_shapes(self) -> tuple[tuple[int, ...], ...]: |     def get_shapes(self) -> tuple[tuple[int, ...], ...]: | ||||||
| @ -106,8 +106,8 @@ class CrossEntropyForward(BenchmarkKernel): | |||||||
|  |  | ||||||
|  |  | ||||||
| class CrossEntropyBackward(BenchmarkKernel): | class CrossEntropyBackward(BenchmarkKernel): | ||||||
|     def __init__(self, compile_mode: str = "max-autotune-no-cudagraphs"): |     def __init__(self, script_args): | ||||||
|         super().__init__(compile_mode) |         super().__init__(script_args) | ||||||
|         self.available_backends = ["eager", "compiled", "quack", "liger"] |         self.available_backends = ["eager", "compiled", "quack", "liger"] | ||||||
|  |  | ||||||
|     def get_shapes(self) -> tuple[tuple[int, ...], ...]: |     def get_shapes(self) -> tuple[tuple[int, ...], ...]: | ||||||
| @ -194,8 +194,8 @@ class CrossEntropyBackward(BenchmarkKernel): | |||||||
|  |  | ||||||
|  |  | ||||||
| class SoftmaxForward(BenchmarkKernel): | class SoftmaxForward(BenchmarkKernel): | ||||||
|     def __init__(self, compile_mode: str = "max-autotune-no-cudagraphs"): |     def __init__(self, script_args): | ||||||
|         super().__init__(compile_mode) |         super().__init__(script_args) | ||||||
|         self.available_backends = ["eager", "compiled", "quack", "liger"] |         self.available_backends = ["eager", "compiled", "quack", "liger"] | ||||||
|  |  | ||||||
|     def get_shapes(self) -> tuple[tuple[int, ...], ...]: |     def get_shapes(self) -> tuple[tuple[int, ...], ...]: | ||||||
| @ -259,8 +259,8 @@ class SoftmaxForward(BenchmarkKernel): | |||||||
|  |  | ||||||
|  |  | ||||||
| class SoftmaxBackward(BenchmarkKernel): | class SoftmaxBackward(BenchmarkKernel): | ||||||
|     def __init__(self, compile_mode: str = "max-autotune-no-cudagraphs"): |     def __init__(self, script_args): | ||||||
|         super().__init__(compile_mode) |         super().__init__(script_args) | ||||||
|         self.available_backends = ["eager", "compiled", "quack", "liger"] |         self.available_backends = ["eager", "compiled", "quack", "liger"] | ||||||
|  |  | ||||||
|     def get_shapes(self) -> tuple[tuple[int, ...], ...]: |     def get_shapes(self) -> tuple[tuple[int, ...], ...]: | ||||||
| @ -329,8 +329,8 @@ class SoftmaxBackward(BenchmarkKernel): | |||||||
|  |  | ||||||
|  |  | ||||||
| class RMSNormForward(BenchmarkKernel): | class RMSNormForward(BenchmarkKernel): | ||||||
|     def __init__(self, compile_mode: str = "max-autotune-no-cudagraphs"): |     def __init__(self, script_args): | ||||||
|         super().__init__(compile_mode) |         super().__init__(script_args) | ||||||
|         self.available_backends = ["eager", "compiled", "quack", "liger"] |         self.available_backends = ["eager", "compiled", "quack", "liger"] | ||||||
|  |  | ||||||
|     def get_shapes(self) -> tuple[tuple[int, ...], ...]: |     def get_shapes(self) -> tuple[tuple[int, ...], ...]: | ||||||
| @ -383,7 +383,22 @@ class RMSNormForward(BenchmarkKernel): | |||||||
|         from quack.rmsnorm import _rmsnorm_fwd |         from quack.rmsnorm import _rmsnorm_fwd | ||||||
|  |  | ||||||
|         x, w = args |         x, w = args | ||||||
|         return lambda: _rmsnorm_fwd(x, w, eps=1e-6) |         y = torch.empty_like(x) | ||||||
|  |  | ||||||
|  |         def quack_fwd(): | ||||||
|  |             _rmsnorm_fwd( | ||||||
|  |                 x, | ||||||
|  |                 w, | ||||||
|  |                 out=y, | ||||||
|  |                 bias=None, | ||||||
|  |                 rstd=None, | ||||||
|  |                 residual=None, | ||||||
|  |                 residual_out=None, | ||||||
|  |                 eps=1e-6, | ||||||
|  |             ) | ||||||
|  |             return y | ||||||
|  |  | ||||||
|  |         return quack_fwd | ||||||
|  |  | ||||||
|     def liger(self, args, kwargs) -> Any: |     def liger(self, args, kwargs) -> Any: | ||||||
|         from liger_kernel.transformers.rms_norm import LigerRMSNorm |         from liger_kernel.transformers.rms_norm import LigerRMSNorm | ||||||
| @ -404,9 +419,14 @@ class RMSNormForward(BenchmarkKernel): | |||||||
|  |  | ||||||
|  |  | ||||||
| class RMSNormBackward(BenchmarkKernel): | class RMSNormBackward(BenchmarkKernel): | ||||||
|     def __init__(self, compile_mode: str = "max-autotune-no-cudagraphs"): |     def __init__(self, script_args): | ||||||
|         super().__init__(compile_mode) |         super().__init__(script_args) | ||||||
|         self.available_backends = ["eager", "compiled", "quack", "liger"] |         self.available_backends = [ | ||||||
|  |             "eager", | ||||||
|  |             "compiled", | ||||||
|  |             "quack", | ||||||
|  |             "liger", | ||||||
|  |         ] | ||||||
|  |  | ||||||
|     def get_shapes(self) -> tuple[tuple[int, ...], ...]: |     def get_shapes(self) -> tuple[tuple[int, ...], ...]: | ||||||
|         # TODO: OOM for (32768, 65536) on h100 |         # TODO: OOM for (32768, 65536) on h100 | ||||||
| @ -454,8 +474,11 @@ class RMSNormBackward(BenchmarkKernel): | |||||||
|             y, [x, w], grad_outputs=dy, retain_graph=True |             y, [x, w], grad_outputs=dy, retain_graph=True | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|  |     def compute_rstd(self, x, eps): | ||||||
|  |         return torch.rsqrt(torch.mean(x.float().square(), dim=-1, keepdim=True) + eps) | ||||||
|  |  | ||||||
|     def quack(self, args, kwargs=None) -> Any: |     def quack(self, args, kwargs=None) -> Any: | ||||||
|         from quack.rmsnorm import _rmsnorm_backward |         from quack.rmsnorm import _get_sm_count, _rmsnorm_bwd | ||||||
|  |  | ||||||
|         ( |         ( | ||||||
|             x, |             x, | ||||||
| @ -463,15 +486,40 @@ class RMSNormBackward(BenchmarkKernel): | |||||||
|             dy, |             dy, | ||||||
|         ) = args |         ) = args | ||||||
|         M, N = x.shape |         M, N = x.shape | ||||||
|         rstd = torch.randn(M, device="cuda", dtype=torch.float32) |  | ||||||
|         return lambda: _rmsnorm_backward(x, w, dy, rstd) |         rstd = self.compute_rstd(x, eps=1e-6) | ||||||
|  |         dx = torch.empty_like(x) | ||||||
|  |         sm_count = _get_sm_count(x.size(1), x.device) | ||||||
|  |         dw_partial = torch.empty( | ||||||
|  |             sm_count, x.size(1), device=x.device, dtype=torch.float32 | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |         def quack_bwd(): | ||||||
|  |             _rmsnorm_bwd( | ||||||
|  |                 x, | ||||||
|  |                 w, | ||||||
|  |                 dy, | ||||||
|  |                 rstd, | ||||||
|  |                 dx, | ||||||
|  |                 dw_partial, | ||||||
|  |                 db_partial=None, | ||||||
|  |                 dresidual_out=None, | ||||||
|  |                 dresidual=None, | ||||||
|  |                 sm_count=sm_count, | ||||||
|  |             ) | ||||||
|  |             dw = dw_partial.sum(dim=0).to(w.dtype) | ||||||
|  |             return dx, dw | ||||||
|  |  | ||||||
|  |         return quack_bwd | ||||||
|  |  | ||||||
|     def liger(self, args, kwargs=None) -> Any: |     def liger(self, args, kwargs=None) -> Any: | ||||||
|         from liger_kernel.transformers.rms_norm import LigerRMSNorm |         from liger_kernel.transformers.rms_norm import LigerRMSNorm | ||||||
|  |  | ||||||
|         x, w, dy = args |         x, w, dy = args | ||||||
|         M, N = x.shape |         M, N = x.shape | ||||||
|         liger_rmsnorm = LigerRMSNorm(hidden_size=N, eps=1e-6).cuda() |         liger_rmsnorm = LigerRMSNorm( | ||||||
|  |             hidden_size=N, eps=1e-6, casting_mode="gemma" | ||||||
|  |         ).cuda() | ||||||
|         liger_rmsnorm.weight.data.copy_(w) |         liger_rmsnorm.weight.data.copy_(w) | ||||||
|         y = liger_rmsnorm(x) |         y = liger_rmsnorm(x) | ||||||
|         return lambda: torch.autograd.grad( |         return lambda: torch.autograd.grad( | ||||||
| @ -489,8 +537,8 @@ class RMSNormBackward(BenchmarkKernel): | |||||||
|  |  | ||||||
|  |  | ||||||
| class LayerNormForward(BenchmarkKernel): | class LayerNormForward(BenchmarkKernel): | ||||||
|     def __init__(self, compile_mode: str = "max-autotune-no-cudagraphs"): |     def __init__(self, script_args): | ||||||
|         super().__init__(compile_mode) |         super().__init__(script_args) | ||||||
|         self.available_backends = ["eager", "compiled", "quack", "liger"] |         self.available_backends = ["eager", "compiled", "quack", "liger"] | ||||||
|  |  | ||||||
|     def get_shapes(self) -> tuple[tuple[int, ...], ...]: |     def get_shapes(self) -> tuple[tuple[int, ...], ...]: | ||||||
| @ -563,8 +611,8 @@ class LayerNormForward(BenchmarkKernel): | |||||||
|  |  | ||||||
|  |  | ||||||
| class LayerNormBackward(BenchmarkKernel): | class LayerNormBackward(BenchmarkKernel): | ||||||
|     def __init__(self, compile_mode: str = "max-autotune-no-cudagraphs"): |     def __init__(self, script_args): | ||||||
|         super().__init__(compile_mode) |         super().__init__(script_args) | ||||||
|         self.available_backends = ["eager", "compiled", "liger"] |         self.available_backends = ["eager", "compiled", "liger"] | ||||||
|  |  | ||||||
|     def get_shapes(self) -> tuple[tuple[int, ...], ...]: |     def get_shapes(self) -> tuple[tuple[int, ...], ...]: | ||||||
| @ -614,20 +662,31 @@ class LayerNormBackward(BenchmarkKernel): | |||||||
|             y, [x, w], grad_outputs=dy, retain_graph=True |             y, [x, w], grad_outputs=dy, retain_graph=True | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|  |     def compute_mean_rstd(self, x, eps): | ||||||
|  |         x = x.float() | ||||||
|  |  | ||||||
|  |         var, mean = torch.var_mean(x, dim=-1, keepdim=True, correction=0) | ||||||
|  |         rstd = torch.rsqrt(var + eps) | ||||||
|  |         return mean, rstd | ||||||
|  |  | ||||||
|     def liger(self, args, kwargs) -> Any: |     def liger(self, args, kwargs) -> Any: | ||||||
|         from liger_kernel.transformers.layer_norm import LigerLayerNorm |         """ | ||||||
|  |         Call layer_norm_backward directly rather than calling | ||||||
|  |         liger_kernel.transformers.layer_norm.LigerLayerNorm and | ||||||
|  |         torch.autograd.grad. | ||||||
|  |  | ||||||
|  |         The latter fashion saves mean/rstd in x.dtype which can fail | ||||||
|  |         accuracy test. We call layer_norm_backward with fp32 mean and | ||||||
|  |         rstd. | ||||||
|  |         """ | ||||||
|  |         from liger_kernel.ops.layer_norm import layer_norm_backward | ||||||
|  |  | ||||||
|         x, w, dy = args |         x, w, dy = args | ||||||
|  |         eps = 1e-6 | ||||||
|  |         mean, rstd = self.compute_mean_rstd(x, eps) | ||||||
|         M, N = x.shape |         M, N = x.shape | ||||||
|         liger_layernorm = LigerLayerNorm(hidden_size=N, eps=1e-6).cuda() |  | ||||||
|         liger_layernorm.weight.data.copy_(w) |         return lambda: layer_norm_backward(dy, x, w, None, mean, rstd)[0:2] | ||||||
|         liger_layernorm.bias.data.copy_( |  | ||||||
|             torch.zeros(N, device="cuda", dtype=torch.float32) |  | ||||||
|         ) |  | ||||||
|         y = liger_layernorm(x) |  | ||||||
|         return lambda: torch.autograd.grad( |  | ||||||
|             y, [x, liger_layernorm.weight], grad_outputs=dy, retain_graph=True |  | ||||||
|         ) |  | ||||||
|  |  | ||||||
|     def benchmark(self): |     def benchmark(self): | ||||||
|         for M, N in self.get_shapes(): |         for M, N in self.get_shapes(): | ||||||
|  | |||||||
| @ -1,4 +1,5 @@ | |||||||
| import os | import os | ||||||
|  | import sys | ||||||
| from collections import defaultdict | from collections import defaultdict | ||||||
| from collections.abc import Callable | from collections.abc import Callable | ||||||
| from dataclasses import dataclass | from dataclasses import dataclass | ||||||
| @ -43,10 +44,11 @@ class Performance: | |||||||
|  |  | ||||||
|  |  | ||||||
| class BenchmarkKernel: | class BenchmarkKernel: | ||||||
|     def __init__(self, compile_mode: str = "max-autotune-no-cudagraphs"): |     def __init__(self, script_args): | ||||||
|  |         self.script_args = script_args | ||||||
|         self.name = self.__class__.__name__ |         self.name = self.__class__.__name__ | ||||||
|         self.available_backends: list[str] = [] |         self.available_backends: list[str] = [] | ||||||
|         self.compile_mode: str = compile_mode |         self.compile_mode: str = script_args.compile_mode | ||||||
|  |  | ||||||
|         # mapping from backend to list of performance results |         # mapping from backend to list of performance results | ||||||
|         self.profiling_results: defaultdict[str, list[Performance]] = defaultdict(list) |         self.profiling_results: defaultdict[str, list[Performance]] = defaultdict(list) | ||||||
| @ -106,14 +108,21 @@ class BenchmarkKernel: | |||||||
|             args_ref, kwargs_ref = self.clone_inputs(args, kwargs) |             args_ref, kwargs_ref = self.clone_inputs(args, kwargs) | ||||||
|             res[backend] = getattr(self, backend)(args_ref, kwargs_ref)() |             res[backend] = getattr(self, backend)(args_ref, kwargs_ref)() | ||||||
|         gold = res["eager"] |         gold = res["eager"] | ||||||
|  |  | ||||||
|  |         tol = {} | ||||||
|  |         if self.script_args.tolerance: | ||||||
|  |             tol = { | ||||||
|  |                 "atol": self.script_args.tolerance, | ||||||
|  |                 "rtol": self.script_args.tolerance, | ||||||
|  |             } | ||||||
|         for backend in self.available_backends: |         for backend in self.available_backends: | ||||||
|             if backend == "eager": |             if backend == "eager": | ||||||
|                 continue |                 continue | ||||||
|             try: |             try: | ||||||
|                 torch.testing.assert_close(res[backend], gold) |                 torch.testing.assert_close(res[backend], gold, **tol) | ||||||
|                 for t, gold_t in zip(res[backend], gold): |                 for t, gold_t in zip(res[backend], gold): | ||||||
|                     if t.requires_grad: |                     if t.requires_grad: | ||||||
|                         torch.testing.assert_close(t.grad, gold_t.grad) |                         torch.testing.assert_close(t.grad, gold_t.grad, **tol) | ||||||
|                 print( |                 print( | ||||||
|                     f"Accuracy check \033[92m✓ succeed\033[0m for {backend} backend on {self.name} kernel" |                     f"Accuracy check \033[92m✓ succeed\033[0m for {backend} backend on {self.name} kernel" | ||||||
|                 ) |                 ) | ||||||
| @ -121,6 +130,9 @@ class BenchmarkKernel: | |||||||
|                 print( |                 print( | ||||||
|                     f"Accuracy check \033[91m✗ failed\033[0m for {backend} backend on {self.name} kernel. Error {e}" |                     f"Accuracy check \033[91m✗ failed\033[0m for {backend} backend on {self.name} kernel. Error {e}" | ||||||
|                 ) |                 ) | ||||||
|  |                 if self.script_args.exit_on_accuracy_failure: | ||||||
|  |                     print("Exit right away since --exit-on-accuracy-failure is set") | ||||||
|  |                     sys.exit(1) | ||||||
|  |  | ||||||
|     def benchmark_single_shape( |     def benchmark_single_shape( | ||||||
|         self, args, kwargs=None, should_check_accuracy=True, setting: str = "" |         self, args, kwargs=None, should_check_accuracy=True, setting: str = "" | ||||||
|  | |||||||
| @ -1,8 +1,8 @@ | |||||||
| add_loop_eager,compile_time_instruction_count,3070000000,0.1 | add_loop_eager,compile_time_instruction_count,3184000000,0.1 | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| add_loop_eager_dynamic,compile_time_instruction_count,4432000000,0.1 | add_loop_eager_dynamic,compile_time_instruction_count,4595000000,0.1 | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| @ -18,7 +18,7 @@ add_loop_inductor_gpu,compile_time_instruction_count,26800000000,0.1 | |||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| basic_modules_ListOfLinears_eager,compile_time_instruction_count,1048000000,0.1 | basic_modules_ListOfLinears_eager,compile_time_instruction_count,1096000000,0.1 | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| @ -26,7 +26,7 @@ basic_modules_ListOfLinears_inductor,compile_time_instruction_count,15240000000, | |||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,17020000000,0.1 | basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,17720000000,0.1 | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| @ -34,11 +34,11 @@ basic_modules_ListOfLinears_inductor_gpu,compile_time_instruction_count,11090000 | |||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| update_hint_regression,compile_time_instruction_count,1719000000,0.1 | update_hint_regression,compile_time_instruction_count,1645000000,0.1 | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| sum_floordiv_regression,compile_time_instruction_count,3686995725,0.1 | sum_floordiv_regression,compile_time_instruction_count,3813000000,0.1 | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| @ -50,31 +50,31 @@ symint_sum_loop,compile_time_instruction_count,4299000000,0.1 | |||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| aotdispatcher_inference_nosubclass_cpu,compile_time_instruction_count,1869000000,0.1 | aotdispatcher_inference_nosubclass_cpu,compile_time_instruction_count,1793000000,0.1 | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| aotdispatcher_inference_subclass_cpu,compile_time_instruction_count,5281000000,0.1 | aotdispatcher_inference_subclass_cpu,compile_time_instruction_count,5120000000,0.1 | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| aotdispatcher_partitioner_cpu,compile_time_instruction_count,8333000000,0.1 | aotdispatcher_partitioner_cpu,compile_time_instruction_count,7936000000,0.1 | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| aotdispatcher_partitioner_cpu2,compile_time_instruction_count,1909000000,0.1 | aotdispatcher_partitioner_cpu2,compile_time_instruction_count,1848000000,0.1 | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| aotdispatcher_training_nosubclass_cpu,compile_time_instruction_count,3442000000,0.1 | aotdispatcher_training_nosubclass_cpu,compile_time_instruction_count,3152000000,0.1 | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| aotdispatcher_training_subclass_cpu,compile_time_instruction_count,9239000000,0.1 | aotdispatcher_training_subclass_cpu,compile_time_instruction_count,8301000000,0.1 | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| mm_loop_inductor_gpu,compile_time_instruction_count,4820968837,0.1 | mm_loop_inductor_gpu,compile_time_instruction_count,4958000000,0.1 | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| @ -82,8 +82,8 @@ mm_loop_inductor_dynamic_gpu,compile_time_instruction_count,9051000000,0.1 | |||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| basic_NestedModule_eager,compile_time_instruction_count,9554000000,0.1 | basic_NestedModule_eager,compile_time_instruction_count,9990000000,0.1 | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| basic_InlineMod_eager,compile_time_instruction_count,7618000000,0.1 | basic_InlineMod_eager,compile_time_instruction_count,8126000000,0.1 | ||||||
|  | |||||||
| 
 | 
| @ -43,6 +43,7 @@ tolerance: | |||||||
|     - doctr_reco_predictor |     - doctr_reco_predictor | ||||||
|     - drq |     - drq | ||||||
|     - phlippe_resnet |     - phlippe_resnet | ||||||
|  |     - pytorch_CycleGAN_and_pix2pix | ||||||
|  |  | ||||||
|   higher_bf16: |   higher_bf16: | ||||||
|     - doctr_reco_predictor |     - doctr_reco_predictor | ||||||
|  | |||||||
| @ -44,21 +44,101 @@ PyTorch,div_,div__M1_N1_K1_cpu_dtype_onetorch.float32_dtype_twotorch.float32,sho | |||||||
| PyTorch,div_,div__M64_N64_K64_cpu_dtype_onetorch.float32_dtype_twotorch.float32,short,False,59.241161,0.000000 | PyTorch,div_,div__M64_N64_K64_cpu_dtype_onetorch.float32_dtype_twotorch.float32,short,False,59.241161,0.000000 | ||||||
| PyTorch,div_,div__M64_N64_K128_cpu_dtype_onetorch.float32_dtype_twotorch.float32,short,False,59.852816,0.000000 | PyTorch,div_,div__M64_N64_K128_cpu_dtype_onetorch.float32_dtype_twotorch.float32,short,False,59.852816,0.000000 | ||||||
| PyTorch,add,"add_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.float32",short,False,57.006677,0.000000 | PyTorch,add,"add_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.float32",short,False,57.006677,0.000000 | ||||||
|  | PyTorch,add,"add_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.bfloat16",short,False,88.167000,0.000000 | ||||||
|  | PyTorch,add,"add_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.float64",short,False,57.519000,0.000000 | ||||||
| PyTorch,sub,"sub_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.float32",short,False,55.606088,0.000000 | PyTorch,sub,"sub_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.float32",short,False,55.606088,0.000000 | ||||||
|  | PyTorch,sub,"sub_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.bfloat16",short,False,86.551000,0.000000 | ||||||
|  | PyTorch,sub,"sub_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.float64",short,False,57.864088,0.000000 | ||||||
| PyTorch,div,"div_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.float32",short,False,58.529255,0.000000 | PyTorch,div,"div_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.float32",short,False,58.529255,0.000000 | ||||||
|  | PyTorch,div,"div_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.bfloat16",short,False,71.641000,0.000000 | ||||||
|  | PyTorch,div,"div_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.float64",short,False,83.073000,0.000000 | ||||||
| PyTorch,mul,"mul_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.float32",short,False,54.645077,0.000000 | PyTorch,mul,"mul_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.float32",short,False,54.645077,0.000000 | ||||||
|  | PyTorch,mul,"mul_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.bfloat16",short,False,67.570000,0.000000 | ||||||
|  | PyTorch,mul,"mul_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.float64",short,False,57.895000,0.000000 | ||||||
| PyTorch,add,add_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,4.397014,0.000000 | PyTorch,add,add_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,4.397014,0.000000 | ||||||
|  | PyTorch,add,add_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,7.739000,0.000000 | ||||||
|  | PyTorch,add,add_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,7.786000,0.000000 | ||||||
|  | PyTorch,add,add_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,1.911000,0.000000 | ||||||
| PyTorch,add,add_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,59.243500,0.000000 | PyTorch,add,add_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,59.243500,0.000000 | ||||||
|  | PyTorch,add,add_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,105.066000,0.000000 | ||||||
|  | PyTorch,add,add_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,106.076000,0.000000 | ||||||
|  | PyTorch,add,add_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,47.225000,0.000000 | ||||||
| PyTorch,add,add_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,57.947691,0.000000 | PyTorch,add,add_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,57.947691,0.000000 | ||||||
|  | PyTorch,add,add_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,107.291000,0.000000 | ||||||
|  | PyTorch,add,add_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,107.224000,0.000000 | ||||||
|  | PyTorch,add,add_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,47.912000,0.000000 | ||||||
| PyTorch,sub,sub_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,1.925851,0.000000 | PyTorch,sub,sub_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,1.925851,0.000000 | ||||||
|  | PyTorch,sub,sub_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,8.0240000,0.000000 | ||||||
|  | PyTorch,sub,sub_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,8.069000,0.000000 | ||||||
|  | PyTorch,sub,sub_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,1.938000,0.000000 | ||||||
| PyTorch,sub,sub_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,57.308320,0.000000 | PyTorch,sub,sub_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,57.308320,0.000000 | ||||||
|  | PyTorch,sub,sub_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,107.091000,0.000000 | ||||||
|  | PyTorch,sub,sub_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,108.710000,0.000000 | ||||||
|  | PyTorch,sub,sub_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,47.502000,0.000000 | ||||||
| PyTorch,sub,sub_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,57.787743,0.000000 | PyTorch,sub,sub_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,57.787743,0.000000 | ||||||
|  | PyTorch,sub,sub_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,108.863000,0.000000 | ||||||
|  | PyTorch,sub,sub_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,108.939000,0.000000 | ||||||
|  | PyTorch,sub,sub_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,47.603000,0.000000 | ||||||
| PyTorch,div,div_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,7.978539,0.000000 | PyTorch,div,div_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,7.978539,0.000000 | ||||||
|  | PyTorch,div,div_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,8.741000,0.000000 | ||||||
|  | PyTorch,div,div_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,8.757000,0.000000 | ||||||
|  | PyTorch,div,div_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,8.774000,0.000000 | ||||||
| PyTorch,div,div_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,159.754860,0.000000 | PyTorch,div,div_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,159.754860,0.000000 | ||||||
|  | PyTorch,div,div_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,165.552000,0.000000 | ||||||
|  | PyTorch,div,div_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,165.755000,0.000000 | ||||||
|  | PyTorch,div,div_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,165.714000,0.000000 | ||||||
| PyTorch,div,div_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,165.360235,0.000000 | PyTorch,div,div_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,165.360235,0.000000 | ||||||
|  | PyTorch,div,div_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,168.376000,0.000000 | ||||||
|  | PyTorch,div,div_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,169.604000,0.000000 | ||||||
|  | PyTorch,div,div_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,168.428000,0.000000 | ||||||
| PyTorch,mul,mul_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,3.928136,0.000000 | PyTorch,mul,mul_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,3.928136,0.000000 | ||||||
|  | PyTorch,mul,mul_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,7.402000,0.000000 | ||||||
|  | PyTorch,mul,mul_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,7.567000,0.000000 | ||||||
|  | PyTorch,mul,mul_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,4.020000,0.000000 | ||||||
| PyTorch,mul,mul_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,56.413499,0.000000 | PyTorch,mul,mul_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,56.413499,0.000000 | ||||||
|  | PyTorch,mul,mul_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,104.638000,0.000000 | ||||||
|  | PyTorch,mul,mul_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,104.335000,0.000000 | ||||||
|  | PyTorch,mul,mul_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,48.612000,0.000000 | ||||||
| PyTorch,mul,mul_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,55.925090,0.000000 | PyTorch,mul,mul_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,55.925090,0.000000 | ||||||
|  | PyTorch,mul,mul_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,106.110000,0.000000 | ||||||
|  | PyTorch,mul,mul_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,106.389000,0.000000 | ||||||
|  | PyTorch,mul,mul_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,48.195000,0.000000 | ||||||
|  | PyTorch,asr,asr_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,1.989000,0.000000 | ||||||
|  | PyTorch,asr,asr_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,7.999000,0.000000 | ||||||
|  | PyTorch,asr,asr_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,7.939000,0.000000 | ||||||
|  | PyTorch,asr,asr_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,1.980000,0.000000 | ||||||
|  | PyTorch,asr,asr_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,54.408000,0.000000 | ||||||
|  | PyTorch,asr,asr_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,105.647000,0.000000 | ||||||
|  | PyTorch,asr,asr_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,106.476000,0.000000 | ||||||
|  | PyTorch,asr,asr_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,48.784000,0.000000 | ||||||
|  | PyTorch,asr,asr_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,55.583000,0.000000 | ||||||
|  | PyTorch,asr,asr_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,108.083000,0.000000 | ||||||
|  | PyTorch,asr,asr_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,107.663000,0.000000 | ||||||
|  | PyTorch,asr,asr_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,48.283000,0.000000 | ||||||
|  | PyTorch,lsl,lsl_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,1.986000,0.000000 | ||||||
|  | PyTorch,lsl,lsl_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,7.676000,0.000000 | ||||||
|  | PyTorch,lsl,lsl_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,7.618000,0.000000 | ||||||
|  | PyTorch,lsl,lsl_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,1.982000,0.000000 | ||||||
|  | PyTorch,lsl,lsl_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,54.698000,0.000000 | ||||||
|  | PyTorch,lsl,lsl_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,105.899000,0.000000 | ||||||
|  | PyTorch,lsl,lsl_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,106.741000,0.000000 | ||||||
|  | PyTorch,lsl,lsl_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,51.182000,0.000000 | ||||||
|  | PyTorch,lsl,lsl_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,55.290000,0.000000 | ||||||
|  | PyTorch,lsl,lsl_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,107.744000,0.000000 | ||||||
|  | PyTorch,lsl,lsl_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,107.820000,0.000000 | ||||||
|  | PyTorch,lsl,lsl_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,51.298000,0.000000 | ||||||
|  | PyTorch,xor,xor_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,1.988000,0.000000 | ||||||
|  | PyTorch,xor,xor_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,7.689000,0.000000 | ||||||
|  | PyTorch,xor,xor_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,7.695000,0.000000 | ||||||
|  | PyTorch,xor,xor_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,1.978000,0.000000 | ||||||
|  | PyTorch,xor,xor_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,54.934000,0.000000 | ||||||
|  | PyTorch,xor,xor_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,105.217000,0.000000 | ||||||
|  | PyTorch,xor,xor_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,104.215000,0.000000 | ||||||
|  | PyTorch,xor,xor_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,47.115000,0.000000 | ||||||
|  | PyTorch,xor,xor_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,55.974000,0.000000 | ||||||
|  | PyTorch,xor,xor_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,106.828000,0.000000 | ||||||
|  | PyTorch,xor,xor_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,106.879000,0.000000 | ||||||
|  | PyTorch,xor,xor_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,48.197000,0.000000 | ||||||
| PyTorch,logical_and,"logical_and_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.bool",short,False,78.404254,0.000000 | PyTorch,logical_and,"logical_and_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.bool",short,False,78.404254,0.000000 | ||||||
| PyTorch,logical_and,logical_and_M1_N1_K1_cpu_dtype_onetorch.bool_dtype_twotorch.bool,short,False,5.354032,0.000000 | PyTorch,logical_and,logical_and_M1_N1_K1_cpu_dtype_onetorch.bool_dtype_twotorch.bool,short,False,5.354032,0.000000 | ||||||
| PyTorch,logical_and,logical_and_M64_N64_K64_cpu_dtype_onetorch.bool_dtype_twotorch.bool,short,False,54.072783,0.000000 | PyTorch,logical_and,logical_and_M64_N64_K64_cpu_dtype_onetorch.bool_dtype_twotorch.bool,short,False,54.072783,0.000000 | ||||||
| @ -71,6 +151,9 @@ PyTorch,baddbmm,baddbmm_B2_M1_N8_K2_cpu_dtypetorch.float32,short,False,6.631313, | |||||||
| PyTorch,baddbmm,baddbmm_B2_M1_N8_K2_cpu_dtypetorch.bfloat16,short,False,6.476986,0.000000 | PyTorch,baddbmm,baddbmm_B2_M1_N8_K2_cpu_dtypetorch.bfloat16,short,False,6.476986,0.000000 | ||||||
| PyTorch,baddbmm,baddbmm_B128_M64_N32_K64_cpu_dtypetorch.float32,short,False,266.065131,0.000000 | PyTorch,baddbmm,baddbmm_B128_M64_N32_K64_cpu_dtypetorch.float32,short,False,266.065131,0.000000 | ||||||
| PyTorch,baddbmm,baddbmm_B128_M64_N32_K64_cpu_dtypetorch.bfloat16,short,False,295.503063,0.000000 | PyTorch,baddbmm,baddbmm_B128_M64_N32_K64_cpu_dtypetorch.bfloat16,short,False,295.503063,0.000000 | ||||||
|  | PyTorch,all,all_M1_N1_K1_cpu,short,False,5.773000,0.000000 | ||||||
|  | PyTorch,all,all_M64_N64_K64_cpu,short,False,89.427000,0.000000 | ||||||
|  | PyTorch,all,all_M64_N64_K128_cpu,short,False,120.119000,0.000000 | ||||||
| PyTorch,cat,"cat_sizes(1,1,1)_N2_dim0_cpu",short,False,4.301950,0.000000 | PyTorch,cat,"cat_sizes(1,1,1)_N2_dim0_cpu",short,False,4.301950,0.000000 | ||||||
| PyTorch,cat,"cat_sizes(512,512,2)_N2_dim1_cpu",short,False,99.093415,0.000000 | PyTorch,cat,"cat_sizes(512,512,2)_N2_dim1_cpu",short,False,99.093415,0.000000 | ||||||
| PyTorch,cat,"cat_sizes(128,1024,2)_N2_dim1_cpu",short,False,96.771578,0.000000 | PyTorch,cat,"cat_sizes(128,1024,2)_N2_dim1_cpu",short,False,96.771578,0.000000 | ||||||
|  | |||||||
| 
 | 
| @ -580,6 +580,9 @@ class BenchmarkRunner: | |||||||
|                 else "unknown" |                 else "unknown" | ||||||
|             ) |             ) | ||||||
|  |  | ||||||
|  |             # Extract operator name from test_name | ||||||
|  |             operator_name = test_name.split("_")[0] | ||||||
|  |  | ||||||
|             # Create the record |             # Create the record | ||||||
|             @dataclass |             @dataclass | ||||||
|             class BenchmarkInfo: |             class BenchmarkInfo: | ||||||
| @ -593,6 +596,7 @@ class BenchmarkRunner: | |||||||
|                 name: str |                 name: str | ||||||
|                 type: str |                 type: str | ||||||
|                 origins: list[str] |                 origins: list[str] | ||||||
|  |                 extra_info: dict[str, Any] | ||||||
|  |  | ||||||
|             @dataclass |             @dataclass | ||||||
|             class MetricInfo: |             class MetricInfo: | ||||||
| @ -618,10 +622,14 @@ class BenchmarkRunner: | |||||||
|                         "device": device, |                         "device": device, | ||||||
|                         "arch": device_arch, |                         "arch": device_arch, | ||||||
|                         "use_compile": use_compile, |                         "use_compile": use_compile, | ||||||
|  |                         "operator_name": operator_name, | ||||||
|                     }, |                     }, | ||||||
|                 ), |                 ), | ||||||
|                 model=ModelInfo( |                 model=ModelInfo( | ||||||
|                     name=test_name, type="micro-benchmark", origins=["pytorch"] |                     name=test_name, | ||||||
|  |                     type="micro-benchmark", | ||||||
|  |                     origins=["pytorch"], | ||||||
|  |                     extra_info={"operator_name": operator_name}, | ||||||
|                 ), |                 ), | ||||||
|                 metric=MetricInfo( |                 metric=MetricInfo( | ||||||
|                     name="latency", |                     name="latency", | ||||||
|  | |||||||
| @ -25,7 +25,7 @@ binary_configs_broadcast = op_bench.config_list( | |||||||
|     ], |     ], | ||||||
|     cross_product_configs={ |     cross_product_configs={ | ||||||
|         "device": ["cpu"], |         "device": ["cpu"], | ||||||
|         "dtype": [torch.float], |         "dtype": [torch.float, torch.bfloat16, torch.float64], | ||||||
|     }, |     }, | ||||||
|     tags=["short"], |     tags=["short"], | ||||||
| ) | ) | ||||||
| @ -71,8 +71,8 @@ binary_short_configs = op_bench.config_list( | |||||||
|     ], |     ], | ||||||
|     cross_product_configs={ |     cross_product_configs={ | ||||||
|         "device": ["cpu", "cuda"], |         "device": ["cpu", "cuda"], | ||||||
|         "dtype_one": [torch.int32], |         "dtype_one": [torch.int32, torch.uint8], | ||||||
|         "dtype_two": [torch.int32], |         "dtype_two": [torch.int32, torch.uint8], | ||||||
|     }, |     }, | ||||||
|     tags=["short"], |     tags=["short"], | ||||||
| ) | ) | ||||||
| @ -82,8 +82,8 @@ binary_long_configs = op_bench.cross_product_configs( | |||||||
|     N=[32, 64], |     N=[32, 64], | ||||||
|     K=[256, 512], |     K=[256, 512], | ||||||
|     device=["cpu", "cuda"], |     device=["cpu", "cuda"], | ||||||
|     dtype_one=[torch.int8, torch.int32], |     dtype_one=[torch.int8, torch.int32, torch.uint8], | ||||||
|     dtype_two=[torch.int8, torch.int32], |     dtype_two=[torch.int8, torch.int32, torch.uint8], | ||||||
|     tags=["long"], |     tags=["long"], | ||||||
| ) | ) | ||||||
|  |  | ||||||
|  | |||||||
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							| @ -176,8 +176,8 @@ THIRD_PARTY_LIBS = { | |||||||
|     "omp": ["//xplat/third-party/linker_lib:omp", "//third_party:no-op"], |     "omp": ["//xplat/third-party/linker_lib:omp", "//third_party:no-op"], | ||||||
|     "pocketfft": ["//third-party/pocket_fft:pocketfft", "//third_party:pocketfft_header"], |     "pocketfft": ["//third-party/pocket_fft:pocketfft", "//third_party:pocketfft_header"], | ||||||
|     "psimd": ["//xplat/third-party/psimd:psimd", "//third_party:psimd"], |     "psimd": ["//xplat/third-party/psimd:psimd", "//third_party:psimd"], | ||||||
|     "pthreadpool": ["//xplat/third-party/pthreadpool:pthreadpool", "//third_party:pthreadpool"], |     "pthreadpool": ["fbsource//xplat/third-party/pthreadpool:pthreadpool", "//third_party:pthreadpool"], | ||||||
|     "pthreadpool_header": ["//xplat/third-party/pthreadpool:pthreadpool_header", "//third_party:pthreadpool_header"], |     "pthreadpool_header": ["fbsource//xplat/third-party/pthreadpool:pthreadpool_header", "//third_party:pthreadpool_header"], | ||||||
|     "moodycamel": ["//third-party/moodycamel:moodycamel", "//third_party:moodycamel"], |     "moodycamel": ["//third-party/moodycamel:moodycamel", "//third_party:moodycamel"], | ||||||
|     "pyyaml": ["//third-party/pypi/pyyaml:pyyaml", "//third_party:pyyaml"], |     "pyyaml": ["//third-party/pypi/pyyaml:pyyaml", "//third_party:pyyaml"], | ||||||
|     "rt": ["//xplat/third-party/linker_lib:rt", "//third_party:rt"], |     "rt": ["//xplat/third-party/linker_lib:rt", "//third_party:rt"], | ||||||
| @ -1729,8 +1729,10 @@ def define_buck_targets( | |||||||
|             "torch/csrc/jit/backends/backend_debug_info.cpp", |             "torch/csrc/jit/backends/backend_debug_info.cpp", | ||||||
|             "torch/csrc/jit/backends/backend_interface.cpp", |             "torch/csrc/jit/backends/backend_interface.cpp", | ||||||
|         ], |         ], | ||||||
|         compiler_flags = get_pt_compiler_flags(), |         compiler_flags = get_pt_compiler_flags() + select({ | ||||||
|         fbandroid_compiler_flags = c2_fbandroid_xplat_compiler_flags, |             "DEFAULT": [], | ||||||
|  |             "ovr_config//os:android": c2_fbandroid_xplat_compiler_flags | ||||||
|  |         }), | ||||||
|         # @lint-ignore BUCKLINT link_whole |         # @lint-ignore BUCKLINT link_whole | ||||||
|         link_whole = True, |         link_whole = True, | ||||||
|         linker_flags = get_no_as_needed_linker_flag(), |         linker_flags = get_no_as_needed_linker_flag(), | ||||||
| @ -2023,6 +2025,9 @@ def define_buck_targets( | |||||||
|                 "ovr_config//os:android-x86_64": [ |                 "ovr_config//os:android-x86_64": [ | ||||||
|                     "-mssse3", |                     "-mssse3", | ||||||
|                 ], |                 ], | ||||||
|  |             }) + select({ | ||||||
|  |                 "DEFAULT": [], | ||||||
|  |                 "ovr_config//os:android": c2_fbandroid_xplat_compiler_flags, | ||||||
|             }), |             }), | ||||||
|             exported_preprocessor_flags = get_aten_preprocessor_flags(), |             exported_preprocessor_flags = get_aten_preprocessor_flags(), | ||||||
|             exported_deps = [ |             exported_deps = [ | ||||||
|  | |||||||
| @ -855,6 +855,7 @@ libtorch_python_cuda_core_sources = [ | |||||||
|     "torch/csrc/cuda/Stream.cpp", |     "torch/csrc/cuda/Stream.cpp", | ||||||
|     "torch/csrc/cuda/Graph.cpp", |     "torch/csrc/cuda/Graph.cpp", | ||||||
|     "torch/csrc/cuda/MemPool.cpp", |     "torch/csrc/cuda/MemPool.cpp", | ||||||
|  |     "torch/csrc/cuda/GreenContext.cpp", | ||||||
|     "torch/csrc/cuda/shared/cudart.cpp", |     "torch/csrc/cuda/shared/cudart.cpp", | ||||||
|     "torch/csrc/cuda/shared/nvtx.cpp", |     "torch/csrc/cuda/shared/nvtx.cpp", | ||||||
|     "torch/csrc/cuda/utils.cpp", |     "torch/csrc/cuda/utils.cpp", | ||||||
|  | |||||||
| @ -13,7 +13,17 @@ | |||||||
| namespace c10::CachingAllocator { | namespace c10::CachingAllocator { | ||||||
|  |  | ||||||
| // "large" allocations may be packed in 20 MiB blocks | // "large" allocations may be packed in 20 MiB blocks | ||||||
| const size_t kLargeBuffer = 20971520; | constexpr size_t kLargeBuffer = 20971520; | ||||||
|  | // "small" allocations are packed in 2 MiB blocks | ||||||
|  | constexpr size_t kSmallBuffer = 2097152; | ||||||
|  | // all sizes are rounded to at least 512 bytes | ||||||
|  | constexpr size_t kMinBlockSize = 512; | ||||||
|  | // largest "small" allocation is 1 MiB | ||||||
|  | constexpr size_t kSmallSize = 1048576; | ||||||
|  | // allocations between 1 and 10 MiB may use kLargeBuffer | ||||||
|  | constexpr size_t kMinLargeAlloc = 10485760; | ||||||
|  | // round up large allocations to 2 MiB | ||||||
|  | constexpr size_t kRoundLarge = 2097152; | ||||||
|  |  | ||||||
| // A utility class for tokenizing allocator configuration strings into discrete | // A utility class for tokenizing allocator configuration strings into discrete | ||||||
| // parts. For example, the config string: | // parts. For example, the config string: | ||||||
|  | |||||||
| @ -223,7 +223,7 @@ inline DispatchKey backendToDispatchKey(Backend b) { | |||||||
|     case Backend::PrivateUse1: |     case Backend::PrivateUse1: | ||||||
|       return DispatchKey::PrivateUse1; |       return DispatchKey::PrivateUse1; | ||||||
|     default: |     default: | ||||||
|       throw std::runtime_error("Unknown backend"); |       TORCH_CHECK(false, "Unknown backend"); | ||||||
|   } |   } | ||||||
| } | } | ||||||
|  |  | ||||||
|  | |||||||
| @ -336,7 +336,7 @@ class C10_API Scalar { | |||||||
|     } else if (isBoolean()) { |     } else if (isBoolean()) { | ||||||
|       return ScalarType::Bool; |       return ScalarType::Bool; | ||||||
|     } else { |     } else { | ||||||
|       throw std::runtime_error("Unknown scalar type."); |       TORCH_CHECK(false, "Unknown scalar type."); | ||||||
|     } |     } | ||||||
|   } |   } | ||||||
|  |  | ||||||
|  | |||||||
| @ -228,7 +228,7 @@ std::pair<std::string, std::string> getDtypeNames(c10::ScalarType scalarType) { | |||||||
|     case c10::ScalarType::Float4_e2m1fn_x2: |     case c10::ScalarType::Float4_e2m1fn_x2: | ||||||
|       return std::make_pair("float4_e2m1fn_x2", ""); |       return std::make_pair("float4_e2m1fn_x2", ""); | ||||||
|     default: |     default: | ||||||
|       throw std::runtime_error("Unimplemented scalar type"); |       TORCH_CHECK(false, "Unimplemented scalar type"); | ||||||
|   } |   } | ||||||
| } | } | ||||||
|  |  | ||||||
|  | |||||||
| @ -52,19 +52,6 @@ AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(SPECIALIZE_CppTypeToScalarType) | |||||||
| AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_CONSTANT) | AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_CONSTANT) | ||||||
| #undef DEFINE_CONSTANT | #undef DEFINE_CONSTANT | ||||||
|  |  | ||||||
| inline const char* toString(ScalarType t) { |  | ||||||
| #define DEFINE_CASE(_, name) \ |  | ||||||
|   case ScalarType::name:     \ |  | ||||||
|     return #name; |  | ||||||
|  |  | ||||||
|   switch (t) { |  | ||||||
|     AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_CASE) |  | ||||||
|     default: |  | ||||||
|       return "UNKNOWN_SCALAR"; |  | ||||||
|   } |  | ||||||
| #undef DEFINE_CASE |  | ||||||
| } |  | ||||||
|  |  | ||||||
| inline size_t elementSize(ScalarType t) { | inline size_t elementSize(ScalarType t) { | ||||||
| #define CASE_ELEMENTSIZE_CASE(ctype, name) \ | #define CASE_ELEMENTSIZE_CASE(ctype, name) \ | ||||||
|   case ScalarType::name:                   \ |   case ScalarType::name:                   \ | ||||||
| @ -150,22 +137,6 @@ inline ScalarType toQIntType(ScalarType t) { | |||||||
|   } |   } | ||||||
| } | } | ||||||
|  |  | ||||||
| inline ScalarType toUnderlying(ScalarType t) { |  | ||||||
|   switch (t) { |  | ||||||
|     case ScalarType::QUInt8: |  | ||||||
|     case ScalarType::QUInt4x2: |  | ||||||
|       [[fallthrough]]; |  | ||||||
|     case ScalarType::QUInt2x4: |  | ||||||
|       return ScalarType::Byte; |  | ||||||
|     case ScalarType::QInt8: |  | ||||||
|       return ScalarType::Char; |  | ||||||
|     case ScalarType::QInt32: |  | ||||||
|       return ScalarType::Int; |  | ||||||
|     default: |  | ||||||
|       return t; |  | ||||||
|   } |  | ||||||
| } |  | ||||||
|  |  | ||||||
| inline bool isSignedType(ScalarType t) { | inline bool isSignedType(ScalarType t) { | ||||||
| #define CASE_ISSIGNED(name)     \ | #define CASE_ISSIGNED(name)     \ | ||||||
|   case ScalarType::name:        \ |   case ScalarType::name:        \ | ||||||
| @ -308,12 +279,6 @@ inline bool canCast(const ScalarType from, const ScalarType to) { | |||||||
|  |  | ||||||
| C10_API ScalarType promoteTypes(ScalarType a, ScalarType b); | C10_API ScalarType promoteTypes(ScalarType a, ScalarType b); | ||||||
|  |  | ||||||
| inline std::ostream& operator<<( |  | ||||||
|     std::ostream& stream, |  | ||||||
|     at::ScalarType scalar_type) { |  | ||||||
|   return stream << toString(scalar_type); |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // Returns a pair of strings representing the names for each dtype. | // Returns a pair of strings representing the names for each dtype. | ||||||
| // The returned pair is (name, legacy_name_if_applicable) | // The returned pair is (name, legacy_name_if_applicable) | ||||||
| C10_API std::pair<std::string, std::string> getDtypeNames( | C10_API std::pair<std::string, std::string> getDtypeNames( | ||||||
|  | |||||||
| @ -87,9 +87,7 @@ bool ThreadPool::inThreadPool() const { | |||||||
| } | } | ||||||
|  |  | ||||||
| void ThreadPool::run(std::function<void()> func) { | void ThreadPool::run(std::function<void()> func) { | ||||||
|   if (threads_.empty()) { |   TORCH_CHECK(threads_.size() > 0, "No threads to run a task"); | ||||||
|     throw std::runtime_error("No threads to run a task"); |  | ||||||
|   } |  | ||||||
|   std::unique_lock<std::mutex> lock(mutex_); |   std::unique_lock<std::mutex> lock(mutex_); | ||||||
|  |  | ||||||
|   // Set task and signal condition variable so that a worker thread will |   // Set task and signal condition variable so that a worker thread will | ||||||
|  | |||||||
Some files were not shown because too many files have changed in this diff Show More
		Reference in New Issue
	
	Block a user
	