mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-11-03 23:45:05 +08:00 
			
		
		
		
	Compare commits
	
		
			1 Commits
		
	
	
		
			revert-cpp
			...
			gh/jgong5/
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 09360d17ba | 
@ -150,7 +150,7 @@ function install_130 {
 | 
			
		||||
  CUDNN_VERSION=9.13.0.50
 | 
			
		||||
  echo "Installing CUDA 13.0 and cuDNN ${CUDNN_VERSION} and NVSHMEM and NCCL and cuSparseLt-0.7.1"
 | 
			
		||||
  # install CUDA 13.0 in the same container
 | 
			
		||||
  install_cuda 13.0.2 cuda_13.0.2_580.95.05_linux
 | 
			
		||||
  install_cuda 13.0.0 cuda_13.0.0_580.65.06_linux
 | 
			
		||||
 | 
			
		||||
  # cuDNN license: https://developer.nvidia.com/cudnn/license_agreement
 | 
			
		||||
  install_cudnn 13 $CUDNN_VERSION
 | 
			
		||||
 | 
			
		||||
@ -19,7 +19,7 @@ pip_install \
 | 
			
		||||
  transformers==4.36.2
 | 
			
		||||
 | 
			
		||||
pip_install coloredlogs packaging
 | 
			
		||||
pip_install onnxruntime==1.23.1
 | 
			
		||||
pip_install onnxruntime==1.23.0
 | 
			
		||||
pip_install onnxscript==0.5.4
 | 
			
		||||
 | 
			
		||||
# Cache the transformers model to be used later by ONNX tests. We need to run the transformers
 | 
			
		||||
 | 
			
		||||
@ -334,12 +334,12 @@ sympy==1.13.3
 | 
			
		||||
#Pinned versions:
 | 
			
		||||
#test that import:
 | 
			
		||||
 | 
			
		||||
onnx==1.19.1
 | 
			
		||||
onnx==1.18.0
 | 
			
		||||
#Description: Required by onnx tests, and mypy and test_public_bindings.py when checking torch.onnx._internal
 | 
			
		||||
#Pinned versions:
 | 
			
		||||
#test that import:
 | 
			
		||||
 | 
			
		||||
onnxscript==0.5.4
 | 
			
		||||
onnxscript==0.5.3
 | 
			
		||||
#Description: Required by mypy and test_public_bindings.py when checking torch.onnx._internal
 | 
			
		||||
#Pinned versions:
 | 
			
		||||
#test that import:
 | 
			
		||||
 | 
			
		||||
@ -6,7 +6,7 @@ dependencies = [
 | 
			
		||||
    "GitPython==3.1.45",
 | 
			
		||||
    "docker==7.1.0",
 | 
			
		||||
    "pytest==7.3.2",
 | 
			
		||||
    "uv==0.9.5"
 | 
			
		||||
    "uv==0.8.6"
 | 
			
		||||
]
 | 
			
		||||
 | 
			
		||||
[tool.setuptools]
 | 
			
		||||
 | 
			
		||||
@ -163,13 +163,8 @@ if [[ "$(uname)" != Darwin ]]; then
 | 
			
		||||
  MEMORY_LIMIT_MAX_JOBS=12
 | 
			
		||||
  NUM_CPUS=$(( $(nproc) - 2 ))
 | 
			
		||||
 | 
			
		||||
  if [[ "$(uname)" == Linux ]]; then
 | 
			
		||||
    # Defaults here for **binary** linux builds so they can be changed in one place
 | 
			
		||||
    export MAX_JOBS=${MAX_JOBS:-$(( ${NUM_CPUS} > ${MEMORY_LIMIT_MAX_JOBS} ? ${MEMORY_LIMIT_MAX_JOBS} : ${NUM_CPUS} ))}
 | 
			
		||||
  else
 | 
			
		||||
    # For other builds
 | 
			
		||||
    export MAX_JOBS=${NUM_CPUS}
 | 
			
		||||
  fi
 | 
			
		||||
  # Defaults here for **binary** linux builds so they can be changed in one place
 | 
			
		||||
  export MAX_JOBS=${MAX_JOBS:-$(( ${NUM_CPUS} > ${MEMORY_LIMIT_MAX_JOBS} ? ${MEMORY_LIMIT_MAX_JOBS} : ${NUM_CPUS} ))}
 | 
			
		||||
 | 
			
		||||
  cat >>"$envfile" <<EOL
 | 
			
		||||
  export MAX_JOBS="${MAX_JOBS}"
 | 
			
		||||
 | 
			
		||||
@ -1,359 +0,0 @@
 | 
			
		||||
---
 | 
			
		||||
name: docstring
 | 
			
		||||
description: Write docstrings for PyTorch functions and methods following PyTorch conventions. Use when writing or updating docstrings in PyTorch code.
 | 
			
		||||
---
 | 
			
		||||
 | 
			
		||||
# PyTorch Docstring Writing Guide
 | 
			
		||||
 | 
			
		||||
This skill describes how to write docstrings for functions and methods in the PyTorch project, following the conventions in `torch/_tensor_docs.py` and `torch/nn/functional.py`.
 | 
			
		||||
 | 
			
		||||
## General Principles
 | 
			
		||||
 | 
			
		||||
- Use **raw strings** (`r"""..."""`) for all docstrings to avoid issues with LaTeX/math backslashes
 | 
			
		||||
- Follow **Sphinx/reStructuredText** (reST) format for documentation
 | 
			
		||||
- Be **concise but complete** - include all essential information
 | 
			
		||||
- Always include **examples** when possible
 | 
			
		||||
- Use **cross-references** to related functions/classes
 | 
			
		||||
 | 
			
		||||
## Docstring Structure
 | 
			
		||||
 | 
			
		||||
### 1. Function Signature (First Line)
 | 
			
		||||
 | 
			
		||||
Start with the function signature showing all parameters:
 | 
			
		||||
 | 
			
		||||
```python
 | 
			
		||||
r"""function_name(param1, param2, *, kwarg1=default1, kwarg2=default2) -> ReturnType
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
**Notes:**
 | 
			
		||||
- Include the function name
 | 
			
		||||
- Show positional and keyword-only arguments (use `*` separator)
 | 
			
		||||
- Include default values
 | 
			
		||||
- Show return type annotation
 | 
			
		||||
- This line should NOT end with a period
 | 
			
		||||
 | 
			
		||||
### 2. Brief Description
 | 
			
		||||
 | 
			
		||||
Provide a one-line description of what the function does:
 | 
			
		||||
 | 
			
		||||
```python
 | 
			
		||||
r"""conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1) -> Tensor
 | 
			
		||||
 | 
			
		||||
Applies a 2D convolution over an input image composed of several input
 | 
			
		||||
planes.
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
### 3. Mathematical Formulas (if applicable)
 | 
			
		||||
 | 
			
		||||
Use Sphinx math directives for mathematical expressions:
 | 
			
		||||
 | 
			
		||||
```python
 | 
			
		||||
.. math::
 | 
			
		||||
    \text{Softmax}(x_{i}) = \frac{\exp(x_i)}{\sum_j \exp(x_j)}
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
Or inline math: `:math:\`x^2\``
 | 
			
		||||
 | 
			
		||||
### 4. Cross-References
 | 
			
		||||
 | 
			
		||||
Link to related classes and functions using Sphinx roles:
 | 
			
		||||
 | 
			
		||||
- `:class:\`~torch.nn.ModuleName\`` - Link to a class
 | 
			
		||||
- `:func:\`torch.function_name\`` - Link to a function
 | 
			
		||||
- `:meth:\`~Tensor.method_name\`` - Link to a method
 | 
			
		||||
- `:attr:\`attribute_name\`` - Reference an attribute
 | 
			
		||||
- The `~` prefix shows only the last component (e.g., `Conv2d` instead of `torch.nn.Conv2d`)
 | 
			
		||||
 | 
			
		||||
**Example:**
 | 
			
		||||
```python
 | 
			
		||||
See :class:`~torch.nn.Conv2d` for details and output shape.
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
### 5. Notes and Warnings
 | 
			
		||||
 | 
			
		||||
Use admonitions for important information:
 | 
			
		||||
 | 
			
		||||
```python
 | 
			
		||||
.. note::
 | 
			
		||||
    This function doesn't work directly with NLLLoss,
 | 
			
		||||
    which expects the Log to be computed between the Softmax and itself.
 | 
			
		||||
    Use log_softmax instead (it's faster and has better numerical properties).
 | 
			
		||||
 | 
			
		||||
.. warning::
 | 
			
		||||
    :func:`new_tensor` always copies :attr:`data`. If you have a Tensor
 | 
			
		||||
    ``data`` and want to avoid a copy, use :func:`torch.Tensor.requires_grad_`
 | 
			
		||||
    or :func:`torch.Tensor.detach`.
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
### 6. Args Section
 | 
			
		||||
 | 
			
		||||
Document all parameters with type annotations and descriptions:
 | 
			
		||||
 | 
			
		||||
```python
 | 
			
		||||
Args:
 | 
			
		||||
    input (Tensor): input tensor of shape :math:`(\text{minibatch} , \text{in\_channels} , iH , iW)`
 | 
			
		||||
    weight (Tensor): filters of shape :math:`(\text{out\_channels} , kH , kW)`
 | 
			
		||||
    bias (Tensor, optional): optional bias tensor of shape :math:`(\text{out\_channels})`. Default: ``None``
 | 
			
		||||
    stride (int or tuple): the stride of the convolving kernel. Can be a single number or a
 | 
			
		||||
      tuple `(sH, sW)`. Default: 1
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
**Formatting rules:**
 | 
			
		||||
- Parameter name in **lowercase**
 | 
			
		||||
- Type in parentheses: `(Type)`, `(Type, optional)` for optional parameters
 | 
			
		||||
- Description follows the type
 | 
			
		||||
- For optional parameters, include "Default: ``value``" at the end
 | 
			
		||||
- Use double backticks for inline code: ``` ``None`` ```
 | 
			
		||||
- Indent continuation lines by 2 spaces
 | 
			
		||||
 | 
			
		||||
### 7. Keyword Args Section (if applicable)
 | 
			
		||||
 | 
			
		||||
Sometimes keyword arguments are documented separately:
 | 
			
		||||
 | 
			
		||||
```python
 | 
			
		||||
Keyword args:
 | 
			
		||||
    dtype (:class:`torch.dtype`, optional): the desired type of returned tensor.
 | 
			
		||||
        Default: if None, same :class:`torch.dtype` as this tensor.
 | 
			
		||||
    device (:class:`torch.device`, optional): the desired device of returned tensor.
 | 
			
		||||
        Default: if None, same :class:`torch.device` as this tensor.
 | 
			
		||||
    requires_grad (bool, optional): If autograd should record operations on the
 | 
			
		||||
        returned tensor. Default: ``False``.
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
### 8. Returns Section (if needed)
 | 
			
		||||
 | 
			
		||||
Document the return value:
 | 
			
		||||
 | 
			
		||||
```python
 | 
			
		||||
Returns:
 | 
			
		||||
    Tensor: Sampled tensor of same shape as `logits` from the Gumbel-Softmax distribution.
 | 
			
		||||
        If ``hard=True``, the returned samples will be one-hot, otherwise they will
 | 
			
		||||
        be probability distributions that sum to 1 across `dim`.
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
Or simply include it in the function signature line if obvious from context.
 | 
			
		||||
 | 
			
		||||
### 9. Examples Section
 | 
			
		||||
 | 
			
		||||
Always include examples when possible:
 | 
			
		||||
 | 
			
		||||
```python
 | 
			
		||||
Examples::
 | 
			
		||||
 | 
			
		||||
    >>> inputs = torch.randn(33, 16, 30)
 | 
			
		||||
    >>> filters = torch.randn(20, 16, 5)
 | 
			
		||||
    >>> F.conv1d(inputs, filters)
 | 
			
		||||
 | 
			
		||||
    >>> # With square kernels and equal stride
 | 
			
		||||
    >>> filters = torch.randn(8, 4, 3, 3)
 | 
			
		||||
    >>> inputs = torch.randn(1, 4, 5, 5)
 | 
			
		||||
    >>> F.conv2d(inputs, filters, padding=1)
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
**Formatting rules:**
 | 
			
		||||
- Use `Examples::` with double colon
 | 
			
		||||
- Use `>>>` prompt for Python code
 | 
			
		||||
- Include comments with `#` when helpful
 | 
			
		||||
- Show actual output when it helps understanding (indent without `>>>`)
 | 
			
		||||
 | 
			
		||||
### 10. External References
 | 
			
		||||
 | 
			
		||||
Link to papers or external documentation:
 | 
			
		||||
 | 
			
		||||
```python
 | 
			
		||||
.. _Link Name:
 | 
			
		||||
    https://arxiv.org/abs/1611.00712
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
Reference them in text: ```See `Link Name`_```
 | 
			
		||||
 | 
			
		||||
## Method Types
 | 
			
		||||
 | 
			
		||||
### Native Python Functions
 | 
			
		||||
 | 
			
		||||
For regular Python functions, use a standard docstring:
 | 
			
		||||
 | 
			
		||||
```python
 | 
			
		||||
def relu(input: Tensor, inplace: bool = False) -> Tensor:
 | 
			
		||||
    r"""relu(input, inplace=False) -> Tensor
 | 
			
		||||
 | 
			
		||||
    Applies the rectified linear unit function element-wise. See
 | 
			
		||||
    :class:`~torch.nn.ReLU` for more details.
 | 
			
		||||
    """
 | 
			
		||||
    # implementation
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
### C-Bound Functions (using add_docstr)
 | 
			
		||||
 | 
			
		||||
For C-bound functions, use `_add_docstr`:
 | 
			
		||||
 | 
			
		||||
```python
 | 
			
		||||
conv1d = _add_docstr(
 | 
			
		||||
    torch.conv1d,
 | 
			
		||||
    r"""
 | 
			
		||||
conv1d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1) -> Tensor
 | 
			
		||||
 | 
			
		||||
Applies a 1D convolution over an input signal composed of several input
 | 
			
		||||
planes.
 | 
			
		||||
 | 
			
		||||
See :class:`~torch.nn.Conv1d` for details and output shape.
 | 
			
		||||
 | 
			
		||||
Args:
 | 
			
		||||
    input: input tensor of shape :math:`(\text{minibatch} , \text{in\_channels} , iW)`
 | 
			
		||||
    weight: filters of shape :math:`(\text{out\_channels} , kW)`
 | 
			
		||||
    ...
 | 
			
		||||
""",
 | 
			
		||||
)
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
### In-Place Variants
 | 
			
		||||
 | 
			
		||||
For in-place operations (ending with `_`), reference the original:
 | 
			
		||||
 | 
			
		||||
```python
 | 
			
		||||
add_docstr_all(
 | 
			
		||||
    "abs_",
 | 
			
		||||
    r"""
 | 
			
		||||
abs_() -> Tensor
 | 
			
		||||
 | 
			
		||||
In-place version of :meth:`~Tensor.abs`
 | 
			
		||||
""",
 | 
			
		||||
)
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
### Alias Functions
 | 
			
		||||
 | 
			
		||||
For aliases, simply reference the original:
 | 
			
		||||
 | 
			
		||||
```python
 | 
			
		||||
add_docstr_all(
 | 
			
		||||
    "absolute",
 | 
			
		||||
    r"""
 | 
			
		||||
absolute() -> Tensor
 | 
			
		||||
 | 
			
		||||
Alias for :func:`abs`
 | 
			
		||||
""",
 | 
			
		||||
)
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
## Common Patterns
 | 
			
		||||
 | 
			
		||||
### Shape Documentation
 | 
			
		||||
 | 
			
		||||
Use LaTeX math notation for tensor shapes:
 | 
			
		||||
 | 
			
		||||
```python
 | 
			
		||||
:math:`(\text{minibatch} , \text{in\_channels} , iH , iW)`
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
### Reusable Argument Definitions
 | 
			
		||||
 | 
			
		||||
For commonly used arguments, define them once and reuse:
 | 
			
		||||
 | 
			
		||||
```python
 | 
			
		||||
common_args = parse_kwargs(
 | 
			
		||||
    """
 | 
			
		||||
    dtype (:class:`torch.dtype`, optional): the desired type of returned tensor.
 | 
			
		||||
        Default: if None, same as this tensor.
 | 
			
		||||
"""
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
# Then use with .format():
 | 
			
		||||
r"""
 | 
			
		||||
...
 | 
			
		||||
 | 
			
		||||
Keyword args:
 | 
			
		||||
    {dtype}
 | 
			
		||||
    {device}
 | 
			
		||||
""".format(**common_args)
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
### Template Insertion
 | 
			
		||||
 | 
			
		||||
Insert reproducibility notes or other common text:
 | 
			
		||||
 | 
			
		||||
```python
 | 
			
		||||
r"""
 | 
			
		||||
{tf32_note}
 | 
			
		||||
 | 
			
		||||
{cudnn_reproducibility_note}
 | 
			
		||||
""".format(**reproducibility_notes, **tf32_notes)
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
## Complete Example
 | 
			
		||||
 | 
			
		||||
Here's a complete example showing all elements:
 | 
			
		||||
 | 
			
		||||
```python
 | 
			
		||||
def gumbel_softmax(
 | 
			
		||||
    logits: Tensor,
 | 
			
		||||
    tau: float = 1,
 | 
			
		||||
    hard: bool = False,
 | 
			
		||||
    eps: float = 1e-10,
 | 
			
		||||
    dim: int = -1,
 | 
			
		||||
) -> Tensor:
 | 
			
		||||
    r"""
 | 
			
		||||
    Sample from the Gumbel-Softmax distribution and optionally discretize.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        logits (Tensor): `[..., num_features]` unnormalized log probabilities
 | 
			
		||||
        tau (float): non-negative scalar temperature
 | 
			
		||||
        hard (bool): if ``True``, the returned samples will be discretized as one-hot vectors,
 | 
			
		||||
              but will be differentiated as if it is the soft sample in autograd. Default: ``False``
 | 
			
		||||
        dim (int): A dimension along which softmax will be computed. Default: -1
 | 
			
		||||
 | 
			
		||||
    Returns:
 | 
			
		||||
        Tensor: Sampled tensor of same shape as `logits` from the Gumbel-Softmax distribution.
 | 
			
		||||
            If ``hard=True``, the returned samples will be one-hot, otherwise they will
 | 
			
		||||
            be probability distributions that sum to 1 across `dim`.
 | 
			
		||||
 | 
			
		||||
    .. note::
 | 
			
		||||
        This function is here for legacy reasons, may be removed from nn.Functional in the future.
 | 
			
		||||
 | 
			
		||||
    Examples::
 | 
			
		||||
        >>> logits = torch.randn(20, 32)
 | 
			
		||||
        >>> # Sample soft categorical using reparametrization trick:
 | 
			
		||||
        >>> F.gumbel_softmax(logits, tau=1, hard=False)
 | 
			
		||||
        >>> # Sample hard categorical using "Straight-through" trick:
 | 
			
		||||
        >>> F.gumbel_softmax(logits, tau=1, hard=True)
 | 
			
		||||
 | 
			
		||||
    .. _Link 1:
 | 
			
		||||
        https://arxiv.org/abs/1611.00712
 | 
			
		||||
    """
 | 
			
		||||
    # implementation
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
## Quick Checklist
 | 
			
		||||
 | 
			
		||||
When writing a PyTorch docstring, ensure:
 | 
			
		||||
 | 
			
		||||
- [ ] Use raw string (`r"""`)
 | 
			
		||||
- [ ] Include function signature on first line
 | 
			
		||||
- [ ] Provide brief description
 | 
			
		||||
- [ ] Document all parameters in Args section with types
 | 
			
		||||
- [ ] Include default values for optional parameters
 | 
			
		||||
- [ ] Use Sphinx cross-references (`:func:`, `:class:`, `:meth:`)
 | 
			
		||||
- [ ] Add mathematical formulas if applicable
 | 
			
		||||
- [ ] Include at least one example in Examples section
 | 
			
		||||
- [ ] Add warnings/notes for important caveats
 | 
			
		||||
- [ ] Link to related module class with `:class:`
 | 
			
		||||
- [ ] Use proper math notation for tensor shapes
 | 
			
		||||
- [ ] Follow consistent formatting and indentation
 | 
			
		||||
 | 
			
		||||
## Common Sphinx Roles Reference
 | 
			
		||||
 | 
			
		||||
- `:class:\`~torch.nn.Module\`` - Class reference
 | 
			
		||||
- `:func:\`torch.function\`` - Function reference
 | 
			
		||||
- `:meth:\`~Tensor.method\`` - Method reference
 | 
			
		||||
- `:attr:\`attribute\`` - Attribute reference
 | 
			
		||||
- `:math:\`equation\`` - Inline math
 | 
			
		||||
- `:ref:\`label\`` - Internal reference
 | 
			
		||||
- ``` ``code`` ``` - Inline code (use double backticks)
 | 
			
		||||
 | 
			
		||||
## Additional Notes
 | 
			
		||||
 | 
			
		||||
- **Indentation**: Use 4 spaces for code, 2 spaces for continuation of parameter descriptions
 | 
			
		||||
- **Line length**: Try to keep lines under 100 characters when possible
 | 
			
		||||
- **Periods**: End sentences with periods, but not the signature line
 | 
			
		||||
- **Backticks**: Use double backticks for code: ``` ``True`` ``None`` ``False`` ```
 | 
			
		||||
- **Types**: Common types are `Tensor`, `int`, `float`, `bool`, `str`, `tuple`, `list`, etc.
 | 
			
		||||
@ -1,385 +0,0 @@
 | 
			
		||||
---
 | 
			
		||||
name: skill-writer
 | 
			
		||||
description: Guide users through creating Agent Skills for Claude Code. Use when the user wants to create, write, author, or design a new Skill, or needs help with SKILL.md files, frontmatter, or skill structure.
 | 
			
		||||
---
 | 
			
		||||
 | 
			
		||||
# Skill Writer
 | 
			
		||||
 | 
			
		||||
This Skill helps you create well-structured Agent Skills for Claude Code that follow best practices and validation requirements.
 | 
			
		||||
 | 
			
		||||
## When to use this Skill
 | 
			
		||||
 | 
			
		||||
Use this Skill when:
 | 
			
		||||
- Creating a new Agent Skill
 | 
			
		||||
- Writing or updating SKILL.md files
 | 
			
		||||
- Designing skill structure and frontmatter
 | 
			
		||||
- Troubleshooting skill discovery issues
 | 
			
		||||
- Converting existing prompts or workflows into Skills
 | 
			
		||||
 | 
			
		||||
## Instructions
 | 
			
		||||
 | 
			
		||||
### Step 1: Determine Skill scope
 | 
			
		||||
 | 
			
		||||
First, understand what the Skill should do:
 | 
			
		||||
 | 
			
		||||
1. **Ask clarifying questions**:
 | 
			
		||||
   - What specific capability should this Skill provide?
 | 
			
		||||
   - When should Claude use this Skill?
 | 
			
		||||
   - What tools or resources does it need?
 | 
			
		||||
   - Is this for personal use or team sharing?
 | 
			
		||||
 | 
			
		||||
2. **Keep it focused**: One Skill = one capability
 | 
			
		||||
   - Good: "PDF form filling", "Excel data analysis"
 | 
			
		||||
   - Too broad: "Document processing", "Data tools"
 | 
			
		||||
 | 
			
		||||
### Step 2: Choose Skill location
 | 
			
		||||
 | 
			
		||||
Determine where to create the Skill:
 | 
			
		||||
 | 
			
		||||
**Personal Skills** (`~/.claude/skills/`):
 | 
			
		||||
- Individual workflows and preferences
 | 
			
		||||
- Experimental Skills
 | 
			
		||||
- Personal productivity tools
 | 
			
		||||
 | 
			
		||||
**Project Skills** (`.claude/skills/`):
 | 
			
		||||
- Team workflows and conventions
 | 
			
		||||
- Project-specific expertise
 | 
			
		||||
- Shared utilities (committed to git)
 | 
			
		||||
 | 
			
		||||
### Step 3: Create Skill structure
 | 
			
		||||
 | 
			
		||||
Create the directory and files:
 | 
			
		||||
 | 
			
		||||
```bash
 | 
			
		||||
# Personal
 | 
			
		||||
mkdir -p ~/.claude/skills/skill-name
 | 
			
		||||
 | 
			
		||||
# Project
 | 
			
		||||
mkdir -p .claude/skills/skill-name
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
For multi-file Skills:
 | 
			
		||||
```
 | 
			
		||||
skill-name/
 | 
			
		||||
├── SKILL.md (required)
 | 
			
		||||
├── reference.md (optional)
 | 
			
		||||
├── examples.md (optional)
 | 
			
		||||
├── scripts/
 | 
			
		||||
│   └── helper.py (optional)
 | 
			
		||||
└── templates/
 | 
			
		||||
    └── template.txt (optional)
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
### Step 4: Write SKILL.md frontmatter
 | 
			
		||||
 | 
			
		||||
Create YAML frontmatter with required fields:
 | 
			
		||||
 | 
			
		||||
```yaml
 | 
			
		||||
---
 | 
			
		||||
name: skill-name
 | 
			
		||||
description: Brief description of what this does and when to use it
 | 
			
		||||
---
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
**Field requirements**:
 | 
			
		||||
 | 
			
		||||
- **name**:
 | 
			
		||||
  - Lowercase letters, numbers, hyphens only
 | 
			
		||||
  - Max 64 characters
 | 
			
		||||
  - Must match directory name
 | 
			
		||||
  - Good: `pdf-processor`, `git-commit-helper`
 | 
			
		||||
  - Bad: `PDF_Processor`, `Git Commits!`
 | 
			
		||||
 | 
			
		||||
- **description**:
 | 
			
		||||
  - Max 1024 characters
 | 
			
		||||
  - Include BOTH what it does AND when to use it
 | 
			
		||||
  - Use specific trigger words users would say
 | 
			
		||||
  - Mention file types, operations, and context
 | 
			
		||||
 | 
			
		||||
**Optional frontmatter fields**:
 | 
			
		||||
 | 
			
		||||
- **allowed-tools**: Restrict tool access (comma-separated list)
 | 
			
		||||
  ```yaml
 | 
			
		||||
  allowed-tools: Read, Grep, Glob
 | 
			
		||||
  ```
 | 
			
		||||
  Use for:
 | 
			
		||||
  - Read-only Skills
 | 
			
		||||
  - Security-sensitive workflows
 | 
			
		||||
  - Limited-scope operations
 | 
			
		||||
 | 
			
		||||
### Step 5: Write effective descriptions
 | 
			
		||||
 | 
			
		||||
The description is critical for Claude to discover your Skill.
 | 
			
		||||
 | 
			
		||||
**Formula**: `[What it does] + [When to use it] + [Key triggers]`
 | 
			
		||||
 | 
			
		||||
**Examples**:
 | 
			
		||||
 | 
			
		||||
✅ **Good**:
 | 
			
		||||
```yaml
 | 
			
		||||
description: Extract text and tables from PDF files, fill forms, merge documents. Use when working with PDF files or when the user mentions PDFs, forms, or document extraction.
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
✅ **Good**:
 | 
			
		||||
```yaml
 | 
			
		||||
description: Analyze Excel spreadsheets, create pivot tables, and generate charts. Use when working with Excel files, spreadsheets, or analyzing tabular data in .xlsx format.
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
❌ **Too vague**:
 | 
			
		||||
```yaml
 | 
			
		||||
description: Helps with documents
 | 
			
		||||
description: For data analysis
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
**Tips**:
 | 
			
		||||
- Include specific file extensions (.pdf, .xlsx, .json)
 | 
			
		||||
- Mention common user phrases ("analyze", "extract", "generate")
 | 
			
		||||
- List concrete operations (not generic verbs)
 | 
			
		||||
- Add context clues ("Use when...", "For...")
 | 
			
		||||
 | 
			
		||||
### Step 6: Structure the Skill content
 | 
			
		||||
 | 
			
		||||
Use clear Markdown sections:
 | 
			
		||||
 | 
			
		||||
```markdown
 | 
			
		||||
# Skill Name
 | 
			
		||||
 | 
			
		||||
Brief overview of what this Skill does.
 | 
			
		||||
 | 
			
		||||
## Quick start
 | 
			
		||||
 | 
			
		||||
Provide a simple example to get started immediately.
 | 
			
		||||
 | 
			
		||||
## Instructions
 | 
			
		||||
 | 
			
		||||
Step-by-step guidance for Claude:
 | 
			
		||||
1. First step with clear action
 | 
			
		||||
2. Second step with expected outcome
 | 
			
		||||
3. Handle edge cases
 | 
			
		||||
 | 
			
		||||
## Examples
 | 
			
		||||
 | 
			
		||||
Show concrete usage examples with code or commands.
 | 
			
		||||
 | 
			
		||||
## Best practices
 | 
			
		||||
 | 
			
		||||
- Key conventions to follow
 | 
			
		||||
- Common pitfalls to avoid
 | 
			
		||||
- When to use vs. not use
 | 
			
		||||
 | 
			
		||||
## Requirements
 | 
			
		||||
 | 
			
		||||
List any dependencies or prerequisites:
 | 
			
		||||
```bash
 | 
			
		||||
pip install package-name
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
## Advanced usage
 | 
			
		||||
 | 
			
		||||
For complex scenarios, see [reference.md](reference.md).
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
### Step 7: Add supporting files (optional)
 | 
			
		||||
 | 
			
		||||
Create additional files for progressive disclosure:
 | 
			
		||||
 | 
			
		||||
**reference.md**: Detailed API docs, advanced options
 | 
			
		||||
**examples.md**: Extended examples and use cases
 | 
			
		||||
**scripts/**: Helper scripts and utilities
 | 
			
		||||
**templates/**: File templates or boilerplate
 | 
			
		||||
 | 
			
		||||
Reference them from SKILL.md:
 | 
			
		||||
```markdown
 | 
			
		||||
For advanced usage, see [reference.md](reference.md).
 | 
			
		||||
 | 
			
		||||
Run the helper script:
 | 
			
		||||
\`\`\`bash
 | 
			
		||||
python scripts/helper.py input.txt
 | 
			
		||||
\`\`\`
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
### Step 8: Validate the Skill
 | 
			
		||||
 | 
			
		||||
Check these requirements:
 | 
			
		||||
 | 
			
		||||
✅ **File structure**:
 | 
			
		||||
- [ ] SKILL.md exists in correct location
 | 
			
		||||
- [ ] Directory name matches frontmatter `name`
 | 
			
		||||
 | 
			
		||||
✅ **YAML frontmatter**:
 | 
			
		||||
- [ ] Opening `---` on line 1
 | 
			
		||||
- [ ] Closing `---` before content
 | 
			
		||||
- [ ] Valid YAML (no tabs, correct indentation)
 | 
			
		||||
- [ ] `name` follows naming rules
 | 
			
		||||
- [ ] `description` is specific and < 1024 chars
 | 
			
		||||
 | 
			
		||||
✅ **Content quality**:
 | 
			
		||||
- [ ] Clear instructions for Claude
 | 
			
		||||
- [ ] Concrete examples provided
 | 
			
		||||
- [ ] Edge cases handled
 | 
			
		||||
- [ ] Dependencies listed (if any)
 | 
			
		||||
 | 
			
		||||
✅ **Testing**:
 | 
			
		||||
- [ ] Description matches user questions
 | 
			
		||||
- [ ] Skill activates on relevant queries
 | 
			
		||||
- [ ] Instructions are clear and actionable
 | 
			
		||||
 | 
			
		||||
### Step 9: Test the Skill
 | 
			
		||||
 | 
			
		||||
1. **Restart Claude Code** (if running) to load the Skill
 | 
			
		||||
 | 
			
		||||
2. **Ask relevant questions** that match the description:
 | 
			
		||||
   ```
 | 
			
		||||
   Can you help me extract text from this PDF?
 | 
			
		||||
   ```
 | 
			
		||||
 | 
			
		||||
3. **Verify activation**: Claude should use the Skill automatically
 | 
			
		||||
 | 
			
		||||
4. **Check behavior**: Confirm Claude follows the instructions correctly
 | 
			
		||||
 | 
			
		||||
### Step 10: Debug if needed
 | 
			
		||||
 | 
			
		||||
If Claude doesn't use the Skill:
 | 
			
		||||
 | 
			
		||||
1. **Make description more specific**:
 | 
			
		||||
   - Add trigger words
 | 
			
		||||
   - Include file types
 | 
			
		||||
   - Mention common user phrases
 | 
			
		||||
 | 
			
		||||
2. **Check file location**:
 | 
			
		||||
   ```bash
 | 
			
		||||
   ls ~/.claude/skills/skill-name/SKILL.md
 | 
			
		||||
   ls .claude/skills/skill-name/SKILL.md
 | 
			
		||||
   ```
 | 
			
		||||
 | 
			
		||||
3. **Validate YAML**:
 | 
			
		||||
   ```bash
 | 
			
		||||
   cat SKILL.md | head -n 10
 | 
			
		||||
   ```
 | 
			
		||||
 | 
			
		||||
4. **Run debug mode**:
 | 
			
		||||
   ```bash
 | 
			
		||||
   claude --debug
 | 
			
		||||
   ```
 | 
			
		||||
 | 
			
		||||
## Common patterns
 | 
			
		||||
 | 
			
		||||
### Read-only Skill
 | 
			
		||||
 | 
			
		||||
```yaml
 | 
			
		||||
---
 | 
			
		||||
name: code-reader
 | 
			
		||||
description: Read and analyze code without making changes. Use for code review, understanding codebases, or documentation.
 | 
			
		||||
allowed-tools: Read, Grep, Glob
 | 
			
		||||
---
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
### Script-based Skill
 | 
			
		||||
 | 
			
		||||
```yaml
 | 
			
		||||
---
 | 
			
		||||
name: data-processor
 | 
			
		||||
description: Process CSV and JSON data files with Python scripts. Use when analyzing data files or transforming datasets.
 | 
			
		||||
---
 | 
			
		||||
 | 
			
		||||
# Data Processor
 | 
			
		||||
 | 
			
		||||
## Instructions
 | 
			
		||||
 | 
			
		||||
1. Use the processing script:
 | 
			
		||||
\`\`\`bash
 | 
			
		||||
python scripts/process.py input.csv --output results.json
 | 
			
		||||
\`\`\`
 | 
			
		||||
 | 
			
		||||
2. Validate output with:
 | 
			
		||||
\`\`\`bash
 | 
			
		||||
python scripts/validate.py results.json
 | 
			
		||||
\`\`\`
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
### Multi-file Skill with progressive disclosure
 | 
			
		||||
 | 
			
		||||
```yaml
 | 
			
		||||
---
 | 
			
		||||
name: api-designer
 | 
			
		||||
description: Design REST APIs following best practices. Use when creating API endpoints, designing routes, or planning API architecture.
 | 
			
		||||
---
 | 
			
		||||
 | 
			
		||||
# API Designer
 | 
			
		||||
 | 
			
		||||
Quick start: See [examples.md](examples.md)
 | 
			
		||||
 | 
			
		||||
Detailed reference: See [reference.md](reference.md)
 | 
			
		||||
 | 
			
		||||
## Instructions
 | 
			
		||||
 | 
			
		||||
1. Gather requirements
 | 
			
		||||
2. Design endpoints (see examples.md)
 | 
			
		||||
3. Document with OpenAPI spec
 | 
			
		||||
4. Review against best practices (see reference.md)
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
## Best practices for Skill authors
 | 
			
		||||
 | 
			
		||||
1. **One Skill, one purpose**: Don't create mega-Skills
 | 
			
		||||
2. **Specific descriptions**: Include trigger words users will say
 | 
			
		||||
3. **Clear instructions**: Write for Claude, not humans
 | 
			
		||||
4. **Concrete examples**: Show real code, not pseudocode
 | 
			
		||||
5. **List dependencies**: Mention required packages in description
 | 
			
		||||
6. **Test with teammates**: Verify activation and clarity
 | 
			
		||||
7. **Version your Skills**: Document changes in content
 | 
			
		||||
8. **Use progressive disclosure**: Put advanced details in separate files
 | 
			
		||||
 | 
			
		||||
## Validation checklist
 | 
			
		||||
 | 
			
		||||
Before finalizing a Skill, verify:
 | 
			
		||||
 | 
			
		||||
- [ ] Name is lowercase, hyphens only, max 64 chars
 | 
			
		||||
- [ ] Description is specific and < 1024 chars
 | 
			
		||||
- [ ] Description includes "what" and "when"
 | 
			
		||||
- [ ] YAML frontmatter is valid
 | 
			
		||||
- [ ] Instructions are step-by-step
 | 
			
		||||
- [ ] Examples are concrete and realistic
 | 
			
		||||
- [ ] Dependencies are documented
 | 
			
		||||
- [ ] File paths use forward slashes
 | 
			
		||||
- [ ] Skill activates on relevant queries
 | 
			
		||||
- [ ] Claude follows instructions correctly
 | 
			
		||||
 | 
			
		||||
## Troubleshooting
 | 
			
		||||
 | 
			
		||||
**Skill doesn't activate**:
 | 
			
		||||
- Make description more specific with trigger words
 | 
			
		||||
- Include file types and operations in description
 | 
			
		||||
- Add "Use when..." clause with user phrases
 | 
			
		||||
 | 
			
		||||
**Multiple Skills conflict**:
 | 
			
		||||
- Make descriptions more distinct
 | 
			
		||||
- Use different trigger words
 | 
			
		||||
- Narrow the scope of each Skill
 | 
			
		||||
 | 
			
		||||
**Skill has errors**:
 | 
			
		||||
- Check YAML syntax (no tabs, proper indentation)
 | 
			
		||||
- Verify file paths (use forward slashes)
 | 
			
		||||
- Ensure scripts have execute permissions
 | 
			
		||||
- List all dependencies
 | 
			
		||||
 | 
			
		||||
## Examples
 | 
			
		||||
 | 
			
		||||
See the documentation for complete examples:
 | 
			
		||||
- Simple single-file Skill (commit-helper)
 | 
			
		||||
- Skill with tool permissions (code-reviewer)
 | 
			
		||||
- Multi-file Skill (pdf-processing)
 | 
			
		||||
 | 
			
		||||
## Output format
 | 
			
		||||
 | 
			
		||||
When creating a Skill, I will:
 | 
			
		||||
 | 
			
		||||
1. Ask clarifying questions about scope and requirements
 | 
			
		||||
2. Suggest a Skill name and location
 | 
			
		||||
3. Create the SKILL.md file with proper frontmatter
 | 
			
		||||
4. Include clear instructions and examples
 | 
			
		||||
5. Add supporting files if needed
 | 
			
		||||
6. Provide testing instructions
 | 
			
		||||
7. Validate against all requirements
 | 
			
		||||
 | 
			
		||||
The result will be a complete, working Skill that follows all best practices and validation rules.
 | 
			
		||||
							
								
								
									
										7
									
								
								.github/actions/setup-rocm/action.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										7
									
								
								.github/actions/setup-rocm/action.yml
									
									
									
									
										vendored
									
									
								
							@ -124,10 +124,3 @@ runs:
 | 
			
		||||
      id: login-ecr
 | 
			
		||||
      continue-on-error: true
 | 
			
		||||
      uses: aws-actions/amazon-ecr-login@062b18b96a7aff071d4dc91bc00c4c1a7945b076 # v2.0.1
 | 
			
		||||
 | 
			
		||||
    - name: Preserve github env variables for use in docker
 | 
			
		||||
      shell: bash
 | 
			
		||||
      run: |
 | 
			
		||||
        env | grep '^GITHUB' >> "${RUNNER_TEMP}/github_env_${GITHUB_RUN_ID}"
 | 
			
		||||
        env | grep '^CI' >> "${RUNNER_TEMP}/github_env_${GITHUB_RUN_ID}"
 | 
			
		||||
        env | grep '^RUNNER' >> "${RUNNER_TEMP}/github_env_${GITHUB_RUN_ID}"
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										2
									
								
								.github/ci_commit_pins/vision.txt
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.github/ci_commit_pins/vision.txt
									
									
									
									
										vendored
									
									
								
							@ -1 +1 @@
 | 
			
		||||
1752fe6809b74921644866275ab80244b96e80bc
 | 
			
		||||
faffd5cf673615583da6517275e361cb3dbc77e6
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										2
									
								
								.github/ci_commit_pins/xla.txt
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.github/ci_commit_pins/xla.txt
									
									
									
									
										vendored
									
									
								
							@ -1 +1 @@
 | 
			
		||||
df6798dfb931ce7c7fe5bed2447cd1092a5981af
 | 
			
		||||
0fa6e3129e61143224663e1ec67980d12b7ec4eb
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										5
									
								
								.github/ci_configs/vllm/Dockerfile
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										5
									
								
								.github/ci_configs/vllm/Dockerfile
									
									
									
									
										vendored
									
									
								
							@ -283,9 +283,6 @@ RUN --mount=type=bind,source=${TORCH_WHEELS_PATH},target=/dist \
 | 
			
		||||
        uv pip install --system $(cat torch_build_versions.txt | xargs) --index-url https://download.pytorch.org/whl/nightly/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.'); \
 | 
			
		||||
    fi
 | 
			
		||||
 | 
			
		||||
RUN --mount=type=cache,target=/root/.cache/uv \
 | 
			
		||||
    uv pip install --system --pre apache-tvm-ffi==0.1.0b15
 | 
			
		||||
 | 
			
		||||
# Install the vllm wheel from previous stage
 | 
			
		||||
RUN --mount=type=cache,target=/root/.cache/uv \
 | 
			
		||||
    uv pip install --system /wheels/vllm/*.whl --verbose
 | 
			
		||||
@ -298,8 +295,6 @@ RUN --mount=type=cache,target=/root/.cache/uv \
 | 
			
		||||
ARG torch_cuda_arch_list='8.0;8.9;9.0a;10.0a;12.0'
 | 
			
		||||
ENV TORCH_CUDA_ARCH_LIST=${torch_cuda_arch_list}
 | 
			
		||||
 | 
			
		||||
# TODO(elainewy): remove this once vllm commit is updated, and install flashinfer from pip
 | 
			
		||||
# see https://github.com/pytorch/pytorch/pull/165274#issuecomment-3408531784
 | 
			
		||||
ARG FLASHINFER_GIT_REPO="https://github.com/flashinfer-ai/flashinfer.git"
 | 
			
		||||
ARG FLASHINFER_GIT_REF="v0.2.14.post1"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										9
									
								
								.github/label_to_label.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										9
									
								
								.github/label_to_label.yml
									
									
									
									
										vendored
									
									
								
							@ -15,11 +15,6 @@
 | 
			
		||||
  - "module: reinplacing"
 | 
			
		||||
  then:
 | 
			
		||||
  - "module: pt2-dispatcher"
 | 
			
		||||
- any:
 | 
			
		||||
  - "vllm-compile"
 | 
			
		||||
  then:
 | 
			
		||||
  - "module: vllm"
 | 
			
		||||
  - "oncall: pt2"
 | 
			
		||||
- any:
 | 
			
		||||
  - "module: vmap"
 | 
			
		||||
  then:
 | 
			
		||||
@ -32,6 +27,10 @@
 | 
			
		||||
  - "module: pt2 optimizer"
 | 
			
		||||
  then:
 | 
			
		||||
  - "module: dynamo"
 | 
			
		||||
- any:
 | 
			
		||||
  - "module: flex attention"
 | 
			
		||||
  then:
 | 
			
		||||
  - "module: higher order operators"
 | 
			
		||||
- any:
 | 
			
		||||
  - "module: aotinductor"
 | 
			
		||||
  then:
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										22
									
								
								.github/scripts/generate_binary_build_matrix.py
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										22
									
								
								.github/scripts/generate_binary_build_matrix.py
									
									
									
									
										vendored
									
									
								
							@ -22,7 +22,7 @@ CUDA_ARCHES_FULL_VERSION = {
 | 
			
		||||
    "12.6": "12.6.3",
 | 
			
		||||
    "12.8": "12.8.1",
 | 
			
		||||
    "12.9": "12.9.1",
 | 
			
		||||
    "13.0": "13.0.2",
 | 
			
		||||
    "13.0": "13.0.0",
 | 
			
		||||
}
 | 
			
		||||
CUDA_ARCHES_CUDNN_VERSION = {
 | 
			
		||||
    "12.6": "9",
 | 
			
		||||
@ -96,21 +96,21 @@ PYTORCH_EXTRA_INSTALL_REQUIREMENTS = {
 | 
			
		||||
        "nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux'"
 | 
			
		||||
    ),
 | 
			
		||||
    "13.0": (
 | 
			
		||||
        "nvidia-cuda-nvrtc==13.0.88; platform_system == 'Linux' | "
 | 
			
		||||
        "nvidia-cuda-runtime==13.0.96; platform_system == 'Linux' | "
 | 
			
		||||
        "nvidia-cuda-cupti==13.0.85; platform_system == 'Linux' | "
 | 
			
		||||
        "nvidia-cuda-nvrtc==13.0.48; platform_system == 'Linux' | "
 | 
			
		||||
        "nvidia-cuda-runtime==13.0.48; platform_system == 'Linux' | "
 | 
			
		||||
        "nvidia-cuda-cupti==13.0.48; platform_system == 'Linux' | "
 | 
			
		||||
        "nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | "
 | 
			
		||||
        "nvidia-cublas==13.1.0.3; platform_system == 'Linux' | "
 | 
			
		||||
        "nvidia-cufft==12.0.0.61; platform_system == 'Linux' | "
 | 
			
		||||
        "nvidia-cublas==13.0.0.19; platform_system == 'Linux' | "
 | 
			
		||||
        "nvidia-cufft==12.0.0.15; platform_system == 'Linux' | "
 | 
			
		||||
        "nvidia-curand==10.4.0.35; platform_system == 'Linux' | "
 | 
			
		||||
        "nvidia-cusolver==12.0.4.66; platform_system == 'Linux' | "
 | 
			
		||||
        "nvidia-cusparse==12.6.3.3; platform_system == 'Linux' | "
 | 
			
		||||
        "nvidia-cusolver==12.0.3.29; platform_system == 'Linux' | "
 | 
			
		||||
        "nvidia-cusparse==12.6.2.49; platform_system == 'Linux' | "
 | 
			
		||||
        "nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | "
 | 
			
		||||
        "nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | "
 | 
			
		||||
        "nvidia-nvshmem-cu13==3.3.24; platform_system == 'Linux' | "
 | 
			
		||||
        "nvidia-nvtx==13.0.85; platform_system == 'Linux' | "
 | 
			
		||||
        "nvidia-nvjitlink==13.0.88; platform_system == 'Linux' | "
 | 
			
		||||
        "nvidia-cufile==1.15.1.6; platform_system == 'Linux'"
 | 
			
		||||
        "nvidia-nvtx==13.0.39; platform_system == 'Linux' | "
 | 
			
		||||
        "nvidia-nvjitlink==13.0.39; platform_system == 'Linux' | "
 | 
			
		||||
        "nvidia-cufile==1.15.0.42; platform_system == 'Linux'"
 | 
			
		||||
    ),
 | 
			
		||||
    "xpu": (
 | 
			
		||||
        "intel-cmplr-lib-rt==2025.2.1 | "
 | 
			
		||||
 | 
			
		||||
@ -79,9 +79,9 @@ jobs:
 | 
			
		||||
    runs-on: "windows-11-arm64-preview"
 | 
			
		||||
    {%- else %}
 | 
			
		||||
    {%- if branches == "nightly" %}
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
 | 
			
		||||
    {%- else %}
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge.nonephemeral"
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral"
 | 
			
		||||
    {%- endif %}
 | 
			
		||||
    {%- endif %}
 | 
			
		||||
    timeout-minutes: !{{ common.timeout_minutes_windows_binary }}
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										14
									
								
								.github/workflows/generated-linux-aarch64-binary-manywheel-nightly.yml
									
									
									
										generated
									
									
										vendored
									
									
								
							
							
						
						
									
										14
									
								
								.github/workflows/generated-linux-aarch64-binary-manywheel-nightly.yml
									
									
									
										generated
									
									
										vendored
									
									
								
							@ -270,7 +270,7 @@ jobs:
 | 
			
		||||
      ALPINE_IMAGE: "arm64v8/alpine"
 | 
			
		||||
      build_name: manywheel-py3_10-cuda-aarch64-13_0
 | 
			
		||||
      build_environment: linux-aarch64-binary-manywheel
 | 
			
		||||
      PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc==13.0.88; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.96; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.85; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.1.0.3; platform_system == 'Linux' | nvidia-cufft==12.0.0.61; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.4.66; platform_system == 'Linux' | nvidia-cusparse==12.6.3.3; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.3.24; platform_system == 'Linux' | nvidia-nvtx==13.0.85; platform_system == 'Linux' | nvidia-nvjitlink==13.0.88; platform_system == 'Linux' | nvidia-cufile==1.15.1.6; platform_system == 'Linux'
 | 
			
		||||
      PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc==13.0.48; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.48; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.48; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.0.0.19; platform_system == 'Linux' | nvidia-cufft==12.0.0.15; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.3.29; platform_system == 'Linux' | nvidia-cusparse==12.6.2.49; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.3.24; platform_system == 'Linux' | nvidia-nvtx==13.0.39; platform_system == 'Linux' | nvidia-nvjitlink==13.0.39; platform_system == 'Linux' | nvidia-cufile==1.15.0.42; platform_system == 'Linux'
 | 
			
		||||
      timeout-minutes: 420
 | 
			
		||||
    secrets:
 | 
			
		||||
      github-token: ${{ secrets.GITHUB_TOKEN }}
 | 
			
		||||
@ -519,7 +519,7 @@ jobs:
 | 
			
		||||
      ALPINE_IMAGE: "arm64v8/alpine"
 | 
			
		||||
      build_name: manywheel-py3_11-cuda-aarch64-13_0
 | 
			
		||||
      build_environment: linux-aarch64-binary-manywheel
 | 
			
		||||
      PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc==13.0.88; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.96; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.85; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.1.0.3; platform_system == 'Linux' | nvidia-cufft==12.0.0.61; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.4.66; platform_system == 'Linux' | nvidia-cusparse==12.6.3.3; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.3.24; platform_system == 'Linux' | nvidia-nvtx==13.0.85; platform_system == 'Linux' | nvidia-nvjitlink==13.0.88; platform_system == 'Linux' | nvidia-cufile==1.15.1.6; platform_system == 'Linux'
 | 
			
		||||
      PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc==13.0.48; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.48; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.48; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.0.0.19; platform_system == 'Linux' | nvidia-cufft==12.0.0.15; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.3.29; platform_system == 'Linux' | nvidia-cusparse==12.6.2.49; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.3.24; platform_system == 'Linux' | nvidia-nvtx==13.0.39; platform_system == 'Linux' | nvidia-nvjitlink==13.0.39; platform_system == 'Linux' | nvidia-cufile==1.15.0.42; platform_system == 'Linux'
 | 
			
		||||
      timeout-minutes: 420
 | 
			
		||||
    secrets:
 | 
			
		||||
      github-token: ${{ secrets.GITHUB_TOKEN }}
 | 
			
		||||
@ -768,7 +768,7 @@ jobs:
 | 
			
		||||
      ALPINE_IMAGE: "arm64v8/alpine"
 | 
			
		||||
      build_name: manywheel-py3_12-cuda-aarch64-13_0
 | 
			
		||||
      build_environment: linux-aarch64-binary-manywheel
 | 
			
		||||
      PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc==13.0.88; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.96; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.85; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.1.0.3; platform_system == 'Linux' | nvidia-cufft==12.0.0.61; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.4.66; platform_system == 'Linux' | nvidia-cusparse==12.6.3.3; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.3.24; platform_system == 'Linux' | nvidia-nvtx==13.0.85; platform_system == 'Linux' | nvidia-nvjitlink==13.0.88; platform_system == 'Linux' | nvidia-cufile==1.15.1.6; platform_system == 'Linux'
 | 
			
		||||
      PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc==13.0.48; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.48; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.48; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.0.0.19; platform_system == 'Linux' | nvidia-cufft==12.0.0.15; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.3.29; platform_system == 'Linux' | nvidia-cusparse==12.6.2.49; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.3.24; platform_system == 'Linux' | nvidia-nvtx==13.0.39; platform_system == 'Linux' | nvidia-nvjitlink==13.0.39; platform_system == 'Linux' | nvidia-cufile==1.15.0.42; platform_system == 'Linux'
 | 
			
		||||
      timeout-minutes: 420
 | 
			
		||||
    secrets:
 | 
			
		||||
      github-token: ${{ secrets.GITHUB_TOKEN }}
 | 
			
		||||
@ -1017,7 +1017,7 @@ jobs:
 | 
			
		||||
      ALPINE_IMAGE: "arm64v8/alpine"
 | 
			
		||||
      build_name: manywheel-py3_13-cuda-aarch64-13_0
 | 
			
		||||
      build_environment: linux-aarch64-binary-manywheel
 | 
			
		||||
      PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc==13.0.88; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.96; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.85; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.1.0.3; platform_system == 'Linux' | nvidia-cufft==12.0.0.61; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.4.66; platform_system == 'Linux' | nvidia-cusparse==12.6.3.3; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.3.24; platform_system == 'Linux' | nvidia-nvtx==13.0.85; platform_system == 'Linux' | nvidia-nvjitlink==13.0.88; platform_system == 'Linux' | nvidia-cufile==1.15.1.6; platform_system == 'Linux'
 | 
			
		||||
      PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc==13.0.48; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.48; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.48; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.0.0.19; platform_system == 'Linux' | nvidia-cufft==12.0.0.15; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.3.29; platform_system == 'Linux' | nvidia-cusparse==12.6.2.49; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.3.24; platform_system == 'Linux' | nvidia-nvtx==13.0.39; platform_system == 'Linux' | nvidia-nvjitlink==13.0.39; platform_system == 'Linux' | nvidia-cufile==1.15.0.42; platform_system == 'Linux'
 | 
			
		||||
      timeout-minutes: 420
 | 
			
		||||
    secrets:
 | 
			
		||||
      github-token: ${{ secrets.GITHUB_TOKEN }}
 | 
			
		||||
@ -1266,7 +1266,7 @@ jobs:
 | 
			
		||||
      ALPINE_IMAGE: "arm64v8/alpine"
 | 
			
		||||
      build_name: manywheel-py3_13t-cuda-aarch64-13_0
 | 
			
		||||
      build_environment: linux-aarch64-binary-manywheel
 | 
			
		||||
      PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc==13.0.88; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.96; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.85; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.1.0.3; platform_system == 'Linux' | nvidia-cufft==12.0.0.61; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.4.66; platform_system == 'Linux' | nvidia-cusparse==12.6.3.3; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.3.24; platform_system == 'Linux' | nvidia-nvtx==13.0.85; platform_system == 'Linux' | nvidia-nvjitlink==13.0.88; platform_system == 'Linux' | nvidia-cufile==1.15.1.6; platform_system == 'Linux'
 | 
			
		||||
      PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc==13.0.48; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.48; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.48; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.0.0.19; platform_system == 'Linux' | nvidia-cufft==12.0.0.15; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.3.29; platform_system == 'Linux' | nvidia-cusparse==12.6.2.49; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.3.24; platform_system == 'Linux' | nvidia-nvtx==13.0.39; platform_system == 'Linux' | nvidia-nvjitlink==13.0.39; platform_system == 'Linux' | nvidia-cufile==1.15.0.42; platform_system == 'Linux'
 | 
			
		||||
      timeout-minutes: 420
 | 
			
		||||
    secrets:
 | 
			
		||||
      github-token: ${{ secrets.GITHUB_TOKEN }}
 | 
			
		||||
@ -1515,7 +1515,7 @@ jobs:
 | 
			
		||||
      ALPINE_IMAGE: "arm64v8/alpine"
 | 
			
		||||
      build_name: manywheel-py3_14-cuda-aarch64-13_0
 | 
			
		||||
      build_environment: linux-aarch64-binary-manywheel
 | 
			
		||||
      PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc==13.0.88; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.96; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.85; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.1.0.3; platform_system == 'Linux' | nvidia-cufft==12.0.0.61; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.4.66; platform_system == 'Linux' | nvidia-cusparse==12.6.3.3; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.3.24; platform_system == 'Linux' | nvidia-nvtx==13.0.85; platform_system == 'Linux' | nvidia-nvjitlink==13.0.88; platform_system == 'Linux' | nvidia-cufile==1.15.1.6; platform_system == 'Linux'
 | 
			
		||||
      PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc==13.0.48; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.48; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.48; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.0.0.19; platform_system == 'Linux' | nvidia-cufft==12.0.0.15; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.3.29; platform_system == 'Linux' | nvidia-cusparse==12.6.2.49; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.3.24; platform_system == 'Linux' | nvidia-nvtx==13.0.39; platform_system == 'Linux' | nvidia-nvjitlink==13.0.39; platform_system == 'Linux' | nvidia-cufile==1.15.0.42; platform_system == 'Linux'
 | 
			
		||||
      timeout-minutes: 420
 | 
			
		||||
    secrets:
 | 
			
		||||
      github-token: ${{ secrets.GITHUB_TOKEN }}
 | 
			
		||||
@ -1764,7 +1764,7 @@ jobs:
 | 
			
		||||
      ALPINE_IMAGE: "arm64v8/alpine"
 | 
			
		||||
      build_name: manywheel-py3_14t-cuda-aarch64-13_0
 | 
			
		||||
      build_environment: linux-aarch64-binary-manywheel
 | 
			
		||||
      PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc==13.0.88; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.96; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.85; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.1.0.3; platform_system == 'Linux' | nvidia-cufft==12.0.0.61; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.4.66; platform_system == 'Linux' | nvidia-cusparse==12.6.3.3; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.3.24; platform_system == 'Linux' | nvidia-nvtx==13.0.85; platform_system == 'Linux' | nvidia-nvjitlink==13.0.88; platform_system == 'Linux' | nvidia-cufile==1.15.1.6; platform_system == 'Linux'
 | 
			
		||||
      PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc==13.0.48; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.48; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.48; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.0.0.19; platform_system == 'Linux' | nvidia-cufft==12.0.0.15; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.3.29; platform_system == 'Linux' | nvidia-cusparse==12.6.2.49; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.3.24; platform_system == 'Linux' | nvidia-nvtx==13.0.39; platform_system == 'Linux' | nvidia-nvjitlink==13.0.39; platform_system == 'Linux' | nvidia-cufile==1.15.0.42; platform_system == 'Linux'
 | 
			
		||||
      timeout-minutes: 420
 | 
			
		||||
    secrets:
 | 
			
		||||
      github-token: ${{ secrets.GITHUB_TOKEN }}
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										14
									
								
								.github/workflows/generated-linux-binary-manywheel-nightly.yml
									
									
									
										generated
									
									
										vendored
									
									
								
							
							
						
						
									
										14
									
								
								.github/workflows/generated-linux-binary-manywheel-nightly.yml
									
									
									
										generated
									
									
										vendored
									
									
								
							@ -325,7 +325,7 @@ jobs:
 | 
			
		||||
      runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
 | 
			
		||||
      build_name: manywheel-py3_10-cuda13_0
 | 
			
		||||
      build_environment: linux-binary-manywheel
 | 
			
		||||
      PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc==13.0.88; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.96; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.85; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.1.0.3; platform_system == 'Linux' | nvidia-cufft==12.0.0.61; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.4.66; platform_system == 'Linux' | nvidia-cusparse==12.6.3.3; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.3.24; platform_system == 'Linux' | nvidia-nvtx==13.0.85; platform_system == 'Linux' | nvidia-nvjitlink==13.0.88; platform_system == 'Linux' | nvidia-cufile==1.15.1.6; platform_system == 'Linux'
 | 
			
		||||
      PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc==13.0.48; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.48; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.48; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.0.0.19; platform_system == 'Linux' | nvidia-cufft==12.0.0.15; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.3.29; platform_system == 'Linux' | nvidia-cusparse==12.6.2.49; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.3.24; platform_system == 'Linux' | nvidia-nvtx==13.0.39; platform_system == 'Linux' | nvidia-nvjitlink==13.0.39; platform_system == 'Linux' | nvidia-cufile==1.15.0.42; platform_system == 'Linux'
 | 
			
		||||
    secrets:
 | 
			
		||||
      github-token: ${{ secrets.GITHUB_TOKEN }}
 | 
			
		||||
  manywheel-py3_10-cuda13_0-test:  # Testing
 | 
			
		||||
@ -991,7 +991,7 @@ jobs:
 | 
			
		||||
      runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
 | 
			
		||||
      build_name: manywheel-py3_11-cuda13_0
 | 
			
		||||
      build_environment: linux-binary-manywheel
 | 
			
		||||
      PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc==13.0.88; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.96; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.85; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.1.0.3; platform_system == 'Linux' | nvidia-cufft==12.0.0.61; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.4.66; platform_system == 'Linux' | nvidia-cusparse==12.6.3.3; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.3.24; platform_system == 'Linux' | nvidia-nvtx==13.0.85; platform_system == 'Linux' | nvidia-nvjitlink==13.0.88; platform_system == 'Linux' | nvidia-cufile==1.15.1.6; platform_system == 'Linux'
 | 
			
		||||
      PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc==13.0.48; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.48; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.48; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.0.0.19; platform_system == 'Linux' | nvidia-cufft==12.0.0.15; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.3.29; platform_system == 'Linux' | nvidia-cusparse==12.6.2.49; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.3.24; platform_system == 'Linux' | nvidia-nvtx==13.0.39; platform_system == 'Linux' | nvidia-nvjitlink==13.0.39; platform_system == 'Linux' | nvidia-cufile==1.15.0.42; platform_system == 'Linux'
 | 
			
		||||
    secrets:
 | 
			
		||||
      github-token: ${{ secrets.GITHUB_TOKEN }}
 | 
			
		||||
  manywheel-py3_11-cuda13_0-test:  # Testing
 | 
			
		||||
@ -1657,7 +1657,7 @@ jobs:
 | 
			
		||||
      runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
 | 
			
		||||
      build_name: manywheel-py3_12-cuda13_0
 | 
			
		||||
      build_environment: linux-binary-manywheel
 | 
			
		||||
      PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc==13.0.88; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.96; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.85; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.1.0.3; platform_system == 'Linux' | nvidia-cufft==12.0.0.61; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.4.66; platform_system == 'Linux' | nvidia-cusparse==12.6.3.3; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.3.24; platform_system == 'Linux' | nvidia-nvtx==13.0.85; platform_system == 'Linux' | nvidia-nvjitlink==13.0.88; platform_system == 'Linux' | nvidia-cufile==1.15.1.6; platform_system == 'Linux'
 | 
			
		||||
      PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc==13.0.48; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.48; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.48; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.0.0.19; platform_system == 'Linux' | nvidia-cufft==12.0.0.15; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.3.29; platform_system == 'Linux' | nvidia-cusparse==12.6.2.49; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.3.24; platform_system == 'Linux' | nvidia-nvtx==13.0.39; platform_system == 'Linux' | nvidia-nvjitlink==13.0.39; platform_system == 'Linux' | nvidia-cufile==1.15.0.42; platform_system == 'Linux'
 | 
			
		||||
    secrets:
 | 
			
		||||
      github-token: ${{ secrets.GITHUB_TOKEN }}
 | 
			
		||||
  manywheel-py3_12-cuda13_0-test:  # Testing
 | 
			
		||||
@ -2323,7 +2323,7 @@ jobs:
 | 
			
		||||
      runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
 | 
			
		||||
      build_name: manywheel-py3_13-cuda13_0
 | 
			
		||||
      build_environment: linux-binary-manywheel
 | 
			
		||||
      PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc==13.0.88; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.96; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.85; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.1.0.3; platform_system == 'Linux' | nvidia-cufft==12.0.0.61; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.4.66; platform_system == 'Linux' | nvidia-cusparse==12.6.3.3; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.3.24; platform_system == 'Linux' | nvidia-nvtx==13.0.85; platform_system == 'Linux' | nvidia-nvjitlink==13.0.88; platform_system == 'Linux' | nvidia-cufile==1.15.1.6; platform_system == 'Linux'
 | 
			
		||||
      PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc==13.0.48; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.48; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.48; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.0.0.19; platform_system == 'Linux' | nvidia-cufft==12.0.0.15; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.3.29; platform_system == 'Linux' | nvidia-cusparse==12.6.2.49; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.3.24; platform_system == 'Linux' | nvidia-nvtx==13.0.39; platform_system == 'Linux' | nvidia-nvjitlink==13.0.39; platform_system == 'Linux' | nvidia-cufile==1.15.0.42; platform_system == 'Linux'
 | 
			
		||||
    secrets:
 | 
			
		||||
      github-token: ${{ secrets.GITHUB_TOKEN }}
 | 
			
		||||
  manywheel-py3_13-cuda13_0-test:  # Testing
 | 
			
		||||
@ -2989,7 +2989,7 @@ jobs:
 | 
			
		||||
      runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
 | 
			
		||||
      build_name: manywheel-py3_13t-cuda13_0
 | 
			
		||||
      build_environment: linux-binary-manywheel
 | 
			
		||||
      PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc==13.0.88; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.96; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.85; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.1.0.3; platform_system == 'Linux' | nvidia-cufft==12.0.0.61; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.4.66; platform_system == 'Linux' | nvidia-cusparse==12.6.3.3; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.3.24; platform_system == 'Linux' | nvidia-nvtx==13.0.85; platform_system == 'Linux' | nvidia-nvjitlink==13.0.88; platform_system == 'Linux' | nvidia-cufile==1.15.1.6; platform_system == 'Linux'
 | 
			
		||||
      PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc==13.0.48; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.48; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.48; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.0.0.19; platform_system == 'Linux' | nvidia-cufft==12.0.0.15; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.3.29; platform_system == 'Linux' | nvidia-cusparse==12.6.2.49; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.3.24; platform_system == 'Linux' | nvidia-nvtx==13.0.39; platform_system == 'Linux' | nvidia-nvjitlink==13.0.39; platform_system == 'Linux' | nvidia-cufile==1.15.0.42; platform_system == 'Linux'
 | 
			
		||||
    secrets:
 | 
			
		||||
      github-token: ${{ secrets.GITHUB_TOKEN }}
 | 
			
		||||
  manywheel-py3_13t-cuda13_0-test:  # Testing
 | 
			
		||||
@ -3655,7 +3655,7 @@ jobs:
 | 
			
		||||
      runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
 | 
			
		||||
      build_name: manywheel-py3_14-cuda13_0
 | 
			
		||||
      build_environment: linux-binary-manywheel
 | 
			
		||||
      PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc==13.0.88; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.96; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.85; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.1.0.3; platform_system == 'Linux' | nvidia-cufft==12.0.0.61; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.4.66; platform_system == 'Linux' | nvidia-cusparse==12.6.3.3; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.3.24; platform_system == 'Linux' | nvidia-nvtx==13.0.85; platform_system == 'Linux' | nvidia-nvjitlink==13.0.88; platform_system == 'Linux' | nvidia-cufile==1.15.1.6; platform_system == 'Linux'
 | 
			
		||||
      PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc==13.0.48; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.48; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.48; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.0.0.19; platform_system == 'Linux' | nvidia-cufft==12.0.0.15; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.3.29; platform_system == 'Linux' | nvidia-cusparse==12.6.2.49; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.3.24; platform_system == 'Linux' | nvidia-nvtx==13.0.39; platform_system == 'Linux' | nvidia-nvjitlink==13.0.39; platform_system == 'Linux' | nvidia-cufile==1.15.0.42; platform_system == 'Linux'
 | 
			
		||||
    secrets:
 | 
			
		||||
      github-token: ${{ secrets.GITHUB_TOKEN }}
 | 
			
		||||
  manywheel-py3_14-cuda13_0-test:  # Testing
 | 
			
		||||
@ -4321,7 +4321,7 @@ jobs:
 | 
			
		||||
      runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
 | 
			
		||||
      build_name: manywheel-py3_14t-cuda13_0
 | 
			
		||||
      build_environment: linux-binary-manywheel
 | 
			
		||||
      PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc==13.0.88; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.96; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.85; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.1.0.3; platform_system == 'Linux' | nvidia-cufft==12.0.0.61; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.4.66; platform_system == 'Linux' | nvidia-cusparse==12.6.3.3; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.3.24; platform_system == 'Linux' | nvidia-nvtx==13.0.85; platform_system == 'Linux' | nvidia-nvjitlink==13.0.88; platform_system == 'Linux' | nvidia-cufile==1.15.1.6; platform_system == 'Linux'
 | 
			
		||||
      PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc==13.0.48; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.48; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.48; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.0.0.19; platform_system == 'Linux' | nvidia-cufft==12.0.0.15; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.3.29; platform_system == 'Linux' | nvidia-cusparse==12.6.2.49; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.3.24; platform_system == 'Linux' | nvidia-nvtx==13.0.39; platform_system == 'Linux' | nvidia-nvjitlink==13.0.39; platform_system == 'Linux' | nvidia-cufile==1.15.0.42; platform_system == 'Linux'
 | 
			
		||||
    secrets:
 | 
			
		||||
      github-token: ${{ secrets.GITHUB_TOKEN }}
 | 
			
		||||
  manywheel-py3_14t-cuda13_0-test:  # Testing
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										8
									
								
								.github/workflows/generated-windows-binary-libtorch-debug-nightly.yml
									
									
									
										generated
									
									
										vendored
									
									
								
							
							
						
						
									
										8
									
								
								.github/workflows/generated-windows-binary-libtorch-debug-nightly.yml
									
									
									
										generated
									
									
										vendored
									
									
								
							@ -44,7 +44,7 @@ jobs:
 | 
			
		||||
  libtorch-cpu-shared-with-deps-debug-build:
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    needs: get-label-type
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
 | 
			
		||||
    timeout-minutes: 360
 | 
			
		||||
    env:
 | 
			
		||||
      PYTORCH_ROOT: ${{ github.workspace }}/pytorch
 | 
			
		||||
@ -291,7 +291,7 @@ jobs:
 | 
			
		||||
  libtorch-cuda12_6-shared-with-deps-debug-build:
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    needs: get-label-type
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
 | 
			
		||||
    timeout-minutes: 360
 | 
			
		||||
    env:
 | 
			
		||||
      PYTORCH_ROOT: ${{ github.workspace }}/pytorch
 | 
			
		||||
@ -541,7 +541,7 @@ jobs:
 | 
			
		||||
  libtorch-cuda12_8-shared-with-deps-debug-build:
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    needs: get-label-type
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
 | 
			
		||||
    timeout-minutes: 360
 | 
			
		||||
    env:
 | 
			
		||||
      PYTORCH_ROOT: ${{ github.workspace }}/pytorch
 | 
			
		||||
@ -791,7 +791,7 @@ jobs:
 | 
			
		||||
  libtorch-cuda13_0-shared-with-deps-debug-build:
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    needs: get-label-type
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
 | 
			
		||||
    timeout-minutes: 360
 | 
			
		||||
    env:
 | 
			
		||||
      PYTORCH_ROOT: ${{ github.workspace }}/pytorch
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										8
									
								
								.github/workflows/generated-windows-binary-libtorch-release-nightly.yml
									
									
									
										generated
									
									
										vendored
									
									
								
							
							
						
						
									
										8
									
								
								.github/workflows/generated-windows-binary-libtorch-release-nightly.yml
									
									
									
										generated
									
									
										vendored
									
									
								
							@ -44,7 +44,7 @@ jobs:
 | 
			
		||||
  libtorch-cpu-shared-with-deps-release-build:
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    needs: get-label-type
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
 | 
			
		||||
    timeout-minutes: 360
 | 
			
		||||
    env:
 | 
			
		||||
      PYTORCH_ROOT: ${{ github.workspace }}/pytorch
 | 
			
		||||
@ -291,7 +291,7 @@ jobs:
 | 
			
		||||
  libtorch-cuda12_6-shared-with-deps-release-build:
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    needs: get-label-type
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
 | 
			
		||||
    timeout-minutes: 360
 | 
			
		||||
    env:
 | 
			
		||||
      PYTORCH_ROOT: ${{ github.workspace }}/pytorch
 | 
			
		||||
@ -541,7 +541,7 @@ jobs:
 | 
			
		||||
  libtorch-cuda12_8-shared-with-deps-release-build:
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    needs: get-label-type
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
 | 
			
		||||
    timeout-minutes: 360
 | 
			
		||||
    env:
 | 
			
		||||
      PYTORCH_ROOT: ${{ github.workspace }}/pytorch
 | 
			
		||||
@ -791,7 +791,7 @@ jobs:
 | 
			
		||||
  libtorch-cuda13_0-shared-with-deps-release-build:
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    needs: get-label-type
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
 | 
			
		||||
    timeout-minutes: 360
 | 
			
		||||
    env:
 | 
			
		||||
      PYTORCH_ROOT: ${{ github.workspace }}/pytorch
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										70
									
								
								.github/workflows/generated-windows-binary-wheel-nightly.yml
									
									
									
										generated
									
									
										vendored
									
									
								
							
							
						
						
									
										70
									
								
								.github/workflows/generated-windows-binary-wheel-nightly.yml
									
									
									
										generated
									
									
										vendored
									
									
								
							@ -44,7 +44,7 @@ jobs:
 | 
			
		||||
  wheel-py3_10-cpu-build:
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    needs: get-label-type
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
 | 
			
		||||
    timeout-minutes: 360
 | 
			
		||||
    env:
 | 
			
		||||
      PYTORCH_ROOT: ${{ github.workspace }}/pytorch
 | 
			
		||||
@ -279,7 +279,7 @@ jobs:
 | 
			
		||||
  wheel-py3_10-cuda12_6-build:
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    needs: get-label-type
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
 | 
			
		||||
    timeout-minutes: 360
 | 
			
		||||
    env:
 | 
			
		||||
      PYTORCH_ROOT: ${{ github.workspace }}/pytorch
 | 
			
		||||
@ -517,7 +517,7 @@ jobs:
 | 
			
		||||
  wheel-py3_10-cuda12_8-build:
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    needs: get-label-type
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
 | 
			
		||||
    timeout-minutes: 360
 | 
			
		||||
    env:
 | 
			
		||||
      PYTORCH_ROOT: ${{ github.workspace }}/pytorch
 | 
			
		||||
@ -755,7 +755,7 @@ jobs:
 | 
			
		||||
  wheel-py3_10-cuda13_0-build:
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    needs: get-label-type
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
 | 
			
		||||
    timeout-minutes: 360
 | 
			
		||||
    env:
 | 
			
		||||
      PYTORCH_ROOT: ${{ github.workspace }}/pytorch
 | 
			
		||||
@ -993,7 +993,7 @@ jobs:
 | 
			
		||||
  wheel-py3_10-xpu-build:
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    needs: get-label-type
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
 | 
			
		||||
    timeout-minutes: 360
 | 
			
		||||
    env:
 | 
			
		||||
      PYTORCH_ROOT: ${{ github.workspace }}/pytorch
 | 
			
		||||
@ -1229,7 +1229,7 @@ jobs:
 | 
			
		||||
  wheel-py3_11-cpu-build:
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    needs: get-label-type
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
 | 
			
		||||
    timeout-minutes: 360
 | 
			
		||||
    env:
 | 
			
		||||
      PYTORCH_ROOT: ${{ github.workspace }}/pytorch
 | 
			
		||||
@ -1464,7 +1464,7 @@ jobs:
 | 
			
		||||
  wheel-py3_11-cuda12_6-build:
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    needs: get-label-type
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
 | 
			
		||||
    timeout-minutes: 360
 | 
			
		||||
    env:
 | 
			
		||||
      PYTORCH_ROOT: ${{ github.workspace }}/pytorch
 | 
			
		||||
@ -1702,7 +1702,7 @@ jobs:
 | 
			
		||||
  wheel-py3_11-cuda12_8-build:
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    needs: get-label-type
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
 | 
			
		||||
    timeout-minutes: 360
 | 
			
		||||
    env:
 | 
			
		||||
      PYTORCH_ROOT: ${{ github.workspace }}/pytorch
 | 
			
		||||
@ -1940,7 +1940,7 @@ jobs:
 | 
			
		||||
  wheel-py3_11-cuda13_0-build:
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    needs: get-label-type
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
 | 
			
		||||
    timeout-minutes: 360
 | 
			
		||||
    env:
 | 
			
		||||
      PYTORCH_ROOT: ${{ github.workspace }}/pytorch
 | 
			
		||||
@ -2178,7 +2178,7 @@ jobs:
 | 
			
		||||
  wheel-py3_11-xpu-build:
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    needs: get-label-type
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
 | 
			
		||||
    timeout-minutes: 360
 | 
			
		||||
    env:
 | 
			
		||||
      PYTORCH_ROOT: ${{ github.workspace }}/pytorch
 | 
			
		||||
@ -2414,7 +2414,7 @@ jobs:
 | 
			
		||||
  wheel-py3_12-cpu-build:
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    needs: get-label-type
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
 | 
			
		||||
    timeout-minutes: 360
 | 
			
		||||
    env:
 | 
			
		||||
      PYTORCH_ROOT: ${{ github.workspace }}/pytorch
 | 
			
		||||
@ -2649,7 +2649,7 @@ jobs:
 | 
			
		||||
  wheel-py3_12-cuda12_6-build:
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    needs: get-label-type
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
 | 
			
		||||
    timeout-minutes: 360
 | 
			
		||||
    env:
 | 
			
		||||
      PYTORCH_ROOT: ${{ github.workspace }}/pytorch
 | 
			
		||||
@ -2887,7 +2887,7 @@ jobs:
 | 
			
		||||
  wheel-py3_12-cuda12_8-build:
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    needs: get-label-type
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
 | 
			
		||||
    timeout-minutes: 360
 | 
			
		||||
    env:
 | 
			
		||||
      PYTORCH_ROOT: ${{ github.workspace }}/pytorch
 | 
			
		||||
@ -3125,7 +3125,7 @@ jobs:
 | 
			
		||||
  wheel-py3_12-cuda13_0-build:
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    needs: get-label-type
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
 | 
			
		||||
    timeout-minutes: 360
 | 
			
		||||
    env:
 | 
			
		||||
      PYTORCH_ROOT: ${{ github.workspace }}/pytorch
 | 
			
		||||
@ -3363,7 +3363,7 @@ jobs:
 | 
			
		||||
  wheel-py3_12-xpu-build:
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    needs: get-label-type
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
 | 
			
		||||
    timeout-minutes: 360
 | 
			
		||||
    env:
 | 
			
		||||
      PYTORCH_ROOT: ${{ github.workspace }}/pytorch
 | 
			
		||||
@ -3599,7 +3599,7 @@ jobs:
 | 
			
		||||
  wheel-py3_13-cpu-build:
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    needs: get-label-type
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
 | 
			
		||||
    timeout-minutes: 360
 | 
			
		||||
    env:
 | 
			
		||||
      PYTORCH_ROOT: ${{ github.workspace }}/pytorch
 | 
			
		||||
@ -3834,7 +3834,7 @@ jobs:
 | 
			
		||||
  wheel-py3_13-cuda12_6-build:
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    needs: get-label-type
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
 | 
			
		||||
    timeout-minutes: 360
 | 
			
		||||
    env:
 | 
			
		||||
      PYTORCH_ROOT: ${{ github.workspace }}/pytorch
 | 
			
		||||
@ -4072,7 +4072,7 @@ jobs:
 | 
			
		||||
  wheel-py3_13-cuda12_8-build:
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    needs: get-label-type
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
 | 
			
		||||
    timeout-minutes: 360
 | 
			
		||||
    env:
 | 
			
		||||
      PYTORCH_ROOT: ${{ github.workspace }}/pytorch
 | 
			
		||||
@ -4310,7 +4310,7 @@ jobs:
 | 
			
		||||
  wheel-py3_13-cuda13_0-build:
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    needs: get-label-type
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
 | 
			
		||||
    timeout-minutes: 360
 | 
			
		||||
    env:
 | 
			
		||||
      PYTORCH_ROOT: ${{ github.workspace }}/pytorch
 | 
			
		||||
@ -4548,7 +4548,7 @@ jobs:
 | 
			
		||||
  wheel-py3_13-xpu-build:
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    needs: get-label-type
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
 | 
			
		||||
    timeout-minutes: 360
 | 
			
		||||
    env:
 | 
			
		||||
      PYTORCH_ROOT: ${{ github.workspace }}/pytorch
 | 
			
		||||
@ -4784,7 +4784,7 @@ jobs:
 | 
			
		||||
  wheel-py3_13t-cpu-build:
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    needs: get-label-type
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
 | 
			
		||||
    timeout-minutes: 360
 | 
			
		||||
    env:
 | 
			
		||||
      PYTORCH_ROOT: ${{ github.workspace }}/pytorch
 | 
			
		||||
@ -5019,7 +5019,7 @@ jobs:
 | 
			
		||||
  wheel-py3_13t-cuda12_6-build:
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    needs: get-label-type
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
 | 
			
		||||
    timeout-minutes: 360
 | 
			
		||||
    env:
 | 
			
		||||
      PYTORCH_ROOT: ${{ github.workspace }}/pytorch
 | 
			
		||||
@ -5257,7 +5257,7 @@ jobs:
 | 
			
		||||
  wheel-py3_13t-cuda12_8-build:
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    needs: get-label-type
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
 | 
			
		||||
    timeout-minutes: 360
 | 
			
		||||
    env:
 | 
			
		||||
      PYTORCH_ROOT: ${{ github.workspace }}/pytorch
 | 
			
		||||
@ -5495,7 +5495,7 @@ jobs:
 | 
			
		||||
  wheel-py3_13t-cuda13_0-build:
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    needs: get-label-type
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
 | 
			
		||||
    timeout-minutes: 360
 | 
			
		||||
    env:
 | 
			
		||||
      PYTORCH_ROOT: ${{ github.workspace }}/pytorch
 | 
			
		||||
@ -5733,7 +5733,7 @@ jobs:
 | 
			
		||||
  wheel-py3_13t-xpu-build:
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    needs: get-label-type
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
 | 
			
		||||
    timeout-minutes: 360
 | 
			
		||||
    env:
 | 
			
		||||
      PYTORCH_ROOT: ${{ github.workspace }}/pytorch
 | 
			
		||||
@ -5969,7 +5969,7 @@ jobs:
 | 
			
		||||
  wheel-py3_14-cpu-build:
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    needs: get-label-type
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
 | 
			
		||||
    timeout-minutes: 360
 | 
			
		||||
    env:
 | 
			
		||||
      PYTORCH_ROOT: ${{ github.workspace }}/pytorch
 | 
			
		||||
@ -6204,7 +6204,7 @@ jobs:
 | 
			
		||||
  wheel-py3_14-cuda12_6-build:
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    needs: get-label-type
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
 | 
			
		||||
    timeout-minutes: 360
 | 
			
		||||
    env:
 | 
			
		||||
      PYTORCH_ROOT: ${{ github.workspace }}/pytorch
 | 
			
		||||
@ -6442,7 +6442,7 @@ jobs:
 | 
			
		||||
  wheel-py3_14-cuda12_8-build:
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    needs: get-label-type
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
 | 
			
		||||
    timeout-minutes: 360
 | 
			
		||||
    env:
 | 
			
		||||
      PYTORCH_ROOT: ${{ github.workspace }}/pytorch
 | 
			
		||||
@ -6680,7 +6680,7 @@ jobs:
 | 
			
		||||
  wheel-py3_14-cuda13_0-build:
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    needs: get-label-type
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
 | 
			
		||||
    timeout-minutes: 360
 | 
			
		||||
    env:
 | 
			
		||||
      PYTORCH_ROOT: ${{ github.workspace }}/pytorch
 | 
			
		||||
@ -6918,7 +6918,7 @@ jobs:
 | 
			
		||||
  wheel-py3_14-xpu-build:
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    needs: get-label-type
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
 | 
			
		||||
    timeout-minutes: 360
 | 
			
		||||
    env:
 | 
			
		||||
      PYTORCH_ROOT: ${{ github.workspace }}/pytorch
 | 
			
		||||
@ -7154,7 +7154,7 @@ jobs:
 | 
			
		||||
  wheel-py3_14t-cpu-build:
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    needs: get-label-type
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
 | 
			
		||||
    timeout-minutes: 360
 | 
			
		||||
    env:
 | 
			
		||||
      PYTORCH_ROOT: ${{ github.workspace }}/pytorch
 | 
			
		||||
@ -7389,7 +7389,7 @@ jobs:
 | 
			
		||||
  wheel-py3_14t-cuda12_6-build:
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    needs: get-label-type
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
 | 
			
		||||
    timeout-minutes: 360
 | 
			
		||||
    env:
 | 
			
		||||
      PYTORCH_ROOT: ${{ github.workspace }}/pytorch
 | 
			
		||||
@ -7627,7 +7627,7 @@ jobs:
 | 
			
		||||
  wheel-py3_14t-cuda12_8-build:
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    needs: get-label-type
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
 | 
			
		||||
    timeout-minutes: 360
 | 
			
		||||
    env:
 | 
			
		||||
      PYTORCH_ROOT: ${{ github.workspace }}/pytorch
 | 
			
		||||
@ -7865,7 +7865,7 @@ jobs:
 | 
			
		||||
  wheel-py3_14t-cuda13_0-build:
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    needs: get-label-type
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
 | 
			
		||||
    timeout-minutes: 360
 | 
			
		||||
    env:
 | 
			
		||||
      PYTORCH_ROOT: ${{ github.workspace }}/pytorch
 | 
			
		||||
@ -8103,7 +8103,7 @@ jobs:
 | 
			
		||||
  wheel-py3_14t-xpu-build:
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    needs: get-label-type
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
 | 
			
		||||
    timeout-minutes: 360
 | 
			
		||||
    env:
 | 
			
		||||
      PYTORCH_ROOT: ${{ github.workspace }}/pytorch
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										1
									
								
								.github/workflows/inductor-periodic.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								.github/workflows/inductor-periodic.yml
									
									
									
									
										vendored
									
									
								
							@ -88,6 +88,7 @@ jobs:
 | 
			
		||||
    with:
 | 
			
		||||
      build-environment: linux-jammy-rocm-py3_10
 | 
			
		||||
      docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3-benchmarks
 | 
			
		||||
      sync-tag: rocm-build
 | 
			
		||||
      test-matrix: |
 | 
			
		||||
        { include: [
 | 
			
		||||
          { config: "dynamo_eager_torchbench", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										15
									
								
								.github/workflows/periodic.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										15
									
								
								.github/workflows/periodic.yml
									
									
									
									
										vendored
									
									
								
							@ -147,16 +147,15 @@ jobs:
 | 
			
		||||
      runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
 | 
			
		||||
      build-environment: linux-jammy-cuda12.8-py3.10-gcc9-debug
 | 
			
		||||
      docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9
 | 
			
		||||
      cuda-arch-list: 8.9
 | 
			
		||||
      test-matrix: |
 | 
			
		||||
        { include: [
 | 
			
		||||
          { config: "default", shard: 1, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", owners: ["oncall:debug-build"] },
 | 
			
		||||
          { config: "default", shard: 2, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", owners: ["oncall:debug-build"] },
 | 
			
		||||
          { config: "default", shard: 3, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", owners: ["oncall:debug-build"] },
 | 
			
		||||
          { config: "default", shard: 4, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", owners: ["oncall:debug-build"] },
 | 
			
		||||
          { config: "default", shard: 5, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", owners: ["oncall:debug-build"] },
 | 
			
		||||
          { config: "default", shard: 6, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", owners: ["oncall:debug-build"] },
 | 
			
		||||
          { config: "default", shard: 7, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", owners: ["oncall:debug-build"] },
 | 
			
		||||
          { config: "default", shard: 1, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu", owners: ["oncall:debug-build"] },
 | 
			
		||||
          { config: "default", shard: 2, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu", owners: ["oncall:debug-build"] },
 | 
			
		||||
          { config: "default", shard: 3, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu", owners: ["oncall:debug-build"] },
 | 
			
		||||
          { config: "default", shard: 4, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu", owners: ["oncall:debug-build"] },
 | 
			
		||||
          { config: "default", shard: 5, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu", owners: ["oncall:debug-build"] },
 | 
			
		||||
          { config: "default", shard: 6, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu", owners: ["oncall:debug-build"] },
 | 
			
		||||
          { config: "default", shard: 7, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu", owners: ["oncall:debug-build"] },
 | 
			
		||||
        ]}
 | 
			
		||||
    secrets: inherit
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										3
									
								
								.github/workflows/pull.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										3
									
								
								.github/workflows/pull.yml
									
									
									
									
										vendored
									
									
								
							@ -347,8 +347,7 @@ jobs:
 | 
			
		||||
    uses: ./.github/workflows/_linux-build.yml
 | 
			
		||||
    needs: get-label-type
 | 
			
		||||
    with:
 | 
			
		||||
      # This should sync with the build in xpu.yml but xpu uses a larger runner
 | 
			
		||||
      # sync-tag: linux-xpu-n-build
 | 
			
		||||
      sync-tag: linux-xpu-n-build
 | 
			
		||||
      runner_prefix: ${{ needs.get-label-type.outputs.label-type }}
 | 
			
		||||
      build-environment: linux-jammy-xpu-n-py3.10
 | 
			
		||||
      docker-image-name: ci-image:pytorch-linux-jammy-xpu-n-py3
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										1
									
								
								.github/workflows/rocm-mi300.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								.github/workflows/rocm-mi300.yml
									
									
									
									
										vendored
									
									
								
							@ -45,6 +45,7 @@ jobs:
 | 
			
		||||
      runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
 | 
			
		||||
      build-environment: linux-noble-rocm-py3.12-mi300
 | 
			
		||||
      docker-image-name: ci-image:pytorch-linux-noble-rocm-n-py3
 | 
			
		||||
      sync-tag: rocm-build
 | 
			
		||||
      test-matrix: |
 | 
			
		||||
        { include: [
 | 
			
		||||
          { config: "default", shard: 1, num_shards: 6, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										1
									
								
								.github/workflows/rocm-mi355.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								.github/workflows/rocm-mi355.yml
									
									
									
									
										vendored
									
									
								
							@ -42,6 +42,7 @@ jobs:
 | 
			
		||||
      runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
 | 
			
		||||
      build-environment: linux-noble-rocm-py3.12-mi355
 | 
			
		||||
      docker-image-name: ci-image:pytorch-linux-noble-rocm-n-py3
 | 
			
		||||
      sync-tag: rocm-build
 | 
			
		||||
      test-matrix: |
 | 
			
		||||
        { include: [
 | 
			
		||||
          { config: "default", shard: 1, num_shards: 6, runner: "linux.rocm.gpu.mi355.1" },
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										12
									
								
								.github/workflows/rocm-navi31.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										12
									
								
								.github/workflows/rocm-navi31.yml
									
									
									
									
										vendored
									
									
								
							@ -26,23 +26,11 @@ jobs:
 | 
			
		||||
      id-token: write
 | 
			
		||||
      contents: read
 | 
			
		||||
 | 
			
		||||
  get-label-type:
 | 
			
		||||
    name: get-label-type
 | 
			
		||||
    uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main
 | 
			
		||||
    if: ${{ (github.event_name != 'schedule' || github.repository == 'pytorch/pytorch') && github.repository_owner == 'pytorch' }}
 | 
			
		||||
    with:
 | 
			
		||||
      triggering_actor: ${{ github.triggering_actor }}
 | 
			
		||||
      issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }}
 | 
			
		||||
      curr_branch: ${{ github.head_ref || github.ref_name }}
 | 
			
		||||
      curr_ref_type: ${{ github.ref_type }}
 | 
			
		||||
 | 
			
		||||
  linux-jammy-rocm-py3_10-build:
 | 
			
		||||
    if: ${{ (github.event_name != 'schedule' || github.repository == 'pytorch/pytorch') && github.repository_owner == 'pytorch' }}
 | 
			
		||||
    name: linux-jammy-rocm-py3.10
 | 
			
		||||
    uses: ./.github/workflows/_linux-build.yml
 | 
			
		||||
    needs: get-label-type
 | 
			
		||||
    with:
 | 
			
		||||
      runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
 | 
			
		||||
      build-environment: linux-jammy-rocm-py3.10
 | 
			
		||||
      docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3
 | 
			
		||||
      sync-tag: rocm-build
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										12
									
								
								.github/workflows/rocm.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										12
									
								
								.github/workflows/rocm.yml
									
									
									
									
										vendored
									
									
								
							@ -26,23 +26,11 @@ jobs:
 | 
			
		||||
      id-token: write
 | 
			
		||||
      contents: read
 | 
			
		||||
 | 
			
		||||
  get-label-type:
 | 
			
		||||
    name: get-label-type
 | 
			
		||||
    uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main
 | 
			
		||||
    if: ${{ (github.event_name != 'schedule' || github.repository == 'pytorch/pytorch') && github.repository_owner == 'pytorch' }}
 | 
			
		||||
    with:
 | 
			
		||||
      triggering_actor: ${{ github.triggering_actor }}
 | 
			
		||||
      issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }}
 | 
			
		||||
      curr_branch: ${{ github.head_ref || github.ref_name }}
 | 
			
		||||
      curr_ref_type: ${{ github.ref_type }}
 | 
			
		||||
 | 
			
		||||
  linux-jammy-rocm-py3_10-build:
 | 
			
		||||
    if: ${{ (github.event_name != 'schedule' || github.repository == 'pytorch/pytorch') && github.repository_owner == 'pytorch' }}
 | 
			
		||||
    name: linux-jammy-rocm-py3.10
 | 
			
		||||
    uses: ./.github/workflows/_linux-build.yml
 | 
			
		||||
    needs: get-label-type
 | 
			
		||||
    with:
 | 
			
		||||
      runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
 | 
			
		||||
      build-environment: linux-jammy-rocm-py3.10
 | 
			
		||||
      docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3
 | 
			
		||||
      sync-tag: rocm-build
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										149
									
								
								.github/workflows/trunk-tagging.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										149
									
								
								.github/workflows/trunk-tagging.yml
									
									
									
									
										vendored
									
									
								
							@ -58,10 +58,8 @@ jobs:
 | 
			
		||||
          else
 | 
			
		||||
            COMMIT_SHA="${{ github.sha }}"
 | 
			
		||||
          fi
 | 
			
		||||
          {
 | 
			
		||||
            echo "sha=${COMMIT_SHA}"
 | 
			
		||||
            echo "tag_name=trunk/${COMMIT_SHA}"
 | 
			
		||||
          } >> "${GITHUB_OUTPUT}"
 | 
			
		||||
          echo "sha=${COMMIT_SHA}" >> "${GITHUB_OUTPUT}"
 | 
			
		||||
          echo "tag_name=trunk/${COMMIT_SHA}" >> "${GITHUB_OUTPUT}"
 | 
			
		||||
 | 
			
		||||
      - name: Validate commit SHA
 | 
			
		||||
        run: |
 | 
			
		||||
@ -89,7 +87,7 @@ jobs:
 | 
			
		||||
            echo "✅ Commit ${COMMIT_SHA} is valid (automatic push trigger)"
 | 
			
		||||
          fi
 | 
			
		||||
 | 
			
		||||
      - name: Create and push tag(s) with retry
 | 
			
		||||
      - name: Create and push tag with retry
 | 
			
		||||
        id: check_tag
 | 
			
		||||
        env:
 | 
			
		||||
          TAG_NAME: ${{ steps.commit.outputs.tag_name }}
 | 
			
		||||
@ -114,23 +112,14 @@ jobs:
 | 
			
		||||
            return 1
 | 
			
		||||
          }
 | 
			
		||||
 | 
			
		||||
          # Counters for summary reporting
 | 
			
		||||
          created_count=0
 | 
			
		||||
          skipped_count=0
 | 
			
		||||
          failed_count=0
 | 
			
		||||
          # Exit early if tag already exists
 | 
			
		||||
          if check_tag_exists; then
 | 
			
		||||
            echo "✅ Tag already exists - no action needed"
 | 
			
		||||
            echo "exists=true" >> "${GITHUB_OUTPUT}"
 | 
			
		||||
            exit 0
 | 
			
		||||
          fi
 | 
			
		||||
 | 
			
		||||
          # Always write outputs once on exit
 | 
			
		||||
          finish() {
 | 
			
		||||
            set +e
 | 
			
		||||
            if [ -n "${GITHUB_OUTPUT:-}" ]; then
 | 
			
		||||
              {
 | 
			
		||||
                echo "created_count=${created_count}"
 | 
			
		||||
                echo "skipped_count=${skipped_count}"
 | 
			
		||||
                echo "failed_count=${failed_count}"
 | 
			
		||||
              } >> "${GITHUB_OUTPUT}"
 | 
			
		||||
            fi
 | 
			
		||||
          }
 | 
			
		||||
          trap finish EXIT
 | 
			
		||||
          echo "Tag ${TAG_NAME} does not exist, proceeding with creation"
 | 
			
		||||
 | 
			
		||||
          # Retry configuration
 | 
			
		||||
          MAX_RETRIES=5
 | 
			
		||||
@ -205,111 +194,31 @@ jobs:
 | 
			
		||||
            }
 | 
			
		||||
          }
 | 
			
		||||
 | 
			
		||||
          # New behavior for push events: enumerate commits in the push and tag each one.
 | 
			
		||||
          # For workflow_dispatch, retain existing single-SHA behavior.
 | 
			
		||||
 | 
			
		||||
          # Always fetch tags once up front to improve idempotency in loops
 | 
			
		||||
          git fetch origin --tags --quiet || true
 | 
			
		||||
 | 
			
		||||
          if [ "${{ github.event_name }}" = "push" ]; then
 | 
			
		||||
            BEFORE_SHA="${{ github.event.before }}"
 | 
			
		||||
            AFTER_SHA="${{ github.sha }}"  # same as event.after
 | 
			
		||||
 | 
			
		||||
            # List commits introduced by this push (old..new), oldest first for stable ordering
 | 
			
		||||
            commits_file="$(mktemp)"
 | 
			
		||||
            git rev-list --reverse "${BEFORE_SHA}..${AFTER_SHA}" > "${commits_file}"
 | 
			
		||||
 | 
			
		||||
            if [ ! -s "${commits_file}" ]; then
 | 
			
		||||
              echo "No new commits found between ${BEFORE_SHA}..${AFTER_SHA}; nothing to tag."
 | 
			
		||||
              rm -f "${commits_file}"
 | 
			
		||||
              exit 0
 | 
			
		||||
            fi
 | 
			
		||||
 | 
			
		||||
            commit_count="$(wc -l < "${commits_file}" | tr -d ' ')"
 | 
			
		||||
            echo "Found ${commit_count} commit(s) to tag for push:"
 | 
			
		||||
            while IFS= read -r sha; do
 | 
			
		||||
              printf '  %s\n' "${sha}"
 | 
			
		||||
            done < "${commits_file}"
 | 
			
		||||
 | 
			
		||||
            while IFS= read -r sha; do
 | 
			
		||||
              TAG_NAME="trunk/${sha}"
 | 
			
		||||
              COMMIT_SHA="${sha}"
 | 
			
		||||
 | 
			
		||||
              # If tag already exists locally or remotely, skip (idempotent)
 | 
			
		||||
              if check_tag_exists; then
 | 
			
		||||
                echo "✅ Tag ${TAG_NAME} already exists - skipping"
 | 
			
		||||
                skipped_count=$((skipped_count + 1))
 | 
			
		||||
                continue
 | 
			
		||||
              fi
 | 
			
		||||
 | 
			
		||||
              echo "Tag ${TAG_NAME} does not exist, proceeding with creation"
 | 
			
		||||
 | 
			
		||||
              if retry_with_backoff "tag_with_retry" "Creating tag ${TAG_NAME} for commit ${COMMIT_SHA}"; then
 | 
			
		||||
                created_count=$((created_count + 1))
 | 
			
		||||
              else
 | 
			
		||||
                echo "Tag creation failed after all retry attempts for ${TAG_NAME}"
 | 
			
		||||
                failed_count=$((failed_count + 1))
 | 
			
		||||
              fi
 | 
			
		||||
            done < "${commits_file}"
 | 
			
		||||
 | 
			
		||||
            rm -f "${commits_file}"
 | 
			
		||||
 | 
			
		||||
            if [ "${failed_count}" -gt 0 ]; then
 | 
			
		||||
              exit 1
 | 
			
		||||
            fi
 | 
			
		||||
          # Execute with retry
 | 
			
		||||
          if retry_with_backoff "tag_with_retry" "Creating tag ${TAG_NAME} for commit ${COMMIT_SHA}"; then
 | 
			
		||||
            echo "exists=false" >> "${GITHUB_OUTPUT}"
 | 
			
		||||
            exit 0
 | 
			
		||||
          else
 | 
			
		||||
            # workflow_dispatch path (single SHA tagging preserved)
 | 
			
		||||
 | 
			
		||||
            # Exit early if tag already exists
 | 
			
		||||
            if check_tag_exists; then
 | 
			
		||||
              echo "✅ Tag already exists - no action needed"
 | 
			
		||||
              skipped_count=1
 | 
			
		||||
              exit 0
 | 
			
		||||
            fi
 | 
			
		||||
 | 
			
		||||
            echo "Tag ${TAG_NAME} does not exist, proceeding with creation"
 | 
			
		||||
 | 
			
		||||
            if retry_with_backoff "tag_with_retry" "Creating tag ${TAG_NAME} for commit ${COMMIT_SHA}"; then
 | 
			
		||||
              created_count=1
 | 
			
		||||
              exit 0
 | 
			
		||||
            else
 | 
			
		||||
              echo "Tag creation failed after all retry attempts"
 | 
			
		||||
              failed_count=1
 | 
			
		||||
              exit 1
 | 
			
		||||
            fi
 | 
			
		||||
            echo "Tag creation failed after all retry attempts"
 | 
			
		||||
            exit 1
 | 
			
		||||
          fi
 | 
			
		||||
 | 
			
		||||
      - name: Tag creation summary
 | 
			
		||||
        if: always()
 | 
			
		||||
        run: |
 | 
			
		||||
          if [ "${{ github.event_name }}" = "push" ]; then
 | 
			
		||||
            echo "Trigger: push on main"
 | 
			
		||||
            echo "Created: ${{ steps.check_tag.outputs.created_count }}"
 | 
			
		||||
            echo "Skipped (already existed): ${{ steps.check_tag.outputs.skipped_count }}"
 | 
			
		||||
            echo "Failed: ${{ steps.check_tag.outputs.failed_count }}"
 | 
			
		||||
            if [ "${{ steps.check_tag.outputs.failed_count }}" = "0" ]; then
 | 
			
		||||
              echo "✅ Completed tagging for push range ${{ github.event.before }}..${{ github.sha }}"
 | 
			
		||||
            else
 | 
			
		||||
              echo "❌ Some tags failed to create for push range ${{ github.event.before }}..${{ github.sha }}"
 | 
			
		||||
            fi
 | 
			
		||||
          if [ "${{ steps.check_tag.outputs.exists }}" = "true" ]; then
 | 
			
		||||
            echo "✅ Tag ${{ steps.commit.outputs.tag_name }} already existed - no action needed"
 | 
			
		||||
          elif [ "${{ job.status }}" = "success" ]; then
 | 
			
		||||
            echo "✅ Successfully created tag ${{ steps.commit.outputs.tag_name }} for commit ${{ steps.commit.outputs.sha }}"
 | 
			
		||||
          else
 | 
			
		||||
            if [ "${{ steps.check_tag.outputs.failed_count }}" = "0" ]; then
 | 
			
		||||
              if [ "${{ steps.check_tag.outputs.created_count }}" = "0" ]; then
 | 
			
		||||
                echo "✅ Tag ${{ steps.commit.outputs.tag_name }} already existed - no action needed"
 | 
			
		||||
              else
 | 
			
		||||
                echo "✅ Successfully created tag ${{ steps.commit.outputs.tag_name }} for commit ${{ steps.commit.outputs.sha }}"
 | 
			
		||||
              fi
 | 
			
		||||
            else
 | 
			
		||||
              echo "❌ Failed to create tag ${{ steps.commit.outputs.tag_name }} for commit ${{ steps.commit.outputs.sha }}"
 | 
			
		||||
            fi
 | 
			
		||||
 | 
			
		||||
            echo ""
 | 
			
		||||
            echo "Tag details:"
 | 
			
		||||
            echo "  Name: ${{ steps.commit.outputs.tag_name }}"
 | 
			
		||||
            echo "  Commit: ${{ steps.commit.outputs.sha }}"
 | 
			
		||||
            echo "  Trigger: ${{ github.event_name }}"
 | 
			
		||||
            if [ -n "${{ github.event.inputs.commit_sha }}" ]; then
 | 
			
		||||
              echo "  Manual commit: ${{ github.event.inputs.commit_sha }}"
 | 
			
		||||
            fi
 | 
			
		||||
            echo "❌ Failed to create tag ${{ steps.commit.outputs.tag_name }} for commit ${{ steps.commit.outputs.sha }}"
 | 
			
		||||
          fi
 | 
			
		||||
 | 
			
		||||
          echo ""
 | 
			
		||||
          echo "Tag details:"
 | 
			
		||||
          echo "  Name: ${{ steps.commit.outputs.tag_name }}"
 | 
			
		||||
          echo "  Commit: ${{ steps.commit.outputs.sha }}"
 | 
			
		||||
          echo "  Trigger: ${{ github.event_name }}"
 | 
			
		||||
          if [ -n "${{ github.event.inputs.commit_sha }}" ]; then
 | 
			
		||||
            echo "  Manual commit: ${{ github.event.inputs.commit_sha }}"
 | 
			
		||||
          fi
 | 
			
		||||
 | 
			
		||||
@ -833,7 +833,8 @@ exclude_patterns = [
 | 
			
		||||
command = [
 | 
			
		||||
    'python3',
 | 
			
		||||
    'tools/linter/adapters/grep_linter.py',
 | 
			
		||||
    '--pattern=(cudaSetDevice|cudaGetDevice)\\(',
 | 
			
		||||
    '--pattern=cudaSetDevice(',
 | 
			
		||||
    '--pattern=cudaGetDevice(',
 | 
			
		||||
    '--linter-name=RAWCUDADEVICE',
 | 
			
		||||
    '--error-name=raw CUDA API usage',
 | 
			
		||||
    """--error-description=\
 | 
			
		||||
@ -1137,8 +1138,11 @@ command = [
 | 
			
		||||
[[linter]]
 | 
			
		||||
code = 'WORKFLOWSYNC'
 | 
			
		||||
include_patterns = [
 | 
			
		||||
    '.github/workflows/*.yml',
 | 
			
		||||
    '.github/workflows/*.yaml',
 | 
			
		||||
    '.github/workflows/pull.yml',
 | 
			
		||||
    '.github/workflows/trunk.yml',
 | 
			
		||||
    '.github/workflows/periodic.yml',
 | 
			
		||||
    '.github/workflows/mac-mps.yml',
 | 
			
		||||
    '.github/workflows/slow.yml',
 | 
			
		||||
]
 | 
			
		||||
command = [
 | 
			
		||||
    'python3',
 | 
			
		||||
 | 
			
		||||
@ -31,9 +31,9 @@ Be careful when running untrusted models. This classification includes models cr
 | 
			
		||||
 | 
			
		||||
**Prefer to execute untrusted models within a secure, isolated environment such as a sandbox** (e.g., containers, virtual machines). This helps protect your system from potentially malicious code. You can find further details and instructions in [this page](https://developers.google.com/code-sandboxing).
 | 
			
		||||
 | 
			
		||||
**Be mindful of risky model formats**. Give preference to share and load weights with the appropriate format for your use case. [safetensors](https://huggingface.co/docs/safetensors/en/index) gives the most safety but is the most restricted in what it supports. [`torch.load`](https://pytorch.org/docs/stable/generated/torch.load.html#torch.load) has a significantly larger surface of attack but is more flexible in what it can serialize. See the documentation for more details.
 | 
			
		||||
**Be mindful of risky model formats**. Give preference to share and load weights with the appropriate format for your use case. [safetensors](https://huggingface.co/docs/safetensors/en/index) gives the most safety but is the most restricted in what it supports. [`torch.load`](https://pytorch.org/docs/stable/generated/torch.load.html#torch.load) with `weights_only=True` is also secure to our knowledge even though it offers significantly larger surface of attack. Loading un-trusted checkpoint with `weights_only=False` MUST never be done.
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
Even for more secure serialization formats, unexpected inputs to the downstream system can cause diverse security threats (e.g. denial of service, out of bound reads/writes) and thus we recommend extensive validation of any untrusted inputs.
 | 
			
		||||
 | 
			
		||||
Important Note: The trustworthiness of a model is not binary. You must always determine the proper level of caution depending on the specific model and how it matches your use case and risk tolerance.
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -38,7 +38,7 @@ set_bool(AT_HIPSPARSELT_ENABLED CAFFE2_USE_HIPSPARSELT)
 | 
			
		||||
 | 
			
		||||
configure_file(Config.h.in "${CMAKE_CURRENT_SOURCE_DIR}/Config.h")
 | 
			
		||||
# TODO: Do not generate CUDAConfig.h for ROCm BUILDS
 | 
			
		||||
# At the moment, `jit_macros.h` include CUDAConfig.h for both CUDA and HIP builds
 | 
			
		||||
# At the moment, `jit_macors.h` include CUDAConfig.h for both CUDA and HIP builds
 | 
			
		||||
if(USE_CUDA OR USE_ROCM)
 | 
			
		||||
  configure_file(cuda/CUDAConfig.h.in "${CMAKE_CURRENT_SOURCE_DIR}/cuda/CUDAConfig.h")
 | 
			
		||||
endif()
 | 
			
		||||
@ -289,15 +289,14 @@ IF(USE_FBGEMM_GENAI)
 | 
			
		||||
 | 
			
		||||
    set_target_properties(fbgemm_genai PROPERTIES POSITION_INDEPENDENT_CODE ON)
 | 
			
		||||
 | 
			
		||||
    set(fbgemm_genai_cuh
 | 
			
		||||
    set(fbgemm_genai_mx8mx8bf16_grouped
 | 
			
		||||
      "${FBGEMM_GENAI_SRCS}/cutlass_extensions/mx8mx8bf16_grouped/"
 | 
			
		||||
      "${FBGEMM_GENAI_SRCS}/"
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    target_include_directories(fbgemm_genai PRIVATE
 | 
			
		||||
      ${FBGEMM_THIRD_PARTY}/cutlass/include
 | 
			
		||||
      ${FBGEMM_THIRD_PARTY}/cutlass/tools/util/include
 | 
			
		||||
      ${fbgemm_genai_cuh}
 | 
			
		||||
      ${fbgemm_genai_mx8mx8bf16_grouped}
 | 
			
		||||
      ${FBGEMM_GENAI_SRCS}/common/include/   # includes fbgemm_gpu/quantize/utils.h, fbgemm_gpu/quantize/tuning_cache.hpp
 | 
			
		||||
      ${FBGEMM_GENAI_SRCS}/include/          # includes fbgemm_gpu/torch_ops.h
 | 
			
		||||
    )
 | 
			
		||||
@ -314,14 +313,13 @@ IF(USE_FBGEMM_GENAI)
 | 
			
		||||
 | 
			
		||||
    # Add additional HIPCC compiler flags for performance
 | 
			
		||||
    set(FBGEMM_GENAI_EXTRA_HIPCC_FLAGS
 | 
			
		||||
      -mllvm
 | 
			
		||||
      -amdgpu-coerce-illegal-types=1
 | 
			
		||||
      -mllvm
 | 
			
		||||
      -enable-post-misched=0
 | 
			
		||||
      -mllvm
 | 
			
		||||
      -greedy-reverse-local-assignment=1
 | 
			
		||||
      -fhip-new-launch-api)
 | 
			
		||||
    if(DEFINED ROCM_VERSION_DEV AND ROCM_VERSION_DEV VERSION_LESS "7.2.0")
 | 
			
		||||
        list(PREPEND FBGEMM_GENAI_EXTRA_HIPCC_FLAGS -mllvm -amdgpu-coerce-illegal-types=1)
 | 
			
		||||
      endif()
 | 
			
		||||
 | 
			
		||||
    # Only compile for gfx942 for now.
 | 
			
		||||
    # This is rather hacky, I could not figure out a clean solution :(
 | 
			
		||||
 | 
			
		||||
@ -19,7 +19,6 @@
 | 
			
		||||
#include <ATen/detail/MPSHooksInterface.h>
 | 
			
		||||
#include <ATen/detail/MTIAHooksInterface.h>
 | 
			
		||||
#include <ATen/detail/PrivateUse1HooksInterface.h>
 | 
			
		||||
#include <ATen/detail/XLAHooksInterface.h>
 | 
			
		||||
#include <ATen/detail/XPUHooksInterface.h>
 | 
			
		||||
#include <c10/core/QEngine.h>
 | 
			
		||||
#include <c10/core/impl/DeviceGuardImplInterface.h>
 | 
			
		||||
@ -89,8 +88,6 @@ class TORCH_API Context {
 | 
			
		||||
      return at::detail::getHIPHooks();
 | 
			
		||||
    } else if (opt_device_type == at::kHPU) {
 | 
			
		||||
      return at::detail::getHPUHooks();
 | 
			
		||||
    } else if (opt_device_type == at::kXLA) {
 | 
			
		||||
      return at::detail::getXLAHooks();
 | 
			
		||||
    } else {
 | 
			
		||||
      TORCH_CHECK(
 | 
			
		||||
          false,
 | 
			
		||||
@ -199,7 +196,7 @@ class TORCH_API Context {
 | 
			
		||||
    return c10::impl::hasDeviceGuardImpl(c10::DeviceType::IPU);
 | 
			
		||||
  }
 | 
			
		||||
  static bool hasXLA() {
 | 
			
		||||
    return detail::getXLAHooks().hasXLA();
 | 
			
		||||
    return c10::impl::hasDeviceGuardImpl(c10::DeviceType::XLA);
 | 
			
		||||
  }
 | 
			
		||||
  static bool hasXPU() {
 | 
			
		||||
    return detail::getXPUHooks().hasXPU();
 | 
			
		||||
 | 
			
		||||
@ -122,7 +122,7 @@ void FunctionalTensorWrapper::freeze_storage() const {
 | 
			
		||||
//          |   have their own storages, but backends like functorch      |
 | 
			
		||||
//         \/   are allowed to re-alias underneath the pass               \/
 | 
			
		||||
// . - - - - - - - - - - - - - .                             . - - - - - - - - - - - - - - - .
 | 
			
		||||
// |    underlying_storage     |                             |      underlying_storage       |
 | 
			
		||||
// |    underyling_storage     |                             |      underyling_storage       |
 | 
			
		||||
// . - - - - - - - - - - - - - .                             . - - - - - - - - - - - - - - - .
 | 
			
		||||
//
 | 
			
		||||
// This constructor is only used by view ops.
 | 
			
		||||
 | 
			
		||||
@ -1534,7 +1534,7 @@ void TensorIteratorBase::build(TensorIteratorConfig& config) {
 | 
			
		||||
 | 
			
		||||
  // XLA and lazy tensors don't have storage, so they don't have an underlying data pointer.
 | 
			
		||||
  // Nothing beyond this point is important for meta functions, so it's fine to exit early here.
 | 
			
		||||
  // Extend the condition to MAIA tensors as MAIA tensors also don't have storage.
 | 
			
		||||
  // Extend the condition to MAIA tesnors as MAIA tensors also don't have storage.
 | 
			
		||||
  if (privateuse1_without_storage  ||
 | 
			
		||||
      common_device_.type() == DeviceType::XLA  ||
 | 
			
		||||
      common_device_.type() == DeviceType::IPU  ||
 | 
			
		||||
 | 
			
		||||
@ -39,7 +39,7 @@ struct HostBlock {
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
template <typename B>
 | 
			
		||||
struct alignas(hardware_destructive_interference_size) FreeBlockList {
 | 
			
		||||
struct alignas(64) FreeBlockList {
 | 
			
		||||
  std::mutex mutex_;
 | 
			
		||||
  std::deque<B*> list_;
 | 
			
		||||
};
 | 
			
		||||
@ -94,11 +94,11 @@ struct PinnedReserveSegment {
 | 
			
		||||
struct TORCH_API HostStats {
 | 
			
		||||
  // COUNT: total allocations (active)
 | 
			
		||||
  Stat active_requests;
 | 
			
		||||
  // SUM: bytes allocated/reserved by this memory allocator. (active)
 | 
			
		||||
  // SUM: bytes allocated/reserved by this memory alocator. (active)
 | 
			
		||||
  Stat active_bytes;
 | 
			
		||||
  // COUNT: total allocations (active + free)
 | 
			
		||||
  Stat allocations;
 | 
			
		||||
  // SUM: bytes allocated/reserved by this memory allocator. This accounts
 | 
			
		||||
  // SUM: bytes allocated/reserved by this memory alocator. This accounts
 | 
			
		||||
  // for both free and in-use blocks.
 | 
			
		||||
  Stat allocated_bytes;
 | 
			
		||||
 | 
			
		||||
@ -122,12 +122,12 @@ struct TORCH_API HostStats {
 | 
			
		||||
// Struct containing memory allocator summary statistics for host, as they
 | 
			
		||||
// are staged for reporting. This is a temporary struct that is used to
 | 
			
		||||
// avoid locking the allocator while collecting stats.
 | 
			
		||||
struct alignas(hardware_destructive_interference_size) HostStatsStaged {
 | 
			
		||||
struct alignas(64) HostStatsStaged {
 | 
			
		||||
  std::mutex timing_mutex_;
 | 
			
		||||
  // COUNT: total allocations (active + free)
 | 
			
		||||
  // LOCK: access to this stat is protected by the allocator's blocks_mutex_
 | 
			
		||||
  Stat allocations;
 | 
			
		||||
  // SUM: bytes allocated/reserved by this memory allocator. This accounts
 | 
			
		||||
  // SUM: bytes allocated/reserved by this memory alocator. This accounts
 | 
			
		||||
  // for both free and in-use blocks.
 | 
			
		||||
  Stat allocated_bytes;
 | 
			
		||||
  // COUNT: number of allocations per bucket (active)
 | 
			
		||||
@ -455,7 +455,7 @@ struct CachingHostAllocatorImpl {
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  void resetAccumulatedStats() {
 | 
			
		||||
    // Resetting accumulated memory stats requires concurrently holding both the
 | 
			
		||||
    // Reseting accumulated memory stats requires concurrently holding both the
 | 
			
		||||
    // free list mutexes and the blocks mutex. Previously, this was only done in
 | 
			
		||||
    // empty_cache function.
 | 
			
		||||
    for (size_t i = 0; i < free_list_.size(); ++i) {
 | 
			
		||||
@ -482,7 +482,7 @@ struct CachingHostAllocatorImpl {
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  void resetPeakStats() {
 | 
			
		||||
    // Resetting peak memory stats requires concurrently holding both the
 | 
			
		||||
    // Reseting peak memory stats requires concurrently holding both the
 | 
			
		||||
    // free list mutexes and the blocks mutex. Previously, this was only done in
 | 
			
		||||
    // empty_cache function.
 | 
			
		||||
    for (size_t i = 0; i < free_list_.size(); ++i) {
 | 
			
		||||
@ -669,7 +669,7 @@ struct CachingHostAllocatorImpl {
 | 
			
		||||
    TORCH_CHECK_NOT_IMPLEMENTED(false, "Not implemented for query_event");
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  alignas(hardware_destructive_interference_size) std::mutex blocks_mutex_;
 | 
			
		||||
  alignas(64) std::mutex blocks_mutex_;
 | 
			
		||||
  ska::flat_hash_set<B*> blocks_; // block list
 | 
			
		||||
  ska::flat_hash_map<void*, B*> ptr_to_block_;
 | 
			
		||||
 | 
			
		||||
@ -677,17 +677,17 @@ struct CachingHostAllocatorImpl {
 | 
			
		||||
  // size. This allows us to quickly find a free block of the right size.
 | 
			
		||||
  // We use deque to store per size free list and guard the list with its own
 | 
			
		||||
  // mutex.
 | 
			
		||||
  alignas(hardware_destructive_interference_size) std::vector<FreeBlockList<B>> free_list_ =
 | 
			
		||||
  alignas(64) std::vector<FreeBlockList<B>> free_list_ =
 | 
			
		||||
      std::vector<FreeBlockList<B>>(MAX_SIZE_INDEX);
 | 
			
		||||
 | 
			
		||||
  alignas(hardware_destructive_interference_size) std::mutex events_mutex_;
 | 
			
		||||
  alignas(64) std::mutex events_mutex_;
 | 
			
		||||
  std::deque<std::pair<E, B*>> events_; // event queue paired with block
 | 
			
		||||
 | 
			
		||||
  // Indicates whether the object is active.
 | 
			
		||||
  // Set to false in the destructor to signal background threads to stop.
 | 
			
		||||
  std::atomic<bool> active_{true};
 | 
			
		||||
protected:
 | 
			
		||||
  alignas(hardware_destructive_interference_size) HostStatsStaged stats_;
 | 
			
		||||
  alignas(64) HostStatsStaged stats_;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
struct TORCH_API HostAllocator : public at::Allocator {
 | 
			
		||||
 | 
			
		||||
@ -59,7 +59,9 @@ struct TORCH_API Generator {
 | 
			
		||||
 | 
			
		||||
  explicit Generator(c10::intrusive_ptr<c10::GeneratorImpl> gen_impl)
 | 
			
		||||
   : impl_(std::move(gen_impl)) {
 | 
			
		||||
    TORCH_CHECK(impl_.get(), "GeneratorImpl with nullptr is not supported");
 | 
			
		||||
    if (impl_.get() == nullptr) {
 | 
			
		||||
      throw std::runtime_error("GeneratorImpl with nullptr is not supported");
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  bool operator==(const Generator& rhs) const {
 | 
			
		||||
 | 
			
		||||
@ -111,7 +111,9 @@ class TORCH_API TensorBase {
 | 
			
		||||
  explicit TensorBase(
 | 
			
		||||
      c10::intrusive_ptr<TensorImpl, UndefinedTensorImpl> tensor_impl)
 | 
			
		||||
      : impl_(std::move(tensor_impl)) {
 | 
			
		||||
    TORCH_CHECK(impl_.get(), "TensorImpl with nullptr is not supported");
 | 
			
		||||
    if (impl_.get() == nullptr) {
 | 
			
		||||
      throw std::runtime_error("TensorImpl with nullptr is not supported");
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
  TensorBase(const TensorBase&) = default;
 | 
			
		||||
  TensorBase(TensorBase&&) noexcept = default;
 | 
			
		||||
 | 
			
		||||
@ -109,10 +109,6 @@ TORCH_LIBRARY_IMPL(_, AutogradHPU, m) {
 | 
			
		||||
  m.fallback(AUTOGRAD_FALLBACK);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TORCH_LIBRARY_IMPL(_, AutogradPrivateUse1, m) {
 | 
			
		||||
  m.fallback(AUTOGRAD_FALLBACK);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#undef AUTOGRAD_FALLBACK
 | 
			
		||||
 | 
			
		||||
} // namespace
 | 
			
		||||
 | 
			
		||||
@ -148,7 +148,7 @@ struct TORCH_API ClassType : public NamedType {
 | 
			
		||||
 | 
			
		||||
  void checkNotExist(const std::string& name, const std::string& what) const;
 | 
			
		||||
 | 
			
		||||
  // Attributes are stored in a specific slot at runtime for efficiency.
 | 
			
		||||
  // Attributes are stored in a specific slot at runtime for effiency.
 | 
			
		||||
  // When emitting instructions we specify the slot so that attribute access is
 | 
			
		||||
  // a constant lookup
 | 
			
		||||
  std::optional<size_t> findAttributeSlot(const std::string& name) const {
 | 
			
		||||
@ -412,7 +412,7 @@ struct TORCH_API ClassType : public NamedType {
 | 
			
		||||
  // Holds method attributes
 | 
			
		||||
  std::weak_ptr<CompilationUnit> compilation_unit_;
 | 
			
		||||
 | 
			
		||||
  // Holds all attributes, attribute details are found on ClassAttribute
 | 
			
		||||
  // Holds all atrributes, attribute details are found on ClassAttribute
 | 
			
		||||
  std::vector<ClassAttribute> attributes_;
 | 
			
		||||
  // Construct mirroring attributes_, only around due to the fact that `containedTypes()` method returns an ArrayRef.
 | 
			
		||||
  // Never fill this without using the appropriate provideNewClassAttribute method
 | 
			
		||||
 | 
			
		||||
@ -442,17 +442,11 @@ RegistrationHandleRAII Dispatcher::registerFallback(DispatchKey dispatchKey, Ker
 | 
			
		||||
 | 
			
		||||
  auto idx = getDispatchTableIndexForDispatchKey(dispatchKey);
 | 
			
		||||
  TORCH_CHECK(idx >= 0 && static_cast<uint64_t>(idx) < backendFallbackKernels_.size(), "idx=", idx);
 | 
			
		||||
  // NB: Perserve BC for registering fallback for AutogradPrivateUse1 multiple time,
 | 
			
		||||
  // refer to https://github.com/pytorch/pytorch/issues/163979 for more informations.
 | 
			
		||||
  TORCH_CHECK(
 | 
			
		||||
      dispatchKey == DispatchKey::AutogradPrivateUse1 ||
 | 
			
		||||
          !backendFallbackKernels_[idx].kernel.isValid(),
 | 
			
		||||
      "Tried to register multiple backend fallbacks for the same dispatch key ",
 | 
			
		||||
      dispatchKey,
 | 
			
		||||
      "; previous registration ",
 | 
			
		||||
      backendFallbackKernels_[idx].debug,
 | 
			
		||||
      ", new registration ",
 | 
			
		||||
      debug);
 | 
			
		||||
    !backendFallbackKernels_[idx].kernel.isValid(),
 | 
			
		||||
    "Tried to register multiple backend fallbacks for the same dispatch key ", dispatchKey, "; previous registration ",
 | 
			
		||||
    backendFallbackKernels_[idx].debug, ", new registration ", debug
 | 
			
		||||
  );
 | 
			
		||||
  // NB: inferred function schema is always nullptr for fallbacks, as fallbacks
 | 
			
		||||
  // cannot be unboxed
 | 
			
		||||
  backendFallbackKernels_[idx] = impl::AnnotatedKernel(std::move(kernel), nullptr, std::move(debug));
 | 
			
		||||
@ -537,7 +531,7 @@ int64_t Dispatcher::sequenceNumberForRunningRecordFunction(DispatchKey dispatchK
 | 
			
		||||
 | 
			
		||||
  // Note: this records a sequence number for both Autograd keys, and for
 | 
			
		||||
  // non-Autograd keys where the dispatchKeySet still contains an autograd key.
 | 
			
		||||
  // This means that we might collect the same sequence number two different
 | 
			
		||||
  // This means that we might collect the same sequence nubmer two different
 | 
			
		||||
  // events if they all occurred above Autograd and still had the Autograd
 | 
			
		||||
  // dispatch key in the dispatch key set.
 | 
			
		||||
  // However, this usually doesn't happen: normally the first call will
 | 
			
		||||
 | 
			
		||||
@ -585,7 +585,7 @@ class TORCH_API OperatorHandle {
 | 
			
		||||
 | 
			
		||||
  // We need to store this iterator in order to make
 | 
			
		||||
  // Dispatcher::cleanup() fast -- it runs a lot on program
 | 
			
		||||
  // termination (and presumably library unloading).
 | 
			
		||||
  // termination (and presuambly library unloading).
 | 
			
		||||
  std::list<Dispatcher::OperatorDef>::iterator operatorIterator_;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -365,7 +365,7 @@ std::pair<const AnnotatedKernel&, const char*> OperatorEntry::computeDispatchTab
 | 
			
		||||
  //          For autograd keys, we only use kernel from CompositeImplicitAutograd when there's no direct registration
 | 
			
		||||
  //          to its corresponding backend key or CompositeExplicitAutograd. See Note [CompositeExplicitAutograd and CompositeImplicitAutograd].
 | 
			
		||||
  //          For AutogradOther, we eagerly return ambiguousAutogradOtherKernel() if there's registration to any of
 | 
			
		||||
  //          its backends and ask backend extender to request a dedicated Autograd key for the backend.
 | 
			
		||||
  //          its backends and ask backend extender to request a decicated Autograd key for the backend.
 | 
			
		||||
  //          See Note [Ambiguity in AutogradOther kernel] for more details.
 | 
			
		||||
  //          A CompositeExplicitAutograd kernel prevents CompositeImplicitAutograd kernel being used for Autograd keys, but it doesn't
 | 
			
		||||
  //          cause confusion for AutogradOther. It's pretty straightforward to use Autograd (if available)
 | 
			
		||||
 | 
			
		||||
@ -261,7 +261,7 @@ std::ostream& operator<<(std::ostream& out, const FunctionSchema& schema) {
 | 
			
		||||
    //
 | 
			
		||||
    // There are 2 cases
 | 
			
		||||
    // 1. something like 'aten::items.str(Dict(str, t) self) -> ((str, t)[])'.
 | 
			
		||||
    // without the extra parenthesis, the c++ scheme parser can not parse it.
 | 
			
		||||
    // without the extra parenthesis, the c++ schem parser can not parse it.
 | 
			
		||||
    // 2. something like '-> ((str, str))'. Need extra parenthesis so the return
 | 
			
		||||
    // type is a single tuple rather than two strings.
 | 
			
		||||
    // PR (https://github.com/pytorch/pytorch/pull/23204) has more context about
 | 
			
		||||
 | 
			
		||||
@ -68,7 +68,11 @@ Symbol InternedStrings::_symbol(const std::string& s) {
 | 
			
		||||
    return it->second;
 | 
			
		||||
 | 
			
		||||
  auto pos = s.find("::");
 | 
			
		||||
  TORCH_CHECK(pos != std::string::npos, "all symbols must have a namespace, <namespace>::<string>, but found: ", s);
 | 
			
		||||
  if (pos == std::string::npos) {
 | 
			
		||||
    std::stringstream ss;
 | 
			
		||||
    ss << "all symbols must have a namespace, <namespace>::<string>, but found: " << s;
 | 
			
		||||
    throw std::runtime_error(ss.str());
 | 
			
		||||
  }
 | 
			
		||||
  Symbol ns = _symbol("namespaces::" + s.substr(0, pos));
 | 
			
		||||
 | 
			
		||||
  Symbol sym(sym_to_info_.size());
 | 
			
		||||
@ -117,7 +121,12 @@ std::string Symbol::domainString() const {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
Symbol Symbol::fromDomainAndUnqualString(const std::string & d, const std::string & s) {
 | 
			
		||||
  TORCH_CHECK(d.compare(0, domain_prefix().size(), domain_prefix()) == 0, "Symbol: domain string is expected to be prefixed with '", domain_prefix(), "', e.g. 'org.pytorch.aten'");
 | 
			
		||||
  if (d.compare(0, domain_prefix().size(), domain_prefix()) != 0) {
 | 
			
		||||
    std::ostringstream ss;
 | 
			
		||||
    ss << "Symbol: domain string is expected to be prefixed with '"
 | 
			
		||||
       << domain_prefix() << "', e.g. 'org.pytorch.aten'";
 | 
			
		||||
    throw std::runtime_error(ss.str());
 | 
			
		||||
  }
 | 
			
		||||
  std::string qualString = d.substr(domain_prefix().size()) + "::" + s;
 | 
			
		||||
  return fromQualString(qualString);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -7,7 +7,6 @@
 | 
			
		||||
#include <ATen/core/jit_type.h>
 | 
			
		||||
#include <ATen/core/stack.h>
 | 
			
		||||
#include <ATen/core/type_factory.h>
 | 
			
		||||
#include <c10/util/Exception.h>
 | 
			
		||||
#include <c10/util/StringUtil.h>
 | 
			
		||||
#include <c10/util/hash.h>
 | 
			
		||||
#include <c10/util/irange.h>
 | 
			
		||||
@ -413,7 +412,7 @@ size_t IValue::hash(const IValue& v) {
 | 
			
		||||
    case Tag::Enum:
 | 
			
		||||
    case Tag::Stream:
 | 
			
		||||
    case Tag::Uninitialized:
 | 
			
		||||
      TORCH_CHECK(false,
 | 
			
		||||
      throw std::runtime_error(
 | 
			
		||||
          "unhashable type: '" + v.type()->repr_str() + "'");
 | 
			
		||||
  }
 | 
			
		||||
  // the above switch should be exhaustive
 | 
			
		||||
 | 
			
		||||
@ -1176,7 +1176,7 @@ struct TORCH_API IValue final {
 | 
			
		||||
  using HashIdentityIValueMap =
 | 
			
		||||
      std::unordered_map<IValue, IValue, HashIdentityIValue, CompIdentityIValues>;
 | 
			
		||||
 | 
			
		||||
  // Checks if this and rhs has a subvalues in common.
 | 
			
		||||
  // Chechs if this and rhs has a subvalues in common.
 | 
			
		||||
  // [t1,t2] and [t2, t3] returns true.
 | 
			
		||||
  bool overlaps(const IValue& rhs) const;
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -1501,7 +1501,7 @@ struct C10_EXPORT ivalue::Object final : c10::intrusive_ptr_target {
 | 
			
		||||
  // However, the CompilationUnit holds ownership of the type's graphs, so
 | 
			
		||||
  // inserting a constant object into a Graph would create a reference cycle if
 | 
			
		||||
  // that constant object held a shared_ptr to its CU. For these objects we
 | 
			
		||||
  // instantiate them with non-owning references to its CU
 | 
			
		||||
  // instatiate them with non-owning references to its CU
 | 
			
		||||
  Object(WeakOrStrongTypePtr type, size_t numSlots) : type_(std::move(type)) {
 | 
			
		||||
    slots_.resize(numSlots);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
@ -8,7 +8,6 @@
 | 
			
		||||
#include <ATen/core/type_factory.h>
 | 
			
		||||
#include <ATen/core/qualified_name.h>
 | 
			
		||||
#include <c10/util/TypeList.h>
 | 
			
		||||
#include <c10/util/Exception.h>
 | 
			
		||||
#include <optional>
 | 
			
		||||
#include <c10/core/SymFloat.h>
 | 
			
		||||
#include <c10/core/SymBool.h>
 | 
			
		||||
@ -117,8 +116,10 @@ struct SingleElementType : public SharedType {
 | 
			
		||||
 | 
			
		||||
 protected:
 | 
			
		||||
  SingleElementType(TypePtr elem) : SharedType(Kind), elem(std::move(elem)) {
 | 
			
		||||
    TORCH_CHECK(this->elem, c10::str(
 | 
			
		||||
    if (!this->elem) {
 | 
			
		||||
      throw std::runtime_error(c10::str(
 | 
			
		||||
            "Can not create ", typeKindToString(Kind), " with None type"));
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
 private:
 | 
			
		||||
@ -373,7 +374,7 @@ struct TORCH_API SymbolicShape {
 | 
			
		||||
  // Unranked shape constructor.
 | 
			
		||||
  SymbolicShape() : dims_(std::nullopt) {}
 | 
			
		||||
 | 
			
		||||
  // Known rank but unknown dimensions.
 | 
			
		||||
  // Known rank but unknown dimentions.
 | 
			
		||||
  SymbolicShape(std::optional<size_t> rank) : dims_(std::nullopt) {
 | 
			
		||||
    if(!rank) {
 | 
			
		||||
      return;
 | 
			
		||||
@ -415,12 +416,16 @@ struct TORCH_API SymbolicShape {
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  ShapeSymbol operator[](size_t i) const {
 | 
			
		||||
    TORCH_CHECK(dims_, "Rank isn't fixed");
 | 
			
		||||
    if (!dims_) {
 | 
			
		||||
      throw std::runtime_error("Rank isn't fixed");
 | 
			
		||||
    }
 | 
			
		||||
    return (*dims_).at(i);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  ShapeSymbol at(size_t i) const {
 | 
			
		||||
    TORCH_CHECK(dims_, "Rank isn't fixed");
 | 
			
		||||
    if (!dims_) {
 | 
			
		||||
      throw std::runtime_error("Rank isn't fixed");
 | 
			
		||||
    }
 | 
			
		||||
    return (*dims_).at(i);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
@ -515,7 +520,9 @@ struct VaryingShape {
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  const std::optional<T> &operator[](size_t i) const {
 | 
			
		||||
    TORCH_CHECK(dims_, "Rank isn't fixed");
 | 
			
		||||
    if (!dims_) {
 | 
			
		||||
      throw std::runtime_error("Rank isn't fixed");
 | 
			
		||||
    }
 | 
			
		||||
    return (*dims_).at(i);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
@ -884,9 +891,9 @@ struct TORCH_API ListType
 | 
			
		||||
 | 
			
		||||
  // global singleton
 | 
			
		||||
  // Given an inner type T and an identifier,
 | 
			
		||||
  // this function will return the global singleton type pointer
 | 
			
		||||
  // this function wil return the global singleton type pointer
 | 
			
		||||
  // the type List<T>.
 | 
			
		||||
  // The extra "identifier" argument is needed because we have multiple container types
 | 
			
		||||
  // The extra "identifier" argument is needed beccause we have multiple container types
 | 
			
		||||
  // that all re-use this function (List<T>, array<T, N>, etc.)
 | 
			
		||||
  static TypePtr get(const std::string& identifier, TypePtr inner);
 | 
			
		||||
 | 
			
		||||
@ -950,7 +957,9 @@ struct TORCH_API DictType : public SharedType {
 | 
			
		||||
 | 
			
		||||
  TypePtr createWithContained(
 | 
			
		||||
      std::vector<TypePtr> contained_types) const override {
 | 
			
		||||
    TORCH_CHECK(contained_types.size() == 2, "Expected 2 contained types");
 | 
			
		||||
    if (contained_types.size() != 2) {
 | 
			
		||||
      throw std::runtime_error("Expected 2 contained types");
 | 
			
		||||
    }
 | 
			
		||||
    return create(std::move(contained_types.at(0)), std::move(contained_types.at(1)));
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -185,11 +185,11 @@ struct TORCH_API Type {
 | 
			
		||||
        : repr_(nullptr) {}
 | 
			
		||||
 | 
			
		||||
    /* implicit */ SingletonOrSharedTypePtr(SingletonTypePtr<T> p)
 | 
			
		||||
        : repr_(makeSingletonSharedPtr(p.get())) {}
 | 
			
		||||
        : repr_(p) {}
 | 
			
		||||
 | 
			
		||||
    template <typename U, std::enable_if_t<std::is_convertible_v<U*, T*>, bool> = true>
 | 
			
		||||
    /* implicit */ SingletonOrSharedTypePtr(SingletonTypePtr<U> p)
 | 
			
		||||
        : repr_(makeSingletonSharedPtr(static_cast<T*>(p.get()))) {}
 | 
			
		||||
        : repr_(SingletonTypePtr<T>(p.get())) {}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    // We need to support construction from T* for pybind. The problem
 | 
			
		||||
@ -202,8 +202,8 @@ struct TORCH_API Type {
 | 
			
		||||
    // Case 2: if T is exactly Type, we need to do a dynamic_cast to
 | 
			
		||||
    // check if it's a SharedType and do the right thing.
 | 
			
		||||
    //
 | 
			
		||||
    // Case 3: Otherwise, T is not a SharedType. Use a singleton
 | 
			
		||||
    // pointer.
 | 
			
		||||
    // Case 3: Otherwise, T is not a SharedType. (debug-check this
 | 
			
		||||
    // assumption!) Use a singleton pointer.
 | 
			
		||||
 | 
			
		||||
    template <typename U = T, std::enable_if_t<std::is_base_of_v<SharedType, U>, bool> = true>
 | 
			
		||||
    /* implicit */ SingletonOrSharedTypePtr(T* p) : SingletonOrSharedTypePtr(static_cast<typename detail::as_shared_type<U>::type>(p)->shared_from_this()) {}
 | 
			
		||||
@ -211,15 +211,15 @@ struct TORCH_API Type {
 | 
			
		||||
    template <typename U = T, std::enable_if_t<std::is_same_v<Type, U>, bool> = true>
 | 
			
		||||
    /* implicit */ SingletonOrSharedTypePtr(T* p) {
 | 
			
		||||
      if (auto* shared_p = dynamic_cast<typename detail::as_shared_type<U>::type>(p)) {
 | 
			
		||||
        repr_ = shared_p->shared_from_this();
 | 
			
		||||
        repr_ = Repr(shared_p->shared_from_this());
 | 
			
		||||
      } else {
 | 
			
		||||
        repr_ = makeSingletonSharedPtr(p);
 | 
			
		||||
        repr_ = Repr(p);
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    template <typename U = T, std::enable_if_t<!std::is_same_v<Type, U> && !std::is_base_of_v<SharedType, U>, bool> = true>
 | 
			
		||||
    /* implicit */ SingletonOrSharedTypePtr(T* p)
 | 
			
		||||
        : repr_(makeSingletonSharedPtr(p)) {
 | 
			
		||||
        : repr_(p) {
 | 
			
		||||
      TORCH_INTERNAL_ASSERT_DEBUG_ONLY(dynamic_cast<typename detail::as_shared_type<U>::type>(p) == nullptr);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
@ -230,19 +230,19 @@ struct TORCH_API Type {
 | 
			
		||||
    ~SingletonOrSharedTypePtr() = default;
 | 
			
		||||
 | 
			
		||||
    T* get() const {
 | 
			
		||||
      return repr_.get();
 | 
			
		||||
      return repr_.isSharedAndNonNull() ? repr_.shared_.repr_.get() : static_cast<T*>(repr_.rawRepr().first);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    operator bool() const {
 | 
			
		||||
      return repr_ != nullptr;
 | 
			
		||||
      return repr_.isNonNull();
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    bool operator==(std::nullptr_t) const {
 | 
			
		||||
      return repr_ == nullptr;
 | 
			
		||||
      return !repr_.isNonNull();
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    bool operator!=(std::nullptr_t) const {
 | 
			
		||||
      return repr_ != nullptr;
 | 
			
		||||
      return repr_.isNonNull();
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    template <typename U = T, std::enable_if_t<!std::is_same_v<std::remove_const_t<U>, void>, bool> = true>
 | 
			
		||||
@ -255,14 +255,138 @@ struct TORCH_API Type {
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
  private:
 | 
			
		||||
    // Use shared_ptr's aliasing constructor to create a non-owning pointer
 | 
			
		||||
    // to a singleton. The lifetime is tied to the null shared_ptr, so there's
 | 
			
		||||
    // no reference counting overhead for the singleton itself.
 | 
			
		||||
    static std::shared_ptr<T> makeSingletonSharedPtr(T* ptr) {
 | 
			
		||||
      return std::shared_ptr<T>(std::shared_ptr<T>(), ptr);
 | 
			
		||||
    }
 | 
			
		||||
    // NOTE: SharedPtrWrapper exists to work around a baffling bug in
 | 
			
		||||
    // nvcc; see comment in destroy() below.
 | 
			
		||||
    struct SharedPtrWrapper {
 | 
			
		||||
      SharedPtrWrapper(std::shared_ptr<T> &&x)
 | 
			
		||||
          : repr_(std::move(x)) {}
 | 
			
		||||
      std::shared_ptr<T> repr_;
 | 
			
		||||
    };
 | 
			
		||||
    union Repr {
 | 
			
		||||
      Repr() : Repr(nullptr) {}
 | 
			
		||||
 | 
			
		||||
    std::shared_ptr<T> repr_;
 | 
			
		||||
      explicit Repr(std::shared_ptr<T> x)
 | 
			
		||||
          : shared_(std::move(x)) {}
 | 
			
		||||
 | 
			
		||||
      explicit Repr(std::nullptr_t)
 | 
			
		||||
          : singletonRepr_(nullptr) {}
 | 
			
		||||
 | 
			
		||||
      explicit Repr(SingletonTypePtr<T> p)
 | 
			
		||||
          : singletonRepr_(p.get()) {}
 | 
			
		||||
 | 
			
		||||
      ~Repr() {
 | 
			
		||||
        destroy();
 | 
			
		||||
      }
 | 
			
		||||
 | 
			
		||||
      // NOTE: the only non-UB way to access our null state is through
 | 
			
		||||
      // rawRepr(), because our copy operation doesn't preserve which
 | 
			
		||||
      // union member is active for null pointers.
 | 
			
		||||
      Repr(const Repr& rhs) {
 | 
			
		||||
        if (rhs.isSharedAndNonNull()) {
 | 
			
		||||
          new (&shared_) SharedPtrWrapper(rhs.shared_);
 | 
			
		||||
        } else {
 | 
			
		||||
          singletonRepr_.singleton_ = static_cast<T*>(rhs.rawRepr().first);
 | 
			
		||||
          TORCH_INTERNAL_ASSERT_DEBUG_ONLY(rhs.singletonRepr_.unused_ == nullptr);
 | 
			
		||||
          singletonRepr_.unused_ = nullptr;
 | 
			
		||||
        }
 | 
			
		||||
      }
 | 
			
		||||
 | 
			
		||||
      Repr(Repr&& rhs) noexcept {
 | 
			
		||||
        if (rhs.isSharedAndNonNull()) {
 | 
			
		||||
          new (&shared_) SharedPtrWrapper(std::move(rhs.shared_));
 | 
			
		||||
        } else {
 | 
			
		||||
          singletonRepr_.singleton_ = static_cast<T*>(rhs.rawRepr().first);
 | 
			
		||||
          TORCH_INTERNAL_ASSERT_DEBUG_ONLY(rhs.singletonRepr_.unused_ == nullptr);
 | 
			
		||||
          singletonRepr_.unused_ = nullptr;
 | 
			
		||||
        }
 | 
			
		||||
      }
 | 
			
		||||
 | 
			
		||||
      Repr& operator=(const Repr& rhs) {
 | 
			
		||||
        if (&rhs == this) {
 | 
			
		||||
          return *this;
 | 
			
		||||
        }
 | 
			
		||||
        if (rhs.isSharedAndNonNull()) {
 | 
			
		||||
          if (isSharedAndNonNull()) {
 | 
			
		||||
            shared_ = rhs.shared_;
 | 
			
		||||
          } else {
 | 
			
		||||
            new (&shared_) SharedPtrWrapper(rhs.shared_);
 | 
			
		||||
          }
 | 
			
		||||
        } else {
 | 
			
		||||
          if (isSharedAndNonNull()) {
 | 
			
		||||
            destroy();
 | 
			
		||||
          }
 | 
			
		||||
          singletonRepr_.singleton_ = static_cast<T*>(rhs.rawRepr().first);
 | 
			
		||||
          TORCH_INTERNAL_ASSERT_DEBUG_ONLY(rhs.rawRepr().nullIfSingleton_ == nullptr);
 | 
			
		||||
          singletonRepr_.unused_ = nullptr;
 | 
			
		||||
        }
 | 
			
		||||
        return *this;
 | 
			
		||||
      }
 | 
			
		||||
 | 
			
		||||
      Repr& operator=(Repr&& rhs) noexcept {
 | 
			
		||||
        if (&rhs == this) {
 | 
			
		||||
          return *this;
 | 
			
		||||
        }
 | 
			
		||||
        if (rhs.isSharedAndNonNull()) {
 | 
			
		||||
          if (isSharedAndNonNull()) {
 | 
			
		||||
            shared_ = std::move(rhs.shared_);
 | 
			
		||||
          } else {
 | 
			
		||||
            new (&shared_) SharedPtrWrapper(std::move(rhs.shared_));
 | 
			
		||||
          }
 | 
			
		||||
        } else {
 | 
			
		||||
          if (isSharedAndNonNull()) {
 | 
			
		||||
            destroy();
 | 
			
		||||
          }
 | 
			
		||||
          singletonRepr_.singleton_ = static_cast<T*>(rhs.rawRepr().first);
 | 
			
		||||
          TORCH_INTERNAL_ASSERT_DEBUG_ONLY(rhs.rawRepr().nullIfSingleton_ == nullptr);
 | 
			
		||||
          singletonRepr_.unused_ = nullptr;
 | 
			
		||||
        }
 | 
			
		||||
        return *this;
 | 
			
		||||
      }
 | 
			
		||||
 | 
			
		||||
      SharedPtrWrapper shared_;
 | 
			
		||||
 | 
			
		||||
      struct SingletonRepr {
 | 
			
		||||
        explicit SingletonRepr(T* s) : singleton_(s) {}
 | 
			
		||||
        T* singleton_;
 | 
			
		||||
        void* unused_ = nullptr;
 | 
			
		||||
      } singletonRepr_;
 | 
			
		||||
      struct RawRepr {
 | 
			
		||||
        void* first;
 | 
			
		||||
        void* nullIfSingleton_;
 | 
			
		||||
      };
 | 
			
		||||
 | 
			
		||||
      // It is UB to read the singleton part of Repr if it was
 | 
			
		||||
      // constructed as a shared_ptr and vice versa, but memcpying out
 | 
			
		||||
      // the representation is always OK, so here's an accessor to obey
 | 
			
		||||
      // the letter of the law.
 | 
			
		||||
      RawRepr rawRepr() const {
 | 
			
		||||
        RawRepr repr{};
 | 
			
		||||
        memcpy(&repr, reinterpret_cast<const char *>(this), sizeof(RawRepr));
 | 
			
		||||
        return repr;
 | 
			
		||||
      }
 | 
			
		||||
 | 
			
		||||
      bool isNonNull() const {
 | 
			
		||||
        auto repr = rawRepr();
 | 
			
		||||
        TORCH_INTERNAL_ASSERT_DEBUG_ONLY(repr.nullIfSingleton_ == nullptr || repr.first != nullptr);
 | 
			
		||||
        return repr.first != nullptr;
 | 
			
		||||
      }
 | 
			
		||||
 | 
			
		||||
      bool isSharedAndNonNull() const {
 | 
			
		||||
        return rawRepr().nullIfSingleton_ != nullptr;
 | 
			
		||||
      }
 | 
			
		||||
 | 
			
		||||
     private:
 | 
			
		||||
      void destroy() {
 | 
			
		||||
        if (isSharedAndNonNull()) {
 | 
			
		||||
          // Without SharedPtrWrapper, this line would read
 | 
			
		||||
          // `shared_.~shared_ptr()` and nvcc would complain with
 | 
			
		||||
          // "error: expected primary-expression before '>' token"
 | 
			
		||||
          // referring to the "t" in "shared_ptr". SharedPtrWrapper
 | 
			
		||||
          // exists to work around this compiler bug.
 | 
			
		||||
          shared_.~SharedPtrWrapper();
 | 
			
		||||
        }
 | 
			
		||||
      }
 | 
			
		||||
    } repr_;
 | 
			
		||||
  };
 | 
			
		||||
 | 
			
		||||
  using TypePtr = SingletonOrSharedTypePtr<Type>;
 | 
			
		||||
 | 
			
		||||
@ -21,7 +21,7 @@ namespace c10 {
 | 
			
		||||
 | 
			
		||||
namespace detail {
 | 
			
		||||
// The first argument of the schema might be of type DispatchKeySet, in which case we remove it.
 | 
			
		||||
// We do this because every argument in a function schema is expected to be convertible
 | 
			
		||||
// We do this because every argument in a function schema is expected to be convertable
 | 
			
		||||
// to an ivalue, but DispatchKeySet is not a type we want the jit to be aware of.
 | 
			
		||||
// See Note [Plumbing Keys Through The Dispatcher]
 | 
			
		||||
template<class KernelFunctor>
 | 
			
		||||
 | 
			
		||||
@ -251,7 +251,7 @@ TEST(OperatorRegistrationTest, whenRegisteringCPUTensorType_thenCanOnlyCallUnbox
 | 
			
		||||
  callOpUnboxedWithPrecomputedDispatchKeySet<void, Tensor>(*op, c10::DispatchKeySet(c10::DispatchKey::CPU), dummyTensor(c10::DispatchKey::CUDA));
 | 
			
		||||
  EXPECT_TRUE(called_kernel_cpu);
 | 
			
		||||
 | 
			
		||||
  // Ensure that dispatch key from tensor is not used here.
 | 
			
		||||
  // Ensure that disptach key from tensor is not used here.
 | 
			
		||||
  called_kernel_cpu = false;
 | 
			
		||||
  expectThrows<c10::Error>([&] {
 | 
			
		||||
    callOpUnboxedWithPrecomputedDispatchKeySet<void, Tensor>(*op, c10::DispatchKeySet(c10::DispatchKey::CUDA), dummyTensor(c10::DispatchKey::CPU));
 | 
			
		||||
 | 
			
		||||
@ -172,7 +172,7 @@ VaryingShape<Stride> TensorType::computeStrideProps(
 | 
			
		||||
  // The logic below follows what TensorIterator uses in its logic:
 | 
			
		||||
  //   1. Fast_set_up is the short-cut to identify a. channels_last and
 | 
			
		||||
  //      b. contiguous format, which is what we have in the below logic.
 | 
			
		||||
  //   2. In more general cases, it does best effort to preserve permutatoin.
 | 
			
		||||
  //   2. In more generla cases, it does best effort to preserve permutatoin.
 | 
			
		||||
  if (is_channels_last_strides_2d(sizes, strides) || is_channels_last_strides_3d(sizes, strides)) {
 | 
			
		||||
    // case 1.a. short cut channels last
 | 
			
		||||
    std::iota(stride_indices.rbegin() + 1, stride_indices.rend() - 1, 2);
 | 
			
		||||
 | 
			
		||||
@ -8,7 +8,6 @@
 | 
			
		||||
#include <ATen/core/jit_type.h>
 | 
			
		||||
#include <c10/macros/Macros.h>
 | 
			
		||||
#include <c10/util/env.h>
 | 
			
		||||
#include <c10/util/Exception.h>
 | 
			
		||||
#include <c10/util/flat_hash_map.h>
 | 
			
		||||
#include <c10/util/irange.h>
 | 
			
		||||
#include <array>
 | 
			
		||||
@ -827,7 +826,9 @@ TupleType::TupleType(
 | 
			
		||||
    : NamedType(TypeKind::TupleType, std::move(name)),
 | 
			
		||||
      elements_(std::move(elements)),
 | 
			
		||||
      has_free_variables_(std::any_of(elements_.begin(), elements_.end(), [](const TypePtr& v) {
 | 
			
		||||
        TORCH_CHECK(v, "Can not create tuple with None type");
 | 
			
		||||
        if (!v) {
 | 
			
		||||
          throw std::runtime_error("Can not create tuple with None type");
 | 
			
		||||
        }
 | 
			
		||||
        return v->hasFreeVariables();
 | 
			
		||||
      })), schema_(std::move(schema)) {
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -104,6 +104,71 @@ class Vectorized<float> {
 | 
			
		||||
    }
 | 
			
		||||
    return b;
 | 
			
		||||
  }
 | 
			
		||||
  // Implementation is picked from
 | 
			
		||||
  // https://github.com/ARM-software/ComputeLibrary/blob/v25.01/src/core/NEON/SVEMath.inl#L105
 | 
			
		||||
  inline svfloat32_t svexp_f32_z(svbool_t pg, svfloat32_t x) const {
 | 
			
		||||
    const auto c1 =
 | 
			
		||||
        svreinterpret_f32_u32(svdup_n_u32(0x3f7ffff6)); // x^1: 0x1.ffffecp-1f
 | 
			
		||||
    const auto c2 =
 | 
			
		||||
        svreinterpret_f32_u32(svdup_n_u32(0x3efffedb)); // x^2: 0x1.fffdb6p-2f
 | 
			
		||||
    const auto c3 =
 | 
			
		||||
        svreinterpret_f32_u32(svdup_n_u32(0x3e2aaf33)); // x^3: 0x1.555e66p-3f
 | 
			
		||||
    const auto c4 =
 | 
			
		||||
        svreinterpret_f32_u32(svdup_n_u32(0x3d2b9f17)); // x^4: 0x1.573e2ep-5f
 | 
			
		||||
    const auto c5 =
 | 
			
		||||
        svreinterpret_f32_u32(svdup_n_u32(0x3c072010)); // x^5: 0x1.0e4020p-7f
 | 
			
		||||
    const auto shift = svreinterpret_f32_u32(
 | 
			
		||||
        svdup_n_u32(0x4b00007f)); // 2^23 + 127 = 0x1.0000fep23f
 | 
			
		||||
    const auto inv_ln2 = svreinterpret_f32_u32(
 | 
			
		||||
        svdup_n_u32(0x3fb8aa3b)); // 1 / ln(2) = 0x1.715476p+0f
 | 
			
		||||
    const auto neg_ln2_hi = svreinterpret_f32_u32(svdup_n_u32(
 | 
			
		||||
        0xbf317200)); // -ln(2) from bits  -1 to -19: -0x1.62e400p-1f
 | 
			
		||||
    const auto neg_ln2_lo = svreinterpret_f32_u32(svdup_n_u32(
 | 
			
		||||
        0xb5bfbe8e)); // -ln(2) from bits -20 to -42: -0x1.7f7d1cp-20f
 | 
			
		||||
    const auto inf = svdup_n_f32(std::numeric_limits<float>::infinity());
 | 
			
		||||
    const auto max_input = svdup_n_f32(88.37f); // Approximately ln(2^127.5)
 | 
			
		||||
    const auto zero = svdup_n_f32(0.f);
 | 
			
		||||
    const auto min_input = svdup_n_f32(-86.64f); // Approximately ln(2^-125)
 | 
			
		||||
    // Range reduction:
 | 
			
		||||
    //   e^x = 2^n * e^r
 | 
			
		||||
    // where:
 | 
			
		||||
    //   n = floor(x / ln(2))
 | 
			
		||||
    //   r = x - n * ln(2)
 | 
			
		||||
    //
 | 
			
		||||
    // By adding x / ln(2) with 2^23 + 127 (shift):
 | 
			
		||||
    //   * As FP32 fraction part only has 23-bits, the addition of 2^23 + 127
 | 
			
		||||
    //   forces decimal part
 | 
			
		||||
    //     of x / ln(2) out of the result. The integer part of x / ln(2) (i.e.
 | 
			
		||||
    //     n) + 127 will occupy the whole fraction part of z in FP32 format.
 | 
			
		||||
    //     Subtracting 2^23 + 127 (shift) from z will result in the integer part
 | 
			
		||||
    //     of x / ln(2) (i.e. n) because the decimal part has been pushed out
 | 
			
		||||
    //     and lost.
 | 
			
		||||
    //   * The addition of 127 makes the FP32 fraction part of z ready to be
 | 
			
		||||
    //   used as the exponent
 | 
			
		||||
    //     in FP32 format. Left shifting z by 23 bits will result in 2^n.
 | 
			
		||||
    const auto z = svmla_f32_z(pg, shift, x, inv_ln2);
 | 
			
		||||
    const auto n = svsub_f32_z(pg, z, shift);
 | 
			
		||||
    const auto scale = svreinterpret_f32_u32(
 | 
			
		||||
        svlsl_n_u32_z(pg, svreinterpret_u32_f32(z), 23)); // 2^n
 | 
			
		||||
    // The calculation of n * ln(2) is done using 2 steps to achieve accuracy
 | 
			
		||||
    // beyond FP32. This outperforms longer Taylor series (3-4 tabs) both in
 | 
			
		||||
    // term of accuracy and performance.
 | 
			
		||||
    const auto r_hi = svmla_f32_z(pg, x, n, neg_ln2_hi);
 | 
			
		||||
    const auto r = svmla_f32_z(pg, r_hi, n, neg_ln2_lo);
 | 
			
		||||
    // Compute the truncated Taylor series of e^r.
 | 
			
		||||
    //   poly = scale * (1 + c1 * r + c2 * r^2 + c3 * r^3 + c4 * r^4 + c5 * r^5)
 | 
			
		||||
    const auto r2 = svmul_f32_z(pg, r, r);
 | 
			
		||||
    const auto p1 = svmul_f32_z(pg, c1, r);
 | 
			
		||||
    const auto p23 = svmla_f32_z(pg, c2, c3, r);
 | 
			
		||||
    const auto p45 = svmla_f32_z(pg, c4, c5, r);
 | 
			
		||||
    const auto p2345 = svmla_f32_z(pg, p23, p45, r2);
 | 
			
		||||
    const auto p12345 = svmla_f32_z(pg, p1, p2345, r2);
 | 
			
		||||
    auto poly = svmla_f32_z(pg, scale, p12345, scale);
 | 
			
		||||
    // Handle underflow and overflow.
 | 
			
		||||
    poly = svsel_f32(svcmplt_f32(pg, x, min_input), zero, poly);
 | 
			
		||||
    poly = svsel_f32(svcmpgt_f32(pg, x, max_input), inf, poly);
 | 
			
		||||
    return poly;
 | 
			
		||||
  }
 | 
			
		||||
  static Vectorized<float> loadu(const void* ptr, int64_t count = size()) {
 | 
			
		||||
    if (count == size())
 | 
			
		||||
      return svld1_f32(ptrue, reinterpret_cast<const float*>(ptr));
 | 
			
		||||
@ -248,41 +313,11 @@ class Vectorized<float> {
 | 
			
		||||
    return USE_SLEEF(
 | 
			
		||||
        Vectorized<float>(Sleef_expm1fx_u10sve(values)), map(std::expm1));
 | 
			
		||||
  }
 | 
			
		||||
  // Implementation copied from Arm Optimized Routines:
 | 
			
		||||
  // https://github.com/ARM-software/optimized-routines/blob/master/math/aarch64/sve/expf.c
 | 
			
		||||
  Vectorized<float> exp_u20() const {
 | 
			
		||||
    // special case to handle special inputs that are too large or too small
 | 
			
		||||
    // i.e. where there's at least one element x, s.t. |x| >= 87.3...
 | 
			
		||||
    svbool_t is_special_case = svacgt(svptrue_b32(), values, 0x1.5d5e2ap+6f);
 | 
			
		||||
    if (svptest_any(svptrue_b32(), is_special_case)) {
 | 
			
		||||
      return exp();
 | 
			
		||||
    }
 | 
			
		||||
    const svfloat32_t ln2_hi = svdup_n_f32(0x1.62e4p-1f);
 | 
			
		||||
    const svfloat32_t ln2_lo = svdup_n_f32(0x1.7f7d1cp-20f);
 | 
			
		||||
    const svfloat32_t c1 = svdup_n_f32(0.5f);
 | 
			
		||||
    const svfloat32_t inv_ln2 = svdup_n_f32(0x1.715476p+0f);
 | 
			
		||||
 | 
			
		||||
    const float shift = 0x1.803f8p17f;
 | 
			
		||||
 | 
			
		||||
    /* n = round(x/(ln2/N)).  */
 | 
			
		||||
    svfloat32_t z = svmad_x(svptrue_b32(), inv_ln2, values, shift);
 | 
			
		||||
    svfloat32_t n = svsub_x(svptrue_b32(), z, shift);
 | 
			
		||||
 | 
			
		||||
    /* r = x - n*ln2/N.  */
 | 
			
		||||
    svfloat32_t r = values;
 | 
			
		||||
    r = svmls_x(svptrue_b32(), r, n, ln2_hi);
 | 
			
		||||
    r = svmls_x(svptrue_b32(), r, n, ln2_lo);
 | 
			
		||||
 | 
			
		||||
    /* scale = 2^(n/N).  */
 | 
			
		||||
    svfloat32_t scale = svexpa(svreinterpret_u32(z));
 | 
			
		||||
 | 
			
		||||
    /* poly(r) = exp(r) - 1 ~= r + 0.5 r^2.  */
 | 
			
		||||
    svfloat32_t r2 = svmul_x(svptrue_b32(), r, r);
 | 
			
		||||
    svfloat32_t poly = svmla_x(svptrue_b32(), r, r2, c1);
 | 
			
		||||
    return svmla_x(svptrue_b32(), scale, scale, poly);
 | 
			
		||||
    return exp();
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized<float> fexp_u20() const {
 | 
			
		||||
    return exp_u20();
 | 
			
		||||
    return exp();
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized<float> fmod(const Vectorized<float>& q) const {USE_SLEEF(
 | 
			
		||||
      { return Vectorized<float>(Sleef_fmodfx_sve(values, q)); },
 | 
			
		||||
@ -418,11 +453,9 @@ class Vectorized<float> {
 | 
			
		||||
        ptrue, svmax_f32_z(ptrue, values, CONST_MIN_TANH), CONST_MAX_TANH);
 | 
			
		||||
 | 
			
		||||
    // Step 2: Calculate exp(2 * x), where x is the clamped value.
 | 
			
		||||
    // svmul_f32_z computes 2 * x, and exp_u20() computes the exponential of
 | 
			
		||||
    // the result (via Vectorized<float>, then auto-converts back to
 | 
			
		||||
    // svfloat32_t).
 | 
			
		||||
    svfloat32_t exp2x =
 | 
			
		||||
        Vectorized<float>(svmul_f32_z(ptrue, CONST_2, x)).exp_u20();
 | 
			
		||||
    // svmul_f32_z computes 2 * x, and svexp_f32_z computes the exponential of
 | 
			
		||||
    // the result.
 | 
			
		||||
    svfloat32_t exp2x = svexp_f32_z(ptrue, svmul_f32_z(ptrue, CONST_2, x));
 | 
			
		||||
 | 
			
		||||
    // Step 3: Calculate the numerator of the tanh function, which is exp(2x)
 | 
			
		||||
    // - 1.
 | 
			
		||||
 | 
			
		||||
@ -6,11 +6,9 @@
 | 
			
		||||
#ifdef __aarch64__
 | 
			
		||||
#if !defined(CPU_CAPABILITY_SVE)
 | 
			
		||||
#include <ATen/cpu/vec/vec128/vec128_bfloat16_neon.h>
 | 
			
		||||
#include <ATen/cpu/vec/vec128/vec128_double_neon.h>
 | 
			
		||||
#include <ATen/cpu/vec/vec128/vec128_float_neon.h>
 | 
			
		||||
#include <ATen/cpu/vec/vec128/vec128_half_neon.h>
 | 
			
		||||
#include <ATen/cpu/vec/vec128/vec128_int_aarch64.h>
 | 
			
		||||
#include <ATen/cpu/vec/vec128/vec128_uint_aarch64.h>
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
#include <ATen/cpu/vec/vec128/vec128_convert.h>
 | 
			
		||||
 | 
			
		||||
@ -354,47 +354,9 @@ class Vectorized<c10::BFloat16> : public Vectorized16<
 | 
			
		||||
 | 
			
		||||
  DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(abs)
 | 
			
		||||
  Vectorized frac() const;
 | 
			
		||||
  DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(neg)
 | 
			
		||||
  DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(trunc)
 | 
			
		||||
  DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(sqrt)
 | 
			
		||||
 | 
			
		||||
#ifdef __ARM_FEATURE_BF16
 | 
			
		||||
  Vectorized<c10::BFloat16> neg() const {
 | 
			
		||||
    return -values;
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized<c10::BFloat16> reciprocal() const {
 | 
			
		||||
    return 1.0f / values;
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized<c10::BFloat16> operator==(
 | 
			
		||||
      const Vectorized<c10::BFloat16>& other) const {
 | 
			
		||||
    return values == other.values;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  Vectorized<c10::BFloat16> operator!=(
 | 
			
		||||
      const Vectorized<c10::BFloat16>& other) const {
 | 
			
		||||
    return values != other.values;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  Vectorized<c10::BFloat16> operator<(
 | 
			
		||||
      const Vectorized<c10::BFloat16>& other) const {
 | 
			
		||||
    return values < other.values;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  Vectorized<c10::BFloat16> operator<=(
 | 
			
		||||
      const Vectorized<c10::BFloat16>& other) const {
 | 
			
		||||
    return values <= other.values;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  Vectorized<c10::BFloat16> operator>(
 | 
			
		||||
      const Vectorized<c10::BFloat16>& other) const {
 | 
			
		||||
    return values > other.values;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  Vectorized<c10::BFloat16> operator>=(
 | 
			
		||||
      const Vectorized<c10::BFloat16>& other) const {
 | 
			
		||||
    return values >= other.values;
 | 
			
		||||
  }
 | 
			
		||||
#else
 | 
			
		||||
  DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(neg)
 | 
			
		||||
  DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(reciprocal)
 | 
			
		||||
  DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator==)
 | 
			
		||||
  DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator!=)
 | 
			
		||||
@ -402,7 +364,6 @@ class Vectorized<c10::BFloat16> : public Vectorized16<
 | 
			
		||||
  DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator<=)
 | 
			
		||||
  DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator>)
 | 
			
		||||
  DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator>=)
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
#undef DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD
 | 
			
		||||
#undef DEFINE_BINARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD
 | 
			
		||||
@ -451,52 +412,28 @@ template <>
 | 
			
		||||
Vectorized<c10::BFloat16> inline operator+(
 | 
			
		||||
    const Vectorized<c10::BFloat16>& a,
 | 
			
		||||
    const Vectorized<c10::BFloat16>& b) {
 | 
			
		||||
#ifdef __ARM_FEATURE_BF16
 | 
			
		||||
  bfloat16x8_t x = a;
 | 
			
		||||
  bfloat16x8_t y = b;
 | 
			
		||||
  return x + y;
 | 
			
		||||
#else
 | 
			
		||||
  return binary_operator_via_float(std::plus<Vectorized<float>>(), a, b);
 | 
			
		||||
#endif
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<c10::BFloat16> inline operator-(
 | 
			
		||||
    const Vectorized<c10::BFloat16>& a,
 | 
			
		||||
    const Vectorized<c10::BFloat16>& b) {
 | 
			
		||||
#ifdef __ARM_FEATURE_BF16
 | 
			
		||||
  bfloat16x8_t x = a;
 | 
			
		||||
  bfloat16x8_t y = b;
 | 
			
		||||
  return x - y;
 | 
			
		||||
#else
 | 
			
		||||
  return binary_operator_via_float(std::minus<Vectorized<float>>(), a, b);
 | 
			
		||||
#endif
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<c10::BFloat16> inline operator*(
 | 
			
		||||
    const Vectorized<c10::BFloat16>& a,
 | 
			
		||||
    const Vectorized<c10::BFloat16>& b) {
 | 
			
		||||
#ifdef __ARM_FEATURE_BF16
 | 
			
		||||
  bfloat16x8_t x = a;
 | 
			
		||||
  bfloat16x8_t y = b;
 | 
			
		||||
  return x * y;
 | 
			
		||||
#else
 | 
			
		||||
  return binary_operator_via_float(std::multiplies<Vectorized<float>>(), a, b);
 | 
			
		||||
#endif
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<c10::BFloat16> inline operator/(
 | 
			
		||||
    const Vectorized<c10::BFloat16>& a,
 | 
			
		||||
    const Vectorized<c10::BFloat16>& b) {
 | 
			
		||||
#ifdef __ARM_FEATURE_BF16
 | 
			
		||||
  bfloat16x8_t x = a;
 | 
			
		||||
  bfloat16x8_t y = b;
 | 
			
		||||
  return x / y;
 | 
			
		||||
#else
 | 
			
		||||
  return binary_operator_via_float(std::divides<Vectorized<float>>(), a, b);
 | 
			
		||||
#endif
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// frac. Implement this here so we can use subtraction
 | 
			
		||||
@ -607,19 +544,12 @@ Vectorized<c10::BFloat16> inline fmadd(
 | 
			
		||||
    const Vectorized<c10::BFloat16>& a,
 | 
			
		||||
    const Vectorized<c10::BFloat16>& b,
 | 
			
		||||
    const Vectorized<c10::BFloat16>& c) {
 | 
			
		||||
#ifdef __ARM_FEATURE_BF16
 | 
			
		||||
  bfloat16x8_t x = a;
 | 
			
		||||
  bfloat16x8_t y = b;
 | 
			
		||||
  bfloat16x8_t z = c;
 | 
			
		||||
  return x * y + z;
 | 
			
		||||
#else
 | 
			
		||||
  // NOTE [BF16 FMA]: There isn't an FMA that accumulates into BF16!  Also,
 | 
			
		||||
  // vbfmlalbq_f32 and vbfmlaltq_f32 take the even and odd-numbered
 | 
			
		||||
  // elements, not the bottom and top half, so they don't seem
 | 
			
		||||
  // particularly useful here. Ideally we would include dot product in
 | 
			
		||||
  // the Vectorized interface...
 | 
			
		||||
  return a * b + c;
 | 
			
		||||
#endif
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
@ -627,15 +557,8 @@ Vectorized<c10::BFloat16> inline fnmadd(
 | 
			
		||||
    const Vectorized<c10::BFloat16>& a,
 | 
			
		||||
    const Vectorized<c10::BFloat16>& b,
 | 
			
		||||
    const Vectorized<c10::BFloat16>& c) {
 | 
			
		||||
#ifdef __ARM_FEATURE_BF16
 | 
			
		||||
  bfloat16x8_t x = a;
 | 
			
		||||
  bfloat16x8_t y = b;
 | 
			
		||||
  bfloat16x8_t z = c;
 | 
			
		||||
  return (-x) * y + z;
 | 
			
		||||
#else
 | 
			
		||||
  // See NOTE [BF16 FMA] above.
 | 
			
		||||
  return -a * b + c;
 | 
			
		||||
#endif
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
@ -643,15 +566,8 @@ Vectorized<c10::BFloat16> inline fmsub(
 | 
			
		||||
    const Vectorized<c10::BFloat16>& a,
 | 
			
		||||
    const Vectorized<c10::BFloat16>& b,
 | 
			
		||||
    const Vectorized<c10::BFloat16>& c) {
 | 
			
		||||
#ifdef __ARM_FEATURE_BF16
 | 
			
		||||
  bfloat16x8_t x = a;
 | 
			
		||||
  bfloat16x8_t y = b;
 | 
			
		||||
  bfloat16x8_t z = c;
 | 
			
		||||
  return x * y - z;
 | 
			
		||||
#else
 | 
			
		||||
  // See NOTE [BF16 FMA] above.
 | 
			
		||||
  return a * b - c;
 | 
			
		||||
#endif
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
@ -659,15 +575,8 @@ Vectorized<c10::BFloat16> inline fnmsub(
 | 
			
		||||
    const Vectorized<c10::BFloat16>& a,
 | 
			
		||||
    const Vectorized<c10::BFloat16>& b,
 | 
			
		||||
    const Vectorized<c10::BFloat16>& c) {
 | 
			
		||||
#ifdef __ARM_FEATURE_BF16
 | 
			
		||||
  bfloat16x8_t x = a;
 | 
			
		||||
  bfloat16x8_t y = b;
 | 
			
		||||
  bfloat16x8_t z = c;
 | 
			
		||||
  return (-x) * y - z;
 | 
			
		||||
#else
 | 
			
		||||
  // See NOTE [BF16 FMA] above.
 | 
			
		||||
  return -a * b - c;
 | 
			
		||||
#endif
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#endif // !defined(C10_MOBILE) && defined(__aarch64__)
 | 
			
		||||
 | 
			
		||||
@ -5,129 +5,6 @@
 | 
			
		||||
namespace at::vec {
 | 
			
		||||
inline namespace CPU_CAPABILITY {
 | 
			
		||||
#if (defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE256))
 | 
			
		||||
 | 
			
		||||
// Enable auto-vectorization for GCC-13+ and clang-17+
 | 
			
		||||
// GCC-12 has a bug: gcc.gnu.org/bugzilla/show_bug.cgi?id=117001
 | 
			
		||||
#if __GNUC__ > 12 || (defined(__clang__) && (__clang_major__ >= 17))
 | 
			
		||||
 | 
			
		||||
template <typename from_type, typename to_type>
 | 
			
		||||
inline void convertImpl(
 | 
			
		||||
    const from_type* __restrict src,
 | 
			
		||||
    to_type* __restrict dst,
 | 
			
		||||
    int64_t n) {
 | 
			
		||||
  uint64_t len = static_cast<uint64_t>(n);
 | 
			
		||||
  for (uint64_t i = 0; i < len; i++) {
 | 
			
		||||
    dst[i] = static_cast<to_type>(src[i]);
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#define CONVERT_TEMPLATE(from_type, to_type)                           \
 | 
			
		||||
  template <>                                                          \
 | 
			
		||||
  inline void convert(const from_type* src, to_type* dst, int64_t n) { \
 | 
			
		||||
    return convertImpl<from_type, to_type>(src, dst, n);               \
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
CONVERT_TEMPLATE(uint8_t, uint8_t)
 | 
			
		||||
CONVERT_TEMPLATE(uint8_t, int8_t)
 | 
			
		||||
CONVERT_TEMPLATE(uint8_t, int16_t)
 | 
			
		||||
CONVERT_TEMPLATE(uint8_t, int32_t)
 | 
			
		||||
CONVERT_TEMPLATE(uint8_t, int64_t)
 | 
			
		||||
CONVERT_TEMPLATE(uint8_t, float)
 | 
			
		||||
CONVERT_TEMPLATE(uint8_t, double)
 | 
			
		||||
CONVERT_TEMPLATE(int8_t, uint8_t)
 | 
			
		||||
CONVERT_TEMPLATE(int8_t, int8_t)
 | 
			
		||||
CONVERT_TEMPLATE(int8_t, int16_t)
 | 
			
		||||
CONVERT_TEMPLATE(int8_t, int32_t)
 | 
			
		||||
CONVERT_TEMPLATE(int8_t, int64_t)
 | 
			
		||||
CONVERT_TEMPLATE(int8_t, float)
 | 
			
		||||
CONVERT_TEMPLATE(int8_t, double)
 | 
			
		||||
CONVERT_TEMPLATE(int16_t, uint8_t)
 | 
			
		||||
CONVERT_TEMPLATE(int16_t, int8_t)
 | 
			
		||||
CONVERT_TEMPLATE(int16_t, int16_t)
 | 
			
		||||
CONVERT_TEMPLATE(int16_t, int32_t)
 | 
			
		||||
CONVERT_TEMPLATE(int16_t, int64_t)
 | 
			
		||||
CONVERT_TEMPLATE(int16_t, float)
 | 
			
		||||
CONVERT_TEMPLATE(int16_t, double)
 | 
			
		||||
CONVERT_TEMPLATE(int32_t, uint8_t)
 | 
			
		||||
CONVERT_TEMPLATE(int32_t, int8_t)
 | 
			
		||||
CONVERT_TEMPLATE(int32_t, int16_t)
 | 
			
		||||
CONVERT_TEMPLATE(int32_t, int32_t)
 | 
			
		||||
CONVERT_TEMPLATE(int32_t, int64_t)
 | 
			
		||||
CONVERT_TEMPLATE(int32_t, float)
 | 
			
		||||
CONVERT_TEMPLATE(int32_t, double)
 | 
			
		||||
CONVERT_TEMPLATE(int64_t, uint8_t)
 | 
			
		||||
CONVERT_TEMPLATE(int64_t, int8_t)
 | 
			
		||||
CONVERT_TEMPLATE(int64_t, int16_t)
 | 
			
		||||
CONVERT_TEMPLATE(int64_t, int32_t)
 | 
			
		||||
CONVERT_TEMPLATE(int64_t, int64_t)
 | 
			
		||||
CONVERT_TEMPLATE(int64_t, float)
 | 
			
		||||
CONVERT_TEMPLATE(int64_t, double)
 | 
			
		||||
CONVERT_TEMPLATE(float, uint8_t)
 | 
			
		||||
CONVERT_TEMPLATE(float, int8_t)
 | 
			
		||||
CONVERT_TEMPLATE(float, int16_t)
 | 
			
		||||
CONVERT_TEMPLATE(float, int32_t)
 | 
			
		||||
CONVERT_TEMPLATE(float, int64_t)
 | 
			
		||||
CONVERT_TEMPLATE(float, float)
 | 
			
		||||
CONVERT_TEMPLATE(float, double)
 | 
			
		||||
CONVERT_TEMPLATE(double, uint8_t)
 | 
			
		||||
CONVERT_TEMPLATE(double, int8_t)
 | 
			
		||||
CONVERT_TEMPLATE(double, int16_t)
 | 
			
		||||
CONVERT_TEMPLATE(double, int32_t)
 | 
			
		||||
CONVERT_TEMPLATE(double, int64_t)
 | 
			
		||||
CONVERT_TEMPLATE(double, float)
 | 
			
		||||
CONVERT_TEMPLATE(double, double)
 | 
			
		||||
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
 | 
			
		||||
 | 
			
		||||
#define CONVERT_FROM_FP16_TEMPLATE(to_type)                            \
 | 
			
		||||
  template <>                                                          \
 | 
			
		||||
  inline void convert(const at::Half* src, to_type* dst, int64_t n) {  \
 | 
			
		||||
    const float16_t* srcPtr = reinterpret_cast<const float16_t*>(src); \
 | 
			
		||||
    return convertImpl<float16_t, to_type>(srcPtr, dst, n);            \
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
#define CONVERT_TO_FP16_TEMPLATE(from_type)                             \
 | 
			
		||||
  template <>                                                           \
 | 
			
		||||
  inline void convert(const from_type* src, at::Half* dst, int64_t n) { \
 | 
			
		||||
    float16_t* dstPtr = reinterpret_cast<float16_t*>(dst);              \
 | 
			
		||||
    return convertImpl<from_type, float16_t>(src, dstPtr, n);           \
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
CONVERT_FROM_FP16_TEMPLATE(uint8_t)
 | 
			
		||||
CONVERT_FROM_FP16_TEMPLATE(int8_t)
 | 
			
		||||
CONVERT_FROM_FP16_TEMPLATE(int16_t)
 | 
			
		||||
CONVERT_FROM_FP16_TEMPLATE(int32_t)
 | 
			
		||||
CONVERT_FROM_FP16_TEMPLATE(int64_t)
 | 
			
		||||
CONVERT_FROM_FP16_TEMPLATE(float16_t)
 | 
			
		||||
CONVERT_FROM_FP16_TEMPLATE(float)
 | 
			
		||||
CONVERT_FROM_FP16_TEMPLATE(double)
 | 
			
		||||
CONVERT_TO_FP16_TEMPLATE(uint8_t)
 | 
			
		||||
CONVERT_TO_FP16_TEMPLATE(int8_t)
 | 
			
		||||
CONVERT_TO_FP16_TEMPLATE(int16_t)
 | 
			
		||||
CONVERT_TO_FP16_TEMPLATE(int32_t)
 | 
			
		||||
CONVERT_TO_FP16_TEMPLATE(int64_t)
 | 
			
		||||
CONVERT_TO_FP16_TEMPLATE(float)
 | 
			
		||||
CONVERT_TO_FP16_TEMPLATE(double)
 | 
			
		||||
#endif
 | 
			
		||||
#ifdef __ARM_FEATURE_BF16
 | 
			
		||||
CONVERT_TEMPLATE(bfloat16_t, uint8_t)
 | 
			
		||||
CONVERT_TEMPLATE(bfloat16_t, int8_t)
 | 
			
		||||
CONVERT_TEMPLATE(bfloat16_t, int16_t)
 | 
			
		||||
CONVERT_TEMPLATE(bfloat16_t, int32_t)
 | 
			
		||||
CONVERT_TEMPLATE(bfloat16_t, int64_t)
 | 
			
		||||
CONVERT_TEMPLATE(bfloat16_t, bfloat16_t)
 | 
			
		||||
CONVERT_TEMPLATE(bfloat16_t, float)
 | 
			
		||||
CONVERT_TEMPLATE(bfloat16_t, double)
 | 
			
		||||
CONVERT_TEMPLATE(uint8_t, bfloat16_t)
 | 
			
		||||
CONVERT_TEMPLATE(int8_t, bfloat16_t)
 | 
			
		||||
CONVERT_TEMPLATE(int16_t, bfloat16_t)
 | 
			
		||||
CONVERT_TEMPLATE(int32_t, bfloat16_t)
 | 
			
		||||
CONVERT_TEMPLATE(int64_t, bfloat16_t)
 | 
			
		||||
CONVERT_TEMPLATE(float, bfloat16_t)
 | 
			
		||||
CONVERT_TEMPLATE(double, bfloat16_t)
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
template <typename src_t>
 | 
			
		||||
struct VecConvert<
 | 
			
		||||
    float,
 | 
			
		||||
 | 
			
		||||
@ -1,586 +0,0 @@
 | 
			
		||||
#pragma once
 | 
			
		||||
 | 
			
		||||
#include <ATen/cpu/vec/intrinsics.h>
 | 
			
		||||
#include <ATen/cpu/vec/vec_base.h>
 | 
			
		||||
#include <c10/macros/Macros.h>
 | 
			
		||||
#include <c10/util/irange.h>
 | 
			
		||||
#include <cmath>
 | 
			
		||||
 | 
			
		||||
namespace at::vec {
 | 
			
		||||
// Note [CPU_CAPABILITY namespace]
 | 
			
		||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 | 
			
		||||
// This header, and all of its subheaders, will be compiled with
 | 
			
		||||
// different architecture flags for each supported set of vector
 | 
			
		||||
// intrinsics. So we need to make sure they aren't inadvertently
 | 
			
		||||
// linked together. We do this by declaring objects in an `inline
 | 
			
		||||
// namespace` which changes the name mangling, but can still be
 | 
			
		||||
// accessed as `at::vec`.
 | 
			
		||||
inline namespace CPU_CAPABILITY {
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
struct is_vec_specialized_for<double> : std::bool_constant<true> {};
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
class Vectorized<double> {
 | 
			
		||||
 private:
 | 
			
		||||
  float64x2_t values;
 | 
			
		||||
 | 
			
		||||
 public:
 | 
			
		||||
  using value_type = double;
 | 
			
		||||
  using size_type = int;
 | 
			
		||||
  static constexpr size_type size() {
 | 
			
		||||
    return 2;
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized() {
 | 
			
		||||
    values = vdupq_n_f64(0.0);
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized(float64x2_t v) : values(v) {}
 | 
			
		||||
  Vectorized(double val) {
 | 
			
		||||
    values = vdupq_n_f64(val);
 | 
			
		||||
  }
 | 
			
		||||
  template <
 | 
			
		||||
      typename... Args,
 | 
			
		||||
      typename = std::enable_if_t<(sizeof...(Args) == size())>>
 | 
			
		||||
  Vectorized(Args... vals) {
 | 
			
		||||
    __at_align__ double buffer[size()] = {vals...};
 | 
			
		||||
    values = vld1q_f64(buffer);
 | 
			
		||||
  }
 | 
			
		||||
  operator float64x2_t() const {
 | 
			
		||||
    return values;
 | 
			
		||||
  }
 | 
			
		||||
  template <int64_t mask>
 | 
			
		||||
  static Vectorized<double> blend(
 | 
			
		||||
      const Vectorized<double>& a,
 | 
			
		||||
      const Vectorized<double>& b) {
 | 
			
		||||
    // Build an array of flags: each bit of element is 1 if the corresponding
 | 
			
		||||
    // bit in 'mask' is set, 0 otherwise.
 | 
			
		||||
    uint64x2_t maskArray = {
 | 
			
		||||
        (mask & 1ULL) ? 0xFFFFFFFFFFFFFFFF : 0,
 | 
			
		||||
        (mask & 2ULL) ? 0xFFFFFFFFFFFFFFFF : 0};
 | 
			
		||||
    // Use BSL to select elements from b where the mask is 1, else from a
 | 
			
		||||
    return vbslq_f64(maskArray, b.values, a.values);
 | 
			
		||||
  }
 | 
			
		||||
  static Vectorized<double> blendv(
 | 
			
		||||
      const Vectorized<double>& a,
 | 
			
		||||
      const Vectorized<double>& b,
 | 
			
		||||
      const Vectorized<double>& mask_) {
 | 
			
		||||
    return vbslq_f64(vreinterpretq_u64_f64(mask_.values), b.values, a.values);
 | 
			
		||||
  }
 | 
			
		||||
  template <typename step_t>
 | 
			
		||||
  static Vectorized<double> arange(
 | 
			
		||||
      double base = 0.,
 | 
			
		||||
      step_t step = static_cast<step_t>(1)) {
 | 
			
		||||
    return {base, base + static_cast<double>(step)};
 | 
			
		||||
  }
 | 
			
		||||
  static inline Vectorized<double> set(
 | 
			
		||||
      const Vectorized<double>& a,
 | 
			
		||||
      const Vectorized<double>& b,
 | 
			
		||||
      int64_t count = size()) {
 | 
			
		||||
    if (count == 0) {
 | 
			
		||||
      return a;
 | 
			
		||||
    } else if (count >= 2) {
 | 
			
		||||
      return b;
 | 
			
		||||
    } else {
 | 
			
		||||
      float64x2_t c = {b.values[0], a.values[1]};
 | 
			
		||||
      return c;
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
  static Vectorized<double> loadu(const void* ptr, int64_t count = size()) {
 | 
			
		||||
    if (count == size()) {
 | 
			
		||||
      return vld1q_f64(reinterpret_cast<const double*>(ptr));
 | 
			
		||||
    } else if (count == 1) {
 | 
			
		||||
      float64x1_t x = vld1_f64(reinterpret_cast<const double*>(ptr));
 | 
			
		||||
      float64x1_t z = {0.0};
 | 
			
		||||
      return vcombine_f64(x, z);
 | 
			
		||||
    } else {
 | 
			
		||||
      return vdupq_n_f64(0.0);
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
  void store(void* ptr, int64_t count = size()) const {
 | 
			
		||||
    if (count == size()) {
 | 
			
		||||
      vst1q_f64(reinterpret_cast<double*>(ptr), values);
 | 
			
		||||
    } else if (count == 1) {
 | 
			
		||||
      vst1_f64(reinterpret_cast<double*>(ptr), vget_low_f64(values));
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
  const double& operator[](int idx) const = delete;
 | 
			
		||||
  double& operator[](int idx) = delete;
 | 
			
		||||
  int64_t zero_mask() const {
 | 
			
		||||
    // returns an integer mask where all zero elements are translated to 1-bit
 | 
			
		||||
    // and others are translated to 0-bit
 | 
			
		||||
    uint64x2_t cmpReg = vceqzq_f64(values);
 | 
			
		||||
    uint64x2_t mask = {1, 2};
 | 
			
		||||
    uint64x2_t res = vandq_u64(cmpReg, mask);
 | 
			
		||||
    return res[0] | res[1];
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized<double> isnan() const {
 | 
			
		||||
    // NaN check
 | 
			
		||||
    return vreinterpretq_f64_u32(
 | 
			
		||||
        vmvnq_u32(vreinterpretq_u32_u64(vceqq_f64(values, values))));
 | 
			
		||||
  }
 | 
			
		||||
  bool has_inf_nan() const {
 | 
			
		||||
    Vectorized<double> x = vsubq_f64(values, values);
 | 
			
		||||
    float64x2_t r = x.isnan();
 | 
			
		||||
    uint64x2_t u = vreinterpretq_u64_f64(r);
 | 
			
		||||
    return u[0] | u[1];
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized<double> map(double (*f)(double)) const {
 | 
			
		||||
    float64x2_t result;
 | 
			
		||||
    result[0] = f(values[0]);
 | 
			
		||||
    result[1] = f(values[1]);
 | 
			
		||||
    return result;
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized<double> map2(
 | 
			
		||||
      const Vectorized<double>& second,
 | 
			
		||||
      double (*const f)(double, double)) const {
 | 
			
		||||
    float64x2_t result;
 | 
			
		||||
    result[0] = f(values[0], second.values[0]);
 | 
			
		||||
    result[1] = f(values[1], second.values[1]);
 | 
			
		||||
    return result;
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized<double> abs() const {
 | 
			
		||||
    return vabsq_f64(values);
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized<double> angle() const {
 | 
			
		||||
    auto zero = Vectorized<double>(0.0);
 | 
			
		||||
    auto pi = Vectorized<double>(c10::pi<double>);
 | 
			
		||||
    auto tmp = blendv(zero, pi, vreinterpretq_f64_u64(vcltzq_f64(values)));
 | 
			
		||||
    return blendv(tmp, *this, isnan());
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized<double> real() const {
 | 
			
		||||
    return *this;
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized<double> imag() const {
 | 
			
		||||
    return Vectorized<double>(0.0);
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized<double> conj() const {
 | 
			
		||||
    return *this;
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized<double> acos() const {
 | 
			
		||||
    return USE_SLEEF(
 | 
			
		||||
        Vectorized<double>(Sleef_acosd2_u10(values)), map(std::acos));
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized<double> acosh() const {
 | 
			
		||||
    return USE_SLEEF(
 | 
			
		||||
        Vectorized<double>(Sleef_acoshd2_u10(values)), map(std::acosh));
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized<double> asin() const {
 | 
			
		||||
    return USE_SLEEF(
 | 
			
		||||
        Vectorized<double>(Sleef_asind2_u10(values)), map(std::asin));
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized<double> asinh() const {
 | 
			
		||||
    return USE_SLEEF(
 | 
			
		||||
        Vectorized<double>(Sleef_asinhd2_u10(values)), map(std::asinh));
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized<double> atan() const {
 | 
			
		||||
    return USE_SLEEF(
 | 
			
		||||
        Vectorized<double>(Sleef_atand2_u10(values)), map(std::atan));
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized<double> atanh() const {
 | 
			
		||||
    return USE_SLEEF(
 | 
			
		||||
        Vectorized<double>(Sleef_atanhd2_u10(values)), map(std::atanh));
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized<double> atan2(const Vectorized<double>& b) const {USE_SLEEF(
 | 
			
		||||
      { return Vectorized<double>(Sleef_atan2d2_u10(values, b)); },
 | 
			
		||||
      {
 | 
			
		||||
        __at_align__ double tmp[size()];
 | 
			
		||||
        __at_align__ double tmp_b[size()];
 | 
			
		||||
        store(tmp);
 | 
			
		||||
        b.store(tmp_b);
 | 
			
		||||
        for (int64_t i = 0; i < size(); i++) {
 | 
			
		||||
          tmp[i] = std::atan2(tmp[i], tmp_b[i]);
 | 
			
		||||
        }
 | 
			
		||||
        return loadu(tmp);
 | 
			
		||||
      })} Vectorized<double> copysign(const Vectorized<double>& sign) const {
 | 
			
		||||
      USE_SLEEF(
 | 
			
		||||
          { return Vectorized<double>(Sleef_copysignd2(values, sign)); },
 | 
			
		||||
          {
 | 
			
		||||
            __at_align__ double tmp[size()];
 | 
			
		||||
            __at_align__ double tmp_sign[size()];
 | 
			
		||||
            store(tmp);
 | 
			
		||||
            sign.store(tmp_sign);
 | 
			
		||||
            for (int64_t i = 0; i < size(); i++) {
 | 
			
		||||
              tmp[i] = std::copysign(tmp[i], tmp_sign[i]);
 | 
			
		||||
            }
 | 
			
		||||
            return loadu(tmp);
 | 
			
		||||
          })} Vectorized<double> erf() const {
 | 
			
		||||
    return USE_SLEEF(
 | 
			
		||||
        Vectorized<double>(Sleef_erfd2_u10(values)), map(std::erf));
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized<double> erfc() const {
 | 
			
		||||
    return USE_SLEEF(
 | 
			
		||||
        Vectorized<double>(Sleef_erfcd2_u15(values)), map(std::erfc));
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized<double> exp() const {
 | 
			
		||||
    return USE_SLEEF(
 | 
			
		||||
        Vectorized<double>(Sleef_expd2_u10(values)), map(std::exp));
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized<double> exp2() const {
 | 
			
		||||
    return USE_SLEEF(
 | 
			
		||||
        Vectorized<double>(Sleef_exp2d2_u10(values)), map(std::exp2));
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized<double> expm1() const {
 | 
			
		||||
    return USE_SLEEF(
 | 
			
		||||
        Vectorized<double>(Sleef_expm1d2_u10(values)), map(std::expm1));
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized<double> fmod(const Vectorized<double>& q) const {USE_SLEEF(
 | 
			
		||||
      { return Vectorized<double>(Sleef_fmodd2(values, q)); },
 | 
			
		||||
      {
 | 
			
		||||
        __at_align__ double tmp[size()];
 | 
			
		||||
        __at_align__ double tmp_q[size()];
 | 
			
		||||
        store(tmp);
 | 
			
		||||
        q.store(tmp_q);
 | 
			
		||||
        for (int64_t i = 0; i < size(); i++) {
 | 
			
		||||
          tmp[i] = std::fmod(tmp[i], tmp_q[i]);
 | 
			
		||||
        }
 | 
			
		||||
        return loadu(tmp);
 | 
			
		||||
      })} Vectorized<double> hypot(const Vectorized<double>& b) const {
 | 
			
		||||
      USE_SLEEF(
 | 
			
		||||
          { return Vectorized<double>(Sleef_hypotd2_u05(values, b)); },
 | 
			
		||||
          {
 | 
			
		||||
            __at_align__ double tmp[size()];
 | 
			
		||||
            __at_align__ double tmp_b[size()];
 | 
			
		||||
            store(tmp);
 | 
			
		||||
            b.store(tmp_b);
 | 
			
		||||
            for (int64_t i = 0; i < size(); i++) {
 | 
			
		||||
              tmp[i] = std::hypot(tmp[i], tmp_b[i]);
 | 
			
		||||
            }
 | 
			
		||||
            return loadu(tmp);
 | 
			
		||||
          })} Vectorized<double> i0() const {
 | 
			
		||||
    return map(calc_i0);
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized<double> nextafter(const Vectorized<double>& b) const {USE_SLEEF(
 | 
			
		||||
      { return Vectorized<double>(Sleef_nextafterd2(values, b)); },
 | 
			
		||||
      {
 | 
			
		||||
        __at_align__ double tmp[size()];
 | 
			
		||||
        __at_align__ double tmp_b[size()];
 | 
			
		||||
        store(tmp);
 | 
			
		||||
        b.store(tmp_b);
 | 
			
		||||
        for (int64_t i = 0; i < size(); ++i) {
 | 
			
		||||
          tmp[i] = std::nextafter(tmp[i], tmp_b[i]);
 | 
			
		||||
        }
 | 
			
		||||
        return loadu(tmp);
 | 
			
		||||
      })} Vectorized<double> log() const {
 | 
			
		||||
    return USE_SLEEF(
 | 
			
		||||
        Vectorized<double>(Sleef_logd2_u10(values)), map(std::log));
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized<double> log2() const {
 | 
			
		||||
    return USE_SLEEF(
 | 
			
		||||
        Vectorized<double>(Sleef_log2d2_u10(values)), map(std::log2));
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized<double> log10() const {
 | 
			
		||||
    return USE_SLEEF(
 | 
			
		||||
        Vectorized<double>(Sleef_log10d2_u10(values)), map(std::log10));
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized<double> log1p() const {
 | 
			
		||||
    return USE_SLEEF(
 | 
			
		||||
        Vectorized<double>(Sleef_log1pd2_u10(values)), map(std::log1p));
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized<double> frac() const;
 | 
			
		||||
  Vectorized<double> sin() const {
 | 
			
		||||
    return USE_SLEEF(
 | 
			
		||||
        Vectorized<double>(Sleef_sind2_u10(values)), map(std::sin));
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized<double> sinh() const {
 | 
			
		||||
    return USE_SLEEF(
 | 
			
		||||
        Vectorized<double>(Sleef_sinhd2_u10(values)), map(std::sinh));
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized<double> cos() const {
 | 
			
		||||
    return USE_SLEEF(
 | 
			
		||||
        Vectorized<double>(Sleef_cosd2_u10(values)), map(std::cos));
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized<double> cosh() const {
 | 
			
		||||
    return USE_SLEEF(
 | 
			
		||||
        Vectorized<double>(Sleef_coshd2_u10(values)), map(std::cosh));
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized<double> pow(const Vectorized<double>& b) const {USE_SLEEF(
 | 
			
		||||
      { return Vectorized<double>(Sleef_powd2_u10(values, b)); },
 | 
			
		||||
      {
 | 
			
		||||
        __at_align__ double tmp[size()];
 | 
			
		||||
        __at_align__ double tmp_b[size()];
 | 
			
		||||
        store(tmp);
 | 
			
		||||
        b.store(tmp_b);
 | 
			
		||||
        for (int64_t i = 0; i < size(); i++) {
 | 
			
		||||
          tmp[i] = std::pow(tmp[i], tmp_b[i]);
 | 
			
		||||
        }
 | 
			
		||||
        return loadu(tmp);
 | 
			
		||||
      })} // Comparison using the _CMP_**_OQ predicate.
 | 
			
		||||
          //   `O`: get false if an operand is NaN
 | 
			
		||||
          //   `Q`: do not raise if an operand is NaN
 | 
			
		||||
  Vectorized<double> tan() const {
 | 
			
		||||
    return USE_SLEEF(
 | 
			
		||||
        Vectorized<double>(Sleef_tand2_u10(values)), map(std::tan));
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized<double> tanh() const {
 | 
			
		||||
    return USE_SLEEF(
 | 
			
		||||
        Vectorized<double>(Sleef_tanhd2_u10(values)), map(std::tanh));
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized<double> lgamma() const {
 | 
			
		||||
    return USE_SLEEF(
 | 
			
		||||
        Vectorized<double>(Sleef_lgammad2_u10(values)), map(std::lgamma));
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized<double> erfinv() const {
 | 
			
		||||
    return map(calc_erfinv);
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized<double> exp_u20() const {
 | 
			
		||||
    return exp();
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized<double> fexp_u20() const {
 | 
			
		||||
    return exp();
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized<double> i0e() const {
 | 
			
		||||
    return map(calc_i0e);
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized<double> digamma() const {
 | 
			
		||||
    return map(calc_digamma);
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized<double> igamma(const Vectorized<double>& x) const {
 | 
			
		||||
    __at_align__ double tmp[size()];
 | 
			
		||||
    __at_align__ double tmp_x[size()];
 | 
			
		||||
    store(tmp);
 | 
			
		||||
    x.store(tmp_x);
 | 
			
		||||
    for (int64_t i = 0; i < size(); i++) {
 | 
			
		||||
      tmp[i] = calc_igamma(tmp[i], tmp_x[i]);
 | 
			
		||||
    }
 | 
			
		||||
    return loadu(tmp);
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized<double> igammac(const Vectorized<double>& x) const {
 | 
			
		||||
    __at_align__ double tmp[size()];
 | 
			
		||||
    __at_align__ double tmp_x[size()];
 | 
			
		||||
    store(tmp);
 | 
			
		||||
    x.store(tmp_x);
 | 
			
		||||
    for (int64_t i = 0; i < size(); i++) {
 | 
			
		||||
      tmp[i] = calc_igammac(tmp[i], tmp_x[i]);
 | 
			
		||||
    }
 | 
			
		||||
    return loadu(tmp);
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized<double> ceil() const {
 | 
			
		||||
    return vrndpq_f64(values);
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized<double> floor() const {
 | 
			
		||||
    return vrndmq_f64(values);
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized<double> neg() const {
 | 
			
		||||
    return vnegq_f64(values);
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized<double> round() const {
 | 
			
		||||
    return vrndiq_f64(values);
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized<double> trunc() const {
 | 
			
		||||
    return vrndq_f64(values);
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized<double> sqrt() const {
 | 
			
		||||
    return vsqrtq_f64(values);
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized<double> reciprocal() const {
 | 
			
		||||
    return vdivq_f64(vdupq_n_f64(1.0), values);
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized<double> rsqrt() const {
 | 
			
		||||
    return vdivq_f64(vdupq_n_f64(1.0), vsqrtq_f64(values));
 | 
			
		||||
  }
 | 
			
		||||
  double reduce_add() const {
 | 
			
		||||
    return vaddvq_f64(values);
 | 
			
		||||
  }
 | 
			
		||||
  double reduce_max() const {
 | 
			
		||||
    return vmaxvq_f64(values);
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized<double> operator==(const Vectorized<double>& other) const {
 | 
			
		||||
    return Vectorized<double>(
 | 
			
		||||
        vreinterpretq_f64_u64(vceqq_f64(values, other.values)));
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  Vectorized<double> operator!=(const Vectorized<double>& other) const {
 | 
			
		||||
    float64x2_t r0 = vreinterpretq_f64_u32(
 | 
			
		||||
        vmvnq_u32(vreinterpretq_u32_u64(vceqq_f64(values, other.values))));
 | 
			
		||||
    return Vectorized<double>(r0);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  Vectorized<double> operator<(const Vectorized<double>& other) const {
 | 
			
		||||
    return Vectorized<double>(
 | 
			
		||||
        vreinterpretq_f64_u64(vcltq_f64(values, other.values)));
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  Vectorized<double> operator<=(const Vectorized<double>& other) const {
 | 
			
		||||
    return Vectorized<double>(
 | 
			
		||||
        vreinterpretq_f64_u64(vcleq_f64(values, other.values)));
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  Vectorized<double> operator>(const Vectorized<double>& other) const {
 | 
			
		||||
    return Vectorized<double>(
 | 
			
		||||
        vreinterpretq_f64_u64(vcgtq_f64(values, other.values)));
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  Vectorized<double> operator>=(const Vectorized<double>& other) const {
 | 
			
		||||
    return Vectorized<double>(
 | 
			
		||||
        vreinterpretq_f64_u64(vcgeq_f64(values, other.values)));
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  Vectorized<double> eq(const Vectorized<double>& other) const;
 | 
			
		||||
  Vectorized<double> ne(const Vectorized<double>& other) const;
 | 
			
		||||
  Vectorized<double> gt(const Vectorized<double>& other) const;
 | 
			
		||||
  Vectorized<double> ge(const Vectorized<double>& other) const;
 | 
			
		||||
  Vectorized<double> lt(const Vectorized<double>& other) const;
 | 
			
		||||
  Vectorized<double> le(const Vectorized<double>& other) const;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<double> inline operator+(
 | 
			
		||||
    const Vectorized<double>& a,
 | 
			
		||||
    const Vectorized<double>& b) {
 | 
			
		||||
  return vaddq_f64(a, b);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<double> inline operator-(
 | 
			
		||||
    const Vectorized<double>& a,
 | 
			
		||||
    const Vectorized<double>& b) {
 | 
			
		||||
  return vsubq_f64(a, b);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<double> inline operator*(
 | 
			
		||||
    const Vectorized<double>& a,
 | 
			
		||||
    const Vectorized<double>& b) {
 | 
			
		||||
  return vmulq_f64(a, b);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<double> inline operator/(
 | 
			
		||||
    const Vectorized<double>& a,
 | 
			
		||||
    const Vectorized<double>& b) {
 | 
			
		||||
  return vdivq_f64(a, b);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// frac. Implement this here so we can use subtraction
 | 
			
		||||
Vectorized<double> inline Vectorized<double>::frac() const {
 | 
			
		||||
  return *this - this->trunc();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Implements the IEEE 754 201X `maximum` operation, which propagates NaN if
 | 
			
		||||
// either input is a NaN.
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<double> inline maximum(
 | 
			
		||||
    const Vectorized<double>& a,
 | 
			
		||||
    const Vectorized<double>& b) {
 | 
			
		||||
  return vmaxq_f64(a, b);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Implements the IEEE 754 201X `minimum` operation, which propagates NaN if
 | 
			
		||||
// either input is a NaN.
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<double> inline minimum(
 | 
			
		||||
    const Vectorized<double>& a,
 | 
			
		||||
    const Vectorized<double>& b) {
 | 
			
		||||
  return vminq_f64(a, b);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<double> inline clamp(
 | 
			
		||||
    const Vectorized<double>& a,
 | 
			
		||||
    const Vectorized<double>& min,
 | 
			
		||||
    const Vectorized<double>& max) {
 | 
			
		||||
  return vminq_f64(max, vmaxq_f64(min, a));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<double> inline clamp_max(
 | 
			
		||||
    const Vectorized<double>& a,
 | 
			
		||||
    const Vectorized<double>& max) {
 | 
			
		||||
  return vminq_f64(max, a);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<double> inline clamp_min(
 | 
			
		||||
    const Vectorized<double>& a,
 | 
			
		||||
    const Vectorized<double>& min) {
 | 
			
		||||
  return vmaxq_f64(min, a);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<double> inline operator&(
 | 
			
		||||
    const Vectorized<double>& a,
 | 
			
		||||
    const Vectorized<double>& b) {
 | 
			
		||||
  return vreinterpretq_f64_u64(
 | 
			
		||||
      vandq_u64(vreinterpretq_u64_f64(a), vreinterpretq_u64_f64(b)));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<double> inline operator|(
 | 
			
		||||
    const Vectorized<double>& a,
 | 
			
		||||
    const Vectorized<double>& b) {
 | 
			
		||||
  return vreinterpretq_f64_u64(
 | 
			
		||||
      vorrq_u64(vreinterpretq_u64_f64(a), vreinterpretq_u64_f64(b)));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<double> inline operator^(
 | 
			
		||||
    const Vectorized<double>& a,
 | 
			
		||||
    const Vectorized<double>& b) {
 | 
			
		||||
  return vreinterpretq_f64_u64(
 | 
			
		||||
      veorq_u64(vreinterpretq_u64_f64(a), vreinterpretq_u64_f64(b)));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
inline Vectorized<double> Vectorized<double>::eq(
 | 
			
		||||
    const Vectorized<double>& other) const {
 | 
			
		||||
  return (*this == other) & Vectorized<double>(1.0);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
inline Vectorized<double> Vectorized<double>::ne(
 | 
			
		||||
    const Vectorized<double>& other) const {
 | 
			
		||||
  return (*this != other) & Vectorized<double>(1.0);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
inline Vectorized<double> Vectorized<double>::gt(
 | 
			
		||||
    const Vectorized<double>& other) const {
 | 
			
		||||
  return (*this > other) & Vectorized<double>(1.0);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
inline Vectorized<double> Vectorized<double>::ge(
 | 
			
		||||
    const Vectorized<double>& other) const {
 | 
			
		||||
  return (*this >= other) & Vectorized<double>(1.0);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
inline Vectorized<double> Vectorized<double>::lt(
 | 
			
		||||
    const Vectorized<double>& other) const {
 | 
			
		||||
  return (*this < other) & Vectorized<double>(1.0);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
inline Vectorized<double> Vectorized<double>::le(
 | 
			
		||||
    const Vectorized<double>& other) const {
 | 
			
		||||
  return (*this <= other) & Vectorized<double>(1.0);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<double> inline fmadd(
 | 
			
		||||
    const Vectorized<double>& a,
 | 
			
		||||
    const Vectorized<double>& b,
 | 
			
		||||
    const Vectorized<double>& c) {
 | 
			
		||||
  return vfmaq_f64(c, a, b);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<double> inline fnmadd(
 | 
			
		||||
    const Vectorized<double>& a,
 | 
			
		||||
    const Vectorized<double>& b,
 | 
			
		||||
    const Vectorized<double>& c) {
 | 
			
		||||
  return vfmsq_f64(c, a, b);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<double> inline fmsub(
 | 
			
		||||
    const Vectorized<double>& a,
 | 
			
		||||
    const Vectorized<double>& b,
 | 
			
		||||
    const Vectorized<double>& c) {
 | 
			
		||||
  return vfmaq_f64(vnegq_f64(c), a, b);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<double> inline fnmsub(
 | 
			
		||||
    const Vectorized<double>& a,
 | 
			
		||||
    const Vectorized<double>& b,
 | 
			
		||||
    const Vectorized<double>& c) {
 | 
			
		||||
  return vfmsq_f64(vnegq_f64(c), a, b);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
} // namespace CPU_CAPABILITY
 | 
			
		||||
} // namespace at::vec
 | 
			
		||||
@ -307,49 +307,11 @@ class Vectorized<float> {
 | 
			
		||||
  DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(exp)
 | 
			
		||||
  DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(exp2)
 | 
			
		||||
  DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(expm1)
 | 
			
		||||
  // Implementation copied from Arm Optimized Routine
 | 
			
		||||
  // https://github.com/ARM-software/optimized-routines/blob/master/math/aarch64/advsimd/expf.c
 | 
			
		||||
  Vectorized<float> exp_u20() const {
 | 
			
		||||
    // bail out to sleef if it's a special case:
 | 
			
		||||
    // i.e. there's an input s.t. |input| > 87.3....
 | 
			
		||||
    const float32x4_t special_bound = vdupq_n_f32(0x1.5d5e2ap+6f);
 | 
			
		||||
    uint32x4_t cmp = vcagtq_f32(values, special_bound);
 | 
			
		||||
    if (vpaddd_u64(vreinterpretq_u64_u32(cmp)) != 0) {
 | 
			
		||||
      return exp();
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    const float32x4_t inv_ln2 = vdupq_n_f32(0x1.715476p+0f);
 | 
			
		||||
    const float ln2_hi = 0x1.62e4p-1f;
 | 
			
		||||
    const float ln2_lo = 0x1.7f7d1cp-20f;
 | 
			
		||||
    const float c0 = 0x1.0e4020p-7f;
 | 
			
		||||
    const float c2 = 0x1.555e66p-3f;
 | 
			
		||||
    const float32x4_t ln2_c02 = {ln2_hi, ln2_lo, c0, c2};
 | 
			
		||||
 | 
			
		||||
    const uint32x4_t exponent_bias = vdupq_n_u32(0x3f800000);
 | 
			
		||||
    const float32x4_t c1 = vdupq_n_f32(0x1.573e2ep-5f);
 | 
			
		||||
    const float32x4_t c3 = vdupq_n_f32(0x1.fffdb6p-2f);
 | 
			
		||||
    const float32x4_t c4 = vdupq_n_f32(0x1.ffffecp-1f);
 | 
			
		||||
 | 
			
		||||
    /* exp(x) = 2^n (1 + poly(r)), with 1 + poly(r) in [1/sqrt(2),sqrt(2)]
 | 
			
		||||
      x = ln2*n + r, with r in [-ln2/2, ln2/2].  */
 | 
			
		||||
 | 
			
		||||
    float32x4_t n = vrndaq_f32(vmulq_f32(values, inv_ln2));
 | 
			
		||||
    float32x4_t r = vfmsq_laneq_f32(values, n, ln2_c02, 0);
 | 
			
		||||
    r = vfmsq_laneq_f32(r, n, ln2_c02, 1);
 | 
			
		||||
    uint32x4_t e = vshlq_n_u32(vreinterpretq_u32_s32(vcvtq_s32_f32(n)), 23);
 | 
			
		||||
    float32x4_t scale = vreinterpretq_f32_u32(vaddq_u32(e, exponent_bias));
 | 
			
		||||
 | 
			
		||||
    float32x4_t r2 = vmulq_f32(r, r);
 | 
			
		||||
    float32x4_t p = vfmaq_laneq_f32(c1, r, ln2_c02, 2);
 | 
			
		||||
    float32x4_t q = vfmaq_laneq_f32(c3, r, ln2_c02, 3);
 | 
			
		||||
    q = vfmaq_f32(q, p, r2);
 | 
			
		||||
    p = vmulq_f32(c4, r);
 | 
			
		||||
    float32x4_t poly = vfmaq_f32(p, q, r2);
 | 
			
		||||
 | 
			
		||||
    return vfmaq_f32(scale, poly, scale);
 | 
			
		||||
    return exp();
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized<float> fexp_u20() const {
 | 
			
		||||
    return exp_u20();
 | 
			
		||||
    return exp();
 | 
			
		||||
  }
 | 
			
		||||
  DEFINE_SLEEF_COMPATIBLE_BINARY_ELEMENTWISE_FUNC_WITH_SLEEF_NAME(
 | 
			
		||||
      fmod,
 | 
			
		||||
@ -578,6 +540,42 @@ inline Vectorized<float> Vectorized<float>::le(
 | 
			
		||||
  return (*this <= other) & Vectorized<float>(1.0f);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
inline void convert(const float* src, int32_t* dst, int64_t n) {
 | 
			
		||||
  int64_t i;
 | 
			
		||||
#ifndef __msvc_cl__
 | 
			
		||||
#pragma unroll
 | 
			
		||||
#endif
 | 
			
		||||
  for (i = 0; i <= (n - Vectorized<float>::size());
 | 
			
		||||
       i += Vectorized<float>::size()) {
 | 
			
		||||
    vst1q_s32(dst + i, vcvtq_s32_f32(vld1q_f32(src + i)));
 | 
			
		||||
  }
 | 
			
		||||
#ifndef __msvc_cl__
 | 
			
		||||
#pragma unroll
 | 
			
		||||
#endif
 | 
			
		||||
  for (; i < n; i++) {
 | 
			
		||||
    dst[i] = static_cast<int32_t>(src[i]);
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
inline void convert(const int32_t* src, float* dst, int64_t n) {
 | 
			
		||||
  int64_t i;
 | 
			
		||||
#ifndef __msvc_cl__
 | 
			
		||||
#pragma unroll
 | 
			
		||||
#endif
 | 
			
		||||
  for (i = 0; i <= (n - Vectorized<float>::size());
 | 
			
		||||
       i += Vectorized<float>::size()) {
 | 
			
		||||
    vst1q_f32(dst + i, vcvtq_f32_s32(vld1q_s32(src + i)));
 | 
			
		||||
  }
 | 
			
		||||
#ifndef __msvc_cl__
 | 
			
		||||
#pragma unroll
 | 
			
		||||
#endif
 | 
			
		||||
  for (; i < n; i++) {
 | 
			
		||||
    dst[i] = static_cast<float>(src[i]);
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<float> inline fmadd(
 | 
			
		||||
    const Vectorized<float>& a,
 | 
			
		||||
@ -634,7 +632,8 @@ inline Vectorized<float> Vectorized<float>::erf() const {
 | 
			
		||||
  // - exp(- x * x)
 | 
			
		||||
  auto pow_2 = (*this) * (*this);
 | 
			
		||||
  auto neg_pow_2 = pow_2 ^ neg_zero_vec;
 | 
			
		||||
  auto tmp4 = neg_pow_2.exp();
 | 
			
		||||
  auto tmp4 = neg_pow_2.map(
 | 
			
		||||
      std::exp); // This can be swapped for a faster implementation of exp.
 | 
			
		||||
  auto tmp5 = tmp4 ^ neg_zero_vec;
 | 
			
		||||
  // erf(x) = sign(x) * (1 - r * t * exp(- x * x))
 | 
			
		||||
  auto tmp6 = t * tmp5;
 | 
			
		||||
 | 
			
		||||
@ -234,7 +234,7 @@ class Vectorized<c10::Half> : public Vectorized16<
 | 
			
		||||
        vshlq_u16(vandq_u16(is_zero_vec, vdupq_n_u16(1)), shift);
 | 
			
		||||
    return vaddvq_u16(bits_vec);
 | 
			
		||||
#else // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
 | 
			
		||||
    // use known working implementation.
 | 
			
		||||
    // use known working implmentation.
 | 
			
		||||
    __at_align__ value_type tmp[size()];
 | 
			
		||||
    store(tmp);
 | 
			
		||||
    int mask = 0;
 | 
			
		||||
@ -569,6 +569,46 @@ inline Vectorized<c10::Half> Vectorized<c10::Half>::le(
 | 
			
		||||
  return (*this <= other) & Vectorized<c10::Half>(1);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// These are global functions, so the defaults in vec_base.h should
 | 
			
		||||
// work fine if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC is not available.
 | 
			
		||||
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
 | 
			
		||||
template <>
 | 
			
		||||
inline void convert(const float16_t* src, int16_t* dst, int64_t n) {
 | 
			
		||||
  int64_t i;
 | 
			
		||||
#ifndef __msvc_cl__
 | 
			
		||||
#pragma unroll
 | 
			
		||||
#endif
 | 
			
		||||
  for (i = 0; i <= (n - Vectorized<c10::Half>::size());
 | 
			
		||||
       i += Vectorized<c10::Half>::size()) {
 | 
			
		||||
    vst1q_s16(dst + i, vcvtq_s16_f16(vld1q_f16(src + i)));
 | 
			
		||||
  }
 | 
			
		||||
#ifndef __msvc_cl__
 | 
			
		||||
#pragma unroll
 | 
			
		||||
#endif
 | 
			
		||||
  for (; i < n; i++) {
 | 
			
		||||
    dst[i] = static_cast<int16_t>(src[i]);
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
inline void convert(const int16_t* src, float16_t* dst, int64_t n) {
 | 
			
		||||
  int64_t i;
 | 
			
		||||
#ifndef __msvc_cl__
 | 
			
		||||
#pragma unroll
 | 
			
		||||
#endif
 | 
			
		||||
  for (i = 0; i <= (n - Vectorized<c10::Half>::size());
 | 
			
		||||
       i += Vectorized<c10::Half>::size()) {
 | 
			
		||||
    vst1q_f16(dst + i, vcvtq_f16_s16(vld1q_s16(src + i)));
 | 
			
		||||
  }
 | 
			
		||||
#ifndef __msvc_cl__
 | 
			
		||||
#pragma unroll
 | 
			
		||||
#endif
 | 
			
		||||
  for (; i < n; i++) {
 | 
			
		||||
    dst[i] = static_cast<float16_t>(src[i]);
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<c10::Half> inline fmadd(
 | 
			
		||||
    const Vectorized<c10::Half>& a,
 | 
			
		||||
 | 
			
		||||
@ -1,378 +0,0 @@
 | 
			
		||||
#pragma once
 | 
			
		||||
 | 
			
		||||
#include <ATen/cpu/vec/intrinsics.h>
 | 
			
		||||
#include <ATen/cpu/vec/vec_base.h>
 | 
			
		||||
#include <c10/macros/Macros.h>
 | 
			
		||||
#include <c10/util/irange.h>
 | 
			
		||||
 | 
			
		||||
namespace at::vec {
 | 
			
		||||
// Note [CPU_CAPABILITY namespace]
 | 
			
		||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 | 
			
		||||
// This header, and all of its subheaders, will be compiled with
 | 
			
		||||
// different architecture flags for each supported set of vector
 | 
			
		||||
// intrinsics. So we need to make sure they aren't inadvertently
 | 
			
		||||
// linked together. We do this by declaring objects in an `inline
 | 
			
		||||
// namespace` which changes the name mangling, but can still be
 | 
			
		||||
// accessed as `at::vec`.
 | 
			
		||||
inline namespace CPU_CAPABILITY {
 | 
			
		||||
 | 
			
		||||
#define VEC_UINT_NEON_TEMPLATE(vl, bit)                                       \
 | 
			
		||||
  template <>                                                                 \
 | 
			
		||||
  struct is_vec_specialized_for<uint##bit##_t> : std::bool_constant<true> {}; \
 | 
			
		||||
                                                                              \
 | 
			
		||||
  template <>                                                                 \
 | 
			
		||||
  class Vectorized<uint##bit##_t> {                                           \
 | 
			
		||||
    using neon_type = uint##bit##x##vl##_t;                                   \
 | 
			
		||||
                                                                              \
 | 
			
		||||
   private:                                                                   \
 | 
			
		||||
    neon_type values;                                                         \
 | 
			
		||||
                                                                              \
 | 
			
		||||
   public:                                                                    \
 | 
			
		||||
    using value_type = uint##bit##_t;                                         \
 | 
			
		||||
    using size_type = int;                                                    \
 | 
			
		||||
    static constexpr size_type size() {                                       \
 | 
			
		||||
      return vl;                                                              \
 | 
			
		||||
    }                                                                         \
 | 
			
		||||
    Vectorized() {                                                            \
 | 
			
		||||
      values = vdupq_n_u##bit(0);                                             \
 | 
			
		||||
    }                                                                         \
 | 
			
		||||
    Vectorized(neon_type v) : values(v) {}                                    \
 | 
			
		||||
    Vectorized(uint##bit##_t val);                                            \
 | 
			
		||||
    template <                                                                \
 | 
			
		||||
        typename... Args,                                                     \
 | 
			
		||||
        typename = std::enable_if_t<(sizeof...(Args) == size())>>             \
 | 
			
		||||
    Vectorized(Args... vals) {                                                \
 | 
			
		||||
      __at_align__ uint##bit##_t buffer[size()] = {vals...};                  \
 | 
			
		||||
      values = vld1q_u##bit(buffer);                                          \
 | 
			
		||||
    }                                                                         \
 | 
			
		||||
    operator neon_type() const {                                              \
 | 
			
		||||
      return values;                                                          \
 | 
			
		||||
    }                                                                         \
 | 
			
		||||
    static Vectorized<uint##bit##_t> loadu(                                   \
 | 
			
		||||
        const void* ptr,                                                      \
 | 
			
		||||
        uint64_t count = size());                                             \
 | 
			
		||||
    void store(void* ptr, uint64_t count = size()) const;                     \
 | 
			
		||||
    template <uint64_t mask>                                                  \
 | 
			
		||||
    static Vectorized<uint##bit##_t> blend(                                   \
 | 
			
		||||
        const Vectorized<uint##bit##_t>& a,                                   \
 | 
			
		||||
        const Vectorized<uint##bit##_t>& b);                                  \
 | 
			
		||||
    static Vectorized<uint##bit##_t> blendv(                                  \
 | 
			
		||||
        const Vectorized<uint##bit##_t>& a,                                   \
 | 
			
		||||
        const Vectorized<uint##bit##_t>& b,                                   \
 | 
			
		||||
        const Vectorized<uint##bit##_t>& mask_) {                             \
 | 
			
		||||
      return vbslq_u##bit(mask_.values, b, a);                                \
 | 
			
		||||
    }                                                                         \
 | 
			
		||||
    template <typename step_t>                                                \
 | 
			
		||||
    static Vectorized<uint##bit##_t> arange(                                  \
 | 
			
		||||
        value_type base = 0,                                                  \
 | 
			
		||||
        step_t step = static_cast<step_t>(1));                                \
 | 
			
		||||
    static Vectorized<uint##bit##_t> set(                                     \
 | 
			
		||||
        const Vectorized<uint##bit##_t>& a,                                   \
 | 
			
		||||
        const Vectorized<uint##bit##_t>& b,                                   \
 | 
			
		||||
        uint64_t count = size());                                             \
 | 
			
		||||
    const uint##bit##_t& operator[](uint idx) const = delete;                 \
 | 
			
		||||
    uint##bit##_t& operator[](uint idx) = delete;                             \
 | 
			
		||||
    Vectorized<uint##bit##_t> abs() const {                                   \
 | 
			
		||||
      return values;                                                          \
 | 
			
		||||
    }                                                                         \
 | 
			
		||||
    Vectorized<uint##bit##_t> real() const {                                  \
 | 
			
		||||
      return values;                                                          \
 | 
			
		||||
    }                                                                         \
 | 
			
		||||
    Vectorized<uint##bit##_t> imag() const {                                  \
 | 
			
		||||
      return vdupq_n_u##bit(0);                                               \
 | 
			
		||||
    }                                                                         \
 | 
			
		||||
    Vectorized<uint##bit##_t> conj() const {                                  \
 | 
			
		||||
      return values;                                                          \
 | 
			
		||||
    }                                                                         \
 | 
			
		||||
    Vectorized<uint##bit##_t> neg() const {                                   \
 | 
			
		||||
      return vreinterpretq_u##bit##_s##bit(                                   \
 | 
			
		||||
          vnegq_s##bit(vreinterpretq_s##bit##_u##bit(values)));               \
 | 
			
		||||
    }                                                                         \
 | 
			
		||||
    uint##bit##_t reduce_add() const {                                        \
 | 
			
		||||
      return vaddvq_u##bit(values);                                           \
 | 
			
		||||
    }                                                                         \
 | 
			
		||||
    uint##bit##_t reduce_max() const;                                         \
 | 
			
		||||
    Vectorized<uint##bit##_t> operator==(                                     \
 | 
			
		||||
        const Vectorized<uint##bit##_t>& other) const {                       \
 | 
			
		||||
      return Vectorized<value_type>(vceqq_u##bit(values, other.values));      \
 | 
			
		||||
    }                                                                         \
 | 
			
		||||
    Vectorized<uint##bit##_t> operator!=(                                     \
 | 
			
		||||
        const Vectorized<uint##bit##_t>& other) const;                        \
 | 
			
		||||
    Vectorized<uint##bit##_t> operator<(                                      \
 | 
			
		||||
        const Vectorized<uint##bit##_t>& other) const {                       \
 | 
			
		||||
      return Vectorized<value_type>(vcltq_u##bit(values, other.values));      \
 | 
			
		||||
    }                                                                         \
 | 
			
		||||
    Vectorized<uint##bit##_t> operator<=(                                     \
 | 
			
		||||
        const Vectorized<uint##bit##_t>& other) const {                       \
 | 
			
		||||
      return Vectorized<value_type>(vcleq_u##bit(values, other.values));      \
 | 
			
		||||
    }                                                                         \
 | 
			
		||||
    Vectorized<uint##bit##_t> operator>(                                      \
 | 
			
		||||
        const Vectorized<uint##bit##_t>& other) const {                       \
 | 
			
		||||
      return Vectorized<value_type>(vcgtq_u##bit(values, other.values));      \
 | 
			
		||||
    }                                                                         \
 | 
			
		||||
    Vectorized<uint##bit##_t> operator>=(                                     \
 | 
			
		||||
        const Vectorized<uint##bit##_t>& other) const {                       \
 | 
			
		||||
      return Vectorized<value_type>(vcgeq_u##bit(values, other.values));      \
 | 
			
		||||
    }                                                                         \
 | 
			
		||||
    Vectorized<uint##bit##_t> eq(                                             \
 | 
			
		||||
        const Vectorized<uint##bit##_t>& other) const;                        \
 | 
			
		||||
    Vectorized<uint##bit##_t> ne(                                             \
 | 
			
		||||
        const Vectorized<uint##bit##_t>& other) const;                        \
 | 
			
		||||
    Vectorized<uint##bit##_t> gt(                                             \
 | 
			
		||||
        const Vectorized<uint##bit##_t>& other) const;                        \
 | 
			
		||||
    Vectorized<uint##bit##_t> ge(                                             \
 | 
			
		||||
        const Vectorized<uint##bit##_t>& other) const;                        \
 | 
			
		||||
    Vectorized<uint##bit##_t> lt(                                             \
 | 
			
		||||
        const Vectorized<uint##bit##_t>& other) const;                        \
 | 
			
		||||
    Vectorized<uint##bit##_t> le(                                             \
 | 
			
		||||
        const Vectorized<uint##bit##_t>& other) const;                        \
 | 
			
		||||
  };                                                                          \
 | 
			
		||||
  template <>                                                                 \
 | 
			
		||||
  Vectorized<uint##bit##_t> inline operator+(                                 \
 | 
			
		||||
      const Vectorized<uint##bit##_t>& a,                                     \
 | 
			
		||||
      const Vectorized<uint##bit##_t>& b) {                                   \
 | 
			
		||||
    return vaddq_u##bit(a, b);                                                \
 | 
			
		||||
  }                                                                           \
 | 
			
		||||
  template <>                                                                 \
 | 
			
		||||
  Vectorized<uint##bit##_t> inline operator-(                                 \
 | 
			
		||||
      const Vectorized<uint##bit##_t>& a,                                     \
 | 
			
		||||
      const Vectorized<uint##bit##_t>& b) {                                   \
 | 
			
		||||
    return vsubq_u##bit(a, b);                                                \
 | 
			
		||||
  }                                                                           \
 | 
			
		||||
  template <>                                                                 \
 | 
			
		||||
  Vectorized<uint##bit##_t> inline operator&(                                 \
 | 
			
		||||
      const Vectorized<uint##bit##_t>& a,                                     \
 | 
			
		||||
      const Vectorized<uint##bit##_t>& b) {                                   \
 | 
			
		||||
    return vandq_u##bit(a, b);                                                \
 | 
			
		||||
  }                                                                           \
 | 
			
		||||
  template <>                                                                 \
 | 
			
		||||
  Vectorized<uint##bit##_t> inline operator|(                                 \
 | 
			
		||||
      const Vectorized<uint##bit##_t>& a,                                     \
 | 
			
		||||
      const Vectorized<uint##bit##_t>& b) {                                   \
 | 
			
		||||
    return vorrq_u##bit(a, b);                                                \
 | 
			
		||||
  }                                                                           \
 | 
			
		||||
  template <>                                                                 \
 | 
			
		||||
  Vectorized<uint##bit##_t> inline operator^(                                 \
 | 
			
		||||
      const Vectorized<uint##bit##_t>& a,                                     \
 | 
			
		||||
      const Vectorized<uint##bit##_t>& b) {                                   \
 | 
			
		||||
    return veorq_u##bit(a, b);                                                \
 | 
			
		||||
  }                                                                           \
 | 
			
		||||
  Vectorized<uint##bit##_t> inline Vectorized<uint##bit##_t>::eq(             \
 | 
			
		||||
      const Vectorized<uint##bit##_t>& other) const {                         \
 | 
			
		||||
    return (*this == other) & Vectorized<uint##bit##_t>(1);                   \
 | 
			
		||||
  }                                                                           \
 | 
			
		||||
  Vectorized<uint##bit##_t> inline Vectorized<uint##bit##_t>::ne(             \
 | 
			
		||||
      const Vectorized<uint##bit##_t>& other) const {                         \
 | 
			
		||||
    return (*this != other) & Vectorized<uint##bit##_t>(1);                   \
 | 
			
		||||
  }                                                                           \
 | 
			
		||||
  Vectorized<uint##bit##_t> inline Vectorized<uint##bit##_t>::gt(             \
 | 
			
		||||
      const Vectorized<uint##bit##_t>& other) const {                         \
 | 
			
		||||
    return (*this > other) & Vectorized<uint##bit##_t>(1);                    \
 | 
			
		||||
  }                                                                           \
 | 
			
		||||
  Vectorized<uint##bit##_t> inline Vectorized<uint##bit##_t>::ge(             \
 | 
			
		||||
      const Vectorized<uint##bit##_t>& other) const {                         \
 | 
			
		||||
    return (*this >= other) & Vectorized<uint##bit##_t>(1);                   \
 | 
			
		||||
  }                                                                           \
 | 
			
		||||
  Vectorized<uint##bit##_t> inline Vectorized<uint##bit##_t>::lt(             \
 | 
			
		||||
      const Vectorized<uint##bit##_t>& other) const {                         \
 | 
			
		||||
    return (*this < other) & Vectorized<uint##bit##_t>(1);                    \
 | 
			
		||||
  }                                                                           \
 | 
			
		||||
  Vectorized<uint##bit##_t> inline Vectorized<uint##bit##_t>::le(             \
 | 
			
		||||
      const Vectorized<uint##bit##_t>& other) const {                         \
 | 
			
		||||
    return (*this <= other) & Vectorized<uint##bit##_t>(1);                   \
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
VEC_UINT_NEON_TEMPLATE(16, 8)
 | 
			
		||||
 | 
			
		||||
inline uint8_t Vectorized<uint8_t>::reduce_max() const {
 | 
			
		||||
  return vmaxvq_u8(values);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<uint8_t> inline operator*(
 | 
			
		||||
    const Vectorized<uint8_t>& a,
 | 
			
		||||
    const Vectorized<uint8_t>& b) {
 | 
			
		||||
  return vmulq_u8(a, b);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
inline Vectorized<uint8_t> operator~(const Vectorized<uint8_t>& a) {
 | 
			
		||||
  return vmvnq_u8(a);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
inline Vectorized<uint8_t> Vectorized<uint8_t>::operator!=(
 | 
			
		||||
    const Vectorized<uint8_t>& other) const {
 | 
			
		||||
  return ~(*this == other);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<uint8_t> inline minimum(
 | 
			
		||||
    const Vectorized<uint8_t>& a,
 | 
			
		||||
    const Vectorized<uint8_t>& b) {
 | 
			
		||||
  return vminq_u8(a, b);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<uint8_t> inline maximum(
 | 
			
		||||
    const Vectorized<uint8_t>& a,
 | 
			
		||||
    const Vectorized<uint8_t>& b) {
 | 
			
		||||
  return vmaxq_u8(a, b);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <uint64_t mask>
 | 
			
		||||
Vectorized<uint8_t> Vectorized<uint8_t>::blend(
 | 
			
		||||
    const Vectorized<uint8_t>& a,
 | 
			
		||||
    const Vectorized<uint8_t>& b) {
 | 
			
		||||
  // Build an array of flags: each bit of element is 1 if the corresponding bit
 | 
			
		||||
  // in 'mask' is set, 0 otherwise.
 | 
			
		||||
  uint8x16_t maskArray = {
 | 
			
		||||
      (mask & 1LL) ? 0xFF : 0,
 | 
			
		||||
      (mask & 2LL) ? 0xFF : 0,
 | 
			
		||||
      (mask & 4LL) ? 0xFF : 0,
 | 
			
		||||
      (mask & 8LL) ? 0xFF : 0,
 | 
			
		||||
      (mask & 16LL) ? 0xFF : 0,
 | 
			
		||||
      (mask & 32LL) ? 0xFF : 0,
 | 
			
		||||
      (mask & 64LL) ? 0xFF : 0,
 | 
			
		||||
      (mask & 128LL) ? 0xFF : 0,
 | 
			
		||||
      (mask & 256LL) ? 0xFF : 0,
 | 
			
		||||
      (mask & 512LL) ? 0xFF : 0,
 | 
			
		||||
      (mask & 1024LL) ? 0xFF : 0,
 | 
			
		||||
      (mask & 2048LL) ? 0xFF : 0,
 | 
			
		||||
      (mask & 4096LL) ? 0xFF : 0,
 | 
			
		||||
      (mask & 8192LL) ? 0xFF : 0,
 | 
			
		||||
      (mask & 16384LL) ? 0xFF : 0,
 | 
			
		||||
      (mask & 32768LL) ? 0xFF : 0};
 | 
			
		||||
  // Use BSL to select elements from b where the mask is 1, else from a
 | 
			
		||||
  return vbslq_u8(maskArray, b.values, a.values);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#define VEC_UINT_NEON_OPS(vl, bit)                                             \
 | 
			
		||||
  inline Vectorized<uint##bit##_t>::Vectorized(uint##bit##_t val) {            \
 | 
			
		||||
    values = vdupq_n_u##bit(val);                                              \
 | 
			
		||||
  }                                                                            \
 | 
			
		||||
  inline Vectorized<uint##bit##_t> Vectorized<uint##bit##_t>::loadu(           \
 | 
			
		||||
      const void* ptr, uint64_t count) {                                       \
 | 
			
		||||
    if (count == size()) {                                                     \
 | 
			
		||||
      return vld1q_u##bit(reinterpret_cast<const uint##bit##_t*>(ptr));        \
 | 
			
		||||
    } else {                                                                   \
 | 
			
		||||
      __at_align__ uint##bit##_t tmp_values[size()];                           \
 | 
			
		||||
      for (const auto i : c10::irange(size())) {                               \
 | 
			
		||||
        tmp_values[i] = 0;                                                     \
 | 
			
		||||
      }                                                                        \
 | 
			
		||||
      std::memcpy(                                                             \
 | 
			
		||||
          tmp_values,                                                          \
 | 
			
		||||
          reinterpret_cast<const uint##bit##_t*>(ptr),                         \
 | 
			
		||||
          count * sizeof(uint##bit##_t));                                      \
 | 
			
		||||
      return vld1q_u##bit(reinterpret_cast<const uint##bit##_t*>(tmp_values)); \
 | 
			
		||||
    }                                                                          \
 | 
			
		||||
  }                                                                            \
 | 
			
		||||
  inline void Vectorized<uint##bit##_t>::store(void* ptr, uint64_t count)      \
 | 
			
		||||
      const {                                                                  \
 | 
			
		||||
    if (count == size()) {                                                     \
 | 
			
		||||
      vst1q_u##bit(reinterpret_cast<uint##bit##_t*>(ptr), values);             \
 | 
			
		||||
    } else {                                                                   \
 | 
			
		||||
      uint##bit##_t tmp_values[size()];                                        \
 | 
			
		||||
      vst1q_u##bit(reinterpret_cast<uint##bit##_t*>(tmp_values), values);      \
 | 
			
		||||
      std::memcpy(ptr, tmp_values, count * sizeof(uint##bit##_t));             \
 | 
			
		||||
    }                                                                          \
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
VEC_UINT_NEON_OPS(16, 8)
 | 
			
		||||
 | 
			
		||||
template <typename step_t>
 | 
			
		||||
inline Vectorized<uint8_t> Vectorized<uint8_t>::arange(
 | 
			
		||||
    uint8_t base,
 | 
			
		||||
    step_t step) {
 | 
			
		||||
  const Vectorized<uint8_t> base_vec(base);
 | 
			
		||||
  const Vectorized<uint8_t> step_vec(step);
 | 
			
		||||
  const uint8x16_t step_sizes = {
 | 
			
		||||
      0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15};
 | 
			
		||||
  return vmlaq_u8(base_vec, step_sizes, step_vec);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<uint8_t> inline operator>>(
 | 
			
		||||
    const Vectorized<uint8_t>& a,
 | 
			
		||||
    const Vectorized<uint8_t>& b) {
 | 
			
		||||
  uint8x16_t x = a;
 | 
			
		||||
  uint8x16_t bound = vdupq_n_u8(8);
 | 
			
		||||
  uint8x16_t z = vminq_u8(b, bound);
 | 
			
		||||
  return x >> z;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<uint8_t> inline operator<<(
 | 
			
		||||
    const Vectorized<uint8_t>& a,
 | 
			
		||||
    const Vectorized<uint8_t>& b) {
 | 
			
		||||
  uint8x16_t bound = vdupq_n_u8(8);
 | 
			
		||||
  uint8x16_t z = vminq_u8(b, bound);
 | 
			
		||||
  return vshlq_u8(a, vreinterpretq_s8_u8(z));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
inline Vectorized<uint8_t> Vectorized<uint8_t>::set(
 | 
			
		||||
    const Vectorized<uint8_t>& a,
 | 
			
		||||
    const Vectorized<uint8_t>& b,
 | 
			
		||||
    uint64_t count) {
 | 
			
		||||
  if (count == 0) {
 | 
			
		||||
    return a;
 | 
			
		||||
  } else if (count >= 16) {
 | 
			
		||||
    return b;
 | 
			
		||||
  } else {
 | 
			
		||||
    // Build an array of flags: each bit of element is 1 if the corresponding
 | 
			
		||||
    // bit in 'mask' is set, 0 otherwise.
 | 
			
		||||
    uint8x16_t maskArray = {
 | 
			
		||||
        static_cast<uint8_t>((count >= 1LL) ? 0xFF : 0),
 | 
			
		||||
        static_cast<uint8_t>((count >= 2LL) ? 0xFF : 0),
 | 
			
		||||
        static_cast<uint8_t>((count >= 3LL) ? 0xFF : 0),
 | 
			
		||||
        static_cast<uint8_t>((count >= 4LL) ? 0xFF : 0),
 | 
			
		||||
        static_cast<uint8_t>((count >= 5LL) ? 0xFF : 0),
 | 
			
		||||
        static_cast<uint8_t>((count >= 6LL) ? 0xFF : 0),
 | 
			
		||||
        static_cast<uint8_t>((count >= 7LL) ? 0xFF : 0),
 | 
			
		||||
        static_cast<uint8_t>((count >= 8LL) ? 0xFF : 0),
 | 
			
		||||
        static_cast<uint8_t>((count >= 9LL) ? 0xFF : 0),
 | 
			
		||||
        static_cast<uint8_t>((count >= 10LL) ? 0xFF : 0),
 | 
			
		||||
        static_cast<uint8_t>((count >= 11LL) ? 0xFF : 0),
 | 
			
		||||
        static_cast<uint8_t>((count >= 12LL) ? 0xFF : 0),
 | 
			
		||||
        static_cast<uint8_t>((count >= 13LL) ? 0xFF : 0),
 | 
			
		||||
        static_cast<uint8_t>((count >= 14LL) ? 0xFF : 0),
 | 
			
		||||
        static_cast<uint8_t>((count >= 15LL) ? 0xFF : 0),
 | 
			
		||||
        0};
 | 
			
		||||
 | 
			
		||||
    // Use BSL to select elements from b where the mask is 1, else from a
 | 
			
		||||
    return vbslq_u8(maskArray, b.values, a.values);
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<uint8_t> inline operator/(
 | 
			
		||||
    const Vectorized<uint8_t>& a,
 | 
			
		||||
    const Vectorized<uint8_t>& b) {
 | 
			
		||||
  uint8x16_t x = a;
 | 
			
		||||
  uint8x16_t y = b;
 | 
			
		||||
  return x / y;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<uint8_t> inline clamp(
 | 
			
		||||
    const Vectorized<uint8_t>& a,
 | 
			
		||||
    const Vectorized<uint8_t>& min,
 | 
			
		||||
    const Vectorized<uint8_t>& max) {
 | 
			
		||||
  return minimum(max, maximum(min, a));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<uint8_t> inline clamp_max(
 | 
			
		||||
    const Vectorized<uint8_t>& a,
 | 
			
		||||
    const Vectorized<uint8_t>& max) {
 | 
			
		||||
  return minimum(max, a);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<uint8_t> inline clamp_min(
 | 
			
		||||
    const Vectorized<uint8_t>& a,
 | 
			
		||||
    const Vectorized<uint8_t>& min) {
 | 
			
		||||
  return maximum(min, a);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
} // namespace CPU_CAPABILITY
 | 
			
		||||
} // namespace at::vec
 | 
			
		||||
@ -1740,7 +1740,7 @@ Vectorized<int16_t> inline shift_256_16(
 | 
			
		||||
 | 
			
		||||
  // Control masks for shuffle operation, treating 256 bits as an
 | 
			
		||||
  // array of 16-bit elements, and considering pairs of neighboring
 | 
			
		||||
  // elements.  Specifically, a mask named "ctl_M_N" (M,N in [0,1], and
 | 
			
		||||
  // elements.  Specifially, a mask named "ctl_M_N" (M,N in [0,1], and
 | 
			
		||||
  // M!=N) is set so that shuffle will move element with index M from
 | 
			
		||||
  // input pair into element with index N in output pair, and element
 | 
			
		||||
  // with index M in output pair will be set to all 0s.
 | 
			
		||||
@ -1875,7 +1875,7 @@ Vectorized<T> inline shift_256_8(
 | 
			
		||||
 | 
			
		||||
  // Control masks for shuffle operation, treating 256 bits as an
 | 
			
		||||
  // array of 8-bit elements, and considering quadruples of
 | 
			
		||||
  // neighboring elements.  Specifically, a mask named "ctl_M_N" (M,N
 | 
			
		||||
  // neighboring elements.  Specifially, a mask named "ctl_M_N" (M,N
 | 
			
		||||
  // in [0,1,2,3], and M!=N) is set so that shuffle will move element
 | 
			
		||||
  // with index M from input quadruple into element with index N in
 | 
			
		||||
  // output quadruple, and other elements in output quadruple will be
 | 
			
		||||
 | 
			
		||||
@ -1390,7 +1390,7 @@ std::pair<Vectorized<float>, Vectorized<float>> inline convert_int8_to_float(
 | 
			
		||||
 | 
			
		||||
std::pair<Vectorized<float>, Vectorized<float>> inline convert_int8_to_float(
 | 
			
		||||
    at::vec::Vectorized<uint8_t> src) {
 | 
			
		||||
  auto u8x8 = vget_low_u8(src);
 | 
			
		||||
  auto u8x8 = vld1_u8(src.operator const uint8_t*());
 | 
			
		||||
  auto u16x8 = vmovl_u8(u8x8);
 | 
			
		||||
  auto u32x4_hi = vmovl_u16(vget_high_u16(u16x8));
 | 
			
		||||
  auto u32x4_lo = vmovl_u16(vget_low_u16(u16x8));
 | 
			
		||||
@ -1412,7 +1412,7 @@ Vectorized<float> inline convert_int8_half_register_to_float(
 | 
			
		||||
 | 
			
		||||
Vectorized<float> inline convert_int8_half_register_to_float(
 | 
			
		||||
    at::vec::Vectorized<uint8_t> src) {
 | 
			
		||||
  auto u8x8 = vget_low_u8(src);
 | 
			
		||||
  auto u8x8 = vld1_u8(src.operator const uint8_t*());
 | 
			
		||||
  auto u16x8 = vmovl_u8(u8x8);
 | 
			
		||||
  auto u32x4_lo = vmovl_u16(vget_low_u16(u16x8));
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -143,7 +143,7 @@ class Vectorized<double> {
 | 
			
		||||
      const Vectorized<double>& a,
 | 
			
		||||
      const Vectorized<double>& b,
 | 
			
		||||
      const Vectorized<double>& mask) {
 | 
			
		||||
    // the mask used here returned by comparison of vec256
 | 
			
		||||
    // the mask used here returned by comparision of vec256
 | 
			
		||||
 | 
			
		||||
    return {
 | 
			
		||||
        vec_sel(a._vec0, b._vec0, mask._vecb0),
 | 
			
		||||
 | 
			
		||||
@ -142,7 +142,7 @@ class Vectorized<float> {
 | 
			
		||||
      const Vectorized<float>& a,
 | 
			
		||||
      const Vectorized<float>& b,
 | 
			
		||||
      const Vectorized<float>& mask) {
 | 
			
		||||
    // the mask used here returned by comparison of vec256
 | 
			
		||||
    // the mask used here returned by comparision of vec256
 | 
			
		||||
    // assuming this we can use the same mask directly with vec_sel
 | 
			
		||||
    return {
 | 
			
		||||
        vec_sel(a._vec0, b._vec0, mask._vecb0),
 | 
			
		||||
 | 
			
		||||
@ -202,7 +202,7 @@ class Vectorized<int16_t> {
 | 
			
		||||
      const Vectorized<int16_t>& a,
 | 
			
		||||
      const Vectorized<int16_t>& b,
 | 
			
		||||
      const Vectorized<int16_t>& mask) {
 | 
			
		||||
    // the mask used here returned by comparison of vec256
 | 
			
		||||
    // the mask used here returned by comparision of vec256
 | 
			
		||||
    // assuming this we can use the same mask directly with vec_sel
 | 
			
		||||
    // warning intel style mask will not work properly
 | 
			
		||||
    return {
 | 
			
		||||
 | 
			
		||||
@ -155,7 +155,7 @@ class Vectorized<int32_t> {
 | 
			
		||||
      const Vectorized<int32_t>& a,
 | 
			
		||||
      const Vectorized<int32_t>& b,
 | 
			
		||||
      const Vectorized<int32_t>& mask) {
 | 
			
		||||
    // the mask used here returned by comparison of vec256
 | 
			
		||||
    // the mask used here returned by comparision of vec256
 | 
			
		||||
    // assuming this we can use the same mask directly with vec_sel
 | 
			
		||||
    // warning intel style mask will not work properly
 | 
			
		||||
    return {
 | 
			
		||||
 | 
			
		||||
@ -119,7 +119,7 @@ class Vectorized<int64_t> {
 | 
			
		||||
      const Vectorized<int64_t>& a,
 | 
			
		||||
      const Vectorized<int64_t>& b,
 | 
			
		||||
      const Vectorized<int64_t>& mask) {
 | 
			
		||||
    // the mask used here returned by comparison of vec256
 | 
			
		||||
    // the mask used here returned by comparision of vec256
 | 
			
		||||
 | 
			
		||||
    return {
 | 
			
		||||
        vec_sel(a._vec0, b._vec0, mask._vecb0),
 | 
			
		||||
 | 
			
		||||
@ -397,7 +397,7 @@ inline Vectorized<bool> operator&&(
 | 
			
		||||
  const __m512i* other_ = reinterpret_cast<const __m512i*>(other.as_bytes());
 | 
			
		||||
  __m512i out = _mm512_and_si512(*self_, *other_);
 | 
			
		||||
  Vectorized<bool> ret;
 | 
			
		||||
  // We do not have a constructor that takes __m512i, so we need to memcpy
 | 
			
		||||
  // We do not have a constructer that takes __m512i, so we need to memcpy
 | 
			
		||||
  std::memcpy(ret, &out, ret.size() * sizeof(bool));
 | 
			
		||||
  return ret;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -1852,7 +1852,7 @@ Vectorized<T> inline shift_512_8(
 | 
			
		||||
 | 
			
		||||
  // Control masks for shuffle operation, treating 512 bits as an
 | 
			
		||||
  // array of 8-bit elements, and considering pairs of neighboring
 | 
			
		||||
  // elements.  Specifically, a mask named "ctl_M_N" (M,N in [0,1], and
 | 
			
		||||
  // elements.  Specifially, a mask named "ctl_M_N" (M,N in [0,1], and
 | 
			
		||||
  // M!=N) is set so that shuffle will move element with index M from
 | 
			
		||||
  // input pair into element with index N in output pair, and element
 | 
			
		||||
  // with index M in output pair will be set to all 0s.
 | 
			
		||||
 | 
			
		||||
@ -634,7 +634,7 @@ struct Vectorized {
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized<T> neg() const {
 | 
			
		||||
    // NB: the trailing return type is needed because we need to coerce the
 | 
			
		||||
    // return value back to T in the case of unary operator- incurring a
 | 
			
		||||
    // return value back to T in the case of unary operator- incuring a
 | 
			
		||||
    // promotion
 | 
			
		||||
    return map([](T x) -> T { return -x; });
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
@ -1958,7 +1958,7 @@ void scaled_gemm(
 | 
			
		||||
    ScalarType result_dtype,
 | 
			
		||||
    bool use_fast_accum,
 | 
			
		||||
    const std::optional<Tensor>& alpha) {
 | 
			
		||||
  // Note: see `cublasCommonArgs` for various non-intuitive manipulations
 | 
			
		||||
  // Note: see `cublasCommonArgs` for various non-intuitive manupulations
 | 
			
		||||
  // of input arguments to this function.
 | 
			
		||||
  const auto computeType = CUBLAS_COMPUTE_32F;
 | 
			
		||||
  const auto scaleType = CUDA_R_32F;
 | 
			
		||||
 | 
			
		||||
@ -2,10 +2,10 @@
 | 
			
		||||
 | 
			
		||||
#include <ATen/cuda/ATenCUDAGeneral.h>
 | 
			
		||||
#include <ATen/cuda/CUDAContext.h>
 | 
			
		||||
#include <ATen/cuda/Exceptions.h>
 | 
			
		||||
#include <c10/core/impl/GPUTrace.h>
 | 
			
		||||
#include <c10/cuda/CUDAGuard.h>
 | 
			
		||||
#include <c10/cuda/CUDAStream.h>
 | 
			
		||||
#include <c10/cuda/CUDAGuard.h>
 | 
			
		||||
#include <ATen/cuda/Exceptions.h>
 | 
			
		||||
#include <c10/util/Exception.h>
 | 
			
		||||
 | 
			
		||||
#include <cuda_runtime_api.h>
 | 
			
		||||
@ -246,79 +246,4 @@ private:
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
// EventPool - Thread-safe pool of CUDA events to avoid expensive cudaEventCreate
 | 
			
		||||
// calls. cudaEventCreate when concurrently invoked from multiple threads can be
 | 
			
		||||
// very expensive (especially on certain device/driver combinations).
 | 
			
		||||
using CUDAEventPtr =
 | 
			
		||||
    std::unique_ptr<CUDAEvent, std::function<void(CUDAEvent*)>>;
 | 
			
		||||
 | 
			
		||||
class EventPool {
 | 
			
		||||
 public:
 | 
			
		||||
  EventPool() : pools_(at::cuda::device_count()) {}
 | 
			
		||||
 | 
			
		||||
  CUDAEventPtr get(const DeviceIndex device) {
 | 
			
		||||
    // If the device is invalid, return a default event and no pooling
 | 
			
		||||
    if (device < 0 || device >= (DeviceIndex)pools_.size()) {
 | 
			
		||||
      auto deleter = [](CUDAEvent* event) {
 | 
			
		||||
        delete event;
 | 
			
		||||
      };
 | 
			
		||||
      return CUDAEventPtr(
 | 
			
		||||
        std::make_unique<CUDAEvent>(cudaEventDisableTiming).release(), deleter);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    auto& pool = pools_[device];
 | 
			
		||||
 | 
			
		||||
    // Create a destructor that returns the event to the appropriate device pool
 | 
			
		||||
    auto destructor = [&pool](CUDAEvent* event) noexcept {
 | 
			
		||||
      if (event != nullptr) {
 | 
			
		||||
        std::lock_guard<std::mutex> lock(pool.mutex_);
 | 
			
		||||
        pool.event_pool_.emplace_back(event);
 | 
			
		||||
      }
 | 
			
		||||
    };
 | 
			
		||||
 | 
			
		||||
    {
 | 
			
		||||
      std::lock_guard<std::mutex> lock(pool.mutex_);
 | 
			
		||||
      if (!pool.event_pool_.empty()) {
 | 
			
		||||
        auto event = std::move(pool.event_pool_.back());
 | 
			
		||||
        pool.event_pool_.pop_back();
 | 
			
		||||
        return CUDAEventPtr(event.release(), destructor);
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    return CUDAEventPtr(
 | 
			
		||||
        std::make_unique<CUDAEvent>(cudaEventDisableTiming).release(),
 | 
			
		||||
        destructor);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  void empty_cache() {
 | 
			
		||||
    for (auto& pool : pools_) {
 | 
			
		||||
      std::lock_guard<std::mutex> lock(pool.mutex_);
 | 
			
		||||
      pool.event_pool_.clear();
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  void init_num_events(const size_t num_events) {
 | 
			
		||||
    for (DeviceIndex device_idx = 0; device_idx < at::cuda::device_count(); ++device_idx) {
 | 
			
		||||
        CUDAGuard device_guard(device_idx);
 | 
			
		||||
        std::vector<CUDAEventPtr> temp_events;
 | 
			
		||||
        temp_events.reserve(num_events);
 | 
			
		||||
        for (size_t i = 0; i < num_events; ++i) {
 | 
			
		||||
          auto event = get(device_idx);
 | 
			
		||||
          // Record the event to ensure it's properly initialized
 | 
			
		||||
          event->record();
 | 
			
		||||
          temp_events.emplace_back(std::move(event));
 | 
			
		||||
        }
 | 
			
		||||
        // Events will be returned to pool when temp_events is destroyed
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
 private:
 | 
			
		||||
  struct alignas(64) PerDevicePool {
 | 
			
		||||
    alignas(64) std::mutex mutex_;
 | 
			
		||||
    std::vector<std::unique_ptr<CUDAEvent>> event_pool_;
 | 
			
		||||
  };
 | 
			
		||||
 | 
			
		||||
  std::vector<PerDevicePool> pools_;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
} // namespace at::cuda
 | 
			
		||||
 | 
			
		||||
@ -168,9 +168,11 @@ void CUDAGraph::instantiate() {
 | 
			
		||||
  // https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__GRAPH.html#group__CUDART__GRAPH_1g1accfe1da0c605a577c22d9751a09597
 | 
			
		||||
  // cudaGraphInstantiateWithFlags
 | 
			
		||||
  // https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__GRAPH.html#group__CUDART__GRAPH_1ga2c652a24ba93e52b99a47bec0888233
 | 
			
		||||
#if !defined(USE_ROCM) || ROCM_VERSION >= 60200
 | 
			
		||||
  int version = 0;
 | 
			
		||||
  AT_CUDA_CHECK(cudaDriverGetVersion(&version));
 | 
			
		||||
  if (version < 11040) {
 | 
			
		||||
#endif
 | 
			
		||||
    // Trailing NULL, NULL, 0 arguments were recommended by Cuda driver people,
 | 
			
		||||
    // who prefer not to report error message through these arguments moving forward
 | 
			
		||||
    // (they prefer return value, or errors on api calls internal to the capture)
 | 
			
		||||
@ -181,11 +183,13 @@ void CUDAGraph::instantiate() {
 | 
			
		||||
#endif
 | 
			
		||||
//Since ROCm 6.2, we want to go down this path as hipGraphExecDestroy in the destructor will not immediately free the memory.
 | 
			
		||||
//It will wait for the next sync operation. cudaGraphInstantiateFlagAutoFreeOnLaunch will add async frees after graph launch.
 | 
			
		||||
#if !defined(USE_ROCM) || ROCM_VERSION >= 60200
 | 
			
		||||
  } else {
 | 
			
		||||
    AT_CUDA_CHECK(cudaGraphInstantiateWithFlags(&graph_exec_,
 | 
			
		||||
                                                graph_,
 | 
			
		||||
                                                cudaGraphInstantiateFlagAutoFreeOnLaunch));
 | 
			
		||||
  }
 | 
			
		||||
#endif
 | 
			
		||||
  has_graph_exec_ = true;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -307,7 +311,7 @@ CUDAGraph::~CUDAGraph() {
 | 
			
		||||
// There are recent HIP changes where hipGraphExecDestroy doesn't immediately free memory.
 | 
			
		||||
// They wait for next sync point in order to free the memory, this is to ensure that all
 | 
			
		||||
// hipGraphLaunch are finished before we release any memory. This feature was enabled in rocm6.2.
 | 
			
		||||
// We need to ensure all async operations finish before deleting the object.
 | 
			
		||||
// We need to ensure all async opreations finish before deleting the object.
 | 
			
		||||
#if (defined(USE_ROCM) && ROCM_VERSION >= 60200)
 | 
			
		||||
  if (capture_dev_ != UNDEFINED_DEVICE) // check if capture_dev_ contains the real device id
 | 
			
		||||
  {
 | 
			
		||||
 | 
			
		||||
@ -1,192 +0,0 @@
 | 
			
		||||
#include <ATen/cuda/CUDAGreenContext.h>
 | 
			
		||||
 | 
			
		||||
namespace at::cuda {
 | 
			
		||||
  GreenContext::GreenContext(uint32_t device_id, uint32_t num_sms) {
 | 
			
		||||
#if CUDA_HAS_GREEN_CONTEXT
 | 
			
		||||
    int driver_version;
 | 
			
		||||
    C10_CUDA_CHECK(cudaDriverGetVersion(&driver_version));
 | 
			
		||||
    TORCH_CHECK(
 | 
			
		||||
        driver_version >= 12080, "cuda driver too old to use green context!");
 | 
			
		||||
    CUcontext pctx = nullptr;
 | 
			
		||||
    C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuCtxGetCurrent_(&pctx));
 | 
			
		||||
    if (C10_UNLIKELY(!pctx)) {
 | 
			
		||||
      TORCH_WARN(
 | 
			
		||||
          "Attempted to create a green context but"
 | 
			
		||||
          " there was no primary context! Creating a primary context...");
 | 
			
		||||
 | 
			
		||||
      cudaFree(0);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    CUdevice device;
 | 
			
		||||
    device_id_ = device_id;
 | 
			
		||||
    C10_CUDA_DRIVER_CHECK(
 | 
			
		||||
        c10::cuda::DriverAPI::get()->cuDeviceGet_(&device, device_id));
 | 
			
		||||
 | 
			
		||||
    // Get device resources
 | 
			
		||||
    CUdevResource device_resource;
 | 
			
		||||
    C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuDeviceGetDevResource_(
 | 
			
		||||
        device, &device_resource, CU_DEV_RESOURCE_TYPE_SM));
 | 
			
		||||
 | 
			
		||||
    // Split resources
 | 
			
		||||
    std::vector<CUdevResource> result(1);
 | 
			
		||||
    auto result_data = result.data();
 | 
			
		||||
    unsigned int nb_groups = 1;
 | 
			
		||||
    CUdevResource remaining;
 | 
			
		||||
 | 
			
		||||
    C10_CUDA_DRIVER_CHECK(
 | 
			
		||||
        c10::cuda::DriverAPI::get()->cuDevSmResourceSplitByCount_(
 | 
			
		||||
            result_data,
 | 
			
		||||
            &nb_groups,
 | 
			
		||||
            &device_resource,
 | 
			
		||||
            &remaining,
 | 
			
		||||
            0, // default flags
 | 
			
		||||
            num_sms));
 | 
			
		||||
 | 
			
		||||
    TORCH_CHECK(nb_groups == 1, "Failed to create single resource group");
 | 
			
		||||
 | 
			
		||||
    // Generate resource descriptor
 | 
			
		||||
    CUdevResourceDesc desc;
 | 
			
		||||
    C10_CUDA_DRIVER_CHECK(
 | 
			
		||||
        c10::cuda::DriverAPI::get()->cuDevResourceGenerateDesc_(
 | 
			
		||||
            &desc, result_data, 1));
 | 
			
		||||
 | 
			
		||||
    // Create green context
 | 
			
		||||
    // CU_GREEN_CTX_DEFAULT_STREAM is required per docs:
 | 
			
		||||
    // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__GREEN__CONTEXTS.html
 | 
			
		||||
    C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuGreenCtxCreate_(
 | 
			
		||||
        &green_ctx_, desc, device, CU_GREEN_CTX_DEFAULT_STREAM));
 | 
			
		||||
 | 
			
		||||
    // Convert to regular context
 | 
			
		||||
    C10_CUDA_DRIVER_CHECK(
 | 
			
		||||
        c10::cuda::DriverAPI::get()->cuCtxFromGreenCtx_(&context_, green_ctx_));
 | 
			
		||||
    TORCH_CHECK(context_, "Green ctx conversion to regular ctx failed!");
 | 
			
		||||
#else
 | 
			
		||||
    TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
 | 
			
		||||
#endif
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  std::unique_ptr<GreenContext> GreenContext::create(
 | 
			
		||||
      uint32_t num_sms,
 | 
			
		||||
      std::optional<uint32_t> device_id) {
 | 
			
		||||
#if CUDA_HAS_GREEN_CONTEXT
 | 
			
		||||
    if (!device_id.has_value()) {
 | 
			
		||||
      device_id = at::cuda::current_device();
 | 
			
		||||
    }
 | 
			
		||||
    return std::make_unique<GreenContext>(device_id.value(), num_sms);
 | 
			
		||||
#else
 | 
			
		||||
    TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
 | 
			
		||||
#endif
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Implement move operations
 | 
			
		||||
  GreenContext::GreenContext(GreenContext&& other) noexcept{
 | 
			
		||||
#if CUDA_HAS_GREEN_CONTEXT
 | 
			
		||||
    device_id_ = std::exchange(other.device_id_, -1);
 | 
			
		||||
    green_ctx_ = std::exchange(other.green_ctx_, nullptr);
 | 
			
		||||
    context_ = std::exchange(other.context_, nullptr);
 | 
			
		||||
    parent_stream_ = std::exchange(other.parent_stream_, nullptr);
 | 
			
		||||
#else
 | 
			
		||||
    TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
 | 
			
		||||
#endif
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  GreenContext& GreenContext::operator=(GreenContext&& other) noexcept{
 | 
			
		||||
#if CUDA_HAS_GREEN_CONTEXT
 | 
			
		||||
    if (this != &other) {
 | 
			
		||||
      // Clean up current resources
 | 
			
		||||
      if (green_ctx_) {
 | 
			
		||||
        CUcontext current = nullptr;
 | 
			
		||||
        C10_CUDA_DRIVER_CHECK(
 | 
			
		||||
            c10::cuda::DriverAPI::get()->cuCtxGetCurrent_(¤t));
 | 
			
		||||
        if (current == context_) {
 | 
			
		||||
          TORCH_CHECK(
 | 
			
		||||
              false,
 | 
			
		||||
              "attempting to overwrite current green ctx "
 | 
			
		||||
              "when it is active!");
 | 
			
		||||
        }
 | 
			
		||||
        C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuGreenCtxDestroy_(green_ctx_));
 | 
			
		||||
      }
 | 
			
		||||
 | 
			
		||||
      // Take ownership of other's resources
 | 
			
		||||
      device_id_ = std::exchange(other.device_id_, -1);
 | 
			
		||||
      green_ctx_ = std::exchange(other.green_ctx_, nullptr);
 | 
			
		||||
      context_ = std::exchange(other.context_, nullptr);
 | 
			
		||||
      parent_stream_ = std::exchange(other.parent_stream_, nullptr);
 | 
			
		||||
    }
 | 
			
		||||
    return *this;
 | 
			
		||||
#else
 | 
			
		||||
    TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
 | 
			
		||||
#endif
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  GreenContext::~GreenContext() noexcept{
 | 
			
		||||
#if CUDA_HAS_GREEN_CONTEXT
 | 
			
		||||
    C10_CUDA_DRIVER_CHECK(
 | 
			
		||||
        c10::cuda::DriverAPI::get()->cuGreenCtxDestroy_(green_ctx_));
 | 
			
		||||
#else
 | 
			
		||||
    TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
 | 
			
		||||
#endif
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Get the underlying CUDA context
 | 
			
		||||
  CUcontext GreenContext::getContext() const {
 | 
			
		||||
#if CUDA_HAS_GREEN_CONTEXT
 | 
			
		||||
    return context_;
 | 
			
		||||
#else
 | 
			
		||||
    TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
 | 
			
		||||
#endif
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Get the underlying green context
 | 
			
		||||
#if CUDA_HAS_GREEN_CONTEXT
 | 
			
		||||
  CUgreenCtx GreenContext::getGreenContext() const {
 | 
			
		||||
    return green_ctx_;
 | 
			
		||||
  }
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
  // Make this context current
 | 
			
		||||
  void GreenContext::setContext() {
 | 
			
		||||
#if CUDA_HAS_GREEN_CONTEXT
 | 
			
		||||
    auto current_stream = c10::cuda::getCurrentCUDAStream();
 | 
			
		||||
    parent_stream_ = current_stream.stream();
 | 
			
		||||
 | 
			
		||||
    at::cuda::CUDAEvent ev;
 | 
			
		||||
    ev.record(current_stream);
 | 
			
		||||
 | 
			
		||||
    CUcontext current = nullptr;
 | 
			
		||||
    C10_CUDA_DRIVER_CHECK(
 | 
			
		||||
        c10::cuda::DriverAPI::get()->cuCtxGetCurrent_(¤t));
 | 
			
		||||
    if (!current) {
 | 
			
		||||
      C10_CUDA_DRIVER_CHECK(
 | 
			
		||||
          c10::cuda::DriverAPI::get()->cuCtxSetCurrent_(context_));
 | 
			
		||||
    } else {
 | 
			
		||||
      C10_CUDA_DRIVER_CHECK(
 | 
			
		||||
          c10::cuda::DriverAPI::get()->cuCtxPushCurrent_(context_));
 | 
			
		||||
    }
 | 
			
		||||
    // currently hardcodes the new green context to use the default stream
 | 
			
		||||
    // TODO(eqy): consider creating a new stream if e.g., it allows interop
 | 
			
		||||
    // with CUDA Graph captures etc.
 | 
			
		||||
    auto default_stream = c10::cuda::getDefaultCUDAStream();
 | 
			
		||||
    ev.block(default_stream);
 | 
			
		||||
    c10::cuda::setCurrentCUDAStream(default_stream);
 | 
			
		||||
#else
 | 
			
		||||
    TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
 | 
			
		||||
#endif
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  void GreenContext::popContext() {
 | 
			
		||||
#if CUDA_HAS_GREEN_CONTEXT
 | 
			
		||||
    // see above note about stream being hardcoded to the default stream
 | 
			
		||||
    at::cuda::CUDAEvent ev;
 | 
			
		||||
    ev.record(c10::cuda::getCurrentCUDAStream());
 | 
			
		||||
    CUcontext popped;
 | 
			
		||||
    C10_CUDA_DRIVER_CHECK(
 | 
			
		||||
        c10::cuda::DriverAPI::get()->cuCtxPopCurrent_(&popped));
 | 
			
		||||
    TORCH_INTERNAL_ASSERT(
 | 
			
		||||
        popped == context_, "expected popped context to be the current ctx");
 | 
			
		||||
    ev.block(c10::cuda::getStreamFromExternal(parent_stream_, device_id_));
 | 
			
		||||
#else
 | 
			
		||||
    TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
 | 
			
		||||
#endif
 | 
			
		||||
  }
 | 
			
		||||
} // namespace at::cuda
 | 
			
		||||
@ -1,53 +0,0 @@
 | 
			
		||||
#pragma once
 | 
			
		||||
#include <ATen/cuda/CUDAEvent.h>
 | 
			
		||||
 | 
			
		||||
#if defined(CUDA_VERSION) && !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
 | 
			
		||||
#include <c10/cuda/driver_api.h>
 | 
			
		||||
#include <cuda.h>
 | 
			
		||||
#include <memory>
 | 
			
		||||
#include <stdexcept>
 | 
			
		||||
#include <vector>
 | 
			
		||||
#define CUDA_HAS_GREEN_CONTEXT 1
 | 
			
		||||
#else
 | 
			
		||||
#define CUDA_HAS_GREEN_CONTEXT 0
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
namespace at::cuda {
 | 
			
		||||
 | 
			
		||||
class TORCH_CUDA_CPP_API GreenContext {
 | 
			
		||||
 public:
 | 
			
		||||
  GreenContext(uint32_t device_id, uint32_t num_sms);
 | 
			
		||||
 | 
			
		||||
  static std::unique_ptr<GreenContext> create(uint32_t num_sms, std::optional<uint32_t> device_id);
 | 
			
		||||
 | 
			
		||||
  // Delete copy constructor and assignment
 | 
			
		||||
  GreenContext(const GreenContext&) = delete;
 | 
			
		||||
  GreenContext& operator=(const GreenContext&) = delete;
 | 
			
		||||
 | 
			
		||||
  // Implement move operations
 | 
			
		||||
  GreenContext(GreenContext&& other) noexcept;
 | 
			
		||||
  GreenContext& operator=(GreenContext&& other) noexcept;
 | 
			
		||||
  ~GreenContext() noexcept;
 | 
			
		||||
 | 
			
		||||
  // Get the underlying CUDA context
 | 
			
		||||
  CUcontext getContext() const;
 | 
			
		||||
 | 
			
		||||
  // Get the underlying green context
 | 
			
		||||
#if CUDA_HAS_GREEN_CONTEXT
 | 
			
		||||
  CUgreenCtx getGreenContext() const;
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
  // Make this context current
 | 
			
		||||
  void setContext();
 | 
			
		||||
 | 
			
		||||
  void popContext();
 | 
			
		||||
 | 
			
		||||
 private:
 | 
			
		||||
#if CUDA_HAS_GREEN_CONTEXT
 | 
			
		||||
  int32_t device_id_ = -1;
 | 
			
		||||
  CUgreenCtx green_ctx_ = nullptr;
 | 
			
		||||
  CUcontext context_ = nullptr;
 | 
			
		||||
  cudaStream_t parent_stream_ = nullptr;
 | 
			
		||||
#endif
 | 
			
		||||
};
 | 
			
		||||
} // namespace at::cuda
 | 
			
		||||
@ -1,270 +0,0 @@
 | 
			
		||||
#include <cstdint>
 | 
			
		||||
#include <c10/util/typeid.h>
 | 
			
		||||
#include <c10/util/Exception.h>
 | 
			
		||||
#include <c10/util/SmallVector.h>
 | 
			
		||||
#include <c10/core/Scalar.h>
 | 
			
		||||
#include <c10/core/ScalarType.h>
 | 
			
		||||
#include <c10/util/Exception.h>
 | 
			
		||||
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
 | 
			
		||||
#include <ATen/core/Tensor.h>
 | 
			
		||||
#include <ATen/core/NamedTensor.h>
 | 
			
		||||
#include <ATen/Dispatch.h>
 | 
			
		||||
#include <ATen/ExpandUtils.h>
 | 
			
		||||
#include <ATen/OpMathType.h>
 | 
			
		||||
#include <ATen/TensorUtils.h>
 | 
			
		||||
#include <ATen/cuda/CUDABlas.h>
 | 
			
		||||
#include <ATen/cuda/tunable/Tunable.h>
 | 
			
		||||
#include <ATen/cuda/tunable/TunableGemm.h>
 | 
			
		||||
#include <ATen/native/Resize.h>
 | 
			
		||||
#include <c10/util/MaybeOwned.h>
 | 
			
		||||
#include <ATen/native/GroupedMMUtils.h>
 | 
			
		||||
#include <ATen/native/cuda/RowwiseScaledMM.h>
 | 
			
		||||
#include <ATen/native/cuda/ScaledGroupMM.h>
 | 
			
		||||
#include <ATen/native/cuda/GroupMM.h>
 | 
			
		||||
#include <ATen/ceil_div.h>
 | 
			
		||||
 | 
			
		||||
#ifdef USE_FBGEMM_GENAI
 | 
			
		||||
#include <fbgemm_gpu/torch_ops.h>
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
#ifndef AT_PER_OPERATOR_HEADERS
 | 
			
		||||
#include <ATen/Functions.h>
 | 
			
		||||
#include <ATen/NativeFunctions.h>
 | 
			
		||||
#else
 | 
			
		||||
#include <ATen/ops/_addmm_activation_native.h>
 | 
			
		||||
#include <ATen/ops/_efficientzerotensor.h>
 | 
			
		||||
#include <ATen/ops/_scaled_mm_native.h>
 | 
			
		||||
#include <ATen/ops/_unsafe_view_native.h>
 | 
			
		||||
#include <ATen/ops/abs.h>
 | 
			
		||||
#include <ATen/ops/addmm_native.h>
 | 
			
		||||
#include <ATen/ops/addmv_native.h>
 | 
			
		||||
#include <ATen/ops/baddbmm_native.h>
 | 
			
		||||
#include <ATen/ops/bmm_native.h>
 | 
			
		||||
#include <ATen/ops/copy_native.h>
 | 
			
		||||
#include <ATen/ops/dot_native.h>
 | 
			
		||||
#include <ATen/ops/empty.h>
 | 
			
		||||
#include <ATen/ops/empty_strided.h>
 | 
			
		||||
#include <ATen/ops/gelu.h>
 | 
			
		||||
#include <ATen/ops/max.h>
 | 
			
		||||
#include <ATen/ops/mm_native.h>
 | 
			
		||||
#include <ATen/ops/mul.h>
 | 
			
		||||
#include <ATen/ops/relu.h>
 | 
			
		||||
#include <ATen/ops/ones.h>
 | 
			
		||||
#include <ATen/ops/scalar_tensor_native.h>
 | 
			
		||||
#include <ATen/ops/vdot_native.h>
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
using at::blas::ScalingType;
 | 
			
		||||
using at::blas::SwizzleType;
 | 
			
		||||
 | 
			
		||||
namespace at::cuda::scaled {
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 * Both inputs must be fp8,
 | 
			
		||||
 * Each needs a single scale, {Tensorwise (float)}
 | 
			
		||||
 */
 | 
			
		||||
bool check_tensorwise_recipe(c10::ScalarType type_a,
 | 
			
		||||
                             std::vector<ScalingType>& recipe_a,
 | 
			
		||||
                             ArrayRef<Tensor>& scales_a,
 | 
			
		||||
                             c10::ScalarType type_b,
 | 
			
		||||
                             std::vector<ScalingType>& recipe_b,
 | 
			
		||||
                             ArrayRef<Tensor>& scales_b) {
 | 
			
		||||
  // both types must be fp8
 | 
			
		||||
  if (!isFloat8Type(type_a) || !isFloat8Type(type_b)) {
 | 
			
		||||
    return false;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // 1 scale each, {Tensorwise, float}
 | 
			
		||||
  if (scales_a.size() != 1 || recipe_a.size() != 1 || scales_b.size() != 1 || recipe_b.size() != 1) {
 | 
			
		||||
    return false;
 | 
			
		||||
  }
 | 
			
		||||
  // Need {Blockwise_1x32, e8m0} for A & B
 | 
			
		||||
  if (recipe_a[0] != ScalingType::TensorWise) return false;
 | 
			
		||||
  if (scales_a[0].scalar_type() != ScalarType::Float) return false;
 | 
			
		||||
  if (recipe_b[0] != ScalingType::TensorWise) return false;
 | 
			
		||||
  if (scales_b[0].scalar_type() != ScalarType::Float) return false;
 | 
			
		||||
 | 
			
		||||
  return true;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 * Both inputs must be fp8,
 | 
			
		||||
 * Each needs scales, {Rowwise (float)}
 | 
			
		||||
 */
 | 
			
		||||
bool check_rowwise_recipe(c10::ScalarType type_a,
 | 
			
		||||
                             std::vector<ScalingType>& recipe_a,
 | 
			
		||||
                             ArrayRef<Tensor>& scales_a,
 | 
			
		||||
                             c10::ScalarType type_b,
 | 
			
		||||
                             std::vector<ScalingType>& recipe_b,
 | 
			
		||||
                             ArrayRef<Tensor>& scales_b) {
 | 
			
		||||
  // both types must be fp8
 | 
			
		||||
  if (!isFloat8Type(type_a) || !isFloat8Type(type_b)) {
 | 
			
		||||
    return false;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // 1 scale each, {Tensorwise, float}
 | 
			
		||||
  if (scales_a.size() != 1 || recipe_a.size() != 1 || scales_b.size() != 1 || recipe_b.size() != 1) {
 | 
			
		||||
    return false;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Need {RowWise, dp32} for A & B
 | 
			
		||||
  if (recipe_a[0] != ScalingType::RowWise) return false;
 | 
			
		||||
  if (scales_a[0].scalar_type() != ScalarType::Float) return false;
 | 
			
		||||
  if (recipe_b[0] != ScalingType::RowWise) return false;
 | 
			
		||||
  if (scales_b[0].scalar_type() != ScalarType::Float) return false;
 | 
			
		||||
 | 
			
		||||
  return true;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 * Two-level scaling, canonical NVFP4
 | 
			
		||||
 * Both inputs must be fp4
 | 
			
		||||
 * A, B need 2 scales, {Blockwise_1x16 (e4m3), Tensorwise (fp32)}
 | 
			
		||||
 */
 | 
			
		||||
bool check_nvfp4_recipe(c10::ScalarType type_a,
 | 
			
		||||
                        std::vector<ScalingType>& recipe_a,
 | 
			
		||||
                        ArrayRef<Tensor>& scales_a,
 | 
			
		||||
                        c10::ScalarType type_b,
 | 
			
		||||
                        std::vector<ScalingType>& recipe_b,
 | 
			
		||||
                        ArrayRef<Tensor>& scales_b) {
 | 
			
		||||
  // both types must be fp4
 | 
			
		||||
  if (type_a != ScalarType::Float4_e2m1fn_x2 || type_b != ScalarType::Float4_e2m1fn_x2) {
 | 
			
		||||
    return false;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // 2 scales, 2 recipes for each input
 | 
			
		||||
  if (scales_a.size() != 2 || recipe_a.size() != 2 || scales_b.size() != 2 || recipe_b.size() != 2) {
 | 
			
		||||
    return false;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Need {Blockwise_1x16, e4m3 for scale[0], Tensorwise, fp32 for scale[1]}
 | 
			
		||||
  if (recipe_a[0] != ScalingType::BlockWise1x16 || recipe_a[1] != ScalingType::TensorWise) return false;
 | 
			
		||||
  if (scales_a[0].scalar_type() != ScalarType::Float8_e4m3fn || scales_a[1].scalar_type() != ScalarType::Float) return false;
 | 
			
		||||
  if (recipe_b[0] != ScalingType::BlockWise1x16 || recipe_b[1] != ScalingType::TensorWise) return false;
 | 
			
		||||
  if (scales_b[0].scalar_type() != ScalarType::Float8_e4m3fn || scales_b[1].scalar_type() != ScalarType::Float) return false;
 | 
			
		||||
 | 
			
		||||
  return true;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 * Single-level scaling, what PyT currently understands
 | 
			
		||||
 * Both inputs must be fp4
 | 
			
		||||
 * A, B need 1 scale, {Blockwise_1x16 (e4m3)}
 | 
			
		||||
 */
 | 
			
		||||
bool check_nvfp4_recipe_single_scale
 | 
			
		||||
                       (c10::ScalarType type_a,
 | 
			
		||||
                        std::vector<ScalingType>& recipe_a,
 | 
			
		||||
                        ArrayRef<Tensor>& scales_a,
 | 
			
		||||
                        c10::ScalarType type_b,
 | 
			
		||||
                        std::vector<ScalingType>& recipe_b,
 | 
			
		||||
                        ArrayRef<Tensor>& scales_b) {
 | 
			
		||||
  // both types must be fp4
 | 
			
		||||
  if (type_a != ScalarType::Float4_e2m1fn_x2 || type_b != ScalarType::Float4_e2m1fn_x2) {
 | 
			
		||||
    return false;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // 2 scales, 2 recipes for each input
 | 
			
		||||
  if (scales_a.size() != 1 || recipe_a.size() != 1 || scales_b.size() != 1 || recipe_b.size() != 1) {
 | 
			
		||||
    return false;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Need {Blockwise_1x16, e4m3 for scale[0], Tensorwise, fp32 for scale[1]}
 | 
			
		||||
  if (recipe_a[0] != ScalingType::BlockWise1x16) return false;
 | 
			
		||||
  if (scales_a[0].scalar_type() != ScalarType::Float8_e4m3fn) return false;
 | 
			
		||||
  if (recipe_b[0] != ScalingType::BlockWise1x16) return false;
 | 
			
		||||
  if (scales_b[0].scalar_type() != ScalarType::Float8_e4m3fn) return false;
 | 
			
		||||
 | 
			
		||||
  return true;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 * Both inputs must be fp8
 | 
			
		||||
 * A, B must only have 1 scale each, A: {Blockwise_1x128 (float), B: {Blockwise_128x128 (float)
 | 
			
		||||
 */
 | 
			
		||||
bool check_deepseek_recipe(ScalingType expected_recipe_a,
 | 
			
		||||
                           ScalingType expected_recipe_b,
 | 
			
		||||
                           c10::ScalarType type_a,
 | 
			
		||||
                           std::vector<ScalingType>& recipe_a,
 | 
			
		||||
                           ArrayRef<Tensor>& scales_a,
 | 
			
		||||
                           c10::ScalarType type_b,
 | 
			
		||||
                           std::vector<ScalingType>& recipe_b,
 | 
			
		||||
                           ArrayRef<Tensor>& scales_b) {
 | 
			
		||||
  // both types must be fp8
 | 
			
		||||
  if (type_a != ScalarType::Float8_e4m3fn || type_b != ScalarType::Float8_e4m3fn) {
 | 
			
		||||
    return false;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // 1 scales, 1 recipes for each input
 | 
			
		||||
  if (scales_a.size() != 1 || recipe_a.size() != 1 || scales_b.size() != 1 || recipe_b.size() != 1) {
 | 
			
		||||
    return false;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Need {Blockwise_1x128, float} for A, {Blockwise_128x128, float} for B
 | 
			
		||||
  if (recipe_a[0] != expected_recipe_a) return false;
 | 
			
		||||
  if (scales_a[0].scalar_type() != ScalarType::Float) return false;
 | 
			
		||||
  if (recipe_b[0] != expected_recipe_b) return false;
 | 
			
		||||
  if (scales_b[0].scalar_type() != ScalarType::Float) return false;
 | 
			
		||||
 | 
			
		||||
  return true;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 * Both inputs must be fp8
 | 
			
		||||
 * A, B must have 1 scale each, {Blockwise_1x32, e8m0}
 | 
			
		||||
 */
 | 
			
		||||
bool check_mxfp8_recipe(c10::ScalarType type_a,
 | 
			
		||||
                        std::vector<ScalingType>& recipe_a,
 | 
			
		||||
                        ArrayRef<Tensor>& scales_a,
 | 
			
		||||
                        c10::ScalarType type_b,
 | 
			
		||||
                        std::vector<ScalingType>& recipe_b,
 | 
			
		||||
                        ArrayRef<Tensor>& scales_b) {
 | 
			
		||||
  // both types must be fp8
 | 
			
		||||
  if (type_a != ScalarType::Float8_e4m3fn || type_b != ScalarType::Float8_e4m3fn) {
 | 
			
		||||
    return false;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // 1 scales, 1 recipes for each input
 | 
			
		||||
  if (scales_a.size() != 1 || recipe_a.size() != 1 || scales_b.size() != 1 || recipe_b.size() != 1) {
 | 
			
		||||
    return false;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Need {Blockwise_1x32, e8m0} for A & B
 | 
			
		||||
  if (recipe_a[0] != ScalingType::BlockWise1x32) return false;
 | 
			
		||||
  if (scales_a[0].scalar_type() != ScalarType::Float8_e8m0fnu) return false;
 | 
			
		||||
  if (recipe_b[0] != ScalingType::BlockWise1x32) return false;
 | 
			
		||||
  if (scales_b[0].scalar_type() != ScalarType::Float8_e8m0fnu) return false;
 | 
			
		||||
 | 
			
		||||
  return true;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 * Both inputs must be fp4
 | 
			
		||||
 * A, B must have 1 scale each, {Blockwise_1x32, e8m0}
 | 
			
		||||
 */
 | 
			
		||||
bool check_mxfp4_recipe(c10::ScalarType type_a,
 | 
			
		||||
                        std::vector<ScalingType>& recipe_a,
 | 
			
		||||
                        ArrayRef<Tensor>& scales_a,
 | 
			
		||||
                        c10::ScalarType type_b,
 | 
			
		||||
                        std::vector<ScalingType>& recipe_b,
 | 
			
		||||
                        ArrayRef<Tensor>& scales_b) {
 | 
			
		||||
  // both types must be fp4
 | 
			
		||||
  if (type_a != ScalarType::Float4_e2m1fn_x2 || type_b != ScalarType::Float4_e2m1fn_x2) {
 | 
			
		||||
    return false;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // 1 scales, 1 recipes for each input
 | 
			
		||||
  if (scales_a.size() != 1 || recipe_a.size() != 1 || scales_b.size() != 1 || recipe_b.size() != 1) {
 | 
			
		||||
    return false;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Need {Blockwise_1x32, e8m0} for A & B
 | 
			
		||||
  if (recipe_a[0] != ScalingType::BlockWise1x32) return false;
 | 
			
		||||
  if (scales_a[0].scalar_type() != ScalarType::Float8_e8m0fnu) return false;
 | 
			
		||||
  if (recipe_b[0] != ScalingType::BlockWise1x32) return false;
 | 
			
		||||
  if (scales_b[0].scalar_type() != ScalarType::Float8_e8m0fnu) return false;
 | 
			
		||||
 | 
			
		||||
  return true;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
} // namespace at::native::cuda::blas::scaled
 | 
			
		||||
@ -1,174 +0,0 @@
 | 
			
		||||
#include <cstdint>
 | 
			
		||||
#include <c10/util/typeid.h>
 | 
			
		||||
#include <c10/util/Exception.h>
 | 
			
		||||
#include <c10/util/SmallVector.h>
 | 
			
		||||
#include <c10/core/Scalar.h>
 | 
			
		||||
#include <c10/core/ScalarType.h>
 | 
			
		||||
#include <c10/util/Exception.h>
 | 
			
		||||
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
 | 
			
		||||
#include <ATen/core/Tensor.h>
 | 
			
		||||
#include <ATen/core/NamedTensor.h>
 | 
			
		||||
#include <ATen/Dispatch.h>
 | 
			
		||||
#include <ATen/ExpandUtils.h>
 | 
			
		||||
#include <ATen/OpMathType.h>
 | 
			
		||||
#include <ATen/TensorUtils.h>
 | 
			
		||||
#include <ATen/cuda/CUDABlas.h>
 | 
			
		||||
#include <ATen/cuda/tunable/Tunable.h>
 | 
			
		||||
#include <ATen/cuda/tunable/TunableGemm.h>
 | 
			
		||||
#include <ATen/native/Resize.h>
 | 
			
		||||
#include <c10/util/MaybeOwned.h>
 | 
			
		||||
#include <ATen/native/GroupedMMUtils.h>
 | 
			
		||||
#include <ATen/native/cuda/RowwiseScaledMM.h>
 | 
			
		||||
#include <ATen/native/cuda/ScaledGroupMM.h>
 | 
			
		||||
#include <ATen/native/cuda/GroupMM.h>
 | 
			
		||||
#include <ATen/ceil_div.h>
 | 
			
		||||
 | 
			
		||||
#ifdef USE_FBGEMM_GENAI
 | 
			
		||||
#include <fbgemm_gpu/torch_ops.h>
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
#ifndef AT_PER_OPERATOR_HEADERS
 | 
			
		||||
#include <ATen/Functions.h>
 | 
			
		||||
#include <ATen/NativeFunctions.h>
 | 
			
		||||
#else
 | 
			
		||||
#include <ATen/ops/_addmm_activation_native.h>
 | 
			
		||||
#include <ATen/ops/_efficientzerotensor.h>
 | 
			
		||||
#include <ATen/ops/_scaled_mm_native.h>
 | 
			
		||||
#include <ATen/ops/_unsafe_view_native.h>
 | 
			
		||||
#include <ATen/ops/abs.h>
 | 
			
		||||
#include <ATen/ops/addmm_native.h>
 | 
			
		||||
#include <ATen/ops/addmv_native.h>
 | 
			
		||||
#include <ATen/ops/baddbmm_native.h>
 | 
			
		||||
#include <ATen/ops/bmm_native.h>
 | 
			
		||||
#include <ATen/ops/copy_native.h>
 | 
			
		||||
#include <ATen/ops/dot_native.h>
 | 
			
		||||
#include <ATen/ops/empty.h>
 | 
			
		||||
#include <ATen/ops/empty_strided.h>
 | 
			
		||||
#include <ATen/ops/gelu.h>
 | 
			
		||||
#include <ATen/ops/max.h>
 | 
			
		||||
#include <ATen/ops/mm_native.h>
 | 
			
		||||
#include <ATen/ops/mul.h>
 | 
			
		||||
#include <ATen/ops/relu.h>
 | 
			
		||||
#include <ATen/ops/ones.h>
 | 
			
		||||
#include <ATen/ops/scalar_tensor_native.h>
 | 
			
		||||
#include <ATen/ops/vdot_native.h>
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
using at::blas::ScalingType;
 | 
			
		||||
using at::blas::SwizzleType;
 | 
			
		||||
 | 
			
		||||
namespace at::cuda::scaled {
 | 
			
		||||
 | 
			
		||||
static bool _scaled_mm_allowed_device(bool sm90_only=false, bool sm100_only=false) {
 | 
			
		||||
#ifdef USE_ROCM
 | 
			
		||||
    static const std::vector<std::string> archs = {
 | 
			
		||||
        "gfx942",
 | 
			
		||||
#if ROCM_VERSION >= 60300
 | 
			
		||||
        "gfx1200", "gfx1201",
 | 
			
		||||
#endif
 | 
			
		||||
#if ROCM_VERSION >= 60500
 | 
			
		||||
        "gfx950"
 | 
			
		||||
#endif
 | 
			
		||||
    };
 | 
			
		||||
    return at::detail::getCUDAHooks().isGPUArch(archs);
 | 
			
		||||
#else
 | 
			
		||||
    auto dprops = at::cuda::getCurrentDeviceProperties();
 | 
			
		||||
 | 
			
		||||
    if (sm90_only || sm100_only) {
 | 
			
		||||
      return (sm90_only && dprops->major == 9) || (sm100_only && dprops->major == 10);
 | 
			
		||||
    } else {
 | 
			
		||||
      return dprops->major >= 9 || (dprops->major == 8 && dprops->minor == 9);
 | 
			
		||||
    }
 | 
			
		||||
#endif
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#ifdef USE_ROCM
 | 
			
		||||
static bool _scaled_mm_is_fnuz() {
 | 
			
		||||
    return at::detail::getCUDAHooks().isGPUArch({"gfx942"});
 | 
			
		||||
}
 | 
			
		||||
#endif
 | 
			
		||||
/**
 | 
			
		||||
 * Track concrete implementations available
 | 
			
		||||
 */
 | 
			
		||||
enum class ScaledGemmImplementation {
 | 
			
		||||
  NONE = 0,
 | 
			
		||||
  TENSORWISE_TENSORWISE = 1,
 | 
			
		||||
  ROWWISE_ROWWISE = 2,
 | 
			
		||||
  BLOCK_128x128_1x128 = 3,
 | 
			
		||||
  BLOCK_1x128_128x128 = 4,
 | 
			
		||||
  BLOCK_1x128_1x128 = 5,
 | 
			
		||||
  MXFP8_MXFP8 = 6,
 | 
			
		||||
  NVFP4_NVFP4 = 7,
 | 
			
		||||
  NVFP4_NVFP4_SINGLE_SCALE = 8,
 | 
			
		||||
  MXFP4_MXFP4 = 9,
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 * Convert passed int (enum) from python back into a
 | 
			
		||||
 * strictly-typed enum
 | 
			
		||||
 */
 | 
			
		||||
template <class EnumType, class ArrayType>
 | 
			
		||||
std::vector<EnumType> convert_int_to_enum(ArrayType& v) {
 | 
			
		||||
  std::vector<EnumType> converted;
 | 
			
		||||
  converted.reserve(v.size());
 | 
			
		||||
 | 
			
		||||
  for (auto vi : v) {
 | 
			
		||||
    converted.push_back(static_cast<EnumType>(vi));
 | 
			
		||||
  }
 | 
			
		||||
  return converted;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
bool check_tensorwise_recipe(c10::ScalarType,
 | 
			
		||||
                             std::vector<ScalingType>&,
 | 
			
		||||
                             ArrayRef<Tensor>&,
 | 
			
		||||
                             c10::ScalarType,
 | 
			
		||||
                             std::vector<ScalingType>&,
 | 
			
		||||
                             ArrayRef<Tensor>&);
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
bool check_rowwise_recipe(c10::ScalarType,
 | 
			
		||||
                             std::vector<ScalingType>&,
 | 
			
		||||
                             ArrayRef<Tensor>&,
 | 
			
		||||
                             c10::ScalarType,
 | 
			
		||||
                             std::vector<ScalingType>&,
 | 
			
		||||
                             ArrayRef<Tensor>&);
 | 
			
		||||
 | 
			
		||||
bool check_nvfp4_recipe(c10::ScalarType,
 | 
			
		||||
                        std::vector<ScalingType>&,
 | 
			
		||||
                        ArrayRef<Tensor>&,
 | 
			
		||||
                        c10::ScalarType,
 | 
			
		||||
                        std::vector<ScalingType>&,
 | 
			
		||||
                        ArrayRef<Tensor>&);
 | 
			
		||||
 | 
			
		||||
bool check_nvfp4_recipe_single_scale
 | 
			
		||||
                       (c10::ScalarType,
 | 
			
		||||
                        std::vector<ScalingType>&,
 | 
			
		||||
                        ArrayRef<Tensor>&,
 | 
			
		||||
                        c10::ScalarType,
 | 
			
		||||
                        std::vector<ScalingType>&,
 | 
			
		||||
                        ArrayRef<Tensor>&);
 | 
			
		||||
 | 
			
		||||
bool check_deepseek_recipe(ScalingType,
 | 
			
		||||
                           ScalingType,
 | 
			
		||||
                           c10::ScalarType,
 | 
			
		||||
                           std::vector<ScalingType>&,
 | 
			
		||||
                           ArrayRef<Tensor>&,
 | 
			
		||||
                           c10::ScalarType,
 | 
			
		||||
                           std::vector<ScalingType>&,
 | 
			
		||||
                           ArrayRef<Tensor>&);
 | 
			
		||||
 | 
			
		||||
bool check_mxfp8_recipe(c10::ScalarType,
 | 
			
		||||
                        std::vector<ScalingType>&,
 | 
			
		||||
                        ArrayRef<Tensor>&,
 | 
			
		||||
                        c10::ScalarType,
 | 
			
		||||
                        std::vector<ScalingType>&,
 | 
			
		||||
                        ArrayRef<Tensor>&);
 | 
			
		||||
 | 
			
		||||
bool check_mxfp4_recipe(c10::ScalarType,
 | 
			
		||||
                        std::vector<ScalingType>&,
 | 
			
		||||
                        ArrayRef<Tensor>&,
 | 
			
		||||
                        c10::ScalarType,
 | 
			
		||||
                        std::vector<ScalingType>&,
 | 
			
		||||
                        ArrayRef<Tensor>&);
 | 
			
		||||
 | 
			
		||||
} // namespace at::native::cuda::blas::scaled
 | 
			
		||||
@ -137,7 +137,7 @@ struct CUDACachingHostAllocatorImpl
 | 
			
		||||
  void free_block_slowpath(Block* block) {
 | 
			
		||||
    auto start = std::chrono::steady_clock::now();
 | 
			
		||||
    // Users may change the allocator config at will. torch unit tests do this.
 | 
			
		||||
    // However, allocations using cudaHostRegister should use corresponding
 | 
			
		||||
    // However, allocations using cudaHostRegister should use corresonding
 | 
			
		||||
    // cudaHostUnregister and similarly for cudaHostAlloc / cudaFreeHost.
 | 
			
		||||
    void* ptr = block->ptr_;
 | 
			
		||||
    bool use_register = false;
 | 
			
		||||
 | 
			
		||||
@ -70,7 +70,11 @@
 | 
			
		||||
#define ATEN_CUB_MAXIMUM() NO_ROCM(at_cuda_detail)ROCM_HIPCUB(::cub)::Max()
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
#if defined(USE_ROCM)
 | 
			
		||||
#if (!defined(USE_ROCM) && !CUB_SUPPORTS_NV_BFLOAT16()) || defined(USE_ROCM)
 | 
			
		||||
 | 
			
		||||
#if !defined(USE_ROCM)
 | 
			
		||||
namespace at_cuda_detail {
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
// backport https://github.com/NVIDIA/cub/pull/306 for c10::BFloat16
 | 
			
		||||
 | 
			
		||||
@ -92,6 +96,10 @@ template <>
 | 
			
		||||
struct ROCM_HIPCUB(cub)::NumericTraits<c10::BFloat16>:
 | 
			
		||||
       ROCM_HIPCUB(cub)::BaseTraits<ROCM_HIPCUB(cub)::FLOATING_POINT, true, false, unsigned short, c10::BFloat16> {};
 | 
			
		||||
 | 
			
		||||
#if !defined(USE_ROCM)
 | 
			
		||||
} // namespace at_cuda_detail
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
#if !defined(USE_ROCM)
 | 
			
		||||
@ -113,7 +121,7 @@ struct cuda_type<c10::Half> {
 | 
			
		||||
  using type = __half;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
#if !defined(USE_ROCM)
 | 
			
		||||
#if !defined(USE_ROCM) && CUB_SUPPORTS_NV_BFLOAT16()
 | 
			
		||||
 | 
			
		||||
template<>
 | 
			
		||||
struct cuda_type<c10::BFloat16> {
 | 
			
		||||
@ -195,6 +203,36 @@ __global__ void transform_vals(InputIteratorT1 a, InputIteratorT2 b, OutputItera
 | 
			
		||||
  *out = scan_op(static_cast<acc_t>(*a), static_cast<acc_t>(*b));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#if !CUB_SUPPORTS_FUTURE_VALUE()
 | 
			
		||||
template<typename ValueT, typename InputIteratorT>
 | 
			
		||||
struct chained_iterator {
 | 
			
		||||
  using iterator_category = std::random_access_iterator_tag;
 | 
			
		||||
  using difference_type   = std::ptrdiff_t;
 | 
			
		||||
  using value_type        = ValueT;
 | 
			
		||||
  using pointer           = ValueT*;
 | 
			
		||||
  using reference         = ValueT&;
 | 
			
		||||
 | 
			
		||||
  InputIteratorT iter;
 | 
			
		||||
  ValueT *first;
 | 
			
		||||
  difference_type offset = 0;
 | 
			
		||||
 | 
			
		||||
  __device__ ValueT operator[](difference_type i) {
 | 
			
		||||
    i +=  offset;
 | 
			
		||||
    if (i == 0) {
 | 
			
		||||
      return *first;
 | 
			
		||||
    } else {
 | 
			
		||||
      return ValueT(iter[i - 1]);
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
  __device__ chained_iterator operator+(difference_type i) {
 | 
			
		||||
    return chained_iterator{iter, first, i};
 | 
			
		||||
  }
 | 
			
		||||
  __device__ ValueT operator*() {
 | 
			
		||||
    return (*this)[0];
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
// even though cub is supposed to support tensors with int_max elements, in reality it doesn't,
 | 
			
		||||
// so split at int_max/2
 | 
			
		||||
constexpr int max_cub_size = std::numeric_limits<int>::max() / 2 + 1; // 2**30
 | 
			
		||||
@ -239,6 +277,25 @@ inline void inclusive_scan(InputIteratorT input, OutputIteratorT output, ScanOpT
 | 
			
		||||
        first_elem_ptr,
 | 
			
		||||
        scan_op);
 | 
			
		||||
    C10_CUDA_KERNEL_LAUNCH_CHECK();
 | 
			
		||||
#if !CUB_SUPPORTS_FUTURE_VALUE()
 | 
			
		||||
    using ArgIndexInputIterator = NO_ROCM(at_cuda_detail)::cub::ArgIndexInputIterator<InputIteratorT>;
 | 
			
		||||
    using tuple = typename ArgIndexInputIterator::value_type;
 | 
			
		||||
    auto input_iter_transform = [=] __device__ (const tuple &x)->input_t  {
 | 
			
		||||
      if (x.key == 0) {
 | 
			
		||||
        return *first_elem_ptr;
 | 
			
		||||
      } else {
 | 
			
		||||
        return x.value;
 | 
			
		||||
      }
 | 
			
		||||
    };
 | 
			
		||||
    auto input_ = ATEN_CUB_TRANSFORM_ITERATOR(input_t, decltype(input_iter_transform), ArgIndexInputIterator)(
 | 
			
		||||
      ArgIndexInputIterator(input + i), input_iter_transform);
 | 
			
		||||
    CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceScan::InclusiveScan,
 | 
			
		||||
        input_,
 | 
			
		||||
        output + i,
 | 
			
		||||
        scan_op,
 | 
			
		||||
        size_cub,
 | 
			
		||||
        at::cuda::getCurrentCUDAStream());
 | 
			
		||||
#else
 | 
			
		||||
    CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceScan::ExclusiveScan,
 | 
			
		||||
        input + i + 1,
 | 
			
		||||
        output + i,
 | 
			
		||||
@ -246,6 +303,7 @@ inline void inclusive_scan(InputIteratorT input, OutputIteratorT output, ScanOpT
 | 
			
		||||
        ::at_cuda_detail::cub::FutureValue<input_t>(first_elem_ptr),
 | 
			
		||||
        size_cub,
 | 
			
		||||
        at::cuda::getCurrentCUDAStream());
 | 
			
		||||
#endif
 | 
			
		||||
  }
 | 
			
		||||
#endif
 | 
			
		||||
}
 | 
			
		||||
@ -497,6 +555,16 @@ inline void exclusive_scan(InputIteratorT input, OutputIteratorT output, ScanOpT
 | 
			
		||||
        first_elem_ptr,
 | 
			
		||||
        scan_op);
 | 
			
		||||
    C10_CUDA_KERNEL_LAUNCH_CHECK();
 | 
			
		||||
#if !CUB_SUPPORTS_FUTURE_VALUE()
 | 
			
		||||
    auto input_ = impl::chained_iterator<InitValueT, InputIteratorT>{
 | 
			
		||||
      input + i, first_elem_ptr};
 | 
			
		||||
    CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceScan::InclusiveScan,
 | 
			
		||||
        input_,
 | 
			
		||||
        output + i,
 | 
			
		||||
        scan_op,
 | 
			
		||||
        size_cub,
 | 
			
		||||
        at::cuda::getCurrentCUDAStream());
 | 
			
		||||
#else
 | 
			
		||||
    CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceScan::ExclusiveScan,
 | 
			
		||||
        input + i,
 | 
			
		||||
        output + i,
 | 
			
		||||
@ -504,6 +572,7 @@ inline void exclusive_scan(InputIteratorT input, OutputIteratorT output, ScanOpT
 | 
			
		||||
        ::at_cuda_detail::cub::FutureValue<InitValueT>(first_elem_ptr),
 | 
			
		||||
        size_cub,
 | 
			
		||||
        at::cuda::getCurrentCUDAStream());
 | 
			
		||||
#endif
 | 
			
		||||
  }
 | 
			
		||||
#endif
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -4,7 +4,7 @@
 | 
			
		||||
#include <ATen/cuda/CUDAConfig.h>
 | 
			
		||||
 | 
			
		||||
// NOTE: These templates are intentionally not defined in this header,
 | 
			
		||||
// which avoids re-compiling them for each translation unit. If you get
 | 
			
		||||
// which aviods re-compiling them for each translation unit. If you get
 | 
			
		||||
// a link error, you need to add an explicit instantiation for your
 | 
			
		||||
// types in cub.cu
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -10,6 +10,14 @@
 | 
			
		||||
#define CUB_VERSION 200001
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
// cub sort support for __nv_bfloat16 is added to cub 1.13 in:
 | 
			
		||||
// https://github.com/NVIDIA/cub/pull/306
 | 
			
		||||
#if CUB_VERSION >= 101300
 | 
			
		||||
#define CUB_SUPPORTS_NV_BFLOAT16() true
 | 
			
		||||
#else
 | 
			
		||||
#define CUB_SUPPORTS_NV_BFLOAT16() false
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
// cub support for CUB_WRAPPED_NAMESPACE is added to cub 1.13.1 in:
 | 
			
		||||
// https://github.com/NVIDIA/cub/pull/326
 | 
			
		||||
// CUB_WRAPPED_NAMESPACE is defined globally in cmake/Dependencies.cmake
 | 
			
		||||
@ -20,6 +28,14 @@
 | 
			
		||||
#define USE_GLOBAL_CUB_WRAPPED_NAMESPACE() false
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
// cub support for cub::FutureValue is added to cub 1.15 in:
 | 
			
		||||
// https://github.com/NVIDIA/cub/pull/305
 | 
			
		||||
#if CUB_VERSION >= 101500
 | 
			
		||||
#define CUB_SUPPORTS_FUTURE_VALUE() true
 | 
			
		||||
#else
 | 
			
		||||
#define CUB_SUPPORTS_FUTURE_VALUE() false
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
// There were many bc-breaking changes in major version release of CCCL v3.0.0
 | 
			
		||||
// Please see https://nvidia.github.io/cccl/cccl/3.0_migration_guide.html
 | 
			
		||||
#if CUB_VERSION >= 200800
 | 
			
		||||
 | 
			
		||||
@ -38,7 +38,7 @@ GemmTunableOp_float_NT,nt_25088_4096_64,1219,1.262
 | 
			
		||||
GemmTunableOp_float_NT,nt_4096_4096_64,1216,0.033
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
Note the "Validator" lines. If you change a library version, or ROCm version, or PyTorch version, TunableOp will detect
 | 
			
		||||
Note the "Validator" lines. If you change a library verison, or ROCm version, or PyTorch version, TunableOp will detect
 | 
			
		||||
this and reject the tunings file because the prior tunings are likely affected by other software changes.
 | 
			
		||||
 | 
			
		||||
The remaining lines are the tuned solutions for each TunableOp encountered during your execution. Each line consists of
 | 
			
		||||
 | 
			
		||||
@ -235,7 +235,7 @@ class TunableOp {
 | 
			
		||||
      // numeric check option is controlled by non-static env var, so check it once per tuned operator
 | 
			
		||||
      bool do_numerics_check = ctx->IsNumericsCheckEnabled();
 | 
			
		||||
 | 
			
		||||
      // calculate a reference answer for numerical check
 | 
			
		||||
      // calcaulte a reference answer for numerical check
 | 
			
		||||
      if (do_numerics_check) {
 | 
			
		||||
        reference_params = params->DeepCopy(false);
 | 
			
		||||
        TORCH_CHECK(ops_[ResultEntry::Default()]->Call(reference_params) == OK);
 | 
			
		||||
 | 
			
		||||
@ -12,7 +12,7 @@ namespace at {
 | 
			
		||||
 | 
			
		||||
// AcceleratorHooksInterface is a shared interface provided by all
 | 
			
		||||
// accelerators to allow generic code.
 | 
			
		||||
// This interface is hook-based as it corresponds to all the functions
 | 
			
		||||
// This inferface is hook-based as it corresponds to all the functions
 | 
			
		||||
// that are going to be called in a generic way from the CPU code.
 | 
			
		||||
 | 
			
		||||
struct TORCH_API AcceleratorHooksInterface {
 | 
			
		||||
 | 
			
		||||
@ -38,7 +38,7 @@ struct TORCH_API PrivateUse1HooksInterface : AcceleratorHooksInterface {
 | 
			
		||||
 | 
			
		||||
  Generator getNewGenerator(
 | 
			
		||||
      [[maybe_unused]] DeviceIndex device_index = -1) const override {
 | 
			
		||||
    // TODO(FFFrog): Preserved for BC and will be removed in the future.
 | 
			
		||||
    // TODO(FFFrog): Perserved for BC and will be removed in the future.
 | 
			
		||||
    if (at::GetGeneratorPrivate().has_value())
 | 
			
		||||
      return at::GetGeneratorForPrivateuse1(device_index);
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -1,23 +0,0 @@
 | 
			
		||||
#include <ATen/detail/XLAHooksInterface.h>
 | 
			
		||||
 | 
			
		||||
namespace at {
 | 
			
		||||
namespace detail {
 | 
			
		||||
 | 
			
		||||
const XLAHooksInterface& getXLAHooks() {
 | 
			
		||||
  auto create_impl = [] {
 | 
			
		||||
    // Create XLA hooks using the registry
 | 
			
		||||
    auto hooks = XLAHooksRegistry()->Create("torch_xla::detail::XLAHooks", XLAHooksArgs{});
 | 
			
		||||
    if (hooks) {
 | 
			
		||||
      return hooks;
 | 
			
		||||
    }
 | 
			
		||||
    // If hooks creation fails, fall back to default implementation
 | 
			
		||||
    return std::make_unique<XLAHooksInterface>();
 | 
			
		||||
  };
 | 
			
		||||
  static auto hooks = create_impl();
 | 
			
		||||
  return *hooks;
 | 
			
		||||
}
 | 
			
		||||
} // namespace detail
 | 
			
		||||
 | 
			
		||||
C10_DEFINE_REGISTRY(XLAHooksRegistry, XLAHooksInterface, XLAHooksArgs)
 | 
			
		||||
 | 
			
		||||
} // namespace at
 | 
			
		||||
@ -1,79 +0,0 @@
 | 
			
		||||
#pragma once
 | 
			
		||||
 | 
			
		||||
#include <c10/core/Device.h>
 | 
			
		||||
#include <c10/util/Exception.h>
 | 
			
		||||
#include <c10/util/Registry.h>
 | 
			
		||||
 | 
			
		||||
#include <ATen/detail/AcceleratorHooksInterface.h>
 | 
			
		||||
 | 
			
		||||
C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-parameter")
 | 
			
		||||
 | 
			
		||||
namespace at {
 | 
			
		||||
 | 
			
		||||
constexpr const char* XLA_HELP =
 | 
			
		||||
  "This error has occurred because you are trying "
 | 
			
		||||
  "to use some XLA functionality, but the XLA library has not been "
 | 
			
		||||
  "loaded by the dynamic linker. You must load xla libraries by `import torch_xla`";
 | 
			
		||||
 | 
			
		||||
struct TORCH_API XLAHooksInterface : AcceleratorHooksInterface {
 | 
			
		||||
  ~XLAHooksInterface() override = default;
 | 
			
		||||
 | 
			
		||||
  void init() const override {
 | 
			
		||||
    TORCH_CHECK(false, "Cannot initialize XLA without torch_xla library. ", XLA_HELP);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  virtual bool hasXLA() const {
 | 
			
		||||
    return false;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  virtual std::string showConfig() const {
 | 
			
		||||
    TORCH_CHECK(
 | 
			
		||||
        false,
 | 
			
		||||
        "Cannot query detailed XLA version without torch_xla library. ",
 | 
			
		||||
        XLA_HELP);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  const Generator& getDefaultGenerator(
 | 
			
		||||
      [[maybe_unused]] DeviceIndex device_index = -1) const override {
 | 
			
		||||
    TORCH_CHECK(
 | 
			
		||||
        false, "Cannot get default XLA generator without torch_xla library. ", XLA_HELP);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  Generator getNewGenerator(
 | 
			
		||||
      [[maybe_unused]] DeviceIndex device_index = -1) const override {
 | 
			
		||||
    TORCH_CHECK(false, "Cannot get XLA generator without torch_xla library. ", XLA_HELP);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  virtual DeviceIndex getCurrentDevice() const override {
 | 
			
		||||
    TORCH_CHECK(false, "Cannot get current XLA device without torch_xla library. ", XLA_HELP);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  Device getDeviceFromPtr(void* /*data*/) const override {
 | 
			
		||||
    TORCH_CHECK(false, "Cannot get device of pointer on XLA without torch_xla library. ", XLA_HELP);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  Allocator* getPinnedMemoryAllocator() const override {
 | 
			
		||||
    TORCH_CHECK(false, "Cannot get XLA pinned memory allocator without torch_xla library. ", XLA_HELP);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  bool isPinnedPtr(const void* data) const override {
 | 
			
		||||
    return false;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  bool hasPrimaryContext(DeviceIndex device_index) const override {
 | 
			
		||||
    TORCH_CHECK(false, "Cannot query primary context without torch_xla library. ", XLA_HELP);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
struct TORCH_API XLAHooksArgs {};
 | 
			
		||||
 | 
			
		||||
TORCH_DECLARE_REGISTRY(XLAHooksRegistry, XLAHooksInterface, XLAHooksArgs);
 | 
			
		||||
#define REGISTER_XLA_HOOKS(clsname) \
 | 
			
		||||
  C10_REGISTER_CLASS(XLAHooksRegistry, clsname, clsname)
 | 
			
		||||
 | 
			
		||||
namespace detail {
 | 
			
		||||
TORCH_API const XLAHooksInterface& getXLAHooks();
 | 
			
		||||
} // namespace detail
 | 
			
		||||
} // namespace at
 | 
			
		||||
C10_DIAGNOSTIC_POP()
 | 
			
		||||
@ -283,7 +283,7 @@ inline void boxed_existing_bdim_all_batch_rule(
 | 
			
		||||
// Use when all tensors arguments accept one (normal) batch dim.
 | 
			
		||||
// This batching rule expands the batch dim on all Tensors, reshapes it into
 | 
			
		||||
// dim 0, calls the op, and then reshapes the batch dim out of dim 0.
 | 
			
		||||
// This is not the most efficient thing; if there are alternatives, please try
 | 
			
		||||
// This is not the most efficient thing; if there are alternatives, plese try
 | 
			
		||||
// to use them. Use this only as a last resort.
 | 
			
		||||
#define EXISTING_BDIM_ALL_BOXED(op) \
 | 
			
		||||
  m.impl(#op, torch::CppFunction::makeFromBoxedFunction<boxed_existing_bdim_all_batch_rule>());
 | 
			
		||||
 | 
			
		||||
@ -384,7 +384,7 @@ fourOutputs solve_ex_batch_rule(
 | 
			
		||||
 | 
			
		||||
  // NOTE [ solve_ex Batch Rule Contiguity ]
 | 
			
		||||
  // A determines whether or not linalg_solve takes an optimized path. We need the check on A_ to match the one run on
 | 
			
		||||
  // A as BatchedTensor since it might have been saved by autograd (specifically by the jvp) and the autograd behavior
 | 
			
		||||
  // A as BatchedTensor since it might have been saved by autograd (specifically by the jvp) and the autograd behvaior
 | 
			
		||||
  // differs based on whether or not the optimized path was taken
 | 
			
		||||
  const auto batched_A_was_contiguous = A_bdim.has_value() ? at::select(A, *A_bdim, 0).is_contiguous() : A.is_contiguous();
 | 
			
		||||
  if (batched_A_was_contiguous && !A.is_complex()) {
 | 
			
		||||
 | 
			
		||||
@ -282,7 +282,7 @@ static std::tuple<Tensor, std::optional<int64_t>> _softmax_backward_batch_rule(
 | 
			
		||||
 | 
			
		||||
  dim = getPhysicalDim(output_, /*has_batch_dim*/true, dim);
 | 
			
		||||
 | 
			
		||||
  // Not sure why output_ needs to be marked as .contiguous(). Something must
 | 
			
		||||
  // Not sure why output_ needs to be marked as .contiguous(). Someting must
 | 
			
		||||
  // have changed in PyTorch (and output of softmax is probably always contiguous)
 | 
			
		||||
  return std::make_tuple(at::_softmax_backward_data(grad_output_, output_.contiguous(), dim, input_dtype), 0);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -224,7 +224,7 @@ static Tensor safeStack(TensorList tensors) {
 | 
			
		||||
  // is possible for the backward function to return an undefined grad for some
 | 
			
		||||
  // grad_input for each example. In that case, we return an undefined grad.
 | 
			
		||||
  //
 | 
			
		||||
  // It is theoretically possible for *some* of the examples to produce an
 | 
			
		||||
  // It is theoretically posssible for *some* of the examples to produce an
 | 
			
		||||
  // undefined grad (a kernel could peek at the gradient values and return an
 | 
			
		||||
  // undefined tensor if it determines the gradient is full of zeros). We
 | 
			
		||||
  // could handle this by treating the undefined grad as a zero-filled tensor
 | 
			
		||||
 | 
			
		||||
@ -113,7 +113,7 @@ SymIntArrayRef BatchedTensorImpl::sym_sizes_custom() const {
 | 
			
		||||
  return sym_sizes_default();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// The following are publicly exposed as methods of Tensor
 | 
			
		||||
// The following are publically exposed as methods of Tensor
 | 
			
		||||
 | 
			
		||||
IntArrayRef BatchedTensorImpl::strides_custom() const {
 | 
			
		||||
  return strides_default();
 | 
			
		||||
 | 
			
		||||
@ -37,7 +37,7 @@ namespace at::functorch  {
 | 
			
		||||
// how to perform the transform.
 | 
			
		||||
//
 | 
			
		||||
// TODO: we can excise DynamicLayer in favor of Interpreter,
 | 
			
		||||
// But I am going to leave it for now as a compatibility shim to avoid
 | 
			
		||||
// But I am going to leave it for now as a compatiblity shim to avoid
 | 
			
		||||
// needing to refactor a lot of callsites...
 | 
			
		||||
struct TORCH_API DynamicLayer {
 | 
			
		||||
  explicit DynamicLayer(
 | 
			
		||||
 | 
			
		||||
@ -88,7 +88,7 @@ std::ostream& operator<<(std::ostream& os, const TransformType& t);
 | 
			
		||||
// >>> VmapInterpreterPtr(&interpreter).batchSize()
 | 
			
		||||
//
 | 
			
		||||
// Finally, Interpreter::process switches on the type of the interpreter
 | 
			
		||||
// and calls one of {Transform}Interpreter::processImpl under the hood.
 | 
			
		||||
// and calls one of {Transform}Intepreter::processImpl under the hood.
 | 
			
		||||
// Same for Interpreter::sendToNextInterpreter :)
 | 
			
		||||
 | 
			
		||||
struct VmapInterpreterMeta {
 | 
			
		||||
 | 
			
		||||
@ -733,7 +733,7 @@ TORCH_LIBRARY_IMPL(_, FuncTorchBatched, m) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) {
 | 
			
		||||
  // still legacy b/c returns multiple tensors
 | 
			
		||||
  // still legacy b/c teturns multiple tensors
 | 
			
		||||
  m.impl("split.Tensor", split_batching_rule);
 | 
			
		||||
  m.impl("split_with_sizes", split_with_sizes_batching_rule);
 | 
			
		||||
  m.impl("split_with_sizes_copy", split_with_sizes_copy_batching_rule);
 | 
			
		||||
 | 
			
		||||
@ -158,7 +158,7 @@ void MPSStream::fill(id<MTLBuffer> buffer, uint8_t value, size_t length, size_t
 | 
			
		||||
      endKernelCoalescing();
 | 
			
		||||
      id<MTLBlitCommandEncoder> blitEncoder = [commandBuffer() blitCommandEncoder];
 | 
			
		||||
 | 
			
		||||
      // For some reason fillBufferfor stopped working for length > 4Gb on MacOS 26
 | 
			
		||||
      // For some reason fillBufferfor stopped working for lengh > 4Gb on MacOS 26
 | 
			
		||||
      // See https://github.com/pytorch/pytorch/issues/163962
 | 
			
		||||
      // Workaround by batching copy commands into 4Gb chunks
 | 
			
		||||
      constexpr size_t max_copy_size = 0x100000000; // 4GB
 | 
			
		||||
 | 
			
		||||
@ -3620,7 +3620,7 @@ Tensor& _int_mm_out_cpu(const Tensor& self, const Tensor& mat2, Tensor& result)
 | 
			
		||||
    try {
 | 
			
		||||
      mkldnn_matmul_i8i8i32(self, mat2, result);
 | 
			
		||||
      dispatched = true;
 | 
			
		||||
    } catch ([[maybe_unused]] const std::exception& e) {
 | 
			
		||||
    } catch (const std::exception& e) {
 | 
			
		||||
      TORCH_WARN(func_name, " failed, switching to BLAS gemm: ", e.what());
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
@ -148,7 +148,7 @@ inline void checkInputsSolver(const Tensor& A,
 | 
			
		||||
 | 
			
		||||
inline bool is_row_or_column_contiguous(const Tensor& t) {
 | 
			
		||||
  // This could be made more general, similar to how it's checked in matmul, which would allow to
 | 
			
		||||
  // elide the copy with strides such as (6, 12, 1, 3) or (3, 1, 9), but this is quite tricky.
 | 
			
		||||
  // ellide the copy with strides such as (6, 12, 1, 3) or (3, 1, 9), but this is quite tricky.
 | 
			
		||||
  // We choose to be conservative for simplicity
 | 
			
		||||
  return t.is_contiguous() || t.transpose(-2, -1).is_contiguous();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -11,8 +11,6 @@ inline void check_pixel_shuffle_shapes(const Tensor& self, int64_t upscale_facto
 | 
			
		||||
              "pixel_shuffle expects a positive upscale_factor, but got ",
 | 
			
		||||
              upscale_factor);
 | 
			
		||||
  int64_t c = self.size(-3);
 | 
			
		||||
  TORCH_CHECK_VALUE(upscale_factor <= std::numeric_limits<decltype(upscale_factor)>::max() / upscale_factor,
 | 
			
		||||
        "upscale factor is too large, (upscale_factor)^2 overflowed: upscale_factor=", upscale_factor);
 | 
			
		||||
  int64_t upscale_factor_squared = upscale_factor * upscale_factor;
 | 
			
		||||
  TORCH_CHECK(c % upscale_factor_squared == 0,
 | 
			
		||||
              "pixel_shuffle expects its input's 'channel' dimension to be divisible by the square of "
 | 
			
		||||
 | 
			
		||||
@ -21,7 +21,7 @@ enum class fft_norm_mode {
 | 
			
		||||
// NOTE [ Fourier Transform Conjugate Symmetry ]
 | 
			
		||||
//
 | 
			
		||||
// Real-to-complex Fourier transform satisfies the conjugate symmetry. That is,
 | 
			
		||||
// assuming X is the transformed K-dimensional signal, we have
 | 
			
		||||
// assuming X is the transformed K-dimensionsal signal, we have
 | 
			
		||||
//
 | 
			
		||||
//     X[i_1, ..., i_K] = X[j_i, ..., j_K]*,
 | 
			
		||||
//
 | 
			
		||||
 | 
			
		||||
Some files were not shown because too many files have changed in this diff Show More
		Reference in New Issue
	
	Block a user