mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-11-04 08:00:58 +08:00 
			
		
		
		
	Compare commits
	
		
			180 Commits
		
	
	
		
			add_op_to_
			...
			sdpa-bs-ze
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| e3e06d9e4d | |||
| 7ce723d21c | |||
| 4295a9a158 | |||
| 90d7be35e9 | |||
| 8d4e48831e | |||
| 90b30ebf7e | |||
| 173bcda436 | |||
| 6530bc70fb | |||
| 4c38887346 | |||
| 81fa4a204c | |||
| 4e6afa8c07 | |||
| 79aa88cc5d | |||
| fa4cb91846 | |||
| c58d0ad85d | |||
| 000f49551b | |||
| 9940e894ea | |||
| 27302a4932 | |||
| 507614ba43 | |||
| 86f9f1d0ab | |||
| 154e4d36e9 | |||
| a2b6afeac5 | |||
| 262830d86c | |||
| e4c01011c2 | |||
| a60d9e1f6d | |||
| f863550192 | |||
| 84b14f3a10 | |||
| 5121499f6b | |||
| 8f80892359 | |||
| cdb60e44eb | |||
| 25909d2629 | |||
| c7eee49525 | |||
| 621ba05107 | |||
| 39a70cead1 | |||
| d97f6550a2 | |||
| 516e58965a | |||
| b55b779ad3 | |||
| 74e53d0761 | |||
| 798a6d2be1 | |||
| b0e9c86971 | |||
| 661a56002f | |||
| c9bc00f016 | |||
| ec51b139e1 | |||
| eb83c3ca23 | |||
| 7924e3aacf | |||
| 78bcfcf870 | |||
| 1e2e7cb18b | |||
| 003601a70d | |||
| 1d58d5fe25 | |||
| de7fdfe41a | |||
| b31bad1b8f | |||
| 2efcf3ca98 | |||
| 761f946043 | |||
| 8aa465f18e | |||
| 0a5d68d92d | |||
| 42bd210fff | |||
| 1d13c314b3 | |||
| 0c9763a5a0 | |||
| 79a4a9c02e | |||
| 9d0b77f4cd | |||
| d486eee234 | |||
| cddd5f74ab | |||
| dfdb68e51f | |||
| 98c818320a | |||
| cc20b7ad72 | |||
| bc11a42b3f | |||
| 4fc06f2e0a | |||
| 82473c3d59 | |||
| b6a4236e5d | |||
| b04173be9b | |||
| 32ac38f85d | |||
| c9b49e506e | |||
| 6038e476e8 | |||
| 2c851c16e5 | |||
| 31584f2d91 | |||
| 0442125362 | |||
| fdcf402d82 | |||
| 13cda9b89e | |||
| fa6d911dda | |||
| 0db6bcc015 | |||
| 60ac039998 | |||
| 380d440d1c | |||
| 9038a30cee | |||
| 690c8c13b9 | |||
| 28ee6b62ed | |||
| 81577bdb3f | |||
| e67e3d95f3 | |||
| 27af8480ea | |||
| 6494cdc40c | |||
| ac7074efa2 | |||
| 263901cec4 | |||
| c12293dcbe | |||
| 5a4997dcae | |||
| 47f638eae7 | |||
| 882b834082 | |||
| b146ea411e | |||
| 8625ffbd45 | |||
| 0977cc4474 | |||
| d9a55faccc | |||
| 75b8295868 | |||
| defb6a80d8 | |||
| f8fccb1e48 | |||
| 5aac4cfce4 | |||
| baf91bbbfc | |||
| cbcb4f7768 | |||
| 2b93d5b450 | |||
| 6b7cd48e7e | |||
| bf5aa9e42e | |||
| b1eb6dede5 | |||
| 673060beae | |||
| 2e8e9a59a8 | |||
| fb277a5916 | |||
| 73fa0d0c63 | |||
| 36c21cc84e | |||
| 0b68814b44 | |||
| e64a814ae7 | |||
| 0b58d87aec | |||
| 757975ad50 | |||
| 291712026b | |||
| 3e77a2b478 | |||
| 82ef1b5db3 | |||
| 5f370f5c42 | |||
| 05b2e02cb4 | |||
| 12f742941d | |||
| 35180fafee | |||
| c746feb86a | |||
| c5f26db5bf | |||
| 18e99b6d45 | |||
| ab9e466928 | |||
| af4ba78543 | |||
| 282f39a4bc | |||
| a479769488 | |||
| 26c7375477 | |||
| d01f15152c | |||
| 4fae6968b1 | |||
| f9953e0f61 | |||
| 34ed7a8f0d | |||
| 2fde10d914 | |||
| 0a93295da0 | |||
| 4b898b51b9 | |||
| 550e3e6efb | |||
| 715449ca76 | |||
| 84d8d06fc3 | |||
| 60992d98b2 | |||
| 59e015e3a1 | |||
| 8904a5a7c9 | |||
| f5df9ca03a | |||
| 2998abd777 | |||
| e13580e41c | |||
| f3b8e15f20 | |||
| 5211f4c108 | |||
| ad9027b80d | |||
| a1005427bf | |||
| 35153d0846 | |||
| 7773a22cdb | |||
| 7cb467a169 | |||
| 12aac12b8d | |||
| 2b748d0a56 | |||
| 16745a882a | |||
| 8daef35cf1 | |||
| 51319ca090 | |||
| d311a3d1dc | |||
| 04adfe5ba9 | |||
| 4be1e3bf92 | |||
| e7592f4005 | |||
| d334c3649d | |||
| 9f82535c5a | |||
| 5b35fc8777 | |||
| 2f38eece7c | |||
| 830e789a55 | |||
| ad4dc52bf6 | |||
| dac9ed9790 | |||
| 1c7fe8f861 | |||
| 4e643422f6 | |||
| 3c3b278872 | |||
| 0bd12c1168 | |||
| ce8a7764e2 | |||
| d1269a0434 | |||
| c87cf1be32 | |||
| 2fc5e45a41 | |||
| f9022ba93b | 
@ -19,7 +19,7 @@ pip_install \
 | 
			
		||||
  transformers==4.36.2
 | 
			
		||||
 | 
			
		||||
pip_install coloredlogs packaging
 | 
			
		||||
pip_install onnxruntime==1.23.0
 | 
			
		||||
pip_install onnxruntime==1.23.1
 | 
			
		||||
pip_install onnxscript==0.5.4
 | 
			
		||||
 | 
			
		||||
# Cache the transformers model to be used later by ONNX tests. We need to run the transformers
 | 
			
		||||
 | 
			
		||||
@ -334,12 +334,12 @@ sympy==1.13.3
 | 
			
		||||
#Pinned versions:
 | 
			
		||||
#test that import:
 | 
			
		||||
 | 
			
		||||
onnx==1.18.0
 | 
			
		||||
onnx==1.19.1
 | 
			
		||||
#Description: Required by onnx tests, and mypy and test_public_bindings.py when checking torch.onnx._internal
 | 
			
		||||
#Pinned versions:
 | 
			
		||||
#test that import:
 | 
			
		||||
 | 
			
		||||
onnxscript==0.5.3
 | 
			
		||||
onnxscript==0.5.4
 | 
			
		||||
#Description: Required by mypy and test_public_bindings.py when checking torch.onnx._internal
 | 
			
		||||
#Pinned versions:
 | 
			
		||||
#test that import:
 | 
			
		||||
 | 
			
		||||
@ -6,7 +6,7 @@ dependencies = [
 | 
			
		||||
    "GitPython==3.1.45",
 | 
			
		||||
    "docker==7.1.0",
 | 
			
		||||
    "pytest==7.3.2",
 | 
			
		||||
    "uv==0.8.6"
 | 
			
		||||
    "uv==0.9.5"
 | 
			
		||||
]
 | 
			
		||||
 | 
			
		||||
[tool.setuptools]
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										354
									
								
								.claude/skills/pytorch-docstring.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										354
									
								
								.claude/skills/pytorch-docstring.md
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,354 @@
 | 
			
		||||
# PyTorch Docstring Writing Guide
 | 
			
		||||
 | 
			
		||||
This skill describes how to write docstrings for functions and methods in the PyTorch project, following the conventions in `torch/_tensor_docs.py` and `torch/nn/functional.py`.
 | 
			
		||||
 | 
			
		||||
## General Principles
 | 
			
		||||
 | 
			
		||||
- Use **raw strings** (`r"""..."""`) for all docstrings to avoid issues with LaTeX/math backslashes
 | 
			
		||||
- Follow **Sphinx/reStructuredText** (reST) format for documentation
 | 
			
		||||
- Be **concise but complete** - include all essential information
 | 
			
		||||
- Always include **examples** when possible
 | 
			
		||||
- Use **cross-references** to related functions/classes
 | 
			
		||||
 | 
			
		||||
## Docstring Structure
 | 
			
		||||
 | 
			
		||||
### 1. Function Signature (First Line)
 | 
			
		||||
 | 
			
		||||
Start with the function signature showing all parameters:
 | 
			
		||||
 | 
			
		||||
```python
 | 
			
		||||
r"""function_name(param1, param2, *, kwarg1=default1, kwarg2=default2) -> ReturnType
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
**Notes:**
 | 
			
		||||
- Include the function name
 | 
			
		||||
- Show positional and keyword-only arguments (use `*` separator)
 | 
			
		||||
- Include default values
 | 
			
		||||
- Show return type annotation
 | 
			
		||||
- This line should NOT end with a period
 | 
			
		||||
 | 
			
		||||
### 2. Brief Description
 | 
			
		||||
 | 
			
		||||
Provide a one-line description of what the function does:
 | 
			
		||||
 | 
			
		||||
```python
 | 
			
		||||
r"""conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1) -> Tensor
 | 
			
		||||
 | 
			
		||||
Applies a 2D convolution over an input image composed of several input
 | 
			
		||||
planes.
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
### 3. Mathematical Formulas (if applicable)
 | 
			
		||||
 | 
			
		||||
Use Sphinx math directives for mathematical expressions:
 | 
			
		||||
 | 
			
		||||
```python
 | 
			
		||||
.. math::
 | 
			
		||||
    \text{Softmax}(x_{i}) = \frac{\exp(x_i)}{\sum_j \exp(x_j)}
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
Or inline math: `:math:\`x^2\``
 | 
			
		||||
 | 
			
		||||
### 4. Cross-References
 | 
			
		||||
 | 
			
		||||
Link to related classes and functions using Sphinx roles:
 | 
			
		||||
 | 
			
		||||
- `:class:\`~torch.nn.ModuleName\`` - Link to a class
 | 
			
		||||
- `:func:\`torch.function_name\`` - Link to a function
 | 
			
		||||
- `:meth:\`~Tensor.method_name\`` - Link to a method
 | 
			
		||||
- `:attr:\`attribute_name\`` - Reference an attribute
 | 
			
		||||
- The `~` prefix shows only the last component (e.g., `Conv2d` instead of `torch.nn.Conv2d`)
 | 
			
		||||
 | 
			
		||||
**Example:**
 | 
			
		||||
```python
 | 
			
		||||
See :class:`~torch.nn.Conv2d` for details and output shape.
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
### 5. Notes and Warnings
 | 
			
		||||
 | 
			
		||||
Use admonitions for important information:
 | 
			
		||||
 | 
			
		||||
```python
 | 
			
		||||
.. note::
 | 
			
		||||
    This function doesn't work directly with NLLLoss,
 | 
			
		||||
    which expects the Log to be computed between the Softmax and itself.
 | 
			
		||||
    Use log_softmax instead (it's faster and has better numerical properties).
 | 
			
		||||
 | 
			
		||||
.. warning::
 | 
			
		||||
    :func:`new_tensor` always copies :attr:`data`. If you have a Tensor
 | 
			
		||||
    ``data`` and want to avoid a copy, use :func:`torch.Tensor.requires_grad_`
 | 
			
		||||
    or :func:`torch.Tensor.detach`.
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
### 6. Args Section
 | 
			
		||||
 | 
			
		||||
Document all parameters with type annotations and descriptions:
 | 
			
		||||
 | 
			
		||||
```python
 | 
			
		||||
Args:
 | 
			
		||||
    input (Tensor): input tensor of shape :math:`(\text{minibatch} , \text{in\_channels} , iH , iW)`
 | 
			
		||||
    weight (Tensor): filters of shape :math:`(\text{out\_channels} , kH , kW)`
 | 
			
		||||
    bias (Tensor, optional): optional bias tensor of shape :math:`(\text{out\_channels})`. Default: ``None``
 | 
			
		||||
    stride (int or tuple): the stride of the convolving kernel. Can be a single number or a
 | 
			
		||||
      tuple `(sH, sW)`. Default: 1
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
**Formatting rules:**
 | 
			
		||||
- Parameter name in **lowercase**
 | 
			
		||||
- Type in parentheses: `(Type)`, `(Type, optional)` for optional parameters
 | 
			
		||||
- Description follows the type
 | 
			
		||||
- For optional parameters, include "Default: ``value``" at the end
 | 
			
		||||
- Use double backticks for inline code: ``` ``None`` ```
 | 
			
		||||
- Indent continuation lines by 2 spaces
 | 
			
		||||
 | 
			
		||||
### 7. Keyword Args Section (if applicable)
 | 
			
		||||
 | 
			
		||||
Sometimes keyword arguments are documented separately:
 | 
			
		||||
 | 
			
		||||
```python
 | 
			
		||||
Keyword args:
 | 
			
		||||
    dtype (:class:`torch.dtype`, optional): the desired type of returned tensor.
 | 
			
		||||
        Default: if None, same :class:`torch.dtype` as this tensor.
 | 
			
		||||
    device (:class:`torch.device`, optional): the desired device of returned tensor.
 | 
			
		||||
        Default: if None, same :class:`torch.device` as this tensor.
 | 
			
		||||
    requires_grad (bool, optional): If autograd should record operations on the
 | 
			
		||||
        returned tensor. Default: ``False``.
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
### 8. Returns Section (if needed)
 | 
			
		||||
 | 
			
		||||
Document the return value:
 | 
			
		||||
 | 
			
		||||
```python
 | 
			
		||||
Returns:
 | 
			
		||||
    Tensor: Sampled tensor of same shape as `logits` from the Gumbel-Softmax distribution.
 | 
			
		||||
        If ``hard=True``, the returned samples will be one-hot, otherwise they will
 | 
			
		||||
        be probability distributions that sum to 1 across `dim`.
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
Or simply include it in the function signature line if obvious from context.
 | 
			
		||||
 | 
			
		||||
### 9. Examples Section
 | 
			
		||||
 | 
			
		||||
Always include examples when possible:
 | 
			
		||||
 | 
			
		||||
```python
 | 
			
		||||
Examples::
 | 
			
		||||
 | 
			
		||||
    >>> inputs = torch.randn(33, 16, 30)
 | 
			
		||||
    >>> filters = torch.randn(20, 16, 5)
 | 
			
		||||
    >>> F.conv1d(inputs, filters)
 | 
			
		||||
 | 
			
		||||
    >>> # With square kernels and equal stride
 | 
			
		||||
    >>> filters = torch.randn(8, 4, 3, 3)
 | 
			
		||||
    >>> inputs = torch.randn(1, 4, 5, 5)
 | 
			
		||||
    >>> F.conv2d(inputs, filters, padding=1)
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
**Formatting rules:**
 | 
			
		||||
- Use `Examples::` with double colon
 | 
			
		||||
- Use `>>>` prompt for Python code
 | 
			
		||||
- Include comments with `#` when helpful
 | 
			
		||||
- Show actual output when it helps understanding (indent without `>>>`)
 | 
			
		||||
 | 
			
		||||
### 10. External References
 | 
			
		||||
 | 
			
		||||
Link to papers or external documentation:
 | 
			
		||||
 | 
			
		||||
```python
 | 
			
		||||
.. _Link Name:
 | 
			
		||||
    https://arxiv.org/abs/1611.00712
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
Reference them in text: ```See `Link Name`_```
 | 
			
		||||
 | 
			
		||||
## Method Types
 | 
			
		||||
 | 
			
		||||
### Native Python Functions
 | 
			
		||||
 | 
			
		||||
For regular Python functions, use a standard docstring:
 | 
			
		||||
 | 
			
		||||
```python
 | 
			
		||||
def relu(input: Tensor, inplace: bool = False) -> Tensor:
 | 
			
		||||
    r"""relu(input, inplace=False) -> Tensor
 | 
			
		||||
 | 
			
		||||
    Applies the rectified linear unit function element-wise. See
 | 
			
		||||
    :class:`~torch.nn.ReLU` for more details.
 | 
			
		||||
    """
 | 
			
		||||
    # implementation
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
### C-Bound Functions (using add_docstr)
 | 
			
		||||
 | 
			
		||||
For C-bound functions, use `_add_docstr`:
 | 
			
		||||
 | 
			
		||||
```python
 | 
			
		||||
conv1d = _add_docstr(
 | 
			
		||||
    torch.conv1d,
 | 
			
		||||
    r"""
 | 
			
		||||
conv1d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1) -> Tensor
 | 
			
		||||
 | 
			
		||||
Applies a 1D convolution over an input signal composed of several input
 | 
			
		||||
planes.
 | 
			
		||||
 | 
			
		||||
See :class:`~torch.nn.Conv1d` for details and output shape.
 | 
			
		||||
 | 
			
		||||
Args:
 | 
			
		||||
    input: input tensor of shape :math:`(\text{minibatch} , \text{in\_channels} , iW)`
 | 
			
		||||
    weight: filters of shape :math:`(\text{out\_channels} , kW)`
 | 
			
		||||
    ...
 | 
			
		||||
""",
 | 
			
		||||
)
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
### In-Place Variants
 | 
			
		||||
 | 
			
		||||
For in-place operations (ending with `_`), reference the original:
 | 
			
		||||
 | 
			
		||||
```python
 | 
			
		||||
add_docstr_all(
 | 
			
		||||
    "abs_",
 | 
			
		||||
    r"""
 | 
			
		||||
abs_() -> Tensor
 | 
			
		||||
 | 
			
		||||
In-place version of :meth:`~Tensor.abs`
 | 
			
		||||
""",
 | 
			
		||||
)
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
### Alias Functions
 | 
			
		||||
 | 
			
		||||
For aliases, simply reference the original:
 | 
			
		||||
 | 
			
		||||
```python
 | 
			
		||||
add_docstr_all(
 | 
			
		||||
    "absolute",
 | 
			
		||||
    r"""
 | 
			
		||||
absolute() -> Tensor
 | 
			
		||||
 | 
			
		||||
Alias for :func:`abs`
 | 
			
		||||
""",
 | 
			
		||||
)
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
## Common Patterns
 | 
			
		||||
 | 
			
		||||
### Shape Documentation
 | 
			
		||||
 | 
			
		||||
Use LaTeX math notation for tensor shapes:
 | 
			
		||||
 | 
			
		||||
```python
 | 
			
		||||
:math:`(\text{minibatch} , \text{in\_channels} , iH , iW)`
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
### Reusable Argument Definitions
 | 
			
		||||
 | 
			
		||||
For commonly used arguments, define them once and reuse:
 | 
			
		||||
 | 
			
		||||
```python
 | 
			
		||||
common_args = parse_kwargs(
 | 
			
		||||
    """
 | 
			
		||||
    dtype (:class:`torch.dtype`, optional): the desired type of returned tensor.
 | 
			
		||||
        Default: if None, same as this tensor.
 | 
			
		||||
"""
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
# Then use with .format():
 | 
			
		||||
r"""
 | 
			
		||||
...
 | 
			
		||||
 | 
			
		||||
Keyword args:
 | 
			
		||||
    {dtype}
 | 
			
		||||
    {device}
 | 
			
		||||
""".format(**common_args)
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
### Template Insertion
 | 
			
		||||
 | 
			
		||||
Insert reproducibility notes or other common text:
 | 
			
		||||
 | 
			
		||||
```python
 | 
			
		||||
r"""
 | 
			
		||||
{tf32_note}
 | 
			
		||||
 | 
			
		||||
{cudnn_reproducibility_note}
 | 
			
		||||
""".format(**reproducibility_notes, **tf32_notes)
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
## Complete Example
 | 
			
		||||
 | 
			
		||||
Here's a complete example showing all elements:
 | 
			
		||||
 | 
			
		||||
```python
 | 
			
		||||
def gumbel_softmax(
 | 
			
		||||
    logits: Tensor,
 | 
			
		||||
    tau: float = 1,
 | 
			
		||||
    hard: bool = False,
 | 
			
		||||
    eps: float = 1e-10,
 | 
			
		||||
    dim: int = -1,
 | 
			
		||||
) -> Tensor:
 | 
			
		||||
    r"""
 | 
			
		||||
    Sample from the Gumbel-Softmax distribution and optionally discretize.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        logits (Tensor): `[..., num_features]` unnormalized log probabilities
 | 
			
		||||
        tau (float): non-negative scalar temperature
 | 
			
		||||
        hard (bool): if ``True``, the returned samples will be discretized as one-hot vectors,
 | 
			
		||||
              but will be differentiated as if it is the soft sample in autograd. Default: ``False``
 | 
			
		||||
        dim (int): A dimension along which softmax will be computed. Default: -1
 | 
			
		||||
 | 
			
		||||
    Returns:
 | 
			
		||||
        Tensor: Sampled tensor of same shape as `logits` from the Gumbel-Softmax distribution.
 | 
			
		||||
            If ``hard=True``, the returned samples will be one-hot, otherwise they will
 | 
			
		||||
            be probability distributions that sum to 1 across `dim`.
 | 
			
		||||
 | 
			
		||||
    .. note::
 | 
			
		||||
        This function is here for legacy reasons, may be removed from nn.Functional in the future.
 | 
			
		||||
 | 
			
		||||
    Examples::
 | 
			
		||||
        >>> logits = torch.randn(20, 32)
 | 
			
		||||
        >>> # Sample soft categorical using reparametrization trick:
 | 
			
		||||
        >>> F.gumbel_softmax(logits, tau=1, hard=False)
 | 
			
		||||
        >>> # Sample hard categorical using "Straight-through" trick:
 | 
			
		||||
        >>> F.gumbel_softmax(logits, tau=1, hard=True)
 | 
			
		||||
 | 
			
		||||
    .. _Link 1:
 | 
			
		||||
        https://arxiv.org/abs/1611.00712
 | 
			
		||||
    """
 | 
			
		||||
    # implementation
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
## Quick Checklist
 | 
			
		||||
 | 
			
		||||
When writing a PyTorch docstring, ensure:
 | 
			
		||||
 | 
			
		||||
- [ ] Use raw string (`r"""`)
 | 
			
		||||
- [ ] Include function signature on first line
 | 
			
		||||
- [ ] Provide brief description
 | 
			
		||||
- [ ] Document all parameters in Args section with types
 | 
			
		||||
- [ ] Include default values for optional parameters
 | 
			
		||||
- [ ] Use Sphinx cross-references (`:func:`, `:class:`, `:meth:`)
 | 
			
		||||
- [ ] Add mathematical formulas if applicable
 | 
			
		||||
- [ ] Include at least one example in Examples section
 | 
			
		||||
- [ ] Add warnings/notes for important caveats
 | 
			
		||||
- [ ] Link to related module class with `:class:`
 | 
			
		||||
- [ ] Use proper math notation for tensor shapes
 | 
			
		||||
- [ ] Follow consistent formatting and indentation
 | 
			
		||||
 | 
			
		||||
## Common Sphinx Roles Reference
 | 
			
		||||
 | 
			
		||||
- `:class:\`~torch.nn.Module\`` - Class reference
 | 
			
		||||
- `:func:\`torch.function\`` - Function reference
 | 
			
		||||
- `:meth:\`~Tensor.method\`` - Method reference
 | 
			
		||||
- `:attr:\`attribute\`` - Attribute reference
 | 
			
		||||
- `:math:\`equation\`` - Inline math
 | 
			
		||||
- `:ref:\`label\`` - Internal reference
 | 
			
		||||
- ``` ``code`` ``` - Inline code (use double backticks)
 | 
			
		||||
 | 
			
		||||
## Additional Notes
 | 
			
		||||
 | 
			
		||||
- **Indentation**: Use 4 spaces for code, 2 spaces for continuation of parameter descriptions
 | 
			
		||||
- **Line length**: Try to keep lines under 100 characters when possible
 | 
			
		||||
- **Periods**: End sentences with periods, but not the signature line
 | 
			
		||||
- **Backticks**: Use double backticks for code: ``` ``True`` ``None`` ``False`` ```
 | 
			
		||||
- **Types**: Common types are `Tensor`, `int`, `float`, `bool`, `str`, `tuple`, `list`, etc.
 | 
			
		||||
							
								
								
									
										7
									
								
								.github/actions/setup-rocm/action.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										7
									
								
								.github/actions/setup-rocm/action.yml
									
									
									
									
										vendored
									
									
								
							@ -124,3 +124,10 @@ runs:
 | 
			
		||||
      id: login-ecr
 | 
			
		||||
      continue-on-error: true
 | 
			
		||||
      uses: aws-actions/amazon-ecr-login@062b18b96a7aff071d4dc91bc00c4c1a7945b076 # v2.0.1
 | 
			
		||||
 | 
			
		||||
    - name: Preserve github env variables for use in docker
 | 
			
		||||
      shell: bash
 | 
			
		||||
      run: |
 | 
			
		||||
        env | grep '^GITHUB' >> "${RUNNER_TEMP}/github_env_${GITHUB_RUN_ID}"
 | 
			
		||||
        env | grep '^CI' >> "${RUNNER_TEMP}/github_env_${GITHUB_RUN_ID}"
 | 
			
		||||
        env | grep '^RUNNER' >> "${RUNNER_TEMP}/github_env_${GITHUB_RUN_ID}"
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										2
									
								
								.github/ci_commit_pins/vision.txt
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.github/ci_commit_pins/vision.txt
									
									
									
									
										vendored
									
									
								
							@ -1 +1 @@
 | 
			
		||||
faffd5cf673615583da6517275e361cb3dbc77e6
 | 
			
		||||
1752fe6809b74921644866275ab80244b96e80bc
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										2
									
								
								.github/ci_commit_pins/xla.txt
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.github/ci_commit_pins/xla.txt
									
									
									
									
										vendored
									
									
								
							@ -1 +1 @@
 | 
			
		||||
0fa6e3129e61143224663e1ec67980d12b7ec4eb
 | 
			
		||||
df6798dfb931ce7c7fe5bed2447cd1092a5981af
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										5
									
								
								.github/ci_configs/vllm/Dockerfile
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										5
									
								
								.github/ci_configs/vllm/Dockerfile
									
									
									
									
										vendored
									
									
								
							@ -283,6 +283,9 @@ RUN --mount=type=bind,source=${TORCH_WHEELS_PATH},target=/dist \
 | 
			
		||||
        uv pip install --system $(cat torch_build_versions.txt | xargs) --index-url https://download.pytorch.org/whl/nightly/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.'); \
 | 
			
		||||
    fi
 | 
			
		||||
 | 
			
		||||
RUN --mount=type=cache,target=/root/.cache/uv \
 | 
			
		||||
    uv pip install --system --pre apache-tvm-ffi==0.1.0b15
 | 
			
		||||
 | 
			
		||||
# Install the vllm wheel from previous stage
 | 
			
		||||
RUN --mount=type=cache,target=/root/.cache/uv \
 | 
			
		||||
    uv pip install --system /wheels/vllm/*.whl --verbose
 | 
			
		||||
@ -295,6 +298,8 @@ RUN --mount=type=cache,target=/root/.cache/uv \
 | 
			
		||||
ARG torch_cuda_arch_list='8.0;8.9;9.0a;10.0a;12.0'
 | 
			
		||||
ENV TORCH_CUDA_ARCH_LIST=${torch_cuda_arch_list}
 | 
			
		||||
 | 
			
		||||
# TODO(elainewy): remove this once vllm commit is updated, and install flashinfer from pip
 | 
			
		||||
# see https://github.com/pytorch/pytorch/pull/165274#issuecomment-3408531784
 | 
			
		||||
ARG FLASHINFER_GIT_REPO="https://github.com/flashinfer-ai/flashinfer.git"
 | 
			
		||||
ARG FLASHINFER_GIT_REF="v0.2.14.post1"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										9
									
								
								.github/label_to_label.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										9
									
								
								.github/label_to_label.yml
									
									
									
									
										vendored
									
									
								
							@ -15,6 +15,11 @@
 | 
			
		||||
  - "module: reinplacing"
 | 
			
		||||
  then:
 | 
			
		||||
  - "module: pt2-dispatcher"
 | 
			
		||||
- any:
 | 
			
		||||
  - "vllm-compile"
 | 
			
		||||
  then:
 | 
			
		||||
  - "module: vllm"
 | 
			
		||||
  - "oncall: pt2"
 | 
			
		||||
- any:
 | 
			
		||||
  - "module: vmap"
 | 
			
		||||
  then:
 | 
			
		||||
@ -27,10 +32,6 @@
 | 
			
		||||
  - "module: pt2 optimizer"
 | 
			
		||||
  then:
 | 
			
		||||
  - "module: dynamo"
 | 
			
		||||
- any:
 | 
			
		||||
  - "module: flex attention"
 | 
			
		||||
  then:
 | 
			
		||||
  - "module: higher order operators"
 | 
			
		||||
- any:
 | 
			
		||||
  - "module: aotinductor"
 | 
			
		||||
  then:
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										1
									
								
								.github/workflows/inductor-periodic.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								.github/workflows/inductor-periodic.yml
									
									
									
									
										vendored
									
									
								
							@ -88,7 +88,6 @@ jobs:
 | 
			
		||||
    with:
 | 
			
		||||
      build-environment: linux-jammy-rocm-py3_10
 | 
			
		||||
      docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3-benchmarks
 | 
			
		||||
      sync-tag: rocm-build
 | 
			
		||||
      test-matrix: |
 | 
			
		||||
        { include: [
 | 
			
		||||
          { config: "dynamo_eager_torchbench", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										15
									
								
								.github/workflows/periodic.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										15
									
								
								.github/workflows/periodic.yml
									
									
									
									
										vendored
									
									
								
							@ -147,15 +147,16 @@ jobs:
 | 
			
		||||
      runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
 | 
			
		||||
      build-environment: linux-jammy-cuda12.8-py3.10-gcc9-debug
 | 
			
		||||
      docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9
 | 
			
		||||
      cuda-arch-list: 8.9
 | 
			
		||||
      test-matrix: |
 | 
			
		||||
        { include: [
 | 
			
		||||
          { config: "default", shard: 1, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu", owners: ["oncall:debug-build"] },
 | 
			
		||||
          { config: "default", shard: 2, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu", owners: ["oncall:debug-build"] },
 | 
			
		||||
          { config: "default", shard: 3, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu", owners: ["oncall:debug-build"] },
 | 
			
		||||
          { config: "default", shard: 4, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu", owners: ["oncall:debug-build"] },
 | 
			
		||||
          { config: "default", shard: 5, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu", owners: ["oncall:debug-build"] },
 | 
			
		||||
          { config: "default", shard: 6, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu", owners: ["oncall:debug-build"] },
 | 
			
		||||
          { config: "default", shard: 7, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu", owners: ["oncall:debug-build"] },
 | 
			
		||||
          { config: "default", shard: 1, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", owners: ["oncall:debug-build"] },
 | 
			
		||||
          { config: "default", shard: 2, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", owners: ["oncall:debug-build"] },
 | 
			
		||||
          { config: "default", shard: 3, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", owners: ["oncall:debug-build"] },
 | 
			
		||||
          { config: "default", shard: 4, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", owners: ["oncall:debug-build"] },
 | 
			
		||||
          { config: "default", shard: 5, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", owners: ["oncall:debug-build"] },
 | 
			
		||||
          { config: "default", shard: 6, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", owners: ["oncall:debug-build"] },
 | 
			
		||||
          { config: "default", shard: 7, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", owners: ["oncall:debug-build"] },
 | 
			
		||||
        ]}
 | 
			
		||||
    secrets: inherit
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										3
									
								
								.github/workflows/pull.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										3
									
								
								.github/workflows/pull.yml
									
									
									
									
										vendored
									
									
								
							@ -347,7 +347,8 @@ jobs:
 | 
			
		||||
    uses: ./.github/workflows/_linux-build.yml
 | 
			
		||||
    needs: get-label-type
 | 
			
		||||
    with:
 | 
			
		||||
      sync-tag: linux-xpu-n-build
 | 
			
		||||
      # This should sync with the build in xpu.yml but xpu uses a larger runner
 | 
			
		||||
      # sync-tag: linux-xpu-n-build
 | 
			
		||||
      runner_prefix: ${{ needs.get-label-type.outputs.label-type }}
 | 
			
		||||
      build-environment: linux-jammy-xpu-n-py3.10
 | 
			
		||||
      docker-image-name: ci-image:pytorch-linux-jammy-xpu-n-py3
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										1
									
								
								.github/workflows/rocm-mi300.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								.github/workflows/rocm-mi300.yml
									
									
									
									
										vendored
									
									
								
							@ -45,7 +45,6 @@ jobs:
 | 
			
		||||
      runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
 | 
			
		||||
      build-environment: linux-noble-rocm-py3.12-mi300
 | 
			
		||||
      docker-image-name: ci-image:pytorch-linux-noble-rocm-n-py3
 | 
			
		||||
      sync-tag: rocm-build
 | 
			
		||||
      test-matrix: |
 | 
			
		||||
        { include: [
 | 
			
		||||
          { config: "default", shard: 1, num_shards: 6, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										1
									
								
								.github/workflows/rocm-mi355.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								.github/workflows/rocm-mi355.yml
									
									
									
									
										vendored
									
									
								
							@ -42,7 +42,6 @@ jobs:
 | 
			
		||||
      runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
 | 
			
		||||
      build-environment: linux-noble-rocm-py3.12-mi355
 | 
			
		||||
      docker-image-name: ci-image:pytorch-linux-noble-rocm-n-py3
 | 
			
		||||
      sync-tag: rocm-build
 | 
			
		||||
      test-matrix: |
 | 
			
		||||
        { include: [
 | 
			
		||||
          { config: "default", shard: 1, num_shards: 6, runner: "linux.rocm.gpu.mi355.1" },
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										12
									
								
								.github/workflows/rocm-navi31.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										12
									
								
								.github/workflows/rocm-navi31.yml
									
									
									
									
										vendored
									
									
								
							@ -26,11 +26,23 @@ jobs:
 | 
			
		||||
      id-token: write
 | 
			
		||||
      contents: read
 | 
			
		||||
 | 
			
		||||
  get-label-type:
 | 
			
		||||
    name: get-label-type
 | 
			
		||||
    uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main
 | 
			
		||||
    if: ${{ (github.event_name != 'schedule' || github.repository == 'pytorch/pytorch') && github.repository_owner == 'pytorch' }}
 | 
			
		||||
    with:
 | 
			
		||||
      triggering_actor: ${{ github.triggering_actor }}
 | 
			
		||||
      issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }}
 | 
			
		||||
      curr_branch: ${{ github.head_ref || github.ref_name }}
 | 
			
		||||
      curr_ref_type: ${{ github.ref_type }}
 | 
			
		||||
 | 
			
		||||
  linux-jammy-rocm-py3_10-build:
 | 
			
		||||
    if: ${{ (github.event_name != 'schedule' || github.repository == 'pytorch/pytorch') && github.repository_owner == 'pytorch' }}
 | 
			
		||||
    name: linux-jammy-rocm-py3.10
 | 
			
		||||
    uses: ./.github/workflows/_linux-build.yml
 | 
			
		||||
    needs: get-label-type
 | 
			
		||||
    with:
 | 
			
		||||
      runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
 | 
			
		||||
      build-environment: linux-jammy-rocm-py3.10
 | 
			
		||||
      docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3
 | 
			
		||||
      sync-tag: rocm-build
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										12
									
								
								.github/workflows/rocm.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										12
									
								
								.github/workflows/rocm.yml
									
									
									
									
										vendored
									
									
								
							@ -26,11 +26,23 @@ jobs:
 | 
			
		||||
      id-token: write
 | 
			
		||||
      contents: read
 | 
			
		||||
 | 
			
		||||
  get-label-type:
 | 
			
		||||
    name: get-label-type
 | 
			
		||||
    uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main
 | 
			
		||||
    if: ${{ (github.event_name != 'schedule' || github.repository == 'pytorch/pytorch') && github.repository_owner == 'pytorch' }}
 | 
			
		||||
    with:
 | 
			
		||||
      triggering_actor: ${{ github.triggering_actor }}
 | 
			
		||||
      issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }}
 | 
			
		||||
      curr_branch: ${{ github.head_ref || github.ref_name }}
 | 
			
		||||
      curr_ref_type: ${{ github.ref_type }}
 | 
			
		||||
 | 
			
		||||
  linux-jammy-rocm-py3_10-build:
 | 
			
		||||
    if: ${{ (github.event_name != 'schedule' || github.repository == 'pytorch/pytorch') && github.repository_owner == 'pytorch' }}
 | 
			
		||||
    name: linux-jammy-rocm-py3.10
 | 
			
		||||
    uses: ./.github/workflows/_linux-build.yml
 | 
			
		||||
    needs: get-label-type
 | 
			
		||||
    with:
 | 
			
		||||
      runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
 | 
			
		||||
      build-environment: linux-jammy-rocm-py3.10
 | 
			
		||||
      docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3
 | 
			
		||||
      sync-tag: rocm-build
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										147
									
								
								.github/workflows/trunk-tagging.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										147
									
								
								.github/workflows/trunk-tagging.yml
									
									
									
									
										vendored
									
									
								
							@ -58,8 +58,10 @@ jobs:
 | 
			
		||||
          else
 | 
			
		||||
            COMMIT_SHA="${{ github.sha }}"
 | 
			
		||||
          fi
 | 
			
		||||
          echo "sha=${COMMIT_SHA}" >> "${GITHUB_OUTPUT}"
 | 
			
		||||
          echo "tag_name=trunk/${COMMIT_SHA}" >> "${GITHUB_OUTPUT}"
 | 
			
		||||
          {
 | 
			
		||||
            echo "sha=${COMMIT_SHA}"
 | 
			
		||||
            echo "tag_name=trunk/${COMMIT_SHA}"
 | 
			
		||||
          } >> "${GITHUB_OUTPUT}"
 | 
			
		||||
 | 
			
		||||
      - name: Validate commit SHA
 | 
			
		||||
        run: |
 | 
			
		||||
@ -87,7 +89,7 @@ jobs:
 | 
			
		||||
            echo "✅ Commit ${COMMIT_SHA} is valid (automatic push trigger)"
 | 
			
		||||
          fi
 | 
			
		||||
 | 
			
		||||
      - name: Create and push tag with retry
 | 
			
		||||
      - name: Create and push tag(s) with retry
 | 
			
		||||
        id: check_tag
 | 
			
		||||
        env:
 | 
			
		||||
          TAG_NAME: ${{ steps.commit.outputs.tag_name }}
 | 
			
		||||
@ -112,14 +114,23 @@ jobs:
 | 
			
		||||
            return 1
 | 
			
		||||
          }
 | 
			
		||||
 | 
			
		||||
          # Exit early if tag already exists
 | 
			
		||||
          if check_tag_exists; then
 | 
			
		||||
            echo "✅ Tag already exists - no action needed"
 | 
			
		||||
            echo "exists=true" >> "${GITHUB_OUTPUT}"
 | 
			
		||||
            exit 0
 | 
			
		||||
          fi
 | 
			
		||||
          # Counters for summary reporting
 | 
			
		||||
          created_count=0
 | 
			
		||||
          skipped_count=0
 | 
			
		||||
          failed_count=0
 | 
			
		||||
 | 
			
		||||
          echo "Tag ${TAG_NAME} does not exist, proceeding with creation"
 | 
			
		||||
          # Always write outputs once on exit
 | 
			
		||||
          finish() {
 | 
			
		||||
            set +e
 | 
			
		||||
            if [ -n "${GITHUB_OUTPUT:-}" ]; then
 | 
			
		||||
              {
 | 
			
		||||
                echo "created_count=${created_count}"
 | 
			
		||||
                echo "skipped_count=${skipped_count}"
 | 
			
		||||
                echo "failed_count=${failed_count}"
 | 
			
		||||
              } >> "${GITHUB_OUTPUT}"
 | 
			
		||||
            fi
 | 
			
		||||
          }
 | 
			
		||||
          trap finish EXIT
 | 
			
		||||
 | 
			
		||||
          # Retry configuration
 | 
			
		||||
          MAX_RETRIES=5
 | 
			
		||||
@ -194,31 +205,111 @@ jobs:
 | 
			
		||||
            }
 | 
			
		||||
          }
 | 
			
		||||
 | 
			
		||||
          # Execute with retry
 | 
			
		||||
          if retry_with_backoff "tag_with_retry" "Creating tag ${TAG_NAME} for commit ${COMMIT_SHA}"; then
 | 
			
		||||
            echo "exists=false" >> "${GITHUB_OUTPUT}"
 | 
			
		||||
          # New behavior for push events: enumerate commits in the push and tag each one.
 | 
			
		||||
          # For workflow_dispatch, retain existing single-SHA behavior.
 | 
			
		||||
 | 
			
		||||
          # Always fetch tags once up front to improve idempotency in loops
 | 
			
		||||
          git fetch origin --tags --quiet || true
 | 
			
		||||
 | 
			
		||||
          if [ "${{ github.event_name }}" = "push" ]; then
 | 
			
		||||
            BEFORE_SHA="${{ github.event.before }}"
 | 
			
		||||
            AFTER_SHA="${{ github.sha }}"  # same as event.after
 | 
			
		||||
 | 
			
		||||
            # List commits introduced by this push (old..new), oldest first for stable ordering
 | 
			
		||||
            commits_file="$(mktemp)"
 | 
			
		||||
            git rev-list --reverse "${BEFORE_SHA}..${AFTER_SHA}" > "${commits_file}"
 | 
			
		||||
 | 
			
		||||
            if [ ! -s "${commits_file}" ]; then
 | 
			
		||||
              echo "No new commits found between ${BEFORE_SHA}..${AFTER_SHA}; nothing to tag."
 | 
			
		||||
              rm -f "${commits_file}"
 | 
			
		||||
              exit 0
 | 
			
		||||
            fi
 | 
			
		||||
 | 
			
		||||
            commit_count="$(wc -l < "${commits_file}" | tr -d ' ')"
 | 
			
		||||
            echo "Found ${commit_count} commit(s) to tag for push:"
 | 
			
		||||
            while IFS= read -r sha; do
 | 
			
		||||
              printf '  %s\n' "${sha}"
 | 
			
		||||
            done < "${commits_file}"
 | 
			
		||||
 | 
			
		||||
            while IFS= read -r sha; do
 | 
			
		||||
              TAG_NAME="trunk/${sha}"
 | 
			
		||||
              COMMIT_SHA="${sha}"
 | 
			
		||||
 | 
			
		||||
              # If tag already exists locally or remotely, skip (idempotent)
 | 
			
		||||
              if check_tag_exists; then
 | 
			
		||||
                echo "✅ Tag ${TAG_NAME} already exists - skipping"
 | 
			
		||||
                skipped_count=$((skipped_count + 1))
 | 
			
		||||
                continue
 | 
			
		||||
              fi
 | 
			
		||||
 | 
			
		||||
              echo "Tag ${TAG_NAME} does not exist, proceeding with creation"
 | 
			
		||||
 | 
			
		||||
              if retry_with_backoff "tag_with_retry" "Creating tag ${TAG_NAME} for commit ${COMMIT_SHA}"; then
 | 
			
		||||
                created_count=$((created_count + 1))
 | 
			
		||||
              else
 | 
			
		||||
                echo "Tag creation failed after all retry attempts for ${TAG_NAME}"
 | 
			
		||||
                failed_count=$((failed_count + 1))
 | 
			
		||||
              fi
 | 
			
		||||
            done < "${commits_file}"
 | 
			
		||||
 | 
			
		||||
            rm -f "${commits_file}"
 | 
			
		||||
 | 
			
		||||
            if [ "${failed_count}" -gt 0 ]; then
 | 
			
		||||
              exit 1
 | 
			
		||||
            fi
 | 
			
		||||
            exit 0
 | 
			
		||||
          else
 | 
			
		||||
            echo "Tag creation failed after all retry attempts"
 | 
			
		||||
            exit 1
 | 
			
		||||
            # workflow_dispatch path (single SHA tagging preserved)
 | 
			
		||||
 | 
			
		||||
            # Exit early if tag already exists
 | 
			
		||||
            if check_tag_exists; then
 | 
			
		||||
              echo "✅ Tag already exists - no action needed"
 | 
			
		||||
              skipped_count=1
 | 
			
		||||
              exit 0
 | 
			
		||||
            fi
 | 
			
		||||
 | 
			
		||||
            echo "Tag ${TAG_NAME} does not exist, proceeding with creation"
 | 
			
		||||
 | 
			
		||||
            if retry_with_backoff "tag_with_retry" "Creating tag ${TAG_NAME} for commit ${COMMIT_SHA}"; then
 | 
			
		||||
              created_count=1
 | 
			
		||||
              exit 0
 | 
			
		||||
            else
 | 
			
		||||
              echo "Tag creation failed after all retry attempts"
 | 
			
		||||
              failed_count=1
 | 
			
		||||
              exit 1
 | 
			
		||||
            fi
 | 
			
		||||
          fi
 | 
			
		||||
 | 
			
		||||
      - name: Tag creation summary
 | 
			
		||||
        if: always()
 | 
			
		||||
        run: |
 | 
			
		||||
          if [ "${{ steps.check_tag.outputs.exists }}" = "true" ]; then
 | 
			
		||||
            echo "✅ Tag ${{ steps.commit.outputs.tag_name }} already existed - no action needed"
 | 
			
		||||
          elif [ "${{ job.status }}" = "success" ]; then
 | 
			
		||||
            echo "✅ Successfully created tag ${{ steps.commit.outputs.tag_name }} for commit ${{ steps.commit.outputs.sha }}"
 | 
			
		||||
          if [ "${{ github.event_name }}" = "push" ]; then
 | 
			
		||||
            echo "Trigger: push on main"
 | 
			
		||||
            echo "Created: ${{ steps.check_tag.outputs.created_count }}"
 | 
			
		||||
            echo "Skipped (already existed): ${{ steps.check_tag.outputs.skipped_count }}"
 | 
			
		||||
            echo "Failed: ${{ steps.check_tag.outputs.failed_count }}"
 | 
			
		||||
            if [ "${{ steps.check_tag.outputs.failed_count }}" = "0" ]; then
 | 
			
		||||
              echo "✅ Completed tagging for push range ${{ github.event.before }}..${{ github.sha }}"
 | 
			
		||||
            else
 | 
			
		||||
              echo "❌ Some tags failed to create for push range ${{ github.event.before }}..${{ github.sha }}"
 | 
			
		||||
            fi
 | 
			
		||||
          else
 | 
			
		||||
            echo "❌ Failed to create tag ${{ steps.commit.outputs.tag_name }} for commit ${{ steps.commit.outputs.sha }}"
 | 
			
		||||
          fi
 | 
			
		||||
            if [ "${{ steps.check_tag.outputs.failed_count }}" = "0" ]; then
 | 
			
		||||
              if [ "${{ steps.check_tag.outputs.created_count }}" = "0" ]; then
 | 
			
		||||
                echo "✅ Tag ${{ steps.commit.outputs.tag_name }} already existed - no action needed"
 | 
			
		||||
              else
 | 
			
		||||
                echo "✅ Successfully created tag ${{ steps.commit.outputs.tag_name }} for commit ${{ steps.commit.outputs.sha }}"
 | 
			
		||||
              fi
 | 
			
		||||
            else
 | 
			
		||||
              echo "❌ Failed to create tag ${{ steps.commit.outputs.tag_name }} for commit ${{ steps.commit.outputs.sha }}"
 | 
			
		||||
            fi
 | 
			
		||||
 | 
			
		||||
          echo ""
 | 
			
		||||
          echo "Tag details:"
 | 
			
		||||
          echo "  Name: ${{ steps.commit.outputs.tag_name }}"
 | 
			
		||||
          echo "  Commit: ${{ steps.commit.outputs.sha }}"
 | 
			
		||||
          echo "  Trigger: ${{ github.event_name }}"
 | 
			
		||||
          if [ -n "${{ github.event.inputs.commit_sha }}" ]; then
 | 
			
		||||
            echo "  Manual commit: ${{ github.event.inputs.commit_sha }}"
 | 
			
		||||
            echo ""
 | 
			
		||||
            echo "Tag details:"
 | 
			
		||||
            echo "  Name: ${{ steps.commit.outputs.tag_name }}"
 | 
			
		||||
            echo "  Commit: ${{ steps.commit.outputs.sha }}"
 | 
			
		||||
            echo "  Trigger: ${{ github.event_name }}"
 | 
			
		||||
            if [ -n "${{ github.event.inputs.commit_sha }}" ]; then
 | 
			
		||||
              echo "  Manual commit: ${{ github.event.inputs.commit_sha }}"
 | 
			
		||||
            fi
 | 
			
		||||
          fi
 | 
			
		||||
 | 
			
		||||
@ -833,8 +833,7 @@ exclude_patterns = [
 | 
			
		||||
command = [
 | 
			
		||||
    'python3',
 | 
			
		||||
    'tools/linter/adapters/grep_linter.py',
 | 
			
		||||
    '--pattern=cudaSetDevice(',
 | 
			
		||||
    '--pattern=cudaGetDevice(',
 | 
			
		||||
    '--pattern=(cudaSetDevice|cudaGetDevice)\\(',
 | 
			
		||||
    '--linter-name=RAWCUDADEVICE',
 | 
			
		||||
    '--error-name=raw CUDA API usage',
 | 
			
		||||
    """--error-description=\
 | 
			
		||||
@ -1138,11 +1137,8 @@ command = [
 | 
			
		||||
[[linter]]
 | 
			
		||||
code = 'WORKFLOWSYNC'
 | 
			
		||||
include_patterns = [
 | 
			
		||||
    '.github/workflows/pull.yml',
 | 
			
		||||
    '.github/workflows/trunk.yml',
 | 
			
		||||
    '.github/workflows/periodic.yml',
 | 
			
		||||
    '.github/workflows/mac-mps.yml',
 | 
			
		||||
    '.github/workflows/slow.yml',
 | 
			
		||||
    '.github/workflows/*.yml',
 | 
			
		||||
    '.github/workflows/*.yaml',
 | 
			
		||||
]
 | 
			
		||||
command = [
 | 
			
		||||
    'python3',
 | 
			
		||||
 | 
			
		||||
@ -289,14 +289,15 @@ IF(USE_FBGEMM_GENAI)
 | 
			
		||||
 | 
			
		||||
    set_target_properties(fbgemm_genai PROPERTIES POSITION_INDEPENDENT_CODE ON)
 | 
			
		||||
 | 
			
		||||
    set(fbgemm_genai_mx8mx8bf16_grouped
 | 
			
		||||
    set(fbgemm_genai_cuh
 | 
			
		||||
      "${FBGEMM_GENAI_SRCS}/cutlass_extensions/mx8mx8bf16_grouped/"
 | 
			
		||||
      "${FBGEMM_GENAI_SRCS}/"
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    target_include_directories(fbgemm_genai PRIVATE
 | 
			
		||||
      ${FBGEMM_THIRD_PARTY}/cutlass/include
 | 
			
		||||
      ${FBGEMM_THIRD_PARTY}/cutlass/tools/util/include
 | 
			
		||||
      ${fbgemm_genai_mx8mx8bf16_grouped}
 | 
			
		||||
      ${fbgemm_genai_cuh}
 | 
			
		||||
      ${FBGEMM_GENAI_SRCS}/common/include/   # includes fbgemm_gpu/quantize/utils.h, fbgemm_gpu/quantize/tuning_cache.hpp
 | 
			
		||||
      ${FBGEMM_GENAI_SRCS}/include/          # includes fbgemm_gpu/torch_ops.h
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
@ -19,6 +19,7 @@
 | 
			
		||||
#include <ATen/detail/MPSHooksInterface.h>
 | 
			
		||||
#include <ATen/detail/MTIAHooksInterface.h>
 | 
			
		||||
#include <ATen/detail/PrivateUse1HooksInterface.h>
 | 
			
		||||
#include <ATen/detail/XLAHooksInterface.h>
 | 
			
		||||
#include <ATen/detail/XPUHooksInterface.h>
 | 
			
		||||
#include <c10/core/QEngine.h>
 | 
			
		||||
#include <c10/core/impl/DeviceGuardImplInterface.h>
 | 
			
		||||
@ -88,6 +89,8 @@ class TORCH_API Context {
 | 
			
		||||
      return at::detail::getHIPHooks();
 | 
			
		||||
    } else if (opt_device_type == at::kHPU) {
 | 
			
		||||
      return at::detail::getHPUHooks();
 | 
			
		||||
    } else if (opt_device_type == at::kXLA) {
 | 
			
		||||
      return at::detail::getXLAHooks();
 | 
			
		||||
    } else {
 | 
			
		||||
      TORCH_CHECK(
 | 
			
		||||
          false,
 | 
			
		||||
@ -196,7 +199,7 @@ class TORCH_API Context {
 | 
			
		||||
    return c10::impl::hasDeviceGuardImpl(c10::DeviceType::IPU);
 | 
			
		||||
  }
 | 
			
		||||
  static bool hasXLA() {
 | 
			
		||||
    return c10::impl::hasDeviceGuardImpl(c10::DeviceType::XLA);
 | 
			
		||||
    return detail::getXLAHooks().hasXLA();
 | 
			
		||||
  }
 | 
			
		||||
  static bool hasXPU() {
 | 
			
		||||
    return detail::getXPUHooks().hasXPU();
 | 
			
		||||
 | 
			
		||||
@ -59,9 +59,7 @@ struct TORCH_API Generator {
 | 
			
		||||
 | 
			
		||||
  explicit Generator(c10::intrusive_ptr<c10::GeneratorImpl> gen_impl)
 | 
			
		||||
   : impl_(std::move(gen_impl)) {
 | 
			
		||||
    if (impl_.get() == nullptr) {
 | 
			
		||||
      throw std::runtime_error("GeneratorImpl with nullptr is not supported");
 | 
			
		||||
    }
 | 
			
		||||
    TORCH_CHECK(impl_.get(), "GeneratorImpl with nullptr is not supported");
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  bool operator==(const Generator& rhs) const {
 | 
			
		||||
 | 
			
		||||
@ -111,9 +111,7 @@ class TORCH_API TensorBase {
 | 
			
		||||
  explicit TensorBase(
 | 
			
		||||
      c10::intrusive_ptr<TensorImpl, UndefinedTensorImpl> tensor_impl)
 | 
			
		||||
      : impl_(std::move(tensor_impl)) {
 | 
			
		||||
    if (impl_.get() == nullptr) {
 | 
			
		||||
      throw std::runtime_error("TensorImpl with nullptr is not supported");
 | 
			
		||||
    }
 | 
			
		||||
    TORCH_CHECK(impl_.get(), "TensorImpl with nullptr is not supported");
 | 
			
		||||
  }
 | 
			
		||||
  TensorBase(const TensorBase&) = default;
 | 
			
		||||
  TensorBase(TensorBase&&) noexcept = default;
 | 
			
		||||
 | 
			
		||||
@ -109,6 +109,10 @@ TORCH_LIBRARY_IMPL(_, AutogradHPU, m) {
 | 
			
		||||
  m.fallback(AUTOGRAD_FALLBACK);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TORCH_LIBRARY_IMPL(_, AutogradPrivateUse1, m) {
 | 
			
		||||
  m.fallback(AUTOGRAD_FALLBACK);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#undef AUTOGRAD_FALLBACK
 | 
			
		||||
 | 
			
		||||
} // namespace
 | 
			
		||||
 | 
			
		||||
@ -442,11 +442,17 @@ RegistrationHandleRAII Dispatcher::registerFallback(DispatchKey dispatchKey, Ker
 | 
			
		||||
 | 
			
		||||
  auto idx = getDispatchTableIndexForDispatchKey(dispatchKey);
 | 
			
		||||
  TORCH_CHECK(idx >= 0 && static_cast<uint64_t>(idx) < backendFallbackKernels_.size(), "idx=", idx);
 | 
			
		||||
  // NB: Perserve BC for registering fallback for AutogradPrivateUse1 multiple time,
 | 
			
		||||
  // refer to https://github.com/pytorch/pytorch/issues/163979 for more informations.
 | 
			
		||||
  TORCH_CHECK(
 | 
			
		||||
    !backendFallbackKernels_[idx].kernel.isValid(),
 | 
			
		||||
    "Tried to register multiple backend fallbacks for the same dispatch key ", dispatchKey, "; previous registration ",
 | 
			
		||||
    backendFallbackKernels_[idx].debug, ", new registration ", debug
 | 
			
		||||
  );
 | 
			
		||||
      dispatchKey == DispatchKey::AutogradPrivateUse1 ||
 | 
			
		||||
          !backendFallbackKernels_[idx].kernel.isValid(),
 | 
			
		||||
      "Tried to register multiple backend fallbacks for the same dispatch key ",
 | 
			
		||||
      dispatchKey,
 | 
			
		||||
      "; previous registration ",
 | 
			
		||||
      backendFallbackKernels_[idx].debug,
 | 
			
		||||
      ", new registration ",
 | 
			
		||||
      debug);
 | 
			
		||||
  // NB: inferred function schema is always nullptr for fallbacks, as fallbacks
 | 
			
		||||
  // cannot be unboxed
 | 
			
		||||
  backendFallbackKernels_[idx] = impl::AnnotatedKernel(std::move(kernel), nullptr, std::move(debug));
 | 
			
		||||
 | 
			
		||||
@ -68,11 +68,7 @@ Symbol InternedStrings::_symbol(const std::string& s) {
 | 
			
		||||
    return it->second;
 | 
			
		||||
 | 
			
		||||
  auto pos = s.find("::");
 | 
			
		||||
  if (pos == std::string::npos) {
 | 
			
		||||
    std::stringstream ss;
 | 
			
		||||
    ss << "all symbols must have a namespace, <namespace>::<string>, but found: " << s;
 | 
			
		||||
    throw std::runtime_error(ss.str());
 | 
			
		||||
  }
 | 
			
		||||
  TORCH_CHECK(pos != std::string::npos, "all symbols must have a namespace, <namespace>::<string>, but found: ", s);
 | 
			
		||||
  Symbol ns = _symbol("namespaces::" + s.substr(0, pos));
 | 
			
		||||
 | 
			
		||||
  Symbol sym(sym_to_info_.size());
 | 
			
		||||
@ -121,12 +117,7 @@ std::string Symbol::domainString() const {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
Symbol Symbol::fromDomainAndUnqualString(const std::string & d, const std::string & s) {
 | 
			
		||||
  if (d.compare(0, domain_prefix().size(), domain_prefix()) != 0) {
 | 
			
		||||
    std::ostringstream ss;
 | 
			
		||||
    ss << "Symbol: domain string is expected to be prefixed with '"
 | 
			
		||||
       << domain_prefix() << "', e.g. 'org.pytorch.aten'";
 | 
			
		||||
    throw std::runtime_error(ss.str());
 | 
			
		||||
  }
 | 
			
		||||
  TORCH_CHECK(d.compare(0, domain_prefix().size(), domain_prefix()) == 0, "Symbol: domain string is expected to be prefixed with '", domain_prefix(), "', e.g. 'org.pytorch.aten'");
 | 
			
		||||
  std::string qualString = d.substr(domain_prefix().size()) + "::" + s;
 | 
			
		||||
  return fromQualString(qualString);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -7,6 +7,7 @@
 | 
			
		||||
#include <ATen/core/jit_type.h>
 | 
			
		||||
#include <ATen/core/stack.h>
 | 
			
		||||
#include <ATen/core/type_factory.h>
 | 
			
		||||
#include <c10/util/Exception.h>
 | 
			
		||||
#include <c10/util/StringUtil.h>
 | 
			
		||||
#include <c10/util/hash.h>
 | 
			
		||||
#include <c10/util/irange.h>
 | 
			
		||||
@ -412,7 +413,7 @@ size_t IValue::hash(const IValue& v) {
 | 
			
		||||
    case Tag::Enum:
 | 
			
		||||
    case Tag::Stream:
 | 
			
		||||
    case Tag::Uninitialized:
 | 
			
		||||
      throw std::runtime_error(
 | 
			
		||||
      TORCH_CHECK(false,
 | 
			
		||||
          "unhashable type: '" + v.type()->repr_str() + "'");
 | 
			
		||||
  }
 | 
			
		||||
  // the above switch should be exhaustive
 | 
			
		||||
 | 
			
		||||
@ -8,6 +8,7 @@
 | 
			
		||||
#include <ATen/core/type_factory.h>
 | 
			
		||||
#include <ATen/core/qualified_name.h>
 | 
			
		||||
#include <c10/util/TypeList.h>
 | 
			
		||||
#include <c10/util/Exception.h>
 | 
			
		||||
#include <optional>
 | 
			
		||||
#include <c10/core/SymFloat.h>
 | 
			
		||||
#include <c10/core/SymBool.h>
 | 
			
		||||
@ -116,10 +117,8 @@ struct SingleElementType : public SharedType {
 | 
			
		||||
 | 
			
		||||
 protected:
 | 
			
		||||
  SingleElementType(TypePtr elem) : SharedType(Kind), elem(std::move(elem)) {
 | 
			
		||||
    if (!this->elem) {
 | 
			
		||||
      throw std::runtime_error(c10::str(
 | 
			
		||||
    TORCH_CHECK(this->elem, c10::str(
 | 
			
		||||
            "Can not create ", typeKindToString(Kind), " with None type"));
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
 private:
 | 
			
		||||
@ -416,16 +415,12 @@ struct TORCH_API SymbolicShape {
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  ShapeSymbol operator[](size_t i) const {
 | 
			
		||||
    if (!dims_) {
 | 
			
		||||
      throw std::runtime_error("Rank isn't fixed");
 | 
			
		||||
    }
 | 
			
		||||
    TORCH_CHECK(dims_, "Rank isn't fixed");
 | 
			
		||||
    return (*dims_).at(i);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  ShapeSymbol at(size_t i) const {
 | 
			
		||||
    if (!dims_) {
 | 
			
		||||
      throw std::runtime_error("Rank isn't fixed");
 | 
			
		||||
    }
 | 
			
		||||
    TORCH_CHECK(dims_, "Rank isn't fixed");
 | 
			
		||||
    return (*dims_).at(i);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
@ -520,9 +515,7 @@ struct VaryingShape {
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  const std::optional<T> &operator[](size_t i) const {
 | 
			
		||||
    if (!dims_) {
 | 
			
		||||
      throw std::runtime_error("Rank isn't fixed");
 | 
			
		||||
    }
 | 
			
		||||
    TORCH_CHECK(dims_, "Rank isn't fixed");
 | 
			
		||||
    return (*dims_).at(i);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
@ -957,9 +950,7 @@ struct TORCH_API DictType : public SharedType {
 | 
			
		||||
 | 
			
		||||
  TypePtr createWithContained(
 | 
			
		||||
      std::vector<TypePtr> contained_types) const override {
 | 
			
		||||
    if (contained_types.size() != 2) {
 | 
			
		||||
      throw std::runtime_error("Expected 2 contained types");
 | 
			
		||||
    }
 | 
			
		||||
    TORCH_CHECK(contained_types.size() == 2, "Expected 2 contained types");
 | 
			
		||||
    return create(std::move(contained_types.at(0)), std::move(contained_types.at(1)));
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -185,11 +185,11 @@ struct TORCH_API Type {
 | 
			
		||||
        : repr_(nullptr) {}
 | 
			
		||||
 | 
			
		||||
    /* implicit */ SingletonOrSharedTypePtr(SingletonTypePtr<T> p)
 | 
			
		||||
        : repr_(p) {}
 | 
			
		||||
        : repr_(makeSingletonSharedPtr(p.get())) {}
 | 
			
		||||
 | 
			
		||||
    template <typename U, std::enable_if_t<std::is_convertible_v<U*, T*>, bool> = true>
 | 
			
		||||
    /* implicit */ SingletonOrSharedTypePtr(SingletonTypePtr<U> p)
 | 
			
		||||
        : repr_(SingletonTypePtr<T>(p.get())) {}
 | 
			
		||||
        : repr_(makeSingletonSharedPtr(static_cast<T*>(p.get()))) {}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    // We need to support construction from T* for pybind. The problem
 | 
			
		||||
@ -202,8 +202,8 @@ struct TORCH_API Type {
 | 
			
		||||
    // Case 2: if T is exactly Type, we need to do a dynamic_cast to
 | 
			
		||||
    // check if it's a SharedType and do the right thing.
 | 
			
		||||
    //
 | 
			
		||||
    // Case 3: Otherwise, T is not a SharedType. (debug-check this
 | 
			
		||||
    // assumption!) Use a singleton pointer.
 | 
			
		||||
    // Case 3: Otherwise, T is not a SharedType. Use a singleton
 | 
			
		||||
    // pointer.
 | 
			
		||||
 | 
			
		||||
    template <typename U = T, std::enable_if_t<std::is_base_of_v<SharedType, U>, bool> = true>
 | 
			
		||||
    /* implicit */ SingletonOrSharedTypePtr(T* p) : SingletonOrSharedTypePtr(static_cast<typename detail::as_shared_type<U>::type>(p)->shared_from_this()) {}
 | 
			
		||||
@ -211,15 +211,15 @@ struct TORCH_API Type {
 | 
			
		||||
    template <typename U = T, std::enable_if_t<std::is_same_v<Type, U>, bool> = true>
 | 
			
		||||
    /* implicit */ SingletonOrSharedTypePtr(T* p) {
 | 
			
		||||
      if (auto* shared_p = dynamic_cast<typename detail::as_shared_type<U>::type>(p)) {
 | 
			
		||||
        repr_ = Repr(shared_p->shared_from_this());
 | 
			
		||||
        repr_ = shared_p->shared_from_this();
 | 
			
		||||
      } else {
 | 
			
		||||
        repr_ = Repr(p);
 | 
			
		||||
        repr_ = makeSingletonSharedPtr(p);
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    template <typename U = T, std::enable_if_t<!std::is_same_v<Type, U> && !std::is_base_of_v<SharedType, U>, bool> = true>
 | 
			
		||||
    /* implicit */ SingletonOrSharedTypePtr(T* p)
 | 
			
		||||
        : repr_(p) {
 | 
			
		||||
        : repr_(makeSingletonSharedPtr(p)) {
 | 
			
		||||
      TORCH_INTERNAL_ASSERT_DEBUG_ONLY(dynamic_cast<typename detail::as_shared_type<U>::type>(p) == nullptr);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
@ -230,19 +230,19 @@ struct TORCH_API Type {
 | 
			
		||||
    ~SingletonOrSharedTypePtr() = default;
 | 
			
		||||
 | 
			
		||||
    T* get() const {
 | 
			
		||||
      return repr_.isSharedAndNonNull() ? repr_.shared_.repr_.get() : static_cast<T*>(repr_.rawRepr().first);
 | 
			
		||||
      return repr_.get();
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    operator bool() const {
 | 
			
		||||
      return repr_.isNonNull();
 | 
			
		||||
      return repr_ != nullptr;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    bool operator==(std::nullptr_t) const {
 | 
			
		||||
      return !repr_.isNonNull();
 | 
			
		||||
      return repr_ == nullptr;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    bool operator!=(std::nullptr_t) const {
 | 
			
		||||
      return repr_.isNonNull();
 | 
			
		||||
      return repr_ != nullptr;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    template <typename U = T, std::enable_if_t<!std::is_same_v<std::remove_const_t<U>, void>, bool> = true>
 | 
			
		||||
@ -255,138 +255,14 @@ struct TORCH_API Type {
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
  private:
 | 
			
		||||
    // NOTE: SharedPtrWrapper exists to work around a baffling bug in
 | 
			
		||||
    // nvcc; see comment in destroy() below.
 | 
			
		||||
    struct SharedPtrWrapper {
 | 
			
		||||
      SharedPtrWrapper(std::shared_ptr<T> &&x)
 | 
			
		||||
          : repr_(std::move(x)) {}
 | 
			
		||||
      std::shared_ptr<T> repr_;
 | 
			
		||||
    };
 | 
			
		||||
    union Repr {
 | 
			
		||||
      Repr() : Repr(nullptr) {}
 | 
			
		||||
    // Use shared_ptr's aliasing constructor to create a non-owning pointer
 | 
			
		||||
    // to a singleton. The lifetime is tied to the null shared_ptr, so there's
 | 
			
		||||
    // no reference counting overhead for the singleton itself.
 | 
			
		||||
    static std::shared_ptr<T> makeSingletonSharedPtr(T* ptr) {
 | 
			
		||||
      return std::shared_ptr<T>(std::shared_ptr<T>(), ptr);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
      explicit Repr(std::shared_ptr<T> x)
 | 
			
		||||
          : shared_(std::move(x)) {}
 | 
			
		||||
 | 
			
		||||
      explicit Repr(std::nullptr_t)
 | 
			
		||||
          : singletonRepr_(nullptr) {}
 | 
			
		||||
 | 
			
		||||
      explicit Repr(SingletonTypePtr<T> p)
 | 
			
		||||
          : singletonRepr_(p.get()) {}
 | 
			
		||||
 | 
			
		||||
      ~Repr() {
 | 
			
		||||
        destroy();
 | 
			
		||||
      }
 | 
			
		||||
 | 
			
		||||
      // NOTE: the only non-UB way to access our null state is through
 | 
			
		||||
      // rawRepr(), because our copy operation doesn't preserve which
 | 
			
		||||
      // union member is active for null pointers.
 | 
			
		||||
      Repr(const Repr& rhs) {
 | 
			
		||||
        if (rhs.isSharedAndNonNull()) {
 | 
			
		||||
          new (&shared_) SharedPtrWrapper(rhs.shared_);
 | 
			
		||||
        } else {
 | 
			
		||||
          singletonRepr_.singleton_ = static_cast<T*>(rhs.rawRepr().first);
 | 
			
		||||
          TORCH_INTERNAL_ASSERT_DEBUG_ONLY(rhs.singletonRepr_.unused_ == nullptr);
 | 
			
		||||
          singletonRepr_.unused_ = nullptr;
 | 
			
		||||
        }
 | 
			
		||||
      }
 | 
			
		||||
 | 
			
		||||
      Repr(Repr&& rhs) noexcept {
 | 
			
		||||
        if (rhs.isSharedAndNonNull()) {
 | 
			
		||||
          new (&shared_) SharedPtrWrapper(std::move(rhs.shared_));
 | 
			
		||||
        } else {
 | 
			
		||||
          singletonRepr_.singleton_ = static_cast<T*>(rhs.rawRepr().first);
 | 
			
		||||
          TORCH_INTERNAL_ASSERT_DEBUG_ONLY(rhs.singletonRepr_.unused_ == nullptr);
 | 
			
		||||
          singletonRepr_.unused_ = nullptr;
 | 
			
		||||
        }
 | 
			
		||||
      }
 | 
			
		||||
 | 
			
		||||
      Repr& operator=(const Repr& rhs) {
 | 
			
		||||
        if (&rhs == this) {
 | 
			
		||||
          return *this;
 | 
			
		||||
        }
 | 
			
		||||
        if (rhs.isSharedAndNonNull()) {
 | 
			
		||||
          if (isSharedAndNonNull()) {
 | 
			
		||||
            shared_ = rhs.shared_;
 | 
			
		||||
          } else {
 | 
			
		||||
            new (&shared_) SharedPtrWrapper(rhs.shared_);
 | 
			
		||||
          }
 | 
			
		||||
        } else {
 | 
			
		||||
          if (isSharedAndNonNull()) {
 | 
			
		||||
            destroy();
 | 
			
		||||
          }
 | 
			
		||||
          singletonRepr_.singleton_ = static_cast<T*>(rhs.rawRepr().first);
 | 
			
		||||
          TORCH_INTERNAL_ASSERT_DEBUG_ONLY(rhs.rawRepr().nullIfSingleton_ == nullptr);
 | 
			
		||||
          singletonRepr_.unused_ = nullptr;
 | 
			
		||||
        }
 | 
			
		||||
        return *this;
 | 
			
		||||
      }
 | 
			
		||||
 | 
			
		||||
      Repr& operator=(Repr&& rhs) noexcept {
 | 
			
		||||
        if (&rhs == this) {
 | 
			
		||||
          return *this;
 | 
			
		||||
        }
 | 
			
		||||
        if (rhs.isSharedAndNonNull()) {
 | 
			
		||||
          if (isSharedAndNonNull()) {
 | 
			
		||||
            shared_ = std::move(rhs.shared_);
 | 
			
		||||
          } else {
 | 
			
		||||
            new (&shared_) SharedPtrWrapper(std::move(rhs.shared_));
 | 
			
		||||
          }
 | 
			
		||||
        } else {
 | 
			
		||||
          if (isSharedAndNonNull()) {
 | 
			
		||||
            destroy();
 | 
			
		||||
          }
 | 
			
		||||
          singletonRepr_.singleton_ = static_cast<T*>(rhs.rawRepr().first);
 | 
			
		||||
          TORCH_INTERNAL_ASSERT_DEBUG_ONLY(rhs.rawRepr().nullIfSingleton_ == nullptr);
 | 
			
		||||
          singletonRepr_.unused_ = nullptr;
 | 
			
		||||
        }
 | 
			
		||||
        return *this;
 | 
			
		||||
      }
 | 
			
		||||
 | 
			
		||||
      SharedPtrWrapper shared_;
 | 
			
		||||
 | 
			
		||||
      struct SingletonRepr {
 | 
			
		||||
        explicit SingletonRepr(T* s) : singleton_(s) {}
 | 
			
		||||
        T* singleton_;
 | 
			
		||||
        void* unused_ = nullptr;
 | 
			
		||||
      } singletonRepr_;
 | 
			
		||||
      struct RawRepr {
 | 
			
		||||
        void* first;
 | 
			
		||||
        void* nullIfSingleton_;
 | 
			
		||||
      };
 | 
			
		||||
 | 
			
		||||
      // It is UB to read the singleton part of Repr if it was
 | 
			
		||||
      // constructed as a shared_ptr and vice versa, but memcpying out
 | 
			
		||||
      // the representation is always OK, so here's an accessor to obey
 | 
			
		||||
      // the letter of the law.
 | 
			
		||||
      RawRepr rawRepr() const {
 | 
			
		||||
        RawRepr repr{};
 | 
			
		||||
        memcpy(&repr, reinterpret_cast<const char *>(this), sizeof(RawRepr));
 | 
			
		||||
        return repr;
 | 
			
		||||
      }
 | 
			
		||||
 | 
			
		||||
      bool isNonNull() const {
 | 
			
		||||
        auto repr = rawRepr();
 | 
			
		||||
        TORCH_INTERNAL_ASSERT_DEBUG_ONLY(repr.nullIfSingleton_ == nullptr || repr.first != nullptr);
 | 
			
		||||
        return repr.first != nullptr;
 | 
			
		||||
      }
 | 
			
		||||
 | 
			
		||||
      bool isSharedAndNonNull() const {
 | 
			
		||||
        return rawRepr().nullIfSingleton_ != nullptr;
 | 
			
		||||
      }
 | 
			
		||||
 | 
			
		||||
     private:
 | 
			
		||||
      void destroy() {
 | 
			
		||||
        if (isSharedAndNonNull()) {
 | 
			
		||||
          // Without SharedPtrWrapper, this line would read
 | 
			
		||||
          // `shared_.~shared_ptr()` and nvcc would complain with
 | 
			
		||||
          // "error: expected primary-expression before '>' token"
 | 
			
		||||
          // referring to the "t" in "shared_ptr". SharedPtrWrapper
 | 
			
		||||
          // exists to work around this compiler bug.
 | 
			
		||||
          shared_.~SharedPtrWrapper();
 | 
			
		||||
        }
 | 
			
		||||
      }
 | 
			
		||||
    } repr_;
 | 
			
		||||
    std::shared_ptr<T> repr_;
 | 
			
		||||
  };
 | 
			
		||||
 | 
			
		||||
  using TypePtr = SingletonOrSharedTypePtr<Type>;
 | 
			
		||||
 | 
			
		||||
@ -8,6 +8,7 @@
 | 
			
		||||
#include <ATen/core/jit_type.h>
 | 
			
		||||
#include <c10/macros/Macros.h>
 | 
			
		||||
#include <c10/util/env.h>
 | 
			
		||||
#include <c10/util/Exception.h>
 | 
			
		||||
#include <c10/util/flat_hash_map.h>
 | 
			
		||||
#include <c10/util/irange.h>
 | 
			
		||||
#include <array>
 | 
			
		||||
@ -826,9 +827,7 @@ TupleType::TupleType(
 | 
			
		||||
    : NamedType(TypeKind::TupleType, std::move(name)),
 | 
			
		||||
      elements_(std::move(elements)),
 | 
			
		||||
      has_free_variables_(std::any_of(elements_.begin(), elements_.end(), [](const TypePtr& v) {
 | 
			
		||||
        if (!v) {
 | 
			
		||||
          throw std::runtime_error("Can not create tuple with None type");
 | 
			
		||||
        }
 | 
			
		||||
        TORCH_CHECK(v, "Can not create tuple with None type");
 | 
			
		||||
        return v->hasFreeVariables();
 | 
			
		||||
      })), schema_(std::move(schema)) {
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -104,71 +104,6 @@ class Vectorized<float> {
 | 
			
		||||
    }
 | 
			
		||||
    return b;
 | 
			
		||||
  }
 | 
			
		||||
  // Implementation is picked from
 | 
			
		||||
  // https://github.com/ARM-software/ComputeLibrary/blob/v25.01/src/core/NEON/SVEMath.inl#L105
 | 
			
		||||
  inline svfloat32_t svexp_f32_z(svbool_t pg, svfloat32_t x) const {
 | 
			
		||||
    const auto c1 =
 | 
			
		||||
        svreinterpret_f32_u32(svdup_n_u32(0x3f7ffff6)); // x^1: 0x1.ffffecp-1f
 | 
			
		||||
    const auto c2 =
 | 
			
		||||
        svreinterpret_f32_u32(svdup_n_u32(0x3efffedb)); // x^2: 0x1.fffdb6p-2f
 | 
			
		||||
    const auto c3 =
 | 
			
		||||
        svreinterpret_f32_u32(svdup_n_u32(0x3e2aaf33)); // x^3: 0x1.555e66p-3f
 | 
			
		||||
    const auto c4 =
 | 
			
		||||
        svreinterpret_f32_u32(svdup_n_u32(0x3d2b9f17)); // x^4: 0x1.573e2ep-5f
 | 
			
		||||
    const auto c5 =
 | 
			
		||||
        svreinterpret_f32_u32(svdup_n_u32(0x3c072010)); // x^5: 0x1.0e4020p-7f
 | 
			
		||||
    const auto shift = svreinterpret_f32_u32(
 | 
			
		||||
        svdup_n_u32(0x4b00007f)); // 2^23 + 127 = 0x1.0000fep23f
 | 
			
		||||
    const auto inv_ln2 = svreinterpret_f32_u32(
 | 
			
		||||
        svdup_n_u32(0x3fb8aa3b)); // 1 / ln(2) = 0x1.715476p+0f
 | 
			
		||||
    const auto neg_ln2_hi = svreinterpret_f32_u32(svdup_n_u32(
 | 
			
		||||
        0xbf317200)); // -ln(2) from bits  -1 to -19: -0x1.62e400p-1f
 | 
			
		||||
    const auto neg_ln2_lo = svreinterpret_f32_u32(svdup_n_u32(
 | 
			
		||||
        0xb5bfbe8e)); // -ln(2) from bits -20 to -42: -0x1.7f7d1cp-20f
 | 
			
		||||
    const auto inf = svdup_n_f32(std::numeric_limits<float>::infinity());
 | 
			
		||||
    const auto max_input = svdup_n_f32(88.37f); // Approximately ln(2^127.5)
 | 
			
		||||
    const auto zero = svdup_n_f32(0.f);
 | 
			
		||||
    const auto min_input = svdup_n_f32(-86.64f); // Approximately ln(2^-125)
 | 
			
		||||
    // Range reduction:
 | 
			
		||||
    //   e^x = 2^n * e^r
 | 
			
		||||
    // where:
 | 
			
		||||
    //   n = floor(x / ln(2))
 | 
			
		||||
    //   r = x - n * ln(2)
 | 
			
		||||
    //
 | 
			
		||||
    // By adding x / ln(2) with 2^23 + 127 (shift):
 | 
			
		||||
    //   * As FP32 fraction part only has 23-bits, the addition of 2^23 + 127
 | 
			
		||||
    //   forces decimal part
 | 
			
		||||
    //     of x / ln(2) out of the result. The integer part of x / ln(2) (i.e.
 | 
			
		||||
    //     n) + 127 will occupy the whole fraction part of z in FP32 format.
 | 
			
		||||
    //     Subtracting 2^23 + 127 (shift) from z will result in the integer part
 | 
			
		||||
    //     of x / ln(2) (i.e. n) because the decimal part has been pushed out
 | 
			
		||||
    //     and lost.
 | 
			
		||||
    //   * The addition of 127 makes the FP32 fraction part of z ready to be
 | 
			
		||||
    //   used as the exponent
 | 
			
		||||
    //     in FP32 format. Left shifting z by 23 bits will result in 2^n.
 | 
			
		||||
    const auto z = svmla_f32_z(pg, shift, x, inv_ln2);
 | 
			
		||||
    const auto n = svsub_f32_z(pg, z, shift);
 | 
			
		||||
    const auto scale = svreinterpret_f32_u32(
 | 
			
		||||
        svlsl_n_u32_z(pg, svreinterpret_u32_f32(z), 23)); // 2^n
 | 
			
		||||
    // The calculation of n * ln(2) is done using 2 steps to achieve accuracy
 | 
			
		||||
    // beyond FP32. This outperforms longer Taylor series (3-4 tabs) both in
 | 
			
		||||
    // term of accuracy and performance.
 | 
			
		||||
    const auto r_hi = svmla_f32_z(pg, x, n, neg_ln2_hi);
 | 
			
		||||
    const auto r = svmla_f32_z(pg, r_hi, n, neg_ln2_lo);
 | 
			
		||||
    // Compute the truncated Taylor series of e^r.
 | 
			
		||||
    //   poly = scale * (1 + c1 * r + c2 * r^2 + c3 * r^3 + c4 * r^4 + c5 * r^5)
 | 
			
		||||
    const auto r2 = svmul_f32_z(pg, r, r);
 | 
			
		||||
    const auto p1 = svmul_f32_z(pg, c1, r);
 | 
			
		||||
    const auto p23 = svmla_f32_z(pg, c2, c3, r);
 | 
			
		||||
    const auto p45 = svmla_f32_z(pg, c4, c5, r);
 | 
			
		||||
    const auto p2345 = svmla_f32_z(pg, p23, p45, r2);
 | 
			
		||||
    const auto p12345 = svmla_f32_z(pg, p1, p2345, r2);
 | 
			
		||||
    auto poly = svmla_f32_z(pg, scale, p12345, scale);
 | 
			
		||||
    // Handle underflow and overflow.
 | 
			
		||||
    poly = svsel_f32(svcmplt_f32(pg, x, min_input), zero, poly);
 | 
			
		||||
    poly = svsel_f32(svcmpgt_f32(pg, x, max_input), inf, poly);
 | 
			
		||||
    return poly;
 | 
			
		||||
  }
 | 
			
		||||
  static Vectorized<float> loadu(const void* ptr, int64_t count = size()) {
 | 
			
		||||
    if (count == size())
 | 
			
		||||
      return svld1_f32(ptrue, reinterpret_cast<const float*>(ptr));
 | 
			
		||||
@ -313,11 +248,41 @@ class Vectorized<float> {
 | 
			
		||||
    return USE_SLEEF(
 | 
			
		||||
        Vectorized<float>(Sleef_expm1fx_u10sve(values)), map(std::expm1));
 | 
			
		||||
  }
 | 
			
		||||
  // Implementation copied from Arm Optimized Routines:
 | 
			
		||||
  // https://github.com/ARM-software/optimized-routines/blob/master/math/aarch64/sve/expf.c
 | 
			
		||||
  Vectorized<float> exp_u20() const {
 | 
			
		||||
    return exp();
 | 
			
		||||
    // special case to handle special inputs that are too large or too small
 | 
			
		||||
    // i.e. where there's at least one element x, s.t. |x| >= 87.3...
 | 
			
		||||
    svbool_t is_special_case = svacgt(svptrue_b32(), values, 0x1.5d5e2ap+6f);
 | 
			
		||||
    if (svptest_any(svptrue_b32(), is_special_case)) {
 | 
			
		||||
      return exp();
 | 
			
		||||
    }
 | 
			
		||||
    const svfloat32_t ln2_hi = svdup_n_f32(0x1.62e4p-1f);
 | 
			
		||||
    const svfloat32_t ln2_lo = svdup_n_f32(0x1.7f7d1cp-20f);
 | 
			
		||||
    const svfloat32_t c1 = svdup_n_f32(0.5f);
 | 
			
		||||
    const svfloat32_t inv_ln2 = svdup_n_f32(0x1.715476p+0f);
 | 
			
		||||
 | 
			
		||||
    const float shift = 0x1.803f8p17f;
 | 
			
		||||
 | 
			
		||||
    /* n = round(x/(ln2/N)).  */
 | 
			
		||||
    svfloat32_t z = svmad_x(svptrue_b32(), inv_ln2, values, shift);
 | 
			
		||||
    svfloat32_t n = svsub_x(svptrue_b32(), z, shift);
 | 
			
		||||
 | 
			
		||||
    /* r = x - n*ln2/N.  */
 | 
			
		||||
    svfloat32_t r = values;
 | 
			
		||||
    r = svmls_x(svptrue_b32(), r, n, ln2_hi);
 | 
			
		||||
    r = svmls_x(svptrue_b32(), r, n, ln2_lo);
 | 
			
		||||
 | 
			
		||||
    /* scale = 2^(n/N).  */
 | 
			
		||||
    svfloat32_t scale = svexpa(svreinterpret_u32(z));
 | 
			
		||||
 | 
			
		||||
    /* poly(r) = exp(r) - 1 ~= r + 0.5 r^2.  */
 | 
			
		||||
    svfloat32_t r2 = svmul_x(svptrue_b32(), r, r);
 | 
			
		||||
    svfloat32_t poly = svmla_x(svptrue_b32(), r, r2, c1);
 | 
			
		||||
    return svmla_x(svptrue_b32(), scale, scale, poly);
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized<float> fexp_u20() const {
 | 
			
		||||
    return exp();
 | 
			
		||||
    return exp_u20();
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized<float> fmod(const Vectorized<float>& q) const {USE_SLEEF(
 | 
			
		||||
      { return Vectorized<float>(Sleef_fmodfx_sve(values, q)); },
 | 
			
		||||
@ -453,9 +418,11 @@ class Vectorized<float> {
 | 
			
		||||
        ptrue, svmax_f32_z(ptrue, values, CONST_MIN_TANH), CONST_MAX_TANH);
 | 
			
		||||
 | 
			
		||||
    // Step 2: Calculate exp(2 * x), where x is the clamped value.
 | 
			
		||||
    // svmul_f32_z computes 2 * x, and svexp_f32_z computes the exponential of
 | 
			
		||||
    // the result.
 | 
			
		||||
    svfloat32_t exp2x = svexp_f32_z(ptrue, svmul_f32_z(ptrue, CONST_2, x));
 | 
			
		||||
    // svmul_f32_z computes 2 * x, and exp_u20() computes the exponential of
 | 
			
		||||
    // the result (via Vectorized<float>, then auto-converts back to
 | 
			
		||||
    // svfloat32_t).
 | 
			
		||||
    svfloat32_t exp2x =
 | 
			
		||||
        Vectorized<float>(svmul_f32_z(ptrue, CONST_2, x)).exp_u20();
 | 
			
		||||
 | 
			
		||||
    // Step 3: Calculate the numerator of the tanh function, which is exp(2x)
 | 
			
		||||
    // - 1.
 | 
			
		||||
 | 
			
		||||
@ -6,9 +6,11 @@
 | 
			
		||||
#ifdef __aarch64__
 | 
			
		||||
#if !defined(CPU_CAPABILITY_SVE)
 | 
			
		||||
#include <ATen/cpu/vec/vec128/vec128_bfloat16_neon.h>
 | 
			
		||||
#include <ATen/cpu/vec/vec128/vec128_double_neon.h>
 | 
			
		||||
#include <ATen/cpu/vec/vec128/vec128_float_neon.h>
 | 
			
		||||
#include <ATen/cpu/vec/vec128/vec128_half_neon.h>
 | 
			
		||||
#include <ATen/cpu/vec/vec128/vec128_int_aarch64.h>
 | 
			
		||||
#include <ATen/cpu/vec/vec128/vec128_uint_aarch64.h>
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
#include <ATen/cpu/vec/vec128/vec128_convert.h>
 | 
			
		||||
 | 
			
		||||
@ -354,9 +354,47 @@ class Vectorized<c10::BFloat16> : public Vectorized16<
 | 
			
		||||
 | 
			
		||||
  DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(abs)
 | 
			
		||||
  Vectorized frac() const;
 | 
			
		||||
  DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(neg)
 | 
			
		||||
  DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(trunc)
 | 
			
		||||
  DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(sqrt)
 | 
			
		||||
 | 
			
		||||
#ifdef __ARM_FEATURE_BF16
 | 
			
		||||
  Vectorized<c10::BFloat16> neg() const {
 | 
			
		||||
    return -values;
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized<c10::BFloat16> reciprocal() const {
 | 
			
		||||
    return 1.0f / values;
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized<c10::BFloat16> operator==(
 | 
			
		||||
      const Vectorized<c10::BFloat16>& other) const {
 | 
			
		||||
    return values == other.values;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  Vectorized<c10::BFloat16> operator!=(
 | 
			
		||||
      const Vectorized<c10::BFloat16>& other) const {
 | 
			
		||||
    return values != other.values;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  Vectorized<c10::BFloat16> operator<(
 | 
			
		||||
      const Vectorized<c10::BFloat16>& other) const {
 | 
			
		||||
    return values < other.values;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  Vectorized<c10::BFloat16> operator<=(
 | 
			
		||||
      const Vectorized<c10::BFloat16>& other) const {
 | 
			
		||||
    return values <= other.values;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  Vectorized<c10::BFloat16> operator>(
 | 
			
		||||
      const Vectorized<c10::BFloat16>& other) const {
 | 
			
		||||
    return values > other.values;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  Vectorized<c10::BFloat16> operator>=(
 | 
			
		||||
      const Vectorized<c10::BFloat16>& other) const {
 | 
			
		||||
    return values >= other.values;
 | 
			
		||||
  }
 | 
			
		||||
#else
 | 
			
		||||
  DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(neg)
 | 
			
		||||
  DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(reciprocal)
 | 
			
		||||
  DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator==)
 | 
			
		||||
  DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator!=)
 | 
			
		||||
@ -364,6 +402,7 @@ class Vectorized<c10::BFloat16> : public Vectorized16<
 | 
			
		||||
  DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator<=)
 | 
			
		||||
  DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator>)
 | 
			
		||||
  DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator>=)
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
#undef DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD
 | 
			
		||||
#undef DEFINE_BINARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD
 | 
			
		||||
@ -412,28 +451,52 @@ template <>
 | 
			
		||||
Vectorized<c10::BFloat16> inline operator+(
 | 
			
		||||
    const Vectorized<c10::BFloat16>& a,
 | 
			
		||||
    const Vectorized<c10::BFloat16>& b) {
 | 
			
		||||
#ifdef __ARM_FEATURE_BF16
 | 
			
		||||
  bfloat16x8_t x = a;
 | 
			
		||||
  bfloat16x8_t y = b;
 | 
			
		||||
  return x + y;
 | 
			
		||||
#else
 | 
			
		||||
  return binary_operator_via_float(std::plus<Vectorized<float>>(), a, b);
 | 
			
		||||
#endif
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<c10::BFloat16> inline operator-(
 | 
			
		||||
    const Vectorized<c10::BFloat16>& a,
 | 
			
		||||
    const Vectorized<c10::BFloat16>& b) {
 | 
			
		||||
#ifdef __ARM_FEATURE_BF16
 | 
			
		||||
  bfloat16x8_t x = a;
 | 
			
		||||
  bfloat16x8_t y = b;
 | 
			
		||||
  return x - y;
 | 
			
		||||
#else
 | 
			
		||||
  return binary_operator_via_float(std::minus<Vectorized<float>>(), a, b);
 | 
			
		||||
#endif
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<c10::BFloat16> inline operator*(
 | 
			
		||||
    const Vectorized<c10::BFloat16>& a,
 | 
			
		||||
    const Vectorized<c10::BFloat16>& b) {
 | 
			
		||||
#ifdef __ARM_FEATURE_BF16
 | 
			
		||||
  bfloat16x8_t x = a;
 | 
			
		||||
  bfloat16x8_t y = b;
 | 
			
		||||
  return x * y;
 | 
			
		||||
#else
 | 
			
		||||
  return binary_operator_via_float(std::multiplies<Vectorized<float>>(), a, b);
 | 
			
		||||
#endif
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<c10::BFloat16> inline operator/(
 | 
			
		||||
    const Vectorized<c10::BFloat16>& a,
 | 
			
		||||
    const Vectorized<c10::BFloat16>& b) {
 | 
			
		||||
#ifdef __ARM_FEATURE_BF16
 | 
			
		||||
  bfloat16x8_t x = a;
 | 
			
		||||
  bfloat16x8_t y = b;
 | 
			
		||||
  return x / y;
 | 
			
		||||
#else
 | 
			
		||||
  return binary_operator_via_float(std::divides<Vectorized<float>>(), a, b);
 | 
			
		||||
#endif
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// frac. Implement this here so we can use subtraction
 | 
			
		||||
@ -544,12 +607,19 @@ Vectorized<c10::BFloat16> inline fmadd(
 | 
			
		||||
    const Vectorized<c10::BFloat16>& a,
 | 
			
		||||
    const Vectorized<c10::BFloat16>& b,
 | 
			
		||||
    const Vectorized<c10::BFloat16>& c) {
 | 
			
		||||
#ifdef __ARM_FEATURE_BF16
 | 
			
		||||
  bfloat16x8_t x = a;
 | 
			
		||||
  bfloat16x8_t y = b;
 | 
			
		||||
  bfloat16x8_t z = c;
 | 
			
		||||
  return x * y + z;
 | 
			
		||||
#else
 | 
			
		||||
  // NOTE [BF16 FMA]: There isn't an FMA that accumulates into BF16!  Also,
 | 
			
		||||
  // vbfmlalbq_f32 and vbfmlaltq_f32 take the even and odd-numbered
 | 
			
		||||
  // elements, not the bottom and top half, so they don't seem
 | 
			
		||||
  // particularly useful here. Ideally we would include dot product in
 | 
			
		||||
  // the Vectorized interface...
 | 
			
		||||
  return a * b + c;
 | 
			
		||||
#endif
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
@ -557,8 +627,15 @@ Vectorized<c10::BFloat16> inline fnmadd(
 | 
			
		||||
    const Vectorized<c10::BFloat16>& a,
 | 
			
		||||
    const Vectorized<c10::BFloat16>& b,
 | 
			
		||||
    const Vectorized<c10::BFloat16>& c) {
 | 
			
		||||
#ifdef __ARM_FEATURE_BF16
 | 
			
		||||
  bfloat16x8_t x = a;
 | 
			
		||||
  bfloat16x8_t y = b;
 | 
			
		||||
  bfloat16x8_t z = c;
 | 
			
		||||
  return (-x) * y + z;
 | 
			
		||||
#else
 | 
			
		||||
  // See NOTE [BF16 FMA] above.
 | 
			
		||||
  return -a * b + c;
 | 
			
		||||
#endif
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
@ -566,8 +643,15 @@ Vectorized<c10::BFloat16> inline fmsub(
 | 
			
		||||
    const Vectorized<c10::BFloat16>& a,
 | 
			
		||||
    const Vectorized<c10::BFloat16>& b,
 | 
			
		||||
    const Vectorized<c10::BFloat16>& c) {
 | 
			
		||||
#ifdef __ARM_FEATURE_BF16
 | 
			
		||||
  bfloat16x8_t x = a;
 | 
			
		||||
  bfloat16x8_t y = b;
 | 
			
		||||
  bfloat16x8_t z = c;
 | 
			
		||||
  return x * y - z;
 | 
			
		||||
#else
 | 
			
		||||
  // See NOTE [BF16 FMA] above.
 | 
			
		||||
  return a * b - c;
 | 
			
		||||
#endif
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
@ -575,8 +659,15 @@ Vectorized<c10::BFloat16> inline fnmsub(
 | 
			
		||||
    const Vectorized<c10::BFloat16>& a,
 | 
			
		||||
    const Vectorized<c10::BFloat16>& b,
 | 
			
		||||
    const Vectorized<c10::BFloat16>& c) {
 | 
			
		||||
#ifdef __ARM_FEATURE_BF16
 | 
			
		||||
  bfloat16x8_t x = a;
 | 
			
		||||
  bfloat16x8_t y = b;
 | 
			
		||||
  bfloat16x8_t z = c;
 | 
			
		||||
  return (-x) * y - z;
 | 
			
		||||
#else
 | 
			
		||||
  // See NOTE [BF16 FMA] above.
 | 
			
		||||
  return -a * b - c;
 | 
			
		||||
#endif
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#endif // !defined(C10_MOBILE) && defined(__aarch64__)
 | 
			
		||||
 | 
			
		||||
@ -5,6 +5,114 @@
 | 
			
		||||
namespace at::vec {
 | 
			
		||||
inline namespace CPU_CAPABILITY {
 | 
			
		||||
#if (defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE256))
 | 
			
		||||
 | 
			
		||||
// Enable auto-vectorization for GCC-13+ and clang-17+
 | 
			
		||||
// GCC-12 has a bug: gcc.gnu.org/bugzilla/show_bug.cgi?id=117001
 | 
			
		||||
#if __GNUC__ > 12 || (defined(__clang__) && (__clang_major__ >= 17))
 | 
			
		||||
 | 
			
		||||
template <typename from_type, typename to_type>
 | 
			
		||||
inline void convertImpl(
 | 
			
		||||
    const from_type* __restrict src,
 | 
			
		||||
    to_type* __restrict dst,
 | 
			
		||||
    int64_t n) {
 | 
			
		||||
  uint64_t len = static_cast<uint64_t>(n);
 | 
			
		||||
  for (uint64_t i = 0; i < len; i++) {
 | 
			
		||||
    dst[i] = static_cast<to_type>(src[i]);
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#define CONVERT_TEMPLATE(from_type, to_type)                           \
 | 
			
		||||
  template <>                                                          \
 | 
			
		||||
  inline void convert(const from_type* src, to_type* dst, int64_t n) { \
 | 
			
		||||
    return convertImpl<from_type, to_type>(src, dst, n);               \
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
CONVERT_TEMPLATE(uint8_t, uint8_t)
 | 
			
		||||
CONVERT_TEMPLATE(uint8_t, int8_t)
 | 
			
		||||
CONVERT_TEMPLATE(uint8_t, int16_t)
 | 
			
		||||
CONVERT_TEMPLATE(uint8_t, int32_t)
 | 
			
		||||
CONVERT_TEMPLATE(uint8_t, int64_t)
 | 
			
		||||
CONVERT_TEMPLATE(uint8_t, float)
 | 
			
		||||
CONVERT_TEMPLATE(uint8_t, double)
 | 
			
		||||
CONVERT_TEMPLATE(int8_t, uint8_t)
 | 
			
		||||
CONVERT_TEMPLATE(int8_t, int8_t)
 | 
			
		||||
CONVERT_TEMPLATE(int8_t, int16_t)
 | 
			
		||||
CONVERT_TEMPLATE(int8_t, int32_t)
 | 
			
		||||
CONVERT_TEMPLATE(int8_t, int64_t)
 | 
			
		||||
CONVERT_TEMPLATE(int8_t, float)
 | 
			
		||||
CONVERT_TEMPLATE(int8_t, double)
 | 
			
		||||
CONVERT_TEMPLATE(int16_t, uint8_t)
 | 
			
		||||
CONVERT_TEMPLATE(int16_t, int8_t)
 | 
			
		||||
CONVERT_TEMPLATE(int16_t, int16_t)
 | 
			
		||||
CONVERT_TEMPLATE(int16_t, int32_t)
 | 
			
		||||
CONVERT_TEMPLATE(int16_t, int64_t)
 | 
			
		||||
CONVERT_TEMPLATE(int16_t, float)
 | 
			
		||||
CONVERT_TEMPLATE(int16_t, double)
 | 
			
		||||
CONVERT_TEMPLATE(int32_t, uint8_t)
 | 
			
		||||
CONVERT_TEMPLATE(int32_t, int8_t)
 | 
			
		||||
CONVERT_TEMPLATE(int32_t, int16_t)
 | 
			
		||||
CONVERT_TEMPLATE(int32_t, int32_t)
 | 
			
		||||
CONVERT_TEMPLATE(int32_t, int64_t)
 | 
			
		||||
CONVERT_TEMPLATE(int32_t, float)
 | 
			
		||||
CONVERT_TEMPLATE(int32_t, double)
 | 
			
		||||
CONVERT_TEMPLATE(int64_t, uint8_t)
 | 
			
		||||
CONVERT_TEMPLATE(int64_t, int8_t)
 | 
			
		||||
CONVERT_TEMPLATE(int64_t, int16_t)
 | 
			
		||||
CONVERT_TEMPLATE(int64_t, int32_t)
 | 
			
		||||
CONVERT_TEMPLATE(int64_t, int64_t)
 | 
			
		||||
CONVERT_TEMPLATE(int64_t, float)
 | 
			
		||||
CONVERT_TEMPLATE(int64_t, double)
 | 
			
		||||
CONVERT_TEMPLATE(float, uint8_t)
 | 
			
		||||
CONVERT_TEMPLATE(float, int8_t)
 | 
			
		||||
CONVERT_TEMPLATE(float, int16_t)
 | 
			
		||||
CONVERT_TEMPLATE(float, int32_t)
 | 
			
		||||
CONVERT_TEMPLATE(float, int64_t)
 | 
			
		||||
CONVERT_TEMPLATE(float, float)
 | 
			
		||||
CONVERT_TEMPLATE(float, double)
 | 
			
		||||
CONVERT_TEMPLATE(double, uint8_t)
 | 
			
		||||
CONVERT_TEMPLATE(double, int8_t)
 | 
			
		||||
CONVERT_TEMPLATE(double, int16_t)
 | 
			
		||||
CONVERT_TEMPLATE(double, int32_t)
 | 
			
		||||
CONVERT_TEMPLATE(double, int64_t)
 | 
			
		||||
CONVERT_TEMPLATE(double, float)
 | 
			
		||||
CONVERT_TEMPLATE(double, double)
 | 
			
		||||
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
 | 
			
		||||
CONVERT_TEMPLATE(float16_t, uint8_t)
 | 
			
		||||
CONVERT_TEMPLATE(float16_t, int8_t)
 | 
			
		||||
CONVERT_TEMPLATE(float16_t, int16_t)
 | 
			
		||||
CONVERT_TEMPLATE(float16_t, int32_t)
 | 
			
		||||
CONVERT_TEMPLATE(float16_t, int64_t)
 | 
			
		||||
CONVERT_TEMPLATE(float16_t, float16_t)
 | 
			
		||||
CONVERT_TEMPLATE(float16_t, float)
 | 
			
		||||
CONVERT_TEMPLATE(float16_t, double)
 | 
			
		||||
CONVERT_TEMPLATE(uint8_t, float16_t)
 | 
			
		||||
CONVERT_TEMPLATE(int8_t, float16_t)
 | 
			
		||||
CONVERT_TEMPLATE(int16_t, float16_t)
 | 
			
		||||
CONVERT_TEMPLATE(int32_t, float16_t)
 | 
			
		||||
CONVERT_TEMPLATE(int64_t, float16_t)
 | 
			
		||||
CONVERT_TEMPLATE(float, float16_t)
 | 
			
		||||
CONVERT_TEMPLATE(double, float16_t)
 | 
			
		||||
#endif
 | 
			
		||||
#ifdef __ARM_FEATURE_BF16
 | 
			
		||||
CONVERT_TEMPLATE(bfloat16_t, uint8_t)
 | 
			
		||||
CONVERT_TEMPLATE(bfloat16_t, int8_t)
 | 
			
		||||
CONVERT_TEMPLATE(bfloat16_t, int16_t)
 | 
			
		||||
CONVERT_TEMPLATE(bfloat16_t, int32_t)
 | 
			
		||||
CONVERT_TEMPLATE(bfloat16_t, int64_t)
 | 
			
		||||
CONVERT_TEMPLATE(bfloat16_t, bfloat16_t)
 | 
			
		||||
CONVERT_TEMPLATE(bfloat16_t, float)
 | 
			
		||||
CONVERT_TEMPLATE(bfloat16_t, double)
 | 
			
		||||
CONVERT_TEMPLATE(uint8_t, bfloat16_t)
 | 
			
		||||
CONVERT_TEMPLATE(int8_t, bfloat16_t)
 | 
			
		||||
CONVERT_TEMPLATE(int16_t, bfloat16_t)
 | 
			
		||||
CONVERT_TEMPLATE(int32_t, bfloat16_t)
 | 
			
		||||
CONVERT_TEMPLATE(int64_t, bfloat16_t)
 | 
			
		||||
CONVERT_TEMPLATE(float, bfloat16_t)
 | 
			
		||||
CONVERT_TEMPLATE(double, bfloat16_t)
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
template <typename src_t>
 | 
			
		||||
struct VecConvert<
 | 
			
		||||
    float,
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										586
									
								
								aten/src/ATen/cpu/vec/vec128/vec128_double_neon.h
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										586
									
								
								aten/src/ATen/cpu/vec/vec128/vec128_double_neon.h
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,586 @@
 | 
			
		||||
#pragma once
 | 
			
		||||
 | 
			
		||||
#include <ATen/cpu/vec/intrinsics.h>
 | 
			
		||||
#include <ATen/cpu/vec/vec_base.h>
 | 
			
		||||
#include <c10/macros/Macros.h>
 | 
			
		||||
#include <c10/util/irange.h>
 | 
			
		||||
#include <cmath>
 | 
			
		||||
 | 
			
		||||
namespace at::vec {
 | 
			
		||||
// Note [CPU_CAPABILITY namespace]
 | 
			
		||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 | 
			
		||||
// This header, and all of its subheaders, will be compiled with
 | 
			
		||||
// different architecture flags for each supported set of vector
 | 
			
		||||
// intrinsics. So we need to make sure they aren't inadvertently
 | 
			
		||||
// linked together. We do this by declaring objects in an `inline
 | 
			
		||||
// namespace` which changes the name mangling, but can still be
 | 
			
		||||
// accessed as `at::vec`.
 | 
			
		||||
inline namespace CPU_CAPABILITY {
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
struct is_vec_specialized_for<double> : std::bool_constant<true> {};
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
class Vectorized<double> {
 | 
			
		||||
 private:
 | 
			
		||||
  float64x2_t values;
 | 
			
		||||
 | 
			
		||||
 public:
 | 
			
		||||
  using value_type = double;
 | 
			
		||||
  using size_type = int;
 | 
			
		||||
  static constexpr size_type size() {
 | 
			
		||||
    return 2;
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized() {
 | 
			
		||||
    values = vdupq_n_f64(0.0);
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized(float64x2_t v) : values(v) {}
 | 
			
		||||
  Vectorized(double val) {
 | 
			
		||||
    values = vdupq_n_f64(val);
 | 
			
		||||
  }
 | 
			
		||||
  template <
 | 
			
		||||
      typename... Args,
 | 
			
		||||
      typename = std::enable_if_t<(sizeof...(Args) == size())>>
 | 
			
		||||
  Vectorized(Args... vals) {
 | 
			
		||||
    __at_align__ double buffer[size()] = {vals...};
 | 
			
		||||
    values = vld1q_f64(buffer);
 | 
			
		||||
  }
 | 
			
		||||
  operator float64x2_t() const {
 | 
			
		||||
    return values;
 | 
			
		||||
  }
 | 
			
		||||
  template <int64_t mask>
 | 
			
		||||
  static Vectorized<double> blend(
 | 
			
		||||
      const Vectorized<double>& a,
 | 
			
		||||
      const Vectorized<double>& b) {
 | 
			
		||||
    // Build an array of flags: each bit of element is 1 if the corresponding
 | 
			
		||||
    // bit in 'mask' is set, 0 otherwise.
 | 
			
		||||
    uint64x2_t maskArray = {
 | 
			
		||||
        (mask & 1ULL) ? 0xFFFFFFFFFFFFFFFF : 0,
 | 
			
		||||
        (mask & 2ULL) ? 0xFFFFFFFFFFFFFFFF : 0};
 | 
			
		||||
    // Use BSL to select elements from b where the mask is 1, else from a
 | 
			
		||||
    return vbslq_f64(maskArray, b.values, a.values);
 | 
			
		||||
  }
 | 
			
		||||
  static Vectorized<double> blendv(
 | 
			
		||||
      const Vectorized<double>& a,
 | 
			
		||||
      const Vectorized<double>& b,
 | 
			
		||||
      const Vectorized<double>& mask_) {
 | 
			
		||||
    return vbslq_f64(vreinterpretq_u64_f64(mask_.values), b.values, a.values);
 | 
			
		||||
  }
 | 
			
		||||
  template <typename step_t>
 | 
			
		||||
  static Vectorized<double> arange(
 | 
			
		||||
      double base = 0.,
 | 
			
		||||
      step_t step = static_cast<step_t>(1)) {
 | 
			
		||||
    return {base, base + static_cast<double>(step)};
 | 
			
		||||
  }
 | 
			
		||||
  static inline Vectorized<double> set(
 | 
			
		||||
      const Vectorized<double>& a,
 | 
			
		||||
      const Vectorized<double>& b,
 | 
			
		||||
      int64_t count = size()) {
 | 
			
		||||
    if (count == 0) {
 | 
			
		||||
      return a;
 | 
			
		||||
    } else if (count >= 2) {
 | 
			
		||||
      return b;
 | 
			
		||||
    } else {
 | 
			
		||||
      float64x2_t c = {b.values[0], a.values[1]};
 | 
			
		||||
      return c;
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
  static Vectorized<double> loadu(const void* ptr, int64_t count = size()) {
 | 
			
		||||
    if (count == size()) {
 | 
			
		||||
      return vld1q_f64(reinterpret_cast<const double*>(ptr));
 | 
			
		||||
    } else if (count == 1) {
 | 
			
		||||
      float64x1_t x = vld1_f64(reinterpret_cast<const double*>(ptr));
 | 
			
		||||
      float64x1_t z = {0.0};
 | 
			
		||||
      return vcombine_f64(x, z);
 | 
			
		||||
    } else {
 | 
			
		||||
      return vdupq_n_f64(0.0);
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
  void store(void* ptr, int64_t count = size()) const {
 | 
			
		||||
    if (count == size()) {
 | 
			
		||||
      vst1q_f64(reinterpret_cast<double*>(ptr), values);
 | 
			
		||||
    } else if (count == 1) {
 | 
			
		||||
      vst1_f64(reinterpret_cast<double*>(ptr), vget_low_f64(values));
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
  const double& operator[](int idx) const = delete;
 | 
			
		||||
  double& operator[](int idx) = delete;
 | 
			
		||||
  int64_t zero_mask() const {
 | 
			
		||||
    // returns an integer mask where all zero elements are translated to 1-bit
 | 
			
		||||
    // and others are translated to 0-bit
 | 
			
		||||
    uint64x2_t cmpReg = vceqzq_f64(values);
 | 
			
		||||
    uint64x2_t mask = {1, 2};
 | 
			
		||||
    uint64x2_t res = vandq_u64(cmpReg, mask);
 | 
			
		||||
    return res[0] | res[1];
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized<double> isnan() const {
 | 
			
		||||
    // NaN check
 | 
			
		||||
    return vreinterpretq_f64_u32(
 | 
			
		||||
        vmvnq_u32(vreinterpretq_u32_u64(vceqq_f64(values, values))));
 | 
			
		||||
  }
 | 
			
		||||
  bool has_inf_nan() const {
 | 
			
		||||
    Vectorized<double> x = vsubq_f64(values, values);
 | 
			
		||||
    float64x2_t r = x.isnan();
 | 
			
		||||
    uint64x2_t u = vreinterpretq_u64_f64(r);
 | 
			
		||||
    return u[0] | u[1];
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized<double> map(double (*f)(double)) const {
 | 
			
		||||
    float64x2_t result;
 | 
			
		||||
    result[0] = f(values[0]);
 | 
			
		||||
    result[1] = f(values[1]);
 | 
			
		||||
    return result;
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized<double> map2(
 | 
			
		||||
      const Vectorized<double>& second,
 | 
			
		||||
      double (*const f)(double, double)) const {
 | 
			
		||||
    float64x2_t result;
 | 
			
		||||
    result[0] = f(values[0], second.values[0]);
 | 
			
		||||
    result[1] = f(values[1], second.values[1]);
 | 
			
		||||
    return result;
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized<double> abs() const {
 | 
			
		||||
    return vabsq_f64(values);
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized<double> angle() const {
 | 
			
		||||
    auto zero = Vectorized<double>(0.0);
 | 
			
		||||
    auto pi = Vectorized<double>(c10::pi<double>);
 | 
			
		||||
    auto tmp = blendv(zero, pi, vreinterpretq_f64_u64(vcltzq_f64(values)));
 | 
			
		||||
    return blendv(tmp, *this, isnan());
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized<double> real() const {
 | 
			
		||||
    return *this;
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized<double> imag() const {
 | 
			
		||||
    return Vectorized<double>(0.0);
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized<double> conj() const {
 | 
			
		||||
    return *this;
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized<double> acos() const {
 | 
			
		||||
    return USE_SLEEF(
 | 
			
		||||
        Vectorized<double>(Sleef_acosd2_u10(values)), map(std::acos));
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized<double> acosh() const {
 | 
			
		||||
    return USE_SLEEF(
 | 
			
		||||
        Vectorized<double>(Sleef_acoshd2_u10(values)), map(std::acosh));
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized<double> asin() const {
 | 
			
		||||
    return USE_SLEEF(
 | 
			
		||||
        Vectorized<double>(Sleef_asind2_u10(values)), map(std::asin));
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized<double> asinh() const {
 | 
			
		||||
    return USE_SLEEF(
 | 
			
		||||
        Vectorized<double>(Sleef_asinhd2_u10(values)), map(std::asinh));
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized<double> atan() const {
 | 
			
		||||
    return USE_SLEEF(
 | 
			
		||||
        Vectorized<double>(Sleef_atand2_u10(values)), map(std::atan));
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized<double> atanh() const {
 | 
			
		||||
    return USE_SLEEF(
 | 
			
		||||
        Vectorized<double>(Sleef_atanhd2_u10(values)), map(std::atanh));
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized<double> atan2(const Vectorized<double>& b) const {USE_SLEEF(
 | 
			
		||||
      { return Vectorized<double>(Sleef_atan2d2_u10(values, b)); },
 | 
			
		||||
      {
 | 
			
		||||
        __at_align__ double tmp[size()];
 | 
			
		||||
        __at_align__ double tmp_b[size()];
 | 
			
		||||
        store(tmp);
 | 
			
		||||
        b.store(tmp_b);
 | 
			
		||||
        for (int64_t i = 0; i < size(); i++) {
 | 
			
		||||
          tmp[i] = std::atan2(tmp[i], tmp_b[i]);
 | 
			
		||||
        }
 | 
			
		||||
        return loadu(tmp);
 | 
			
		||||
      })} Vectorized<double> copysign(const Vectorized<double>& sign) const {
 | 
			
		||||
      USE_SLEEF(
 | 
			
		||||
          { return Vectorized<double>(Sleef_copysignd2(values, sign)); },
 | 
			
		||||
          {
 | 
			
		||||
            __at_align__ double tmp[size()];
 | 
			
		||||
            __at_align__ double tmp_sign[size()];
 | 
			
		||||
            store(tmp);
 | 
			
		||||
            sign.store(tmp_sign);
 | 
			
		||||
            for (int64_t i = 0; i < size(); i++) {
 | 
			
		||||
              tmp[i] = std::copysign(tmp[i], tmp_sign[i]);
 | 
			
		||||
            }
 | 
			
		||||
            return loadu(tmp);
 | 
			
		||||
          })} Vectorized<double> erf() const {
 | 
			
		||||
    return USE_SLEEF(
 | 
			
		||||
        Vectorized<double>(Sleef_erfd2_u10(values)), map(std::erf));
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized<double> erfc() const {
 | 
			
		||||
    return USE_SLEEF(
 | 
			
		||||
        Vectorized<double>(Sleef_erfcd2_u15(values)), map(std::erfc));
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized<double> exp() const {
 | 
			
		||||
    return USE_SLEEF(
 | 
			
		||||
        Vectorized<double>(Sleef_expd2_u10(values)), map(std::exp));
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized<double> exp2() const {
 | 
			
		||||
    return USE_SLEEF(
 | 
			
		||||
        Vectorized<double>(Sleef_exp2d2_u10(values)), map(std::exp2));
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized<double> expm1() const {
 | 
			
		||||
    return USE_SLEEF(
 | 
			
		||||
        Vectorized<double>(Sleef_expm1d2_u10(values)), map(std::expm1));
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized<double> fmod(const Vectorized<double>& q) const {USE_SLEEF(
 | 
			
		||||
      { return Vectorized<double>(Sleef_fmodd2(values, q)); },
 | 
			
		||||
      {
 | 
			
		||||
        __at_align__ double tmp[size()];
 | 
			
		||||
        __at_align__ double tmp_q[size()];
 | 
			
		||||
        store(tmp);
 | 
			
		||||
        q.store(tmp_q);
 | 
			
		||||
        for (int64_t i = 0; i < size(); i++) {
 | 
			
		||||
          tmp[i] = std::fmod(tmp[i], tmp_q[i]);
 | 
			
		||||
        }
 | 
			
		||||
        return loadu(tmp);
 | 
			
		||||
      })} Vectorized<double> hypot(const Vectorized<double>& b) const {
 | 
			
		||||
      USE_SLEEF(
 | 
			
		||||
          { return Vectorized<double>(Sleef_hypotd2_u05(values, b)); },
 | 
			
		||||
          {
 | 
			
		||||
            __at_align__ double tmp[size()];
 | 
			
		||||
            __at_align__ double tmp_b[size()];
 | 
			
		||||
            store(tmp);
 | 
			
		||||
            b.store(tmp_b);
 | 
			
		||||
            for (int64_t i = 0; i < size(); i++) {
 | 
			
		||||
              tmp[i] = std::hypot(tmp[i], tmp_b[i]);
 | 
			
		||||
            }
 | 
			
		||||
            return loadu(tmp);
 | 
			
		||||
          })} Vectorized<double> i0() const {
 | 
			
		||||
    return map(calc_i0);
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized<double> nextafter(const Vectorized<double>& b) const {USE_SLEEF(
 | 
			
		||||
      { return Vectorized<double>(Sleef_nextafterd2(values, b)); },
 | 
			
		||||
      {
 | 
			
		||||
        __at_align__ double tmp[size()];
 | 
			
		||||
        __at_align__ double tmp_b[size()];
 | 
			
		||||
        store(tmp);
 | 
			
		||||
        b.store(tmp_b);
 | 
			
		||||
        for (int64_t i = 0; i < size(); ++i) {
 | 
			
		||||
          tmp[i] = std::nextafter(tmp[i], tmp_b[i]);
 | 
			
		||||
        }
 | 
			
		||||
        return loadu(tmp);
 | 
			
		||||
      })} Vectorized<double> log() const {
 | 
			
		||||
    return USE_SLEEF(
 | 
			
		||||
        Vectorized<double>(Sleef_logd2_u10(values)), map(std::log));
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized<double> log2() const {
 | 
			
		||||
    return USE_SLEEF(
 | 
			
		||||
        Vectorized<double>(Sleef_log2d2_u10(values)), map(std::log2));
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized<double> log10() const {
 | 
			
		||||
    return USE_SLEEF(
 | 
			
		||||
        Vectorized<double>(Sleef_log10d2_u10(values)), map(std::log10));
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized<double> log1p() const {
 | 
			
		||||
    return USE_SLEEF(
 | 
			
		||||
        Vectorized<double>(Sleef_log1pd2_u10(values)), map(std::log1p));
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized<double> frac() const;
 | 
			
		||||
  Vectorized<double> sin() const {
 | 
			
		||||
    return USE_SLEEF(
 | 
			
		||||
        Vectorized<double>(Sleef_sind2_u10(values)), map(std::sin));
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized<double> sinh() const {
 | 
			
		||||
    return USE_SLEEF(
 | 
			
		||||
        Vectorized<double>(Sleef_sinhd2_u10(values)), map(std::sinh));
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized<double> cos() const {
 | 
			
		||||
    return USE_SLEEF(
 | 
			
		||||
        Vectorized<double>(Sleef_cosd2_u10(values)), map(std::cos));
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized<double> cosh() const {
 | 
			
		||||
    return USE_SLEEF(
 | 
			
		||||
        Vectorized<double>(Sleef_coshd2_u10(values)), map(std::cosh));
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized<double> pow(const Vectorized<double>& b) const {USE_SLEEF(
 | 
			
		||||
      { return Vectorized<double>(Sleef_powd2_u10(values, b)); },
 | 
			
		||||
      {
 | 
			
		||||
        __at_align__ double tmp[size()];
 | 
			
		||||
        __at_align__ double tmp_b[size()];
 | 
			
		||||
        store(tmp);
 | 
			
		||||
        b.store(tmp_b);
 | 
			
		||||
        for (int64_t i = 0; i < size(); i++) {
 | 
			
		||||
          tmp[i] = std::pow(tmp[i], tmp_b[i]);
 | 
			
		||||
        }
 | 
			
		||||
        return loadu(tmp);
 | 
			
		||||
      })} // Comparison using the _CMP_**_OQ predicate.
 | 
			
		||||
          //   `O`: get false if an operand is NaN
 | 
			
		||||
          //   `Q`: do not raise if an operand is NaN
 | 
			
		||||
  Vectorized<double> tan() const {
 | 
			
		||||
    return USE_SLEEF(
 | 
			
		||||
        Vectorized<double>(Sleef_tand2_u10(values)), map(std::tan));
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized<double> tanh() const {
 | 
			
		||||
    return USE_SLEEF(
 | 
			
		||||
        Vectorized<double>(Sleef_tanhd2_u10(values)), map(std::tanh));
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized<double> lgamma() const {
 | 
			
		||||
    return USE_SLEEF(
 | 
			
		||||
        Vectorized<double>(Sleef_lgammad2_u10(values)), map(std::lgamma));
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized<double> erfinv() const {
 | 
			
		||||
    return map(calc_erfinv);
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized<double> exp_u20() const {
 | 
			
		||||
    return exp();
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized<double> fexp_u20() const {
 | 
			
		||||
    return exp();
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized<double> i0e() const {
 | 
			
		||||
    return map(calc_i0e);
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized<double> digamma() const {
 | 
			
		||||
    return map(calc_digamma);
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized<double> igamma(const Vectorized<double>& x) const {
 | 
			
		||||
    __at_align__ double tmp[size()];
 | 
			
		||||
    __at_align__ double tmp_x[size()];
 | 
			
		||||
    store(tmp);
 | 
			
		||||
    x.store(tmp_x);
 | 
			
		||||
    for (int64_t i = 0; i < size(); i++) {
 | 
			
		||||
      tmp[i] = calc_igamma(tmp[i], tmp_x[i]);
 | 
			
		||||
    }
 | 
			
		||||
    return loadu(tmp);
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized<double> igammac(const Vectorized<double>& x) const {
 | 
			
		||||
    __at_align__ double tmp[size()];
 | 
			
		||||
    __at_align__ double tmp_x[size()];
 | 
			
		||||
    store(tmp);
 | 
			
		||||
    x.store(tmp_x);
 | 
			
		||||
    for (int64_t i = 0; i < size(); i++) {
 | 
			
		||||
      tmp[i] = calc_igammac(tmp[i], tmp_x[i]);
 | 
			
		||||
    }
 | 
			
		||||
    return loadu(tmp);
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized<double> ceil() const {
 | 
			
		||||
    return vrndpq_f64(values);
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized<double> floor() const {
 | 
			
		||||
    return vrndmq_f64(values);
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized<double> neg() const {
 | 
			
		||||
    return vnegq_f64(values);
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized<double> round() const {
 | 
			
		||||
    return vrndiq_f64(values);
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized<double> trunc() const {
 | 
			
		||||
    return vrndq_f64(values);
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized<double> sqrt() const {
 | 
			
		||||
    return vsqrtq_f64(values);
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized<double> reciprocal() const {
 | 
			
		||||
    return vdivq_f64(vdupq_n_f64(1.0), values);
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized<double> rsqrt() const {
 | 
			
		||||
    return vdivq_f64(vdupq_n_f64(1.0), vsqrtq_f64(values));
 | 
			
		||||
  }
 | 
			
		||||
  double reduce_add() const {
 | 
			
		||||
    return vaddvq_f64(values);
 | 
			
		||||
  }
 | 
			
		||||
  double reduce_max() const {
 | 
			
		||||
    return vmaxvq_f64(values);
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized<double> operator==(const Vectorized<double>& other) const {
 | 
			
		||||
    return Vectorized<double>(
 | 
			
		||||
        vreinterpretq_f64_u64(vceqq_f64(values, other.values)));
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  Vectorized<double> operator!=(const Vectorized<double>& other) const {
 | 
			
		||||
    float64x2_t r0 = vreinterpretq_f64_u32(
 | 
			
		||||
        vmvnq_u32(vreinterpretq_u32_u64(vceqq_f64(values, other.values))));
 | 
			
		||||
    return Vectorized<double>(r0);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  Vectorized<double> operator<(const Vectorized<double>& other) const {
 | 
			
		||||
    return Vectorized<double>(
 | 
			
		||||
        vreinterpretq_f64_u64(vcltq_f64(values, other.values)));
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  Vectorized<double> operator<=(const Vectorized<double>& other) const {
 | 
			
		||||
    return Vectorized<double>(
 | 
			
		||||
        vreinterpretq_f64_u64(vcleq_f64(values, other.values)));
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  Vectorized<double> operator>(const Vectorized<double>& other) const {
 | 
			
		||||
    return Vectorized<double>(
 | 
			
		||||
        vreinterpretq_f64_u64(vcgtq_f64(values, other.values)));
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  Vectorized<double> operator>=(const Vectorized<double>& other) const {
 | 
			
		||||
    return Vectorized<double>(
 | 
			
		||||
        vreinterpretq_f64_u64(vcgeq_f64(values, other.values)));
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  Vectorized<double> eq(const Vectorized<double>& other) const;
 | 
			
		||||
  Vectorized<double> ne(const Vectorized<double>& other) const;
 | 
			
		||||
  Vectorized<double> gt(const Vectorized<double>& other) const;
 | 
			
		||||
  Vectorized<double> ge(const Vectorized<double>& other) const;
 | 
			
		||||
  Vectorized<double> lt(const Vectorized<double>& other) const;
 | 
			
		||||
  Vectorized<double> le(const Vectorized<double>& other) const;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<double> inline operator+(
 | 
			
		||||
    const Vectorized<double>& a,
 | 
			
		||||
    const Vectorized<double>& b) {
 | 
			
		||||
  return vaddq_f64(a, b);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<double> inline operator-(
 | 
			
		||||
    const Vectorized<double>& a,
 | 
			
		||||
    const Vectorized<double>& b) {
 | 
			
		||||
  return vsubq_f64(a, b);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<double> inline operator*(
 | 
			
		||||
    const Vectorized<double>& a,
 | 
			
		||||
    const Vectorized<double>& b) {
 | 
			
		||||
  return vmulq_f64(a, b);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<double> inline operator/(
 | 
			
		||||
    const Vectorized<double>& a,
 | 
			
		||||
    const Vectorized<double>& b) {
 | 
			
		||||
  return vdivq_f64(a, b);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// frac. Implement this here so we can use subtraction
 | 
			
		||||
Vectorized<double> inline Vectorized<double>::frac() const {
 | 
			
		||||
  return *this - this->trunc();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Implements the IEEE 754 201X `maximum` operation, which propagates NaN if
 | 
			
		||||
// either input is a NaN.
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<double> inline maximum(
 | 
			
		||||
    const Vectorized<double>& a,
 | 
			
		||||
    const Vectorized<double>& b) {
 | 
			
		||||
  return vmaxq_f64(a, b);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Implements the IEEE 754 201X `minimum` operation, which propagates NaN if
 | 
			
		||||
// either input is a NaN.
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<double> inline minimum(
 | 
			
		||||
    const Vectorized<double>& a,
 | 
			
		||||
    const Vectorized<double>& b) {
 | 
			
		||||
  return vminq_f64(a, b);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<double> inline clamp(
 | 
			
		||||
    const Vectorized<double>& a,
 | 
			
		||||
    const Vectorized<double>& min,
 | 
			
		||||
    const Vectorized<double>& max) {
 | 
			
		||||
  return vminq_f64(max, vmaxq_f64(min, a));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<double> inline clamp_max(
 | 
			
		||||
    const Vectorized<double>& a,
 | 
			
		||||
    const Vectorized<double>& max) {
 | 
			
		||||
  return vminq_f64(max, a);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<double> inline clamp_min(
 | 
			
		||||
    const Vectorized<double>& a,
 | 
			
		||||
    const Vectorized<double>& min) {
 | 
			
		||||
  return vmaxq_f64(min, a);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<double> inline operator&(
 | 
			
		||||
    const Vectorized<double>& a,
 | 
			
		||||
    const Vectorized<double>& b) {
 | 
			
		||||
  return vreinterpretq_f64_u64(
 | 
			
		||||
      vandq_u64(vreinterpretq_u64_f64(a), vreinterpretq_u64_f64(b)));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<double> inline operator|(
 | 
			
		||||
    const Vectorized<double>& a,
 | 
			
		||||
    const Vectorized<double>& b) {
 | 
			
		||||
  return vreinterpretq_f64_u64(
 | 
			
		||||
      vorrq_u64(vreinterpretq_u64_f64(a), vreinterpretq_u64_f64(b)));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<double> inline operator^(
 | 
			
		||||
    const Vectorized<double>& a,
 | 
			
		||||
    const Vectorized<double>& b) {
 | 
			
		||||
  return vreinterpretq_f64_u64(
 | 
			
		||||
      veorq_u64(vreinterpretq_u64_f64(a), vreinterpretq_u64_f64(b)));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
inline Vectorized<double> Vectorized<double>::eq(
 | 
			
		||||
    const Vectorized<double>& other) const {
 | 
			
		||||
  return (*this == other) & Vectorized<double>(1.0);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
inline Vectorized<double> Vectorized<double>::ne(
 | 
			
		||||
    const Vectorized<double>& other) const {
 | 
			
		||||
  return (*this != other) & Vectorized<double>(1.0);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
inline Vectorized<double> Vectorized<double>::gt(
 | 
			
		||||
    const Vectorized<double>& other) const {
 | 
			
		||||
  return (*this > other) & Vectorized<double>(1.0);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
inline Vectorized<double> Vectorized<double>::ge(
 | 
			
		||||
    const Vectorized<double>& other) const {
 | 
			
		||||
  return (*this >= other) & Vectorized<double>(1.0);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
inline Vectorized<double> Vectorized<double>::lt(
 | 
			
		||||
    const Vectorized<double>& other) const {
 | 
			
		||||
  return (*this < other) & Vectorized<double>(1.0);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
inline Vectorized<double> Vectorized<double>::le(
 | 
			
		||||
    const Vectorized<double>& other) const {
 | 
			
		||||
  return (*this <= other) & Vectorized<double>(1.0);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<double> inline fmadd(
 | 
			
		||||
    const Vectorized<double>& a,
 | 
			
		||||
    const Vectorized<double>& b,
 | 
			
		||||
    const Vectorized<double>& c) {
 | 
			
		||||
  return vfmaq_f64(c, a, b);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<double> inline fnmadd(
 | 
			
		||||
    const Vectorized<double>& a,
 | 
			
		||||
    const Vectorized<double>& b,
 | 
			
		||||
    const Vectorized<double>& c) {
 | 
			
		||||
  return vfmsq_f64(c, a, b);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<double> inline fmsub(
 | 
			
		||||
    const Vectorized<double>& a,
 | 
			
		||||
    const Vectorized<double>& b,
 | 
			
		||||
    const Vectorized<double>& c) {
 | 
			
		||||
  return vfmaq_f64(vnegq_f64(c), a, b);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<double> inline fnmsub(
 | 
			
		||||
    const Vectorized<double>& a,
 | 
			
		||||
    const Vectorized<double>& b,
 | 
			
		||||
    const Vectorized<double>& c) {
 | 
			
		||||
  return vfmsq_f64(vnegq_f64(c), a, b);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
} // namespace CPU_CAPABILITY
 | 
			
		||||
} // namespace at::vec
 | 
			
		||||
@ -307,11 +307,49 @@ class Vectorized<float> {
 | 
			
		||||
  DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(exp)
 | 
			
		||||
  DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(exp2)
 | 
			
		||||
  DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(expm1)
 | 
			
		||||
  // Implementation copied from Arm Optimized Routine
 | 
			
		||||
  // https://github.com/ARM-software/optimized-routines/blob/master/math/aarch64/advsimd/expf.c
 | 
			
		||||
  Vectorized<float> exp_u20() const {
 | 
			
		||||
    return exp();
 | 
			
		||||
    // bail out to sleef if it's a special case:
 | 
			
		||||
    // i.e. there's an input s.t. |input| > 87.3....
 | 
			
		||||
    const float32x4_t special_bound = vdupq_n_f32(0x1.5d5e2ap+6f);
 | 
			
		||||
    uint32x4_t cmp = vcagtq_f32(values, special_bound);
 | 
			
		||||
    if (vpaddd_u64(vreinterpretq_u64_u32(cmp)) != 0) {
 | 
			
		||||
      return exp();
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    const float32x4_t inv_ln2 = vdupq_n_f32(0x1.715476p+0f);
 | 
			
		||||
    const float ln2_hi = 0x1.62e4p-1f;
 | 
			
		||||
    const float ln2_lo = 0x1.7f7d1cp-20f;
 | 
			
		||||
    const float c0 = 0x1.0e4020p-7f;
 | 
			
		||||
    const float c2 = 0x1.555e66p-3f;
 | 
			
		||||
    const float32x4_t ln2_c02 = {ln2_hi, ln2_lo, c0, c2};
 | 
			
		||||
 | 
			
		||||
    const uint32x4_t exponent_bias = vdupq_n_u32(0x3f800000);
 | 
			
		||||
    const float32x4_t c1 = vdupq_n_f32(0x1.573e2ep-5f);
 | 
			
		||||
    const float32x4_t c3 = vdupq_n_f32(0x1.fffdb6p-2f);
 | 
			
		||||
    const float32x4_t c4 = vdupq_n_f32(0x1.ffffecp-1f);
 | 
			
		||||
 | 
			
		||||
    /* exp(x) = 2^n (1 + poly(r)), with 1 + poly(r) in [1/sqrt(2),sqrt(2)]
 | 
			
		||||
      x = ln2*n + r, with r in [-ln2/2, ln2/2].  */
 | 
			
		||||
 | 
			
		||||
    float32x4_t n = vrndaq_f32(vmulq_f32(values, inv_ln2));
 | 
			
		||||
    float32x4_t r = vfmsq_laneq_f32(values, n, ln2_c02, 0);
 | 
			
		||||
    r = vfmsq_laneq_f32(r, n, ln2_c02, 1);
 | 
			
		||||
    uint32x4_t e = vshlq_n_u32(vreinterpretq_u32_s32(vcvtq_s32_f32(n)), 23);
 | 
			
		||||
    float32x4_t scale = vreinterpretq_f32_u32(vaddq_u32(e, exponent_bias));
 | 
			
		||||
 | 
			
		||||
    float32x4_t r2 = vmulq_f32(r, r);
 | 
			
		||||
    float32x4_t p = vfmaq_laneq_f32(c1, r, ln2_c02, 2);
 | 
			
		||||
    float32x4_t q = vfmaq_laneq_f32(c3, r, ln2_c02, 3);
 | 
			
		||||
    q = vfmaq_f32(q, p, r2);
 | 
			
		||||
    p = vmulq_f32(c4, r);
 | 
			
		||||
    float32x4_t poly = vfmaq_f32(p, q, r2);
 | 
			
		||||
 | 
			
		||||
    return vfmaq_f32(scale, poly, scale);
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized<float> fexp_u20() const {
 | 
			
		||||
    return exp();
 | 
			
		||||
    return exp_u20();
 | 
			
		||||
  }
 | 
			
		||||
  DEFINE_SLEEF_COMPATIBLE_BINARY_ELEMENTWISE_FUNC_WITH_SLEEF_NAME(
 | 
			
		||||
      fmod,
 | 
			
		||||
@ -540,42 +578,6 @@ inline Vectorized<float> Vectorized<float>::le(
 | 
			
		||||
  return (*this <= other) & Vectorized<float>(1.0f);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
inline void convert(const float* src, int32_t* dst, int64_t n) {
 | 
			
		||||
  int64_t i;
 | 
			
		||||
#ifndef __msvc_cl__
 | 
			
		||||
#pragma unroll
 | 
			
		||||
#endif
 | 
			
		||||
  for (i = 0; i <= (n - Vectorized<float>::size());
 | 
			
		||||
       i += Vectorized<float>::size()) {
 | 
			
		||||
    vst1q_s32(dst + i, vcvtq_s32_f32(vld1q_f32(src + i)));
 | 
			
		||||
  }
 | 
			
		||||
#ifndef __msvc_cl__
 | 
			
		||||
#pragma unroll
 | 
			
		||||
#endif
 | 
			
		||||
  for (; i < n; i++) {
 | 
			
		||||
    dst[i] = static_cast<int32_t>(src[i]);
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
inline void convert(const int32_t* src, float* dst, int64_t n) {
 | 
			
		||||
  int64_t i;
 | 
			
		||||
#ifndef __msvc_cl__
 | 
			
		||||
#pragma unroll
 | 
			
		||||
#endif
 | 
			
		||||
  for (i = 0; i <= (n - Vectorized<float>::size());
 | 
			
		||||
       i += Vectorized<float>::size()) {
 | 
			
		||||
    vst1q_f32(dst + i, vcvtq_f32_s32(vld1q_s32(src + i)));
 | 
			
		||||
  }
 | 
			
		||||
#ifndef __msvc_cl__
 | 
			
		||||
#pragma unroll
 | 
			
		||||
#endif
 | 
			
		||||
  for (; i < n; i++) {
 | 
			
		||||
    dst[i] = static_cast<float>(src[i]);
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<float> inline fmadd(
 | 
			
		||||
    const Vectorized<float>& a,
 | 
			
		||||
 | 
			
		||||
@ -569,46 +569,6 @@ inline Vectorized<c10::Half> Vectorized<c10::Half>::le(
 | 
			
		||||
  return (*this <= other) & Vectorized<c10::Half>(1);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// These are global functions, so the defaults in vec_base.h should
 | 
			
		||||
// work fine if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC is not available.
 | 
			
		||||
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
 | 
			
		||||
template <>
 | 
			
		||||
inline void convert(const float16_t* src, int16_t* dst, int64_t n) {
 | 
			
		||||
  int64_t i;
 | 
			
		||||
#ifndef __msvc_cl__
 | 
			
		||||
#pragma unroll
 | 
			
		||||
#endif
 | 
			
		||||
  for (i = 0; i <= (n - Vectorized<c10::Half>::size());
 | 
			
		||||
       i += Vectorized<c10::Half>::size()) {
 | 
			
		||||
    vst1q_s16(dst + i, vcvtq_s16_f16(vld1q_f16(src + i)));
 | 
			
		||||
  }
 | 
			
		||||
#ifndef __msvc_cl__
 | 
			
		||||
#pragma unroll
 | 
			
		||||
#endif
 | 
			
		||||
  for (; i < n; i++) {
 | 
			
		||||
    dst[i] = static_cast<int16_t>(src[i]);
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
inline void convert(const int16_t* src, float16_t* dst, int64_t n) {
 | 
			
		||||
  int64_t i;
 | 
			
		||||
#ifndef __msvc_cl__
 | 
			
		||||
#pragma unroll
 | 
			
		||||
#endif
 | 
			
		||||
  for (i = 0; i <= (n - Vectorized<c10::Half>::size());
 | 
			
		||||
       i += Vectorized<c10::Half>::size()) {
 | 
			
		||||
    vst1q_f16(dst + i, vcvtq_f16_s16(vld1q_s16(src + i)));
 | 
			
		||||
  }
 | 
			
		||||
#ifndef __msvc_cl__
 | 
			
		||||
#pragma unroll
 | 
			
		||||
#endif
 | 
			
		||||
  for (; i < n; i++) {
 | 
			
		||||
    dst[i] = static_cast<float16_t>(src[i]);
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<c10::Half> inline fmadd(
 | 
			
		||||
    const Vectorized<c10::Half>& a,
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										378
									
								
								aten/src/ATen/cpu/vec/vec128/vec128_uint_aarch64.h
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										378
									
								
								aten/src/ATen/cpu/vec/vec128/vec128_uint_aarch64.h
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,378 @@
 | 
			
		||||
#pragma once
 | 
			
		||||
 | 
			
		||||
#include <ATen/cpu/vec/intrinsics.h>
 | 
			
		||||
#include <ATen/cpu/vec/vec_base.h>
 | 
			
		||||
#include <c10/macros/Macros.h>
 | 
			
		||||
#include <c10/util/irange.h>
 | 
			
		||||
 | 
			
		||||
namespace at::vec {
 | 
			
		||||
// Note [CPU_CAPABILITY namespace]
 | 
			
		||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 | 
			
		||||
// This header, and all of its subheaders, will be compiled with
 | 
			
		||||
// different architecture flags for each supported set of vector
 | 
			
		||||
// intrinsics. So we need to make sure they aren't inadvertently
 | 
			
		||||
// linked together. We do this by declaring objects in an `inline
 | 
			
		||||
// namespace` which changes the name mangling, but can still be
 | 
			
		||||
// accessed as `at::vec`.
 | 
			
		||||
inline namespace CPU_CAPABILITY {
 | 
			
		||||
 | 
			
		||||
#define VEC_UINT_NEON_TEMPLATE(vl, bit)                                       \
 | 
			
		||||
  template <>                                                                 \
 | 
			
		||||
  struct is_vec_specialized_for<uint##bit##_t> : std::bool_constant<true> {}; \
 | 
			
		||||
                                                                              \
 | 
			
		||||
  template <>                                                                 \
 | 
			
		||||
  class Vectorized<uint##bit##_t> {                                           \
 | 
			
		||||
    using neon_type = uint##bit##x##vl##_t;                                   \
 | 
			
		||||
                                                                              \
 | 
			
		||||
   private:                                                                   \
 | 
			
		||||
    neon_type values;                                                         \
 | 
			
		||||
                                                                              \
 | 
			
		||||
   public:                                                                    \
 | 
			
		||||
    using value_type = uint##bit##_t;                                         \
 | 
			
		||||
    using size_type = int;                                                    \
 | 
			
		||||
    static constexpr size_type size() {                                       \
 | 
			
		||||
      return vl;                                                              \
 | 
			
		||||
    }                                                                         \
 | 
			
		||||
    Vectorized() {                                                            \
 | 
			
		||||
      values = vdupq_n_u##bit(0);                                             \
 | 
			
		||||
    }                                                                         \
 | 
			
		||||
    Vectorized(neon_type v) : values(v) {}                                    \
 | 
			
		||||
    Vectorized(uint##bit##_t val);                                            \
 | 
			
		||||
    template <                                                                \
 | 
			
		||||
        typename... Args,                                                     \
 | 
			
		||||
        typename = std::enable_if_t<(sizeof...(Args) == size())>>             \
 | 
			
		||||
    Vectorized(Args... vals) {                                                \
 | 
			
		||||
      __at_align__ uint##bit##_t buffer[size()] = {vals...};                  \
 | 
			
		||||
      values = vld1q_u##bit(buffer);                                          \
 | 
			
		||||
    }                                                                         \
 | 
			
		||||
    operator neon_type() const {                                              \
 | 
			
		||||
      return values;                                                          \
 | 
			
		||||
    }                                                                         \
 | 
			
		||||
    static Vectorized<uint##bit##_t> loadu(                                   \
 | 
			
		||||
        const void* ptr,                                                      \
 | 
			
		||||
        uint64_t count = size());                                             \
 | 
			
		||||
    void store(void* ptr, uint64_t count = size()) const;                     \
 | 
			
		||||
    template <uint64_t mask>                                                  \
 | 
			
		||||
    static Vectorized<uint##bit##_t> blend(                                   \
 | 
			
		||||
        const Vectorized<uint##bit##_t>& a,                                   \
 | 
			
		||||
        const Vectorized<uint##bit##_t>& b);                                  \
 | 
			
		||||
    static Vectorized<uint##bit##_t> blendv(                                  \
 | 
			
		||||
        const Vectorized<uint##bit##_t>& a,                                   \
 | 
			
		||||
        const Vectorized<uint##bit##_t>& b,                                   \
 | 
			
		||||
        const Vectorized<uint##bit##_t>& mask_) {                             \
 | 
			
		||||
      return vbslq_u##bit(mask_.values, b, a);                                \
 | 
			
		||||
    }                                                                         \
 | 
			
		||||
    template <typename step_t>                                                \
 | 
			
		||||
    static Vectorized<uint##bit##_t> arange(                                  \
 | 
			
		||||
        value_type base = 0,                                                  \
 | 
			
		||||
        step_t step = static_cast<step_t>(1));                                \
 | 
			
		||||
    static Vectorized<uint##bit##_t> set(                                     \
 | 
			
		||||
        const Vectorized<uint##bit##_t>& a,                                   \
 | 
			
		||||
        const Vectorized<uint##bit##_t>& b,                                   \
 | 
			
		||||
        uint64_t count = size());                                             \
 | 
			
		||||
    const uint##bit##_t& operator[](uint idx) const = delete;                 \
 | 
			
		||||
    uint##bit##_t& operator[](uint idx) = delete;                             \
 | 
			
		||||
    Vectorized<uint##bit##_t> abs() const {                                   \
 | 
			
		||||
      return values;                                                          \
 | 
			
		||||
    }                                                                         \
 | 
			
		||||
    Vectorized<uint##bit##_t> real() const {                                  \
 | 
			
		||||
      return values;                                                          \
 | 
			
		||||
    }                                                                         \
 | 
			
		||||
    Vectorized<uint##bit##_t> imag() const {                                  \
 | 
			
		||||
      return vdupq_n_u##bit(0);                                               \
 | 
			
		||||
    }                                                                         \
 | 
			
		||||
    Vectorized<uint##bit##_t> conj() const {                                  \
 | 
			
		||||
      return values;                                                          \
 | 
			
		||||
    }                                                                         \
 | 
			
		||||
    Vectorized<uint##bit##_t> neg() const {                                   \
 | 
			
		||||
      return vreinterpretq_u##bit##_s##bit(                                   \
 | 
			
		||||
          vnegq_s##bit(vreinterpretq_s##bit##_u##bit(values)));               \
 | 
			
		||||
    }                                                                         \
 | 
			
		||||
    uint##bit##_t reduce_add() const {                                        \
 | 
			
		||||
      return vaddvq_u##bit(values);                                           \
 | 
			
		||||
    }                                                                         \
 | 
			
		||||
    uint##bit##_t reduce_max() const;                                         \
 | 
			
		||||
    Vectorized<uint##bit##_t> operator==(                                     \
 | 
			
		||||
        const Vectorized<uint##bit##_t>& other) const {                       \
 | 
			
		||||
      return Vectorized<value_type>(vceqq_u##bit(values, other.values));      \
 | 
			
		||||
    }                                                                         \
 | 
			
		||||
    Vectorized<uint##bit##_t> operator!=(                                     \
 | 
			
		||||
        const Vectorized<uint##bit##_t>& other) const;                        \
 | 
			
		||||
    Vectorized<uint##bit##_t> operator<(                                      \
 | 
			
		||||
        const Vectorized<uint##bit##_t>& other) const {                       \
 | 
			
		||||
      return Vectorized<value_type>(vcltq_u##bit(values, other.values));      \
 | 
			
		||||
    }                                                                         \
 | 
			
		||||
    Vectorized<uint##bit##_t> operator<=(                                     \
 | 
			
		||||
        const Vectorized<uint##bit##_t>& other) const {                       \
 | 
			
		||||
      return Vectorized<value_type>(vcleq_u##bit(values, other.values));      \
 | 
			
		||||
    }                                                                         \
 | 
			
		||||
    Vectorized<uint##bit##_t> operator>(                                      \
 | 
			
		||||
        const Vectorized<uint##bit##_t>& other) const {                       \
 | 
			
		||||
      return Vectorized<value_type>(vcgtq_u##bit(values, other.values));      \
 | 
			
		||||
    }                                                                         \
 | 
			
		||||
    Vectorized<uint##bit##_t> operator>=(                                     \
 | 
			
		||||
        const Vectorized<uint##bit##_t>& other) const {                       \
 | 
			
		||||
      return Vectorized<value_type>(vcgeq_u##bit(values, other.values));      \
 | 
			
		||||
    }                                                                         \
 | 
			
		||||
    Vectorized<uint##bit##_t> eq(                                             \
 | 
			
		||||
        const Vectorized<uint##bit##_t>& other) const;                        \
 | 
			
		||||
    Vectorized<uint##bit##_t> ne(                                             \
 | 
			
		||||
        const Vectorized<uint##bit##_t>& other) const;                        \
 | 
			
		||||
    Vectorized<uint##bit##_t> gt(                                             \
 | 
			
		||||
        const Vectorized<uint##bit##_t>& other) const;                        \
 | 
			
		||||
    Vectorized<uint##bit##_t> ge(                                             \
 | 
			
		||||
        const Vectorized<uint##bit##_t>& other) const;                        \
 | 
			
		||||
    Vectorized<uint##bit##_t> lt(                                             \
 | 
			
		||||
        const Vectorized<uint##bit##_t>& other) const;                        \
 | 
			
		||||
    Vectorized<uint##bit##_t> le(                                             \
 | 
			
		||||
        const Vectorized<uint##bit##_t>& other) const;                        \
 | 
			
		||||
  };                                                                          \
 | 
			
		||||
  template <>                                                                 \
 | 
			
		||||
  Vectorized<uint##bit##_t> inline operator+(                                 \
 | 
			
		||||
      const Vectorized<uint##bit##_t>& a,                                     \
 | 
			
		||||
      const Vectorized<uint##bit##_t>& b) {                                   \
 | 
			
		||||
    return vaddq_u##bit(a, b);                                                \
 | 
			
		||||
  }                                                                           \
 | 
			
		||||
  template <>                                                                 \
 | 
			
		||||
  Vectorized<uint##bit##_t> inline operator-(                                 \
 | 
			
		||||
      const Vectorized<uint##bit##_t>& a,                                     \
 | 
			
		||||
      const Vectorized<uint##bit##_t>& b) {                                   \
 | 
			
		||||
    return vsubq_u##bit(a, b);                                                \
 | 
			
		||||
  }                                                                           \
 | 
			
		||||
  template <>                                                                 \
 | 
			
		||||
  Vectorized<uint##bit##_t> inline operator&(                                 \
 | 
			
		||||
      const Vectorized<uint##bit##_t>& a,                                     \
 | 
			
		||||
      const Vectorized<uint##bit##_t>& b) {                                   \
 | 
			
		||||
    return vandq_u##bit(a, b);                                                \
 | 
			
		||||
  }                                                                           \
 | 
			
		||||
  template <>                                                                 \
 | 
			
		||||
  Vectorized<uint##bit##_t> inline operator|(                                 \
 | 
			
		||||
      const Vectorized<uint##bit##_t>& a,                                     \
 | 
			
		||||
      const Vectorized<uint##bit##_t>& b) {                                   \
 | 
			
		||||
    return vorrq_u##bit(a, b);                                                \
 | 
			
		||||
  }                                                                           \
 | 
			
		||||
  template <>                                                                 \
 | 
			
		||||
  Vectorized<uint##bit##_t> inline operator^(                                 \
 | 
			
		||||
      const Vectorized<uint##bit##_t>& a,                                     \
 | 
			
		||||
      const Vectorized<uint##bit##_t>& b) {                                   \
 | 
			
		||||
    return veorq_u##bit(a, b);                                                \
 | 
			
		||||
  }                                                                           \
 | 
			
		||||
  Vectorized<uint##bit##_t> inline Vectorized<uint##bit##_t>::eq(             \
 | 
			
		||||
      const Vectorized<uint##bit##_t>& other) const {                         \
 | 
			
		||||
    return (*this == other) & Vectorized<uint##bit##_t>(1);                   \
 | 
			
		||||
  }                                                                           \
 | 
			
		||||
  Vectorized<uint##bit##_t> inline Vectorized<uint##bit##_t>::ne(             \
 | 
			
		||||
      const Vectorized<uint##bit##_t>& other) const {                         \
 | 
			
		||||
    return (*this != other) & Vectorized<uint##bit##_t>(1);                   \
 | 
			
		||||
  }                                                                           \
 | 
			
		||||
  Vectorized<uint##bit##_t> inline Vectorized<uint##bit##_t>::gt(             \
 | 
			
		||||
      const Vectorized<uint##bit##_t>& other) const {                         \
 | 
			
		||||
    return (*this > other) & Vectorized<uint##bit##_t>(1);                    \
 | 
			
		||||
  }                                                                           \
 | 
			
		||||
  Vectorized<uint##bit##_t> inline Vectorized<uint##bit##_t>::ge(             \
 | 
			
		||||
      const Vectorized<uint##bit##_t>& other) const {                         \
 | 
			
		||||
    return (*this >= other) & Vectorized<uint##bit##_t>(1);                   \
 | 
			
		||||
  }                                                                           \
 | 
			
		||||
  Vectorized<uint##bit##_t> inline Vectorized<uint##bit##_t>::lt(             \
 | 
			
		||||
      const Vectorized<uint##bit##_t>& other) const {                         \
 | 
			
		||||
    return (*this < other) & Vectorized<uint##bit##_t>(1);                    \
 | 
			
		||||
  }                                                                           \
 | 
			
		||||
  Vectorized<uint##bit##_t> inline Vectorized<uint##bit##_t>::le(             \
 | 
			
		||||
      const Vectorized<uint##bit##_t>& other) const {                         \
 | 
			
		||||
    return (*this <= other) & Vectorized<uint##bit##_t>(1);                   \
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
VEC_UINT_NEON_TEMPLATE(16, 8)
 | 
			
		||||
 | 
			
		||||
inline uint8_t Vectorized<uint8_t>::reduce_max() const {
 | 
			
		||||
  return vmaxvq_u8(values);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<uint8_t> inline operator*(
 | 
			
		||||
    const Vectorized<uint8_t>& a,
 | 
			
		||||
    const Vectorized<uint8_t>& b) {
 | 
			
		||||
  return vmulq_u8(a, b);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
inline Vectorized<uint8_t> operator~(const Vectorized<uint8_t>& a) {
 | 
			
		||||
  return vmvnq_u8(a);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
inline Vectorized<uint8_t> Vectorized<uint8_t>::operator!=(
 | 
			
		||||
    const Vectorized<uint8_t>& other) const {
 | 
			
		||||
  return ~(*this == other);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<uint8_t> inline minimum(
 | 
			
		||||
    const Vectorized<uint8_t>& a,
 | 
			
		||||
    const Vectorized<uint8_t>& b) {
 | 
			
		||||
  return vminq_u8(a, b);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<uint8_t> inline maximum(
 | 
			
		||||
    const Vectorized<uint8_t>& a,
 | 
			
		||||
    const Vectorized<uint8_t>& b) {
 | 
			
		||||
  return vmaxq_u8(a, b);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <uint64_t mask>
 | 
			
		||||
Vectorized<uint8_t> Vectorized<uint8_t>::blend(
 | 
			
		||||
    const Vectorized<uint8_t>& a,
 | 
			
		||||
    const Vectorized<uint8_t>& b) {
 | 
			
		||||
  // Build an array of flags: each bit of element is 1 if the corresponding bit
 | 
			
		||||
  // in 'mask' is set, 0 otherwise.
 | 
			
		||||
  uint8x16_t maskArray = {
 | 
			
		||||
      (mask & 1LL) ? 0xFF : 0,
 | 
			
		||||
      (mask & 2LL) ? 0xFF : 0,
 | 
			
		||||
      (mask & 4LL) ? 0xFF : 0,
 | 
			
		||||
      (mask & 8LL) ? 0xFF : 0,
 | 
			
		||||
      (mask & 16LL) ? 0xFF : 0,
 | 
			
		||||
      (mask & 32LL) ? 0xFF : 0,
 | 
			
		||||
      (mask & 64LL) ? 0xFF : 0,
 | 
			
		||||
      (mask & 128LL) ? 0xFF : 0,
 | 
			
		||||
      (mask & 256LL) ? 0xFF : 0,
 | 
			
		||||
      (mask & 512LL) ? 0xFF : 0,
 | 
			
		||||
      (mask & 1024LL) ? 0xFF : 0,
 | 
			
		||||
      (mask & 2048LL) ? 0xFF : 0,
 | 
			
		||||
      (mask & 4096LL) ? 0xFF : 0,
 | 
			
		||||
      (mask & 8192LL) ? 0xFF : 0,
 | 
			
		||||
      (mask & 16384LL) ? 0xFF : 0,
 | 
			
		||||
      (mask & 32768LL) ? 0xFF : 0};
 | 
			
		||||
  // Use BSL to select elements from b where the mask is 1, else from a
 | 
			
		||||
  return vbslq_u8(maskArray, b.values, a.values);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#define VEC_UINT_NEON_OPS(vl, bit)                                             \
 | 
			
		||||
  inline Vectorized<uint##bit##_t>::Vectorized(uint##bit##_t val) {            \
 | 
			
		||||
    values = vdupq_n_u##bit(val);                                              \
 | 
			
		||||
  }                                                                            \
 | 
			
		||||
  inline Vectorized<uint##bit##_t> Vectorized<uint##bit##_t>::loadu(           \
 | 
			
		||||
      const void* ptr, uint64_t count) {                                       \
 | 
			
		||||
    if (count == size()) {                                                     \
 | 
			
		||||
      return vld1q_u##bit(reinterpret_cast<const uint##bit##_t*>(ptr));        \
 | 
			
		||||
    } else {                                                                   \
 | 
			
		||||
      __at_align__ uint##bit##_t tmp_values[size()];                           \
 | 
			
		||||
      for (const auto i : c10::irange(size())) {                               \
 | 
			
		||||
        tmp_values[i] = 0;                                                     \
 | 
			
		||||
      }                                                                        \
 | 
			
		||||
      std::memcpy(                                                             \
 | 
			
		||||
          tmp_values,                                                          \
 | 
			
		||||
          reinterpret_cast<const uint##bit##_t*>(ptr),                         \
 | 
			
		||||
          count * sizeof(uint##bit##_t));                                      \
 | 
			
		||||
      return vld1q_u##bit(reinterpret_cast<const uint##bit##_t*>(tmp_values)); \
 | 
			
		||||
    }                                                                          \
 | 
			
		||||
  }                                                                            \
 | 
			
		||||
  inline void Vectorized<uint##bit##_t>::store(void* ptr, uint64_t count)      \
 | 
			
		||||
      const {                                                                  \
 | 
			
		||||
    if (count == size()) {                                                     \
 | 
			
		||||
      vst1q_u##bit(reinterpret_cast<uint##bit##_t*>(ptr), values);             \
 | 
			
		||||
    } else {                                                                   \
 | 
			
		||||
      uint##bit##_t tmp_values[size()];                                        \
 | 
			
		||||
      vst1q_u##bit(reinterpret_cast<uint##bit##_t*>(tmp_values), values);      \
 | 
			
		||||
      std::memcpy(ptr, tmp_values, count * sizeof(uint##bit##_t));             \
 | 
			
		||||
    }                                                                          \
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
VEC_UINT_NEON_OPS(16, 8)
 | 
			
		||||
 | 
			
		||||
template <typename step_t>
 | 
			
		||||
inline Vectorized<uint8_t> Vectorized<uint8_t>::arange(
 | 
			
		||||
    uint8_t base,
 | 
			
		||||
    step_t step) {
 | 
			
		||||
  const Vectorized<uint8_t> base_vec(base);
 | 
			
		||||
  const Vectorized<uint8_t> step_vec(step);
 | 
			
		||||
  const uint8x16_t step_sizes = {
 | 
			
		||||
      0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15};
 | 
			
		||||
  return vmlaq_u8(base_vec, step_sizes, step_vec);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<uint8_t> inline operator>>(
 | 
			
		||||
    const Vectorized<uint8_t>& a,
 | 
			
		||||
    const Vectorized<uint8_t>& b) {
 | 
			
		||||
  uint8x16_t x = a;
 | 
			
		||||
  uint8x16_t bound = vdupq_n_u8(8);
 | 
			
		||||
  uint8x16_t z = vminq_u8(b, bound);
 | 
			
		||||
  return x >> z;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<uint8_t> inline operator<<(
 | 
			
		||||
    const Vectorized<uint8_t>& a,
 | 
			
		||||
    const Vectorized<uint8_t>& b) {
 | 
			
		||||
  uint8x16_t bound = vdupq_n_u8(8);
 | 
			
		||||
  uint8x16_t z = vminq_u8(b, bound);
 | 
			
		||||
  return vshlq_u8(a, vreinterpretq_s8_u8(z));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
inline Vectorized<uint8_t> Vectorized<uint8_t>::set(
 | 
			
		||||
    const Vectorized<uint8_t>& a,
 | 
			
		||||
    const Vectorized<uint8_t>& b,
 | 
			
		||||
    uint64_t count) {
 | 
			
		||||
  if (count == 0) {
 | 
			
		||||
    return a;
 | 
			
		||||
  } else if (count >= 16) {
 | 
			
		||||
    return b;
 | 
			
		||||
  } else {
 | 
			
		||||
    // Build an array of flags: each bit of element is 1 if the corresponding
 | 
			
		||||
    // bit in 'mask' is set, 0 otherwise.
 | 
			
		||||
    uint8x16_t maskArray = {
 | 
			
		||||
        static_cast<uint8_t>((count >= 1LL) ? 0xFF : 0),
 | 
			
		||||
        static_cast<uint8_t>((count >= 2LL) ? 0xFF : 0),
 | 
			
		||||
        static_cast<uint8_t>((count >= 3LL) ? 0xFF : 0),
 | 
			
		||||
        static_cast<uint8_t>((count >= 4LL) ? 0xFF : 0),
 | 
			
		||||
        static_cast<uint8_t>((count >= 5LL) ? 0xFF : 0),
 | 
			
		||||
        static_cast<uint8_t>((count >= 6LL) ? 0xFF : 0),
 | 
			
		||||
        static_cast<uint8_t>((count >= 7LL) ? 0xFF : 0),
 | 
			
		||||
        static_cast<uint8_t>((count >= 8LL) ? 0xFF : 0),
 | 
			
		||||
        static_cast<uint8_t>((count >= 9LL) ? 0xFF : 0),
 | 
			
		||||
        static_cast<uint8_t>((count >= 10LL) ? 0xFF : 0),
 | 
			
		||||
        static_cast<uint8_t>((count >= 11LL) ? 0xFF : 0),
 | 
			
		||||
        static_cast<uint8_t>((count >= 12LL) ? 0xFF : 0),
 | 
			
		||||
        static_cast<uint8_t>((count >= 13LL) ? 0xFF : 0),
 | 
			
		||||
        static_cast<uint8_t>((count >= 14LL) ? 0xFF : 0),
 | 
			
		||||
        static_cast<uint8_t>((count >= 15LL) ? 0xFF : 0),
 | 
			
		||||
        0};
 | 
			
		||||
 | 
			
		||||
    // Use BSL to select elements from b where the mask is 1, else from a
 | 
			
		||||
    return vbslq_u8(maskArray, b.values, a.values);
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<uint8_t> inline operator/(
 | 
			
		||||
    const Vectorized<uint8_t>& a,
 | 
			
		||||
    const Vectorized<uint8_t>& b) {
 | 
			
		||||
  uint8x16_t x = a;
 | 
			
		||||
  uint8x16_t y = b;
 | 
			
		||||
  return x / y;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<uint8_t> inline clamp(
 | 
			
		||||
    const Vectorized<uint8_t>& a,
 | 
			
		||||
    const Vectorized<uint8_t>& min,
 | 
			
		||||
    const Vectorized<uint8_t>& max) {
 | 
			
		||||
  return minimum(max, maximum(min, a));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<uint8_t> inline clamp_max(
 | 
			
		||||
    const Vectorized<uint8_t>& a,
 | 
			
		||||
    const Vectorized<uint8_t>& max) {
 | 
			
		||||
  return minimum(max, a);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<uint8_t> inline clamp_min(
 | 
			
		||||
    const Vectorized<uint8_t>& a,
 | 
			
		||||
    const Vectorized<uint8_t>& min) {
 | 
			
		||||
  return maximum(min, a);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
} // namespace CPU_CAPABILITY
 | 
			
		||||
} // namespace at::vec
 | 
			
		||||
@ -1390,7 +1390,7 @@ std::pair<Vectorized<float>, Vectorized<float>> inline convert_int8_to_float(
 | 
			
		||||
 | 
			
		||||
std::pair<Vectorized<float>, Vectorized<float>> inline convert_int8_to_float(
 | 
			
		||||
    at::vec::Vectorized<uint8_t> src) {
 | 
			
		||||
  auto u8x8 = vld1_u8(src.operator const uint8_t*());
 | 
			
		||||
  auto u8x8 = vget_low_u8(src);
 | 
			
		||||
  auto u16x8 = vmovl_u8(u8x8);
 | 
			
		||||
  auto u32x4_hi = vmovl_u16(vget_high_u16(u16x8));
 | 
			
		||||
  auto u32x4_lo = vmovl_u16(vget_low_u16(u16x8));
 | 
			
		||||
@ -1412,7 +1412,7 @@ Vectorized<float> inline convert_int8_half_register_to_float(
 | 
			
		||||
 | 
			
		||||
Vectorized<float> inline convert_int8_half_register_to_float(
 | 
			
		||||
    at::vec::Vectorized<uint8_t> src) {
 | 
			
		||||
  auto u8x8 = vld1_u8(src.operator const uint8_t*());
 | 
			
		||||
  auto u8x8 = vget_low_u8(src);
 | 
			
		||||
  auto u16x8 = vmovl_u8(u8x8);
 | 
			
		||||
  auto u32x4_lo = vmovl_u16(vget_low_u16(u16x8));
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -168,11 +168,9 @@ void CUDAGraph::instantiate() {
 | 
			
		||||
  // https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__GRAPH.html#group__CUDART__GRAPH_1g1accfe1da0c605a577c22d9751a09597
 | 
			
		||||
  // cudaGraphInstantiateWithFlags
 | 
			
		||||
  // https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__GRAPH.html#group__CUDART__GRAPH_1ga2c652a24ba93e52b99a47bec0888233
 | 
			
		||||
#if !defined(USE_ROCM) || ROCM_VERSION >= 60200
 | 
			
		||||
  int version = 0;
 | 
			
		||||
  AT_CUDA_CHECK(cudaDriverGetVersion(&version));
 | 
			
		||||
  if (version < 11040) {
 | 
			
		||||
#endif
 | 
			
		||||
    // Trailing NULL, NULL, 0 arguments were recommended by Cuda driver people,
 | 
			
		||||
    // who prefer not to report error message through these arguments moving forward
 | 
			
		||||
    // (they prefer return value, or errors on api calls internal to the capture)
 | 
			
		||||
@ -183,13 +181,11 @@ void CUDAGraph::instantiate() {
 | 
			
		||||
#endif
 | 
			
		||||
//Since ROCm 6.2, we want to go down this path as hipGraphExecDestroy in the destructor will not immediately free the memory.
 | 
			
		||||
//It will wait for the next sync operation. cudaGraphInstantiateFlagAutoFreeOnLaunch will add async frees after graph launch.
 | 
			
		||||
#if !defined(USE_ROCM) || ROCM_VERSION >= 60200
 | 
			
		||||
  } else {
 | 
			
		||||
    AT_CUDA_CHECK(cudaGraphInstantiateWithFlags(&graph_exec_,
 | 
			
		||||
                                                graph_,
 | 
			
		||||
                                                cudaGraphInstantiateFlagAutoFreeOnLaunch));
 | 
			
		||||
  }
 | 
			
		||||
#endif
 | 
			
		||||
  has_graph_exec_ = true;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										192
									
								
								aten/src/ATen/cuda/CUDAGreenContext.cpp
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										192
									
								
								aten/src/ATen/cuda/CUDAGreenContext.cpp
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,192 @@
 | 
			
		||||
#include <ATen/cuda/CUDAGreenContext.h>
 | 
			
		||||
 | 
			
		||||
namespace at::cuda {
 | 
			
		||||
  GreenContext::GreenContext(uint32_t device_id, uint32_t num_sms) {
 | 
			
		||||
#if CUDA_HAS_GREEN_CONTEXT
 | 
			
		||||
    int driver_version;
 | 
			
		||||
    C10_CUDA_CHECK(cudaDriverGetVersion(&driver_version));
 | 
			
		||||
    TORCH_CHECK(
 | 
			
		||||
        driver_version >= 12080, "cuda driver too old to use green context!");
 | 
			
		||||
    CUcontext pctx = nullptr;
 | 
			
		||||
    C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuCtxGetCurrent_(&pctx));
 | 
			
		||||
    if (C10_UNLIKELY(!pctx)) {
 | 
			
		||||
      TORCH_WARN(
 | 
			
		||||
          "Attempted to create a green context but"
 | 
			
		||||
          " there was no primary context! Creating a primary context...");
 | 
			
		||||
 | 
			
		||||
      cudaFree(0);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    CUdevice device;
 | 
			
		||||
    device_id_ = device_id;
 | 
			
		||||
    C10_CUDA_DRIVER_CHECK(
 | 
			
		||||
        c10::cuda::DriverAPI::get()->cuDeviceGet_(&device, device_id));
 | 
			
		||||
 | 
			
		||||
    // Get device resources
 | 
			
		||||
    CUdevResource device_resource;
 | 
			
		||||
    C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuDeviceGetDevResource_(
 | 
			
		||||
        device, &device_resource, CU_DEV_RESOURCE_TYPE_SM));
 | 
			
		||||
 | 
			
		||||
    // Split resources
 | 
			
		||||
    std::vector<CUdevResource> result(1);
 | 
			
		||||
    auto result_data = result.data();
 | 
			
		||||
    unsigned int nb_groups = 1;
 | 
			
		||||
    CUdevResource remaining;
 | 
			
		||||
 | 
			
		||||
    C10_CUDA_DRIVER_CHECK(
 | 
			
		||||
        c10::cuda::DriverAPI::get()->cuDevSmResourceSplitByCount_(
 | 
			
		||||
            result_data,
 | 
			
		||||
            &nb_groups,
 | 
			
		||||
            &device_resource,
 | 
			
		||||
            &remaining,
 | 
			
		||||
            0, // default flags
 | 
			
		||||
            num_sms));
 | 
			
		||||
 | 
			
		||||
    TORCH_CHECK(nb_groups == 1, "Failed to create single resource group");
 | 
			
		||||
 | 
			
		||||
    // Generate resource descriptor
 | 
			
		||||
    CUdevResourceDesc desc;
 | 
			
		||||
    C10_CUDA_DRIVER_CHECK(
 | 
			
		||||
        c10::cuda::DriverAPI::get()->cuDevResourceGenerateDesc_(
 | 
			
		||||
            &desc, result_data, 1));
 | 
			
		||||
 | 
			
		||||
    // Create green context
 | 
			
		||||
    // CU_GREEN_CTX_DEFAULT_STREAM is required per docs:
 | 
			
		||||
    // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__GREEN__CONTEXTS.html
 | 
			
		||||
    C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuGreenCtxCreate_(
 | 
			
		||||
        &green_ctx_, desc, device, CU_GREEN_CTX_DEFAULT_STREAM));
 | 
			
		||||
 | 
			
		||||
    // Convert to regular context
 | 
			
		||||
    C10_CUDA_DRIVER_CHECK(
 | 
			
		||||
        c10::cuda::DriverAPI::get()->cuCtxFromGreenCtx_(&context_, green_ctx_));
 | 
			
		||||
    TORCH_CHECK(context_, "Green ctx conversion to regular ctx failed!");
 | 
			
		||||
#else
 | 
			
		||||
    TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
 | 
			
		||||
#endif
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  std::unique_ptr<GreenContext> GreenContext::create(
 | 
			
		||||
      uint32_t num_sms,
 | 
			
		||||
      std::optional<uint32_t> device_id) {
 | 
			
		||||
#if CUDA_HAS_GREEN_CONTEXT
 | 
			
		||||
    if (!device_id.has_value()) {
 | 
			
		||||
      device_id = at::cuda::current_device();
 | 
			
		||||
    }
 | 
			
		||||
    return std::make_unique<GreenContext>(device_id.value(), num_sms);
 | 
			
		||||
#else
 | 
			
		||||
    TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
 | 
			
		||||
#endif
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Implement move operations
 | 
			
		||||
  GreenContext::GreenContext(GreenContext&& other) noexcept{
 | 
			
		||||
#if CUDA_HAS_GREEN_CONTEXT
 | 
			
		||||
    device_id_ = std::exchange(other.device_id_, -1);
 | 
			
		||||
    green_ctx_ = std::exchange(other.green_ctx_, nullptr);
 | 
			
		||||
    context_ = std::exchange(other.context_, nullptr);
 | 
			
		||||
    parent_stream_ = std::exchange(other.parent_stream_, nullptr);
 | 
			
		||||
#else
 | 
			
		||||
    TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
 | 
			
		||||
#endif
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  GreenContext& GreenContext::operator=(GreenContext&& other) noexcept{
 | 
			
		||||
#if CUDA_HAS_GREEN_CONTEXT
 | 
			
		||||
    if (this != &other) {
 | 
			
		||||
      // Clean up current resources
 | 
			
		||||
      if (green_ctx_) {
 | 
			
		||||
        CUcontext current = nullptr;
 | 
			
		||||
        C10_CUDA_DRIVER_CHECK(
 | 
			
		||||
            c10::cuda::DriverAPI::get()->cuCtxGetCurrent_(¤t));
 | 
			
		||||
        if (current == context_) {
 | 
			
		||||
          TORCH_CHECK(
 | 
			
		||||
              false,
 | 
			
		||||
              "attempting to overwrite current green ctx "
 | 
			
		||||
              "when it is active!");
 | 
			
		||||
        }
 | 
			
		||||
        C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuGreenCtxDestroy_(green_ctx_));
 | 
			
		||||
      }
 | 
			
		||||
 | 
			
		||||
      // Take ownership of other's resources
 | 
			
		||||
      device_id_ = std::exchange(other.device_id_, -1);
 | 
			
		||||
      green_ctx_ = std::exchange(other.green_ctx_, nullptr);
 | 
			
		||||
      context_ = std::exchange(other.context_, nullptr);
 | 
			
		||||
      parent_stream_ = std::exchange(other.parent_stream_, nullptr);
 | 
			
		||||
    }
 | 
			
		||||
    return *this;
 | 
			
		||||
#else
 | 
			
		||||
    TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
 | 
			
		||||
#endif
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  GreenContext::~GreenContext() noexcept{
 | 
			
		||||
#if CUDA_HAS_GREEN_CONTEXT
 | 
			
		||||
    C10_CUDA_DRIVER_CHECK(
 | 
			
		||||
        c10::cuda::DriverAPI::get()->cuGreenCtxDestroy_(green_ctx_));
 | 
			
		||||
#else
 | 
			
		||||
    TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
 | 
			
		||||
#endif
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Get the underlying CUDA context
 | 
			
		||||
  CUcontext GreenContext::getContext() const {
 | 
			
		||||
#if CUDA_HAS_GREEN_CONTEXT
 | 
			
		||||
    return context_;
 | 
			
		||||
#else
 | 
			
		||||
    TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
 | 
			
		||||
#endif
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Get the underlying green context
 | 
			
		||||
#if CUDA_HAS_GREEN_CONTEXT
 | 
			
		||||
  CUgreenCtx GreenContext::getGreenContext() const {
 | 
			
		||||
    return green_ctx_;
 | 
			
		||||
  }
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
  // Make this context current
 | 
			
		||||
  void GreenContext::setContext() {
 | 
			
		||||
#if CUDA_HAS_GREEN_CONTEXT
 | 
			
		||||
    auto current_stream = c10::cuda::getCurrentCUDAStream();
 | 
			
		||||
    parent_stream_ = current_stream.stream();
 | 
			
		||||
 | 
			
		||||
    at::cuda::CUDAEvent ev;
 | 
			
		||||
    ev.record(current_stream);
 | 
			
		||||
 | 
			
		||||
    CUcontext current = nullptr;
 | 
			
		||||
    C10_CUDA_DRIVER_CHECK(
 | 
			
		||||
        c10::cuda::DriverAPI::get()->cuCtxGetCurrent_(¤t));
 | 
			
		||||
    if (!current) {
 | 
			
		||||
      C10_CUDA_DRIVER_CHECK(
 | 
			
		||||
          c10::cuda::DriverAPI::get()->cuCtxSetCurrent_(context_));
 | 
			
		||||
    } else {
 | 
			
		||||
      C10_CUDA_DRIVER_CHECK(
 | 
			
		||||
          c10::cuda::DriverAPI::get()->cuCtxPushCurrent_(context_));
 | 
			
		||||
    }
 | 
			
		||||
    // currently hardcodes the new green context to use the default stream
 | 
			
		||||
    // TODO(eqy): consider creating a new stream if e.g., it allows interop
 | 
			
		||||
    // with CUDA Graph captures etc.
 | 
			
		||||
    auto default_stream = c10::cuda::getDefaultCUDAStream();
 | 
			
		||||
    ev.block(default_stream);
 | 
			
		||||
    c10::cuda::setCurrentCUDAStream(default_stream);
 | 
			
		||||
#else
 | 
			
		||||
    TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
 | 
			
		||||
#endif
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  void GreenContext::popContext() {
 | 
			
		||||
#if CUDA_HAS_GREEN_CONTEXT
 | 
			
		||||
    // see above note about stream being hardcoded to the default stream
 | 
			
		||||
    at::cuda::CUDAEvent ev;
 | 
			
		||||
    ev.record(c10::cuda::getCurrentCUDAStream());
 | 
			
		||||
    CUcontext popped;
 | 
			
		||||
    C10_CUDA_DRIVER_CHECK(
 | 
			
		||||
        c10::cuda::DriverAPI::get()->cuCtxPopCurrent_(&popped));
 | 
			
		||||
    TORCH_INTERNAL_ASSERT(
 | 
			
		||||
        popped == context_, "expected popped context to be the current ctx");
 | 
			
		||||
    ev.block(c10::cuda::getStreamFromExternal(parent_stream_, device_id_));
 | 
			
		||||
#else
 | 
			
		||||
    TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
 | 
			
		||||
#endif
 | 
			
		||||
  }
 | 
			
		||||
} // namespace at::cuda
 | 
			
		||||
							
								
								
									
										53
									
								
								aten/src/ATen/cuda/CUDAGreenContext.h
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										53
									
								
								aten/src/ATen/cuda/CUDAGreenContext.h
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,53 @@
 | 
			
		||||
#pragma once
 | 
			
		||||
#include <ATen/cuda/CUDAEvent.h>
 | 
			
		||||
 | 
			
		||||
#if defined(CUDA_VERSION) && !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
 | 
			
		||||
#include <c10/cuda/driver_api.h>
 | 
			
		||||
#include <cuda.h>
 | 
			
		||||
#include <memory>
 | 
			
		||||
#include <stdexcept>
 | 
			
		||||
#include <vector>
 | 
			
		||||
#define CUDA_HAS_GREEN_CONTEXT 1
 | 
			
		||||
#else
 | 
			
		||||
#define CUDA_HAS_GREEN_CONTEXT 0
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
namespace at::cuda {
 | 
			
		||||
 | 
			
		||||
class TORCH_CUDA_CPP_API GreenContext {
 | 
			
		||||
 public:
 | 
			
		||||
  GreenContext(uint32_t device_id, uint32_t num_sms);
 | 
			
		||||
 | 
			
		||||
  static std::unique_ptr<GreenContext> create(uint32_t num_sms, std::optional<uint32_t> device_id);
 | 
			
		||||
 | 
			
		||||
  // Delete copy constructor and assignment
 | 
			
		||||
  GreenContext(const GreenContext&) = delete;
 | 
			
		||||
  GreenContext& operator=(const GreenContext&) = delete;
 | 
			
		||||
 | 
			
		||||
  // Implement move operations
 | 
			
		||||
  GreenContext(GreenContext&& other) noexcept;
 | 
			
		||||
  GreenContext& operator=(GreenContext&& other) noexcept;
 | 
			
		||||
  ~GreenContext() noexcept;
 | 
			
		||||
 | 
			
		||||
  // Get the underlying CUDA context
 | 
			
		||||
  CUcontext getContext() const;
 | 
			
		||||
 | 
			
		||||
  // Get the underlying green context
 | 
			
		||||
#if CUDA_HAS_GREEN_CONTEXT
 | 
			
		||||
  CUgreenCtx getGreenContext() const;
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
  // Make this context current
 | 
			
		||||
  void setContext();
 | 
			
		||||
 | 
			
		||||
  void popContext();
 | 
			
		||||
 | 
			
		||||
 private:
 | 
			
		||||
#if CUDA_HAS_GREEN_CONTEXT
 | 
			
		||||
  int32_t device_id_ = -1;
 | 
			
		||||
  CUgreenCtx green_ctx_ = nullptr;
 | 
			
		||||
  CUcontext context_ = nullptr;
 | 
			
		||||
  cudaStream_t parent_stream_ = nullptr;
 | 
			
		||||
#endif
 | 
			
		||||
};
 | 
			
		||||
} // namespace at::cuda
 | 
			
		||||
							
								
								
									
										270
									
								
								aten/src/ATen/cuda/CUDAScaledBlas.cpp
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										270
									
								
								aten/src/ATen/cuda/CUDAScaledBlas.cpp
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,270 @@
 | 
			
		||||
#include <cstdint>
 | 
			
		||||
#include <c10/util/typeid.h>
 | 
			
		||||
#include <c10/util/Exception.h>
 | 
			
		||||
#include <c10/util/SmallVector.h>
 | 
			
		||||
#include <c10/core/Scalar.h>
 | 
			
		||||
#include <c10/core/ScalarType.h>
 | 
			
		||||
#include <c10/util/Exception.h>
 | 
			
		||||
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
 | 
			
		||||
#include <ATen/core/Tensor.h>
 | 
			
		||||
#include <ATen/core/NamedTensor.h>
 | 
			
		||||
#include <ATen/Dispatch.h>
 | 
			
		||||
#include <ATen/ExpandUtils.h>
 | 
			
		||||
#include <ATen/OpMathType.h>
 | 
			
		||||
#include <ATen/TensorUtils.h>
 | 
			
		||||
#include <ATen/cuda/CUDABlas.h>
 | 
			
		||||
#include <ATen/cuda/tunable/Tunable.h>
 | 
			
		||||
#include <ATen/cuda/tunable/TunableGemm.h>
 | 
			
		||||
#include <ATen/native/Resize.h>
 | 
			
		||||
#include <c10/util/MaybeOwned.h>
 | 
			
		||||
#include <ATen/native/GroupedMMUtils.h>
 | 
			
		||||
#include <ATen/native/cuda/RowwiseScaledMM.h>
 | 
			
		||||
#include <ATen/native/cuda/ScaledGroupMM.h>
 | 
			
		||||
#include <ATen/native/cuda/GroupMM.h>
 | 
			
		||||
#include <ATen/ceil_div.h>
 | 
			
		||||
 | 
			
		||||
#ifdef USE_FBGEMM_GENAI
 | 
			
		||||
#include <fbgemm_gpu/torch_ops.h>
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
#ifndef AT_PER_OPERATOR_HEADERS
 | 
			
		||||
#include <ATen/Functions.h>
 | 
			
		||||
#include <ATen/NativeFunctions.h>
 | 
			
		||||
#else
 | 
			
		||||
#include <ATen/ops/_addmm_activation_native.h>
 | 
			
		||||
#include <ATen/ops/_efficientzerotensor.h>
 | 
			
		||||
#include <ATen/ops/_scaled_mm_native.h>
 | 
			
		||||
#include <ATen/ops/_unsafe_view_native.h>
 | 
			
		||||
#include <ATen/ops/abs.h>
 | 
			
		||||
#include <ATen/ops/addmm_native.h>
 | 
			
		||||
#include <ATen/ops/addmv_native.h>
 | 
			
		||||
#include <ATen/ops/baddbmm_native.h>
 | 
			
		||||
#include <ATen/ops/bmm_native.h>
 | 
			
		||||
#include <ATen/ops/copy_native.h>
 | 
			
		||||
#include <ATen/ops/dot_native.h>
 | 
			
		||||
#include <ATen/ops/empty.h>
 | 
			
		||||
#include <ATen/ops/empty_strided.h>
 | 
			
		||||
#include <ATen/ops/gelu.h>
 | 
			
		||||
#include <ATen/ops/max.h>
 | 
			
		||||
#include <ATen/ops/mm_native.h>
 | 
			
		||||
#include <ATen/ops/mul.h>
 | 
			
		||||
#include <ATen/ops/relu.h>
 | 
			
		||||
#include <ATen/ops/ones.h>
 | 
			
		||||
#include <ATen/ops/scalar_tensor_native.h>
 | 
			
		||||
#include <ATen/ops/vdot_native.h>
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
using at::blas::ScalingType;
 | 
			
		||||
using at::blas::SwizzleType;
 | 
			
		||||
 | 
			
		||||
namespace at::cuda::scaled {
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 * Both inputs must be fp8,
 | 
			
		||||
 * Each needs a single scale, {Tensorwise (float)}
 | 
			
		||||
 */
 | 
			
		||||
bool check_tensorwise_recipe(c10::ScalarType type_a,
 | 
			
		||||
                             std::vector<ScalingType>& recipe_a,
 | 
			
		||||
                             ArrayRef<Tensor>& scales_a,
 | 
			
		||||
                             c10::ScalarType type_b,
 | 
			
		||||
                             std::vector<ScalingType>& recipe_b,
 | 
			
		||||
                             ArrayRef<Tensor>& scales_b) {
 | 
			
		||||
  // both types must be fp8
 | 
			
		||||
  if (!isFloat8Type(type_a) || !isFloat8Type(type_b)) {
 | 
			
		||||
    return false;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // 1 scale each, {Tensorwise, float}
 | 
			
		||||
  if (scales_a.size() != 1 || recipe_a.size() != 1 || scales_b.size() != 1 || recipe_b.size() != 1) {
 | 
			
		||||
    return false;
 | 
			
		||||
  }
 | 
			
		||||
  // Need {Blockwise_1x32, e8m0} for A & B
 | 
			
		||||
  if (recipe_a[0] != ScalingType::TensorWise) return false;
 | 
			
		||||
  if (scales_a[0].scalar_type() != ScalarType::Float) return false;
 | 
			
		||||
  if (recipe_b[0] != ScalingType::TensorWise) return false;
 | 
			
		||||
  if (scales_b[0].scalar_type() != ScalarType::Float) return false;
 | 
			
		||||
 | 
			
		||||
  return true;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 * Both inputs must be fp8,
 | 
			
		||||
 * Each needs scales, {Rowwise (float)}
 | 
			
		||||
 */
 | 
			
		||||
bool check_rowwise_recipe(c10::ScalarType type_a,
 | 
			
		||||
                             std::vector<ScalingType>& recipe_a,
 | 
			
		||||
                             ArrayRef<Tensor>& scales_a,
 | 
			
		||||
                             c10::ScalarType type_b,
 | 
			
		||||
                             std::vector<ScalingType>& recipe_b,
 | 
			
		||||
                             ArrayRef<Tensor>& scales_b) {
 | 
			
		||||
  // both types must be fp8
 | 
			
		||||
  if (!isFloat8Type(type_a) || !isFloat8Type(type_b)) {
 | 
			
		||||
    return false;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // 1 scale each, {Tensorwise, float}
 | 
			
		||||
  if (scales_a.size() != 1 || recipe_a.size() != 1 || scales_b.size() != 1 || recipe_b.size() != 1) {
 | 
			
		||||
    return false;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Need {RowWise, dp32} for A & B
 | 
			
		||||
  if (recipe_a[0] != ScalingType::RowWise) return false;
 | 
			
		||||
  if (scales_a[0].scalar_type() != ScalarType::Float) return false;
 | 
			
		||||
  if (recipe_b[0] != ScalingType::RowWise) return false;
 | 
			
		||||
  if (scales_b[0].scalar_type() != ScalarType::Float) return false;
 | 
			
		||||
 | 
			
		||||
  return true;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 * Two-level scaling, canonical NVFP4
 | 
			
		||||
 * Both inputs must be fp4
 | 
			
		||||
 * A, B need 2 scales, {Blockwise_1x16 (e4m3), Tensorwise (fp32)}
 | 
			
		||||
 */
 | 
			
		||||
bool check_nvfp4_recipe(c10::ScalarType type_a,
 | 
			
		||||
                        std::vector<ScalingType>& recipe_a,
 | 
			
		||||
                        ArrayRef<Tensor>& scales_a,
 | 
			
		||||
                        c10::ScalarType type_b,
 | 
			
		||||
                        std::vector<ScalingType>& recipe_b,
 | 
			
		||||
                        ArrayRef<Tensor>& scales_b) {
 | 
			
		||||
  // both types must be fp4
 | 
			
		||||
  if (type_a != ScalarType::Float4_e2m1fn_x2 || type_b != ScalarType::Float4_e2m1fn_x2) {
 | 
			
		||||
    return false;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // 2 scales, 2 recipes for each input
 | 
			
		||||
  if (scales_a.size() != 2 || recipe_a.size() != 2 || scales_b.size() != 2 || recipe_b.size() != 2) {
 | 
			
		||||
    return false;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Need {Blockwise_1x16, e4m3 for scale[0], Tensorwise, fp32 for scale[1]}
 | 
			
		||||
  if (recipe_a[0] != ScalingType::BlockWise1x16 || recipe_a[1] != ScalingType::TensorWise) return false;
 | 
			
		||||
  if (scales_a[0].scalar_type() != ScalarType::Float8_e4m3fn || scales_a[1].scalar_type() != ScalarType::Float) return false;
 | 
			
		||||
  if (recipe_b[0] != ScalingType::BlockWise1x16 || recipe_b[1] != ScalingType::TensorWise) return false;
 | 
			
		||||
  if (scales_b[0].scalar_type() != ScalarType::Float8_e4m3fn || scales_b[1].scalar_type() != ScalarType::Float) return false;
 | 
			
		||||
 | 
			
		||||
  return true;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 * Single-level scaling, what PyT currently understands
 | 
			
		||||
 * Both inputs must be fp4
 | 
			
		||||
 * A, B need 1 scale, {Blockwise_1x16 (e4m3)}
 | 
			
		||||
 */
 | 
			
		||||
bool check_nvfp4_recipe_single_scale
 | 
			
		||||
                       (c10::ScalarType type_a,
 | 
			
		||||
                        std::vector<ScalingType>& recipe_a,
 | 
			
		||||
                        ArrayRef<Tensor>& scales_a,
 | 
			
		||||
                        c10::ScalarType type_b,
 | 
			
		||||
                        std::vector<ScalingType>& recipe_b,
 | 
			
		||||
                        ArrayRef<Tensor>& scales_b) {
 | 
			
		||||
  // both types must be fp4
 | 
			
		||||
  if (type_a != ScalarType::Float4_e2m1fn_x2 || type_b != ScalarType::Float4_e2m1fn_x2) {
 | 
			
		||||
    return false;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // 2 scales, 2 recipes for each input
 | 
			
		||||
  if (scales_a.size() != 1 || recipe_a.size() != 1 || scales_b.size() != 1 || recipe_b.size() != 1) {
 | 
			
		||||
    return false;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Need {Blockwise_1x16, e4m3 for scale[0], Tensorwise, fp32 for scale[1]}
 | 
			
		||||
  if (recipe_a[0] != ScalingType::BlockWise1x16) return false;
 | 
			
		||||
  if (scales_a[0].scalar_type() != ScalarType::Float8_e4m3fn) return false;
 | 
			
		||||
  if (recipe_b[0] != ScalingType::BlockWise1x16) return false;
 | 
			
		||||
  if (scales_b[0].scalar_type() != ScalarType::Float8_e4m3fn) return false;
 | 
			
		||||
 | 
			
		||||
  return true;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 * Both inputs must be fp8
 | 
			
		||||
 * A, B must only have 1 scale each, A: {Blockwise_1x128 (float), B: {Blockwise_128x128 (float)
 | 
			
		||||
 */
 | 
			
		||||
bool check_deepseek_recipe(ScalingType expected_recipe_a,
 | 
			
		||||
                           ScalingType expected_recipe_b,
 | 
			
		||||
                           c10::ScalarType type_a,
 | 
			
		||||
                           std::vector<ScalingType>& recipe_a,
 | 
			
		||||
                           ArrayRef<Tensor>& scales_a,
 | 
			
		||||
                           c10::ScalarType type_b,
 | 
			
		||||
                           std::vector<ScalingType>& recipe_b,
 | 
			
		||||
                           ArrayRef<Tensor>& scales_b) {
 | 
			
		||||
  // both types must be fp8
 | 
			
		||||
  if (type_a != ScalarType::Float8_e4m3fn || type_b != ScalarType::Float8_e4m3fn) {
 | 
			
		||||
    return false;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // 1 scales, 1 recipes for each input
 | 
			
		||||
  if (scales_a.size() != 1 || recipe_a.size() != 1 || scales_b.size() != 1 || recipe_b.size() != 1) {
 | 
			
		||||
    return false;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Need {Blockwise_1x128, float} for A, {Blockwise_128x128, float} for B
 | 
			
		||||
  if (recipe_a[0] != expected_recipe_a) return false;
 | 
			
		||||
  if (scales_a[0].scalar_type() != ScalarType::Float) return false;
 | 
			
		||||
  if (recipe_b[0] != expected_recipe_b) return false;
 | 
			
		||||
  if (scales_b[0].scalar_type() != ScalarType::Float) return false;
 | 
			
		||||
 | 
			
		||||
  return true;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 * Both inputs must be fp8
 | 
			
		||||
 * A, B must have 1 scale each, {Blockwise_1x32, e8m0}
 | 
			
		||||
 */
 | 
			
		||||
bool check_mxfp8_recipe(c10::ScalarType type_a,
 | 
			
		||||
                        std::vector<ScalingType>& recipe_a,
 | 
			
		||||
                        ArrayRef<Tensor>& scales_a,
 | 
			
		||||
                        c10::ScalarType type_b,
 | 
			
		||||
                        std::vector<ScalingType>& recipe_b,
 | 
			
		||||
                        ArrayRef<Tensor>& scales_b) {
 | 
			
		||||
  // both types must be fp8
 | 
			
		||||
  if (type_a != ScalarType::Float8_e4m3fn || type_b != ScalarType::Float8_e4m3fn) {
 | 
			
		||||
    return false;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // 1 scales, 1 recipes for each input
 | 
			
		||||
  if (scales_a.size() != 1 || recipe_a.size() != 1 || scales_b.size() != 1 || recipe_b.size() != 1) {
 | 
			
		||||
    return false;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Need {Blockwise_1x32, e8m0} for A & B
 | 
			
		||||
  if (recipe_a[0] != ScalingType::BlockWise1x32) return false;
 | 
			
		||||
  if (scales_a[0].scalar_type() != ScalarType::Float8_e8m0fnu) return false;
 | 
			
		||||
  if (recipe_b[0] != ScalingType::BlockWise1x32) return false;
 | 
			
		||||
  if (scales_b[0].scalar_type() != ScalarType::Float8_e8m0fnu) return false;
 | 
			
		||||
 | 
			
		||||
  return true;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 * Both inputs must be fp4
 | 
			
		||||
 * A, B must have 1 scale each, {Blockwise_1x32, e8m0}
 | 
			
		||||
 */
 | 
			
		||||
bool check_mxfp4_recipe(c10::ScalarType type_a,
 | 
			
		||||
                        std::vector<ScalingType>& recipe_a,
 | 
			
		||||
                        ArrayRef<Tensor>& scales_a,
 | 
			
		||||
                        c10::ScalarType type_b,
 | 
			
		||||
                        std::vector<ScalingType>& recipe_b,
 | 
			
		||||
                        ArrayRef<Tensor>& scales_b) {
 | 
			
		||||
  // both types must be fp4
 | 
			
		||||
  if (type_a != ScalarType::Float4_e2m1fn_x2 || type_b != ScalarType::Float4_e2m1fn_x2) {
 | 
			
		||||
    return false;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // 1 scales, 1 recipes for each input
 | 
			
		||||
  if (scales_a.size() != 1 || recipe_a.size() != 1 || scales_b.size() != 1 || recipe_b.size() != 1) {
 | 
			
		||||
    return false;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Need {Blockwise_1x32, e8m0} for A & B
 | 
			
		||||
  if (recipe_a[0] != ScalingType::BlockWise1x32) return false;
 | 
			
		||||
  if (scales_a[0].scalar_type() != ScalarType::Float8_e8m0fnu) return false;
 | 
			
		||||
  if (recipe_b[0] != ScalingType::BlockWise1x32) return false;
 | 
			
		||||
  if (scales_b[0].scalar_type() != ScalarType::Float8_e8m0fnu) return false;
 | 
			
		||||
 | 
			
		||||
  return true;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
} // namespace at::native::cuda::blas::scaled
 | 
			
		||||
							
								
								
									
										174
									
								
								aten/src/ATen/cuda/CUDAScaledBlas.h
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										174
									
								
								aten/src/ATen/cuda/CUDAScaledBlas.h
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,174 @@
 | 
			
		||||
#include <cstdint>
 | 
			
		||||
#include <c10/util/typeid.h>
 | 
			
		||||
#include <c10/util/Exception.h>
 | 
			
		||||
#include <c10/util/SmallVector.h>
 | 
			
		||||
#include <c10/core/Scalar.h>
 | 
			
		||||
#include <c10/core/ScalarType.h>
 | 
			
		||||
#include <c10/util/Exception.h>
 | 
			
		||||
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
 | 
			
		||||
#include <ATen/core/Tensor.h>
 | 
			
		||||
#include <ATen/core/NamedTensor.h>
 | 
			
		||||
#include <ATen/Dispatch.h>
 | 
			
		||||
#include <ATen/ExpandUtils.h>
 | 
			
		||||
#include <ATen/OpMathType.h>
 | 
			
		||||
#include <ATen/TensorUtils.h>
 | 
			
		||||
#include <ATen/cuda/CUDABlas.h>
 | 
			
		||||
#include <ATen/cuda/tunable/Tunable.h>
 | 
			
		||||
#include <ATen/cuda/tunable/TunableGemm.h>
 | 
			
		||||
#include <ATen/native/Resize.h>
 | 
			
		||||
#include <c10/util/MaybeOwned.h>
 | 
			
		||||
#include <ATen/native/GroupedMMUtils.h>
 | 
			
		||||
#include <ATen/native/cuda/RowwiseScaledMM.h>
 | 
			
		||||
#include <ATen/native/cuda/ScaledGroupMM.h>
 | 
			
		||||
#include <ATen/native/cuda/GroupMM.h>
 | 
			
		||||
#include <ATen/ceil_div.h>
 | 
			
		||||
 | 
			
		||||
#ifdef USE_FBGEMM_GENAI
 | 
			
		||||
#include <fbgemm_gpu/torch_ops.h>
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
#ifndef AT_PER_OPERATOR_HEADERS
 | 
			
		||||
#include <ATen/Functions.h>
 | 
			
		||||
#include <ATen/NativeFunctions.h>
 | 
			
		||||
#else
 | 
			
		||||
#include <ATen/ops/_addmm_activation_native.h>
 | 
			
		||||
#include <ATen/ops/_efficientzerotensor.h>
 | 
			
		||||
#include <ATen/ops/_scaled_mm_native.h>
 | 
			
		||||
#include <ATen/ops/_unsafe_view_native.h>
 | 
			
		||||
#include <ATen/ops/abs.h>
 | 
			
		||||
#include <ATen/ops/addmm_native.h>
 | 
			
		||||
#include <ATen/ops/addmv_native.h>
 | 
			
		||||
#include <ATen/ops/baddbmm_native.h>
 | 
			
		||||
#include <ATen/ops/bmm_native.h>
 | 
			
		||||
#include <ATen/ops/copy_native.h>
 | 
			
		||||
#include <ATen/ops/dot_native.h>
 | 
			
		||||
#include <ATen/ops/empty.h>
 | 
			
		||||
#include <ATen/ops/empty_strided.h>
 | 
			
		||||
#include <ATen/ops/gelu.h>
 | 
			
		||||
#include <ATen/ops/max.h>
 | 
			
		||||
#include <ATen/ops/mm_native.h>
 | 
			
		||||
#include <ATen/ops/mul.h>
 | 
			
		||||
#include <ATen/ops/relu.h>
 | 
			
		||||
#include <ATen/ops/ones.h>
 | 
			
		||||
#include <ATen/ops/scalar_tensor_native.h>
 | 
			
		||||
#include <ATen/ops/vdot_native.h>
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
using at::blas::ScalingType;
 | 
			
		||||
using at::blas::SwizzleType;
 | 
			
		||||
 | 
			
		||||
namespace at::cuda::scaled {
 | 
			
		||||
 | 
			
		||||
static bool _scaled_mm_allowed_device(bool sm90_only=false, bool sm100_only=false) {
 | 
			
		||||
#ifdef USE_ROCM
 | 
			
		||||
    static const std::vector<std::string> archs = {
 | 
			
		||||
        "gfx942",
 | 
			
		||||
#if ROCM_VERSION >= 60300
 | 
			
		||||
        "gfx1200", "gfx1201",
 | 
			
		||||
#endif
 | 
			
		||||
#if ROCM_VERSION >= 60500
 | 
			
		||||
        "gfx950"
 | 
			
		||||
#endif
 | 
			
		||||
    };
 | 
			
		||||
    return at::detail::getCUDAHooks().isGPUArch(archs);
 | 
			
		||||
#else
 | 
			
		||||
    auto dprops = at::cuda::getCurrentDeviceProperties();
 | 
			
		||||
 | 
			
		||||
    if (sm90_only || sm100_only) {
 | 
			
		||||
      return (sm90_only && dprops->major == 9) || (sm100_only && dprops->major == 10);
 | 
			
		||||
    } else {
 | 
			
		||||
      return dprops->major >= 9 || (dprops->major == 8 && dprops->minor == 9);
 | 
			
		||||
    }
 | 
			
		||||
#endif
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#ifdef USE_ROCM
 | 
			
		||||
static bool _scaled_mm_is_fnuz() {
 | 
			
		||||
    return at::detail::getCUDAHooks().isGPUArch({"gfx942"});
 | 
			
		||||
}
 | 
			
		||||
#endif
 | 
			
		||||
/**
 | 
			
		||||
 * Track concrete implementations available
 | 
			
		||||
 */
 | 
			
		||||
enum class ScaledGemmImplementation {
 | 
			
		||||
  NONE = 0,
 | 
			
		||||
  TENSORWISE_TENSORWISE = 1,
 | 
			
		||||
  ROWWISE_ROWWISE = 2,
 | 
			
		||||
  BLOCK_128x128_1x128 = 3,
 | 
			
		||||
  BLOCK_1x128_128x128 = 4,
 | 
			
		||||
  BLOCK_1x128_1x128 = 5,
 | 
			
		||||
  MXFP8_MXFP8 = 6,
 | 
			
		||||
  NVFP4_NVFP4 = 7,
 | 
			
		||||
  NVFP4_NVFP4_SINGLE_SCALE = 8,
 | 
			
		||||
  MXFP4_MXFP4 = 9,
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 * Convert passed int (enum) from python back into a
 | 
			
		||||
 * strictly-typed enum
 | 
			
		||||
 */
 | 
			
		||||
template <class EnumType, class ArrayType>
 | 
			
		||||
std::vector<EnumType> convert_int_to_enum(ArrayType& v) {
 | 
			
		||||
  std::vector<EnumType> converted;
 | 
			
		||||
  converted.reserve(v.size());
 | 
			
		||||
 | 
			
		||||
  for (auto vi : v) {
 | 
			
		||||
    converted.push_back(static_cast<EnumType>(vi));
 | 
			
		||||
  }
 | 
			
		||||
  return converted;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
bool check_tensorwise_recipe(c10::ScalarType,
 | 
			
		||||
                             std::vector<ScalingType>&,
 | 
			
		||||
                             ArrayRef<Tensor>&,
 | 
			
		||||
                             c10::ScalarType,
 | 
			
		||||
                             std::vector<ScalingType>&,
 | 
			
		||||
                             ArrayRef<Tensor>&);
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
bool check_rowwise_recipe(c10::ScalarType,
 | 
			
		||||
                             std::vector<ScalingType>&,
 | 
			
		||||
                             ArrayRef<Tensor>&,
 | 
			
		||||
                             c10::ScalarType,
 | 
			
		||||
                             std::vector<ScalingType>&,
 | 
			
		||||
                             ArrayRef<Tensor>&);
 | 
			
		||||
 | 
			
		||||
bool check_nvfp4_recipe(c10::ScalarType,
 | 
			
		||||
                        std::vector<ScalingType>&,
 | 
			
		||||
                        ArrayRef<Tensor>&,
 | 
			
		||||
                        c10::ScalarType,
 | 
			
		||||
                        std::vector<ScalingType>&,
 | 
			
		||||
                        ArrayRef<Tensor>&);
 | 
			
		||||
 | 
			
		||||
bool check_nvfp4_recipe_single_scale
 | 
			
		||||
                       (c10::ScalarType,
 | 
			
		||||
                        std::vector<ScalingType>&,
 | 
			
		||||
                        ArrayRef<Tensor>&,
 | 
			
		||||
                        c10::ScalarType,
 | 
			
		||||
                        std::vector<ScalingType>&,
 | 
			
		||||
                        ArrayRef<Tensor>&);
 | 
			
		||||
 | 
			
		||||
bool check_deepseek_recipe(ScalingType,
 | 
			
		||||
                           ScalingType,
 | 
			
		||||
                           c10::ScalarType,
 | 
			
		||||
                           std::vector<ScalingType>&,
 | 
			
		||||
                           ArrayRef<Tensor>&,
 | 
			
		||||
                           c10::ScalarType,
 | 
			
		||||
                           std::vector<ScalingType>&,
 | 
			
		||||
                           ArrayRef<Tensor>&);
 | 
			
		||||
 | 
			
		||||
bool check_mxfp8_recipe(c10::ScalarType,
 | 
			
		||||
                        std::vector<ScalingType>&,
 | 
			
		||||
                        ArrayRef<Tensor>&,
 | 
			
		||||
                        c10::ScalarType,
 | 
			
		||||
                        std::vector<ScalingType>&,
 | 
			
		||||
                        ArrayRef<Tensor>&);
 | 
			
		||||
 | 
			
		||||
bool check_mxfp4_recipe(c10::ScalarType,
 | 
			
		||||
                        std::vector<ScalingType>&,
 | 
			
		||||
                        ArrayRef<Tensor>&,
 | 
			
		||||
                        c10::ScalarType,
 | 
			
		||||
                        std::vector<ScalingType>&,
 | 
			
		||||
                        ArrayRef<Tensor>&);
 | 
			
		||||
 | 
			
		||||
} // namespace at::native::cuda::blas::scaled
 | 
			
		||||
@ -70,11 +70,7 @@
 | 
			
		||||
#define ATEN_CUB_MAXIMUM() NO_ROCM(at_cuda_detail)ROCM_HIPCUB(::cub)::Max()
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
#if (!defined(USE_ROCM) && !CUB_SUPPORTS_NV_BFLOAT16()) || defined(USE_ROCM)
 | 
			
		||||
 | 
			
		||||
#if !defined(USE_ROCM)
 | 
			
		||||
namespace at_cuda_detail {
 | 
			
		||||
#endif
 | 
			
		||||
#if defined(USE_ROCM)
 | 
			
		||||
 | 
			
		||||
// backport https://github.com/NVIDIA/cub/pull/306 for c10::BFloat16
 | 
			
		||||
 | 
			
		||||
@ -96,10 +92,6 @@ template <>
 | 
			
		||||
struct ROCM_HIPCUB(cub)::NumericTraits<c10::BFloat16>:
 | 
			
		||||
       ROCM_HIPCUB(cub)::BaseTraits<ROCM_HIPCUB(cub)::FLOATING_POINT, true, false, unsigned short, c10::BFloat16> {};
 | 
			
		||||
 | 
			
		||||
#if !defined(USE_ROCM)
 | 
			
		||||
} // namespace at_cuda_detail
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
#if !defined(USE_ROCM)
 | 
			
		||||
@ -121,7 +113,7 @@ struct cuda_type<c10::Half> {
 | 
			
		||||
  using type = __half;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
#if !defined(USE_ROCM) && CUB_SUPPORTS_NV_BFLOAT16()
 | 
			
		||||
#if !defined(USE_ROCM)
 | 
			
		||||
 | 
			
		||||
template<>
 | 
			
		||||
struct cuda_type<c10::BFloat16> {
 | 
			
		||||
@ -203,36 +195,6 @@ __global__ void transform_vals(InputIteratorT1 a, InputIteratorT2 b, OutputItera
 | 
			
		||||
  *out = scan_op(static_cast<acc_t>(*a), static_cast<acc_t>(*b));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#if !CUB_SUPPORTS_FUTURE_VALUE()
 | 
			
		||||
template<typename ValueT, typename InputIteratorT>
 | 
			
		||||
struct chained_iterator {
 | 
			
		||||
  using iterator_category = std::random_access_iterator_tag;
 | 
			
		||||
  using difference_type   = std::ptrdiff_t;
 | 
			
		||||
  using value_type        = ValueT;
 | 
			
		||||
  using pointer           = ValueT*;
 | 
			
		||||
  using reference         = ValueT&;
 | 
			
		||||
 | 
			
		||||
  InputIteratorT iter;
 | 
			
		||||
  ValueT *first;
 | 
			
		||||
  difference_type offset = 0;
 | 
			
		||||
 | 
			
		||||
  __device__ ValueT operator[](difference_type i) {
 | 
			
		||||
    i +=  offset;
 | 
			
		||||
    if (i == 0) {
 | 
			
		||||
      return *first;
 | 
			
		||||
    } else {
 | 
			
		||||
      return ValueT(iter[i - 1]);
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
  __device__ chained_iterator operator+(difference_type i) {
 | 
			
		||||
    return chained_iterator{iter, first, i};
 | 
			
		||||
  }
 | 
			
		||||
  __device__ ValueT operator*() {
 | 
			
		||||
    return (*this)[0];
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
// even though cub is supposed to support tensors with int_max elements, in reality it doesn't,
 | 
			
		||||
// so split at int_max/2
 | 
			
		||||
constexpr int max_cub_size = std::numeric_limits<int>::max() / 2 + 1; // 2**30
 | 
			
		||||
@ -277,25 +239,6 @@ inline void inclusive_scan(InputIteratorT input, OutputIteratorT output, ScanOpT
 | 
			
		||||
        first_elem_ptr,
 | 
			
		||||
        scan_op);
 | 
			
		||||
    C10_CUDA_KERNEL_LAUNCH_CHECK();
 | 
			
		||||
#if !CUB_SUPPORTS_FUTURE_VALUE()
 | 
			
		||||
    using ArgIndexInputIterator = NO_ROCM(at_cuda_detail)::cub::ArgIndexInputIterator<InputIteratorT>;
 | 
			
		||||
    using tuple = typename ArgIndexInputIterator::value_type;
 | 
			
		||||
    auto input_iter_transform = [=] __device__ (const tuple &x)->input_t  {
 | 
			
		||||
      if (x.key == 0) {
 | 
			
		||||
        return *first_elem_ptr;
 | 
			
		||||
      } else {
 | 
			
		||||
        return x.value;
 | 
			
		||||
      }
 | 
			
		||||
    };
 | 
			
		||||
    auto input_ = ATEN_CUB_TRANSFORM_ITERATOR(input_t, decltype(input_iter_transform), ArgIndexInputIterator)(
 | 
			
		||||
      ArgIndexInputIterator(input + i), input_iter_transform);
 | 
			
		||||
    CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceScan::InclusiveScan,
 | 
			
		||||
        input_,
 | 
			
		||||
        output + i,
 | 
			
		||||
        scan_op,
 | 
			
		||||
        size_cub,
 | 
			
		||||
        at::cuda::getCurrentCUDAStream());
 | 
			
		||||
#else
 | 
			
		||||
    CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceScan::ExclusiveScan,
 | 
			
		||||
        input + i + 1,
 | 
			
		||||
        output + i,
 | 
			
		||||
@ -303,7 +246,6 @@ inline void inclusive_scan(InputIteratorT input, OutputIteratorT output, ScanOpT
 | 
			
		||||
        ::at_cuda_detail::cub::FutureValue<input_t>(first_elem_ptr),
 | 
			
		||||
        size_cub,
 | 
			
		||||
        at::cuda::getCurrentCUDAStream());
 | 
			
		||||
#endif
 | 
			
		||||
  }
 | 
			
		||||
#endif
 | 
			
		||||
}
 | 
			
		||||
@ -555,16 +497,6 @@ inline void exclusive_scan(InputIteratorT input, OutputIteratorT output, ScanOpT
 | 
			
		||||
        first_elem_ptr,
 | 
			
		||||
        scan_op);
 | 
			
		||||
    C10_CUDA_KERNEL_LAUNCH_CHECK();
 | 
			
		||||
#if !CUB_SUPPORTS_FUTURE_VALUE()
 | 
			
		||||
    auto input_ = impl::chained_iterator<InitValueT, InputIteratorT>{
 | 
			
		||||
      input + i, first_elem_ptr};
 | 
			
		||||
    CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceScan::InclusiveScan,
 | 
			
		||||
        input_,
 | 
			
		||||
        output + i,
 | 
			
		||||
        scan_op,
 | 
			
		||||
        size_cub,
 | 
			
		||||
        at::cuda::getCurrentCUDAStream());
 | 
			
		||||
#else
 | 
			
		||||
    CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceScan::ExclusiveScan,
 | 
			
		||||
        input + i,
 | 
			
		||||
        output + i,
 | 
			
		||||
@ -572,7 +504,6 @@ inline void exclusive_scan(InputIteratorT input, OutputIteratorT output, ScanOpT
 | 
			
		||||
        ::at_cuda_detail::cub::FutureValue<InitValueT>(first_elem_ptr),
 | 
			
		||||
        size_cub,
 | 
			
		||||
        at::cuda::getCurrentCUDAStream());
 | 
			
		||||
#endif
 | 
			
		||||
  }
 | 
			
		||||
#endif
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -10,14 +10,6 @@
 | 
			
		||||
#define CUB_VERSION 200001
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
// cub sort support for __nv_bfloat16 is added to cub 1.13 in:
 | 
			
		||||
// https://github.com/NVIDIA/cub/pull/306
 | 
			
		||||
#if CUB_VERSION >= 101300
 | 
			
		||||
#define CUB_SUPPORTS_NV_BFLOAT16() true
 | 
			
		||||
#else
 | 
			
		||||
#define CUB_SUPPORTS_NV_BFLOAT16() false
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
// cub support for CUB_WRAPPED_NAMESPACE is added to cub 1.13.1 in:
 | 
			
		||||
// https://github.com/NVIDIA/cub/pull/326
 | 
			
		||||
// CUB_WRAPPED_NAMESPACE is defined globally in cmake/Dependencies.cmake
 | 
			
		||||
@ -28,14 +20,6 @@
 | 
			
		||||
#define USE_GLOBAL_CUB_WRAPPED_NAMESPACE() false
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
// cub support for cub::FutureValue is added to cub 1.15 in:
 | 
			
		||||
// https://github.com/NVIDIA/cub/pull/305
 | 
			
		||||
#if CUB_VERSION >= 101500
 | 
			
		||||
#define CUB_SUPPORTS_FUTURE_VALUE() true
 | 
			
		||||
#else
 | 
			
		||||
#define CUB_SUPPORTS_FUTURE_VALUE() false
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
// There were many bc-breaking changes in major version release of CCCL v3.0.0
 | 
			
		||||
// Please see https://nvidia.github.io/cccl/cccl/3.0_migration_guide.html
 | 
			
		||||
#if CUB_VERSION >= 200800
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										23
									
								
								aten/src/ATen/detail/XLAHooksInterface.cpp
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										23
									
								
								aten/src/ATen/detail/XLAHooksInterface.cpp
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,23 @@
 | 
			
		||||
#include <ATen/detail/XLAHooksInterface.h>
 | 
			
		||||
 | 
			
		||||
namespace at {
 | 
			
		||||
namespace detail {
 | 
			
		||||
 | 
			
		||||
const XLAHooksInterface& getXLAHooks() {
 | 
			
		||||
  auto create_impl = [] {
 | 
			
		||||
    // Create XLA hooks using the registry
 | 
			
		||||
    auto hooks = XLAHooksRegistry()->Create("torch_xla::detail::XLAHooks", XLAHooksArgs{});
 | 
			
		||||
    if (hooks) {
 | 
			
		||||
      return hooks;
 | 
			
		||||
    }
 | 
			
		||||
    // If hooks creation fails, fall back to default implementation
 | 
			
		||||
    return std::make_unique<XLAHooksInterface>();
 | 
			
		||||
  };
 | 
			
		||||
  static auto hooks = create_impl();
 | 
			
		||||
  return *hooks;
 | 
			
		||||
}
 | 
			
		||||
} // namespace detail
 | 
			
		||||
 | 
			
		||||
C10_DEFINE_REGISTRY(XLAHooksRegistry, XLAHooksInterface, XLAHooksArgs)
 | 
			
		||||
 | 
			
		||||
} // namespace at
 | 
			
		||||
							
								
								
									
										79
									
								
								aten/src/ATen/detail/XLAHooksInterface.h
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										79
									
								
								aten/src/ATen/detail/XLAHooksInterface.h
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,79 @@
 | 
			
		||||
#pragma once
 | 
			
		||||
 | 
			
		||||
#include <c10/core/Device.h>
 | 
			
		||||
#include <c10/util/Exception.h>
 | 
			
		||||
#include <c10/util/Registry.h>
 | 
			
		||||
 | 
			
		||||
#include <ATen/detail/AcceleratorHooksInterface.h>
 | 
			
		||||
 | 
			
		||||
C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-parameter")
 | 
			
		||||
 | 
			
		||||
namespace at {
 | 
			
		||||
 | 
			
		||||
constexpr const char* XLA_HELP =
 | 
			
		||||
  "This error has occurred because you are trying "
 | 
			
		||||
  "to use some XLA functionality, but the XLA library has not been "
 | 
			
		||||
  "loaded by the dynamic linker. You must load xla libraries by `import torch_xla`";
 | 
			
		||||
 | 
			
		||||
struct TORCH_API XLAHooksInterface : AcceleratorHooksInterface {
 | 
			
		||||
  ~XLAHooksInterface() override = default;
 | 
			
		||||
 | 
			
		||||
  void init() const override {
 | 
			
		||||
    TORCH_CHECK(false, "Cannot initialize XLA without torch_xla library. ", XLA_HELP);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  virtual bool hasXLA() const {
 | 
			
		||||
    return false;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  virtual std::string showConfig() const {
 | 
			
		||||
    TORCH_CHECK(
 | 
			
		||||
        false,
 | 
			
		||||
        "Cannot query detailed XLA version without torch_xla library. ",
 | 
			
		||||
        XLA_HELP);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  const Generator& getDefaultGenerator(
 | 
			
		||||
      [[maybe_unused]] DeviceIndex device_index = -1) const override {
 | 
			
		||||
    TORCH_CHECK(
 | 
			
		||||
        false, "Cannot get default XLA generator without torch_xla library. ", XLA_HELP);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  Generator getNewGenerator(
 | 
			
		||||
      [[maybe_unused]] DeviceIndex device_index = -1) const override {
 | 
			
		||||
    TORCH_CHECK(false, "Cannot get XLA generator without torch_xla library. ", XLA_HELP);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  virtual DeviceIndex getCurrentDevice() const override {
 | 
			
		||||
    TORCH_CHECK(false, "Cannot get current XLA device without torch_xla library. ", XLA_HELP);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  Device getDeviceFromPtr(void* /*data*/) const override {
 | 
			
		||||
    TORCH_CHECK(false, "Cannot get device of pointer on XLA without torch_xla library. ", XLA_HELP);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  Allocator* getPinnedMemoryAllocator() const override {
 | 
			
		||||
    TORCH_CHECK(false, "Cannot get XLA pinned memory allocator without torch_xla library. ", XLA_HELP);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  bool isPinnedPtr(const void* data) const override {
 | 
			
		||||
    return false;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  bool hasPrimaryContext(DeviceIndex device_index) const override {
 | 
			
		||||
    TORCH_CHECK(false, "Cannot query primary context without torch_xla library. ", XLA_HELP);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
struct TORCH_API XLAHooksArgs {};
 | 
			
		||||
 | 
			
		||||
TORCH_DECLARE_REGISTRY(XLAHooksRegistry, XLAHooksInterface, XLAHooksArgs);
 | 
			
		||||
#define REGISTER_XLA_HOOKS(clsname) \
 | 
			
		||||
  C10_REGISTER_CLASS(XLAHooksRegistry, clsname, clsname)
 | 
			
		||||
 | 
			
		||||
namespace detail {
 | 
			
		||||
TORCH_API const XLAHooksInterface& getXLAHooks();
 | 
			
		||||
} // namespace detail
 | 
			
		||||
} // namespace at
 | 
			
		||||
C10_DIAGNOSTIC_POP()
 | 
			
		||||
@ -11,6 +11,8 @@ inline void check_pixel_shuffle_shapes(const Tensor& self, int64_t upscale_facto
 | 
			
		||||
              "pixel_shuffle expects a positive upscale_factor, but got ",
 | 
			
		||||
              upscale_factor);
 | 
			
		||||
  int64_t c = self.size(-3);
 | 
			
		||||
  TORCH_CHECK_VALUE(upscale_factor <= std::numeric_limits<decltype(upscale_factor)>::max() / upscale_factor,
 | 
			
		||||
        "upscale factor is too large, (upscale_factor)^2 overflowed: upscale_factor=", upscale_factor);
 | 
			
		||||
  int64_t upscale_factor_squared = upscale_factor * upscale_factor;
 | 
			
		||||
  TORCH_CHECK(c % upscale_factor_squared == 0,
 | 
			
		||||
              "pixel_shuffle expects its input's 'channel' dimension to be divisible by the square of "
 | 
			
		||||
 | 
			
		||||
@ -259,11 +259,20 @@ inline void winograd_f2k3_input_transform_inplace__rvv(
 | 
			
		||||
  const vfloat32m1_t wd1 = __riscv_vfadd_vv_f32m1(d1, d2, 4);
 | 
			
		||||
  const vfloat32m1_t wd2 = __riscv_vfsub_vv_f32m1(d2, d1, 4);
 | 
			
		||||
  const vfloat32m1_t wd3 = __riscv_vfsub_vv_f32m1(d1, d3, 4);
 | 
			
		||||
 | 
			
		||||
  *input_tile_val = __riscv_vset_v_f32m1_f32m1x4(*input_tile_val, 0, wd0);
 | 
			
		||||
  *input_tile_val = __riscv_vset_v_f32m1_f32m1x4(*input_tile_val, 1, wd1);
 | 
			
		||||
  *input_tile_val = __riscv_vset_v_f32m1_f32m1x4(*input_tile_val, 2, wd2);
 | 
			
		||||
  *input_tile_val = __riscv_vset_v_f32m1_f32m1x4(*input_tile_val, 3, wd3);
 | 
			
		||||
  /* GCC 14.2 (RISC-V RVV) ICE workaround:
 | 
			
		||||
   * Avoid single-statement read-modify-write on MEM_REF like:
 | 
			
		||||
   *   *input_tile_val =
 | 
			
		||||
   *     __riscv_vset_v_f32m1_f32m1x4(*input_tile_val, idx, val);
 | 
			
		||||
   * This triggers an ICE during GIMPLE lower (gsi_replace / riscv_gimple_fold_builtin)
 | 
			
		||||
   * with -march=rv64gcv. Use a temporary then write back.
 | 
			
		||||
   * Do NOT refactor into the single-statement form. Clang is unaffected.
 | 
			
		||||
   */
 | 
			
		||||
  vfloat32m1x4_t tmp_input_tile_val = *input_tile_val;
 | 
			
		||||
  tmp_input_tile_val = __riscv_vset_v_f32m1_f32m1x4(tmp_input_tile_val, 0, wd0);
 | 
			
		||||
  tmp_input_tile_val = __riscv_vset_v_f32m1_f32m1x4(tmp_input_tile_val, 1, wd1);
 | 
			
		||||
  tmp_input_tile_val = __riscv_vset_v_f32m1_f32m1x4(tmp_input_tile_val, 2, wd2);
 | 
			
		||||
  tmp_input_tile_val = __riscv_vset_v_f32m1_f32m1x4(tmp_input_tile_val, 3, wd3);
 | 
			
		||||
  *input_tile_val = tmp_input_tile_val;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
inline void winograd_f2k3_output_transform_inplace__rvv(
 | 
			
		||||
@ -277,9 +286,15 @@ inline void winograd_f2k3_output_transform_inplace__rvv(
 | 
			
		||||
  const vfloat32m1_t wm0 = __riscv_vfadd_vv_f32m1(m0_plus_m1, m2, 4);
 | 
			
		||||
  const vfloat32m1_t m1_sub_m2 = __riscv_vfsub_vv_f32m1(m1, m2, 4);
 | 
			
		||||
  const vfloat32m1_t wm1 = __riscv_vfsub_vv_f32m1(m1_sub_m2, m3, 4);
 | 
			
		||||
 | 
			
		||||
  *input_tile_val = __riscv_vset_v_f32m1_f32m1x4(*input_tile_val, 0, wm0);
 | 
			
		||||
  *input_tile_val = __riscv_vset_v_f32m1_f32m1x4(*input_tile_val, 1, wm1);
 | 
			
		||||
  /* GCC 14.2 (RISC-V RVV) ICE workaround — see note above.
 | 
			
		||||
   * Keep the temporary + write-back pattern to avoid ICE.
 | 
			
		||||
   * Do NOT rewrite into:
 | 
			
		||||
   *   *input_tile_val = __riscv_vset_v_f32m1_f32m1x4(*input_tile_val, idx, val);
 | 
			
		||||
   */
 | 
			
		||||
  vfloat32m1x4_t tmp_output_tile_val = *input_tile_val;
 | 
			
		||||
  tmp_output_tile_val = __riscv_vset_v_f32m1_f32m1x4(tmp_output_tile_val, 0, wm0);
 | 
			
		||||
  tmp_output_tile_val = __riscv_vset_v_f32m1_f32m1x4(tmp_output_tile_val, 1, wm1);
 | 
			
		||||
  *input_tile_val = tmp_output_tile_val;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
inline vfloat32m1_t
 | 
			
		||||
@ -300,11 +315,17 @@ inline void winograd_f2k3_kernel_transform__rvv(
 | 
			
		||||
  const vfloat32m1_t const_half = __riscv_vfmv_v_f_f32m1(0.5f, 4);
 | 
			
		||||
  const vfloat32m1_t g0_plus_g2 = __riscv_vfadd_vv_f32m1(g0, g2, 4);
 | 
			
		||||
  vfloat32m1_t half_g0_plus_g2 =  __riscv_vfmul_vv_f32m1(const_half, g0_plus_g2, 4);
 | 
			
		||||
 | 
			
		||||
  *transform = __riscv_vset_v_f32m1_f32m1x4(*transform, 0, g0);
 | 
			
		||||
  *transform = __riscv_vset_v_f32m1_f32m1x4(*transform, 1, vmuladdq_f32(half_g0_plus_g2, const_half, g1));
 | 
			
		||||
  *transform = __riscv_vset_v_f32m1_f32m1x4(*transform, 2, vmulsubq_f32(half_g0_plus_g2, const_half, g1));
 | 
			
		||||
  *transform = __riscv_vset_v_f32m1_f32m1x4(*transform, 3, g2);
 | 
			
		||||
  /* GCC 14.2 (RISC-V RVV) ICE workaround — see note above.
 | 
			
		||||
   * Keep the temporary + write-back pattern to avoid ICE.
 | 
			
		||||
   * Do NOT rewrite into:
 | 
			
		||||
   *   *transform = __riscv_vset_v_f32m1_f32m1x4(*transform, idx, val);
 | 
			
		||||
   */
 | 
			
		||||
  vfloat32m1x4_t tmp_transform = *transform;
 | 
			
		||||
  tmp_transform = __riscv_vset_v_f32m1_f32m1x4(tmp_transform, 0, g0);
 | 
			
		||||
  tmp_transform = __riscv_vset_v_f32m1_f32m1x4(tmp_transform, 1, vmuladdq_f32(half_g0_plus_g2, const_half, g1));
 | 
			
		||||
  tmp_transform = __riscv_vset_v_f32m1_f32m1x4(tmp_transform, 2, vmulsubq_f32(half_g0_plus_g2, const_half, g1));
 | 
			
		||||
  tmp_transform = __riscv_vset_v_f32m1_f32m1x4(tmp_transform, 3, g2);
 | 
			
		||||
  *transform = tmp_transform;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
inline vfloat32m1x4_t v4f_transpose4x4__rvv(const vfloat32m1x4_t m) {
 | 
			
		||||
 | 
			
		||||
@ -120,7 +120,7 @@ static void pow_tensor_scalar_kernel(
 | 
			
		||||
  } else if (dtype == ScalarType::Half) {
 | 
			
		||||
    [&]() {
 | 
			
		||||
      using scalar_t =
 | 
			
		||||
          decltype(c10::impl::ScalarTypeToCPPType<ScalarType::Half>::t);
 | 
			
		||||
          c10::impl::ScalarTypeToCPPTypeT<ScalarType::Half>;
 | 
			
		||||
      const auto exp = exp_scalar.to<scalar_t>();
 | 
			
		||||
      using Vec = Vectorized<scalar_t>;
 | 
			
		||||
      cpu_kernel_vec(iter,
 | 
			
		||||
 | 
			
		||||
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							@ -856,9 +856,13 @@ struct type_specialized_kernel_launcher {
 | 
			
		||||
      out_calc_t output_offset_calculator,
 | 
			
		||||
      loader_t loader,
 | 
			
		||||
      storer_t storer) {
 | 
			
		||||
    if (ret_t == rt_binary_specializations[arg_index][0] &&
 | 
			
		||||
        arg0_t == rt_binary_specializations[arg_index][1] &&
 | 
			
		||||
        arg1_t == rt_binary_specializations[arg_index][2])
 | 
			
		||||
    constexpr ScalarType sret_t = rt_binary_specializations[arg_index][0];
 | 
			
		||||
    constexpr ScalarType sarg0_t = rt_binary_specializations[arg_index][1];
 | 
			
		||||
    constexpr ScalarType sarg1_t = rt_binary_specializations[arg_index][2];
 | 
			
		||||
    if (ret_t == sret_t && arg0_t == sarg0_t && arg1_t == sarg1_t) {
 | 
			
		||||
      using cret_t = c10::impl::ScalarTypeToCPPTypeT<sret_t>;
 | 
			
		||||
      using carg0_t = c10::impl::ScalarTypeToCPPTypeT<sarg0_t>;
 | 
			
		||||
      using carg1_t = c10::impl::ScalarTypeToCPPTypeT<sarg1_t>;
 | 
			
		||||
      launch_vectorized_templated_kernel<
 | 
			
		||||
          func_t,
 | 
			
		||||
          array_t,
 | 
			
		||||
@ -866,12 +870,9 @@ struct type_specialized_kernel_launcher {
 | 
			
		||||
          out_calc_t,
 | 
			
		||||
          loader_t,
 | 
			
		||||
          storer_t,
 | 
			
		||||
          decltype(c10::impl::ScalarTypeToCPPType<
 | 
			
		||||
                   rt_binary_specializations[arg_index][0]>::t),
 | 
			
		||||
          decltype(c10::impl::ScalarTypeToCPPType<
 | 
			
		||||
                   rt_binary_specializations[arg_index][1]>::t),
 | 
			
		||||
          decltype(c10::impl::ScalarTypeToCPPType<
 | 
			
		||||
                   rt_binary_specializations[arg_index][2]>::t)>(
 | 
			
		||||
          cret_t,
 | 
			
		||||
          carg0_t,
 | 
			
		||||
          carg1_t>(
 | 
			
		||||
          numel,
 | 
			
		||||
          f,
 | 
			
		||||
          data,
 | 
			
		||||
@ -879,6 +880,7 @@ struct type_specialized_kernel_launcher {
 | 
			
		||||
          output_offset_calculator,
 | 
			
		||||
          loader,
 | 
			
		||||
          storer);
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										574
									
								
								aten/src/ATen/native/cuda/GroupedBlas.cpp
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										574
									
								
								aten/src/ATen/native/cuda/GroupedBlas.cpp
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,574 @@
 | 
			
		||||
#include <cstdint>
 | 
			
		||||
#include <c10/util/typeid.h>
 | 
			
		||||
#include <c10/util/Exception.h>
 | 
			
		||||
#include <c10/util/SmallVector.h>
 | 
			
		||||
#include <c10/core/Scalar.h>
 | 
			
		||||
#include <c10/core/ScalarType.h>
 | 
			
		||||
#include <c10/util/Exception.h>
 | 
			
		||||
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
 | 
			
		||||
#include <ATen/core/Tensor.h>
 | 
			
		||||
#include <ATen/core/NamedTensor.h>
 | 
			
		||||
#include <ATen/Dispatch.h>
 | 
			
		||||
#include <ATen/ExpandUtils.h>
 | 
			
		||||
#include <ATen/OpMathType.h>
 | 
			
		||||
#include <ATen/TensorUtils.h>
 | 
			
		||||
#include <ATen/cuda/CUDABlas.h>
 | 
			
		||||
#include <ATen/cuda/CUDAScaledBlas.h>
 | 
			
		||||
#include <ATen/cuda/tunable/Tunable.h>
 | 
			
		||||
#include <ATen/cuda/tunable/TunableGemm.h>
 | 
			
		||||
#include <ATen/native/Resize.h>
 | 
			
		||||
#include <c10/util/MaybeOwned.h>
 | 
			
		||||
#include <ATen/native/GroupedMMUtils.h>
 | 
			
		||||
#include <ATen/native/cuda/RowwiseScaledMM.h>
 | 
			
		||||
#include <ATen/native/cuda/ScaledGroupMM.h>
 | 
			
		||||
#include <ATen/native/cuda/GroupMM.h>
 | 
			
		||||
#include <ATen/ceil_div.h>
 | 
			
		||||
 | 
			
		||||
#ifdef USE_FBGEMM_GENAI
 | 
			
		||||
#include <fbgemm_gpu/torch_ops.h>
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
#ifndef AT_PER_OPERATOR_HEADERS
 | 
			
		||||
#include <ATen/Functions.h>
 | 
			
		||||
#include <ATen/NativeFunctions.h>
 | 
			
		||||
#else
 | 
			
		||||
#include <ATen/ops/_addmm_activation_native.h>
 | 
			
		||||
#include <ATen/ops/_efficientzerotensor.h>
 | 
			
		||||
#include <ATen/ops/_scaled_mm_native.h>
 | 
			
		||||
#include <ATen/ops/_unsafe_view_native.h>
 | 
			
		||||
#include <ATen/ops/abs.h>
 | 
			
		||||
#include <ATen/ops/addmm_native.h>
 | 
			
		||||
#include <ATen/ops/addmv_native.h>
 | 
			
		||||
#include <ATen/ops/baddbmm_native.h>
 | 
			
		||||
#include <ATen/ops/bmm_native.h>
 | 
			
		||||
#include <ATen/ops/copy_native.h>
 | 
			
		||||
#include <ATen/ops/dot_native.h>
 | 
			
		||||
#include <ATen/ops/empty.h>
 | 
			
		||||
#include <ATen/ops/empty_strided.h>
 | 
			
		||||
#include <ATen/ops/gelu.h>
 | 
			
		||||
#include <ATen/ops/max.h>
 | 
			
		||||
#include <ATen/ops/mm_native.h>
 | 
			
		||||
#include <ATen/ops/mul.h>
 | 
			
		||||
#include <ATen/ops/relu.h>
 | 
			
		||||
#include <ATen/ops/ones.h>
 | 
			
		||||
#include <ATen/ops/scalar_tensor_native.h>
 | 
			
		||||
#include <ATen/ops/vdot_native.h>
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
using at::blas::ScalingType;
 | 
			
		||||
using at::blas::SwizzleType;
 | 
			
		||||
 | 
			
		||||
namespace scaled_blas = at::cuda::scaled;
 | 
			
		||||
using scaled_blas::ScaledGemmImplementation;
 | 
			
		||||
using scaled_blas::convert_int_to_enum;
 | 
			
		||||
using scaled_blas::_scaled_mm_allowed_device;
 | 
			
		||||
 | 
			
		||||
namespace at::native {
 | 
			
		||||
 | 
			
		||||
namespace {
 | 
			
		||||
 | 
			
		||||
// 2d-2d and 2d-3d
 | 
			
		||||
// scaling=MXFP8
 | 
			
		||||
// CUDA-only
 | 
			
		||||
Tensor&
 | 
			
		||||
_mx8_mx8_bf16_grouped_mm_fbgemm(
 | 
			
		||||
        const Tensor& mat_a,
 | 
			
		||||
        const Tensor& mat_b,
 | 
			
		||||
        const Tensor& scale_a,
 | 
			
		||||
        const SwizzleType& swizzle_a,
 | 
			
		||||
        const Tensor& scale_b,
 | 
			
		||||
        const SwizzleType& swizzle_b,
 | 
			
		||||
        const std::optional<at::Tensor>& offs,
 | 
			
		||||
        Tensor& out) {
 | 
			
		||||
    const bool a_is_2d = mat_a.dim() == 2;
 | 
			
		||||
    const bool b_is_2d = mat_b.dim() == 2;
 | 
			
		||||
    bool b_is_3d = mat_b.dim() == 3;
 | 
			
		||||
    bool is_2d_2d = a_is_2d && b_is_2d;
 | 
			
		||||
    bool is_2d_3d = a_is_2d && b_is_3d;
 | 
			
		||||
    TORCH_CHECK_VALUE(is_2d_2d || is_2d_3d, "MXFP8 grouped GEMM currently only supports 2d-2d and 2d-3d cases");
 | 
			
		||||
    TORCH_CHECK_VALUE(offs.has_value(), "MXFP8 2d-2d and 2d-3d grouped GEMMs requires offsets");
 | 
			
		||||
    TORCH_CHECK_VALUE(out.scalar_type() == at::kBFloat16, "Only bf16 out_dtype is supported for MXFP8 grouped gemm");
 | 
			
		||||
    // MXFP8 expects float8_e8m0fnu scales.
 | 
			
		||||
    TORCH_CHECK_VALUE(scale_a.scalar_type() == at::kFloat8_e8m0fnu && scale_b.scalar_type() == at::kFloat8_e8m0fnu,
 | 
			
		||||
        "For MXFP8 grouped gemm, both scales must be float8_e8m0fnu tensors.");
 | 
			
		||||
#ifdef USE_ROCM
 | 
			
		||||
    TORCH_CHECK_VALUE(swizzle_a == SwizzleType::NO_SWIZZLE && swizzle_b == SwizzleType::NO_SWIZZLE,
 | 
			
		||||
        "For ROCM MXFP8 grouped gemm, both scale swizzle types must be SWIZZLE_NONE");
 | 
			
		||||
#else
 | 
			
		||||
    TORCH_CHECK_VALUE(swizzle_a == SwizzleType::SWIZZLE_32_4_4 && swizzle_b == SwizzleType::SWIZZLE_32_4_4,
 | 
			
		||||
        "For CUDA MXFP8 grouped gemm, both scale swizzle types must be SWIZZLE_32_4_4");
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
#if defined(USE_FBGEMM_GENAI) and !defined(USE_ROCM)
 | 
			
		||||
    fbgemm_gpu::mx8mx8bf16_grouped_mm(
 | 
			
		||||
        mat_a,
 | 
			
		||||
        mat_b,
 | 
			
		||||
        scale_a,
 | 
			
		||||
        scale_b,
 | 
			
		||||
        offs.value(),
 | 
			
		||||
        out);
 | 
			
		||||
#else
 | 
			
		||||
    TORCH_CHECK_NOT_IMPLEMENTED(false, "mxfp8_mxfp8 grouped gemm requires compile with USE_FBGEMM_GENAI");
 | 
			
		||||
#endif
 | 
			
		||||
    return out;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 2d-2d and 2d-3d cases
 | 
			
		||||
// scaling=rowwise
 | 
			
		||||
// CUDA-only
 | 
			
		||||
Tensor&
 | 
			
		||||
_f8_f8_bf16_rowwise_grouped_mm_cuda(
 | 
			
		||||
          const Tensor& mat_a,
 | 
			
		||||
          const Tensor& mat_b,
 | 
			
		||||
          const Tensor& scale_a,
 | 
			
		||||
          const Tensor& scale_b,
 | 
			
		||||
          const std::optional<Tensor>& offs,
 | 
			
		||||
          const std::optional<Tensor>& bias,
 | 
			
		||||
          const bool use_fast_accum,
 | 
			
		||||
          Tensor& out) {
 | 
			
		||||
  TORCH_CHECK_VALUE(mat_a.dtype() == at::kFloat8_e4m3fn, "Expected mat_a to be Float8_e4m3 matrix got ", mat_a.scalar_type());
 | 
			
		||||
  TORCH_CHECK_VALUE(mat_b.dtype() == at::kFloat8_e4m3fn, "Expected mat_a to be Float8_e4m3 matrix got ", mat_b.scalar_type());
 | 
			
		||||
 | 
			
		||||
  at::cuda::detail::f8f8bf16_grouped_mm(
 | 
			
		||||
      mat_a,
 | 
			
		||||
      mat_b,
 | 
			
		||||
      scale_a,
 | 
			
		||||
      scale_b,
 | 
			
		||||
      offs,
 | 
			
		||||
      bias,
 | 
			
		||||
      use_fast_accum,
 | 
			
		||||
      out);
 | 
			
		||||
    return out;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 2d-2d and 2d-3d cases
 | 
			
		||||
// scaling=rowwise
 | 
			
		||||
// only being called for rocm
 | 
			
		||||
Tensor&
 | 
			
		||||
_f8_f8_bf16_rowwise_grouped_mm_rocm(
 | 
			
		||||
      const Tensor& mat_a,
 | 
			
		||||
      const Tensor& mat_b,
 | 
			
		||||
      const Tensor& scale_a,
 | 
			
		||||
      const Tensor& scale_b,
 | 
			
		||||
      const std::optional<Tensor>& offs,
 | 
			
		||||
      Tensor& out) {
 | 
			
		||||
  TORCH_CHECK_VALUE(mat_a.dtype() == at::kFloat8_e4m3fnuz, "Expected mat_a to be Float8_e4m3fnuz matrix got ", mat_a.scalar_type());
 | 
			
		||||
  TORCH_CHECK_VALUE(mat_b.dtype() == at::kFloat8_e4m3fnuz, "Expected mat_a to be Float8_e4m3fnuz matrix got ", mat_b.scalar_type());
 | 
			
		||||
 | 
			
		||||
#if defined(USE_FBGEMM_GENAI) && defined(USE_ROCM)
 | 
			
		||||
  fbgemm_gpu::f8f8bf16_rowwise_grouped_mm(
 | 
			
		||||
      mat_a,
 | 
			
		||||
      // FBGEMM expects B matrix shape to be (.., N, K)
 | 
			
		||||
      mat_b.transpose(-2, -1),
 | 
			
		||||
      scale_a,
 | 
			
		||||
      scale_b,
 | 
			
		||||
      offs,
 | 
			
		||||
      out);
 | 
			
		||||
#else
 | 
			
		||||
  TORCH_CHECK_NOT_IMPLEMENTED(false, "grouped gemm is not supported without USE_FBGEMM_GENAI on ROCM")
 | 
			
		||||
#endif
 | 
			
		||||
  return out;
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Dispatch f8 x f8 -> bf16 row-wise scaled to rocm/cuda
 | 
			
		||||
Tensor&
 | 
			
		||||
_f8_f8_bf16_rowwise_grouped_mm(
 | 
			
		||||
      const Tensor& mat_a,
 | 
			
		||||
      const Tensor& mat_b,
 | 
			
		||||
      const Tensor& scale_a,
 | 
			
		||||
      const Tensor& scale_b,
 | 
			
		||||
      const std::optional<Tensor>& offs,
 | 
			
		||||
      const std::optional<Tensor>& bias,
 | 
			
		||||
      bool use_fast_accum,
 | 
			
		||||
      Tensor& out) {
 | 
			
		||||
  // FP8 per-tensor and per-row scaling expect fp32 scales.
 | 
			
		||||
  TORCH_CHECK_VALUE(scale_a.scalar_type() == kFloat && scale_b.scalar_type() == kFloat,
 | 
			
		||||
      "For grouped FP8 rowwise, both scales must be float32 tensors");
 | 
			
		||||
#ifndef USE_ROCM
 | 
			
		||||
  return _f8_f8_bf16_rowwise_grouped_mm_cuda(
 | 
			
		||||
      mat_a,
 | 
			
		||||
      mat_b,
 | 
			
		||||
      scale_a,
 | 
			
		||||
      scale_b,
 | 
			
		||||
      offs,
 | 
			
		||||
      bias,
 | 
			
		||||
      use_fast_accum,
 | 
			
		||||
      out);
 | 
			
		||||
#else
 | 
			
		||||
  // NOTE: ignore use_fast_accum
 | 
			
		||||
  TORCH_CHECK_VALUE(!bias.has_value(), "ROCM grouped gemm does not support bias")
 | 
			
		||||
  return _f8_f8_bf16_rowwise_grouped_mm_rocm(
 | 
			
		||||
      mat_a,
 | 
			
		||||
      mat_b,
 | 
			
		||||
      scale_a,
 | 
			
		||||
      scale_b,
 | 
			
		||||
      offs,
 | 
			
		||||
      out);
 | 
			
		||||
#endif
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void _check_scales_fp8_rowwise(const Tensor& mat, const Tensor& scale, const int dim, const int arg_idx, const int scale_multiplier=1) {
 | 
			
		||||
  // Checks scales for 2d or 3d target tensors (`mat`).
 | 
			
		||||
  if (mat.dim() == 2) {
 | 
			
		||||
    TORCH_CHECK(
 | 
			
		||||
        scale.dim() == 1,
 | 
			
		||||
        "scale must be a 1D tensor, but got ",
 | 
			
		||||
        scale.dim(),
 | 
			
		||||
        "D, arg ",
 | 
			
		||||
        arg_idx);
 | 
			
		||||
    TORCH_CHECK(
 | 
			
		||||
        scale.is_contiguous(), "scale must be contiguous for arg ", arg_idx);
 | 
			
		||||
    TORCH_CHECK(
 | 
			
		||||
        scale.size(0) == mat.size(dim) * scale_multiplier,
 | 
			
		||||
        "scale must have the same length as mat for arg ",
 | 
			
		||||
        arg_idx);
 | 
			
		||||
  } else {
 | 
			
		||||
    TORCH_CHECK(
 | 
			
		||||
        scale.dim() == 2,
 | 
			
		||||
        "scale must be a 2D tensor, but got ",
 | 
			
		||||
        scale.dim(),
 | 
			
		||||
        "D for arg ",
 | 
			
		||||
        arg_idx);
 | 
			
		||||
    TORCH_CHECK(
 | 
			
		||||
        scale.stride(1) == 1,
 | 
			
		||||
        "scale must be contiguous in the last dimension for arg ",
 | 
			
		||||
        arg_idx);
 | 
			
		||||
    TORCH_CHECK(
 | 
			
		||||
        scale.size(0) == mat.size(0),
 | 
			
		||||
        "scale must have the same batch dimension as mat for arg ",
 | 
			
		||||
        arg_idx);
 | 
			
		||||
    TORCH_CHECK(
 | 
			
		||||
        scale.size(1) == mat.size(1 + dim),
 | 
			
		||||
        "scale must have the same first dimension as mat for arg ",
 | 
			
		||||
        arg_idx);
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void _check_scales_mxfp8(const Tensor& mat, const Tensor& scale, const int dim, const int arg_idx) {
 | 
			
		||||
  // Checks scales for 2d or 3d target tensors (`mat`).
 | 
			
		||||
  if (mat.dim() == 2) {
 | 
			
		||||
    // For MXFP8, 2d tensors have variable size groups represented as subtensors,
 | 
			
		||||
    // that are converted to blocked padded format individually,
 | 
			
		||||
    // so we can't check the scale sizes without doing a d2h sync to get the group sizes here.
 | 
			
		||||
    TORCH_CHECK(
 | 
			
		||||
      scale.dim() == mat.dim(),
 | 
			
		||||
      "for mxfp8, scale must have same number of dimensions as parent tensor, but got mat.dim() = ", mat.dim(), " and scale.dim() = ", scale.dim(), " for arg ", arg_idx);
 | 
			
		||||
 | 
			
		||||
    // LHS mat shape (M, total_K) -> scale shape (rounded_up(M, 128), rounded_up_per_group(K/32, 4))
 | 
			
		||||
    // RHS mat shape (total_K, N) -> scale shape (rounded_up(N, 128), rounded_up_per_group(K/32, 4))
 | 
			
		||||
    //   * weight is transposed prior to the call, scale stays non-transposed.
 | 
			
		||||
    bool LHS = arg_idx == 0;
 | 
			
		||||
    int scale_dim_to_check = 0;
 | 
			
		||||
    int mat_dim_to_check = LHS ? 0 : 1;
 | 
			
		||||
    TORCH_CHECK(
 | 
			
		||||
        scale.size(scale_dim_to_check) >= mat.size(mat_dim_to_check),
 | 
			
		||||
        "for mxfp8, arg ", arg_idx, " tensor shape (", mat.size(0), ", ", mat.size(1), ") ",
 | 
			
		||||
        "must have scale.shape[", scale_dim_to_check, "] >= ", mat.size(mat_dim_to_check), " but got scale.shape=(", scale.size(0), ", ", scale.size(1), ")");
 | 
			
		||||
  } else {
 | 
			
		||||
    // For MXFP8, 3d tensors have static group sizes (stack of 2d tensors),
 | 
			
		||||
    // so we can check the exact expected scale sizes here without a d2h sync.
 | 
			
		||||
    auto round_up = [](auto x, auto y) {
 | 
			
		||||
        return ((x + y - 1) / y) * y;
 | 
			
		||||
    };
 | 
			
		||||
 | 
			
		||||
    // TODO: this is for 3d tensor in 2d-3d case specifically.
 | 
			
		||||
    // We'll need to support 3d-3d and 3d-2d cases once mxfp8 grouped gemm supports them.
 | 
			
		||||
    int64_t G = mat.size(0);
 | 
			
		||||
    int64_t K = mat.size(1);
 | 
			
		||||
    int64_t N = mat.size(2);
 | 
			
		||||
    int64_t blocked_scale_K = round_up(K/32, 4);
 | 
			
		||||
    int64_t blocked_scale_N = round_up(N, 128);
 | 
			
		||||
 | 
			
		||||
    // fbgemm expects stack of flattened blocked scales for 3d tensor, shape (G, blocked_scale_K * blocked_scale_N).
 | 
			
		||||
    TORCH_CHECK(
 | 
			
		||||
      scale.dim() == mat.dim() - 1,
 | 
			
		||||
      "for mxfp8 2d-3d grouped GEMM, the 3d tensor of shape (G,K,N) must have a 2d scale of shape (G, blocked_scale_K * blocked_scale_N), but scale is ", scale.dim(), "D for arg ", arg_idx
 | 
			
		||||
    );
 | 
			
		||||
    TORCH_CHECK(
 | 
			
		||||
      scale.size(0) == G && scale.size(1) == blocked_scale_K * blocked_scale_N,
 | 
			
		||||
      "for mxfp8, the tensor shape (", G, ", ", K, ", ", N, ") must have scale shape (", G, ",", blocked_scale_K, ",", blocked_scale_N, ") for arg ", arg_idx
 | 
			
		||||
    );
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void check_scale(const Tensor& mat, const Tensor& scale, const int dim, const int arg_idx, const int scale_multiplier=1) {
 | 
			
		||||
  bool using_fp8_rowwise = scale.scalar_type() == kFloat;
 | 
			
		||||
  bool using_mxfp8 = scale.scalar_type() == at::kFloat8_e8m0fnu;
 | 
			
		||||
  if (using_fp8_rowwise) {
 | 
			
		||||
    _check_scales_fp8_rowwise(mat, scale, dim, arg_idx, scale_multiplier);
 | 
			
		||||
  } else if (using_mxfp8) {
 | 
			
		||||
    _check_scales_mxfp8(mat, scale, dim, arg_idx);
 | 
			
		||||
  } else {
 | 
			
		||||
    TORCH_CHECK(false, "scale must be float32 or float8_e8m0fnu, but got ", scale.dtype());
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
} // namespace
 | 
			
		||||
 | 
			
		||||
Tensor
 | 
			
		||||
_scaled_grouped_mm_cuda(
 | 
			
		||||
        const Tensor& mat_a,
 | 
			
		||||
        const Tensor& mat_b,
 | 
			
		||||
        const Tensor& scale_a,
 | 
			
		||||
        const Tensor& scale_b,
 | 
			
		||||
        const std::optional<at::Tensor>& offs,
 | 
			
		||||
        const std::optional<at::Tensor>& bias,
 | 
			
		||||
        const std::optional<at::Tensor>& scale_result,
 | 
			
		||||
        std::optional<c10::ScalarType> out_dtype,
 | 
			
		||||
        bool use_fast_accum) {
 | 
			
		||||
  bool allowed_device = _scaled_mm_allowed_device(/*sm90_only*/true, /*sm100_only*/true);
 | 
			
		||||
  TORCH_CHECK_VALUE(allowed_device, "torch._scaled_grouped_mm is only supported on CUDA devices with compute capability = [9.0, 10.0], or ROCm MI300+");
 | 
			
		||||
 | 
			
		||||
  TORCH_CHECK_VALUE(!check_valid_strides_and_return_transposed(mat_a), "Expected mat1 to not be transposed");
 | 
			
		||||
  TORCH_CHECK_VALUE(check_valid_strides_and_return_transposed(mat_b), "Expected mat2 to be transposed");
 | 
			
		||||
  TORCH_CHECK_VALUE(mat_a.dim() == 2 || mat_a.dim() == 3, "mat_a has to be 2 or 3d");
 | 
			
		||||
  TORCH_CHECK_VALUE(mat_b.dim() == 2 || mat_b.dim() == 3, "mat_b has to be 2 or 3d");
 | 
			
		||||
  const bool a_is_2d = mat_a.dim() == 2;
 | 
			
		||||
  const bool b_is_2d = mat_b.dim() == 2;
 | 
			
		||||
 | 
			
		||||
  // NOTE(slayton): For sub-1B formats want contraction_dim argument?
 | 
			
		||||
  if (!a_is_2d || !b_is_2d) {
 | 
			
		||||
    TORCH_CHECK_VALUE(mat_a.size(-1) == mat_b.size(-2), "contraction dimension of mat_a and mat_b must match");
 | 
			
		||||
  }
 | 
			
		||||
  TORCH_CHECK_VALUE(
 | 
			
		||||
    mat_a.size(-1) % 16 == 0,
 | 
			
		||||
    "Expected trailing dimension of mat_a to be divisible by 16 ",
 | 
			
		||||
    "but got mat1 shape: (",
 | 
			
		||||
    mat_a.sizes(),
 | 
			
		||||
    ").");
 | 
			
		||||
  TORCH_CHECK_VALUE(mat_b.size(-2) % 16 == 0 && mat_b.size(-1) % 16 == 0,
 | 
			
		||||
    "Expected mat_b shape to be divisible by 16 ",
 | 
			
		||||
    "but got mat_b shape: (",
 | 
			
		||||
    mat_b.sizes(),
 | 
			
		||||
    ").");
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
  TORCH_CHECK_VALUE(!bias.has_value(), "Bias not supported yet");
 | 
			
		||||
  TORCH_CHECK_VALUE(!scale_result.has_value(), "Scale result not supported yet");
 | 
			
		||||
  TORCH_CHECK_VALUE(offs.has_value() ==  (a_is_2d || b_is_2d), "Have to provide offsets if there is a 2d matrix");
 | 
			
		||||
 | 
			
		||||
  // NOTE: mxfp8 x mxfp8 requires (and asserts later) that offsets is present.
 | 
			
		||||
  //       for rowwise, no offsets implies 3d-3d and is handled by lower-level
 | 
			
		||||
  //       routines
 | 
			
		||||
  if (offs.has_value()) {
 | 
			
		||||
    TORCH_CHECK_VALUE(offs->dim() == 1, "offs has to be 1D");
 | 
			
		||||
    TORCH_CHECK_VALUE(offs->dtype() == at::kInt, "Offsets have to be int32");
 | 
			
		||||
  }
 | 
			
		||||
  // FP8 per-tensor and per-row scaling expect fp32 scales.
 | 
			
		||||
  // MXFP8 expects float8_e8m0fnu scales.
 | 
			
		||||
  TORCH_CHECK_VALUE(
 | 
			
		||||
      (scale_a.scalar_type() == kFloat && scale_b.scalar_type() == kFloat) ||
 | 
			
		||||
      (scale_a.scalar_type() == at::kFloat8_e8m0fnu && scale_b.scalar_type() == at::kFloat8_e8m0fnu),
 | 
			
		||||
      "For FP8 tensorwise and rowwise, both scales must both be float32 tensors. For MXFP8, scales must both be float8_e8m0fnu tensors.");
 | 
			
		||||
 | 
			
		||||
  const int scale_multiplier = (mat_a.dim() == 2 && mat_b.dim() == 2) ? offs->size(0) : 1;
 | 
			
		||||
  check_scale(mat_a, scale_a, 0 ,0, scale_multiplier);
 | 
			
		||||
  check_scale(mat_b, scale_b, 1, 1, scale_multiplier);
 | 
			
		||||
 | 
			
		||||
  const auto out_dtype_ = out_dtype.value_or(kBFloat16);
 | 
			
		||||
  TORCH_CHECK_VALUE(out_dtype_ == kBFloat16, "Only bf16 high precision output types are supported for grouped gemm");
 | 
			
		||||
 | 
			
		||||
  Tensor out = create_grouped_gemm_output_tensor(mat_a, mat_b, offs, out_dtype_);
 | 
			
		||||
 | 
			
		||||
#if defined(USE_FBGEMM_GENAI) && defined(USE_CUDA) && !defined(USE_ROCM)
 | 
			
		||||
  // MXFP8 grouped GEMM dispatching
 | 
			
		||||
  bool is_mx8mx8bf16 = (
 | 
			
		||||
    mat_a.scalar_type() == at::kFloat8_e4m3fn && mat_b.scalar_type() == at::kFloat8_e4m3fn &&
 | 
			
		||||
    scale_a.scalar_type() == at::kFloat8_e8m0fnu && scale_b.scalar_type() == at::kFloat8_e8m0fnu
 | 
			
		||||
  );
 | 
			
		||||
#else
 | 
			
		||||
  bool is_mx8mx8bf16 = false;
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
  if (is_mx8mx8bf16) {
 | 
			
		||||
    // Note: Passing implied SwizzleType here, correctness of scale previously checked
 | 
			
		||||
    //       in `check_scale` call
 | 
			
		||||
    return _mx8_mx8_bf16_grouped_mm_fbgemm(
 | 
			
		||||
        mat_a,
 | 
			
		||||
        mat_b,
 | 
			
		||||
        scale_a,
 | 
			
		||||
        SwizzleType::SWIZZLE_32_4_4,
 | 
			
		||||
        scale_b,
 | 
			
		||||
        SwizzleType::SWIZZLE_32_4_4,
 | 
			
		||||
        offs.value(),
 | 
			
		||||
        out);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // If we're not MXFP8, then we're row-wise scaling.
 | 
			
		||||
  return _f8_f8_bf16_rowwise_grouped_mm(
 | 
			
		||||
      mat_a,
 | 
			
		||||
      mat_b,
 | 
			
		||||
      scale_a,
 | 
			
		||||
      scale_b,
 | 
			
		||||
      offs,
 | 
			
		||||
      bias,
 | 
			
		||||
      use_fast_accum,
 | 
			
		||||
      out);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
namespace {
 | 
			
		||||
 | 
			
		||||
using acceptance_fn = std::function<bool(c10::ScalarType, std::vector<ScalingType>&, ArrayRef<Tensor>&, c10::ScalarType, std::vector<ScalingType>&, ArrayRef<Tensor>&)>;
 | 
			
		||||
 | 
			
		||||
std::array<std::tuple<std::string, acceptance_fn, ScaledGemmImplementation>, 2> scale_grouped_kernel_dispatch = {{
 | 
			
		||||
  { "rowwise_rowwise", scaled_blas::check_rowwise_recipe, ScaledGemmImplementation::ROWWISE_ROWWISE},
 | 
			
		||||
  { "mxfp8_mxfp8", scaled_blas::check_mxfp8_recipe, ScaledGemmImplementation::MXFP8_MXFP8}}};
 | 
			
		||||
 | 
			
		||||
} // anonymous namespace
 | 
			
		||||
 | 
			
		||||
Tensor
 | 
			
		||||
_scaled_grouped_mm_cuda_v2(
 | 
			
		||||
          const Tensor& mat_a, const Tensor& mat_b,
 | 
			
		||||
          ArrayRef<Tensor> scale_a,
 | 
			
		||||
          IntArrayRef scale_recipe_a,
 | 
			
		||||
          IntArrayRef swizzle_a,
 | 
			
		||||
          ArrayRef<Tensor> scale_b,
 | 
			
		||||
          IntArrayRef scale_recipe_b,
 | 
			
		||||
          IntArrayRef swizzle_b,
 | 
			
		||||
          const std::optional<Tensor>& offs,
 | 
			
		||||
          const std::optional<Tensor>& bias,
 | 
			
		||||
          const std::optional<c10::ScalarType> out_dtype,
 | 
			
		||||
          IntArrayRef contraction_dim,
 | 
			
		||||
          bool use_fast_accum) {
 | 
			
		||||
  bool allowed_device = _scaled_mm_allowed_device(/*sm90_only*/true, /*sm100_only*/true);
 | 
			
		||||
  TORCH_CHECK_VALUE(allowed_device, "torch._scaled_grouped_mm is only supported on CUDA devices with compute capability = [9.0, 10.0], or ROCm MI300+");
 | 
			
		||||
 | 
			
		||||
  TORCH_CHECK_VALUE(!check_valid_strides_and_return_transposed(mat_a), "Expected mat1 to not be transposed");
 | 
			
		||||
  TORCH_CHECK_VALUE(check_valid_strides_and_return_transposed(mat_b), "Expected mat2 to be transposed");
 | 
			
		||||
  TORCH_CHECK_VALUE(mat_a.dim() == 2 || mat_a.dim() == 3, "mat_a has to be 2 or 3d");
 | 
			
		||||
  TORCH_CHECK_VALUE(mat_b.dim() == 2 || mat_b.dim() == 3, "mat_b has to be 2 or 3d");
 | 
			
		||||
  const bool a_is_2d = mat_a.dim() == 2;
 | 
			
		||||
  const bool b_is_2d = mat_b.dim() == 2;
 | 
			
		||||
 | 
			
		||||
  // NOTE(slayton): For sub-1B formats want contraction_dim argument?
 | 
			
		||||
  if (!a_is_2d || !b_is_2d) {
 | 
			
		||||
    if (contraction_dim.size() > 0) {
 | 
			
		||||
      const int dim_a = contraction_dim[0], dim_b = mat_b.size(contraction_dim[1]);
 | 
			
		||||
      TORCH_CHECK_VALUE(mat_a.size(dim_a) == mat_b.size(dim_b),
 | 
			
		||||
          "Contraction dimensions (", dim_a, ",", dim_b, ") of mat_a and mat_b must match, got: ", mat_a.size(dim_a), " and ",
 | 
			
		||||
          mat_b.size(dim_b));
 | 
			
		||||
      // Note: only (-1, -2) is currently supported
 | 
			
		||||
      TORCH_CHECK_VALUE(dim_a == -1 && dim_b == -2, "Curently contraction dims must be (-1, -2) only");
 | 
			
		||||
    } else {
 | 
			
		||||
      TORCH_CHECK_VALUE(mat_a.size(-1) == mat_b.size(-2), "contraction dimension of mat_a and mat_b must match");
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
  TORCH_CHECK_VALUE(
 | 
			
		||||
    mat_a.size(-1) % 16 == 0,
 | 
			
		||||
    "Expected trailing dimension of mat_a to be divisible by 16 ",
 | 
			
		||||
    "but got mat1 shape: (",
 | 
			
		||||
    mat_a.sizes(),
 | 
			
		||||
    ").");
 | 
			
		||||
  TORCH_CHECK_VALUE(mat_b.size(-2) % 16 == 0 && mat_b.size(-1) % 16 == 0,
 | 
			
		||||
    "Expected mat_b shape to be divisible by 16 ",
 | 
			
		||||
    "but got mat_b shape: (",
 | 
			
		||||
    mat_b.sizes(),
 | 
			
		||||
    ").");
 | 
			
		||||
 | 
			
		||||
  TORCH_CHECK_VALUE(!bias.has_value(), "Bias not supported yet");
 | 
			
		||||
  TORCH_CHECK_VALUE(offs.has_value() ==  (a_is_2d || b_is_2d), "Have to provide offsets if there is a 2d matrix");
 | 
			
		||||
 | 
			
		||||
  // NOTE: mxfp8 x mxfp8 requires (and asserts later) that offsets is present.
 | 
			
		||||
  //       for rowwise, no offsets implies 3d-3d and is handled by lower-level
 | 
			
		||||
  //       routines
 | 
			
		||||
  if (offs.has_value()) {
 | 
			
		||||
    TORCH_CHECK_VALUE(offs->dim() == 1, "offs has to be 1D");
 | 
			
		||||
    TORCH_CHECK_VALUE(offs->dtype() == at::kInt, "Offsets have to be int32");
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  const auto out_dtype_ = out_dtype.value_or(kBFloat16);
 | 
			
		||||
  TORCH_CHECK_VALUE(out_dtype_ == kBFloat16, "Only bf16 high precision output types are supported for grouped gemm");
 | 
			
		||||
 | 
			
		||||
  Tensor out = create_grouped_gemm_output_tensor(mat_a, mat_b, offs, out_dtype_);
 | 
			
		||||
 | 
			
		||||
  // Conversion of implicitly-defined enums to explicit
 | 
			
		||||
  auto scale_recipe_a_enum = convert_int_to_enum<ScalingType>(scale_recipe_a);
 | 
			
		||||
  auto swizzle_a_enum = convert_int_to_enum<SwizzleType>(swizzle_a);
 | 
			
		||||
  auto scale_recipe_b_enum = convert_int_to_enum<ScalingType>(scale_recipe_b);
 | 
			
		||||
  auto swizzle_b_enum = convert_int_to_enum<SwizzleType>(swizzle_b);
 | 
			
		||||
 | 
			
		||||
  // at this point we can start working out what we want to be doing
 | 
			
		||||
  // Try to do as few steps as possible.
 | 
			
		||||
  // NOTE: support is deliberately sparse, can explicitly enumerate all combinations allowed.
 | 
			
		||||
  // Do this via a list of defined (name, acceptance, concrete_impl) tuples.
 | 
			
		||||
  ScaledGemmImplementation gemm_impl = ScaledGemmImplementation::NONE;
 | 
			
		||||
  for (const auto& fn_entry : scale_grouped_kernel_dispatch) {
 | 
			
		||||
    const auto [name, accept_fn, scaled_gemm_impl] = fn_entry;
 | 
			
		||||
    bool ok = accept_fn(mat_a.scalar_type(),
 | 
			
		||||
                        scale_recipe_a_enum,
 | 
			
		||||
                        scale_a,
 | 
			
		||||
                        mat_b.scalar_type(),
 | 
			
		||||
                        scale_recipe_b_enum,
 | 
			
		||||
                        scale_b);
 | 
			
		||||
    if (ok) {
 | 
			
		||||
      gemm_impl = scaled_gemm_impl;
 | 
			
		||||
      break;
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
  TORCH_CHECK_VALUE(gemm_impl != ScaledGemmImplementation::NONE,
 | 
			
		||||
      "No gemm implementation was found");
 | 
			
		||||
 | 
			
		||||
  switch (gemm_impl) {
 | 
			
		||||
    case ScaledGemmImplementation::ROWWISE_ROWWISE: {
 | 
			
		||||
      const int scale_multiplier = (mat_a.dim() == 2 && mat_b.dim() == 2) ? offs->size(0) : 1;
 | 
			
		||||
      _check_scales_fp8_rowwise(mat_a, scale_a[0], 0 /* dim */ , 0 /* arg_idx */, scale_multiplier);
 | 
			
		||||
      _check_scales_fp8_rowwise(mat_b, scale_b[0], 1 /* dim */ , 1 /* arg_idx */, scale_multiplier);
 | 
			
		||||
      return _f8_f8_bf16_rowwise_grouped_mm(
 | 
			
		||||
          mat_a,
 | 
			
		||||
          mat_b,
 | 
			
		||||
          scale_a[0],
 | 
			
		||||
          scale_b[0],
 | 
			
		||||
          offs,
 | 
			
		||||
          bias,
 | 
			
		||||
          use_fast_accum,
 | 
			
		||||
          out);
 | 
			
		||||
    }
 | 
			
		||||
    case ScaledGemmImplementation::MXFP8_MXFP8: {
 | 
			
		||||
      _check_scales_mxfp8(mat_a, scale_a[0], 0 /* dim */, 0 /* arg_idx */);
 | 
			
		||||
      _check_scales_mxfp8(mat_b, scale_b[0], 1 /* dim */, 1 /* arg_idx */);
 | 
			
		||||
      return _mx8_mx8_bf16_grouped_mm_fbgemm(
 | 
			
		||||
          mat_a,
 | 
			
		||||
          mat_b,
 | 
			
		||||
          scale_a[0],
 | 
			
		||||
          swizzle_a_enum[0],
 | 
			
		||||
          scale_b[0],
 | 
			
		||||
          swizzle_b_enum[0],
 | 
			
		||||
          offs.value(),
 | 
			
		||||
          out);
 | 
			
		||||
    }
 | 
			
		||||
    default:
 | 
			
		||||
      TORCH_CHECK_NOT_IMPLEMENTED(false,
 | 
			
		||||
          "_scaled_grouped_mm_cuda_v2 is in an inconsistent state - should never reach here");
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
Tensor _grouped_mm_cuda(const Tensor& mat_a, const Tensor& mat_b,
 | 
			
		||||
const std::optional<at::Tensor>& offs,
 | 
			
		||||
const std::optional<at::Tensor>& bias,
 | 
			
		||||
std::optional<c10::ScalarType> out_dtype) {
 | 
			
		||||
  _grouped_mm_validate_inputs(mat_a, mat_b, offs, bias, out_dtype);
 | 
			
		||||
  bool a_b_and_out_are_bf16 = (
 | 
			
		||||
    mat_a.dtype() == at::kBFloat16 &&
 | 
			
		||||
    mat_b.dtype() == at::kBFloat16 &&
 | 
			
		||||
    out_dtype.value_or(at::kBFloat16) == at::kBFloat16
 | 
			
		||||
  );
 | 
			
		||||
#ifndef USE_ROCM
 | 
			
		||||
  bool use_fast_path = _scaled_mm_allowed_device(/*sm90_only*/true, /*sm100_only*/true) && a_b_and_out_are_bf16;
 | 
			
		||||
#else
 | 
			
		||||
  // _scaled_mm_allowed_device is used here within _grouped_mm_cuda which seems incorrect since scale is not used.
 | 
			
		||||
  // the _grouped_mm_fallback should be safe for any ROCm GPU since it's just calling typical mm/bmm
 | 
			
		||||
  bool use_fast_path = false;
 | 
			
		||||
#endif
 | 
			
		||||
  const auto out_dtype_ = _resolve_grouped_mm_out_dtype(mat_a, mat_b, out_dtype);
 | 
			
		||||
  Tensor out = create_grouped_gemm_output_tensor(mat_a, mat_b, offs, out_dtype_);
 | 
			
		||||
  if (use_fast_path) {
 | 
			
		||||
    // fast path, no d2h sync needed
 | 
			
		||||
    at::cuda::detail::bf16bf16_grouped_mm(mat_a, mat_b, offs, bias, out);
 | 
			
		||||
  } else {
 | 
			
		||||
    _grouped_mm_fallback(mat_a, mat_b, offs, bias, out_dtype, out);
 | 
			
		||||
  }
 | 
			
		||||
  return out;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
} // namespace at::native
 | 
			
		||||
@ -6,7 +6,7 @@
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
// ROCm 6.3 is planned to have these functions, but until then here they are.
 | 
			
		||||
#if defined(USE_ROCM) && ROCM_VERSION >= 60201
 | 
			
		||||
#if defined(USE_ROCM)
 | 
			
		||||
#include <device_functions.h>
 | 
			
		||||
#include <hip/hip_fp16.h>
 | 
			
		||||
#include <hip/hip_bf16.h>
 | 
			
		||||
@ -115,9 +115,7 @@ __device__ __forceinline__ void fastSpecializedAtomicAdd(
 | 
			
		||||
    index_t index,
 | 
			
		||||
    const index_t numel,
 | 
			
		||||
    scalar_t value) {
 | 
			
		||||
#if (                      \
 | 
			
		||||
    (defined(USE_ROCM) && ROCM_VERSION < 60201) || \
 | 
			
		||||
    (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 700)))
 | 
			
		||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 700))
 | 
			
		||||
  gpuAtomicAddNoReturn(
 | 
			
		||||
      reinterpret_cast<at::Half*>(tensor) + index,
 | 
			
		||||
      static_cast<at::Half>(value));
 | 
			
		||||
@ -160,9 +158,7 @@ __device__ __forceinline__ void fastSpecializedAtomicAdd(
 | 
			
		||||
    index_t index,
 | 
			
		||||
    const index_t numel,
 | 
			
		||||
    scalar_t value) {
 | 
			
		||||
#if (                      \
 | 
			
		||||
    (defined(USE_ROCM) && ROCM_VERSION < 60201) || \
 | 
			
		||||
    (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)))
 | 
			
		||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))
 | 
			
		||||
  gpuAtomicAddNoReturn(
 | 
			
		||||
      reinterpret_cast<at::BFloat16*>(tensor) + index,
 | 
			
		||||
      static_cast<at::BFloat16>(value));
 | 
			
		||||
 | 
			
		||||
@ -1,18 +1,17 @@
 | 
			
		||||
#pragma once
 | 
			
		||||
 | 
			
		||||
#include <ATen/OpMathType.h>
 | 
			
		||||
#include <ATen/cuda/detail/OffsetCalculator.cuh>
 | 
			
		||||
#include <ATen/detail/FunctionTraits.h>
 | 
			
		||||
#include <ATen/native/TensorIterator.h>
 | 
			
		||||
#include <ATen/native/TensorIteratorDynamicCasting.h>
 | 
			
		||||
#include <ATen/cuda/detail/OffsetCalculator.cuh>
 | 
			
		||||
#include <ATen/OpMathType.h>
 | 
			
		||||
#include <ATen/native/cuda/thread_constants.h>
 | 
			
		||||
 | 
			
		||||
#include <thrust/tuple.h>
 | 
			
		||||
 | 
			
		||||
#include <ATen/native/cuda/MemoryAccess.cuh>
 | 
			
		||||
 | 
			
		||||
#include <tuple>
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
namespace at::native {
 | 
			
		||||
 | 
			
		||||
template<int N>
 | 
			
		||||
@ -62,7 +61,11 @@ __device__ inline void elementwise_kernel_helper(func_t f, policy_t policy) {
 | 
			
		||||
  #pragma unroll
 | 
			
		||||
  for (int i = 0; i < elems_per_thread; i++) {
 | 
			
		||||
    if (policy.check_inbounds(i)) {
 | 
			
		||||
#if defined(__HIP__)
 | 
			
		||||
      results[i] = c10::guts::apply(f, args[i]);
 | 
			
		||||
#else
 | 
			
		||||
      results[i] = std::apply(f, args[i]);
 | 
			
		||||
#endif
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -23,7 +23,7 @@ namespace at::native {
 | 
			
		||||
 | 
			
		||||
// The maximum number of threads in a block
 | 
			
		||||
#if defined(USE_ROCM)
 | 
			
		||||
constexpr int MAX_BLOCK_SIZE = 256;
 | 
			
		||||
constexpr int MAX_BLOCK_SIZE = 1024;
 | 
			
		||||
#else
 | 
			
		||||
constexpr int MAX_BLOCK_SIZE = 512;
 | 
			
		||||
#endif
 | 
			
		||||
@ -33,7 +33,7 @@ constexpr unsigned MAX_GRID_SIZE = 65535u;
 | 
			
		||||
// Number of threads in a block given an input size up to MAX_BLOCK_SIZE
 | 
			
		||||
static int getNumThreads(int nElem) {
 | 
			
		||||
#if defined(USE_ROCM)
 | 
			
		||||
  int threadSizes[5] = { 16, 32, 64, 128, MAX_BLOCK_SIZE };
 | 
			
		||||
  int threadSizes[5] = { 64, 128, 256, 512, MAX_BLOCK_SIZE };
 | 
			
		||||
#else
 | 
			
		||||
  int threadSizes[5] = { 32, 64, 128, 256, MAX_BLOCK_SIZE };
 | 
			
		||||
#endif
 | 
			
		||||
@ -115,9 +115,23 @@ __device__ scalar_t reduce(Op op, PTA tensor, int plane) {
 | 
			
		||||
  // first the reductions each thread does separately
 | 
			
		||||
  scalar_t sum = static_cast<scalar_t>(0);
 | 
			
		||||
  for (int batch = threadIdx.y; batch < tensor.size(0); batch += blockDim.y) {
 | 
			
		||||
#if defined(USE_ROCM)
 | 
			
		||||
    constexpr int UNRL = 4; // load deserilize factor
 | 
			
		||||
    scalar_t tmp[UNRL];
 | 
			
		||||
    for (int x = threadIdx.x; x < tensor.size(2); x += blockDim.x*UNRL) {
 | 
			
		||||
#pragma unroll
 | 
			
		||||
      for (int u = 0; u < UNRL; u++)
 | 
			
		||||
        tmp[u] = op(batch, plane, std::min((int)tensor.size(2)-1, (int)(x+u*blockDim.x)));
 | 
			
		||||
#pragma unroll
 | 
			
		||||
      for (int u = 0; u < UNRL; u++)
 | 
			
		||||
        if (x+u*blockDim.x < tensor.size(2))
 | 
			
		||||
          sum += tmp[u];
 | 
			
		||||
    }
 | 
			
		||||
#else
 | 
			
		||||
    for (int x = threadIdx.x; x < tensor.size(2); x += blockDim.x) {
 | 
			
		||||
      sum += op(batch, plane, x);
 | 
			
		||||
    }
 | 
			
		||||
#endif
 | 
			
		||||
  }
 | 
			
		||||
  __shared__ scalar_t shared[C10_WARP_SIZE];
 | 
			
		||||
  SumReduceOp<scalar_t> reduce_op;
 | 
			
		||||
@ -292,6 +306,22 @@ __global__ void batch_norm_collect_statistics_kernel(
 | 
			
		||||
  stat_accscalar_t var_n = 0;
 | 
			
		||||
  int n = 0;
 | 
			
		||||
  for (int batch = threadIdx.y; batch < input.size(0); batch += blockDim.y) {
 | 
			
		||||
#if defined(USE_ROCM)
 | 
			
		||||
    constexpr int UNRL = 4;
 | 
			
		||||
    stat_accscalar_t v_[UNRL];
 | 
			
		||||
    for (int x = threadIdx.x; x < input.size(2); x += blockDim.x*UNRL) {
 | 
			
		||||
      for (int u = 0; u < UNRL; u++)
 | 
			
		||||
        v_[u] = input[batch][plane][min(x+u*blockDim.x, input.size(2)-1)];
 | 
			
		||||
      for (int u = 0; u < UNRL; u++) {
 | 
			
		||||
        if (x+u*blockDim.x < input.size(2)) {
 | 
			
		||||
          stat_accscalar_t d1 = v_[u] - avg;
 | 
			
		||||
          n++;
 | 
			
		||||
          avg += d1 / n;
 | 
			
		||||
          var_n += d1 * (v_[u] - avg);
 | 
			
		||||
        }
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
#else
 | 
			
		||||
    for (int x = threadIdx.x; x < input.size(2); x += blockDim.x) {
 | 
			
		||||
      stat_accscalar_t v = input[batch][plane][x];
 | 
			
		||||
      stat_accscalar_t d1 = v - avg;
 | 
			
		||||
@ -299,6 +329,7 @@ __global__ void batch_norm_collect_statistics_kernel(
 | 
			
		||||
      avg += d1 / n;
 | 
			
		||||
      var_n += d1 * (v - avg);
 | 
			
		||||
    }
 | 
			
		||||
#endif
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // first warpSum to get one value per thread to
 | 
			
		||||
 | 
			
		||||
@ -92,6 +92,16 @@ inline thrust::pair<int64_t, int64_t>  get_index_mapping2d(
 | 
			
		||||
    output_offset + output_y * output_dim_x + output_x);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
__device__ __forceinline__ int64_t reflect_index(int64_t x, int64_t len) {
 | 
			
		||||
  const int64_t two = (len - 1) * 2;
 | 
			
		||||
  if (two <= 0) {
 | 
			
		||||
    return 0;
 | 
			
		||||
  }
 | 
			
		||||
  int64_t m = x % two;
 | 
			
		||||
  if (m < 0) m += two;
 | 
			
		||||
  return (m < len) ? m : (two - m);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template<typename scalar_t>
 | 
			
		||||
__global__ void reflection_pad1d_out_kernel(
 | 
			
		||||
    const scalar_t * input, scalar_t * output,
 | 
			
		||||
@ -106,6 +116,28 @@ __global__ void reflection_pad1d_out_kernel(
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <typename scalar_t>
 | 
			
		||||
__global__ void reflection_pad1d_flat(
 | 
			
		||||
    const scalar_t* __restrict__ input,
 | 
			
		||||
    scalar_t* __restrict__ output,
 | 
			
		||||
    int64_t input_w, int64_t pad_l, int64_t pad_r,
 | 
			
		||||
    int64_t out_w, int64_t plane_count) {
 | 
			
		||||
 | 
			
		||||
  const int64_t bx = blockDim.x;
 | 
			
		||||
  const int64_t tx = threadIdx.x;
 | 
			
		||||
 | 
			
		||||
  const int64_t total = plane_count * out_w;
 | 
			
		||||
  const int64_t grid_stride = static_cast<int64_t>(bx) * gridDim.x;
 | 
			
		||||
  int64_t linear = static_cast<int64_t>(blockIdx.x) * bx + tx;
 | 
			
		||||
 | 
			
		||||
  for (; linear < total; linear += grid_stride) {
 | 
			
		||||
    const int64_t plane = linear / out_w;
 | 
			
		||||
    const int64_t x = linear - plane * out_w;
 | 
			
		||||
    const int64_t j = reflect_index(x - pad_l, input_w);
 | 
			
		||||
    output[plane * out_w + x] = input[plane * input_w + j];
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <typename scalar_t>
 | 
			
		||||
__global__ void reflection_pad1d_backward_out_kernel(
 | 
			
		||||
    scalar_t * grad_input, const scalar_t * grad_output,
 | 
			
		||||
@ -710,25 +742,44 @@ TORCH_IMPL_FUNC(reflection_pad1d_out_cuda)
 | 
			
		||||
  int64_t input_w = input_.size(dim_w);
 | 
			
		||||
  int64_t output_w = input_w + pad_l + pad_r;
 | 
			
		||||
 | 
			
		||||
  dim3 block_size(output_w > 256 ? 256 : output_w);
 | 
			
		||||
  dim3 grid_size((int)::ceil(output_w / 256.0), nplane, nbatch);
 | 
			
		||||
 | 
			
		||||
  Tensor input = input_.contiguous();
 | 
			
		||||
 | 
			
		||||
  AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(
 | 
			
		||||
      kHalf, kBFloat16, input.scalar_type(), "reflection_pad1d_out_template", [&] {
 | 
			
		||||
        reflection_pad1d_out_kernel<<<
 | 
			
		||||
            grid_size,
 | 
			
		||||
            block_size,
 | 
			
		||||
            0,
 | 
			
		||||
            at::cuda::getCurrentCUDAStream()>>>(
 | 
			
		||||
            input.const_data_ptr<scalar_t>(),
 | 
			
		||||
            output.mutable_data_ptr<scalar_t>(),
 | 
			
		||||
            input_w,
 | 
			
		||||
            pad_l,
 | 
			
		||||
            pad_r);
 | 
			
		||||
        C10_CUDA_KERNEL_LAUNCH_CHECK();
 | 
			
		||||
      });
 | 
			
		||||
  const int block_x = static_cast<int>(std::min<int64_t>(256, std::max<int64_t>(1, output_w)));
 | 
			
		||||
  const cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
 | 
			
		||||
  const int max_x = prop->maxGridSize[0];
 | 
			
		||||
  const int max_y = prop->maxGridSize[1];
 | 
			
		||||
  const int max_z = prop->maxGridSize[2];
 | 
			
		||||
 | 
			
		||||
  AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kHalf, kBFloat16, input.scalar_type(), "reflection_pad1d_out", [&] {
 | 
			
		||||
    auto stream = at::cuda::getCurrentCUDAStream();
 | 
			
		||||
 | 
			
		||||
    const int64_t gx = at::ceil_div(output_w, static_cast<int64_t>(block_x));
 | 
			
		||||
 | 
			
		||||
    const bool fits3d = (nplane <= max_y) && (nbatch <= max_z) && (gx <= max_x);
 | 
			
		||||
 | 
			
		||||
    if (fits3d) {
 | 
			
		||||
      dim3 block(block_x, 1, 1);
 | 
			
		||||
      dim3 grid(gx, static_cast<unsigned>(nplane), static_cast<unsigned>(nbatch));
 | 
			
		||||
      reflection_pad1d_out_kernel<scalar_t><<<grid, block, 0, stream>>>(
 | 
			
		||||
          input.const_data_ptr<scalar_t>(),
 | 
			
		||||
          output.mutable_data_ptr<scalar_t>(),
 | 
			
		||||
          input_w, pad_l, pad_r);
 | 
			
		||||
    } else {
 | 
			
		||||
      dim3 block(block_x, 1, 1);
 | 
			
		||||
      const int64_t plane_count = nplane * nbatch;
 | 
			
		||||
      const int64_t total_blocks = at::ceil_div(plane_count * output_w, static_cast<int64_t>(block_x));
 | 
			
		||||
      const int grid_x = static_cast<int>(std::min<int64_t>(max_x, std::max<int64_t>(1, total_blocks)));
 | 
			
		||||
      dim3 grid(grid_x, 1, 1);
 | 
			
		||||
 | 
			
		||||
      reflection_pad1d_flat<scalar_t><<<grid, block, 0, stream>>>(
 | 
			
		||||
          input.const_data_ptr<scalar_t>(),
 | 
			
		||||
          output.mutable_data_ptr<scalar_t>(),
 | 
			
		||||
          input_w, pad_l, pad_r, output_w, plane_count);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    C10_CUDA_KERNEL_LAUNCH_CHECK();
 | 
			
		||||
  });
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TORCH_IMPL_FUNC(reflection_pad1d_backward_out_cuda)(const Tensor& grad_output_,
 | 
			
		||||
 | 
			
		||||
@ -43,6 +43,12 @@ std::tuple<Tensor&, Tensor&> kthvalue_out_impl_cuda(
 | 
			
		||||
  TORCH_CHECK(k >= 1 && k <= slicesize,
 | 
			
		||||
              "kthvalue(): selected number k out of range for dimension ", dim);
 | 
			
		||||
 | 
			
		||||
  TORCH_CHECK(
 | 
			
		||||
      slicesize <= std::numeric_limits<int32_t>::max(),
 | 
			
		||||
      "kthvalue(): dimension ", dim, " is too large (", slicesize,
 | 
			
		||||
      "). The current CUDA implementation supports dimension sizes up to ",
 | 
			
		||||
      std::numeric_limits<int32_t>::max());
 | 
			
		||||
 | 
			
		||||
  at::assert_no_overlap(self, values);
 | 
			
		||||
 | 
			
		||||
  _reduction_with_indices_allocate_or_resize_output(
 | 
			
		||||
@ -163,10 +169,6 @@ std::tuple<Tensor&, Tensor&> kthvalue_out_cuda(
 | 
			
		||||
    bool keepdim,
 | 
			
		||||
    Tensor& values,
 | 
			
		||||
    Tensor& indices) {
 | 
			
		||||
  // See note [Writing Nondeterministic Operations]
 | 
			
		||||
  // If there are duplicate elements of the kth value, the procedure for choosing which
 | 
			
		||||
  // of the duplicates to use for the indices output is nondeterministic.
 | 
			
		||||
  at::globalContext().alertNotDeterministic("kthvalue CUDA");
 | 
			
		||||
  auto result = [&]() {
 | 
			
		||||
    NoNamesGuard guard;
 | 
			
		||||
    // `kthvalue_out_impl_cuda` expects contiguous in input `self`.
 | 
			
		||||
 | 
			
		||||
@ -65,25 +65,34 @@ __global__ void gatherKthValue(
 | 
			
		||||
      &kValue);
 | 
			
		||||
 | 
			
		||||
  // Find the index of the k-th highest element
 | 
			
		||||
  index_t kValueIndex = 0;
 | 
			
		||||
  bool foundKValue = false;
 | 
			
		||||
  __shared__ int32_t minIndexFound;
 | 
			
		||||
 | 
			
		||||
  if (threadIdx.x == 0) {
 | 
			
		||||
      minIndexFound = static_cast<int32_t>(inputSliceSize);
 | 
			
		||||
  }
 | 
			
		||||
  __syncthreads();
 | 
			
		||||
 | 
			
		||||
  for (index_t i = threadIdx.x; i < inputSliceSize; i += blockDim.x) {
 | 
			
		||||
    bool inRange = (i < inputSliceSize);
 | 
			
		||||
    scalar_t v = inRange ? doLdg(&inputSliceStart[i * inputWithinSliceStride])
 | 
			
		||||
                         : static_cast<scalar_t>(0);
 | 
			
		||||
    bool isKValue = inRange &&
 | 
			
		||||
        ((v == kValue) || (at::_isnan(v) && at::_isnan(kValue)));
 | 
			
		||||
    if (isKValue) {
 | 
			
		||||
      kValueIndex = i;
 | 
			
		||||
      foundKValue = true;
 | 
			
		||||
      break;
 | 
			
		||||
    }
 | 
			
		||||
      // Early exit based on best-so-far
 | 
			
		||||
      if (i >= minIndexFound) {
 | 
			
		||||
          break;
 | 
			
		||||
      }
 | 
			
		||||
 | 
			
		||||
      scalar_t v = doLdg(&inputSliceStart[i * inputWithinSliceStride]);
 | 
			
		||||
      bool isKValue =
 | 
			
		||||
          ((v == kValue) || (at::_isnan(v) && at::_isnan(kValue)));
 | 
			
		||||
 | 
			
		||||
      if (isKValue) {
 | 
			
		||||
          atomicMin(&minIndexFound, static_cast<int32_t>(i));
 | 
			
		||||
          break;
 | 
			
		||||
      }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  if (foundKValue) {
 | 
			
		||||
    kthValueSliceStart[0] = kValue;
 | 
			
		||||
    indicesSliceStart[0] = kValueIndex;
 | 
			
		||||
  __syncthreads();
 | 
			
		||||
 | 
			
		||||
  if (threadIdx.x == 0) {
 | 
			
		||||
      indicesSliceStart[0] = static_cast<index_t>(minIndexFound);
 | 
			
		||||
      kthValueSliceStart[0] = kValue;
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -127,6 +127,29 @@ __global__ void upsample_bilinear2d_nhwc_out_frame(
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#ifdef USE_ROCM
 | 
			
		||||
// Helper function to compute output pixel range that can contribute to input pixel
 | 
			
		||||
template <typename accscalar_t>
 | 
			
		||||
__device__ __forceinline__ void compute_output_range(
 | 
			
		||||
    int input_pos,
 | 
			
		||||
    accscalar_t scale,
 | 
			
		||||
    int output_size,
 | 
			
		||||
    bool align_corners,
 | 
			
		||||
    int& min_output,
 | 
			
		||||
    int& max_output) {
 | 
			
		||||
  accscalar_t lo, hi;
 | 
			
		||||
  if (align_corners) {
 | 
			
		||||
      lo = static_cast<accscalar_t>(input_pos - 1) / scale;
 | 
			
		||||
      hi = static_cast<accscalar_t>(input_pos + 1) / scale;
 | 
			
		||||
  } else {
 | 
			
		||||
      lo = (input_pos - static_cast<accscalar_t>(0.5)) / scale - static_cast<accscalar_t>(0.5);
 | 
			
		||||
      hi = (input_pos + static_cast<accscalar_t>(1.5)) / scale - static_cast<accscalar_t>(0.5);
 | 
			
		||||
  }
 | 
			
		||||
  min_output = max(0, static_cast<int>(std::ceil(lo)));
 | 
			
		||||
  max_output = min(output_size - 1, static_cast<int>(std::floor(hi)));
 | 
			
		||||
}
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
// Backward (adjoint) operation 1 <- 2 (accumulates)
 | 
			
		||||
template <typename scalar_t, typename accscalar_t>
 | 
			
		||||
C10_LAUNCH_BOUNDS_1(1024)
 | 
			
		||||
@ -141,8 +164,74 @@ __global__ void upsample_bilinear2d_backward_out_frame(
 | 
			
		||||
    const bool align_corners,
 | 
			
		||||
    scalar_t* __restrict__ idata,
 | 
			
		||||
    const scalar_t* __restrict__ odata) {
 | 
			
		||||
  const size_t o_numel = nc * width2 * height2;
 | 
			
		||||
  // In C++, integer multiplication, like in standard arithmetic, is generally commutative.
 | 
			
		||||
  const size_t i_numel = nc * width1 * height1;
 | 
			
		||||
#ifdef USE_ROCM
 | 
			
		||||
  for (size_t index = blockDim.x * blockIdx.x + threadIdx.x; index < i_numel;
 | 
			
		||||
       index += blockDim.x * gridDim.x) {
 | 
			
		||||
    // Decode input pixel coordinates
 | 
			
		||||
    size_t index_temp = index;
 | 
			
		||||
    const int w1 = index_temp % width1;
 | 
			
		||||
    index_temp /= width1;
 | 
			
		||||
    const int h1 = index_temp % height1;
 | 
			
		||||
    const size_t nc_idx = index_temp / height1;
 | 
			
		||||
 | 
			
		||||
    accscalar_t grad_sum = 0;
 | 
			
		||||
 | 
			
		||||
    // Find range of output pixels that could interpolate from this input pixel
 | 
			
		||||
    int h2_min, h2_max, w2_min, w2_max;
 | 
			
		||||
    compute_output_range<accscalar_t>(h1, rheight, height2, align_corners, h2_min, h2_max);
 | 
			
		||||
    compute_output_range<accscalar_t>(w1, rwidth, width2, align_corners, w2_min, w2_max);
 | 
			
		||||
 | 
			
		||||
    // Iterate over potential output pixels
 | 
			
		||||
    for (int h2 = h2_min; h2 <= h2_max; h2++) {
 | 
			
		||||
      for (int w2 = w2_min; w2 <= w2_max; w2++) {
 | 
			
		||||
        // Compute source coordinates for this output pixel
 | 
			
		||||
        const accscalar_t h1r = area_pixel_compute_source_index<accscalar_t>(
 | 
			
		||||
            rheight, h2, align_corners, /*cubic=*/false);
 | 
			
		||||
        const int h1_base = (int)h1r;
 | 
			
		||||
        const int h1p = (h1_base < height1 - 1) ? 1 : 0;
 | 
			
		||||
        const accscalar_t h1lambda = h1r - h1_base;
 | 
			
		||||
        const accscalar_t h0lambda = static_cast<accscalar_t>(1) - h1lambda;
 | 
			
		||||
 | 
			
		||||
        const accscalar_t w1r = area_pixel_compute_source_index<accscalar_t>(
 | 
			
		||||
            rwidth, w2, align_corners, /*cubic=*/false);
 | 
			
		||||
        const int w1_base = (int)w1r;
 | 
			
		||||
        const int w1p = (w1_base < width1 - 1) ? 1 : 0;
 | 
			
		||||
        const accscalar_t w1lambda = w1r - w1_base;
 | 
			
		||||
        const accscalar_t w0lambda = static_cast<accscalar_t>(1) - w1lambda;
 | 
			
		||||
 | 
			
		||||
        // Check if our input pixel participates in this interpolation and accumulate all weights
 | 
			
		||||
        // At boundaries, h1p=0 or w1p=0 causes some sampling positions to collapse
 | 
			
		||||
        // to the same pixel, so we need to accumulate weights from all matching positions
 | 
			
		||||
        accscalar_t weight = 0;
 | 
			
		||||
 | 
			
		||||
        // Check all four interpolation positions and accumulate weights
 | 
			
		||||
        if (h1 == h1_base && w1 == w1_base) {
 | 
			
		||||
          weight += h0lambda * w0lambda;  // top-left
 | 
			
		||||
        }
 | 
			
		||||
        if (h1 == h1_base && w1 == w1_base + w1p) {
 | 
			
		||||
          weight += h0lambda * w1lambda;  // top-right (may be same as top-left if w1p=0)
 | 
			
		||||
        }
 | 
			
		||||
        if (h1 == h1_base + h1p && w1 == w1_base) {
 | 
			
		||||
          weight += h1lambda * w0lambda;  // bottom-left (may be same as top-left if h1p=0)
 | 
			
		||||
        }
 | 
			
		||||
        if (h1 == h1_base + h1p && w1 == w1_base + w1p) {
 | 
			
		||||
          weight += h1lambda * w1lambda;  // bottom-right (may collapse to other positions)
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        if (weight > 0) {
 | 
			
		||||
          const size_t output_idx = nc_idx * height2 * width2 + h2 * width2 + w2;
 | 
			
		||||
          grad_sum += weight * static_cast<accscalar_t>(odata[output_idx]);
 | 
			
		||||
        }
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    // Write accumulated gradient (no atomics needed)
 | 
			
		||||
    idata[index] = static_cast<scalar_t>(grad_sum);
 | 
			
		||||
  }
 | 
			
		||||
#else
 | 
			
		||||
  const size_t o_numel = nc * width2 * height2;
 | 
			
		||||
  for (size_t index = blockDim.x * blockIdx.x + threadIdx.x; index < o_numel;
 | 
			
		||||
       index += blockDim.x * gridDim.x) {
 | 
			
		||||
    size_t index_temp = index;
 | 
			
		||||
@ -191,6 +280,7 @@ __global__ void upsample_bilinear2d_backward_out_frame(
 | 
			
		||||
        static_cast<scalar_t>(h1lambda * w1lambda * d2val),
 | 
			
		||||
        true);
 | 
			
		||||
  }
 | 
			
		||||
#endif
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <typename scalar_t, typename accscalar_t>
 | 
			
		||||
@ -387,7 +477,6 @@ static void upsample_bilinear2d_backward_out_cuda_template(
 | 
			
		||||
  // threads are not covering the whole input tensor.
 | 
			
		||||
  grad_input.zero_();
 | 
			
		||||
 | 
			
		||||
  const size_t num_kernels = nbatch * channels * output_height * output_width;
 | 
			
		||||
  const int num_threads = std::min(
 | 
			
		||||
      at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock, 1024);
 | 
			
		||||
  cudaStream_t stream = at::cuda::getCurrentCUDAStream();
 | 
			
		||||
@ -397,6 +486,12 @@ static void upsample_bilinear2d_backward_out_cuda_template(
 | 
			
		||||
    return;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
#ifdef USE_ROCM
 | 
			
		||||
  constexpr bool use_input = true;
 | 
			
		||||
#else
 | 
			
		||||
  constexpr bool use_input = false;
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
  AT_DISPATCH_FLOATING_TYPES_AND2(
 | 
			
		||||
      at::ScalarType::Half, at::ScalarType::BFloat16,
 | 
			
		||||
      grad_output_.scalar_type(), "upsample_bilinear2d_backward_out_frame", [&] {
 | 
			
		||||
@ -414,6 +509,8 @@ static void upsample_bilinear2d_backward_out_cuda_template(
 | 
			
		||||
      const accscalar_t rwidth = area_pixel_compute_scale<accscalar_t>(
 | 
			
		||||
          input_width, output_width, align_corners, scales_w);
 | 
			
		||||
 | 
			
		||||
      const size_t num_kernels = nbatch * channels * output_height * output_width;
 | 
			
		||||
 | 
			
		||||
      upsample_bilinear2d_backward_nhwc_out_frame<scalar_t, accscalar_t>
 | 
			
		||||
          <<<ceil_div(num_kernels, static_cast<size_t>(num_threads)), num_threads, 0, stream>>>(
 | 
			
		||||
              input_height,
 | 
			
		||||
@ -444,6 +541,8 @@ static void upsample_bilinear2d_backward_out_cuda_template(
 | 
			
		||||
      const accscalar_t rwidth = area_pixel_compute_scale<accscalar_t>(
 | 
			
		||||
          input_width, output_width, align_corners, scales_w);
 | 
			
		||||
 | 
			
		||||
      const size_t num_kernels = nbatch * channels * (use_input ? input_height * input_width : output_height * output_width);
 | 
			
		||||
 | 
			
		||||
      upsample_bilinear2d_backward_out_frame<scalar_t, accscalar_t>
 | 
			
		||||
          <<<ceil_div(num_kernels, static_cast<size_t>(num_threads)),
 | 
			
		||||
             num_threads,
 | 
			
		||||
 | 
			
		||||
@ -52,7 +52,7 @@ struct FusedAdagradMathFunctor {
 | 
			
		||||
  using opmath_t = at::opmath_type<scalar_t>;
 | 
			
		||||
 | 
			
		||||
  C10_DEVICE __forceinline__ void operator()(
 | 
			
		||||
      int chunk_size,
 | 
			
		||||
      int64_t chunk_size,
 | 
			
		||||
      FusedOptimizerTensorListMetadata<3>& tl,
 | 
			
		||||
      const float* lr_ptr,
 | 
			
		||||
      const double& lr,
 | 
			
		||||
@ -133,4 +133,4 @@ struct FusedAdagradMathFunctor {
 | 
			
		||||
 | 
			
		||||
} // namespace
 | 
			
		||||
 | 
			
		||||
} // namespace at::native
 | 
			
		||||
} // namespace at::native
 | 
			
		||||
 | 
			
		||||
@ -1,4 +1,4 @@
 | 
			
		||||
#if (defined(USE_ROCM) && ROCM_VERSION >= 50700) || ((defined(CUDA_VERSION) && CUDA_VERSION >= 12000) && (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 800)))
 | 
			
		||||
#if defined(USE_ROCM) || ((defined(CUDA_VERSION) && CUDA_VERSION >= 12000) && (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 800)))
 | 
			
		||||
#include <cuda_bf16.h>
 | 
			
		||||
#include <cuda_fp16.h>
 | 
			
		||||
#include <cuda_runtime.h>
 | 
			
		||||
@ -133,7 +133,7 @@ inline __host__ __device__ uint32_t getAlignmentRoundUp(const void* p) {
 | 
			
		||||
#define CDNA2_OR_LATER 0
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
#if (defined(USE_ROCM) && ROCM_VERSION >= 50700) || ((defined(CUDA_VERSION) && CUDA_VERSION >= 12000) && (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 800)))
 | 
			
		||||
#if defined(USE_ROCM) || ((defined(CUDA_VERSION) && CUDA_VERSION >= 12000) && (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 800)))
 | 
			
		||||
 | 
			
		||||
#if defined(USE_ROCM)
 | 
			
		||||
// TODO: Support RDNA
 | 
			
		||||
@ -1161,7 +1161,7 @@ at::Tensor _weight_int4pack_mm_cuda(
 | 
			
		||||
  auto C_final = at::empty(
 | 
			
		||||
      {m, n}, at::TensorOptions().dtype(at::kBFloat16).device(A.device()));
 | 
			
		||||
 | 
			
		||||
#if (defined(USE_ROCM) && ROCM_VERSION >= 50700) || ((defined(CUDA_VERSION) && CUDA_VERSION >= 12000) && (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 800)))
 | 
			
		||||
#if defined(USE_ROCM) || ((defined(CUDA_VERSION) && CUDA_VERSION >= 12000) && (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 800)))
 | 
			
		||||
  auto stream = at::cuda::getCurrentCUDAStream();
 | 
			
		||||
#define RUN_GEMM(WARPS, K_TILES_PER_WARP, Q_GROUP_SIZE, REDUCE_TYPE) \
 | 
			
		||||
  do {                                                               \
 | 
			
		||||
@ -1327,7 +1327,7 @@ at::Tensor _convert_weight_to_int4pack_cuda(
 | 
			
		||||
      {nTilesTensor, kSuperTiles, 32, innerKTiles / 2},
 | 
			
		||||
      at::TensorOptions().dtype(at::kInt).device(in.device()));
 | 
			
		||||
 | 
			
		||||
#if (defined(USE_ROCM) && ROCM_VERSION >= 50700) || ((defined(CUDA_VERSION) && CUDA_VERSION >= 12000) && (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 800)))
 | 
			
		||||
#if defined(USE_ROCM) || ((defined(CUDA_VERSION) && CUDA_VERSION >= 12000) && (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 800)))
 | 
			
		||||
  auto stream = at::cuda::getCurrentCUDAStream();
 | 
			
		||||
  dim3 grid(kSuperTiles, nTiles);
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -2,6 +2,7 @@
 | 
			
		||||
#include <ATen/WrapDimUtilsMulti.h>
 | 
			
		||||
#include <ATen/native/Resize.h>
 | 
			
		||||
#include <ATen/native/mkldnn/xpu/detail/oneDNN.h>
 | 
			
		||||
#include <ATen/native/xpu/Blas.h>
 | 
			
		||||
#include <torch/library.h>
 | 
			
		||||
#ifndef AT_PER_OPERATOR_HEADERS
 | 
			
		||||
 | 
			
		||||
@ -50,9 +51,13 @@ Tensor& addmm_out(
 | 
			
		||||
      mat1.dtype(),
 | 
			
		||||
      " != ",
 | 
			
		||||
      mat2.dtype())
 | 
			
		||||
 | 
			
		||||
  // complex case
 | 
			
		||||
  TORCH_CHECK(
 | 
			
		||||
      !mat1.is_complex(), "Complex datatype matmul is not supported in oneDNN");
 | 
			
		||||
  if (self.is_complex()) {
 | 
			
		||||
    at::native::addmm_complex_out_xpu(self, mat1, mat2, beta, alpha, result);
 | 
			
		||||
 | 
			
		||||
    return result;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  std::vector<int64_t> result_shape = {mat1.size(0), mat2.size(1)};
 | 
			
		||||
  result.resize_(result_shape);
 | 
			
		||||
@ -167,8 +172,11 @@ Tensor& mm_out(const Tensor& self, const Tensor& mat2, Tensor& result) {
 | 
			
		||||
    return result;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  TORCH_CHECK(
 | 
			
		||||
      !self.is_complex(), "Complex datatype matmul is not supported in oneDNN");
 | 
			
		||||
  if (self.is_complex()) {
 | 
			
		||||
    at::native::mm_complex_out_xpu(self, mat2, result);
 | 
			
		||||
 | 
			
		||||
    return result;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  onednn::matmul(result, self, mat2, Tensor(), true, onednn::Attr());
 | 
			
		||||
  return result;
 | 
			
		||||
@ -208,9 +216,12 @@ Tensor& baddbmm_out(
 | 
			
		||||
      input.sizes());
 | 
			
		||||
 | 
			
		||||
  // complex case
 | 
			
		||||
  TORCH_CHECK(
 | 
			
		||||
      !batch1.is_complex(),
 | 
			
		||||
      "Complex datatype matmul is not supported in oneDNN");
 | 
			
		||||
  if (input.is_complex()) {
 | 
			
		||||
    at::native::baddbmm_complex_out_xpu(
 | 
			
		||||
        input, batch1, batch2, beta, alpha, result);
 | 
			
		||||
 | 
			
		||||
    return result;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // general case
 | 
			
		||||
  onednn::Attr attr;
 | 
			
		||||
@ -257,8 +268,13 @@ Tensor& bmm_out(const Tensor& self, const Tensor& batch2, Tensor& result) {
 | 
			
		||||
    return result;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  TORCH_CHECK(
 | 
			
		||||
      !self.is_complex(), "Complex datatype matmul is not supported in oneDNN");
 | 
			
		||||
  // complex case
 | 
			
		||||
  if (self.is_complex()) {
 | 
			
		||||
    at::native::bmm_complex_out_xpu(self, batch2, result);
 | 
			
		||||
 | 
			
		||||
    return result;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  onednn::matmul(result, self, batch2, at::Tensor(), true, onednn::Attr());
 | 
			
		||||
  return result;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -222,6 +222,13 @@ struct nextafter_functor {
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
struct hypot_functor {
 | 
			
		||||
  template <typename T>
 | 
			
		||||
  inline T operator()(const T a, const T b) {
 | 
			
		||||
    return static_cast<T>(precise::sqrt(float(a) * a + float(b) * b));
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
// Complex binary functors
 | 
			
		||||
struct polar_functor {
 | 
			
		||||
  template <typename U>
 | 
			
		||||
@ -362,6 +369,7 @@ struct igammac_functor {
 | 
			
		||||
  REGISTER_OPMATH_BINARY_OP(NAME, half, half);   \
 | 
			
		||||
  REGISTER_OPMATH_BINARY_OP(NAME, bfloat, bfloat)
 | 
			
		||||
 | 
			
		||||
REGISTER_FLOAT_BINARY_OP(hypot);
 | 
			
		||||
REGISTER_FLOAT_BINARY_OP(copysign);
 | 
			
		||||
REGISTER_INT2FLOAT_BINARY_OP(copysign);
 | 
			
		||||
REGISTER_FLOAT_BINARY_OP(fmax);
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										16
									
								
								aten/src/ATen/native/mps/kernels/LinearAlgebra.h
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										16
									
								
								aten/src/ATen/native/mps/kernels/LinearAlgebra.h
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,16 @@
 | 
			
		||||
#pragma onces
 | 
			
		||||
#include <c10/metal/common.h>
 | 
			
		||||
 | 
			
		||||
template <unsigned N = c10::metal::max_ndim>
 | 
			
		||||
struct OrgqrParams {
 | 
			
		||||
  int32_t num_batch_dims;
 | 
			
		||||
 | 
			
		||||
  uint32_t m;
 | 
			
		||||
  uint32_t n;
 | 
			
		||||
  uint32_t k;
 | 
			
		||||
 | 
			
		||||
  ::c10::metal::array<uint32_t, N> A_strides;
 | 
			
		||||
  ::c10::metal::array<uint32_t, N> tau_strides;
 | 
			
		||||
  ::c10::metal::array<uint32_t, N> H_strides;
 | 
			
		||||
  ::c10::metal::array<uint32_t, N> H_sizes;
 | 
			
		||||
};
 | 
			
		||||
@ -1,3 +1,4 @@
 | 
			
		||||
#include <ATen/native/mps/kernels/LinearAlgebra.h>
 | 
			
		||||
#include <c10/metal/utils.h>
 | 
			
		||||
#include <metal_array>
 | 
			
		||||
#include <metal_simdgroup>
 | 
			
		||||
@ -640,6 +641,164 @@ kernel void applyPivots(
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <typename T>
 | 
			
		||||
static T bool_to_float(bool b) {
 | 
			
		||||
  return static_cast<T>(b);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
half2 bool_to_float(bool b) {
 | 
			
		||||
  return half2(b ? 1 : 0, 0);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
float2 bool_to_float(bool b) {
 | 
			
		||||
  return float2(b ? 1 : 0, 0);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <typename T>
 | 
			
		||||
static T calc_H_irc(
 | 
			
		||||
    device T* A,
 | 
			
		||||
    uint32_t A_stride_r,
 | 
			
		||||
    uint32_t A_stride_c,
 | 
			
		||||
    constant T* tau,
 | 
			
		||||
    uint32_t tau_stride,
 | 
			
		||||
    uint32_t r,
 | 
			
		||||
    uint32_t c,
 | 
			
		||||
    uint32_t i) {
 | 
			
		||||
  T I_val = bool_to_float<T>(r == c);
 | 
			
		||||
  T tau_val = tau[i * tau_stride];
 | 
			
		||||
 | 
			
		||||
  T A_ci = c10::metal::conj(A[c * A_stride_r + i * A_stride_c]);
 | 
			
		||||
  T A_ri = A[r * A_stride_r + i * A_stride_c];
 | 
			
		||||
 | 
			
		||||
  T c_eq_i = bool_to_float<T>(c == i);
 | 
			
		||||
  T r_eq_i = bool_to_float<T>(r == i);
 | 
			
		||||
 | 
			
		||||
  T A_ci_ = (c > i) ? A_ci : c_eq_i;
 | 
			
		||||
  T A_ri_ = (r > i) ? A_ri : r_eq_i;
 | 
			
		||||
 | 
			
		||||
  return I_val - c10::metal::mul(tau_val, c10::metal::mul(A_ci_, A_ri_));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Calculate (A @ B)[r, c], the element in the r-th row and c-th column of the
 | 
			
		||||
// result of matrix multiplying A and B together. A and B must be size m-by-m
 | 
			
		||||
// and have the same strides. The formula for this operation, written in Python
 | 
			
		||||
// syntax, is:
 | 
			
		||||
//   (A @ B)[r, c] = A[r, :].dot(B[:, c])
 | 
			
		||||
template <typename T>
 | 
			
		||||
static T calc_matmul_rc(
 | 
			
		||||
    device T* A,
 | 
			
		||||
    device T* B,
 | 
			
		||||
    uint32_t stride_r,
 | 
			
		||||
    uint32_t stride_c,
 | 
			
		||||
    uint32_t m,
 | 
			
		||||
    uint32_t r,
 | 
			
		||||
    uint32_t c) {
 | 
			
		||||
  T AB_rc = 0;
 | 
			
		||||
  auto A_row_offset = r * stride_r;
 | 
			
		||||
  auto B_col_offset = c * stride_c;
 | 
			
		||||
 | 
			
		||||
  uint32_t A_col_offset = 0;
 | 
			
		||||
  uint32_t B_row_offset = 0;
 | 
			
		||||
 | 
			
		||||
  for (uint32_t j = 0; j < m;
 | 
			
		||||
       j++, A_col_offset += stride_c, B_row_offset += stride_r) {
 | 
			
		||||
    AB_rc += c10::metal::mul(
 | 
			
		||||
        A[A_row_offset + A_col_offset], B[B_row_offset + B_col_offset]);
 | 
			
		||||
  }
 | 
			
		||||
  return AB_rc;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <typename T>
 | 
			
		||||
kernel void orgqr(
 | 
			
		||||
    device T* A [[buffer(0)]],
 | 
			
		||||
    constant T* tau [[buffer(1)]],
 | 
			
		||||
    device T* H [[buffer(2)]],
 | 
			
		||||
    device T* H_prod [[buffer(3)]],
 | 
			
		||||
    constant OrgqrParams<>& params [[buffer(4)]],
 | 
			
		||||
    uint tid [[thread_position_in_grid]]) {
 | 
			
		||||
  constant auto& A_strides = params.A_strides;
 | 
			
		||||
  constant auto& tau_strides = params.tau_strides;
 | 
			
		||||
  constant auto& H_strides = params.H_strides;
 | 
			
		||||
  constant auto& H_sizes = params.H_sizes;
 | 
			
		||||
 | 
			
		||||
  auto num_batch_dims = params.num_batch_dims;
 | 
			
		||||
  auto m = params.m;
 | 
			
		||||
  auto n = params.n;
 | 
			
		||||
  auto k = params.k;
 | 
			
		||||
 | 
			
		||||
  auto m2 = m * m;
 | 
			
		||||
  auto batch_idx = tid / m2;
 | 
			
		||||
 | 
			
		||||
  // Find the matrices for this thread's batch index
 | 
			
		||||
  uint32_t A_offset = 0;
 | 
			
		||||
  uint32_t tau_offset = 0;
 | 
			
		||||
  uint32_t H_offset = 0;
 | 
			
		||||
 | 
			
		||||
  for (auto dim = num_batch_dims - 1; dim >= 0; dim--) {
 | 
			
		||||
    auto dim_size = H_sizes[dim];
 | 
			
		||||
    auto dim_idx = batch_idx % dim_size;
 | 
			
		||||
 | 
			
		||||
    A_offset += dim_idx * A_strides[dim];
 | 
			
		||||
    tau_offset += dim_idx * tau_strides[dim];
 | 
			
		||||
    H_offset += dim_idx * H_strides[dim];
 | 
			
		||||
 | 
			
		||||
    batch_idx /= dim_size;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  A += A_offset;
 | 
			
		||||
  tau += tau_offset;
 | 
			
		||||
  H += H_offset;
 | 
			
		||||
  H_prod += H_offset;
 | 
			
		||||
 | 
			
		||||
  auto matrix_idx = tid % m2;
 | 
			
		||||
  auto r = matrix_idx / m;
 | 
			
		||||
  auto c = matrix_idx % m;
 | 
			
		||||
  auto A_stride_r = A_strides[num_batch_dims];
 | 
			
		||||
  auto A_stride_c = A_strides[num_batch_dims + 1];
 | 
			
		||||
  auto tau_stride = tau_strides[num_batch_dims];
 | 
			
		||||
  auto H_stride_r = H_strides[num_batch_dims];
 | 
			
		||||
  auto H_stride_c = H_strides[num_batch_dims + 1];
 | 
			
		||||
 | 
			
		||||
  // Find the element of H and H_prod that this thread will calculate
 | 
			
		||||
  device T* H_elem_ptr = H + (r * H_stride_r + c * H_stride_c);
 | 
			
		||||
  device T* H_prod_elem_ptr = H_prod + (r * H_stride_r + c * H_stride_c);
 | 
			
		||||
 | 
			
		||||
  for (uint32_t i = 0; i < k; i++) {
 | 
			
		||||
    // Calculate and write H_i
 | 
			
		||||
 | 
			
		||||
    T H_irc = calc_H_irc(A, A_stride_r, A_stride_c, tau, tau_stride, r, c, i);
 | 
			
		||||
 | 
			
		||||
    // Calculate element [r, c] of prod(H_0, ..., H_i)
 | 
			
		||||
    if (i == 0) {
 | 
			
		||||
      *H_prod_elem_ptr = H_irc;
 | 
			
		||||
    } else {
 | 
			
		||||
      *H_elem_ptr = H_irc;
 | 
			
		||||
 | 
			
		||||
      // Need this sync because the below matmul requires all threads to finish
 | 
			
		||||
      // writing their entries to `H_prod` and `H`.
 | 
			
		||||
      threadgroup_barrier(mem_flags::mem_threadgroup);
 | 
			
		||||
 | 
			
		||||
      T H_prod_0_to_i_rc =
 | 
			
		||||
          calc_matmul_rc(H_prod, H, H_stride_r, H_stride_c, m, r, c);
 | 
			
		||||
 | 
			
		||||
      // Need this sync because the above matmul uses the current values in
 | 
			
		||||
      // `H_prod`, and we don't want to overwrite those until all threads are
 | 
			
		||||
      // finished using them.
 | 
			
		||||
      threadgroup_barrier(mem_flags::mem_threadgroup);
 | 
			
		||||
 | 
			
		||||
      *H_prod_elem_ptr = H_prod_0_to_i_rc;
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  device T* A_elem_ptr = A + (r * A_stride_r + c * A_stride_c);
 | 
			
		||||
 | 
			
		||||
  if (c < n) {
 | 
			
		||||
    *A_elem_ptr = *H_prod_elem_ptr;
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#define INSTANTIATE_MM_OPS(DTYPE)                                           \
 | 
			
		||||
  template [[host_name("matmul_" #DTYPE)]] kernel void matmul<DTYPE>(       \
 | 
			
		||||
      constant DTYPE * mat1Data [[buffer(0)]],                              \
 | 
			
		||||
@ -679,3 +838,19 @@ INSTANTIATE_MM_OPS(int);
 | 
			
		||||
INSTANTIATE_MM_OPS(short);
 | 
			
		||||
INSTANTIATE_MM_OPS(char);
 | 
			
		||||
INSTANTIATE_MM_OPS(uchar);
 | 
			
		||||
 | 
			
		||||
#define REGISTER_ORGQR(T)                            \
 | 
			
		||||
  template [[host_name("orgqr_" #T)]]                \
 | 
			
		||||
  kernel void orgqr<T>(                              \
 | 
			
		||||
      device T * A [[buffer(0)]],                    \
 | 
			
		||||
      constant T * tau [[buffer(1)]],                \
 | 
			
		||||
      device T * H [[buffer(2)]],                    \
 | 
			
		||||
      device T * H_prod [[buffer(3)]],               \
 | 
			
		||||
      constant OrgqrParams<> & params [[buffer(4)]], \
 | 
			
		||||
      uint tid [[thread_position_in_grid]]);
 | 
			
		||||
 | 
			
		||||
REGISTER_ORGQR(float);
 | 
			
		||||
REGISTER_ORGQR(half);
 | 
			
		||||
REGISTER_ORGQR(bfloat);
 | 
			
		||||
REGISTER_ORGQR(float2);
 | 
			
		||||
REGISTER_ORGQR(half2);
 | 
			
		||||
 | 
			
		||||
@ -5,6 +5,21 @@
 | 
			
		||||
using namespace metal;
 | 
			
		||||
using namespace c10::metal;
 | 
			
		||||
 | 
			
		||||
struct angle_functor {
 | 
			
		||||
  template <typename T, enable_if_t<is_complex_v<T>, bool> = true>
 | 
			
		||||
  inline T operator()(const T x) {
 | 
			
		||||
    return T(atan2(x.y, x.x), 0);
 | 
			
		||||
  }
 | 
			
		||||
  template <typename T, enable_if_t<is_scalar_floating_point_v<T>, bool> = true>
 | 
			
		||||
  inline T operator()(const T x) {
 | 
			
		||||
    return T(isnan(x) ? x : x < 0 ? M_PI_F : 0.0);
 | 
			
		||||
  }
 | 
			
		||||
  template <typename T, enable_if_t<is_scalar_integral_v<T>, bool> = true>
 | 
			
		||||
  inline float operator()(const T x) {
 | 
			
		||||
    return x < 0 ? M_PI_F : 0.0;
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
// Implement exp wrapper for both real and complex types
 | 
			
		||||
template <typename T, enable_if_t<is_scalar_floating_point_v<T>, bool> = true>
 | 
			
		||||
inline T exp_(const T x) {
 | 
			
		||||
@ -545,6 +560,7 @@ REGISTER_UNARY_OP(abs, float, float);
 | 
			
		||||
REGISTER_UNARY_OP(abs, half, half);
 | 
			
		||||
 | 
			
		||||
#define INSTANTIATE_UNARY_KERNELS2(DTYPE0, DTYPE1) \
 | 
			
		||||
  REGISTER_UNARY_OP(angle, DTYPE1, DTYPE0);        \
 | 
			
		||||
  REGISTER_UNARY_OP(erf, DTYPE1, DTYPE0);          \
 | 
			
		||||
  REGISTER_UNARY_OP(erfc, DTYPE1, DTYPE0);         \
 | 
			
		||||
  REGISTER_UNARY_OP(erfinv, DTYPE1, DTYPE0);       \
 | 
			
		||||
@ -583,6 +599,7 @@ INSTANTIATE_UNARY_KERNELS2(float, int);
 | 
			
		||||
INSTANTIATE_UNARY_KERNELS2(float, long);
 | 
			
		||||
 | 
			
		||||
#define INSTANTIATE_UNARY_KERNELS_VEC2(DTYPE)     \
 | 
			
		||||
  REGISTER_UNARY_OP(angle, DTYPE##2, DTYPE##2);   \
 | 
			
		||||
  REGISTER_UNARY_OP(neg, DTYPE##2, DTYPE##2);     \
 | 
			
		||||
  REGISTER_UNARY_OP(exp, DTYPE##2, DTYPE##2);     \
 | 
			
		||||
  REGISTER_UNARY_OP(expm1, DTYPE##2, DTYPE##2);   \
 | 
			
		||||
 | 
			
		||||
@ -92,13 +92,8 @@ static std::tuple<Tensor, Tensor> sdpa_general_mps(const Tensor& query,
 | 
			
		||||
          }
 | 
			
		||||
 | 
			
		||||
          // upcasting to float32 if needed to improve precision when multiplying by the scale factor
 | 
			
		||||
          if ([maskedMM dataType] != MPSDataTypeFloat32) {
 | 
			
		||||
            maskedMM = [mpsGraph castTensor:maskedMM toType:MPSDataTypeFloat32 name:nil];
 | 
			
		||||
          }
 | 
			
		||||
          maskedMM = castMPSTensor(mpsGraph, maskedMM, MPSDataTypeFloat32);
 | 
			
		||||
          maskedMM = [mpsGraph multiplicationWithPrimaryTensor:maskedMM secondaryTensor:scaleTensor name:nil];
 | 
			
		||||
          if ([maskedMM dataType] != qTensor.dataType) {
 | 
			
		||||
            maskedMM = [mpsGraph castTensor:maskedMM toType:qTensor.dataType name:nil];
 | 
			
		||||
          }
 | 
			
		||||
 | 
			
		||||
          if (is_causal) {
 | 
			
		||||
            auto causalMask = [mpsGraph constantWithScalar:1.0f
 | 
			
		||||
@ -112,7 +107,9 @@ static std::tuple<Tensor, Tensor> sdpa_general_mps(const Tensor& query,
 | 
			
		||||
                                                      name:nil];
 | 
			
		||||
          } else if (attn_mask) {
 | 
			
		||||
            graph->maskTensor = mpsGraphRankedPlaceHolder(mpsGraph, *attn_mask);
 | 
			
		||||
            maskedMM = [mpsGraph additionWithPrimaryTensor:maskedMM secondaryTensor:graph->maskTensor name:nil];
 | 
			
		||||
            maskedMM = [mpsGraph additionWithPrimaryTensor:maskedMM
 | 
			
		||||
                                           secondaryTensor:castMPSTensor(mpsGraph, graph->maskTensor, maskedMM.dataType)
 | 
			
		||||
                                                      name:nil];
 | 
			
		||||
          }
 | 
			
		||||
 | 
			
		||||
          // Account for case where all values were masked causing division by 0 in softmax (issue:#156707)
 | 
			
		||||
@ -133,8 +130,8 @@ static std::tuple<Tensor, Tensor> sdpa_general_mps(const Tensor& query,
 | 
			
		||||
          graph->qTensor = qTensor;
 | 
			
		||||
          graph->kTensor = kTensor;
 | 
			
		||||
          graph->vTensor = vTensor;
 | 
			
		||||
          graph->outputTensor = output;
 | 
			
		||||
          graph->attnTensor = sm;
 | 
			
		||||
          graph->outputTensor = castMPSTensor(mpsGraph, output, qTensor.dataType);
 | 
			
		||||
          graph->attnTensor = castMPSTensor(mpsGraph, sm, qTensor.dataType);
 | 
			
		||||
        });
 | 
			
		||||
    auto qPlaceholder = Placeholder(cachedGraph->qTensor, query);
 | 
			
		||||
    auto kPlaceholder = Placeholder(cachedGraph->kTensor, key);
 | 
			
		||||
 | 
			
		||||
@ -202,6 +202,10 @@ static void igammac_mps_kernel(TensorIteratorBase& iter) {
 | 
			
		||||
  lib.exec_binary_kernel(iter, "igammac");
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static void hypot_mps_kernel(TensorIteratorBase& iter) {
 | 
			
		||||
  lib.exec_binary_kernel(iter, "hypot");
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
REGISTER_DISPATCH(fmax_stub, &fmax_mps_kernel)
 | 
			
		||||
REGISTER_DISPATCH(fmin_stub, &fmin_mps_kernel)
 | 
			
		||||
REGISTER_DISPATCH(copysign_stub, ©sign_mps_kernel)
 | 
			
		||||
@ -229,4 +233,5 @@ REGISTER_DISPATCH(fmod_stub, &fmod_mps_kernel)
 | 
			
		||||
REGISTER_DISPATCH(remainder_stub, &remainder_mps_kernel)
 | 
			
		||||
REGISTER_DISPATCH(igamma_stub, &igamma_mps_kernel)
 | 
			
		||||
REGISTER_DISPATCH(igammac_stub, &igammac_mps_kernel)
 | 
			
		||||
REGISTER_DISPATCH(hypot_stub, &hypot_mps_kernel)
 | 
			
		||||
} // namespace at::native
 | 
			
		||||
 | 
			
		||||
@ -16,7 +16,6 @@
 | 
			
		||||
#include <ATen/ops/eq_native.h>
 | 
			
		||||
#include <ATen/ops/ge_native.h>
 | 
			
		||||
#include <ATen/ops/gt_native.h>
 | 
			
		||||
#include <ATen/ops/hypot_native.h>
 | 
			
		||||
#include <ATen/ops/le_native.h>
 | 
			
		||||
#include <ATen/ops/logaddexp2_native.h>
 | 
			
		||||
#include <ATen/ops/logaddexp_native.h>
 | 
			
		||||
@ -278,22 +277,6 @@ TORCH_IMPL_FUNC(pow_Scalar_out_mps)(const Scalar& base, const Tensor& exp, const
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TORCH_IMPL_FUNC(hypot_out_mps)(const Tensor& self, const Tensor& other, const Tensor& output) {
 | 
			
		||||
  mps::BinaryOpBlock hypot_op_block = ^BinaryOpFn(cachedGraph, primaryCastTensor, secondaryCastTensor) {
 | 
			
		||||
    MPSGraph* mpsGraph = cachedGraph->graph();
 | 
			
		||||
    MPSGraphTensor* twoTensor = [mpsGraph constantWithScalar:2.0 shape:@[ @1 ] dataType:primaryCastTensor.dataType];
 | 
			
		||||
    MPSGraphTensor* sumTensor = [mpsGraph additionWithPrimaryTensor:[mpsGraph powerWithPrimaryTensor:primaryCastTensor
 | 
			
		||||
                                                                                     secondaryTensor:twoTensor
 | 
			
		||||
                                                                                                name:nil]
 | 
			
		||||
                                                    secondaryTensor:[mpsGraph powerWithPrimaryTensor:secondaryCastTensor
 | 
			
		||||
                                                                                     secondaryTensor:twoTensor
 | 
			
		||||
                                                                                                name:nil]
 | 
			
		||||
                                                               name:nil];
 | 
			
		||||
    return [mpsGraph squareRootWithTensor:sumTensor name:nil];
 | 
			
		||||
  };
 | 
			
		||||
  mps::binaryOpTensor(self, other, output, "hypot_out_mps", hypot_op_block);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TORCH_IMPL_FUNC(logaddexp_out_mps)(const Tensor& self, const Tensor& other, const Tensor& output) {
 | 
			
		||||
  mps::BinaryOpBlock logaddexp_op_block = ^BinaryOpFn(cachedGraph, primaryCastTensor, secondaryCastTensor) {
 | 
			
		||||
    MPSGraph* mpsGraph = cachedGraph->graph();
 | 
			
		||||
 | 
			
		||||
@ -8,6 +8,9 @@
 | 
			
		||||
#include <ATen/native/Resize.h>
 | 
			
		||||
#include <ATen/native/mps/MPSGraphSequoiaOps.h>
 | 
			
		||||
#include <ATen/native/mps/OperationUtils.h>
 | 
			
		||||
#include <ATen/native/mps/kernels/LinearAlgebra.h>
 | 
			
		||||
 | 
			
		||||
#include <fmt/format.h>
 | 
			
		||||
 | 
			
		||||
#ifndef AT_PER_OPERATOR_HEADERS
 | 
			
		||||
#include <ATen/Functions.h>
 | 
			
		||||
@ -28,6 +31,7 @@
 | 
			
		||||
#include <ATen/ops/linalg_solve_triangular_native.h>
 | 
			
		||||
#include <ATen/ops/lu_unpack_native.h>
 | 
			
		||||
#include <ATen/ops/mm_native.h>
 | 
			
		||||
#include <ATen/ops/orgqr_native.h>
 | 
			
		||||
#include <ATen/ops/slice.h>
 | 
			
		||||
#include <ATen/ops/stack.h>
 | 
			
		||||
#include <ATen/ops/triangular_solve_native.h>
 | 
			
		||||
@ -338,6 +342,8 @@ static void linalg_lu_factor_ex_out_mps_impl(const Tensor& A,
 | 
			
		||||
          ". See https://developer.apple.com/documentation/metalperformanceshaders/mpsmatrixdecompositionstatus for details.");
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  map_mps_decomposition_error_code_to_blas(info);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static void linalg_solve_out_mps_impl(const Tensor& A,
 | 
			
		||||
@ -1233,6 +1239,69 @@ static void cholesky_stub_impl(const Tensor& out, const Tensor& info, bool upper
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static Tensor& orgqr_stub_impl(Tensor& self, const Tensor& tau) {
 | 
			
		||||
  if (self.numel() == 0) {
 | 
			
		||||
    return self;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  auto m = self.size(-2);
 | 
			
		||||
  auto n = self.size(-1);
 | 
			
		||||
  auto k = tau.size(-1);
 | 
			
		||||
 | 
			
		||||
  if (tau.numel() == 0) {
 | 
			
		||||
    auto I = eye(m, self.scalar_type(), std::nullopt, self.device());
 | 
			
		||||
    return self.copy_(I.slice(-1, 0, n));
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  auto num_batch_dims = self.dim() - 2;
 | 
			
		||||
  auto batch_sizes = self.sizes().slice(0, num_batch_dims);
 | 
			
		||||
 | 
			
		||||
  std::vector<int64_t> H_sizes(num_batch_dims + 2);
 | 
			
		||||
  for (auto dim : c10::irange(num_batch_dims)) {
 | 
			
		||||
    H_sizes[dim] = self.size(dim);
 | 
			
		||||
  }
 | 
			
		||||
  H_sizes[num_batch_dims] = m;
 | 
			
		||||
  H_sizes[num_batch_dims + 1] = m;
 | 
			
		||||
 | 
			
		||||
  auto H = at::empty(H_sizes, self.options().memory_format(MemoryFormat::Contiguous));
 | 
			
		||||
  auto H_prod = at::empty_like(H);
 | 
			
		||||
 | 
			
		||||
  OrgqrParams params;
 | 
			
		||||
 | 
			
		||||
  params.num_batch_dims = num_batch_dims;
 | 
			
		||||
  params.m = m;
 | 
			
		||||
  params.n = n;
 | 
			
		||||
  params.k = k;
 | 
			
		||||
 | 
			
		||||
  for (const auto dim : c10::irange(self.dim())) {
 | 
			
		||||
    params.A_strides[dim] = self.stride(dim);
 | 
			
		||||
 | 
			
		||||
    if (dim < tau.dim()) {
 | 
			
		||||
      params.tau_strides[dim] = tau.stride(dim);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    params.H_strides[dim] = H.stride(dim);
 | 
			
		||||
    params.H_sizes[dim] = H.size(dim);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  auto num_threads = H.numel();
 | 
			
		||||
  MPSStream* stream = getCurrentMPSStream();
 | 
			
		||||
 | 
			
		||||
  dispatch_sync_with_rethrow(stream->queue(), ^() {
 | 
			
		||||
    @autoreleasepool {
 | 
			
		||||
      id<MTLComputeCommandEncoder> compute_encoder = stream->commandEncoder();
 | 
			
		||||
      auto pipeline_state = lib.getPipelineStateForFunc(fmt::format("orgqr_{}", scalarToMetalTypeString(self)));
 | 
			
		||||
      getMPSProfiler().beginProfileKernel(pipeline_state, "orgqr", {self, tau});
 | 
			
		||||
      [compute_encoder setComputePipelineState:pipeline_state];
 | 
			
		||||
      mtl_setArgs(compute_encoder, self, tau, H, H_prod, params);
 | 
			
		||||
      mtl_dispatch1DJob(compute_encoder, pipeline_state, num_threads);
 | 
			
		||||
      getMPSProfiler().endProfileKernel(pipeline_state);
 | 
			
		||||
    }
 | 
			
		||||
  });
 | 
			
		||||
 | 
			
		||||
  return self;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
} // namespace mps
 | 
			
		||||
 | 
			
		||||
Tensor addr_mps(const Tensor& self, const Tensor& vec1, const Tensor& vec2, const Scalar& beta, const Scalar& alpha) {
 | 
			
		||||
@ -1448,20 +1517,6 @@ TORCH_IMPL_FUNC(_linalg_solve_ex_out_mps)
 | 
			
		||||
  mps::linalg_solve_out_mps_impl(A, B, left, check_errors, result, LU, pivots, info);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
std::tuple<Tensor&, Tensor&> linalg_lu_factor_out_mps(const Tensor& A, bool pivot, Tensor& LU, Tensor& pivots) {
 | 
			
		||||
  Tensor info = at::empty({}, A.options().dtype(kInt));
 | 
			
		||||
  mps::linalg_lu_factor_ex_out_mps_impl(A, pivot, LU, pivots, info, false);
 | 
			
		||||
  return std::tie(LU, pivots);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
std::tuple<Tensor, Tensor> linalg_lu_factor_mps(const Tensor& A, bool pivot) {
 | 
			
		||||
  Tensor LU = at::empty({0}, A.options());
 | 
			
		||||
  Tensor pivots = at::empty({0}, A.options().dtype(kInt));
 | 
			
		||||
  Tensor info = at::empty({}, A.options().dtype(kInt));
 | 
			
		||||
  mps::linalg_lu_factor_ex_out_mps_impl(A, pivot, LU, pivots, info, false);
 | 
			
		||||
  return std::make_tuple(std::move(LU), std::move(pivots));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TORCH_IMPL_FUNC(lu_unpack_out_mps)
 | 
			
		||||
(const Tensor& LU_data,
 | 
			
		||||
 const Tensor& LU_pivots,
 | 
			
		||||
@ -1483,4 +1538,6 @@ TORCH_IMPL_FUNC(linalg_inv_ex_out_mps)(const Tensor& A, bool check_errors, const
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
REGISTER_DISPATCH(cholesky_stub, mps::cholesky_stub_impl)
 | 
			
		||||
REGISTER_DISPATCH(orgqr_stub, mps::orgqr_stub_impl);
 | 
			
		||||
 | 
			
		||||
} // namespace at::native
 | 
			
		||||
 | 
			
		||||
@ -34,6 +34,7 @@ REGISTER_UNARY_TI_DISPATCH(sinc);
 | 
			
		||||
REGISTER_UNARY_TI_DISPATCH(sinh);
 | 
			
		||||
REGISTER_UNARY_TI_DISPATCH(cosh);
 | 
			
		||||
REGISTER_UNARY_TI_DISPATCH(tanh);
 | 
			
		||||
REGISTER_UNARY_TI_DISPATCH(angle);
 | 
			
		||||
REGISTER_UNARY_TI_DISPATCH(abs);
 | 
			
		||||
REGISTER_UNARY_TI_DISPATCH(sin);
 | 
			
		||||
REGISTER_UNARY_TI_DISPATCH(cos);
 | 
			
		||||
 | 
			
		||||
@ -12,7 +12,6 @@
 | 
			
		||||
#include <ATen/ops/_copy_from_and_resize.h>
 | 
			
		||||
#include <ATen/ops/acos_native.h>
 | 
			
		||||
#include <ATen/ops/acosh_native.h>
 | 
			
		||||
#include <ATen/ops/angle_native.h>
 | 
			
		||||
#include <ATen/ops/asin_native.h>
 | 
			
		||||
#include <ATen/ops/asinh_native.h>
 | 
			
		||||
#include <ATen/ops/atan_native.h>
 | 
			
		||||
@ -204,23 +203,6 @@ Tensor& logical_not_out_mps(const Tensor& self, Tensor& output) {
 | 
			
		||||
  return output;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
Tensor& angle_out_mps(const Tensor& self, Tensor& output) {
 | 
			
		||||
  mps::unary_op(self, output, "angle_out_mps", ^MPSGraphTensor*(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) {
 | 
			
		||||
    auto realPart = [mpsGraph realPartOfTensor:inputTensor name:nil];
 | 
			
		||||
    auto imagPart = [mpsGraph imaginaryPartOfTensor:inputTensor name:nil];
 | 
			
		||||
    return [mpsGraph atan2WithPrimaryTensor:imagPart secondaryTensor:realPart name:nil];
 | 
			
		||||
  });
 | 
			
		||||
  return output;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
Tensor angle_mps(const Tensor& self) {
 | 
			
		||||
  const auto float_type = c10::isIntegralType(self.scalar_type(), /*includeBool=*/true)
 | 
			
		||||
      ? c10::typeMetaToScalarType(c10::get_default_dtype())
 | 
			
		||||
      : c10::toRealValueType(self.scalar_type());
 | 
			
		||||
  Tensor result = at::empty({0}, self.options().dtype(float_type));
 | 
			
		||||
  return angle_out_mps(self, result);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TORCH_IMPL_FUNC(frac_out_mps)(const Tensor& self, const Tensor& output) {
 | 
			
		||||
  TORCH_CHECK(isFloatingType(self.scalar_type()), "frac_out_mps is only implemented for floating types");
 | 
			
		||||
  mps::unary_op(self, output, "frac_out_mps", ^MPSGraphTensor*(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) {
 | 
			
		||||
 | 
			
		||||
@ -403,16 +403,14 @@
 | 
			
		||||
  device_check: NoCheck   # TensorIterator
 | 
			
		||||
  variants: function, method
 | 
			
		||||
  dispatch:
 | 
			
		||||
    CPU, CUDA: angle
 | 
			
		||||
    MPS: angle_mps
 | 
			
		||||
    CPU, CUDA, MPS: angle
 | 
			
		||||
    SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: angle_sparse_csr
 | 
			
		||||
  tags: pointwise
 | 
			
		||||
 | 
			
		||||
- func: angle.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
 | 
			
		||||
  device_check: NoCheck   # TensorIterator
 | 
			
		||||
  dispatch:
 | 
			
		||||
    CPU, CUDA: angle_out
 | 
			
		||||
    MPS: angle_out_mps
 | 
			
		||||
    CPU, CUDA, MPS: angle_out
 | 
			
		||||
    SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: angle_sparse_csr_out
 | 
			
		||||
  tags: pointwise
 | 
			
		||||
 | 
			
		||||
@ -10042,8 +10040,7 @@
 | 
			
		||||
  structured: True
 | 
			
		||||
  structured_inherits: TensorIteratorBase
 | 
			
		||||
  dispatch:
 | 
			
		||||
    CPU, CUDA: hypot_out
 | 
			
		||||
    MPS: hypot_out_mps
 | 
			
		||||
    CPU, CUDA, MPS: hypot_out
 | 
			
		||||
  tags: pointwise
 | 
			
		||||
 | 
			
		||||
- func: hypot(Tensor self, Tensor other) -> Tensor
 | 
			
		||||
@ -14157,16 +14154,10 @@
 | 
			
		||||
- func: linalg_lu_factor(Tensor A, *, bool pivot=True) -> (Tensor LU, Tensor pivots)
 | 
			
		||||
  python_module: linalg
 | 
			
		||||
  variants: function
 | 
			
		||||
  dispatch:
 | 
			
		||||
    CompositeImplicitAutograd: linalg_lu_factor
 | 
			
		||||
    MPS: linalg_lu_factor_mps
 | 
			
		||||
 | 
			
		||||
- func: linalg_lu_factor.out(Tensor A, *, bool pivot=True, Tensor(a!) LU, Tensor(b!) pivots) -> (Tensor(a!) LU, Tensor(b!) pivots)
 | 
			
		||||
  python_module: linalg
 | 
			
		||||
  variants: function
 | 
			
		||||
  dispatch:
 | 
			
		||||
    CompositeImplicitAutograd: linalg_lu_factor_out
 | 
			
		||||
    MPS: linalg_lu_factor_out_mps
 | 
			
		||||
 | 
			
		||||
- func: linalg_lu_factor_ex(Tensor A, *, bool pivot=True, bool check_errors=False) -> (Tensor LU, Tensor pivots, Tensor info)
 | 
			
		||||
  python_module: linalg
 | 
			
		||||
@ -14368,12 +14359,12 @@
 | 
			
		||||
  python_module: linalg
 | 
			
		||||
  variants: function
 | 
			
		||||
  dispatch:
 | 
			
		||||
    CPU, CUDA: linalg_householder_product
 | 
			
		||||
    CPU, CUDA, MPS: linalg_householder_product
 | 
			
		||||
 | 
			
		||||
- func: linalg_householder_product.out(Tensor input, Tensor tau, *, Tensor(a!) out) -> Tensor(a!)
 | 
			
		||||
  python_module: linalg
 | 
			
		||||
  dispatch:
 | 
			
		||||
    CPU, CUDA: linalg_householder_product_out
 | 
			
		||||
    CPU, CUDA, MPS: linalg_householder_product_out
 | 
			
		||||
 | 
			
		||||
- func: linalg_inv_ex(Tensor A, *, bool check_errors=False) -> (Tensor inverse, Tensor info)
 | 
			
		||||
  python_module: linalg
 | 
			
		||||
 | 
			
		||||
@ -575,24 +575,9 @@ void spmm(
 | 
			
		||||
  cusparseOperation_t opB = transpose_B ? CUSPARSE_OPERATION_TRANSPOSE
 | 
			
		||||
                                        : CUSPARSE_OPERATION_NON_TRANSPOSE;
 | 
			
		||||
 | 
			
		||||
  // CUDA < 11.0 doesn't support 64-bit indices and doesn't raise an error about this
 | 
			
		||||
  // silently returning incorrect results
 | 
			
		||||
#if defined(USE_ROCM) && (ROCM_VERSION < 60300)
 | 
			
		||||
  auto mat1_32 = at::native::_sparse_csr_tensor_unsafe(
 | 
			
		||||
      mat1.crow_indices().to(kInt),
 | 
			
		||||
      mat1.col_indices().to(kInt),
 | 
			
		||||
      mat1.values(),
 | 
			
		||||
      mat1.sizes(),
 | 
			
		||||
      mat1.scalar_type(),
 | 
			
		||||
      mat1.layout(),
 | 
			
		||||
      mat1.device());
 | 
			
		||||
  auto descA = at::cuda::sparse::CuSparseSpMatCsrDescriptor(mat1_32);
 | 
			
		||||
  auto algorithm = CUSPARSE_MM_ALG_DEFAULT;
 | 
			
		||||
#else // defined(USE_ROCM) && (ROCM_VERSION < 60300)
 | 
			
		||||
  // TODO: update this to support COO sparse layout
 | 
			
		||||
  auto descA = at::cuda::sparse::CuSparseSpMatCsrDescriptor(mat1);
 | 
			
		||||
  auto algorithm = CUSPARSE_SPMM_CSR_ALG2;
 | 
			
		||||
#endif // defined(USE_ROCM) && (ROCM_VERSION < 60300)
 | 
			
		||||
 | 
			
		||||
  auto descB = at::cuda::sparse::CuSparseConstDnMatDescriptor(
 | 
			
		||||
      transpose_B ? mat2_->mT() : *mat2_);
 | 
			
		||||
 | 
			
		||||
@ -40,15 +40,7 @@
 | 
			
		||||
#include <thrust/iterator/discard_iterator.h>
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
#if defined(__CUDACC__) && (defined(CUSPARSE_VERSION) || (defined(USE_ROCM) && ROCM_VERSION >= 60300))
 | 
			
		||||
#define IS_CUSPARSE11_AVAILABLE() 1
 | 
			
		||||
#else
 | 
			
		||||
#define IS_CUSPARSE11_AVAILABLE() 0
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
#if IS_CUSPARSE11_AVAILABLE()
 | 
			
		||||
#include <library_types.h>
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
namespace at::native {
 | 
			
		||||
 | 
			
		||||
@ -103,17 +95,9 @@ struct csrMatrixRef {
 | 
			
		||||
  int nnz_{0};
 | 
			
		||||
  std::vector<int> size_{};
 | 
			
		||||
 | 
			
		||||
  #if IS_CUSPARSE11_AVAILABLE()
 | 
			
		||||
    cusparseSpMatDescr_t description_{0};
 | 
			
		||||
  #else
 | 
			
		||||
    cusparseMatDescr_t description_{0};
 | 
			
		||||
  #endif
 | 
			
		||||
  cusparseSpMatDescr_t description_{0};
 | 
			
		||||
 | 
			
		||||
  csrMatrixRef() {
 | 
			
		||||
    #if !IS_CUSPARSE11_AVAILABLE()
 | 
			
		||||
      create_general_description_(description_);
 | 
			
		||||
    #endif
 | 
			
		||||
  }
 | 
			
		||||
  csrMatrixRef() = default;
 | 
			
		||||
 | 
			
		||||
  csrMatrixRef(
 | 
			
		||||
      int* csr_indices,
 | 
			
		||||
@ -126,7 +110,6 @@ struct csrMatrixRef {
 | 
			
		||||
        csr_values_{csr_values},
 | 
			
		||||
        nnz_{nnz},
 | 
			
		||||
        size_{size} {
 | 
			
		||||
    #if IS_CUSPARSE11_AVAILABLE()
 | 
			
		||||
      cudaDataType cuda_data_type = at::cuda::getCudaDataType<scalar_t>();
 | 
			
		||||
      TORCH_CUDASPARSE_CHECK(cusparseCreateCsr(
 | 
			
		||||
        &description_,
 | 
			
		||||
@ -140,17 +123,10 @@ struct csrMatrixRef {
 | 
			
		||||
        CUSPARSE_INDEX_32I,
 | 
			
		||||
        CUSPARSE_INDEX_BASE_ZERO,
 | 
			
		||||
        cuda_data_type));
 | 
			
		||||
    #else
 | 
			
		||||
      create_general_description_(description_);
 | 
			
		||||
    #endif
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  ~csrMatrixRef() {
 | 
			
		||||
    #if IS_CUSPARSE11_AVAILABLE()
 | 
			
		||||
      cusparseDestroySpMat(description_);
 | 
			
		||||
    #else
 | 
			
		||||
      cusparseDestroyMatDescr(description_);
 | 
			
		||||
    #endif
 | 
			
		||||
    cusparseDestroySpMat(description_);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  int size(int index) const {
 | 
			
		||||
@ -196,8 +172,6 @@ struct csrOutput {
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
#if IS_CUSPARSE11_AVAILABLE()
 | 
			
		||||
 | 
			
		||||
// RAII guard helps to support cuSparse 11 API for `A @ B` operation
 | 
			
		||||
// This generic template exists because with cuSparse the `scalar_t` type could be a double or float
 | 
			
		||||
template <class scalar_t>
 | 
			
		||||
@ -396,284 +370,6 @@ template struct CusparseMatrixMultiplyOp<float>;
 | 
			
		||||
 | 
			
		||||
template struct CusparseMatrixMultiplyOp<double>;
 | 
			
		||||
 | 
			
		||||
#else // if not IS_CUSPARSE11_AVAILABLE()
 | 
			
		||||
 | 
			
		||||
using DcsrMatrixRef = csrMatrixRef<double>;
 | 
			
		||||
using ScsrMatrixRef = csrMatrixRef<float>;
 | 
			
		||||
 | 
			
		||||
// RAII guard helps to support cuSparse 10 API for `A @ B` operation
 | 
			
		||||
// This generic template exists because with cuSparse the `scalar_t` type could be a double or float
 | 
			
		||||
template <class scalar_t>
 | 
			
		||||
struct CusparseMatrixMultiplyOp {
 | 
			
		||||
  csrOutput operator()(
 | 
			
		||||
      const csrMatrixRef<scalar_t>& lhs,
 | 
			
		||||
      const csrMatrixRef<scalar_t>& rhs,
 | 
			
		||||
      Tensor &output_values,
 | 
			
		||||
      Tensor &output_indices)
 | 
			
		||||
  {
 | 
			
		||||
    static_assert(false&&sizeof(scalar_t), "cusparse csr sparse-sparse MM only supports data type of float and double.");
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
// Specializacion for `A @ B` operation for double values with cuSparse
 | 
			
		||||
template<> struct CusparseMatrixMultiplyOp<double> {
 | 
			
		||||
  csrgemm2Info_t gemm2Info_;
 | 
			
		||||
 | 
			
		||||
  CusparseMatrixMultiplyOp() {
 | 
			
		||||
    TORCH_CUDASPARSE_CHECK(cusparseCreateCsrgemm2Info(&gemm2Info_));
 | 
			
		||||
  }
 | 
			
		||||
  ~CusparseMatrixMultiplyOp() {
 | 
			
		||||
    cusparseDestroyCsrgemm2Info(gemm2Info_);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  csrOutput operator ()(
 | 
			
		||||
      const DcsrMatrixRef& lhs,
 | 
			
		||||
      const DcsrMatrixRef& rhs,
 | 
			
		||||
      Tensor &output_values,
 | 
			
		||||
      Tensor &output_indices) {
 | 
			
		||||
    double alpha = 1.0;
 | 
			
		||||
    DcsrMatrixRef empty;
 | 
			
		||||
    return Dgemm2(lhs, rhs, empty, &alpha, nullptr, output_values, output_indices);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  csrOutput Dgemm2(
 | 
			
		||||
      const DcsrMatrixRef& A,
 | 
			
		||||
      const DcsrMatrixRef& B,
 | 
			
		||||
      const DcsrMatrixRef& C,
 | 
			
		||||
      const double* alpha,
 | 
			
		||||
      const double* beta,
 | 
			
		||||
      Tensor &output_values,
 | 
			
		||||
      Tensor &output_indices) {
 | 
			
		||||
    void* buffer_{nullptr};
 | 
			
		||||
    cusparseHandle_t cusparseHandle_ = at::cuda::getCurrentCUDASparseHandle();
 | 
			
		||||
    TORCH_CUDASPARSE_CHECK(cusparseSetPointerMode(cusparseHandle_, CUSPARSE_POINTER_MODE_HOST));
 | 
			
		||||
 | 
			
		||||
    csrOutput out({A.size(0), B.size(1)});
 | 
			
		||||
    int innerSize = confirm_mult_size(A.size_, B.size_);
 | 
			
		||||
    out.csr_pointers_ = at::empty({out.size(0) + 1}, output_indices.options().dtype(kInt));
 | 
			
		||||
 | 
			
		||||
    // Compute needed buffer size
 | 
			
		||||
    size_t new_bubber_sz;
 | 
			
		||||
    TORCH_CUDASPARSE_CHECK(cusparseDcsrgemm2_bufferSizeExt(
 | 
			
		||||
        cusparseHandle_,
 | 
			
		||||
        out.size(0),
 | 
			
		||||
        out.size(1),
 | 
			
		||||
        innerSize,
 | 
			
		||||
        alpha,
 | 
			
		||||
        A.description_,
 | 
			
		||||
        A.nnz_,
 | 
			
		||||
        A.csr_pointers_,
 | 
			
		||||
        A.csr_indices_,
 | 
			
		||||
        B.description_,
 | 
			
		||||
        B.nnz_,
 | 
			
		||||
        B.csr_pointers_,
 | 
			
		||||
        B.csr_indices_,
 | 
			
		||||
        beta,
 | 
			
		||||
        C.description_,
 | 
			
		||||
        C.nnz_,
 | 
			
		||||
        C.csr_pointers_,
 | 
			
		||||
        C.csr_indices_,
 | 
			
		||||
        gemm2Info_,
 | 
			
		||||
        &new_bubber_sz));
 | 
			
		||||
 | 
			
		||||
    // (Re)allocate buffer if needed
 | 
			
		||||
    auto& allocator = *::c10::cuda::CUDACachingAllocator::get();
 | 
			
		||||
    at::DataPtr data_ptr = allocator.allocate(new_bubber_sz);
 | 
			
		||||
    buffer_ = data_ptr.get();
 | 
			
		||||
 | 
			
		||||
    // Find the resulting non-zero pattern.
 | 
			
		||||
    TORCH_CUDASPARSE_CHECK(cusparseXcsrgemm2Nnz(
 | 
			
		||||
        cusparseHandle_,
 | 
			
		||||
        out.size(0),
 | 
			
		||||
        out.size(1),
 | 
			
		||||
        innerSize,
 | 
			
		||||
        A.description_,
 | 
			
		||||
        A.nnz_,
 | 
			
		||||
        A.csr_pointers_,
 | 
			
		||||
        A.csr_indices_,
 | 
			
		||||
        B.description_,
 | 
			
		||||
        B.nnz_,
 | 
			
		||||
        B.csr_pointers_,
 | 
			
		||||
        B.csr_indices_,
 | 
			
		||||
        C.description_,
 | 
			
		||||
        C.nnz_,
 | 
			
		||||
        C.csr_pointers_,
 | 
			
		||||
        C.csr_indices_,
 | 
			
		||||
        out.description_,
 | 
			
		||||
        out.csr_pointers_.data_ptr<int>(),
 | 
			
		||||
        &out.nnz_,
 | 
			
		||||
        gemm2Info_,
 | 
			
		||||
        buffer_));
 | 
			
		||||
 | 
			
		||||
    out.csr_indices_ = at::empty({out.nnz_}, output_indices.options().dtype(kInt));
 | 
			
		||||
    out.csr_values_ = at::empty({out.nnz_}, output_values.options());
 | 
			
		||||
 | 
			
		||||
    // Perform the gemm2 operation for doubles
 | 
			
		||||
    // out = alpha ∗ A ∗ B + beta ∗ C
 | 
			
		||||
    TORCH_CUDASPARSE_CHECK(cusparseDcsrgemm2(
 | 
			
		||||
        cusparseHandle_,
 | 
			
		||||
        out.size(0),
 | 
			
		||||
        out.size(1),
 | 
			
		||||
        innerSize,
 | 
			
		||||
        alpha,
 | 
			
		||||
        A.description_,
 | 
			
		||||
        A.nnz_,
 | 
			
		||||
        A.csr_values_,
 | 
			
		||||
        A.csr_pointers_,
 | 
			
		||||
        A.csr_indices_,
 | 
			
		||||
        B.description_,
 | 
			
		||||
        B.nnz_,
 | 
			
		||||
        B.csr_values_,
 | 
			
		||||
        B.csr_pointers_,
 | 
			
		||||
        B.csr_indices_,
 | 
			
		||||
        beta,
 | 
			
		||||
        C.description_,
 | 
			
		||||
        C.nnz_,
 | 
			
		||||
        C.csr_values_,
 | 
			
		||||
        C.csr_pointers_,
 | 
			
		||||
        C.csr_indices_,
 | 
			
		||||
        out.description_,
 | 
			
		||||
        out.csr_values_.data_ptr<double>(),
 | 
			
		||||
        out.csr_pointers_.data_ptr<int>(),
 | 
			
		||||
        out.csr_indices_.data_ptr<int>(),
 | 
			
		||||
        gemm2Info_,
 | 
			
		||||
        buffer_));
 | 
			
		||||
    return out;
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
// Specializacion for `A @ B` operation for float values with cuSparse
 | 
			
		||||
template<> struct CusparseMatrixMultiplyOp<float> {
 | 
			
		||||
  csrgemm2Info_t gemm2Info_;
 | 
			
		||||
 | 
			
		||||
  CusparseMatrixMultiplyOp() {
 | 
			
		||||
    TORCH_CUDASPARSE_CHECK(cusparseCreateCsrgemm2Info(&gemm2Info_));
 | 
			
		||||
 | 
			
		||||
  }
 | 
			
		||||
  ~CusparseMatrixMultiplyOp() {
 | 
			
		||||
    cusparseDestroyCsrgemm2Info(gemm2Info_);
 | 
			
		||||
  }
 | 
			
		||||
  csrOutput operator()(
 | 
			
		||||
      const ScsrMatrixRef& lhs,
 | 
			
		||||
      const ScsrMatrixRef& rhs,
 | 
			
		||||
      Tensor &output_values,
 | 
			
		||||
      Tensor &output_indices) {
 | 
			
		||||
    float alpha = 1.0;
 | 
			
		||||
    ScsrMatrixRef empty;
 | 
			
		||||
    return Sgemm2(lhs, rhs, empty, &alpha, nullptr, output_values, output_indices);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  csrOutput Sgemm2(
 | 
			
		||||
      const ScsrMatrixRef& A,
 | 
			
		||||
      const ScsrMatrixRef& B,
 | 
			
		||||
      const ScsrMatrixRef& C,
 | 
			
		||||
      const float* alpha,
 | 
			
		||||
      const float* beta,
 | 
			
		||||
      Tensor &output_values,
 | 
			
		||||
      Tensor &output_indices) {
 | 
			
		||||
    void* buffer_{nullptr};
 | 
			
		||||
    cusparseHandle_t cusparseHandle_ = at::cuda::getCurrentCUDASparseHandle();
 | 
			
		||||
    TORCH_CUDASPARSE_CHECK(cusparseSetPointerMode(cusparseHandle_, CUSPARSE_POINTER_MODE_HOST));
 | 
			
		||||
 | 
			
		||||
    csrOutput out({A.size(0), B.size(1)});
 | 
			
		||||
 | 
			
		||||
    int innerSize = confirm_mult_size(A.size_, B.size_);
 | 
			
		||||
 | 
			
		||||
    out.csr_pointers_ = at::empty({out.size(0) + 1}, output_indices.options().dtype(kInt));
 | 
			
		||||
 | 
			
		||||
    // Compute needed buffer size
 | 
			
		||||
    size_t new_bubber_sz;
 | 
			
		||||
    TORCH_CUDASPARSE_CHECK(cusparseScsrgemm2_bufferSizeExt(
 | 
			
		||||
        cusparseHandle_,
 | 
			
		||||
        out.size(0),
 | 
			
		||||
        out.size(1),
 | 
			
		||||
        innerSize,
 | 
			
		||||
        alpha,
 | 
			
		||||
        A.description_,
 | 
			
		||||
        A.nnz_,
 | 
			
		||||
        A.csr_pointers_,
 | 
			
		||||
        A.csr_indices_,
 | 
			
		||||
        B.description_,
 | 
			
		||||
        B.nnz_,
 | 
			
		||||
        B.csr_pointers_,
 | 
			
		||||
        B.csr_indices_,
 | 
			
		||||
        beta,
 | 
			
		||||
        C.description_,
 | 
			
		||||
        C.nnz_,
 | 
			
		||||
        C.csr_pointers_,
 | 
			
		||||
        C.csr_indices_,
 | 
			
		||||
        gemm2Info_,
 | 
			
		||||
        &new_bubber_sz));
 | 
			
		||||
 | 
			
		||||
    auto& allocator = *::c10::cuda::CUDACachingAllocator::get();
 | 
			
		||||
    at::DataPtr data_ptr = allocator.allocate(new_bubber_sz);
 | 
			
		||||
    buffer_ = data_ptr.get();
 | 
			
		||||
 | 
			
		||||
    // Find the resulting non-zero pattern.
 | 
			
		||||
    TORCH_CUDASPARSE_CHECK(cusparseXcsrgemm2Nnz(
 | 
			
		||||
        cusparseHandle_,
 | 
			
		||||
        out.size(0),
 | 
			
		||||
        out.size(1),
 | 
			
		||||
        innerSize,
 | 
			
		||||
        A.description_,
 | 
			
		||||
        A.nnz_,
 | 
			
		||||
        A.csr_pointers_,
 | 
			
		||||
        A.csr_indices_,
 | 
			
		||||
        B.description_,
 | 
			
		||||
        B.nnz_,
 | 
			
		||||
        B.csr_pointers_,
 | 
			
		||||
        B.csr_indices_,
 | 
			
		||||
        C.description_,
 | 
			
		||||
        C.nnz_,
 | 
			
		||||
        C.csr_pointers_,
 | 
			
		||||
        C.csr_indices_,
 | 
			
		||||
        out.description_,
 | 
			
		||||
        out.csr_pointers_.data_ptr<int>(),
 | 
			
		||||
        &out.nnz_,
 | 
			
		||||
        gemm2Info_,
 | 
			
		||||
        buffer_));
 | 
			
		||||
 | 
			
		||||
    out.csr_indices_ = at::empty({out.nnz_}, output_indices.options().dtype(kInt));
 | 
			
		||||
    out.csr_values_ = at::empty({out.nnz_}, output_values.options());
 | 
			
		||||
 | 
			
		||||
    // Perform the gemm2 operation for doubles
 | 
			
		||||
    // out = alpha ∗ A ∗ B + beta ∗ C
 | 
			
		||||
    TORCH_CUDASPARSE_CHECK(cusparseScsrgemm2(
 | 
			
		||||
        cusparseHandle_,
 | 
			
		||||
        out.size(0),
 | 
			
		||||
        out.size(1),
 | 
			
		||||
        innerSize,
 | 
			
		||||
        alpha,
 | 
			
		||||
        A.description_,
 | 
			
		||||
        A.nnz_,
 | 
			
		||||
        A.csr_values_,
 | 
			
		||||
        A.csr_pointers_,
 | 
			
		||||
        A.csr_indices_,
 | 
			
		||||
        B.description_,
 | 
			
		||||
        B.nnz_,
 | 
			
		||||
        B.csr_values_,
 | 
			
		||||
        B.csr_pointers_,
 | 
			
		||||
        B.csr_indices_,
 | 
			
		||||
        beta,
 | 
			
		||||
        C.description_,
 | 
			
		||||
        C.nnz_,
 | 
			
		||||
        C.csr_values_,
 | 
			
		||||
        C.csr_pointers_,
 | 
			
		||||
        C.csr_indices_,
 | 
			
		||||
        out.description_,
 | 
			
		||||
        out.csr_values_.data_ptr<float>(),
 | 
			
		||||
        out.csr_pointers_.data_ptr<int>(),
 | 
			
		||||
        out.csr_indices_.data_ptr<int>(),
 | 
			
		||||
        gemm2Info_,
 | 
			
		||||
        buffer_));
 | 
			
		||||
    return out;
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
#endif // IS_CUSPARSE11_AVAILABLE()
 | 
			
		||||
 | 
			
		||||
template <typename scalar_t>
 | 
			
		||||
void sparse_sparse_matmul_cuda_kernel(
 | 
			
		||||
    Tensor& result,
 | 
			
		||||
@ -815,19 +511,15 @@ Tensor sparse_sparse_matmul_cuda(const Tensor& mat1_, const Tensor& mat2_) {
 | 
			
		||||
  auto output = at::native::empty_like(mat1_);
 | 
			
		||||
  output.sparse_resize_and_clear_({mat1_.size(0), mat2_.size(1)}, mat1_.sparse_dim(), 0);
 | 
			
		||||
 | 
			
		||||
#if IS_CUSPARSE11_AVAILABLE() && !defined(USE_ROCM)
 | 
			
		||||
#if !defined(USE_ROCM)
 | 
			
		||||
  AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kHalf, kBFloat16, mat1_.scalar_type(), "sparse_matmul", [&] {
 | 
			
		||||
      sparse_sparse_matmul_cuda_kernel<scalar_t>(output, mat1_.coalesce(), mat2_.coalesce());
 | 
			
		||||
  });
 | 
			
		||||
#elif IS_CUSPARSE11_AVAILABLE() && defined(USE_ROCM)
 | 
			
		||||
#else
 | 
			
		||||
  // ROCm does not support half and bfloat16 types for sparse_matmul
 | 
			
		||||
  AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(mat1_.scalar_type(), "sparse_matmul", [&] {
 | 
			
		||||
      sparse_sparse_matmul_cuda_kernel<scalar_t>(output, mat1_.coalesce(), mat2_.coalesce());
 | 
			
		||||
  });
 | 
			
		||||
#else
 | 
			
		||||
  AT_DISPATCH_FLOATING_TYPES(mat1_.scalar_type(), "sparse_matmul", [&] {
 | 
			
		||||
    sparse_sparse_matmul_cuda_kernel<scalar_t>(output, mat1_.coalesce(), mat2_.coalesce());
 | 
			
		||||
  });
 | 
			
		||||
#endif
 | 
			
		||||
  return output;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -33,7 +33,7 @@ using namespace mps;
 | 
			
		||||
#ifndef PYTORCH_JIT_COMPILE_SHADERS
 | 
			
		||||
static auto& lib = MetalShaderLibrary::getBundledLibrary();
 | 
			
		||||
#else
 | 
			
		||||
#include <ATen/native/mps/Mul_metallib.h>
 | 
			
		||||
#include <ATen/native/mps/SparseTensorMath_metallib.h>
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
static Tensor& s_addmm_out_sparse_dense_mps(
 | 
			
		||||
@ -369,12 +369,7 @@ static SparseTensor& mul_out_dense_sparse_mps(
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  if (scalar_like) {
 | 
			
		||||
    auto scalar = dense;
 | 
			
		||||
    if (dense.numel() == 1 && dense.dim() > 0) {
 | 
			
		||||
      scalar = dense.view({});
 | 
			
		||||
    }
 | 
			
		||||
    scalar = scalar.to(values.options());
 | 
			
		||||
    auto out_vals = values.mul(scalar);
 | 
			
		||||
    auto out_vals = values.mul(dense.to(values.options()));
 | 
			
		||||
    if (out.scalar_type() != commonDtype) {
 | 
			
		||||
      out_vals = out_vals.to(out.scalar_type());
 | 
			
		||||
    }
 | 
			
		||||
@ -508,14 +503,14 @@ SparseTensor& mul_out_sparse_mps(const Tensor& t_, const Tensor& src_, SparseTen
 | 
			
		||||
  const auto device = r_.device();
 | 
			
		||||
  auto stream = getCurrentMPSStream();
 | 
			
		||||
 | 
			
		||||
  auto lhs_indices = lhs._indices();
 | 
			
		||||
  auto rhs_indices = rhs._indices();
 | 
			
		||||
  auto lhs_values  = lhs._values().to(commonDtype);
 | 
			
		||||
  auto rhs_values  = rhs._values().to(commonDtype);
 | 
			
		||||
  auto lhs_indices = lhs._indices().contiguous();
 | 
			
		||||
  auto rhs_indices = rhs._indices().contiguous();
 | 
			
		||||
  auto lhs_values  = lhs._values().to(commonDtype).contiguous();
 | 
			
		||||
  auto rhs_values  = rhs._values().to(commonDtype).contiguous();
 | 
			
		||||
 | 
			
		||||
  // Flatten sparse indices to keys
 | 
			
		||||
  auto lhs_keys = flatten_indices(lhs_indices, lhs.sizes());
 | 
			
		||||
  auto rhs_keys = flatten_indices(rhs_indices, rhs.sizes());
 | 
			
		||||
  auto lhs_keys = flatten_indices(lhs_indices, lhs.sizes().slice(0, ndim_i));
 | 
			
		||||
  auto rhs_keys = flatten_indices(rhs_indices, rhs.sizes().slice(0, ndim_i));
 | 
			
		||||
 | 
			
		||||
  // Intersect sorted keys (search the shorter in the longer)
 | 
			
		||||
  const bool A_is_lhs = (lhs_nnz <= rhs_nnz);
 | 
			
		||||
@ -546,35 +541,54 @@ SparseTensor& mul_out_sparse_mps(const Tensor& t_, const Tensor& src_, SparseTen
 | 
			
		||||
  auto out_indices = at::empty({ndim_i, static_cast<int64_t>(M)}, at::device(device).dtype(at::kLong));
 | 
			
		||||
  auto lhs_match = outA_idx.narrow(0, 0, M);
 | 
			
		||||
  auto rhs_match = outB_idx.narrow(0, 0, M);
 | 
			
		||||
  auto out_val_sizes = lhs_values.sizes().vec();
 | 
			
		||||
  out_val_sizes[0] = static_cast<int64_t>(M);
 | 
			
		||||
  auto dense_sizes_vec = lhs.sizes().slice(ndim_i).vec();
 | 
			
		||||
  int64_t cols64 = 1;
 | 
			
		||||
  for (auto s : dense_sizes_vec) cols64 *= s;
 | 
			
		||||
  const uint32_t cols = static_cast<uint32_t>(std::max<int64_t>(cols64, 1));
 | 
			
		||||
 | 
			
		||||
  auto to2d = [&](Tensor t, int64_t nnz) -> Tensor {
 | 
			
		||||
    const int64_t t_cols = t.numel() / nnz;
 | 
			
		||||
    if (t_cols == cols64) {
 | 
			
		||||
      return t.view({nnz, cols64});
 | 
			
		||||
    }
 | 
			
		||||
    return t.view({nnz, 1}).expand({nnz, cols64}).contiguous();
 | 
			
		||||
  };
 | 
			
		||||
 | 
			
		||||
  // make both sides 2d [nnz, cols] buffers so the kernel can index it
 | 
			
		||||
  auto lhs_vals2d = to2d(lhs_values, lhs_nnz);
 | 
			
		||||
  auto rhs_vals2d = to2d(rhs_values, rhs_nnz);
 | 
			
		||||
 | 
			
		||||
  std::vector<int64_t> out_val_sizes;
 | 
			
		||||
  out_val_sizes.reserve(1 + dense_sizes_vec.size());
 | 
			
		||||
  out_val_sizes.push_back(static_cast<int64_t>(M));
 | 
			
		||||
  out_val_sizes.insert(out_val_sizes.end(), dense_sizes_vec.begin(), dense_sizes_vec.end());
 | 
			
		||||
  auto out_values = at::empty(out_val_sizes, lhs_values.options());
 | 
			
		||||
 | 
			
		||||
  const uint32_t cols = static_cast<uint32_t>(
 | 
			
		||||
      lhs_values.numel() / std::max<int64_t>(1, lhs_nnz));
 | 
			
		||||
  if (M > 0) {
 | 
			
		||||
    dispatch_sync_with_rethrow(stream->queue(), ^() {
 | 
			
		||||
      @autoreleasepool {
 | 
			
		||||
        auto pso = lib.getPipelineStateForFunc(
 | 
			
		||||
            "fused_gather_mul_kernel_" + mps::scalarToMetalTypeString(lhs_values));
 | 
			
		||||
        auto enc = stream->commandEncoder();
 | 
			
		||||
        [enc setComputePipelineState:pso];
 | 
			
		||||
 | 
			
		||||
  dispatch_sync_with_rethrow(stream->queue(), ^() {
 | 
			
		||||
    @autoreleasepool {
 | 
			
		||||
      auto pso = lib.getPipelineStateForFunc(
 | 
			
		||||
          "fused_gather_mul_kernel_" + mps::scalarToMetalTypeString(lhs_values));
 | 
			
		||||
      auto enc = stream->commandEncoder();
 | 
			
		||||
      [enc setComputePipelineState:pso];
 | 
			
		||||
        const uint32_t tew = pso.threadExecutionWidth;
 | 
			
		||||
        const uint32_t gridW = std::max<uint32_t>(cols, 1u);
 | 
			
		||||
        const uint32_t tgW = std::min(gridW, tew);
 | 
			
		||||
        MTLSize grid = MTLSizeMake(gridW, 1, M);
 | 
			
		||||
        MTLSize tgs  = MTLSizeMake(tgW, 1, 1);
 | 
			
		||||
 | 
			
		||||
      const uint32_t tew  = pso.threadExecutionWidth;
 | 
			
		||||
      uint32_t tgW = std::min(cols, tew);
 | 
			
		||||
      MTLSize grid = MTLSizeMake(cols, 1, M);
 | 
			
		||||
      MTLSize tgs  = MTLSizeMake(tgW, 1, 1);
 | 
			
		||||
 | 
			
		||||
      mtl_setArgs(enc,
 | 
			
		||||
                  lhs_values, rhs_values,
 | 
			
		||||
                  lhs_match, rhs_match,
 | 
			
		||||
                  lhs_indices, out_indices,
 | 
			
		||||
                  out_values,
 | 
			
		||||
                  std::array<uint32_t, 2>{static_cast<uint32_t>(ndim_i), static_cast<uint32_t>(lhs_nnz)},
 | 
			
		||||
                  std::array<uint32_t, 2>{M, cols});
 | 
			
		||||
      [enc dispatchThreads:grid threadsPerThreadgroup:tgs];
 | 
			
		||||
    }
 | 
			
		||||
  });
 | 
			
		||||
        mtl_setArgs(enc,
 | 
			
		||||
                    lhs_vals2d, rhs_vals2d,
 | 
			
		||||
                    lhs_match, rhs_match,
 | 
			
		||||
                    lhs_indices, out_indices,
 | 
			
		||||
                    out_values,
 | 
			
		||||
                    std::array<uint32_t, 2>{static_cast<uint32_t>(ndim_i), static_cast<uint32_t>(lhs_nnz)},
 | 
			
		||||
                    std::array<uint32_t, 2>{M, cols});
 | 
			
		||||
        [enc dispatchThreads:grid threadsPerThreadgroup:tgs];
 | 
			
		||||
      }
 | 
			
		||||
    });
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  if (r_.scalar_type() != commonDtype) {
 | 
			
		||||
    out_values = out_values.to(r_.scalar_type());
 | 
			
		||||
 | 
			
		||||
@ -62,7 +62,6 @@ kernel void build_row_ptr_from_sorted_rows_by_batch(
 | 
			
		||||
 | 
			
		||||
template <typename T>
 | 
			
		||||
kernel void spmm_bmm_coo_rows_grouped(
 | 
			
		||||
    device const long*   rows      [[buffer(0)]],
 | 
			
		||||
    device const long*   cols      [[buffer(1)]],
 | 
			
		||||
    device const T*      vals      [[buffer(2)]],
 | 
			
		||||
    device const T*      dense     [[buffer(3)]],
 | 
			
		||||
@ -73,7 +72,6 @@ kernel void spmm_bmm_coo_rows_grouped(
 | 
			
		||||
    uint3                ltid      [[thread_position_in_threadgroup]],
 | 
			
		||||
    uint3                tptg      [[threads_per_threadgroup]])
 | 
			
		||||
{
 | 
			
		||||
  const uint B = dims.x;
 | 
			
		||||
  const uint I = dims.y;
 | 
			
		||||
  const uint J = dims.z;
 | 
			
		||||
  const uint K = dims.w;
 | 
			
		||||
@ -197,9 +195,9 @@ kernel void fused_gather_mul_kernel(
 | 
			
		||||
    const ulong offR = (ulong)iR * (ulong)view_cols + (ulong)col;
 | 
			
		||||
    const ulong offO = (ulong)k  * (ulong)view_cols + (ulong)col;
 | 
			
		||||
 | 
			
		||||
    const float a = (float)lhs_vals[offL];
 | 
			
		||||
    const float b = (float)rhs_vals[offR];
 | 
			
		||||
    out_vals[offO] = (T)(a * b);
 | 
			
		||||
    const auto a = static_cast<accum_t<T>>(lhs_vals[offL]);
 | 
			
		||||
    const auto b = static_cast<accum_t<T>>(rhs_vals[offR]);
 | 
			
		||||
    out_vals[offO] = static_cast<T>(mul(a, b));
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // One thread per match copies the indices column
 | 
			
		||||
@ -321,7 +319,6 @@ INSTANTIATE_FOR_FLOAT_TYPES(INSTANTIATE_FUSED_GATHER_MUL);
 | 
			
		||||
#define INSTANTIATE_SPMM_BMM_COO_ROWS_GROUPED(DTYPE)                         \
 | 
			
		||||
  template [[host_name("spmm_bmm_coo_rows_grouped_" #DTYPE)]] kernel void    \
 | 
			
		||||
  spmm_bmm_coo_rows_grouped<DTYPE>(                                          \
 | 
			
		||||
      device const long*   rows      [[buffer(0)]],                          \
 | 
			
		||||
      device const long*   cols      [[buffer(1)]],                          \
 | 
			
		||||
      device const DTYPE*  vals      [[buffer(2)]],                          \
 | 
			
		||||
      device const DTYPE*  dense     [[buffer(3)]],                          \
 | 
			
		||||
@ -22,6 +22,7 @@
 | 
			
		||||
#else
 | 
			
		||||
#include <ATen/ops/empty.h>
 | 
			
		||||
#include <ATen/ops/empty_like.h>
 | 
			
		||||
#include <ATen/ops/zeros_like.h>
 | 
			
		||||
#include <ATen/ops/reshape.h>
 | 
			
		||||
#include <ATen/ops/scalar_tensor.h>
 | 
			
		||||
#include <ATen/ops/sum.h>
 | 
			
		||||
@ -42,7 +43,6 @@ C10_DIAGNOSTIC_POP()
 | 
			
		||||
#include <static_switch.h>
 | 
			
		||||
#include <ATen/native/transformers/cuda/flash_attn/flash_api.h>
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
#include <c10/util/Exception.h>
 | 
			
		||||
 | 
			
		||||
namespace FLASH_NAMESPACE {
 | 
			
		||||
@ -417,6 +417,26 @@ mha_fwd(const at::Tensor &q,         // batch_size x seqlen_q x num_heads x head
 | 
			
		||||
    const int head_size_og = sizes[3];
 | 
			
		||||
    const int seqlen_k = k.size(1);
 | 
			
		||||
    const int num_heads_k = k.size(2);
 | 
			
		||||
 | 
			
		||||
    if (batch_size == 0) {
 | 
			
		||||
        auto opts = q.options();
 | 
			
		||||
        at::Tensor out = at::empty({0, seqlen_q, num_heads, head_size_og}, opts);
 | 
			
		||||
        at::Tensor q_padded = at::empty({0, seqlen_q, num_heads, head_size_og}, opts);
 | 
			
		||||
        at::Tensor k_padded = at::empty({0, seqlen_k, num_heads_k, head_size_og}, opts);
 | 
			
		||||
        at::Tensor v_padded = at::empty({0, seqlen_k, num_heads_k, head_size_og}, opts);
 | 
			
		||||
        at::Tensor softmax_lse = at::empty({0, num_heads, seqlen_q}, opts.dtype(at::kFloat));
 | 
			
		||||
        at::Tensor rng_state = at::empty({2}, at::dtype(c10::kUInt64).device(at::kCUDA));
 | 
			
		||||
        at::Tensor _unused = at::empty({}, at::dtype(c10::kUInt64).device(at::kCUDA));
 | 
			
		||||
        at::Tensor p = at::empty({0}, opts);
 | 
			
		||||
        if (return_softmax) {
 | 
			
		||||
            auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
 | 
			
		||||
            const int seqlen_q_rounded = round_multiple(seqlen_q, 128);
 | 
			
		||||
            const int seqlen_k_rounded = round_multiple(seqlen_k, 128);
 | 
			
		||||
            p = at::empty({0, num_heads, seqlen_q_rounded, seqlen_k_rounded}, opts);
 | 
			
		||||
        }
 | 
			
		||||
        return {std::move(out), std::move(q_padded), std::move(k_padded), std::move(v_padded), std::move(softmax_lse), std::move(rng_state), _unused, std::move(p)};
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    TORCH_CHECK(batch_size > 0, "batch size must be positive");
 | 
			
		||||
    TORCH_CHECK(head_size_og % 8 == 0, "head_size must be a multiple of 8, this is ensured by padding!");
 | 
			
		||||
    TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256");
 | 
			
		||||
@ -547,7 +567,7 @@ mha_fwd(const at::Tensor &q,         // batch_size x seqlen_q x num_heads x head
 | 
			
		||||
        q_padded = q_padded.transpose(1, 2).reshape({batch_size, 1, num_heads_k * seqlen_q, head_size_og});
 | 
			
		||||
        softmax_lse = softmax_lse.reshape({batch_size, num_heads_k * seqlen_q, 1});
 | 
			
		||||
    }
 | 
			
		||||
    return {out, q_padded, k_padded, v_padded, softmax_lse, rng_state, _unused, p};
 | 
			
		||||
    return {std::move(out), std::move(q_padded), std::move(k_padded), std::move(v_padded), std::move(softmax_lse), std::move(rng_state), std::move(_unused), std::move(p)};
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor>
 | 
			
		||||
@ -852,7 +872,6 @@ mha_bwd(const at::Tensor &dout,  // batch_size x seqlen_q x num_heads, x head_si
 | 
			
		||||
    TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
 | 
			
		||||
    TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
 | 
			
		||||
    TORCH_CHECK(out.stride(-1) == 1, "out tensor must have contiguous last dimension");
 | 
			
		||||
    TORCH_CHECK(dout.stride(-1) == 1, "dout tensor must have contiguous last dimension");
 | 
			
		||||
 | 
			
		||||
    const auto sizes = q.sizes();
 | 
			
		||||
 | 
			
		||||
@ -863,6 +882,20 @@ mha_bwd(const at::Tensor &dout,  // batch_size x seqlen_q x num_heads, x head_si
 | 
			
		||||
    const int head_size = sizes[3];
 | 
			
		||||
    const int seqlen_k = k.size(1);
 | 
			
		||||
    const int num_heads_k = k.size(2);
 | 
			
		||||
 | 
			
		||||
    if (batch_size == 0) {
 | 
			
		||||
        auto opts = q.options();
 | 
			
		||||
        at::Tensor dq = at::empty_like(q);
 | 
			
		||||
        at::Tensor dk = at::empty_like(k);
 | 
			
		||||
        at::Tensor dv = at::empty_like(v);
 | 
			
		||||
        auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
 | 
			
		||||
        const int seqlen_q_rounded = round_multiple(seqlen_q, 128);
 | 
			
		||||
        at::Tensor softmax_d = at::empty({0, num_heads, seqlen_q_rounded}, opts.dtype(at::kFloat));
 | 
			
		||||
        return {dq, dk, dv, softmax_d};
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    TORCH_CHECK(dout.stride(-1) == 1, "dout tensor must have contiguous last dimension");
 | 
			
		||||
 | 
			
		||||
    TORCH_CHECK(batch_size > 0, "batch size must be positive");
 | 
			
		||||
    TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8");
 | 
			
		||||
    TORCH_CHECK(head_size_og % 8 == 0, "head_size_og should be a multiple of 8, this is ensured by padding!");
 | 
			
		||||
 | 
			
		||||
@ -76,14 +76,21 @@ bool priority_order_init_ = false;
 | 
			
		||||
// TODO(eqy): more benchmarking to determine whether this should include sm86/89
 | 
			
		||||
// Needs to be kept in-sync with test_fused_chocie in test_transformers.py
 | 
			
		||||
bool check_prefer_cudnn_attention() {
 | 
			
		||||
  static const bool prefer_cudnn = c10::utils::check_env("TORCH_CUDNN_SDPA_PREFERRED") != false;
 | 
			
		||||
  static const bool prefer_cudnn = c10::utils::check_env("TORCH_CUDNN_SDPA_DEPRIORITIZED") != true;
 | 
			
		||||
  if (!prefer_cudnn) {
 | 
			
		||||
    return false;
 | 
			
		||||
  }
 | 
			
		||||
#if (defined(CUDNN_VERSION) && (CUDNN_VERSION >= 90900))
 | 
			
		||||
  auto dprops = at::cuda::getCurrentDeviceProperties();
 | 
			
		||||
  auto major = dprops->major;
 | 
			
		||||
  return (major == 9 || major == 10) && !dprops->minor;
 | 
			
		||||
  try {
 | 
			
		||||
    auto dprops = at::cuda::getCurrentDeviceProperties();
 | 
			
		||||
    auto major = dprops->major;
 | 
			
		||||
    return (major == 9 || major == 10) && !dprops->minor;
 | 
			
		||||
  } catch (c10::Error const& e) {
 | 
			
		||||
#ifdef DEBUG
 | 
			
		||||
    TORCH_WARN("check_prefer_cudnn_attention() caught exception ", e.what());
 | 
			
		||||
#endif
 | 
			
		||||
    return false;
 | 
			
		||||
  }
 | 
			
		||||
#else
 | 
			
		||||
  return false;
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
@ -37,6 +37,10 @@ TEST(SingletonOrSharedTypePtr, Comparison) {
 | 
			
		||||
 | 
			
		||||
  EXPECT_NE(empty, p);
 | 
			
		||||
  EXPECT_NE(p, p2);
 | 
			
		||||
 | 
			
		||||
  EXPECT_EQ(empty, empty);
 | 
			
		||||
  EXPECT_EQ(p, p);
 | 
			
		||||
  EXPECT_EQ(p2, p2);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST(SingletonOrSharedTypePtr, SingletonComparison) {
 | 
			
		||||
@ -47,6 +51,8 @@ TEST(SingletonOrSharedTypePtr, SingletonComparison) {
 | 
			
		||||
  c10::TypePtr type = c10::NoneType::get();
 | 
			
		||||
  EXPECT_NE(type, c10::StringType::get());
 | 
			
		||||
  EXPECT_NE(type, c10::DeviceObjType::get());
 | 
			
		||||
  EXPECT_EQ(type, type);
 | 
			
		||||
  EXPECT_EQ(type, c10::NoneType::get());
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -526,6 +526,41 @@ namespace {
 | 
			
		||||
            [](const vec& v) { return v.expm1(); },
 | 
			
		||||
            createDefaultUnaryTestCase<vec>(TestSeed(), false, true));
 | 
			
		||||
    }
 | 
			
		||||
    TYPED_TEST(Exponents, ExpU20) {
 | 
			
		||||
        using vec = TypeParam;
 | 
			
		||||
        using VT = ValueType<TypeParam>;
 | 
			
		||||
        using UVT = UvalueType<TypeParam>;
 | 
			
		||||
 | 
			
		||||
        // Explicit edge values
 | 
			
		||||
        VT v_too_small = VT(-100.0); // much less than -87.3
 | 
			
		||||
        VT exp_too_small = std::exp(v_too_small);
 | 
			
		||||
        VT v_neg_edge = VT(-0x1.5d5e2ap+6f);   // just at the edge
 | 
			
		||||
        VT exp_neg_edge = std::exp(v_neg_edge);
 | 
			
		||||
        VT v_zero = VT(0.0);         // middle, normal case
 | 
			
		||||
        VT exp_zero = std::exp(v_zero);
 | 
			
		||||
        VT v_pos_edge = VT(0x1.5d5e2ap+6f);    // just at the edge
 | 
			
		||||
        VT exp_pos_edge = std::exp(v_pos_edge);
 | 
			
		||||
        VT v_too_large = VT(100.0);  // much more than 87.3
 | 
			
		||||
        VT exp_too_large = std::exp(v_too_large);
 | 
			
		||||
 | 
			
		||||
        auto test_case = TestingCase<vec>::getBuilder()
 | 
			
		||||
            // Randoms in normal range, but the .addCustom() below guarantees we hit the special/fallback cases
 | 
			
		||||
            .addDomain(CheckWithinDomains<UVT>{{{-100, 100}}, false, getDefaultTolerance<UVT>()})
 | 
			
		||||
            .addCustom({ {v_too_small}, exp_too_small })
 | 
			
		||||
            .addCustom({ {v_neg_edge}, exp_neg_edge })
 | 
			
		||||
            .addCustom({ {v_zero}, exp_zero })
 | 
			
		||||
            .addCustom({ {v_pos_edge}, exp_pos_edge })
 | 
			
		||||
            .addCustom({ {v_too_large}, exp_too_large })
 | 
			
		||||
            .setTrialCount(65536)
 | 
			
		||||
            .setTestSeed(TestSeed());
 | 
			
		||||
 | 
			
		||||
        test_unary<vec>(
 | 
			
		||||
            NAME_INFO(exp_u20_edge_cases),
 | 
			
		||||
            RESOLVE_OVERLOAD(std::exp),
 | 
			
		||||
            [](const vec& v) { return v.exp_u20(); },
 | 
			
		||||
            test_case
 | 
			
		||||
        );
 | 
			
		||||
    }
 | 
			
		||||
    TYPED_TEST(ErrorFunctions, Erf) {
 | 
			
		||||
        using vec = TypeParam;
 | 
			
		||||
        test_unary<vec>(
 | 
			
		||||
 | 
			
		||||
@ -58,8 +58,7 @@ def list_benchmarks():
 | 
			
		||||
 | 
			
		||||
def run_benchmark(
 | 
			
		||||
    benchmark_name: str,
 | 
			
		||||
    should_visualize: bool = False,
 | 
			
		||||
    compile_mode: str = "max-autotune-no-cudagraphs",
 | 
			
		||||
    script_args,
 | 
			
		||||
):
 | 
			
		||||
    """Run a specific benchmark."""
 | 
			
		||||
    if benchmark_name not in BENCHMARK_REGISTRY:
 | 
			
		||||
@ -68,29 +67,29 @@ def run_benchmark(
 | 
			
		||||
        return False
 | 
			
		||||
 | 
			
		||||
    print(f"Running benchmark: {benchmark_name}")
 | 
			
		||||
    print(f"Torch compile mode: {compile_mode}")
 | 
			
		||||
    print(f"Torch compile mode: {script_args.compile_mode}")
 | 
			
		||||
    print("=" * 60)
 | 
			
		||||
 | 
			
		||||
    benchmark_class = BENCHMARK_REGISTRY[benchmark_name]
 | 
			
		||||
    benchmark = benchmark_class(compile_mode)
 | 
			
		||||
    benchmark = benchmark_class(script_args)
 | 
			
		||||
    benchmark.benchmark()
 | 
			
		||||
    if should_visualize:
 | 
			
		||||
    if script_args.visualize:
 | 
			
		||||
        benchmark.visualize()
 | 
			
		||||
 | 
			
		||||
    return True
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def run_all_benchmarks(should_visualize: bool = False, compile_mode: str = "default"):
 | 
			
		||||
def run_all_benchmarks(script_args):
 | 
			
		||||
    """Run all available benchmarks."""
 | 
			
		||||
    print("Running all benchmarks...")
 | 
			
		||||
    print(f"Torch compile mode: {compile_mode}")
 | 
			
		||||
    print(f"Torch compile mode: {script_args.compile_mode}")
 | 
			
		||||
    print("=" * 60)
 | 
			
		||||
 | 
			
		||||
    for name, cls in BENCHMARK_REGISTRY.items():
 | 
			
		||||
        print(f"\n{'=' * 20} {name.upper()} {'=' * 20}")
 | 
			
		||||
        benchmark = cls(compile_mode)
 | 
			
		||||
        benchmark = cls(script_args)
 | 
			
		||||
        benchmark.benchmark()
 | 
			
		||||
        if should_visualize:
 | 
			
		||||
        if script_args.visualize:
 | 
			
		||||
            benchmark.visualize()
 | 
			
		||||
        print()
 | 
			
		||||
 | 
			
		||||
@ -137,6 +136,19 @@ Examples:
 | 
			
		||||
        help="Torch compile mode to use (default: default)",
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--tolerance",
 | 
			
		||||
        type=float,
 | 
			
		||||
        default=None,
 | 
			
		||||
        help="Tolerance for the accuracy check",
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--exit-on-accuracy-failure",
 | 
			
		||||
        action="store_true",
 | 
			
		||||
        help="Whether to exit with an error message for accuracy failure",
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    args = parser.parse_args()
 | 
			
		||||
 | 
			
		||||
    # Handle list option
 | 
			
		||||
@ -146,7 +158,7 @@ Examples:
 | 
			
		||||
 | 
			
		||||
    # Handle all option
 | 
			
		||||
    if args.all:
 | 
			
		||||
        run_all_benchmarks(args.visualize, args.compile_mode)
 | 
			
		||||
        run_all_benchmarks(args)
 | 
			
		||||
        return
 | 
			
		||||
 | 
			
		||||
    # Handle specific benchmarks
 | 
			
		||||
@ -157,7 +169,7 @@ Examples:
 | 
			
		||||
        sys.exit(1)
 | 
			
		||||
 | 
			
		||||
    for benchmark_name in args.benchmarks:
 | 
			
		||||
        run_benchmark(benchmark_name, args.visualize, args.compile_mode)
 | 
			
		||||
        run_benchmark(benchmark_name, args)
 | 
			
		||||
        print()  # Add spacing between benchmarks
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -9,8 +9,8 @@ import torch.nn.functional as F
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class CrossEntropyForward(BenchmarkKernel):
 | 
			
		||||
    def __init__(self, compile_mode: str = "max-autotune-no-cudagraphs"):
 | 
			
		||||
        super().__init__(compile_mode)
 | 
			
		||||
    def __init__(self, script_args):
 | 
			
		||||
        super().__init__(script_args)
 | 
			
		||||
        self.available_backends = ["eager", "compiled", "quack", "liger"]
 | 
			
		||||
 | 
			
		||||
    def get_shapes(self) -> tuple[tuple[int, ...], ...]:
 | 
			
		||||
@ -106,8 +106,8 @@ class CrossEntropyForward(BenchmarkKernel):
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class CrossEntropyBackward(BenchmarkKernel):
 | 
			
		||||
    def __init__(self, compile_mode: str = "max-autotune-no-cudagraphs"):
 | 
			
		||||
        super().__init__(compile_mode)
 | 
			
		||||
    def __init__(self, script_args):
 | 
			
		||||
        super().__init__(script_args)
 | 
			
		||||
        self.available_backends = ["eager", "compiled", "quack", "liger"]
 | 
			
		||||
 | 
			
		||||
    def get_shapes(self) -> tuple[tuple[int, ...], ...]:
 | 
			
		||||
@ -194,8 +194,8 @@ class CrossEntropyBackward(BenchmarkKernel):
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class SoftmaxForward(BenchmarkKernel):
 | 
			
		||||
    def __init__(self, compile_mode: str = "max-autotune-no-cudagraphs"):
 | 
			
		||||
        super().__init__(compile_mode)
 | 
			
		||||
    def __init__(self, script_args):
 | 
			
		||||
        super().__init__(script_args)
 | 
			
		||||
        self.available_backends = ["eager", "compiled", "quack", "liger"]
 | 
			
		||||
 | 
			
		||||
    def get_shapes(self) -> tuple[tuple[int, ...], ...]:
 | 
			
		||||
@ -259,8 +259,8 @@ class SoftmaxForward(BenchmarkKernel):
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class SoftmaxBackward(BenchmarkKernel):
 | 
			
		||||
    def __init__(self, compile_mode: str = "max-autotune-no-cudagraphs"):
 | 
			
		||||
        super().__init__(compile_mode)
 | 
			
		||||
    def __init__(self, script_args):
 | 
			
		||||
        super().__init__(script_args)
 | 
			
		||||
        self.available_backends = ["eager", "compiled", "quack", "liger"]
 | 
			
		||||
 | 
			
		||||
    def get_shapes(self) -> tuple[tuple[int, ...], ...]:
 | 
			
		||||
@ -329,8 +329,8 @@ class SoftmaxBackward(BenchmarkKernel):
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class RMSNormForward(BenchmarkKernel):
 | 
			
		||||
    def __init__(self, compile_mode: str = "max-autotune-no-cudagraphs"):
 | 
			
		||||
        super().__init__(compile_mode)
 | 
			
		||||
    def __init__(self, script_args):
 | 
			
		||||
        super().__init__(script_args)
 | 
			
		||||
        self.available_backends = ["eager", "compiled", "quack", "liger"]
 | 
			
		||||
 | 
			
		||||
    def get_shapes(self) -> tuple[tuple[int, ...], ...]:
 | 
			
		||||
@ -383,7 +383,22 @@ class RMSNormForward(BenchmarkKernel):
 | 
			
		||||
        from quack.rmsnorm import _rmsnorm_fwd
 | 
			
		||||
 | 
			
		||||
        x, w = args
 | 
			
		||||
        return lambda: _rmsnorm_fwd(x, w, eps=1e-6)
 | 
			
		||||
        y = torch.empty_like(x)
 | 
			
		||||
 | 
			
		||||
        def quack_fwd():
 | 
			
		||||
            _rmsnorm_fwd(
 | 
			
		||||
                x,
 | 
			
		||||
                w,
 | 
			
		||||
                out=y,
 | 
			
		||||
                bias=None,
 | 
			
		||||
                rstd=None,
 | 
			
		||||
                residual=None,
 | 
			
		||||
                residual_out=None,
 | 
			
		||||
                eps=1e-6,
 | 
			
		||||
            )
 | 
			
		||||
            return y
 | 
			
		||||
 | 
			
		||||
        return quack_fwd
 | 
			
		||||
 | 
			
		||||
    def liger(self, args, kwargs) -> Any:
 | 
			
		||||
        from liger_kernel.transformers.rms_norm import LigerRMSNorm
 | 
			
		||||
@ -404,9 +419,14 @@ class RMSNormForward(BenchmarkKernel):
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class RMSNormBackward(BenchmarkKernel):
 | 
			
		||||
    def __init__(self, compile_mode: str = "max-autotune-no-cudagraphs"):
 | 
			
		||||
        super().__init__(compile_mode)
 | 
			
		||||
        self.available_backends = ["eager", "compiled", "quack", "liger"]
 | 
			
		||||
    def __init__(self, script_args):
 | 
			
		||||
        super().__init__(script_args)
 | 
			
		||||
        self.available_backends = [
 | 
			
		||||
            "eager",
 | 
			
		||||
            "compiled",
 | 
			
		||||
            "quack",
 | 
			
		||||
            "liger",
 | 
			
		||||
        ]
 | 
			
		||||
 | 
			
		||||
    def get_shapes(self) -> tuple[tuple[int, ...], ...]:
 | 
			
		||||
        # TODO: OOM for (32768, 65536) on h100
 | 
			
		||||
@ -454,8 +474,11 @@ class RMSNormBackward(BenchmarkKernel):
 | 
			
		||||
            y, [x, w], grad_outputs=dy, retain_graph=True
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def compute_rstd(self, x, eps):
 | 
			
		||||
        return torch.rsqrt(torch.mean(x.float().square(), dim=-1, keepdim=True) + eps)
 | 
			
		||||
 | 
			
		||||
    def quack(self, args, kwargs=None) -> Any:
 | 
			
		||||
        from quack.rmsnorm import _rmsnorm_backward
 | 
			
		||||
        from quack.rmsnorm import _get_sm_count, _rmsnorm_bwd
 | 
			
		||||
 | 
			
		||||
        (
 | 
			
		||||
            x,
 | 
			
		||||
@ -463,15 +486,40 @@ class RMSNormBackward(BenchmarkKernel):
 | 
			
		||||
            dy,
 | 
			
		||||
        ) = args
 | 
			
		||||
        M, N = x.shape
 | 
			
		||||
        rstd = torch.randn(M, device="cuda", dtype=torch.float32)
 | 
			
		||||
        return lambda: _rmsnorm_backward(x, w, dy, rstd)
 | 
			
		||||
 | 
			
		||||
        rstd = self.compute_rstd(x, eps=1e-6)
 | 
			
		||||
        dx = torch.empty_like(x)
 | 
			
		||||
        sm_count = _get_sm_count(x.size(1), x.device)
 | 
			
		||||
        dw_partial = torch.empty(
 | 
			
		||||
            sm_count, x.size(1), device=x.device, dtype=torch.float32
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        def quack_bwd():
 | 
			
		||||
            _rmsnorm_bwd(
 | 
			
		||||
                x,
 | 
			
		||||
                w,
 | 
			
		||||
                dy,
 | 
			
		||||
                rstd,
 | 
			
		||||
                dx,
 | 
			
		||||
                dw_partial,
 | 
			
		||||
                db_partial=None,
 | 
			
		||||
                dresidual_out=None,
 | 
			
		||||
                dresidual=None,
 | 
			
		||||
                sm_count=sm_count,
 | 
			
		||||
            )
 | 
			
		||||
            dw = dw_partial.sum(dim=0).to(w.dtype)
 | 
			
		||||
            return dx, dw
 | 
			
		||||
 | 
			
		||||
        return quack_bwd
 | 
			
		||||
 | 
			
		||||
    def liger(self, args, kwargs=None) -> Any:
 | 
			
		||||
        from liger_kernel.transformers.rms_norm import LigerRMSNorm
 | 
			
		||||
 | 
			
		||||
        x, w, dy = args
 | 
			
		||||
        M, N = x.shape
 | 
			
		||||
        liger_rmsnorm = LigerRMSNorm(hidden_size=N, eps=1e-6).cuda()
 | 
			
		||||
        liger_rmsnorm = LigerRMSNorm(
 | 
			
		||||
            hidden_size=N, eps=1e-6, casting_mode="gemma"
 | 
			
		||||
        ).cuda()
 | 
			
		||||
        liger_rmsnorm.weight.data.copy_(w)
 | 
			
		||||
        y = liger_rmsnorm(x)
 | 
			
		||||
        return lambda: torch.autograd.grad(
 | 
			
		||||
@ -489,8 +537,8 @@ class RMSNormBackward(BenchmarkKernel):
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class LayerNormForward(BenchmarkKernel):
 | 
			
		||||
    def __init__(self, compile_mode: str = "max-autotune-no-cudagraphs"):
 | 
			
		||||
        super().__init__(compile_mode)
 | 
			
		||||
    def __init__(self, script_args):
 | 
			
		||||
        super().__init__(script_args)
 | 
			
		||||
        self.available_backends = ["eager", "compiled", "quack", "liger"]
 | 
			
		||||
 | 
			
		||||
    def get_shapes(self) -> tuple[tuple[int, ...], ...]:
 | 
			
		||||
@ -563,8 +611,8 @@ class LayerNormForward(BenchmarkKernel):
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class LayerNormBackward(BenchmarkKernel):
 | 
			
		||||
    def __init__(self, compile_mode: str = "max-autotune-no-cudagraphs"):
 | 
			
		||||
        super().__init__(compile_mode)
 | 
			
		||||
    def __init__(self, script_args):
 | 
			
		||||
        super().__init__(script_args)
 | 
			
		||||
        self.available_backends = ["eager", "compiled", "liger"]
 | 
			
		||||
 | 
			
		||||
    def get_shapes(self) -> tuple[tuple[int, ...], ...]:
 | 
			
		||||
@ -614,20 +662,31 @@ class LayerNormBackward(BenchmarkKernel):
 | 
			
		||||
            y, [x, w], grad_outputs=dy, retain_graph=True
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def compute_mean_rstd(self, x, eps):
 | 
			
		||||
        x = x.float()
 | 
			
		||||
 | 
			
		||||
        var, mean = torch.var_mean(x, dim=-1, keepdim=True, correction=0)
 | 
			
		||||
        rstd = torch.rsqrt(var + eps)
 | 
			
		||||
        return mean, rstd
 | 
			
		||||
 | 
			
		||||
    def liger(self, args, kwargs) -> Any:
 | 
			
		||||
        from liger_kernel.transformers.layer_norm import LigerLayerNorm
 | 
			
		||||
        """
 | 
			
		||||
        Call layer_norm_backward directly rather than calling
 | 
			
		||||
        liger_kernel.transformers.layer_norm.LigerLayerNorm and
 | 
			
		||||
        torch.autograd.grad.
 | 
			
		||||
 | 
			
		||||
        The latter fashion saves mean/rstd in x.dtype which can fail
 | 
			
		||||
        accuracy test. We call layer_norm_backward with fp32 mean and
 | 
			
		||||
        rstd.
 | 
			
		||||
        """
 | 
			
		||||
        from liger_kernel.ops.layer_norm import layer_norm_backward
 | 
			
		||||
 | 
			
		||||
        x, w, dy = args
 | 
			
		||||
        eps = 1e-6
 | 
			
		||||
        mean, rstd = self.compute_mean_rstd(x, eps)
 | 
			
		||||
        M, N = x.shape
 | 
			
		||||
        liger_layernorm = LigerLayerNorm(hidden_size=N, eps=1e-6).cuda()
 | 
			
		||||
        liger_layernorm.weight.data.copy_(w)
 | 
			
		||||
        liger_layernorm.bias.data.copy_(
 | 
			
		||||
            torch.zeros(N, device="cuda", dtype=torch.float32)
 | 
			
		||||
        )
 | 
			
		||||
        y = liger_layernorm(x)
 | 
			
		||||
        return lambda: torch.autograd.grad(
 | 
			
		||||
            y, [x, liger_layernorm.weight], grad_outputs=dy, retain_graph=True
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        return lambda: layer_norm_backward(dy, x, w, None, mean, rstd)[0:2]
 | 
			
		||||
 | 
			
		||||
    def benchmark(self):
 | 
			
		||||
        for M, N in self.get_shapes():
 | 
			
		||||
 | 
			
		||||
@ -1,4 +1,5 @@
 | 
			
		||||
import os
 | 
			
		||||
import sys
 | 
			
		||||
from collections import defaultdict
 | 
			
		||||
from collections.abc import Callable
 | 
			
		||||
from dataclasses import dataclass
 | 
			
		||||
@ -43,10 +44,11 @@ class Performance:
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class BenchmarkKernel:
 | 
			
		||||
    def __init__(self, compile_mode: str = "max-autotune-no-cudagraphs"):
 | 
			
		||||
    def __init__(self, script_args):
 | 
			
		||||
        self.script_args = script_args
 | 
			
		||||
        self.name = self.__class__.__name__
 | 
			
		||||
        self.available_backends: list[str] = []
 | 
			
		||||
        self.compile_mode: str = compile_mode
 | 
			
		||||
        self.compile_mode: str = script_args.compile_mode
 | 
			
		||||
 | 
			
		||||
        # mapping from backend to list of performance results
 | 
			
		||||
        self.profiling_results: defaultdict[str, list[Performance]] = defaultdict(list)
 | 
			
		||||
@ -106,14 +108,21 @@ class BenchmarkKernel:
 | 
			
		||||
            args_ref, kwargs_ref = self.clone_inputs(args, kwargs)
 | 
			
		||||
            res[backend] = getattr(self, backend)(args_ref, kwargs_ref)()
 | 
			
		||||
        gold = res["eager"]
 | 
			
		||||
 | 
			
		||||
        tol = {}
 | 
			
		||||
        if self.script_args.tolerance:
 | 
			
		||||
            tol = {
 | 
			
		||||
                "atol": self.script_args.tolerance,
 | 
			
		||||
                "rtol": self.script_args.tolerance,
 | 
			
		||||
            }
 | 
			
		||||
        for backend in self.available_backends:
 | 
			
		||||
            if backend == "eager":
 | 
			
		||||
                continue
 | 
			
		||||
            try:
 | 
			
		||||
                torch.testing.assert_close(res[backend], gold)
 | 
			
		||||
                torch.testing.assert_close(res[backend], gold, **tol)
 | 
			
		||||
                for t, gold_t in zip(res[backend], gold):
 | 
			
		||||
                    if t.requires_grad:
 | 
			
		||||
                        torch.testing.assert_close(t.grad, gold_t.grad)
 | 
			
		||||
                        torch.testing.assert_close(t.grad, gold_t.grad, **tol)
 | 
			
		||||
                print(
 | 
			
		||||
                    f"Accuracy check \033[92m✓ succeed\033[0m for {backend} backend on {self.name} kernel"
 | 
			
		||||
                )
 | 
			
		||||
@ -121,6 +130,9 @@ class BenchmarkKernel:
 | 
			
		||||
                print(
 | 
			
		||||
                    f"Accuracy check \033[91m✗ failed\033[0m for {backend} backend on {self.name} kernel. Error {e}"
 | 
			
		||||
                )
 | 
			
		||||
                if self.script_args.exit_on_accuracy_failure:
 | 
			
		||||
                    print("Exit right away since --exit-on-accuracy-failure is set")
 | 
			
		||||
                    sys.exit(1)
 | 
			
		||||
 | 
			
		||||
    def benchmark_single_shape(
 | 
			
		||||
        self, args, kwargs=None, should_check_accuracy=True, setting: str = ""
 | 
			
		||||
 | 
			
		||||
@ -1,8 +1,8 @@
 | 
			
		||||
add_loop_eager,compile_time_instruction_count,3070000000,0.1
 | 
			
		||||
add_loop_eager,compile_time_instruction_count,3184000000,0.1
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
add_loop_eager_dynamic,compile_time_instruction_count,4432000000,0.1
 | 
			
		||||
add_loop_eager_dynamic,compile_time_instruction_count,4595000000,0.1
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -18,7 +18,7 @@ add_loop_inductor_gpu,compile_time_instruction_count,26800000000,0.1
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
basic_modules_ListOfLinears_eager,compile_time_instruction_count,1048000000,0.1
 | 
			
		||||
basic_modules_ListOfLinears_eager,compile_time_instruction_count,1096000000,0.1
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -26,7 +26,7 @@ basic_modules_ListOfLinears_inductor,compile_time_instruction_count,15240000000,
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,17020000000,0.1
 | 
			
		||||
basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,17720000000,0.1
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -34,11 +34,11 @@ basic_modules_ListOfLinears_inductor_gpu,compile_time_instruction_count,11090000
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
update_hint_regression,compile_time_instruction_count,1719000000,0.1
 | 
			
		||||
update_hint_regression,compile_time_instruction_count,1645000000,0.1
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
sum_floordiv_regression,compile_time_instruction_count,3686995725,0.1
 | 
			
		||||
sum_floordiv_regression,compile_time_instruction_count,3813000000,0.1
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -50,31 +50,31 @@ symint_sum_loop,compile_time_instruction_count,4299000000,0.1
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
aotdispatcher_inference_nosubclass_cpu,compile_time_instruction_count,1869000000,0.1
 | 
			
		||||
aotdispatcher_inference_nosubclass_cpu,compile_time_instruction_count,1793000000,0.1
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
aotdispatcher_inference_subclass_cpu,compile_time_instruction_count,5281000000,0.1
 | 
			
		||||
aotdispatcher_inference_subclass_cpu,compile_time_instruction_count,5120000000,0.1
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
aotdispatcher_partitioner_cpu,compile_time_instruction_count,8333000000,0.1
 | 
			
		||||
aotdispatcher_partitioner_cpu,compile_time_instruction_count,7936000000,0.1
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
aotdispatcher_partitioner_cpu2,compile_time_instruction_count,1909000000,0.1
 | 
			
		||||
aotdispatcher_partitioner_cpu2,compile_time_instruction_count,1848000000,0.1
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
aotdispatcher_training_nosubclass_cpu,compile_time_instruction_count,3442000000,0.1
 | 
			
		||||
aotdispatcher_training_nosubclass_cpu,compile_time_instruction_count,3152000000,0.1
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
aotdispatcher_training_subclass_cpu,compile_time_instruction_count,9239000000,0.1
 | 
			
		||||
aotdispatcher_training_subclass_cpu,compile_time_instruction_count,8301000000,0.1
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
mm_loop_inductor_gpu,compile_time_instruction_count,4820968837,0.1
 | 
			
		||||
mm_loop_inductor_gpu,compile_time_instruction_count,4958000000,0.1
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -82,8 +82,8 @@ mm_loop_inductor_dynamic_gpu,compile_time_instruction_count,9051000000,0.1
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
basic_NestedModule_eager,compile_time_instruction_count,9554000000,0.1
 | 
			
		||||
basic_NestedModule_eager,compile_time_instruction_count,9990000000,0.1
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
basic_InlineMod_eager,compile_time_instruction_count,7618000000,0.1
 | 
			
		||||
basic_InlineMod_eager,compile_time_instruction_count,8126000000,0.1
 | 
			
		||||
 | 
			
		||||
		
		
			
  | 
@ -43,6 +43,7 @@ tolerance:
 | 
			
		||||
    - doctr_reco_predictor
 | 
			
		||||
    - drq
 | 
			
		||||
    - phlippe_resnet
 | 
			
		||||
    - pytorch_CycleGAN_and_pix2pix
 | 
			
		||||
 | 
			
		||||
  higher_bf16:
 | 
			
		||||
    - doctr_reco_predictor
 | 
			
		||||
 | 
			
		||||
@ -44,21 +44,101 @@ PyTorch,div_,div__M1_N1_K1_cpu_dtype_onetorch.float32_dtype_twotorch.float32,sho
 | 
			
		||||
PyTorch,div_,div__M64_N64_K64_cpu_dtype_onetorch.float32_dtype_twotorch.float32,short,False,59.241161,0.000000
 | 
			
		||||
PyTorch,div_,div__M64_N64_K128_cpu_dtype_onetorch.float32_dtype_twotorch.float32,short,False,59.852816,0.000000
 | 
			
		||||
PyTorch,add,"add_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.float32",short,False,57.006677,0.000000
 | 
			
		||||
PyTorch,add,"add_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.bfloat16",short,False,88.167000,0.000000
 | 
			
		||||
PyTorch,add,"add_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.float64",short,False,57.519000,0.000000
 | 
			
		||||
PyTorch,sub,"sub_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.float32",short,False,55.606088,0.000000
 | 
			
		||||
PyTorch,sub,"sub_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.bfloat16",short,False,86.551000,0.000000
 | 
			
		||||
PyTorch,sub,"sub_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.float64",short,False,57.864088,0.000000
 | 
			
		||||
PyTorch,div,"div_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.float32",short,False,58.529255,0.000000
 | 
			
		||||
PyTorch,div,"div_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.bfloat16",short,False,71.641000,0.000000
 | 
			
		||||
PyTorch,div,"div_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.float64",short,False,83.073000,0.000000
 | 
			
		||||
PyTorch,mul,"mul_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.float32",short,False,54.645077,0.000000
 | 
			
		||||
PyTorch,mul,"mul_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.bfloat16",short,False,67.570000,0.000000
 | 
			
		||||
PyTorch,mul,"mul_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.float64",short,False,57.895000,0.000000
 | 
			
		||||
PyTorch,add,add_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,4.397014,0.000000
 | 
			
		||||
PyTorch,add,add_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,7.739000,0.000000
 | 
			
		||||
PyTorch,add,add_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,7.786000,0.000000
 | 
			
		||||
PyTorch,add,add_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,1.911000,0.000000
 | 
			
		||||
PyTorch,add,add_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,59.243500,0.000000
 | 
			
		||||
PyTorch,add,add_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,105.066000,0.000000
 | 
			
		||||
PyTorch,add,add_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,106.076000,0.000000
 | 
			
		||||
PyTorch,add,add_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,47.225000,0.000000
 | 
			
		||||
PyTorch,add,add_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,57.947691,0.000000
 | 
			
		||||
PyTorch,add,add_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,107.291000,0.000000
 | 
			
		||||
PyTorch,add,add_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,107.224000,0.000000
 | 
			
		||||
PyTorch,add,add_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,47.912000,0.000000
 | 
			
		||||
PyTorch,sub,sub_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,1.925851,0.000000
 | 
			
		||||
PyTorch,sub,sub_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,8.0240000,0.000000
 | 
			
		||||
PyTorch,sub,sub_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,8.069000,0.000000
 | 
			
		||||
PyTorch,sub,sub_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,1.938000,0.000000
 | 
			
		||||
PyTorch,sub,sub_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,57.308320,0.000000
 | 
			
		||||
PyTorch,sub,sub_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,107.091000,0.000000
 | 
			
		||||
PyTorch,sub,sub_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,108.710000,0.000000
 | 
			
		||||
PyTorch,sub,sub_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,47.502000,0.000000
 | 
			
		||||
PyTorch,sub,sub_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,57.787743,0.000000
 | 
			
		||||
PyTorch,sub,sub_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,108.863000,0.000000
 | 
			
		||||
PyTorch,sub,sub_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,108.939000,0.000000
 | 
			
		||||
PyTorch,sub,sub_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,47.603000,0.000000
 | 
			
		||||
PyTorch,div,div_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,7.978539,0.000000
 | 
			
		||||
PyTorch,div,div_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,8.741000,0.000000
 | 
			
		||||
PyTorch,div,div_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,8.757000,0.000000
 | 
			
		||||
PyTorch,div,div_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,8.774000,0.000000
 | 
			
		||||
PyTorch,div,div_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,159.754860,0.000000
 | 
			
		||||
PyTorch,div,div_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,165.552000,0.000000
 | 
			
		||||
PyTorch,div,div_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,165.755000,0.000000
 | 
			
		||||
PyTorch,div,div_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,165.714000,0.000000
 | 
			
		||||
PyTorch,div,div_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,165.360235,0.000000
 | 
			
		||||
PyTorch,div,div_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,168.376000,0.000000
 | 
			
		||||
PyTorch,div,div_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,169.604000,0.000000
 | 
			
		||||
PyTorch,div,div_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,168.428000,0.000000
 | 
			
		||||
PyTorch,mul,mul_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,3.928136,0.000000
 | 
			
		||||
PyTorch,mul,mul_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,7.402000,0.000000
 | 
			
		||||
PyTorch,mul,mul_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,7.567000,0.000000
 | 
			
		||||
PyTorch,mul,mul_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,4.020000,0.000000
 | 
			
		||||
PyTorch,mul,mul_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,56.413499,0.000000
 | 
			
		||||
PyTorch,mul,mul_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,104.638000,0.000000
 | 
			
		||||
PyTorch,mul,mul_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,104.335000,0.000000
 | 
			
		||||
PyTorch,mul,mul_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,48.612000,0.000000
 | 
			
		||||
PyTorch,mul,mul_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,55.925090,0.000000
 | 
			
		||||
PyTorch,mul,mul_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,106.110000,0.000000
 | 
			
		||||
PyTorch,mul,mul_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,106.389000,0.000000
 | 
			
		||||
PyTorch,mul,mul_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,48.195000,0.000000
 | 
			
		||||
PyTorch,asr,asr_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,1.989000,0.000000
 | 
			
		||||
PyTorch,asr,asr_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,7.999000,0.000000
 | 
			
		||||
PyTorch,asr,asr_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,7.939000,0.000000
 | 
			
		||||
PyTorch,asr,asr_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,1.980000,0.000000
 | 
			
		||||
PyTorch,asr,asr_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,54.408000,0.000000
 | 
			
		||||
PyTorch,asr,asr_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,105.647000,0.000000
 | 
			
		||||
PyTorch,asr,asr_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,106.476000,0.000000
 | 
			
		||||
PyTorch,asr,asr_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,48.784000,0.000000
 | 
			
		||||
PyTorch,asr,asr_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,55.583000,0.000000
 | 
			
		||||
PyTorch,asr,asr_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,108.083000,0.000000
 | 
			
		||||
PyTorch,asr,asr_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,107.663000,0.000000
 | 
			
		||||
PyTorch,asr,asr_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,48.283000,0.000000
 | 
			
		||||
PyTorch,lsl,lsl_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,1.986000,0.000000
 | 
			
		||||
PyTorch,lsl,lsl_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,7.676000,0.000000
 | 
			
		||||
PyTorch,lsl,lsl_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,7.618000,0.000000
 | 
			
		||||
PyTorch,lsl,lsl_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,1.982000,0.000000
 | 
			
		||||
PyTorch,lsl,lsl_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,54.698000,0.000000
 | 
			
		||||
PyTorch,lsl,lsl_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,105.899000,0.000000
 | 
			
		||||
PyTorch,lsl,lsl_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,106.741000,0.000000
 | 
			
		||||
PyTorch,lsl,lsl_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,51.182000,0.000000
 | 
			
		||||
PyTorch,lsl,lsl_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,55.290000,0.000000
 | 
			
		||||
PyTorch,lsl,lsl_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,107.744000,0.000000
 | 
			
		||||
PyTorch,lsl,lsl_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,107.820000,0.000000
 | 
			
		||||
PyTorch,lsl,lsl_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,51.298000,0.000000
 | 
			
		||||
PyTorch,xor,xor_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,1.988000,0.000000
 | 
			
		||||
PyTorch,xor,xor_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,7.689000,0.000000
 | 
			
		||||
PyTorch,xor,xor_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,7.695000,0.000000
 | 
			
		||||
PyTorch,xor,xor_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,1.978000,0.000000
 | 
			
		||||
PyTorch,xor,xor_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,54.934000,0.000000
 | 
			
		||||
PyTorch,xor,xor_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,105.217000,0.000000
 | 
			
		||||
PyTorch,xor,xor_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,104.215000,0.000000
 | 
			
		||||
PyTorch,xor,xor_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,47.115000,0.000000
 | 
			
		||||
PyTorch,xor,xor_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,55.974000,0.000000
 | 
			
		||||
PyTorch,xor,xor_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,106.828000,0.000000
 | 
			
		||||
PyTorch,xor,xor_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,106.879000,0.000000
 | 
			
		||||
PyTorch,xor,xor_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,48.197000,0.000000
 | 
			
		||||
PyTorch,logical_and,"logical_and_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.bool",short,False,78.404254,0.000000
 | 
			
		||||
PyTorch,logical_and,logical_and_M1_N1_K1_cpu_dtype_onetorch.bool_dtype_twotorch.bool,short,False,5.354032,0.000000
 | 
			
		||||
PyTorch,logical_and,logical_and_M64_N64_K64_cpu_dtype_onetorch.bool_dtype_twotorch.bool,short,False,54.072783,0.000000
 | 
			
		||||
@ -71,6 +151,9 @@ PyTorch,baddbmm,baddbmm_B2_M1_N8_K2_cpu_dtypetorch.float32,short,False,6.631313,
 | 
			
		||||
PyTorch,baddbmm,baddbmm_B2_M1_N8_K2_cpu_dtypetorch.bfloat16,short,False,6.476986,0.000000
 | 
			
		||||
PyTorch,baddbmm,baddbmm_B128_M64_N32_K64_cpu_dtypetorch.float32,short,False,266.065131,0.000000
 | 
			
		||||
PyTorch,baddbmm,baddbmm_B128_M64_N32_K64_cpu_dtypetorch.bfloat16,short,False,295.503063,0.000000
 | 
			
		||||
PyTorch,all,all_M1_N1_K1_cpu,short,False,5.773000,0.000000
 | 
			
		||||
PyTorch,all,all_M64_N64_K64_cpu,short,False,89.427000,0.000000
 | 
			
		||||
PyTorch,all,all_M64_N64_K128_cpu,short,False,120.119000,0.000000
 | 
			
		||||
PyTorch,cat,"cat_sizes(1,1,1)_N2_dim0_cpu",short,False,4.301950,0.000000
 | 
			
		||||
PyTorch,cat,"cat_sizes(512,512,2)_N2_dim1_cpu",short,False,99.093415,0.000000
 | 
			
		||||
PyTorch,cat,"cat_sizes(128,1024,2)_N2_dim1_cpu",short,False,96.771578,0.000000
 | 
			
		||||
 | 
			
		||||
		
		
			
  | 
@ -25,7 +25,7 @@ binary_configs_broadcast = op_bench.config_list(
 | 
			
		||||
    ],
 | 
			
		||||
    cross_product_configs={
 | 
			
		||||
        "device": ["cpu"],
 | 
			
		||||
        "dtype": [torch.float],
 | 
			
		||||
        "dtype": [torch.float, torch.bfloat16, torch.float64],
 | 
			
		||||
    },
 | 
			
		||||
    tags=["short"],
 | 
			
		||||
)
 | 
			
		||||
@ -71,8 +71,8 @@ binary_short_configs = op_bench.config_list(
 | 
			
		||||
    ],
 | 
			
		||||
    cross_product_configs={
 | 
			
		||||
        "device": ["cpu", "cuda"],
 | 
			
		||||
        "dtype_one": [torch.int32],
 | 
			
		||||
        "dtype_two": [torch.int32],
 | 
			
		||||
        "dtype_one": [torch.int32, torch.uint8],
 | 
			
		||||
        "dtype_two": [torch.int32, torch.uint8],
 | 
			
		||||
    },
 | 
			
		||||
    tags=["short"],
 | 
			
		||||
)
 | 
			
		||||
@ -82,8 +82,8 @@ binary_long_configs = op_bench.cross_product_configs(
 | 
			
		||||
    N=[32, 64],
 | 
			
		||||
    K=[256, 512],
 | 
			
		||||
    device=["cpu", "cuda"],
 | 
			
		||||
    dtype_one=[torch.int8, torch.int32],
 | 
			
		||||
    dtype_two=[torch.int8, torch.int32],
 | 
			
		||||
    dtype_one=[torch.int8, torch.int32, torch.uint8],
 | 
			
		||||
    dtype_two=[torch.int8, torch.int32, torch.uint8],
 | 
			
		||||
    tags=["long"],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							@ -176,8 +176,8 @@ THIRD_PARTY_LIBS = {
 | 
			
		||||
    "omp": ["//xplat/third-party/linker_lib:omp", "//third_party:no-op"],
 | 
			
		||||
    "pocketfft": ["//third-party/pocket_fft:pocketfft", "//third_party:pocketfft_header"],
 | 
			
		||||
    "psimd": ["//xplat/third-party/psimd:psimd", "//third_party:psimd"],
 | 
			
		||||
    "pthreadpool": ["//xplat/third-party/pthreadpool:pthreadpool", "//third_party:pthreadpool"],
 | 
			
		||||
    "pthreadpool_header": ["//xplat/third-party/pthreadpool:pthreadpool_header", "//third_party:pthreadpool_header"],
 | 
			
		||||
    "pthreadpool": ["fbsource//xplat/third-party/pthreadpool:pthreadpool", "//third_party:pthreadpool"],
 | 
			
		||||
    "pthreadpool_header": ["fbsource//xplat/third-party/pthreadpool:pthreadpool_header", "//third_party:pthreadpool_header"],
 | 
			
		||||
    "moodycamel": ["//third-party/moodycamel:moodycamel", "//third_party:moodycamel"],
 | 
			
		||||
    "pyyaml": ["//third-party/pypi/pyyaml:pyyaml", "//third_party:pyyaml"],
 | 
			
		||||
    "rt": ["//xplat/third-party/linker_lib:rt", "//third_party:rt"],
 | 
			
		||||
@ -1729,8 +1729,10 @@ def define_buck_targets(
 | 
			
		||||
            "torch/csrc/jit/backends/backend_debug_info.cpp",
 | 
			
		||||
            "torch/csrc/jit/backends/backend_interface.cpp",
 | 
			
		||||
        ],
 | 
			
		||||
        compiler_flags = get_pt_compiler_flags(),
 | 
			
		||||
        fbandroid_compiler_flags = c2_fbandroid_xplat_compiler_flags,
 | 
			
		||||
        compiler_flags = get_pt_compiler_flags() + select({
 | 
			
		||||
            "DEFAULT": [],
 | 
			
		||||
            "ovr_config//os:android": c2_fbandroid_xplat_compiler_flags
 | 
			
		||||
        }),
 | 
			
		||||
        # @lint-ignore BUCKLINT link_whole
 | 
			
		||||
        link_whole = True,
 | 
			
		||||
        linker_flags = get_no_as_needed_linker_flag(),
 | 
			
		||||
@ -2023,6 +2025,9 @@ def define_buck_targets(
 | 
			
		||||
                "ovr_config//os:android-x86_64": [
 | 
			
		||||
                    "-mssse3",
 | 
			
		||||
                ],
 | 
			
		||||
            }) + select({
 | 
			
		||||
                "DEFAULT": [],
 | 
			
		||||
                "ovr_config//os:android": c2_fbandroid_xplat_compiler_flags,
 | 
			
		||||
            }),
 | 
			
		||||
            exported_preprocessor_flags = get_aten_preprocessor_flags(),
 | 
			
		||||
            exported_deps = [
 | 
			
		||||
 | 
			
		||||
@ -855,6 +855,7 @@ libtorch_python_cuda_core_sources = [
 | 
			
		||||
    "torch/csrc/cuda/Stream.cpp",
 | 
			
		||||
    "torch/csrc/cuda/Graph.cpp",
 | 
			
		||||
    "torch/csrc/cuda/MemPool.cpp",
 | 
			
		||||
    "torch/csrc/cuda/GreenContext.cpp",
 | 
			
		||||
    "torch/csrc/cuda/shared/cudart.cpp",
 | 
			
		||||
    "torch/csrc/cuda/shared/nvtx.cpp",
 | 
			
		||||
    "torch/csrc/cuda/utils.cpp",
 | 
			
		||||
 | 
			
		||||
@ -13,7 +13,17 @@
 | 
			
		||||
namespace c10::CachingAllocator {
 | 
			
		||||
 | 
			
		||||
// "large" allocations may be packed in 20 MiB blocks
 | 
			
		||||
const size_t kLargeBuffer = 20971520;
 | 
			
		||||
constexpr size_t kLargeBuffer = 20971520;
 | 
			
		||||
// "small" allocations are packed in 2 MiB blocks
 | 
			
		||||
constexpr size_t kSmallBuffer = 2097152;
 | 
			
		||||
// all sizes are rounded to at least 512 bytes
 | 
			
		||||
constexpr size_t kMinBlockSize = 512;
 | 
			
		||||
// largest "small" allocation is 1 MiB
 | 
			
		||||
constexpr size_t kSmallSize = 1048576;
 | 
			
		||||
// allocations between 1 and 10 MiB may use kLargeBuffer
 | 
			
		||||
constexpr size_t kMinLargeAlloc = 10485760;
 | 
			
		||||
// round up large allocations to 2 MiB
 | 
			
		||||
constexpr size_t kRoundLarge = 2097152;
 | 
			
		||||
 | 
			
		||||
// A utility class for tokenizing allocator configuration strings into discrete
 | 
			
		||||
// parts. For example, the config string:
 | 
			
		||||
 | 
			
		||||
@ -223,7 +223,7 @@ inline DispatchKey backendToDispatchKey(Backend b) {
 | 
			
		||||
    case Backend::PrivateUse1:
 | 
			
		||||
      return DispatchKey::PrivateUse1;
 | 
			
		||||
    default:
 | 
			
		||||
      throw std::runtime_error("Unknown backend");
 | 
			
		||||
      TORCH_CHECK(false, "Unknown backend");
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -336,7 +336,7 @@ class C10_API Scalar {
 | 
			
		||||
    } else if (isBoolean()) {
 | 
			
		||||
      return ScalarType::Bool;
 | 
			
		||||
    } else {
 | 
			
		||||
      throw std::runtime_error("Unknown scalar type.");
 | 
			
		||||
      TORCH_CHECK(false, "Unknown scalar type.");
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -228,7 +228,7 @@ std::pair<std::string, std::string> getDtypeNames(c10::ScalarType scalarType) {
 | 
			
		||||
    case c10::ScalarType::Float4_e2m1fn_x2:
 | 
			
		||||
      return std::make_pair("float4_e2m1fn_x2", "");
 | 
			
		||||
    default:
 | 
			
		||||
      throw std::runtime_error("Unimplemented scalar type");
 | 
			
		||||
      TORCH_CHECK(false, "Unimplemented scalar type");
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -52,19 +52,6 @@ AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(SPECIALIZE_CppTypeToScalarType)
 | 
			
		||||
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_CONSTANT)
 | 
			
		||||
#undef DEFINE_CONSTANT
 | 
			
		||||
 | 
			
		||||
inline const char* toString(ScalarType t) {
 | 
			
		||||
#define DEFINE_CASE(_, name) \
 | 
			
		||||
  case ScalarType::name:     \
 | 
			
		||||
    return #name;
 | 
			
		||||
 | 
			
		||||
  switch (t) {
 | 
			
		||||
    AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_CASE)
 | 
			
		||||
    default:
 | 
			
		||||
      return "UNKNOWN_SCALAR";
 | 
			
		||||
  }
 | 
			
		||||
#undef DEFINE_CASE
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
inline size_t elementSize(ScalarType t) {
 | 
			
		||||
#define CASE_ELEMENTSIZE_CASE(ctype, name) \
 | 
			
		||||
  case ScalarType::name:                   \
 | 
			
		||||
@ -150,22 +137,6 @@ inline ScalarType toQIntType(ScalarType t) {
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
inline ScalarType toUnderlying(ScalarType t) {
 | 
			
		||||
  switch (t) {
 | 
			
		||||
    case ScalarType::QUInt8:
 | 
			
		||||
    case ScalarType::QUInt4x2:
 | 
			
		||||
      [[fallthrough]];
 | 
			
		||||
    case ScalarType::QUInt2x4:
 | 
			
		||||
      return ScalarType::Byte;
 | 
			
		||||
    case ScalarType::QInt8:
 | 
			
		||||
      return ScalarType::Char;
 | 
			
		||||
    case ScalarType::QInt32:
 | 
			
		||||
      return ScalarType::Int;
 | 
			
		||||
    default:
 | 
			
		||||
      return t;
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
inline bool isSignedType(ScalarType t) {
 | 
			
		||||
#define CASE_ISSIGNED(name)     \
 | 
			
		||||
  case ScalarType::name:        \
 | 
			
		||||
@ -308,12 +279,6 @@ inline bool canCast(const ScalarType from, const ScalarType to) {
 | 
			
		||||
 | 
			
		||||
C10_API ScalarType promoteTypes(ScalarType a, ScalarType b);
 | 
			
		||||
 | 
			
		||||
inline std::ostream& operator<<(
 | 
			
		||||
    std::ostream& stream,
 | 
			
		||||
    at::ScalarType scalar_type) {
 | 
			
		||||
  return stream << toString(scalar_type);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Returns a pair of strings representing the names for each dtype.
 | 
			
		||||
// The returned pair is (name, legacy_name_if_applicable)
 | 
			
		||||
C10_API std::pair<std::string, std::string> getDtypeNames(
 | 
			
		||||
 | 
			
		||||
@ -87,9 +87,7 @@ bool ThreadPool::inThreadPool() const {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void ThreadPool::run(std::function<void()> func) {
 | 
			
		||||
  if (threads_.empty()) {
 | 
			
		||||
    throw std::runtime_error("No threads to run a task");
 | 
			
		||||
  }
 | 
			
		||||
  TORCH_CHECK(threads_.size() > 0, "No threads to run a task");
 | 
			
		||||
  std::unique_lock<std::mutex> lock(mutex_);
 | 
			
		||||
 | 
			
		||||
  // Set task and signal condition variable so that a worker thread will
 | 
			
		||||
 | 
			
		||||
@ -131,15 +131,6 @@ namespace Native {
 | 
			
		||||
 *                  notifyCaptureDestroy.
 | 
			
		||||
 */
 | 
			
		||||
 | 
			
		||||
constexpr size_t kMinBlockSize =
 | 
			
		||||
    512; // all sizes are rounded to at least 512 bytes
 | 
			
		||||
constexpr size_t kSmallSize = 1048576; // largest "small" allocation is 1 MiB
 | 
			
		||||
constexpr size_t kSmallBuffer =
 | 
			
		||||
    2097152; // "small" allocations are packed in 2 MiB blocks
 | 
			
		||||
constexpr size_t kMinLargeAlloc =
 | 
			
		||||
    10485760; // allocations between 1 and 10 MiB may use kLargeBuffer
 | 
			
		||||
constexpr size_t kRoundLarge = 2097152; // round up large allocations to 2 MiB
 | 
			
		||||
 | 
			
		||||
static char SHAREABLE_HANDLE_VERSION = 2;
 | 
			
		||||
enum ShareableHandleType : char {
 | 
			
		||||
  SHAREABLE_CUDA_MALLOC = 'c',
 | 
			
		||||
@ -4478,7 +4469,10 @@ struct BackendStaticInitializer {
 | 
			
		||||
        if (key == "backend") {
 | 
			
		||||
          tokenizer.checkToken(++i, ":");
 | 
			
		||||
          i++; // Move to the value after the colon
 | 
			
		||||
          if (tokenizer[i] == "cudaMallocAsync"
 | 
			
		||||
          // break up token to trick hipify
 | 
			
		||||
          if (tokenizer[i] ==
 | 
			
		||||
                  "c"
 | 
			
		||||
                  "udaMallocAsync"
 | 
			
		||||
#ifdef USE_ROCM
 | 
			
		||||
              // convenience for ROCm users to allow either CUDA or HIP env var
 | 
			
		||||
              || tokenizer[i] == "hipMallocAsync"
 | 
			
		||||
 | 
			
		||||
@ -913,7 +913,9 @@ struct CudaMallocAsyncAllocator : public CUDAAllocator {
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
  std::string name() override {
 | 
			
		||||
    return "cudaMallocAsync";
 | 
			
		||||
    // break up token to trick hipify
 | 
			
		||||
    return "c"
 | 
			
		||||
           "udaMallocAsync";
 | 
			
		||||
  }
 | 
			
		||||
  void copy_data(void* dest, const void* src, std::size_t count) const final {
 | 
			
		||||
    C10_CUDA_CHECK(
 | 
			
		||||
 | 
			
		||||
Some files were not shown because too many files have changed in this diff Show More
		Reference in New Issue
	
	Block a user