mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-27 17:54:55 +08:00
Compare commits
3 Commits
gh/ezyang/
...
cpp-docs-d
| Author | SHA1 | Date | |
|---|---|---|---|
| 5b6cc8215f | |||
| 1c43c9cfd0 | |||
| 102e0d5437 |
@ -19,7 +19,7 @@ pip_install \
|
||||
transformers==4.36.2
|
||||
|
||||
pip_install coloredlogs packaging
|
||||
pip_install onnxruntime==1.23.1
|
||||
pip_install onnxruntime==1.23.0
|
||||
pip_install onnxscript==0.5.4
|
||||
|
||||
# Cache the transformers model to be used later by ONNX tests. We need to run the transformers
|
||||
|
||||
@ -334,12 +334,12 @@ sympy==1.13.3
|
||||
#Pinned versions:
|
||||
#test that import:
|
||||
|
||||
onnx==1.19.1
|
||||
onnx==1.18.0
|
||||
#Description: Required by onnx tests, and mypy and test_public_bindings.py when checking torch.onnx._internal
|
||||
#Pinned versions:
|
||||
#test that import:
|
||||
|
||||
onnxscript==0.5.4
|
||||
onnxscript==0.5.3
|
||||
#Description: Required by mypy and test_public_bindings.py when checking torch.onnx._internal
|
||||
#Pinned versions:
|
||||
#test that import:
|
||||
|
||||
@ -1,15 +1,11 @@
|
||||
sphinx==5.3.0
|
||||
sphinx==7.2.6
|
||||
#Description: This is used to generate PyTorch docs
|
||||
#Pinned versions: 5.3.0
|
||||
#Pinned versions: 7.2.6
|
||||
|
||||
standard-imghdr==3.13.0; python_version >= "3.13"
|
||||
#Description: This is needed by Sphinx, so it needs to be added here.
|
||||
# The reasons are as follows:
|
||||
# 1) This module has been removed from the Python standard library since Python 3.13(https://peps.python.org/pep-0594/#imghdr);
|
||||
# 2) The current version of Sphinx (5.3.0) is not compatible with Python 3.13.
|
||||
# Once Sphinx is upgraded to a version compatible with Python 3.13 or later, we can remove this dependency.
|
||||
pytorch_sphinx_theme2==0.1.0
|
||||
#Description: This is needed to generate PyTorch docs
|
||||
#Pinned versions: 0.1.0
|
||||
|
||||
-e git+https://github.com/pytorch/pytorch_sphinx_theme.git@71e55749be14ceb56e7f8211a9fb649866b87ad4#egg=pytorch_sphinx_theme2
|
||||
# TODO: sphinxcontrib.katex 0.9.0 adds a local KaTeX server to speed up pre-rendering
|
||||
# but it doesn't seem to work and hangs around idly. The initial thought that it is probably
|
||||
# something related to Docker setup. We can investigate this later.
|
||||
@ -36,17 +32,17 @@ tensorboard==2.18.0 ; python_version >= "3.13"
|
||||
#Description: This is used to generate PyTorch docs
|
||||
#Pinned versions: 2.13.0
|
||||
|
||||
breathe==4.34.0
|
||||
breathe==4.36.0
|
||||
#Description: This is used to generate PyTorch C++ docs
|
||||
#Pinned versions: 4.34.0
|
||||
#Pinned versions: 4.36.0
|
||||
|
||||
exhale==0.2.3
|
||||
exhale==0.3.7
|
||||
#Description: This is used to generate PyTorch C++ docs
|
||||
#Pinned versions: 0.2.3
|
||||
#Pinned versions: 0.3.7
|
||||
|
||||
docutils==0.16
|
||||
docutils==0.20
|
||||
#Description: This is used to generate PyTorch C++ docs
|
||||
#Pinned versions: 0.16
|
||||
#Pinned versions: 0.20
|
||||
|
||||
bs4==0.0.1
|
||||
#Description: This is used to generate PyTorch C++ docs
|
||||
@ -56,13 +52,13 @@ IPython==8.12.0
|
||||
#Description: This is used to generate PyTorch functorch docs
|
||||
#Pinned versions: 8.12.0
|
||||
|
||||
myst-nb==0.17.2
|
||||
myst-nb==1.3.0
|
||||
#Description: This is used to generate PyTorch functorch and torch.compile docs.
|
||||
#Pinned versions: 0.17.2
|
||||
#Pinned versions: 1.3.0
|
||||
|
||||
# The following are required to build torch.distributed.elastic.rendezvous.etcd* docs
|
||||
python-etcd==0.4.5
|
||||
sphinx-copybutton==0.5.0
|
||||
sphinx-design==0.4.0
|
||||
sphinx-design==0.6.1
|
||||
sphinxcontrib-mermaid==1.0.0
|
||||
myst-parser==0.18.1
|
||||
myst-parser==4.0.1
|
||||
|
||||
@ -6,7 +6,7 @@ dependencies = [
|
||||
"GitPython==3.1.45",
|
||||
"docker==7.1.0",
|
||||
"pytest==7.3.2",
|
||||
"uv==0.9.5"
|
||||
"uv==0.8.6"
|
||||
]
|
||||
|
||||
[tool.setuptools]
|
||||
|
||||
@ -102,8 +102,18 @@ if [ "$is_main_doc" = true ]; then
|
||||
echo coverage output not found
|
||||
exit 1
|
||||
elif [ $undocumented -gt 0 ]; then
|
||||
echo undocumented objects found:
|
||||
echo "======================================"
|
||||
echo "ERROR: $undocumented undocumented objects found!"
|
||||
echo "======================================"
|
||||
echo ""
|
||||
echo "Full coverage report:"
|
||||
cat build/coverage/python.txt
|
||||
echo ""
|
||||
echo "======================================"
|
||||
echo "Undocumented modules/objects (lines after TOTAL):"
|
||||
tail -n +$((lines - undocumented + 1)) build/coverage/python.txt
|
||||
echo "======================================"
|
||||
echo ""
|
||||
echo "Make sure you've updated relevant .rsts in docs/source!"
|
||||
echo "You can reproduce locally by running 'cd docs && make coverage && cat build/coverage/python.txt'"
|
||||
exit 1
|
||||
|
||||
@ -1,359 +0,0 @@
|
||||
---
|
||||
name: docstring
|
||||
description: Write docstrings for PyTorch functions and methods following PyTorch conventions. Use when writing or updating docstrings in PyTorch code.
|
||||
---
|
||||
|
||||
# PyTorch Docstring Writing Guide
|
||||
|
||||
This skill describes how to write docstrings for functions and methods in the PyTorch project, following the conventions in `torch/_tensor_docs.py` and `torch/nn/functional.py`.
|
||||
|
||||
## General Principles
|
||||
|
||||
- Use **raw strings** (`r"""..."""`) for all docstrings to avoid issues with LaTeX/math backslashes
|
||||
- Follow **Sphinx/reStructuredText** (reST) format for documentation
|
||||
- Be **concise but complete** - include all essential information
|
||||
- Always include **examples** when possible
|
||||
- Use **cross-references** to related functions/classes
|
||||
|
||||
## Docstring Structure
|
||||
|
||||
### 1. Function Signature (First Line)
|
||||
|
||||
Start with the function signature showing all parameters:
|
||||
|
||||
```python
|
||||
r"""function_name(param1, param2, *, kwarg1=default1, kwarg2=default2) -> ReturnType
|
||||
```
|
||||
|
||||
**Notes:**
|
||||
- Include the function name
|
||||
- Show positional and keyword-only arguments (use `*` separator)
|
||||
- Include default values
|
||||
- Show return type annotation
|
||||
- This line should NOT end with a period
|
||||
|
||||
### 2. Brief Description
|
||||
|
||||
Provide a one-line description of what the function does:
|
||||
|
||||
```python
|
||||
r"""conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1) -> Tensor
|
||||
|
||||
Applies a 2D convolution over an input image composed of several input
|
||||
planes.
|
||||
```
|
||||
|
||||
### 3. Mathematical Formulas (if applicable)
|
||||
|
||||
Use Sphinx math directives for mathematical expressions:
|
||||
|
||||
```python
|
||||
.. math::
|
||||
\text{Softmax}(x_{i}) = \frac{\exp(x_i)}{\sum_j \exp(x_j)}
|
||||
```
|
||||
|
||||
Or inline math: `:math:\`x^2\``
|
||||
|
||||
### 4. Cross-References
|
||||
|
||||
Link to related classes and functions using Sphinx roles:
|
||||
|
||||
- `:class:\`~torch.nn.ModuleName\`` - Link to a class
|
||||
- `:func:\`torch.function_name\`` - Link to a function
|
||||
- `:meth:\`~Tensor.method_name\`` - Link to a method
|
||||
- `:attr:\`attribute_name\`` - Reference an attribute
|
||||
- The `~` prefix shows only the last component (e.g., `Conv2d` instead of `torch.nn.Conv2d`)
|
||||
|
||||
**Example:**
|
||||
```python
|
||||
See :class:`~torch.nn.Conv2d` for details and output shape.
|
||||
```
|
||||
|
||||
### 5. Notes and Warnings
|
||||
|
||||
Use admonitions for important information:
|
||||
|
||||
```python
|
||||
.. note::
|
||||
This function doesn't work directly with NLLLoss,
|
||||
which expects the Log to be computed between the Softmax and itself.
|
||||
Use log_softmax instead (it's faster and has better numerical properties).
|
||||
|
||||
.. warning::
|
||||
:func:`new_tensor` always copies :attr:`data`. If you have a Tensor
|
||||
``data`` and want to avoid a copy, use :func:`torch.Tensor.requires_grad_`
|
||||
or :func:`torch.Tensor.detach`.
|
||||
```
|
||||
|
||||
### 6. Args Section
|
||||
|
||||
Document all parameters with type annotations and descriptions:
|
||||
|
||||
```python
|
||||
Args:
|
||||
input (Tensor): input tensor of shape :math:`(\text{minibatch} , \text{in\_channels} , iH , iW)`
|
||||
weight (Tensor): filters of shape :math:`(\text{out\_channels} , kH , kW)`
|
||||
bias (Tensor, optional): optional bias tensor of shape :math:`(\text{out\_channels})`. Default: ``None``
|
||||
stride (int or tuple): the stride of the convolving kernel. Can be a single number or a
|
||||
tuple `(sH, sW)`. Default: 1
|
||||
```
|
||||
|
||||
**Formatting rules:**
|
||||
- Parameter name in **lowercase**
|
||||
- Type in parentheses: `(Type)`, `(Type, optional)` for optional parameters
|
||||
- Description follows the type
|
||||
- For optional parameters, include "Default: ``value``" at the end
|
||||
- Use double backticks for inline code: ``` ``None`` ```
|
||||
- Indent continuation lines by 2 spaces
|
||||
|
||||
### 7. Keyword Args Section (if applicable)
|
||||
|
||||
Sometimes keyword arguments are documented separately:
|
||||
|
||||
```python
|
||||
Keyword args:
|
||||
dtype (:class:`torch.dtype`, optional): the desired type of returned tensor.
|
||||
Default: if None, same :class:`torch.dtype` as this tensor.
|
||||
device (:class:`torch.device`, optional): the desired device of returned tensor.
|
||||
Default: if None, same :class:`torch.device` as this tensor.
|
||||
requires_grad (bool, optional): If autograd should record operations on the
|
||||
returned tensor. Default: ``False``.
|
||||
```
|
||||
|
||||
### 8. Returns Section (if needed)
|
||||
|
||||
Document the return value:
|
||||
|
||||
```python
|
||||
Returns:
|
||||
Tensor: Sampled tensor of same shape as `logits` from the Gumbel-Softmax distribution.
|
||||
If ``hard=True``, the returned samples will be one-hot, otherwise they will
|
||||
be probability distributions that sum to 1 across `dim`.
|
||||
```
|
||||
|
||||
Or simply include it in the function signature line if obvious from context.
|
||||
|
||||
### 9. Examples Section
|
||||
|
||||
Always include examples when possible:
|
||||
|
||||
```python
|
||||
Examples::
|
||||
|
||||
>>> inputs = torch.randn(33, 16, 30)
|
||||
>>> filters = torch.randn(20, 16, 5)
|
||||
>>> F.conv1d(inputs, filters)
|
||||
|
||||
>>> # With square kernels and equal stride
|
||||
>>> filters = torch.randn(8, 4, 3, 3)
|
||||
>>> inputs = torch.randn(1, 4, 5, 5)
|
||||
>>> F.conv2d(inputs, filters, padding=1)
|
||||
```
|
||||
|
||||
**Formatting rules:**
|
||||
- Use `Examples::` with double colon
|
||||
- Use `>>>` prompt for Python code
|
||||
- Include comments with `#` when helpful
|
||||
- Show actual output when it helps understanding (indent without `>>>`)
|
||||
|
||||
### 10. External References
|
||||
|
||||
Link to papers or external documentation:
|
||||
|
||||
```python
|
||||
.. _Link Name:
|
||||
https://arxiv.org/abs/1611.00712
|
||||
```
|
||||
|
||||
Reference them in text: ```See `Link Name`_```
|
||||
|
||||
## Method Types
|
||||
|
||||
### Native Python Functions
|
||||
|
||||
For regular Python functions, use a standard docstring:
|
||||
|
||||
```python
|
||||
def relu(input: Tensor, inplace: bool = False) -> Tensor:
|
||||
r"""relu(input, inplace=False) -> Tensor
|
||||
|
||||
Applies the rectified linear unit function element-wise. See
|
||||
:class:`~torch.nn.ReLU` for more details.
|
||||
"""
|
||||
# implementation
|
||||
```
|
||||
|
||||
### C-Bound Functions (using add_docstr)
|
||||
|
||||
For C-bound functions, use `_add_docstr`:
|
||||
|
||||
```python
|
||||
conv1d = _add_docstr(
|
||||
torch.conv1d,
|
||||
r"""
|
||||
conv1d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1) -> Tensor
|
||||
|
||||
Applies a 1D convolution over an input signal composed of several input
|
||||
planes.
|
||||
|
||||
See :class:`~torch.nn.Conv1d` for details and output shape.
|
||||
|
||||
Args:
|
||||
input: input tensor of shape :math:`(\text{minibatch} , \text{in\_channels} , iW)`
|
||||
weight: filters of shape :math:`(\text{out\_channels} , kW)`
|
||||
...
|
||||
""",
|
||||
)
|
||||
```
|
||||
|
||||
### In-Place Variants
|
||||
|
||||
For in-place operations (ending with `_`), reference the original:
|
||||
|
||||
```python
|
||||
add_docstr_all(
|
||||
"abs_",
|
||||
r"""
|
||||
abs_() -> Tensor
|
||||
|
||||
In-place version of :meth:`~Tensor.abs`
|
||||
""",
|
||||
)
|
||||
```
|
||||
|
||||
### Alias Functions
|
||||
|
||||
For aliases, simply reference the original:
|
||||
|
||||
```python
|
||||
add_docstr_all(
|
||||
"absolute",
|
||||
r"""
|
||||
absolute() -> Tensor
|
||||
|
||||
Alias for :func:`abs`
|
||||
""",
|
||||
)
|
||||
```
|
||||
|
||||
## Common Patterns
|
||||
|
||||
### Shape Documentation
|
||||
|
||||
Use LaTeX math notation for tensor shapes:
|
||||
|
||||
```python
|
||||
:math:`(\text{minibatch} , \text{in\_channels} , iH , iW)`
|
||||
```
|
||||
|
||||
### Reusable Argument Definitions
|
||||
|
||||
For commonly used arguments, define them once and reuse:
|
||||
|
||||
```python
|
||||
common_args = parse_kwargs(
|
||||
"""
|
||||
dtype (:class:`torch.dtype`, optional): the desired type of returned tensor.
|
||||
Default: if None, same as this tensor.
|
||||
"""
|
||||
)
|
||||
|
||||
# Then use with .format():
|
||||
r"""
|
||||
...
|
||||
|
||||
Keyword args:
|
||||
{dtype}
|
||||
{device}
|
||||
""".format(**common_args)
|
||||
```
|
||||
|
||||
### Template Insertion
|
||||
|
||||
Insert reproducibility notes or other common text:
|
||||
|
||||
```python
|
||||
r"""
|
||||
{tf32_note}
|
||||
|
||||
{cudnn_reproducibility_note}
|
||||
""".format(**reproducibility_notes, **tf32_notes)
|
||||
```
|
||||
|
||||
## Complete Example
|
||||
|
||||
Here's a complete example showing all elements:
|
||||
|
||||
```python
|
||||
def gumbel_softmax(
|
||||
logits: Tensor,
|
||||
tau: float = 1,
|
||||
hard: bool = False,
|
||||
eps: float = 1e-10,
|
||||
dim: int = -1,
|
||||
) -> Tensor:
|
||||
r"""
|
||||
Sample from the Gumbel-Softmax distribution and optionally discretize.
|
||||
|
||||
Args:
|
||||
logits (Tensor): `[..., num_features]` unnormalized log probabilities
|
||||
tau (float): non-negative scalar temperature
|
||||
hard (bool): if ``True``, the returned samples will be discretized as one-hot vectors,
|
||||
but will be differentiated as if it is the soft sample in autograd. Default: ``False``
|
||||
dim (int): A dimension along which softmax will be computed. Default: -1
|
||||
|
||||
Returns:
|
||||
Tensor: Sampled tensor of same shape as `logits` from the Gumbel-Softmax distribution.
|
||||
If ``hard=True``, the returned samples will be one-hot, otherwise they will
|
||||
be probability distributions that sum to 1 across `dim`.
|
||||
|
||||
.. note::
|
||||
This function is here for legacy reasons, may be removed from nn.Functional in the future.
|
||||
|
||||
Examples::
|
||||
>>> logits = torch.randn(20, 32)
|
||||
>>> # Sample soft categorical using reparametrization trick:
|
||||
>>> F.gumbel_softmax(logits, tau=1, hard=False)
|
||||
>>> # Sample hard categorical using "Straight-through" trick:
|
||||
>>> F.gumbel_softmax(logits, tau=1, hard=True)
|
||||
|
||||
.. _Link 1:
|
||||
https://arxiv.org/abs/1611.00712
|
||||
"""
|
||||
# implementation
|
||||
```
|
||||
|
||||
## Quick Checklist
|
||||
|
||||
When writing a PyTorch docstring, ensure:
|
||||
|
||||
- [ ] Use raw string (`r"""`)
|
||||
- [ ] Include function signature on first line
|
||||
- [ ] Provide brief description
|
||||
- [ ] Document all parameters in Args section with types
|
||||
- [ ] Include default values for optional parameters
|
||||
- [ ] Use Sphinx cross-references (`:func:`, `:class:`, `:meth:`)
|
||||
- [ ] Add mathematical formulas if applicable
|
||||
- [ ] Include at least one example in Examples section
|
||||
- [ ] Add warnings/notes for important caveats
|
||||
- [ ] Link to related module class with `:class:`
|
||||
- [ ] Use proper math notation for tensor shapes
|
||||
- [ ] Follow consistent formatting and indentation
|
||||
|
||||
## Common Sphinx Roles Reference
|
||||
|
||||
- `:class:\`~torch.nn.Module\`` - Class reference
|
||||
- `:func:\`torch.function\`` - Function reference
|
||||
- `:meth:\`~Tensor.method\`` - Method reference
|
||||
- `:attr:\`attribute\`` - Attribute reference
|
||||
- `:math:\`equation\`` - Inline math
|
||||
- `:ref:\`label\`` - Internal reference
|
||||
- ``` ``code`` ``` - Inline code (use double backticks)
|
||||
|
||||
## Additional Notes
|
||||
|
||||
- **Indentation**: Use 4 spaces for code, 2 spaces for continuation of parameter descriptions
|
||||
- **Line length**: Try to keep lines under 100 characters when possible
|
||||
- **Periods**: End sentences with periods, but not the signature line
|
||||
- **Backticks**: Use double backticks for code: ``` ``True`` ``None`` ``False`` ```
|
||||
- **Types**: Common types are `Tensor`, `int`, `float`, `bool`, `str`, `tuple`, `list`, etc.
|
||||
@ -1,385 +0,0 @@
|
||||
---
|
||||
name: skill-writer
|
||||
description: Guide users through creating Agent Skills for Claude Code. Use when the user wants to create, write, author, or design a new Skill, or needs help with SKILL.md files, frontmatter, or skill structure.
|
||||
---
|
||||
|
||||
# Skill Writer
|
||||
|
||||
This Skill helps you create well-structured Agent Skills for Claude Code that follow best practices and validation requirements.
|
||||
|
||||
## When to use this Skill
|
||||
|
||||
Use this Skill when:
|
||||
- Creating a new Agent Skill
|
||||
- Writing or updating SKILL.md files
|
||||
- Designing skill structure and frontmatter
|
||||
- Troubleshooting skill discovery issues
|
||||
- Converting existing prompts or workflows into Skills
|
||||
|
||||
## Instructions
|
||||
|
||||
### Step 1: Determine Skill scope
|
||||
|
||||
First, understand what the Skill should do:
|
||||
|
||||
1. **Ask clarifying questions**:
|
||||
- What specific capability should this Skill provide?
|
||||
- When should Claude use this Skill?
|
||||
- What tools or resources does it need?
|
||||
- Is this for personal use or team sharing?
|
||||
|
||||
2. **Keep it focused**: One Skill = one capability
|
||||
- Good: "PDF form filling", "Excel data analysis"
|
||||
- Too broad: "Document processing", "Data tools"
|
||||
|
||||
### Step 2: Choose Skill location
|
||||
|
||||
Determine where to create the Skill:
|
||||
|
||||
**Personal Skills** (`~/.claude/skills/`):
|
||||
- Individual workflows and preferences
|
||||
- Experimental Skills
|
||||
- Personal productivity tools
|
||||
|
||||
**Project Skills** (`.claude/skills/`):
|
||||
- Team workflows and conventions
|
||||
- Project-specific expertise
|
||||
- Shared utilities (committed to git)
|
||||
|
||||
### Step 3: Create Skill structure
|
||||
|
||||
Create the directory and files:
|
||||
|
||||
```bash
|
||||
# Personal
|
||||
mkdir -p ~/.claude/skills/skill-name
|
||||
|
||||
# Project
|
||||
mkdir -p .claude/skills/skill-name
|
||||
```
|
||||
|
||||
For multi-file Skills:
|
||||
```
|
||||
skill-name/
|
||||
├── SKILL.md (required)
|
||||
├── reference.md (optional)
|
||||
├── examples.md (optional)
|
||||
├── scripts/
|
||||
│ └── helper.py (optional)
|
||||
└── templates/
|
||||
└── template.txt (optional)
|
||||
```
|
||||
|
||||
### Step 4: Write SKILL.md frontmatter
|
||||
|
||||
Create YAML frontmatter with required fields:
|
||||
|
||||
```yaml
|
||||
---
|
||||
name: skill-name
|
||||
description: Brief description of what this does and when to use it
|
||||
---
|
||||
```
|
||||
|
||||
**Field requirements**:
|
||||
|
||||
- **name**:
|
||||
- Lowercase letters, numbers, hyphens only
|
||||
- Max 64 characters
|
||||
- Must match directory name
|
||||
- Good: `pdf-processor`, `git-commit-helper`
|
||||
- Bad: `PDF_Processor`, `Git Commits!`
|
||||
|
||||
- **description**:
|
||||
- Max 1024 characters
|
||||
- Include BOTH what it does AND when to use it
|
||||
- Use specific trigger words users would say
|
||||
- Mention file types, operations, and context
|
||||
|
||||
**Optional frontmatter fields**:
|
||||
|
||||
- **allowed-tools**: Restrict tool access (comma-separated list)
|
||||
```yaml
|
||||
allowed-tools: Read, Grep, Glob
|
||||
```
|
||||
Use for:
|
||||
- Read-only Skills
|
||||
- Security-sensitive workflows
|
||||
- Limited-scope operations
|
||||
|
||||
### Step 5: Write effective descriptions
|
||||
|
||||
The description is critical for Claude to discover your Skill.
|
||||
|
||||
**Formula**: `[What it does] + [When to use it] + [Key triggers]`
|
||||
|
||||
**Examples**:
|
||||
|
||||
✅ **Good**:
|
||||
```yaml
|
||||
description: Extract text and tables from PDF files, fill forms, merge documents. Use when working with PDF files or when the user mentions PDFs, forms, or document extraction.
|
||||
```
|
||||
|
||||
✅ **Good**:
|
||||
```yaml
|
||||
description: Analyze Excel spreadsheets, create pivot tables, and generate charts. Use when working with Excel files, spreadsheets, or analyzing tabular data in .xlsx format.
|
||||
```
|
||||
|
||||
❌ **Too vague**:
|
||||
```yaml
|
||||
description: Helps with documents
|
||||
description: For data analysis
|
||||
```
|
||||
|
||||
**Tips**:
|
||||
- Include specific file extensions (.pdf, .xlsx, .json)
|
||||
- Mention common user phrases ("analyze", "extract", "generate")
|
||||
- List concrete operations (not generic verbs)
|
||||
- Add context clues ("Use when...", "For...")
|
||||
|
||||
### Step 6: Structure the Skill content
|
||||
|
||||
Use clear Markdown sections:
|
||||
|
||||
```markdown
|
||||
# Skill Name
|
||||
|
||||
Brief overview of what this Skill does.
|
||||
|
||||
## Quick start
|
||||
|
||||
Provide a simple example to get started immediately.
|
||||
|
||||
## Instructions
|
||||
|
||||
Step-by-step guidance for Claude:
|
||||
1. First step with clear action
|
||||
2. Second step with expected outcome
|
||||
3. Handle edge cases
|
||||
|
||||
## Examples
|
||||
|
||||
Show concrete usage examples with code or commands.
|
||||
|
||||
## Best practices
|
||||
|
||||
- Key conventions to follow
|
||||
- Common pitfalls to avoid
|
||||
- When to use vs. not use
|
||||
|
||||
## Requirements
|
||||
|
||||
List any dependencies or prerequisites:
|
||||
```bash
|
||||
pip install package-name
|
||||
```
|
||||
|
||||
## Advanced usage
|
||||
|
||||
For complex scenarios, see [reference.md](reference.md).
|
||||
```
|
||||
|
||||
### Step 7: Add supporting files (optional)
|
||||
|
||||
Create additional files for progressive disclosure:
|
||||
|
||||
**reference.md**: Detailed API docs, advanced options
|
||||
**examples.md**: Extended examples and use cases
|
||||
**scripts/**: Helper scripts and utilities
|
||||
**templates/**: File templates or boilerplate
|
||||
|
||||
Reference them from SKILL.md:
|
||||
```markdown
|
||||
For advanced usage, see [reference.md](reference.md).
|
||||
|
||||
Run the helper script:
|
||||
\`\`\`bash
|
||||
python scripts/helper.py input.txt
|
||||
\`\`\`
|
||||
```
|
||||
|
||||
### Step 8: Validate the Skill
|
||||
|
||||
Check these requirements:
|
||||
|
||||
✅ **File structure**:
|
||||
- [ ] SKILL.md exists in correct location
|
||||
- [ ] Directory name matches frontmatter `name`
|
||||
|
||||
✅ **YAML frontmatter**:
|
||||
- [ ] Opening `---` on line 1
|
||||
- [ ] Closing `---` before content
|
||||
- [ ] Valid YAML (no tabs, correct indentation)
|
||||
- [ ] `name` follows naming rules
|
||||
- [ ] `description` is specific and < 1024 chars
|
||||
|
||||
✅ **Content quality**:
|
||||
- [ ] Clear instructions for Claude
|
||||
- [ ] Concrete examples provided
|
||||
- [ ] Edge cases handled
|
||||
- [ ] Dependencies listed (if any)
|
||||
|
||||
✅ **Testing**:
|
||||
- [ ] Description matches user questions
|
||||
- [ ] Skill activates on relevant queries
|
||||
- [ ] Instructions are clear and actionable
|
||||
|
||||
### Step 9: Test the Skill
|
||||
|
||||
1. **Restart Claude Code** (if running) to load the Skill
|
||||
|
||||
2. **Ask relevant questions** that match the description:
|
||||
```
|
||||
Can you help me extract text from this PDF?
|
||||
```
|
||||
|
||||
3. **Verify activation**: Claude should use the Skill automatically
|
||||
|
||||
4. **Check behavior**: Confirm Claude follows the instructions correctly
|
||||
|
||||
### Step 10: Debug if needed
|
||||
|
||||
If Claude doesn't use the Skill:
|
||||
|
||||
1. **Make description more specific**:
|
||||
- Add trigger words
|
||||
- Include file types
|
||||
- Mention common user phrases
|
||||
|
||||
2. **Check file location**:
|
||||
```bash
|
||||
ls ~/.claude/skills/skill-name/SKILL.md
|
||||
ls .claude/skills/skill-name/SKILL.md
|
||||
```
|
||||
|
||||
3. **Validate YAML**:
|
||||
```bash
|
||||
cat SKILL.md | head -n 10
|
||||
```
|
||||
|
||||
4. **Run debug mode**:
|
||||
```bash
|
||||
claude --debug
|
||||
```
|
||||
|
||||
## Common patterns
|
||||
|
||||
### Read-only Skill
|
||||
|
||||
```yaml
|
||||
---
|
||||
name: code-reader
|
||||
description: Read and analyze code without making changes. Use for code review, understanding codebases, or documentation.
|
||||
allowed-tools: Read, Grep, Glob
|
||||
---
|
||||
```
|
||||
|
||||
### Script-based Skill
|
||||
|
||||
```yaml
|
||||
---
|
||||
name: data-processor
|
||||
description: Process CSV and JSON data files with Python scripts. Use when analyzing data files or transforming datasets.
|
||||
---
|
||||
|
||||
# Data Processor
|
||||
|
||||
## Instructions
|
||||
|
||||
1. Use the processing script:
|
||||
\`\`\`bash
|
||||
python scripts/process.py input.csv --output results.json
|
||||
\`\`\`
|
||||
|
||||
2. Validate output with:
|
||||
\`\`\`bash
|
||||
python scripts/validate.py results.json
|
||||
\`\`\`
|
||||
```
|
||||
|
||||
### Multi-file Skill with progressive disclosure
|
||||
|
||||
```yaml
|
||||
---
|
||||
name: api-designer
|
||||
description: Design REST APIs following best practices. Use when creating API endpoints, designing routes, or planning API architecture.
|
||||
---
|
||||
|
||||
# API Designer
|
||||
|
||||
Quick start: See [examples.md](examples.md)
|
||||
|
||||
Detailed reference: See [reference.md](reference.md)
|
||||
|
||||
## Instructions
|
||||
|
||||
1. Gather requirements
|
||||
2. Design endpoints (see examples.md)
|
||||
3. Document with OpenAPI spec
|
||||
4. Review against best practices (see reference.md)
|
||||
```
|
||||
|
||||
## Best practices for Skill authors
|
||||
|
||||
1. **One Skill, one purpose**: Don't create mega-Skills
|
||||
2. **Specific descriptions**: Include trigger words users will say
|
||||
3. **Clear instructions**: Write for Claude, not humans
|
||||
4. **Concrete examples**: Show real code, not pseudocode
|
||||
5. **List dependencies**: Mention required packages in description
|
||||
6. **Test with teammates**: Verify activation and clarity
|
||||
7. **Version your Skills**: Document changes in content
|
||||
8. **Use progressive disclosure**: Put advanced details in separate files
|
||||
|
||||
## Validation checklist
|
||||
|
||||
Before finalizing a Skill, verify:
|
||||
|
||||
- [ ] Name is lowercase, hyphens only, max 64 chars
|
||||
- [ ] Description is specific and < 1024 chars
|
||||
- [ ] Description includes "what" and "when"
|
||||
- [ ] YAML frontmatter is valid
|
||||
- [ ] Instructions are step-by-step
|
||||
- [ ] Examples are concrete and realistic
|
||||
- [ ] Dependencies are documented
|
||||
- [ ] File paths use forward slashes
|
||||
- [ ] Skill activates on relevant queries
|
||||
- [ ] Claude follows instructions correctly
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
**Skill doesn't activate**:
|
||||
- Make description more specific with trigger words
|
||||
- Include file types and operations in description
|
||||
- Add "Use when..." clause with user phrases
|
||||
|
||||
**Multiple Skills conflict**:
|
||||
- Make descriptions more distinct
|
||||
- Use different trigger words
|
||||
- Narrow the scope of each Skill
|
||||
|
||||
**Skill has errors**:
|
||||
- Check YAML syntax (no tabs, proper indentation)
|
||||
- Verify file paths (use forward slashes)
|
||||
- Ensure scripts have execute permissions
|
||||
- List all dependencies
|
||||
|
||||
## Examples
|
||||
|
||||
See the documentation for complete examples:
|
||||
- Simple single-file Skill (commit-helper)
|
||||
- Skill with tool permissions (code-reviewer)
|
||||
- Multi-file Skill (pdf-processing)
|
||||
|
||||
## Output format
|
||||
|
||||
When creating a Skill, I will:
|
||||
|
||||
1. Ask clarifying questions about scope and requirements
|
||||
2. Suggest a Skill name and location
|
||||
3. Create the SKILL.md file with proper frontmatter
|
||||
4. Include clear instructions and examples
|
||||
5. Add supporting files if needed
|
||||
6. Provide testing instructions
|
||||
7. Validate against all requirements
|
||||
|
||||
The result will be a complete, working Skill that follows all best practices and validation rules.
|
||||
7
.github/actions/setup-rocm/action.yml
vendored
7
.github/actions/setup-rocm/action.yml
vendored
@ -124,10 +124,3 @@ runs:
|
||||
id: login-ecr
|
||||
continue-on-error: true
|
||||
uses: aws-actions/amazon-ecr-login@062b18b96a7aff071d4dc91bc00c4c1a7945b076 # v2.0.1
|
||||
|
||||
- name: Preserve github env variables for use in docker
|
||||
shell: bash
|
||||
run: |
|
||||
env | grep '^GITHUB' >> "${RUNNER_TEMP}/github_env_${GITHUB_RUN_ID}"
|
||||
env | grep '^CI' >> "${RUNNER_TEMP}/github_env_${GITHUB_RUN_ID}"
|
||||
env | grep '^RUNNER' >> "${RUNNER_TEMP}/github_env_${GITHUB_RUN_ID}"
|
||||
|
||||
2
.github/ci_commit_pins/vision.txt
vendored
2
.github/ci_commit_pins/vision.txt
vendored
@ -1 +1 @@
|
||||
1752fe6809b74921644866275ab80244b96e80bc
|
||||
faffd5cf673615583da6517275e361cb3dbc77e6
|
||||
|
||||
5
.github/ci_configs/vllm/Dockerfile
vendored
5
.github/ci_configs/vllm/Dockerfile
vendored
@ -283,9 +283,6 @@ RUN --mount=type=bind,source=${TORCH_WHEELS_PATH},target=/dist \
|
||||
uv pip install --system $(cat torch_build_versions.txt | xargs) --index-url https://download.pytorch.org/whl/nightly/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.'); \
|
||||
fi
|
||||
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
uv pip install --system --pre apache-tvm-ffi==0.1.0b15
|
||||
|
||||
# Install the vllm wheel from previous stage
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
uv pip install --system /wheels/vllm/*.whl --verbose
|
||||
@ -298,8 +295,6 @@ RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
ARG torch_cuda_arch_list='8.0;8.9;9.0a;10.0a;12.0'
|
||||
ENV TORCH_CUDA_ARCH_LIST=${torch_cuda_arch_list}
|
||||
|
||||
# TODO(elainewy): remove this once vllm commit is updated, and install flashinfer from pip
|
||||
# see https://github.com/pytorch/pytorch/pull/165274#issuecomment-3408531784
|
||||
ARG FLASHINFER_GIT_REPO="https://github.com/flashinfer-ai/flashinfer.git"
|
||||
ARG FLASHINFER_GIT_REF="v0.2.14.post1"
|
||||
|
||||
|
||||
9
.github/label_to_label.yml
vendored
9
.github/label_to_label.yml
vendored
@ -15,11 +15,6 @@
|
||||
- "module: reinplacing"
|
||||
then:
|
||||
- "module: pt2-dispatcher"
|
||||
- any:
|
||||
- "vllm-compile"
|
||||
then:
|
||||
- "module: vllm"
|
||||
- "oncall: pt2"
|
||||
- any:
|
||||
- "module: vmap"
|
||||
then:
|
||||
@ -32,6 +27,10 @@
|
||||
- "module: pt2 optimizer"
|
||||
then:
|
||||
- "module: dynamo"
|
||||
- any:
|
||||
- "module: flex attention"
|
||||
then:
|
||||
- "module: higher order operators"
|
||||
- any:
|
||||
- "module: aotinductor"
|
||||
then:
|
||||
|
||||
1
.github/workflows/inductor-periodic.yml
vendored
1
.github/workflows/inductor-periodic.yml
vendored
@ -88,6 +88,7 @@ jobs:
|
||||
with:
|
||||
build-environment: linux-jammy-rocm-py3_10
|
||||
docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3-benchmarks
|
||||
sync-tag: rocm-build
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
{ config: "dynamo_eager_torchbench", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.gfx942.1" },
|
||||
|
||||
15
.github/workflows/periodic.yml
vendored
15
.github/workflows/periodic.yml
vendored
@ -147,16 +147,15 @@ jobs:
|
||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
||||
build-environment: linux-jammy-cuda12.8-py3.10-gcc9-debug
|
||||
docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9
|
||||
cuda-arch-list: 8.9
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
{ config: "default", shard: 1, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", owners: ["oncall:debug-build"] },
|
||||
{ config: "default", shard: 2, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", owners: ["oncall:debug-build"] },
|
||||
{ config: "default", shard: 3, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", owners: ["oncall:debug-build"] },
|
||||
{ config: "default", shard: 4, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", owners: ["oncall:debug-build"] },
|
||||
{ config: "default", shard: 5, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", owners: ["oncall:debug-build"] },
|
||||
{ config: "default", shard: 6, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", owners: ["oncall:debug-build"] },
|
||||
{ config: "default", shard: 7, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", owners: ["oncall:debug-build"] },
|
||||
{ config: "default", shard: 1, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu", owners: ["oncall:debug-build"] },
|
||||
{ config: "default", shard: 2, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu", owners: ["oncall:debug-build"] },
|
||||
{ config: "default", shard: 3, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu", owners: ["oncall:debug-build"] },
|
||||
{ config: "default", shard: 4, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu", owners: ["oncall:debug-build"] },
|
||||
{ config: "default", shard: 5, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu", owners: ["oncall:debug-build"] },
|
||||
{ config: "default", shard: 6, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu", owners: ["oncall:debug-build"] },
|
||||
{ config: "default", shard: 7, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu", owners: ["oncall:debug-build"] },
|
||||
]}
|
||||
secrets: inherit
|
||||
|
||||
|
||||
3
.github/workflows/pull.yml
vendored
3
.github/workflows/pull.yml
vendored
@ -347,8 +347,7 @@ jobs:
|
||||
uses: ./.github/workflows/_linux-build.yml
|
||||
needs: get-label-type
|
||||
with:
|
||||
# This should sync with the build in xpu.yml but xpu uses a larger runner
|
||||
# sync-tag: linux-xpu-n-build
|
||||
sync-tag: linux-xpu-n-build
|
||||
runner_prefix: ${{ needs.get-label-type.outputs.label-type }}
|
||||
build-environment: linux-jammy-xpu-n-py3.10
|
||||
docker-image-name: ci-image:pytorch-linux-jammy-xpu-n-py3
|
||||
|
||||
1
.github/workflows/rocm-mi300.yml
vendored
1
.github/workflows/rocm-mi300.yml
vendored
@ -45,6 +45,7 @@ jobs:
|
||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
||||
build-environment: linux-noble-rocm-py3.12-mi300
|
||||
docker-image-name: ci-image:pytorch-linux-noble-rocm-n-py3
|
||||
sync-tag: rocm-build
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
{ config: "default", shard: 1, num_shards: 6, runner: "linux.rocm.gpu.gfx942.1" },
|
||||
|
||||
1
.github/workflows/rocm-mi355.yml
vendored
1
.github/workflows/rocm-mi355.yml
vendored
@ -42,6 +42,7 @@ jobs:
|
||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
||||
build-environment: linux-noble-rocm-py3.12-mi355
|
||||
docker-image-name: ci-image:pytorch-linux-noble-rocm-n-py3
|
||||
sync-tag: rocm-build
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
{ config: "default", shard: 1, num_shards: 6, runner: "linux.rocm.gpu.mi355.1" },
|
||||
|
||||
12
.github/workflows/rocm-navi31.yml
vendored
12
.github/workflows/rocm-navi31.yml
vendored
@ -26,23 +26,11 @@ jobs:
|
||||
id-token: write
|
||||
contents: read
|
||||
|
||||
get-label-type:
|
||||
name: get-label-type
|
||||
uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main
|
||||
if: ${{ (github.event_name != 'schedule' || github.repository == 'pytorch/pytorch') && github.repository_owner == 'pytorch' }}
|
||||
with:
|
||||
triggering_actor: ${{ github.triggering_actor }}
|
||||
issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }}
|
||||
curr_branch: ${{ github.head_ref || github.ref_name }}
|
||||
curr_ref_type: ${{ github.ref_type }}
|
||||
|
||||
linux-jammy-rocm-py3_10-build:
|
||||
if: ${{ (github.event_name != 'schedule' || github.repository == 'pytorch/pytorch') && github.repository_owner == 'pytorch' }}
|
||||
name: linux-jammy-rocm-py3.10
|
||||
uses: ./.github/workflows/_linux-build.yml
|
||||
needs: get-label-type
|
||||
with:
|
||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
||||
build-environment: linux-jammy-rocm-py3.10
|
||||
docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3
|
||||
sync-tag: rocm-build
|
||||
|
||||
12
.github/workflows/rocm.yml
vendored
12
.github/workflows/rocm.yml
vendored
@ -26,23 +26,11 @@ jobs:
|
||||
id-token: write
|
||||
contents: read
|
||||
|
||||
get-label-type:
|
||||
name: get-label-type
|
||||
uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main
|
||||
if: ${{ (github.event_name != 'schedule' || github.repository == 'pytorch/pytorch') && github.repository_owner == 'pytorch' }}
|
||||
with:
|
||||
triggering_actor: ${{ github.triggering_actor }}
|
||||
issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }}
|
||||
curr_branch: ${{ github.head_ref || github.ref_name }}
|
||||
curr_ref_type: ${{ github.ref_type }}
|
||||
|
||||
linux-jammy-rocm-py3_10-build:
|
||||
if: ${{ (github.event_name != 'schedule' || github.repository == 'pytorch/pytorch') && github.repository_owner == 'pytorch' }}
|
||||
name: linux-jammy-rocm-py3.10
|
||||
uses: ./.github/workflows/_linux-build.yml
|
||||
needs: get-label-type
|
||||
with:
|
||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
||||
build-environment: linux-jammy-rocm-py3.10
|
||||
docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3
|
||||
sync-tag: rocm-build
|
||||
|
||||
149
.github/workflows/trunk-tagging.yml
vendored
149
.github/workflows/trunk-tagging.yml
vendored
@ -58,10 +58,8 @@ jobs:
|
||||
else
|
||||
COMMIT_SHA="${{ github.sha }}"
|
||||
fi
|
||||
{
|
||||
echo "sha=${COMMIT_SHA}"
|
||||
echo "tag_name=trunk/${COMMIT_SHA}"
|
||||
} >> "${GITHUB_OUTPUT}"
|
||||
echo "sha=${COMMIT_SHA}" >> "${GITHUB_OUTPUT}"
|
||||
echo "tag_name=trunk/${COMMIT_SHA}" >> "${GITHUB_OUTPUT}"
|
||||
|
||||
- name: Validate commit SHA
|
||||
run: |
|
||||
@ -89,7 +87,7 @@ jobs:
|
||||
echo "✅ Commit ${COMMIT_SHA} is valid (automatic push trigger)"
|
||||
fi
|
||||
|
||||
- name: Create and push tag(s) with retry
|
||||
- name: Create and push tag with retry
|
||||
id: check_tag
|
||||
env:
|
||||
TAG_NAME: ${{ steps.commit.outputs.tag_name }}
|
||||
@ -114,23 +112,14 @@ jobs:
|
||||
return 1
|
||||
}
|
||||
|
||||
# Counters for summary reporting
|
||||
created_count=0
|
||||
skipped_count=0
|
||||
failed_count=0
|
||||
# Exit early if tag already exists
|
||||
if check_tag_exists; then
|
||||
echo "✅ Tag already exists - no action needed"
|
||||
echo "exists=true" >> "${GITHUB_OUTPUT}"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
# Always write outputs once on exit
|
||||
finish() {
|
||||
set +e
|
||||
if [ -n "${GITHUB_OUTPUT:-}" ]; then
|
||||
{
|
||||
echo "created_count=${created_count}"
|
||||
echo "skipped_count=${skipped_count}"
|
||||
echo "failed_count=${failed_count}"
|
||||
} >> "${GITHUB_OUTPUT}"
|
||||
fi
|
||||
}
|
||||
trap finish EXIT
|
||||
echo "Tag ${TAG_NAME} does not exist, proceeding with creation"
|
||||
|
||||
# Retry configuration
|
||||
MAX_RETRIES=5
|
||||
@ -205,111 +194,31 @@ jobs:
|
||||
}
|
||||
}
|
||||
|
||||
# New behavior for push events: enumerate commits in the push and tag each one.
|
||||
# For workflow_dispatch, retain existing single-SHA behavior.
|
||||
|
||||
# Always fetch tags once up front to improve idempotency in loops
|
||||
git fetch origin --tags --quiet || true
|
||||
|
||||
if [ "${{ github.event_name }}" = "push" ]; then
|
||||
BEFORE_SHA="${{ github.event.before }}"
|
||||
AFTER_SHA="${{ github.sha }}" # same as event.after
|
||||
|
||||
# List commits introduced by this push (old..new), oldest first for stable ordering
|
||||
commits_file="$(mktemp)"
|
||||
git rev-list --reverse "${BEFORE_SHA}..${AFTER_SHA}" > "${commits_file}"
|
||||
|
||||
if [ ! -s "${commits_file}" ]; then
|
||||
echo "No new commits found between ${BEFORE_SHA}..${AFTER_SHA}; nothing to tag."
|
||||
rm -f "${commits_file}"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
commit_count="$(wc -l < "${commits_file}" | tr -d ' ')"
|
||||
echo "Found ${commit_count} commit(s) to tag for push:"
|
||||
while IFS= read -r sha; do
|
||||
printf ' %s\n' "${sha}"
|
||||
done < "${commits_file}"
|
||||
|
||||
while IFS= read -r sha; do
|
||||
TAG_NAME="trunk/${sha}"
|
||||
COMMIT_SHA="${sha}"
|
||||
|
||||
# If tag already exists locally or remotely, skip (idempotent)
|
||||
if check_tag_exists; then
|
||||
echo "✅ Tag ${TAG_NAME} already exists - skipping"
|
||||
skipped_count=$((skipped_count + 1))
|
||||
continue
|
||||
fi
|
||||
|
||||
echo "Tag ${TAG_NAME} does not exist, proceeding with creation"
|
||||
|
||||
if retry_with_backoff "tag_with_retry" "Creating tag ${TAG_NAME} for commit ${COMMIT_SHA}"; then
|
||||
created_count=$((created_count + 1))
|
||||
else
|
||||
echo "Tag creation failed after all retry attempts for ${TAG_NAME}"
|
||||
failed_count=$((failed_count + 1))
|
||||
fi
|
||||
done < "${commits_file}"
|
||||
|
||||
rm -f "${commits_file}"
|
||||
|
||||
if [ "${failed_count}" -gt 0 ]; then
|
||||
exit 1
|
||||
fi
|
||||
# Execute with retry
|
||||
if retry_with_backoff "tag_with_retry" "Creating tag ${TAG_NAME} for commit ${COMMIT_SHA}"; then
|
||||
echo "exists=false" >> "${GITHUB_OUTPUT}"
|
||||
exit 0
|
||||
else
|
||||
# workflow_dispatch path (single SHA tagging preserved)
|
||||
|
||||
# Exit early if tag already exists
|
||||
if check_tag_exists; then
|
||||
echo "✅ Tag already exists - no action needed"
|
||||
skipped_count=1
|
||||
exit 0
|
||||
fi
|
||||
|
||||
echo "Tag ${TAG_NAME} does not exist, proceeding with creation"
|
||||
|
||||
if retry_with_backoff "tag_with_retry" "Creating tag ${TAG_NAME} for commit ${COMMIT_SHA}"; then
|
||||
created_count=1
|
||||
exit 0
|
||||
else
|
||||
echo "Tag creation failed after all retry attempts"
|
||||
failed_count=1
|
||||
exit 1
|
||||
fi
|
||||
echo "Tag creation failed after all retry attempts"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
- name: Tag creation summary
|
||||
if: always()
|
||||
run: |
|
||||
if [ "${{ github.event_name }}" = "push" ]; then
|
||||
echo "Trigger: push on main"
|
||||
echo "Created: ${{ steps.check_tag.outputs.created_count }}"
|
||||
echo "Skipped (already existed): ${{ steps.check_tag.outputs.skipped_count }}"
|
||||
echo "Failed: ${{ steps.check_tag.outputs.failed_count }}"
|
||||
if [ "${{ steps.check_tag.outputs.failed_count }}" = "0" ]; then
|
||||
echo "✅ Completed tagging for push range ${{ github.event.before }}..${{ github.sha }}"
|
||||
else
|
||||
echo "❌ Some tags failed to create for push range ${{ github.event.before }}..${{ github.sha }}"
|
||||
fi
|
||||
if [ "${{ steps.check_tag.outputs.exists }}" = "true" ]; then
|
||||
echo "✅ Tag ${{ steps.commit.outputs.tag_name }} already existed - no action needed"
|
||||
elif [ "${{ job.status }}" = "success" ]; then
|
||||
echo "✅ Successfully created tag ${{ steps.commit.outputs.tag_name }} for commit ${{ steps.commit.outputs.sha }}"
|
||||
else
|
||||
if [ "${{ steps.check_tag.outputs.failed_count }}" = "0" ]; then
|
||||
if [ "${{ steps.check_tag.outputs.created_count }}" = "0" ]; then
|
||||
echo "✅ Tag ${{ steps.commit.outputs.tag_name }} already existed - no action needed"
|
||||
else
|
||||
echo "✅ Successfully created tag ${{ steps.commit.outputs.tag_name }} for commit ${{ steps.commit.outputs.sha }}"
|
||||
fi
|
||||
else
|
||||
echo "❌ Failed to create tag ${{ steps.commit.outputs.tag_name }} for commit ${{ steps.commit.outputs.sha }}"
|
||||
fi
|
||||
|
||||
echo ""
|
||||
echo "Tag details:"
|
||||
echo " Name: ${{ steps.commit.outputs.tag_name }}"
|
||||
echo " Commit: ${{ steps.commit.outputs.sha }}"
|
||||
echo " Trigger: ${{ github.event_name }}"
|
||||
if [ -n "${{ github.event.inputs.commit_sha }}" ]; then
|
||||
echo " Manual commit: ${{ github.event.inputs.commit_sha }}"
|
||||
fi
|
||||
echo "❌ Failed to create tag ${{ steps.commit.outputs.tag_name }} for commit ${{ steps.commit.outputs.sha }}"
|
||||
fi
|
||||
|
||||
echo ""
|
||||
echo "Tag details:"
|
||||
echo " Name: ${{ steps.commit.outputs.tag_name }}"
|
||||
echo " Commit: ${{ steps.commit.outputs.sha }}"
|
||||
echo " Trigger: ${{ github.event_name }}"
|
||||
if [ -n "${{ github.event.inputs.commit_sha }}" ]; then
|
||||
echo " Manual commit: ${{ github.event.inputs.commit_sha }}"
|
||||
fi
|
||||
|
||||
@ -833,7 +833,8 @@ exclude_patterns = [
|
||||
command = [
|
||||
'python3',
|
||||
'tools/linter/adapters/grep_linter.py',
|
||||
'--pattern=(cudaSetDevice|cudaGetDevice)\\(',
|
||||
'--pattern=cudaSetDevice(',
|
||||
'--pattern=cudaGetDevice(',
|
||||
'--linter-name=RAWCUDADEVICE',
|
||||
'--error-name=raw CUDA API usage',
|
||||
"""--error-description=\
|
||||
@ -1137,8 +1138,11 @@ command = [
|
||||
[[linter]]
|
||||
code = 'WORKFLOWSYNC'
|
||||
include_patterns = [
|
||||
'.github/workflows/*.yml',
|
||||
'.github/workflows/*.yaml',
|
||||
'.github/workflows/pull.yml',
|
||||
'.github/workflows/trunk.yml',
|
||||
'.github/workflows/periodic.yml',
|
||||
'.github/workflows/mac-mps.yml',
|
||||
'.github/workflows/slow.yml',
|
||||
]
|
||||
command = [
|
||||
'python3',
|
||||
|
||||
@ -289,15 +289,14 @@ IF(USE_FBGEMM_GENAI)
|
||||
|
||||
set_target_properties(fbgemm_genai PROPERTIES POSITION_INDEPENDENT_CODE ON)
|
||||
|
||||
set(fbgemm_genai_cuh
|
||||
set(fbgemm_genai_mx8mx8bf16_grouped
|
||||
"${FBGEMM_GENAI_SRCS}/cutlass_extensions/mx8mx8bf16_grouped/"
|
||||
"${FBGEMM_GENAI_SRCS}/"
|
||||
)
|
||||
|
||||
target_include_directories(fbgemm_genai PRIVATE
|
||||
${FBGEMM_THIRD_PARTY}/cutlass/include
|
||||
${FBGEMM_THIRD_PARTY}/cutlass/tools/util/include
|
||||
${fbgemm_genai_cuh}
|
||||
${fbgemm_genai_mx8mx8bf16_grouped}
|
||||
${FBGEMM_GENAI_SRCS}/common/include/ # includes fbgemm_gpu/quantize/utils.h, fbgemm_gpu/quantize/tuning_cache.hpp
|
||||
${FBGEMM_GENAI_SRCS}/include/ # includes fbgemm_gpu/torch_ops.h
|
||||
)
|
||||
|
||||
@ -19,7 +19,6 @@
|
||||
#include <ATen/detail/MPSHooksInterface.h>
|
||||
#include <ATen/detail/MTIAHooksInterface.h>
|
||||
#include <ATen/detail/PrivateUse1HooksInterface.h>
|
||||
#include <ATen/detail/XLAHooksInterface.h>
|
||||
#include <ATen/detail/XPUHooksInterface.h>
|
||||
#include <c10/core/QEngine.h>
|
||||
#include <c10/core/impl/DeviceGuardImplInterface.h>
|
||||
@ -89,8 +88,6 @@ class TORCH_API Context {
|
||||
return at::detail::getHIPHooks();
|
||||
} else if (opt_device_type == at::kHPU) {
|
||||
return at::detail::getHPUHooks();
|
||||
} else if (opt_device_type == at::kXLA) {
|
||||
return at::detail::getXLAHooks();
|
||||
} else {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
@ -199,7 +196,7 @@ class TORCH_API Context {
|
||||
return c10::impl::hasDeviceGuardImpl(c10::DeviceType::IPU);
|
||||
}
|
||||
static bool hasXLA() {
|
||||
return detail::getXLAHooks().hasXLA();
|
||||
return c10::impl::hasDeviceGuardImpl(c10::DeviceType::XLA);
|
||||
}
|
||||
static bool hasXPU() {
|
||||
return detail::getXPUHooks().hasXPU();
|
||||
|
||||
@ -59,7 +59,9 @@ struct TORCH_API Generator {
|
||||
|
||||
explicit Generator(c10::intrusive_ptr<c10::GeneratorImpl> gen_impl)
|
||||
: impl_(std::move(gen_impl)) {
|
||||
TORCH_CHECK(impl_.get(), "GeneratorImpl with nullptr is not supported");
|
||||
if (impl_.get() == nullptr) {
|
||||
throw std::runtime_error("GeneratorImpl with nullptr is not supported");
|
||||
}
|
||||
}
|
||||
|
||||
bool operator==(const Generator& rhs) const {
|
||||
|
||||
@ -111,7 +111,9 @@ class TORCH_API TensorBase {
|
||||
explicit TensorBase(
|
||||
c10::intrusive_ptr<TensorImpl, UndefinedTensorImpl> tensor_impl)
|
||||
: impl_(std::move(tensor_impl)) {
|
||||
TORCH_CHECK(impl_.get(), "TensorImpl with nullptr is not supported");
|
||||
if (impl_.get() == nullptr) {
|
||||
throw std::runtime_error("TensorImpl with nullptr is not supported");
|
||||
}
|
||||
}
|
||||
TensorBase(const TensorBase&) = default;
|
||||
TensorBase(TensorBase&&) noexcept = default;
|
||||
|
||||
@ -109,10 +109,6 @@ TORCH_LIBRARY_IMPL(_, AutogradHPU, m) {
|
||||
m.fallback(AUTOGRAD_FALLBACK);
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL(_, AutogradPrivateUse1, m) {
|
||||
m.fallback(AUTOGRAD_FALLBACK);
|
||||
}
|
||||
|
||||
#undef AUTOGRAD_FALLBACK
|
||||
|
||||
} // namespace
|
||||
|
||||
@ -442,17 +442,11 @@ RegistrationHandleRAII Dispatcher::registerFallback(DispatchKey dispatchKey, Ker
|
||||
|
||||
auto idx = getDispatchTableIndexForDispatchKey(dispatchKey);
|
||||
TORCH_CHECK(idx >= 0 && static_cast<uint64_t>(idx) < backendFallbackKernels_.size(), "idx=", idx);
|
||||
// NB: Perserve BC for registering fallback for AutogradPrivateUse1 multiple time,
|
||||
// refer to https://github.com/pytorch/pytorch/issues/163979 for more informations.
|
||||
TORCH_CHECK(
|
||||
dispatchKey == DispatchKey::AutogradPrivateUse1 ||
|
||||
!backendFallbackKernels_[idx].kernel.isValid(),
|
||||
"Tried to register multiple backend fallbacks for the same dispatch key ",
|
||||
dispatchKey,
|
||||
"; previous registration ",
|
||||
backendFallbackKernels_[idx].debug,
|
||||
", new registration ",
|
||||
debug);
|
||||
!backendFallbackKernels_[idx].kernel.isValid(),
|
||||
"Tried to register multiple backend fallbacks for the same dispatch key ", dispatchKey, "; previous registration ",
|
||||
backendFallbackKernels_[idx].debug, ", new registration ", debug
|
||||
);
|
||||
// NB: inferred function schema is always nullptr for fallbacks, as fallbacks
|
||||
// cannot be unboxed
|
||||
backendFallbackKernels_[idx] = impl::AnnotatedKernel(std::move(kernel), nullptr, std::move(debug));
|
||||
|
||||
@ -68,7 +68,11 @@ Symbol InternedStrings::_symbol(const std::string& s) {
|
||||
return it->second;
|
||||
|
||||
auto pos = s.find("::");
|
||||
TORCH_CHECK(pos != std::string::npos, "all symbols must have a namespace, <namespace>::<string>, but found: ", s);
|
||||
if (pos == std::string::npos) {
|
||||
std::stringstream ss;
|
||||
ss << "all symbols must have a namespace, <namespace>::<string>, but found: " << s;
|
||||
throw std::runtime_error(ss.str());
|
||||
}
|
||||
Symbol ns = _symbol("namespaces::" + s.substr(0, pos));
|
||||
|
||||
Symbol sym(sym_to_info_.size());
|
||||
@ -117,7 +121,12 @@ std::string Symbol::domainString() const {
|
||||
}
|
||||
|
||||
Symbol Symbol::fromDomainAndUnqualString(const std::string & d, const std::string & s) {
|
||||
TORCH_CHECK(d.compare(0, domain_prefix().size(), domain_prefix()) == 0, "Symbol: domain string is expected to be prefixed with '", domain_prefix(), "', e.g. 'org.pytorch.aten'");
|
||||
if (d.compare(0, domain_prefix().size(), domain_prefix()) != 0) {
|
||||
std::ostringstream ss;
|
||||
ss << "Symbol: domain string is expected to be prefixed with '"
|
||||
<< domain_prefix() << "', e.g. 'org.pytorch.aten'";
|
||||
throw std::runtime_error(ss.str());
|
||||
}
|
||||
std::string qualString = d.substr(domain_prefix().size()) + "::" + s;
|
||||
return fromQualString(qualString);
|
||||
}
|
||||
|
||||
@ -7,7 +7,6 @@
|
||||
#include <ATen/core/jit_type.h>
|
||||
#include <ATen/core/stack.h>
|
||||
#include <ATen/core/type_factory.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#include <c10/util/StringUtil.h>
|
||||
#include <c10/util/hash.h>
|
||||
#include <c10/util/irange.h>
|
||||
@ -413,7 +412,7 @@ size_t IValue::hash(const IValue& v) {
|
||||
case Tag::Enum:
|
||||
case Tag::Stream:
|
||||
case Tag::Uninitialized:
|
||||
TORCH_CHECK(false,
|
||||
throw std::runtime_error(
|
||||
"unhashable type: '" + v.type()->repr_str() + "'");
|
||||
}
|
||||
// the above switch should be exhaustive
|
||||
|
||||
@ -8,7 +8,6 @@
|
||||
#include <ATen/core/type_factory.h>
|
||||
#include <ATen/core/qualified_name.h>
|
||||
#include <c10/util/TypeList.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#include <optional>
|
||||
#include <c10/core/SymFloat.h>
|
||||
#include <c10/core/SymBool.h>
|
||||
@ -117,8 +116,10 @@ struct SingleElementType : public SharedType {
|
||||
|
||||
protected:
|
||||
SingleElementType(TypePtr elem) : SharedType(Kind), elem(std::move(elem)) {
|
||||
TORCH_CHECK(this->elem, c10::str(
|
||||
if (!this->elem) {
|
||||
throw std::runtime_error(c10::str(
|
||||
"Can not create ", typeKindToString(Kind), " with None type"));
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
@ -415,12 +416,16 @@ struct TORCH_API SymbolicShape {
|
||||
}
|
||||
|
||||
ShapeSymbol operator[](size_t i) const {
|
||||
TORCH_CHECK(dims_, "Rank isn't fixed");
|
||||
if (!dims_) {
|
||||
throw std::runtime_error("Rank isn't fixed");
|
||||
}
|
||||
return (*dims_).at(i);
|
||||
}
|
||||
|
||||
ShapeSymbol at(size_t i) const {
|
||||
TORCH_CHECK(dims_, "Rank isn't fixed");
|
||||
if (!dims_) {
|
||||
throw std::runtime_error("Rank isn't fixed");
|
||||
}
|
||||
return (*dims_).at(i);
|
||||
}
|
||||
|
||||
@ -515,7 +520,9 @@ struct VaryingShape {
|
||||
}
|
||||
|
||||
const std::optional<T> &operator[](size_t i) const {
|
||||
TORCH_CHECK(dims_, "Rank isn't fixed");
|
||||
if (!dims_) {
|
||||
throw std::runtime_error("Rank isn't fixed");
|
||||
}
|
||||
return (*dims_).at(i);
|
||||
}
|
||||
|
||||
@ -950,7 +957,9 @@ struct TORCH_API DictType : public SharedType {
|
||||
|
||||
TypePtr createWithContained(
|
||||
std::vector<TypePtr> contained_types) const override {
|
||||
TORCH_CHECK(contained_types.size() == 2, "Expected 2 contained types");
|
||||
if (contained_types.size() != 2) {
|
||||
throw std::runtime_error("Expected 2 contained types");
|
||||
}
|
||||
return create(std::move(contained_types.at(0)), std::move(contained_types.at(1)));
|
||||
}
|
||||
|
||||
|
||||
@ -185,11 +185,11 @@ struct TORCH_API Type {
|
||||
: repr_(nullptr) {}
|
||||
|
||||
/* implicit */ SingletonOrSharedTypePtr(SingletonTypePtr<T> p)
|
||||
: repr_(makeSingletonSharedPtr(p.get())) {}
|
||||
: repr_(p) {}
|
||||
|
||||
template <typename U, std::enable_if_t<std::is_convertible_v<U*, T*>, bool> = true>
|
||||
/* implicit */ SingletonOrSharedTypePtr(SingletonTypePtr<U> p)
|
||||
: repr_(makeSingletonSharedPtr(static_cast<T*>(p.get()))) {}
|
||||
: repr_(SingletonTypePtr<T>(p.get())) {}
|
||||
|
||||
|
||||
// We need to support construction from T* for pybind. The problem
|
||||
@ -202,8 +202,8 @@ struct TORCH_API Type {
|
||||
// Case 2: if T is exactly Type, we need to do a dynamic_cast to
|
||||
// check if it's a SharedType and do the right thing.
|
||||
//
|
||||
// Case 3: Otherwise, T is not a SharedType. Use a singleton
|
||||
// pointer.
|
||||
// Case 3: Otherwise, T is not a SharedType. (debug-check this
|
||||
// assumption!) Use a singleton pointer.
|
||||
|
||||
template <typename U = T, std::enable_if_t<std::is_base_of_v<SharedType, U>, bool> = true>
|
||||
/* implicit */ SingletonOrSharedTypePtr(T* p) : SingletonOrSharedTypePtr(static_cast<typename detail::as_shared_type<U>::type>(p)->shared_from_this()) {}
|
||||
@ -211,15 +211,15 @@ struct TORCH_API Type {
|
||||
template <typename U = T, std::enable_if_t<std::is_same_v<Type, U>, bool> = true>
|
||||
/* implicit */ SingletonOrSharedTypePtr(T* p) {
|
||||
if (auto* shared_p = dynamic_cast<typename detail::as_shared_type<U>::type>(p)) {
|
||||
repr_ = shared_p->shared_from_this();
|
||||
repr_ = Repr(shared_p->shared_from_this());
|
||||
} else {
|
||||
repr_ = makeSingletonSharedPtr(p);
|
||||
repr_ = Repr(p);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename U = T, std::enable_if_t<!std::is_same_v<Type, U> && !std::is_base_of_v<SharedType, U>, bool> = true>
|
||||
/* implicit */ SingletonOrSharedTypePtr(T* p)
|
||||
: repr_(makeSingletonSharedPtr(p)) {
|
||||
: repr_(p) {
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(dynamic_cast<typename detail::as_shared_type<U>::type>(p) == nullptr);
|
||||
}
|
||||
|
||||
@ -230,19 +230,19 @@ struct TORCH_API Type {
|
||||
~SingletonOrSharedTypePtr() = default;
|
||||
|
||||
T* get() const {
|
||||
return repr_.get();
|
||||
return repr_.isSharedAndNonNull() ? repr_.shared_.repr_.get() : static_cast<T*>(repr_.rawRepr().first);
|
||||
}
|
||||
|
||||
operator bool() const {
|
||||
return repr_ != nullptr;
|
||||
return repr_.isNonNull();
|
||||
}
|
||||
|
||||
bool operator==(std::nullptr_t) const {
|
||||
return repr_ == nullptr;
|
||||
return !repr_.isNonNull();
|
||||
}
|
||||
|
||||
bool operator!=(std::nullptr_t) const {
|
||||
return repr_ != nullptr;
|
||||
return repr_.isNonNull();
|
||||
}
|
||||
|
||||
template <typename U = T, std::enable_if_t<!std::is_same_v<std::remove_const_t<U>, void>, bool> = true>
|
||||
@ -255,14 +255,138 @@ struct TORCH_API Type {
|
||||
}
|
||||
|
||||
private:
|
||||
// Use shared_ptr's aliasing constructor to create a non-owning pointer
|
||||
// to a singleton. The lifetime is tied to the null shared_ptr, so there's
|
||||
// no reference counting overhead for the singleton itself.
|
||||
static std::shared_ptr<T> makeSingletonSharedPtr(T* ptr) {
|
||||
return std::shared_ptr<T>(std::shared_ptr<T>(), ptr);
|
||||
}
|
||||
// NOTE: SharedPtrWrapper exists to work around a baffling bug in
|
||||
// nvcc; see comment in destroy() below.
|
||||
struct SharedPtrWrapper {
|
||||
SharedPtrWrapper(std::shared_ptr<T> &&x)
|
||||
: repr_(std::move(x)) {}
|
||||
std::shared_ptr<T> repr_;
|
||||
};
|
||||
union Repr {
|
||||
Repr() : Repr(nullptr) {}
|
||||
|
||||
std::shared_ptr<T> repr_;
|
||||
explicit Repr(std::shared_ptr<T> x)
|
||||
: shared_(std::move(x)) {}
|
||||
|
||||
explicit Repr(std::nullptr_t)
|
||||
: singletonRepr_(nullptr) {}
|
||||
|
||||
explicit Repr(SingletonTypePtr<T> p)
|
||||
: singletonRepr_(p.get()) {}
|
||||
|
||||
~Repr() {
|
||||
destroy();
|
||||
}
|
||||
|
||||
// NOTE: the only non-UB way to access our null state is through
|
||||
// rawRepr(), because our copy operation doesn't preserve which
|
||||
// union member is active for null pointers.
|
||||
Repr(const Repr& rhs) {
|
||||
if (rhs.isSharedAndNonNull()) {
|
||||
new (&shared_) SharedPtrWrapper(rhs.shared_);
|
||||
} else {
|
||||
singletonRepr_.singleton_ = static_cast<T*>(rhs.rawRepr().first);
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(rhs.singletonRepr_.unused_ == nullptr);
|
||||
singletonRepr_.unused_ = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
Repr(Repr&& rhs) noexcept {
|
||||
if (rhs.isSharedAndNonNull()) {
|
||||
new (&shared_) SharedPtrWrapper(std::move(rhs.shared_));
|
||||
} else {
|
||||
singletonRepr_.singleton_ = static_cast<T*>(rhs.rawRepr().first);
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(rhs.singletonRepr_.unused_ == nullptr);
|
||||
singletonRepr_.unused_ = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
Repr& operator=(const Repr& rhs) {
|
||||
if (&rhs == this) {
|
||||
return *this;
|
||||
}
|
||||
if (rhs.isSharedAndNonNull()) {
|
||||
if (isSharedAndNonNull()) {
|
||||
shared_ = rhs.shared_;
|
||||
} else {
|
||||
new (&shared_) SharedPtrWrapper(rhs.shared_);
|
||||
}
|
||||
} else {
|
||||
if (isSharedAndNonNull()) {
|
||||
destroy();
|
||||
}
|
||||
singletonRepr_.singleton_ = static_cast<T*>(rhs.rawRepr().first);
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(rhs.rawRepr().nullIfSingleton_ == nullptr);
|
||||
singletonRepr_.unused_ = nullptr;
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
Repr& operator=(Repr&& rhs) noexcept {
|
||||
if (&rhs == this) {
|
||||
return *this;
|
||||
}
|
||||
if (rhs.isSharedAndNonNull()) {
|
||||
if (isSharedAndNonNull()) {
|
||||
shared_ = std::move(rhs.shared_);
|
||||
} else {
|
||||
new (&shared_) SharedPtrWrapper(std::move(rhs.shared_));
|
||||
}
|
||||
} else {
|
||||
if (isSharedAndNonNull()) {
|
||||
destroy();
|
||||
}
|
||||
singletonRepr_.singleton_ = static_cast<T*>(rhs.rawRepr().first);
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(rhs.rawRepr().nullIfSingleton_ == nullptr);
|
||||
singletonRepr_.unused_ = nullptr;
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
SharedPtrWrapper shared_;
|
||||
|
||||
struct SingletonRepr {
|
||||
explicit SingletonRepr(T* s) : singleton_(s) {}
|
||||
T* singleton_;
|
||||
void* unused_ = nullptr;
|
||||
} singletonRepr_;
|
||||
struct RawRepr {
|
||||
void* first;
|
||||
void* nullIfSingleton_;
|
||||
};
|
||||
|
||||
// It is UB to read the singleton part of Repr if it was
|
||||
// constructed as a shared_ptr and vice versa, but memcpying out
|
||||
// the representation is always OK, so here's an accessor to obey
|
||||
// the letter of the law.
|
||||
RawRepr rawRepr() const {
|
||||
RawRepr repr{};
|
||||
memcpy(&repr, reinterpret_cast<const char *>(this), sizeof(RawRepr));
|
||||
return repr;
|
||||
}
|
||||
|
||||
bool isNonNull() const {
|
||||
auto repr = rawRepr();
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(repr.nullIfSingleton_ == nullptr || repr.first != nullptr);
|
||||
return repr.first != nullptr;
|
||||
}
|
||||
|
||||
bool isSharedAndNonNull() const {
|
||||
return rawRepr().nullIfSingleton_ != nullptr;
|
||||
}
|
||||
|
||||
private:
|
||||
void destroy() {
|
||||
if (isSharedAndNonNull()) {
|
||||
// Without SharedPtrWrapper, this line would read
|
||||
// `shared_.~shared_ptr()` and nvcc would complain with
|
||||
// "error: expected primary-expression before '>' token"
|
||||
// referring to the "t" in "shared_ptr". SharedPtrWrapper
|
||||
// exists to work around this compiler bug.
|
||||
shared_.~SharedPtrWrapper();
|
||||
}
|
||||
}
|
||||
} repr_;
|
||||
};
|
||||
|
||||
using TypePtr = SingletonOrSharedTypePtr<Type>;
|
||||
|
||||
@ -8,7 +8,6 @@
|
||||
#include <ATen/core/jit_type.h>
|
||||
#include <c10/macros/Macros.h>
|
||||
#include <c10/util/env.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#include <c10/util/flat_hash_map.h>
|
||||
#include <c10/util/irange.h>
|
||||
#include <array>
|
||||
@ -827,7 +826,9 @@ TupleType::TupleType(
|
||||
: NamedType(TypeKind::TupleType, std::move(name)),
|
||||
elements_(std::move(elements)),
|
||||
has_free_variables_(std::any_of(elements_.begin(), elements_.end(), [](const TypePtr& v) {
|
||||
TORCH_CHECK(v, "Can not create tuple with None type");
|
||||
if (!v) {
|
||||
throw std::runtime_error("Can not create tuple with None type");
|
||||
}
|
||||
return v->hasFreeVariables();
|
||||
})), schema_(std::move(schema)) {
|
||||
|
||||
|
||||
@ -104,6 +104,71 @@ class Vectorized<float> {
|
||||
}
|
||||
return b;
|
||||
}
|
||||
// Implementation is picked from
|
||||
// https://github.com/ARM-software/ComputeLibrary/blob/v25.01/src/core/NEON/SVEMath.inl#L105
|
||||
inline svfloat32_t svexp_f32_z(svbool_t pg, svfloat32_t x) const {
|
||||
const auto c1 =
|
||||
svreinterpret_f32_u32(svdup_n_u32(0x3f7ffff6)); // x^1: 0x1.ffffecp-1f
|
||||
const auto c2 =
|
||||
svreinterpret_f32_u32(svdup_n_u32(0x3efffedb)); // x^2: 0x1.fffdb6p-2f
|
||||
const auto c3 =
|
||||
svreinterpret_f32_u32(svdup_n_u32(0x3e2aaf33)); // x^3: 0x1.555e66p-3f
|
||||
const auto c4 =
|
||||
svreinterpret_f32_u32(svdup_n_u32(0x3d2b9f17)); // x^4: 0x1.573e2ep-5f
|
||||
const auto c5 =
|
||||
svreinterpret_f32_u32(svdup_n_u32(0x3c072010)); // x^5: 0x1.0e4020p-7f
|
||||
const auto shift = svreinterpret_f32_u32(
|
||||
svdup_n_u32(0x4b00007f)); // 2^23 + 127 = 0x1.0000fep23f
|
||||
const auto inv_ln2 = svreinterpret_f32_u32(
|
||||
svdup_n_u32(0x3fb8aa3b)); // 1 / ln(2) = 0x1.715476p+0f
|
||||
const auto neg_ln2_hi = svreinterpret_f32_u32(svdup_n_u32(
|
||||
0xbf317200)); // -ln(2) from bits -1 to -19: -0x1.62e400p-1f
|
||||
const auto neg_ln2_lo = svreinterpret_f32_u32(svdup_n_u32(
|
||||
0xb5bfbe8e)); // -ln(2) from bits -20 to -42: -0x1.7f7d1cp-20f
|
||||
const auto inf = svdup_n_f32(std::numeric_limits<float>::infinity());
|
||||
const auto max_input = svdup_n_f32(88.37f); // Approximately ln(2^127.5)
|
||||
const auto zero = svdup_n_f32(0.f);
|
||||
const auto min_input = svdup_n_f32(-86.64f); // Approximately ln(2^-125)
|
||||
// Range reduction:
|
||||
// e^x = 2^n * e^r
|
||||
// where:
|
||||
// n = floor(x / ln(2))
|
||||
// r = x - n * ln(2)
|
||||
//
|
||||
// By adding x / ln(2) with 2^23 + 127 (shift):
|
||||
// * As FP32 fraction part only has 23-bits, the addition of 2^23 + 127
|
||||
// forces decimal part
|
||||
// of x / ln(2) out of the result. The integer part of x / ln(2) (i.e.
|
||||
// n) + 127 will occupy the whole fraction part of z in FP32 format.
|
||||
// Subtracting 2^23 + 127 (shift) from z will result in the integer part
|
||||
// of x / ln(2) (i.e. n) because the decimal part has been pushed out
|
||||
// and lost.
|
||||
// * The addition of 127 makes the FP32 fraction part of z ready to be
|
||||
// used as the exponent
|
||||
// in FP32 format. Left shifting z by 23 bits will result in 2^n.
|
||||
const auto z = svmla_f32_z(pg, shift, x, inv_ln2);
|
||||
const auto n = svsub_f32_z(pg, z, shift);
|
||||
const auto scale = svreinterpret_f32_u32(
|
||||
svlsl_n_u32_z(pg, svreinterpret_u32_f32(z), 23)); // 2^n
|
||||
// The calculation of n * ln(2) is done using 2 steps to achieve accuracy
|
||||
// beyond FP32. This outperforms longer Taylor series (3-4 tabs) both in
|
||||
// term of accuracy and performance.
|
||||
const auto r_hi = svmla_f32_z(pg, x, n, neg_ln2_hi);
|
||||
const auto r = svmla_f32_z(pg, r_hi, n, neg_ln2_lo);
|
||||
// Compute the truncated Taylor series of e^r.
|
||||
// poly = scale * (1 + c1 * r + c2 * r^2 + c3 * r^3 + c4 * r^4 + c5 * r^5)
|
||||
const auto r2 = svmul_f32_z(pg, r, r);
|
||||
const auto p1 = svmul_f32_z(pg, c1, r);
|
||||
const auto p23 = svmla_f32_z(pg, c2, c3, r);
|
||||
const auto p45 = svmla_f32_z(pg, c4, c5, r);
|
||||
const auto p2345 = svmla_f32_z(pg, p23, p45, r2);
|
||||
const auto p12345 = svmla_f32_z(pg, p1, p2345, r2);
|
||||
auto poly = svmla_f32_z(pg, scale, p12345, scale);
|
||||
// Handle underflow and overflow.
|
||||
poly = svsel_f32(svcmplt_f32(pg, x, min_input), zero, poly);
|
||||
poly = svsel_f32(svcmpgt_f32(pg, x, max_input), inf, poly);
|
||||
return poly;
|
||||
}
|
||||
static Vectorized<float> loadu(const void* ptr, int64_t count = size()) {
|
||||
if (count == size())
|
||||
return svld1_f32(ptrue, reinterpret_cast<const float*>(ptr));
|
||||
@ -248,41 +313,11 @@ class Vectorized<float> {
|
||||
return USE_SLEEF(
|
||||
Vectorized<float>(Sleef_expm1fx_u10sve(values)), map(std::expm1));
|
||||
}
|
||||
// Implementation copied from Arm Optimized Routines:
|
||||
// https://github.com/ARM-software/optimized-routines/blob/master/math/aarch64/sve/expf.c
|
||||
Vectorized<float> exp_u20() const {
|
||||
// special case to handle special inputs that are too large or too small
|
||||
// i.e. where there's at least one element x, s.t. |x| >= 87.3...
|
||||
svbool_t is_special_case = svacgt(svptrue_b32(), values, 0x1.5d5e2ap+6f);
|
||||
if (svptest_any(svptrue_b32(), is_special_case)) {
|
||||
return exp();
|
||||
}
|
||||
const svfloat32_t ln2_hi = svdup_n_f32(0x1.62e4p-1f);
|
||||
const svfloat32_t ln2_lo = svdup_n_f32(0x1.7f7d1cp-20f);
|
||||
const svfloat32_t c1 = svdup_n_f32(0.5f);
|
||||
const svfloat32_t inv_ln2 = svdup_n_f32(0x1.715476p+0f);
|
||||
|
||||
const float shift = 0x1.803f8p17f;
|
||||
|
||||
/* n = round(x/(ln2/N)). */
|
||||
svfloat32_t z = svmad_x(svptrue_b32(), inv_ln2, values, shift);
|
||||
svfloat32_t n = svsub_x(svptrue_b32(), z, shift);
|
||||
|
||||
/* r = x - n*ln2/N. */
|
||||
svfloat32_t r = values;
|
||||
r = svmls_x(svptrue_b32(), r, n, ln2_hi);
|
||||
r = svmls_x(svptrue_b32(), r, n, ln2_lo);
|
||||
|
||||
/* scale = 2^(n/N). */
|
||||
svfloat32_t scale = svexpa(svreinterpret_u32(z));
|
||||
|
||||
/* poly(r) = exp(r) - 1 ~= r + 0.5 r^2. */
|
||||
svfloat32_t r2 = svmul_x(svptrue_b32(), r, r);
|
||||
svfloat32_t poly = svmla_x(svptrue_b32(), r, r2, c1);
|
||||
return svmla_x(svptrue_b32(), scale, scale, poly);
|
||||
return exp();
|
||||
}
|
||||
Vectorized<float> fexp_u20() const {
|
||||
return exp_u20();
|
||||
return exp();
|
||||
}
|
||||
Vectorized<float> fmod(const Vectorized<float>& q) const {USE_SLEEF(
|
||||
{ return Vectorized<float>(Sleef_fmodfx_sve(values, q)); },
|
||||
@ -418,11 +453,9 @@ class Vectorized<float> {
|
||||
ptrue, svmax_f32_z(ptrue, values, CONST_MIN_TANH), CONST_MAX_TANH);
|
||||
|
||||
// Step 2: Calculate exp(2 * x), where x is the clamped value.
|
||||
// svmul_f32_z computes 2 * x, and exp_u20() computes the exponential of
|
||||
// the result (via Vectorized<float>, then auto-converts back to
|
||||
// svfloat32_t).
|
||||
svfloat32_t exp2x =
|
||||
Vectorized<float>(svmul_f32_z(ptrue, CONST_2, x)).exp_u20();
|
||||
// svmul_f32_z computes 2 * x, and svexp_f32_z computes the exponential of
|
||||
// the result.
|
||||
svfloat32_t exp2x = svexp_f32_z(ptrue, svmul_f32_z(ptrue, CONST_2, x));
|
||||
|
||||
// Step 3: Calculate the numerator of the tanh function, which is exp(2x)
|
||||
// - 1.
|
||||
|
||||
@ -6,11 +6,9 @@
|
||||
#ifdef __aarch64__
|
||||
#if !defined(CPU_CAPABILITY_SVE)
|
||||
#include <ATen/cpu/vec/vec128/vec128_bfloat16_neon.h>
|
||||
#include <ATen/cpu/vec/vec128/vec128_double_neon.h>
|
||||
#include <ATen/cpu/vec/vec128/vec128_float_neon.h>
|
||||
#include <ATen/cpu/vec/vec128/vec128_half_neon.h>
|
||||
#include <ATen/cpu/vec/vec128/vec128_int_aarch64.h>
|
||||
#include <ATen/cpu/vec/vec128/vec128_uint_aarch64.h>
|
||||
#endif
|
||||
|
||||
#include <ATen/cpu/vec/vec128/vec128_convert.h>
|
||||
|
||||
@ -354,47 +354,9 @@ class Vectorized<c10::BFloat16> : public Vectorized16<
|
||||
|
||||
DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(abs)
|
||||
Vectorized frac() const;
|
||||
DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(neg)
|
||||
DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(trunc)
|
||||
DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(sqrt)
|
||||
|
||||
#ifdef __ARM_FEATURE_BF16
|
||||
Vectorized<c10::BFloat16> neg() const {
|
||||
return -values;
|
||||
}
|
||||
Vectorized<c10::BFloat16> reciprocal() const {
|
||||
return 1.0f / values;
|
||||
}
|
||||
Vectorized<c10::BFloat16> operator==(
|
||||
const Vectorized<c10::BFloat16>& other) const {
|
||||
return values == other.values;
|
||||
}
|
||||
|
||||
Vectorized<c10::BFloat16> operator!=(
|
||||
const Vectorized<c10::BFloat16>& other) const {
|
||||
return values != other.values;
|
||||
}
|
||||
|
||||
Vectorized<c10::BFloat16> operator<(
|
||||
const Vectorized<c10::BFloat16>& other) const {
|
||||
return values < other.values;
|
||||
}
|
||||
|
||||
Vectorized<c10::BFloat16> operator<=(
|
||||
const Vectorized<c10::BFloat16>& other) const {
|
||||
return values <= other.values;
|
||||
}
|
||||
|
||||
Vectorized<c10::BFloat16> operator>(
|
||||
const Vectorized<c10::BFloat16>& other) const {
|
||||
return values > other.values;
|
||||
}
|
||||
|
||||
Vectorized<c10::BFloat16> operator>=(
|
||||
const Vectorized<c10::BFloat16>& other) const {
|
||||
return values >= other.values;
|
||||
}
|
||||
#else
|
||||
DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(neg)
|
||||
DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(reciprocal)
|
||||
DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator==)
|
||||
DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator!=)
|
||||
@ -402,7 +364,6 @@ class Vectorized<c10::BFloat16> : public Vectorized16<
|
||||
DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator<=)
|
||||
DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator>)
|
||||
DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator>=)
|
||||
#endif
|
||||
|
||||
#undef DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD
|
||||
#undef DEFINE_BINARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD
|
||||
@ -451,52 +412,28 @@ template <>
|
||||
Vectorized<c10::BFloat16> inline operator+(
|
||||
const Vectorized<c10::BFloat16>& a,
|
||||
const Vectorized<c10::BFloat16>& b) {
|
||||
#ifdef __ARM_FEATURE_BF16
|
||||
bfloat16x8_t x = a;
|
||||
bfloat16x8_t y = b;
|
||||
return x + y;
|
||||
#else
|
||||
return binary_operator_via_float(std::plus<Vectorized<float>>(), a, b);
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<c10::BFloat16> inline operator-(
|
||||
const Vectorized<c10::BFloat16>& a,
|
||||
const Vectorized<c10::BFloat16>& b) {
|
||||
#ifdef __ARM_FEATURE_BF16
|
||||
bfloat16x8_t x = a;
|
||||
bfloat16x8_t y = b;
|
||||
return x - y;
|
||||
#else
|
||||
return binary_operator_via_float(std::minus<Vectorized<float>>(), a, b);
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<c10::BFloat16> inline operator*(
|
||||
const Vectorized<c10::BFloat16>& a,
|
||||
const Vectorized<c10::BFloat16>& b) {
|
||||
#ifdef __ARM_FEATURE_BF16
|
||||
bfloat16x8_t x = a;
|
||||
bfloat16x8_t y = b;
|
||||
return x * y;
|
||||
#else
|
||||
return binary_operator_via_float(std::multiplies<Vectorized<float>>(), a, b);
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<c10::BFloat16> inline operator/(
|
||||
const Vectorized<c10::BFloat16>& a,
|
||||
const Vectorized<c10::BFloat16>& b) {
|
||||
#ifdef __ARM_FEATURE_BF16
|
||||
bfloat16x8_t x = a;
|
||||
bfloat16x8_t y = b;
|
||||
return x / y;
|
||||
#else
|
||||
return binary_operator_via_float(std::divides<Vectorized<float>>(), a, b);
|
||||
#endif
|
||||
}
|
||||
|
||||
// frac. Implement this here so we can use subtraction
|
||||
@ -607,19 +544,12 @@ Vectorized<c10::BFloat16> inline fmadd(
|
||||
const Vectorized<c10::BFloat16>& a,
|
||||
const Vectorized<c10::BFloat16>& b,
|
||||
const Vectorized<c10::BFloat16>& c) {
|
||||
#ifdef __ARM_FEATURE_BF16
|
||||
bfloat16x8_t x = a;
|
||||
bfloat16x8_t y = b;
|
||||
bfloat16x8_t z = c;
|
||||
return x * y + z;
|
||||
#else
|
||||
// NOTE [BF16 FMA]: There isn't an FMA that accumulates into BF16! Also,
|
||||
// vbfmlalbq_f32 and vbfmlaltq_f32 take the even and odd-numbered
|
||||
// elements, not the bottom and top half, so they don't seem
|
||||
// particularly useful here. Ideally we would include dot product in
|
||||
// the Vectorized interface...
|
||||
return a * b + c;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
@ -627,15 +557,8 @@ Vectorized<c10::BFloat16> inline fnmadd(
|
||||
const Vectorized<c10::BFloat16>& a,
|
||||
const Vectorized<c10::BFloat16>& b,
|
||||
const Vectorized<c10::BFloat16>& c) {
|
||||
#ifdef __ARM_FEATURE_BF16
|
||||
bfloat16x8_t x = a;
|
||||
bfloat16x8_t y = b;
|
||||
bfloat16x8_t z = c;
|
||||
return (-x) * y + z;
|
||||
#else
|
||||
// See NOTE [BF16 FMA] above.
|
||||
return -a * b + c;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
@ -643,15 +566,8 @@ Vectorized<c10::BFloat16> inline fmsub(
|
||||
const Vectorized<c10::BFloat16>& a,
|
||||
const Vectorized<c10::BFloat16>& b,
|
||||
const Vectorized<c10::BFloat16>& c) {
|
||||
#ifdef __ARM_FEATURE_BF16
|
||||
bfloat16x8_t x = a;
|
||||
bfloat16x8_t y = b;
|
||||
bfloat16x8_t z = c;
|
||||
return x * y - z;
|
||||
#else
|
||||
// See NOTE [BF16 FMA] above.
|
||||
return a * b - c;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
@ -659,15 +575,8 @@ Vectorized<c10::BFloat16> inline fnmsub(
|
||||
const Vectorized<c10::BFloat16>& a,
|
||||
const Vectorized<c10::BFloat16>& b,
|
||||
const Vectorized<c10::BFloat16>& c) {
|
||||
#ifdef __ARM_FEATURE_BF16
|
||||
bfloat16x8_t x = a;
|
||||
bfloat16x8_t y = b;
|
||||
bfloat16x8_t z = c;
|
||||
return (-x) * y - z;
|
||||
#else
|
||||
// See NOTE [BF16 FMA] above.
|
||||
return -a * b - c;
|
||||
#endif
|
||||
}
|
||||
|
||||
#endif // !defined(C10_MOBILE) && defined(__aarch64__)
|
||||
|
||||
@ -5,114 +5,6 @@
|
||||
namespace at::vec {
|
||||
inline namespace CPU_CAPABILITY {
|
||||
#if (defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE256))
|
||||
|
||||
// Enable auto-vectorization for GCC-13+ and clang-17+
|
||||
// GCC-12 has a bug: gcc.gnu.org/bugzilla/show_bug.cgi?id=117001
|
||||
#if __GNUC__ > 12 || (defined(__clang__) && (__clang_major__ >= 17))
|
||||
|
||||
template <typename from_type, typename to_type>
|
||||
inline void convertImpl(
|
||||
const from_type* __restrict src,
|
||||
to_type* __restrict dst,
|
||||
int64_t n) {
|
||||
uint64_t len = static_cast<uint64_t>(n);
|
||||
for (uint64_t i = 0; i < len; i++) {
|
||||
dst[i] = static_cast<to_type>(src[i]);
|
||||
}
|
||||
}
|
||||
|
||||
#define CONVERT_TEMPLATE(from_type, to_type) \
|
||||
template <> \
|
||||
inline void convert(const from_type* src, to_type* dst, int64_t n) { \
|
||||
return convertImpl<from_type, to_type>(src, dst, n); \
|
||||
}
|
||||
|
||||
CONVERT_TEMPLATE(uint8_t, uint8_t)
|
||||
CONVERT_TEMPLATE(uint8_t, int8_t)
|
||||
CONVERT_TEMPLATE(uint8_t, int16_t)
|
||||
CONVERT_TEMPLATE(uint8_t, int32_t)
|
||||
CONVERT_TEMPLATE(uint8_t, int64_t)
|
||||
CONVERT_TEMPLATE(uint8_t, float)
|
||||
CONVERT_TEMPLATE(uint8_t, double)
|
||||
CONVERT_TEMPLATE(int8_t, uint8_t)
|
||||
CONVERT_TEMPLATE(int8_t, int8_t)
|
||||
CONVERT_TEMPLATE(int8_t, int16_t)
|
||||
CONVERT_TEMPLATE(int8_t, int32_t)
|
||||
CONVERT_TEMPLATE(int8_t, int64_t)
|
||||
CONVERT_TEMPLATE(int8_t, float)
|
||||
CONVERT_TEMPLATE(int8_t, double)
|
||||
CONVERT_TEMPLATE(int16_t, uint8_t)
|
||||
CONVERT_TEMPLATE(int16_t, int8_t)
|
||||
CONVERT_TEMPLATE(int16_t, int16_t)
|
||||
CONVERT_TEMPLATE(int16_t, int32_t)
|
||||
CONVERT_TEMPLATE(int16_t, int64_t)
|
||||
CONVERT_TEMPLATE(int16_t, float)
|
||||
CONVERT_TEMPLATE(int16_t, double)
|
||||
CONVERT_TEMPLATE(int32_t, uint8_t)
|
||||
CONVERT_TEMPLATE(int32_t, int8_t)
|
||||
CONVERT_TEMPLATE(int32_t, int16_t)
|
||||
CONVERT_TEMPLATE(int32_t, int32_t)
|
||||
CONVERT_TEMPLATE(int32_t, int64_t)
|
||||
CONVERT_TEMPLATE(int32_t, float)
|
||||
CONVERT_TEMPLATE(int32_t, double)
|
||||
CONVERT_TEMPLATE(int64_t, uint8_t)
|
||||
CONVERT_TEMPLATE(int64_t, int8_t)
|
||||
CONVERT_TEMPLATE(int64_t, int16_t)
|
||||
CONVERT_TEMPLATE(int64_t, int32_t)
|
||||
CONVERT_TEMPLATE(int64_t, int64_t)
|
||||
CONVERT_TEMPLATE(int64_t, float)
|
||||
CONVERT_TEMPLATE(int64_t, double)
|
||||
CONVERT_TEMPLATE(float, uint8_t)
|
||||
CONVERT_TEMPLATE(float, int8_t)
|
||||
CONVERT_TEMPLATE(float, int16_t)
|
||||
CONVERT_TEMPLATE(float, int32_t)
|
||||
CONVERT_TEMPLATE(float, int64_t)
|
||||
CONVERT_TEMPLATE(float, float)
|
||||
CONVERT_TEMPLATE(float, double)
|
||||
CONVERT_TEMPLATE(double, uint8_t)
|
||||
CONVERT_TEMPLATE(double, int8_t)
|
||||
CONVERT_TEMPLATE(double, int16_t)
|
||||
CONVERT_TEMPLATE(double, int32_t)
|
||||
CONVERT_TEMPLATE(double, int64_t)
|
||||
CONVERT_TEMPLATE(double, float)
|
||||
CONVERT_TEMPLATE(double, double)
|
||||
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
|
||||
CONVERT_TEMPLATE(float16_t, uint8_t)
|
||||
CONVERT_TEMPLATE(float16_t, int8_t)
|
||||
CONVERT_TEMPLATE(float16_t, int16_t)
|
||||
CONVERT_TEMPLATE(float16_t, int32_t)
|
||||
CONVERT_TEMPLATE(float16_t, int64_t)
|
||||
CONVERT_TEMPLATE(float16_t, float16_t)
|
||||
CONVERT_TEMPLATE(float16_t, float)
|
||||
CONVERT_TEMPLATE(float16_t, double)
|
||||
CONVERT_TEMPLATE(uint8_t, float16_t)
|
||||
CONVERT_TEMPLATE(int8_t, float16_t)
|
||||
CONVERT_TEMPLATE(int16_t, float16_t)
|
||||
CONVERT_TEMPLATE(int32_t, float16_t)
|
||||
CONVERT_TEMPLATE(int64_t, float16_t)
|
||||
CONVERT_TEMPLATE(float, float16_t)
|
||||
CONVERT_TEMPLATE(double, float16_t)
|
||||
#endif
|
||||
#ifdef __ARM_FEATURE_BF16
|
||||
CONVERT_TEMPLATE(bfloat16_t, uint8_t)
|
||||
CONVERT_TEMPLATE(bfloat16_t, int8_t)
|
||||
CONVERT_TEMPLATE(bfloat16_t, int16_t)
|
||||
CONVERT_TEMPLATE(bfloat16_t, int32_t)
|
||||
CONVERT_TEMPLATE(bfloat16_t, int64_t)
|
||||
CONVERT_TEMPLATE(bfloat16_t, bfloat16_t)
|
||||
CONVERT_TEMPLATE(bfloat16_t, float)
|
||||
CONVERT_TEMPLATE(bfloat16_t, double)
|
||||
CONVERT_TEMPLATE(uint8_t, bfloat16_t)
|
||||
CONVERT_TEMPLATE(int8_t, bfloat16_t)
|
||||
CONVERT_TEMPLATE(int16_t, bfloat16_t)
|
||||
CONVERT_TEMPLATE(int32_t, bfloat16_t)
|
||||
CONVERT_TEMPLATE(int64_t, bfloat16_t)
|
||||
CONVERT_TEMPLATE(float, bfloat16_t)
|
||||
CONVERT_TEMPLATE(double, bfloat16_t)
|
||||
#endif
|
||||
|
||||
#endif
|
||||
|
||||
template <typename src_t>
|
||||
struct VecConvert<
|
||||
float,
|
||||
|
||||
@ -1,586 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include <ATen/cpu/vec/intrinsics.h>
|
||||
#include <ATen/cpu/vec/vec_base.h>
|
||||
#include <c10/macros/Macros.h>
|
||||
#include <c10/util/irange.h>
|
||||
#include <cmath>
|
||||
|
||||
namespace at::vec {
|
||||
// Note [CPU_CAPABILITY namespace]
|
||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
// This header, and all of its subheaders, will be compiled with
|
||||
// different architecture flags for each supported set of vector
|
||||
// intrinsics. So we need to make sure they aren't inadvertently
|
||||
// linked together. We do this by declaring objects in an `inline
|
||||
// namespace` which changes the name mangling, but can still be
|
||||
// accessed as `at::vec`.
|
||||
inline namespace CPU_CAPABILITY {
|
||||
|
||||
template <>
|
||||
struct is_vec_specialized_for<double> : std::bool_constant<true> {};
|
||||
|
||||
template <>
|
||||
class Vectorized<double> {
|
||||
private:
|
||||
float64x2_t values;
|
||||
|
||||
public:
|
||||
using value_type = double;
|
||||
using size_type = int;
|
||||
static constexpr size_type size() {
|
||||
return 2;
|
||||
}
|
||||
Vectorized() {
|
||||
values = vdupq_n_f64(0.0);
|
||||
}
|
||||
Vectorized(float64x2_t v) : values(v) {}
|
||||
Vectorized(double val) {
|
||||
values = vdupq_n_f64(val);
|
||||
}
|
||||
template <
|
||||
typename... Args,
|
||||
typename = std::enable_if_t<(sizeof...(Args) == size())>>
|
||||
Vectorized(Args... vals) {
|
||||
__at_align__ double buffer[size()] = {vals...};
|
||||
values = vld1q_f64(buffer);
|
||||
}
|
||||
operator float64x2_t() const {
|
||||
return values;
|
||||
}
|
||||
template <int64_t mask>
|
||||
static Vectorized<double> blend(
|
||||
const Vectorized<double>& a,
|
||||
const Vectorized<double>& b) {
|
||||
// Build an array of flags: each bit of element is 1 if the corresponding
|
||||
// bit in 'mask' is set, 0 otherwise.
|
||||
uint64x2_t maskArray = {
|
||||
(mask & 1ULL) ? 0xFFFFFFFFFFFFFFFF : 0,
|
||||
(mask & 2ULL) ? 0xFFFFFFFFFFFFFFFF : 0};
|
||||
// Use BSL to select elements from b where the mask is 1, else from a
|
||||
return vbslq_f64(maskArray, b.values, a.values);
|
||||
}
|
||||
static Vectorized<double> blendv(
|
||||
const Vectorized<double>& a,
|
||||
const Vectorized<double>& b,
|
||||
const Vectorized<double>& mask_) {
|
||||
return vbslq_f64(vreinterpretq_u64_f64(mask_.values), b.values, a.values);
|
||||
}
|
||||
template <typename step_t>
|
||||
static Vectorized<double> arange(
|
||||
double base = 0.,
|
||||
step_t step = static_cast<step_t>(1)) {
|
||||
return {base, base + static_cast<double>(step)};
|
||||
}
|
||||
static inline Vectorized<double> set(
|
||||
const Vectorized<double>& a,
|
||||
const Vectorized<double>& b,
|
||||
int64_t count = size()) {
|
||||
if (count == 0) {
|
||||
return a;
|
||||
} else if (count >= 2) {
|
||||
return b;
|
||||
} else {
|
||||
float64x2_t c = {b.values[0], a.values[1]};
|
||||
return c;
|
||||
}
|
||||
}
|
||||
static Vectorized<double> loadu(const void* ptr, int64_t count = size()) {
|
||||
if (count == size()) {
|
||||
return vld1q_f64(reinterpret_cast<const double*>(ptr));
|
||||
} else if (count == 1) {
|
||||
float64x1_t x = vld1_f64(reinterpret_cast<const double*>(ptr));
|
||||
float64x1_t z = {0.0};
|
||||
return vcombine_f64(x, z);
|
||||
} else {
|
||||
return vdupq_n_f64(0.0);
|
||||
}
|
||||
}
|
||||
void store(void* ptr, int64_t count = size()) const {
|
||||
if (count == size()) {
|
||||
vst1q_f64(reinterpret_cast<double*>(ptr), values);
|
||||
} else if (count == 1) {
|
||||
vst1_f64(reinterpret_cast<double*>(ptr), vget_low_f64(values));
|
||||
}
|
||||
}
|
||||
const double& operator[](int idx) const = delete;
|
||||
double& operator[](int idx) = delete;
|
||||
int64_t zero_mask() const {
|
||||
// returns an integer mask where all zero elements are translated to 1-bit
|
||||
// and others are translated to 0-bit
|
||||
uint64x2_t cmpReg = vceqzq_f64(values);
|
||||
uint64x2_t mask = {1, 2};
|
||||
uint64x2_t res = vandq_u64(cmpReg, mask);
|
||||
return res[0] | res[1];
|
||||
}
|
||||
Vectorized<double> isnan() const {
|
||||
// NaN check
|
||||
return vreinterpretq_f64_u32(
|
||||
vmvnq_u32(vreinterpretq_u32_u64(vceqq_f64(values, values))));
|
||||
}
|
||||
bool has_inf_nan() const {
|
||||
Vectorized<double> x = vsubq_f64(values, values);
|
||||
float64x2_t r = x.isnan();
|
||||
uint64x2_t u = vreinterpretq_u64_f64(r);
|
||||
return u[0] | u[1];
|
||||
}
|
||||
Vectorized<double> map(double (*f)(double)) const {
|
||||
float64x2_t result;
|
||||
result[0] = f(values[0]);
|
||||
result[1] = f(values[1]);
|
||||
return result;
|
||||
}
|
||||
Vectorized<double> map2(
|
||||
const Vectorized<double>& second,
|
||||
double (*const f)(double, double)) const {
|
||||
float64x2_t result;
|
||||
result[0] = f(values[0], second.values[0]);
|
||||
result[1] = f(values[1], second.values[1]);
|
||||
return result;
|
||||
}
|
||||
Vectorized<double> abs() const {
|
||||
return vabsq_f64(values);
|
||||
}
|
||||
Vectorized<double> angle() const {
|
||||
auto zero = Vectorized<double>(0.0);
|
||||
auto pi = Vectorized<double>(c10::pi<double>);
|
||||
auto tmp = blendv(zero, pi, vreinterpretq_f64_u64(vcltzq_f64(values)));
|
||||
return blendv(tmp, *this, isnan());
|
||||
}
|
||||
Vectorized<double> real() const {
|
||||
return *this;
|
||||
}
|
||||
Vectorized<double> imag() const {
|
||||
return Vectorized<double>(0.0);
|
||||
}
|
||||
Vectorized<double> conj() const {
|
||||
return *this;
|
||||
}
|
||||
Vectorized<double> acos() const {
|
||||
return USE_SLEEF(
|
||||
Vectorized<double>(Sleef_acosd2_u10(values)), map(std::acos));
|
||||
}
|
||||
Vectorized<double> acosh() const {
|
||||
return USE_SLEEF(
|
||||
Vectorized<double>(Sleef_acoshd2_u10(values)), map(std::acosh));
|
||||
}
|
||||
Vectorized<double> asin() const {
|
||||
return USE_SLEEF(
|
||||
Vectorized<double>(Sleef_asind2_u10(values)), map(std::asin));
|
||||
}
|
||||
Vectorized<double> asinh() const {
|
||||
return USE_SLEEF(
|
||||
Vectorized<double>(Sleef_asinhd2_u10(values)), map(std::asinh));
|
||||
}
|
||||
Vectorized<double> atan() const {
|
||||
return USE_SLEEF(
|
||||
Vectorized<double>(Sleef_atand2_u10(values)), map(std::atan));
|
||||
}
|
||||
Vectorized<double> atanh() const {
|
||||
return USE_SLEEF(
|
||||
Vectorized<double>(Sleef_atanhd2_u10(values)), map(std::atanh));
|
||||
}
|
||||
Vectorized<double> atan2(const Vectorized<double>& b) const {USE_SLEEF(
|
||||
{ return Vectorized<double>(Sleef_atan2d2_u10(values, b)); },
|
||||
{
|
||||
__at_align__ double tmp[size()];
|
||||
__at_align__ double tmp_b[size()];
|
||||
store(tmp);
|
||||
b.store(tmp_b);
|
||||
for (int64_t i = 0; i < size(); i++) {
|
||||
tmp[i] = std::atan2(tmp[i], tmp_b[i]);
|
||||
}
|
||||
return loadu(tmp);
|
||||
})} Vectorized<double> copysign(const Vectorized<double>& sign) const {
|
||||
USE_SLEEF(
|
||||
{ return Vectorized<double>(Sleef_copysignd2(values, sign)); },
|
||||
{
|
||||
__at_align__ double tmp[size()];
|
||||
__at_align__ double tmp_sign[size()];
|
||||
store(tmp);
|
||||
sign.store(tmp_sign);
|
||||
for (int64_t i = 0; i < size(); i++) {
|
||||
tmp[i] = std::copysign(tmp[i], tmp_sign[i]);
|
||||
}
|
||||
return loadu(tmp);
|
||||
})} Vectorized<double> erf() const {
|
||||
return USE_SLEEF(
|
||||
Vectorized<double>(Sleef_erfd2_u10(values)), map(std::erf));
|
||||
}
|
||||
Vectorized<double> erfc() const {
|
||||
return USE_SLEEF(
|
||||
Vectorized<double>(Sleef_erfcd2_u15(values)), map(std::erfc));
|
||||
}
|
||||
Vectorized<double> exp() const {
|
||||
return USE_SLEEF(
|
||||
Vectorized<double>(Sleef_expd2_u10(values)), map(std::exp));
|
||||
}
|
||||
Vectorized<double> exp2() const {
|
||||
return USE_SLEEF(
|
||||
Vectorized<double>(Sleef_exp2d2_u10(values)), map(std::exp2));
|
||||
}
|
||||
Vectorized<double> expm1() const {
|
||||
return USE_SLEEF(
|
||||
Vectorized<double>(Sleef_expm1d2_u10(values)), map(std::expm1));
|
||||
}
|
||||
Vectorized<double> fmod(const Vectorized<double>& q) const {USE_SLEEF(
|
||||
{ return Vectorized<double>(Sleef_fmodd2(values, q)); },
|
||||
{
|
||||
__at_align__ double tmp[size()];
|
||||
__at_align__ double tmp_q[size()];
|
||||
store(tmp);
|
||||
q.store(tmp_q);
|
||||
for (int64_t i = 0; i < size(); i++) {
|
||||
tmp[i] = std::fmod(tmp[i], tmp_q[i]);
|
||||
}
|
||||
return loadu(tmp);
|
||||
})} Vectorized<double> hypot(const Vectorized<double>& b) const {
|
||||
USE_SLEEF(
|
||||
{ return Vectorized<double>(Sleef_hypotd2_u05(values, b)); },
|
||||
{
|
||||
__at_align__ double tmp[size()];
|
||||
__at_align__ double tmp_b[size()];
|
||||
store(tmp);
|
||||
b.store(tmp_b);
|
||||
for (int64_t i = 0; i < size(); i++) {
|
||||
tmp[i] = std::hypot(tmp[i], tmp_b[i]);
|
||||
}
|
||||
return loadu(tmp);
|
||||
})} Vectorized<double> i0() const {
|
||||
return map(calc_i0);
|
||||
}
|
||||
Vectorized<double> nextafter(const Vectorized<double>& b) const {USE_SLEEF(
|
||||
{ return Vectorized<double>(Sleef_nextafterd2(values, b)); },
|
||||
{
|
||||
__at_align__ double tmp[size()];
|
||||
__at_align__ double tmp_b[size()];
|
||||
store(tmp);
|
||||
b.store(tmp_b);
|
||||
for (int64_t i = 0; i < size(); ++i) {
|
||||
tmp[i] = std::nextafter(tmp[i], tmp_b[i]);
|
||||
}
|
||||
return loadu(tmp);
|
||||
})} Vectorized<double> log() const {
|
||||
return USE_SLEEF(
|
||||
Vectorized<double>(Sleef_logd2_u10(values)), map(std::log));
|
||||
}
|
||||
Vectorized<double> log2() const {
|
||||
return USE_SLEEF(
|
||||
Vectorized<double>(Sleef_log2d2_u10(values)), map(std::log2));
|
||||
}
|
||||
Vectorized<double> log10() const {
|
||||
return USE_SLEEF(
|
||||
Vectorized<double>(Sleef_log10d2_u10(values)), map(std::log10));
|
||||
}
|
||||
Vectorized<double> log1p() const {
|
||||
return USE_SLEEF(
|
||||
Vectorized<double>(Sleef_log1pd2_u10(values)), map(std::log1p));
|
||||
}
|
||||
Vectorized<double> frac() const;
|
||||
Vectorized<double> sin() const {
|
||||
return USE_SLEEF(
|
||||
Vectorized<double>(Sleef_sind2_u10(values)), map(std::sin));
|
||||
}
|
||||
Vectorized<double> sinh() const {
|
||||
return USE_SLEEF(
|
||||
Vectorized<double>(Sleef_sinhd2_u10(values)), map(std::sinh));
|
||||
}
|
||||
Vectorized<double> cos() const {
|
||||
return USE_SLEEF(
|
||||
Vectorized<double>(Sleef_cosd2_u10(values)), map(std::cos));
|
||||
}
|
||||
Vectorized<double> cosh() const {
|
||||
return USE_SLEEF(
|
||||
Vectorized<double>(Sleef_coshd2_u10(values)), map(std::cosh));
|
||||
}
|
||||
Vectorized<double> pow(const Vectorized<double>& b) const {USE_SLEEF(
|
||||
{ return Vectorized<double>(Sleef_powd2_u10(values, b)); },
|
||||
{
|
||||
__at_align__ double tmp[size()];
|
||||
__at_align__ double tmp_b[size()];
|
||||
store(tmp);
|
||||
b.store(tmp_b);
|
||||
for (int64_t i = 0; i < size(); i++) {
|
||||
tmp[i] = std::pow(tmp[i], tmp_b[i]);
|
||||
}
|
||||
return loadu(tmp);
|
||||
})} // Comparison using the _CMP_**_OQ predicate.
|
||||
// `O`: get false if an operand is NaN
|
||||
// `Q`: do not raise if an operand is NaN
|
||||
Vectorized<double> tan() const {
|
||||
return USE_SLEEF(
|
||||
Vectorized<double>(Sleef_tand2_u10(values)), map(std::tan));
|
||||
}
|
||||
Vectorized<double> tanh() const {
|
||||
return USE_SLEEF(
|
||||
Vectorized<double>(Sleef_tanhd2_u10(values)), map(std::tanh));
|
||||
}
|
||||
Vectorized<double> lgamma() const {
|
||||
return USE_SLEEF(
|
||||
Vectorized<double>(Sleef_lgammad2_u10(values)), map(std::lgamma));
|
||||
}
|
||||
Vectorized<double> erfinv() const {
|
||||
return map(calc_erfinv);
|
||||
}
|
||||
Vectorized<double> exp_u20() const {
|
||||
return exp();
|
||||
}
|
||||
Vectorized<double> fexp_u20() const {
|
||||
return exp();
|
||||
}
|
||||
Vectorized<double> i0e() const {
|
||||
return map(calc_i0e);
|
||||
}
|
||||
Vectorized<double> digamma() const {
|
||||
return map(calc_digamma);
|
||||
}
|
||||
Vectorized<double> igamma(const Vectorized<double>& x) const {
|
||||
__at_align__ double tmp[size()];
|
||||
__at_align__ double tmp_x[size()];
|
||||
store(tmp);
|
||||
x.store(tmp_x);
|
||||
for (int64_t i = 0; i < size(); i++) {
|
||||
tmp[i] = calc_igamma(tmp[i], tmp_x[i]);
|
||||
}
|
||||
return loadu(tmp);
|
||||
}
|
||||
Vectorized<double> igammac(const Vectorized<double>& x) const {
|
||||
__at_align__ double tmp[size()];
|
||||
__at_align__ double tmp_x[size()];
|
||||
store(tmp);
|
||||
x.store(tmp_x);
|
||||
for (int64_t i = 0; i < size(); i++) {
|
||||
tmp[i] = calc_igammac(tmp[i], tmp_x[i]);
|
||||
}
|
||||
return loadu(tmp);
|
||||
}
|
||||
Vectorized<double> ceil() const {
|
||||
return vrndpq_f64(values);
|
||||
}
|
||||
Vectorized<double> floor() const {
|
||||
return vrndmq_f64(values);
|
||||
}
|
||||
Vectorized<double> neg() const {
|
||||
return vnegq_f64(values);
|
||||
}
|
||||
Vectorized<double> round() const {
|
||||
return vrndiq_f64(values);
|
||||
}
|
||||
Vectorized<double> trunc() const {
|
||||
return vrndq_f64(values);
|
||||
}
|
||||
Vectorized<double> sqrt() const {
|
||||
return vsqrtq_f64(values);
|
||||
}
|
||||
Vectorized<double> reciprocal() const {
|
||||
return vdivq_f64(vdupq_n_f64(1.0), values);
|
||||
}
|
||||
Vectorized<double> rsqrt() const {
|
||||
return vdivq_f64(vdupq_n_f64(1.0), vsqrtq_f64(values));
|
||||
}
|
||||
double reduce_add() const {
|
||||
return vaddvq_f64(values);
|
||||
}
|
||||
double reduce_max() const {
|
||||
return vmaxvq_f64(values);
|
||||
}
|
||||
Vectorized<double> operator==(const Vectorized<double>& other) const {
|
||||
return Vectorized<double>(
|
||||
vreinterpretq_f64_u64(vceqq_f64(values, other.values)));
|
||||
}
|
||||
|
||||
Vectorized<double> operator!=(const Vectorized<double>& other) const {
|
||||
float64x2_t r0 = vreinterpretq_f64_u32(
|
||||
vmvnq_u32(vreinterpretq_u32_u64(vceqq_f64(values, other.values))));
|
||||
return Vectorized<double>(r0);
|
||||
}
|
||||
|
||||
Vectorized<double> operator<(const Vectorized<double>& other) const {
|
||||
return Vectorized<double>(
|
||||
vreinterpretq_f64_u64(vcltq_f64(values, other.values)));
|
||||
}
|
||||
|
||||
Vectorized<double> operator<=(const Vectorized<double>& other) const {
|
||||
return Vectorized<double>(
|
||||
vreinterpretq_f64_u64(vcleq_f64(values, other.values)));
|
||||
}
|
||||
|
||||
Vectorized<double> operator>(const Vectorized<double>& other) const {
|
||||
return Vectorized<double>(
|
||||
vreinterpretq_f64_u64(vcgtq_f64(values, other.values)));
|
||||
}
|
||||
|
||||
Vectorized<double> operator>=(const Vectorized<double>& other) const {
|
||||
return Vectorized<double>(
|
||||
vreinterpretq_f64_u64(vcgeq_f64(values, other.values)));
|
||||
}
|
||||
|
||||
Vectorized<double> eq(const Vectorized<double>& other) const;
|
||||
Vectorized<double> ne(const Vectorized<double>& other) const;
|
||||
Vectorized<double> gt(const Vectorized<double>& other) const;
|
||||
Vectorized<double> ge(const Vectorized<double>& other) const;
|
||||
Vectorized<double> lt(const Vectorized<double>& other) const;
|
||||
Vectorized<double> le(const Vectorized<double>& other) const;
|
||||
};
|
||||
|
||||
template <>
|
||||
Vectorized<double> inline operator+(
|
||||
const Vectorized<double>& a,
|
||||
const Vectorized<double>& b) {
|
||||
return vaddq_f64(a, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<double> inline operator-(
|
||||
const Vectorized<double>& a,
|
||||
const Vectorized<double>& b) {
|
||||
return vsubq_f64(a, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<double> inline operator*(
|
||||
const Vectorized<double>& a,
|
||||
const Vectorized<double>& b) {
|
||||
return vmulq_f64(a, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<double> inline operator/(
|
||||
const Vectorized<double>& a,
|
||||
const Vectorized<double>& b) {
|
||||
return vdivq_f64(a, b);
|
||||
}
|
||||
|
||||
// frac. Implement this here so we can use subtraction
|
||||
Vectorized<double> inline Vectorized<double>::frac() const {
|
||||
return *this - this->trunc();
|
||||
}
|
||||
|
||||
// Implements the IEEE 754 201X `maximum` operation, which propagates NaN if
|
||||
// either input is a NaN.
|
||||
template <>
|
||||
Vectorized<double> inline maximum(
|
||||
const Vectorized<double>& a,
|
||||
const Vectorized<double>& b) {
|
||||
return vmaxq_f64(a, b);
|
||||
}
|
||||
|
||||
// Implements the IEEE 754 201X `minimum` operation, which propagates NaN if
|
||||
// either input is a NaN.
|
||||
template <>
|
||||
Vectorized<double> inline minimum(
|
||||
const Vectorized<double>& a,
|
||||
const Vectorized<double>& b) {
|
||||
return vminq_f64(a, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<double> inline clamp(
|
||||
const Vectorized<double>& a,
|
||||
const Vectorized<double>& min,
|
||||
const Vectorized<double>& max) {
|
||||
return vminq_f64(max, vmaxq_f64(min, a));
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<double> inline clamp_max(
|
||||
const Vectorized<double>& a,
|
||||
const Vectorized<double>& max) {
|
||||
return vminq_f64(max, a);
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<double> inline clamp_min(
|
||||
const Vectorized<double>& a,
|
||||
const Vectorized<double>& min) {
|
||||
return vmaxq_f64(min, a);
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<double> inline operator&(
|
||||
const Vectorized<double>& a,
|
||||
const Vectorized<double>& b) {
|
||||
return vreinterpretq_f64_u64(
|
||||
vandq_u64(vreinterpretq_u64_f64(a), vreinterpretq_u64_f64(b)));
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<double> inline operator|(
|
||||
const Vectorized<double>& a,
|
||||
const Vectorized<double>& b) {
|
||||
return vreinterpretq_f64_u64(
|
||||
vorrq_u64(vreinterpretq_u64_f64(a), vreinterpretq_u64_f64(b)));
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<double> inline operator^(
|
||||
const Vectorized<double>& a,
|
||||
const Vectorized<double>& b) {
|
||||
return vreinterpretq_f64_u64(
|
||||
veorq_u64(vreinterpretq_u64_f64(a), vreinterpretq_u64_f64(b)));
|
||||
}
|
||||
|
||||
inline Vectorized<double> Vectorized<double>::eq(
|
||||
const Vectorized<double>& other) const {
|
||||
return (*this == other) & Vectorized<double>(1.0);
|
||||
}
|
||||
|
||||
inline Vectorized<double> Vectorized<double>::ne(
|
||||
const Vectorized<double>& other) const {
|
||||
return (*this != other) & Vectorized<double>(1.0);
|
||||
}
|
||||
|
||||
inline Vectorized<double> Vectorized<double>::gt(
|
||||
const Vectorized<double>& other) const {
|
||||
return (*this > other) & Vectorized<double>(1.0);
|
||||
}
|
||||
|
||||
inline Vectorized<double> Vectorized<double>::ge(
|
||||
const Vectorized<double>& other) const {
|
||||
return (*this >= other) & Vectorized<double>(1.0);
|
||||
}
|
||||
|
||||
inline Vectorized<double> Vectorized<double>::lt(
|
||||
const Vectorized<double>& other) const {
|
||||
return (*this < other) & Vectorized<double>(1.0);
|
||||
}
|
||||
|
||||
inline Vectorized<double> Vectorized<double>::le(
|
||||
const Vectorized<double>& other) const {
|
||||
return (*this <= other) & Vectorized<double>(1.0);
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<double> inline fmadd(
|
||||
const Vectorized<double>& a,
|
||||
const Vectorized<double>& b,
|
||||
const Vectorized<double>& c) {
|
||||
return vfmaq_f64(c, a, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<double> inline fnmadd(
|
||||
const Vectorized<double>& a,
|
||||
const Vectorized<double>& b,
|
||||
const Vectorized<double>& c) {
|
||||
return vfmsq_f64(c, a, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<double> inline fmsub(
|
||||
const Vectorized<double>& a,
|
||||
const Vectorized<double>& b,
|
||||
const Vectorized<double>& c) {
|
||||
return vfmaq_f64(vnegq_f64(c), a, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<double> inline fnmsub(
|
||||
const Vectorized<double>& a,
|
||||
const Vectorized<double>& b,
|
||||
const Vectorized<double>& c) {
|
||||
return vfmsq_f64(vnegq_f64(c), a, b);
|
||||
}
|
||||
|
||||
} // namespace CPU_CAPABILITY
|
||||
} // namespace at::vec
|
||||
@ -307,49 +307,11 @@ class Vectorized<float> {
|
||||
DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(exp)
|
||||
DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(exp2)
|
||||
DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(expm1)
|
||||
// Implementation copied from Arm Optimized Routine
|
||||
// https://github.com/ARM-software/optimized-routines/blob/master/math/aarch64/advsimd/expf.c
|
||||
Vectorized<float> exp_u20() const {
|
||||
// bail out to sleef if it's a special case:
|
||||
// i.e. there's an input s.t. |input| > 87.3....
|
||||
const float32x4_t special_bound = vdupq_n_f32(0x1.5d5e2ap+6f);
|
||||
uint32x4_t cmp = vcagtq_f32(values, special_bound);
|
||||
if (vpaddd_u64(vreinterpretq_u64_u32(cmp)) != 0) {
|
||||
return exp();
|
||||
}
|
||||
|
||||
const float32x4_t inv_ln2 = vdupq_n_f32(0x1.715476p+0f);
|
||||
const float ln2_hi = 0x1.62e4p-1f;
|
||||
const float ln2_lo = 0x1.7f7d1cp-20f;
|
||||
const float c0 = 0x1.0e4020p-7f;
|
||||
const float c2 = 0x1.555e66p-3f;
|
||||
const float32x4_t ln2_c02 = {ln2_hi, ln2_lo, c0, c2};
|
||||
|
||||
const uint32x4_t exponent_bias = vdupq_n_u32(0x3f800000);
|
||||
const float32x4_t c1 = vdupq_n_f32(0x1.573e2ep-5f);
|
||||
const float32x4_t c3 = vdupq_n_f32(0x1.fffdb6p-2f);
|
||||
const float32x4_t c4 = vdupq_n_f32(0x1.ffffecp-1f);
|
||||
|
||||
/* exp(x) = 2^n (1 + poly(r)), with 1 + poly(r) in [1/sqrt(2),sqrt(2)]
|
||||
x = ln2*n + r, with r in [-ln2/2, ln2/2]. */
|
||||
|
||||
float32x4_t n = vrndaq_f32(vmulq_f32(values, inv_ln2));
|
||||
float32x4_t r = vfmsq_laneq_f32(values, n, ln2_c02, 0);
|
||||
r = vfmsq_laneq_f32(r, n, ln2_c02, 1);
|
||||
uint32x4_t e = vshlq_n_u32(vreinterpretq_u32_s32(vcvtq_s32_f32(n)), 23);
|
||||
float32x4_t scale = vreinterpretq_f32_u32(vaddq_u32(e, exponent_bias));
|
||||
|
||||
float32x4_t r2 = vmulq_f32(r, r);
|
||||
float32x4_t p = vfmaq_laneq_f32(c1, r, ln2_c02, 2);
|
||||
float32x4_t q = vfmaq_laneq_f32(c3, r, ln2_c02, 3);
|
||||
q = vfmaq_f32(q, p, r2);
|
||||
p = vmulq_f32(c4, r);
|
||||
float32x4_t poly = vfmaq_f32(p, q, r2);
|
||||
|
||||
return vfmaq_f32(scale, poly, scale);
|
||||
return exp();
|
||||
}
|
||||
Vectorized<float> fexp_u20() const {
|
||||
return exp_u20();
|
||||
return exp();
|
||||
}
|
||||
DEFINE_SLEEF_COMPATIBLE_BINARY_ELEMENTWISE_FUNC_WITH_SLEEF_NAME(
|
||||
fmod,
|
||||
@ -578,6 +540,42 @@ inline Vectorized<float> Vectorized<float>::le(
|
||||
return (*this <= other) & Vectorized<float>(1.0f);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline void convert(const float* src, int32_t* dst, int64_t n) {
|
||||
int64_t i;
|
||||
#ifndef __msvc_cl__
|
||||
#pragma unroll
|
||||
#endif
|
||||
for (i = 0; i <= (n - Vectorized<float>::size());
|
||||
i += Vectorized<float>::size()) {
|
||||
vst1q_s32(dst + i, vcvtq_s32_f32(vld1q_f32(src + i)));
|
||||
}
|
||||
#ifndef __msvc_cl__
|
||||
#pragma unroll
|
||||
#endif
|
||||
for (; i < n; i++) {
|
||||
dst[i] = static_cast<int32_t>(src[i]);
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
inline void convert(const int32_t* src, float* dst, int64_t n) {
|
||||
int64_t i;
|
||||
#ifndef __msvc_cl__
|
||||
#pragma unroll
|
||||
#endif
|
||||
for (i = 0; i <= (n - Vectorized<float>::size());
|
||||
i += Vectorized<float>::size()) {
|
||||
vst1q_f32(dst + i, vcvtq_f32_s32(vld1q_s32(src + i)));
|
||||
}
|
||||
#ifndef __msvc_cl__
|
||||
#pragma unroll
|
||||
#endif
|
||||
for (; i < n; i++) {
|
||||
dst[i] = static_cast<float>(src[i]);
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<float> inline fmadd(
|
||||
const Vectorized<float>& a,
|
||||
|
||||
@ -569,6 +569,46 @@ inline Vectorized<c10::Half> Vectorized<c10::Half>::le(
|
||||
return (*this <= other) & Vectorized<c10::Half>(1);
|
||||
}
|
||||
|
||||
// These are global functions, so the defaults in vec_base.h should
|
||||
// work fine if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC is not available.
|
||||
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
|
||||
template <>
|
||||
inline void convert(const float16_t* src, int16_t* dst, int64_t n) {
|
||||
int64_t i;
|
||||
#ifndef __msvc_cl__
|
||||
#pragma unroll
|
||||
#endif
|
||||
for (i = 0; i <= (n - Vectorized<c10::Half>::size());
|
||||
i += Vectorized<c10::Half>::size()) {
|
||||
vst1q_s16(dst + i, vcvtq_s16_f16(vld1q_f16(src + i)));
|
||||
}
|
||||
#ifndef __msvc_cl__
|
||||
#pragma unroll
|
||||
#endif
|
||||
for (; i < n; i++) {
|
||||
dst[i] = static_cast<int16_t>(src[i]);
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
inline void convert(const int16_t* src, float16_t* dst, int64_t n) {
|
||||
int64_t i;
|
||||
#ifndef __msvc_cl__
|
||||
#pragma unroll
|
||||
#endif
|
||||
for (i = 0; i <= (n - Vectorized<c10::Half>::size());
|
||||
i += Vectorized<c10::Half>::size()) {
|
||||
vst1q_f16(dst + i, vcvtq_f16_s16(vld1q_s16(src + i)));
|
||||
}
|
||||
#ifndef __msvc_cl__
|
||||
#pragma unroll
|
||||
#endif
|
||||
for (; i < n; i++) {
|
||||
dst[i] = static_cast<float16_t>(src[i]);
|
||||
}
|
||||
}
|
||||
#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
|
||||
|
||||
template <>
|
||||
Vectorized<c10::Half> inline fmadd(
|
||||
const Vectorized<c10::Half>& a,
|
||||
|
||||
@ -1,378 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include <ATen/cpu/vec/intrinsics.h>
|
||||
#include <ATen/cpu/vec/vec_base.h>
|
||||
#include <c10/macros/Macros.h>
|
||||
#include <c10/util/irange.h>
|
||||
|
||||
namespace at::vec {
|
||||
// Note [CPU_CAPABILITY namespace]
|
||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
// This header, and all of its subheaders, will be compiled with
|
||||
// different architecture flags for each supported set of vector
|
||||
// intrinsics. So we need to make sure they aren't inadvertently
|
||||
// linked together. We do this by declaring objects in an `inline
|
||||
// namespace` which changes the name mangling, but can still be
|
||||
// accessed as `at::vec`.
|
||||
inline namespace CPU_CAPABILITY {
|
||||
|
||||
#define VEC_UINT_NEON_TEMPLATE(vl, bit) \
|
||||
template <> \
|
||||
struct is_vec_specialized_for<uint##bit##_t> : std::bool_constant<true> {}; \
|
||||
\
|
||||
template <> \
|
||||
class Vectorized<uint##bit##_t> { \
|
||||
using neon_type = uint##bit##x##vl##_t; \
|
||||
\
|
||||
private: \
|
||||
neon_type values; \
|
||||
\
|
||||
public: \
|
||||
using value_type = uint##bit##_t; \
|
||||
using size_type = int; \
|
||||
static constexpr size_type size() { \
|
||||
return vl; \
|
||||
} \
|
||||
Vectorized() { \
|
||||
values = vdupq_n_u##bit(0); \
|
||||
} \
|
||||
Vectorized(neon_type v) : values(v) {} \
|
||||
Vectorized(uint##bit##_t val); \
|
||||
template < \
|
||||
typename... Args, \
|
||||
typename = std::enable_if_t<(sizeof...(Args) == size())>> \
|
||||
Vectorized(Args... vals) { \
|
||||
__at_align__ uint##bit##_t buffer[size()] = {vals...}; \
|
||||
values = vld1q_u##bit(buffer); \
|
||||
} \
|
||||
operator neon_type() const { \
|
||||
return values; \
|
||||
} \
|
||||
static Vectorized<uint##bit##_t> loadu( \
|
||||
const void* ptr, \
|
||||
uint64_t count = size()); \
|
||||
void store(void* ptr, uint64_t count = size()) const; \
|
||||
template <uint64_t mask> \
|
||||
static Vectorized<uint##bit##_t> blend( \
|
||||
const Vectorized<uint##bit##_t>& a, \
|
||||
const Vectorized<uint##bit##_t>& b); \
|
||||
static Vectorized<uint##bit##_t> blendv( \
|
||||
const Vectorized<uint##bit##_t>& a, \
|
||||
const Vectorized<uint##bit##_t>& b, \
|
||||
const Vectorized<uint##bit##_t>& mask_) { \
|
||||
return vbslq_u##bit(mask_.values, b, a); \
|
||||
} \
|
||||
template <typename step_t> \
|
||||
static Vectorized<uint##bit##_t> arange( \
|
||||
value_type base = 0, \
|
||||
step_t step = static_cast<step_t>(1)); \
|
||||
static Vectorized<uint##bit##_t> set( \
|
||||
const Vectorized<uint##bit##_t>& a, \
|
||||
const Vectorized<uint##bit##_t>& b, \
|
||||
uint64_t count = size()); \
|
||||
const uint##bit##_t& operator[](uint idx) const = delete; \
|
||||
uint##bit##_t& operator[](uint idx) = delete; \
|
||||
Vectorized<uint##bit##_t> abs() const { \
|
||||
return values; \
|
||||
} \
|
||||
Vectorized<uint##bit##_t> real() const { \
|
||||
return values; \
|
||||
} \
|
||||
Vectorized<uint##bit##_t> imag() const { \
|
||||
return vdupq_n_u##bit(0); \
|
||||
} \
|
||||
Vectorized<uint##bit##_t> conj() const { \
|
||||
return values; \
|
||||
} \
|
||||
Vectorized<uint##bit##_t> neg() const { \
|
||||
return vreinterpretq_u##bit##_s##bit( \
|
||||
vnegq_s##bit(vreinterpretq_s##bit##_u##bit(values))); \
|
||||
} \
|
||||
uint##bit##_t reduce_add() const { \
|
||||
return vaddvq_u##bit(values); \
|
||||
} \
|
||||
uint##bit##_t reduce_max() const; \
|
||||
Vectorized<uint##bit##_t> operator==( \
|
||||
const Vectorized<uint##bit##_t>& other) const { \
|
||||
return Vectorized<value_type>(vceqq_u##bit(values, other.values)); \
|
||||
} \
|
||||
Vectorized<uint##bit##_t> operator!=( \
|
||||
const Vectorized<uint##bit##_t>& other) const; \
|
||||
Vectorized<uint##bit##_t> operator<( \
|
||||
const Vectorized<uint##bit##_t>& other) const { \
|
||||
return Vectorized<value_type>(vcltq_u##bit(values, other.values)); \
|
||||
} \
|
||||
Vectorized<uint##bit##_t> operator<=( \
|
||||
const Vectorized<uint##bit##_t>& other) const { \
|
||||
return Vectorized<value_type>(vcleq_u##bit(values, other.values)); \
|
||||
} \
|
||||
Vectorized<uint##bit##_t> operator>( \
|
||||
const Vectorized<uint##bit##_t>& other) const { \
|
||||
return Vectorized<value_type>(vcgtq_u##bit(values, other.values)); \
|
||||
} \
|
||||
Vectorized<uint##bit##_t> operator>=( \
|
||||
const Vectorized<uint##bit##_t>& other) const { \
|
||||
return Vectorized<value_type>(vcgeq_u##bit(values, other.values)); \
|
||||
} \
|
||||
Vectorized<uint##bit##_t> eq( \
|
||||
const Vectorized<uint##bit##_t>& other) const; \
|
||||
Vectorized<uint##bit##_t> ne( \
|
||||
const Vectorized<uint##bit##_t>& other) const; \
|
||||
Vectorized<uint##bit##_t> gt( \
|
||||
const Vectorized<uint##bit##_t>& other) const; \
|
||||
Vectorized<uint##bit##_t> ge( \
|
||||
const Vectorized<uint##bit##_t>& other) const; \
|
||||
Vectorized<uint##bit##_t> lt( \
|
||||
const Vectorized<uint##bit##_t>& other) const; \
|
||||
Vectorized<uint##bit##_t> le( \
|
||||
const Vectorized<uint##bit##_t>& other) const; \
|
||||
}; \
|
||||
template <> \
|
||||
Vectorized<uint##bit##_t> inline operator+( \
|
||||
const Vectorized<uint##bit##_t>& a, \
|
||||
const Vectorized<uint##bit##_t>& b) { \
|
||||
return vaddq_u##bit(a, b); \
|
||||
} \
|
||||
template <> \
|
||||
Vectorized<uint##bit##_t> inline operator-( \
|
||||
const Vectorized<uint##bit##_t>& a, \
|
||||
const Vectorized<uint##bit##_t>& b) { \
|
||||
return vsubq_u##bit(a, b); \
|
||||
} \
|
||||
template <> \
|
||||
Vectorized<uint##bit##_t> inline operator&( \
|
||||
const Vectorized<uint##bit##_t>& a, \
|
||||
const Vectorized<uint##bit##_t>& b) { \
|
||||
return vandq_u##bit(a, b); \
|
||||
} \
|
||||
template <> \
|
||||
Vectorized<uint##bit##_t> inline operator|( \
|
||||
const Vectorized<uint##bit##_t>& a, \
|
||||
const Vectorized<uint##bit##_t>& b) { \
|
||||
return vorrq_u##bit(a, b); \
|
||||
} \
|
||||
template <> \
|
||||
Vectorized<uint##bit##_t> inline operator^( \
|
||||
const Vectorized<uint##bit##_t>& a, \
|
||||
const Vectorized<uint##bit##_t>& b) { \
|
||||
return veorq_u##bit(a, b); \
|
||||
} \
|
||||
Vectorized<uint##bit##_t> inline Vectorized<uint##bit##_t>::eq( \
|
||||
const Vectorized<uint##bit##_t>& other) const { \
|
||||
return (*this == other) & Vectorized<uint##bit##_t>(1); \
|
||||
} \
|
||||
Vectorized<uint##bit##_t> inline Vectorized<uint##bit##_t>::ne( \
|
||||
const Vectorized<uint##bit##_t>& other) const { \
|
||||
return (*this != other) & Vectorized<uint##bit##_t>(1); \
|
||||
} \
|
||||
Vectorized<uint##bit##_t> inline Vectorized<uint##bit##_t>::gt( \
|
||||
const Vectorized<uint##bit##_t>& other) const { \
|
||||
return (*this > other) & Vectorized<uint##bit##_t>(1); \
|
||||
} \
|
||||
Vectorized<uint##bit##_t> inline Vectorized<uint##bit##_t>::ge( \
|
||||
const Vectorized<uint##bit##_t>& other) const { \
|
||||
return (*this >= other) & Vectorized<uint##bit##_t>(1); \
|
||||
} \
|
||||
Vectorized<uint##bit##_t> inline Vectorized<uint##bit##_t>::lt( \
|
||||
const Vectorized<uint##bit##_t>& other) const { \
|
||||
return (*this < other) & Vectorized<uint##bit##_t>(1); \
|
||||
} \
|
||||
Vectorized<uint##bit##_t> inline Vectorized<uint##bit##_t>::le( \
|
||||
const Vectorized<uint##bit##_t>& other) const { \
|
||||
return (*this <= other) & Vectorized<uint##bit##_t>(1); \
|
||||
}
|
||||
|
||||
VEC_UINT_NEON_TEMPLATE(16, 8)
|
||||
|
||||
inline uint8_t Vectorized<uint8_t>::reduce_max() const {
|
||||
return vmaxvq_u8(values);
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<uint8_t> inline operator*(
|
||||
const Vectorized<uint8_t>& a,
|
||||
const Vectorized<uint8_t>& b) {
|
||||
return vmulq_u8(a, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline Vectorized<uint8_t> operator~(const Vectorized<uint8_t>& a) {
|
||||
return vmvnq_u8(a);
|
||||
}
|
||||
|
||||
inline Vectorized<uint8_t> Vectorized<uint8_t>::operator!=(
|
||||
const Vectorized<uint8_t>& other) const {
|
||||
return ~(*this == other);
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<uint8_t> inline minimum(
|
||||
const Vectorized<uint8_t>& a,
|
||||
const Vectorized<uint8_t>& b) {
|
||||
return vminq_u8(a, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<uint8_t> inline maximum(
|
||||
const Vectorized<uint8_t>& a,
|
||||
const Vectorized<uint8_t>& b) {
|
||||
return vmaxq_u8(a, b);
|
||||
}
|
||||
|
||||
template <uint64_t mask>
|
||||
Vectorized<uint8_t> Vectorized<uint8_t>::blend(
|
||||
const Vectorized<uint8_t>& a,
|
||||
const Vectorized<uint8_t>& b) {
|
||||
// Build an array of flags: each bit of element is 1 if the corresponding bit
|
||||
// in 'mask' is set, 0 otherwise.
|
||||
uint8x16_t maskArray = {
|
||||
(mask & 1LL) ? 0xFF : 0,
|
||||
(mask & 2LL) ? 0xFF : 0,
|
||||
(mask & 4LL) ? 0xFF : 0,
|
||||
(mask & 8LL) ? 0xFF : 0,
|
||||
(mask & 16LL) ? 0xFF : 0,
|
||||
(mask & 32LL) ? 0xFF : 0,
|
||||
(mask & 64LL) ? 0xFF : 0,
|
||||
(mask & 128LL) ? 0xFF : 0,
|
||||
(mask & 256LL) ? 0xFF : 0,
|
||||
(mask & 512LL) ? 0xFF : 0,
|
||||
(mask & 1024LL) ? 0xFF : 0,
|
||||
(mask & 2048LL) ? 0xFF : 0,
|
||||
(mask & 4096LL) ? 0xFF : 0,
|
||||
(mask & 8192LL) ? 0xFF : 0,
|
||||
(mask & 16384LL) ? 0xFF : 0,
|
||||
(mask & 32768LL) ? 0xFF : 0};
|
||||
// Use BSL to select elements from b where the mask is 1, else from a
|
||||
return vbslq_u8(maskArray, b.values, a.values);
|
||||
}
|
||||
|
||||
#define VEC_UINT_NEON_OPS(vl, bit) \
|
||||
inline Vectorized<uint##bit##_t>::Vectorized(uint##bit##_t val) { \
|
||||
values = vdupq_n_u##bit(val); \
|
||||
} \
|
||||
inline Vectorized<uint##bit##_t> Vectorized<uint##bit##_t>::loadu( \
|
||||
const void* ptr, uint64_t count) { \
|
||||
if (count == size()) { \
|
||||
return vld1q_u##bit(reinterpret_cast<const uint##bit##_t*>(ptr)); \
|
||||
} else { \
|
||||
__at_align__ uint##bit##_t tmp_values[size()]; \
|
||||
for (const auto i : c10::irange(size())) { \
|
||||
tmp_values[i] = 0; \
|
||||
} \
|
||||
std::memcpy( \
|
||||
tmp_values, \
|
||||
reinterpret_cast<const uint##bit##_t*>(ptr), \
|
||||
count * sizeof(uint##bit##_t)); \
|
||||
return vld1q_u##bit(reinterpret_cast<const uint##bit##_t*>(tmp_values)); \
|
||||
} \
|
||||
} \
|
||||
inline void Vectorized<uint##bit##_t>::store(void* ptr, uint64_t count) \
|
||||
const { \
|
||||
if (count == size()) { \
|
||||
vst1q_u##bit(reinterpret_cast<uint##bit##_t*>(ptr), values); \
|
||||
} else { \
|
||||
uint##bit##_t tmp_values[size()]; \
|
||||
vst1q_u##bit(reinterpret_cast<uint##bit##_t*>(tmp_values), values); \
|
||||
std::memcpy(ptr, tmp_values, count * sizeof(uint##bit##_t)); \
|
||||
} \
|
||||
}
|
||||
|
||||
VEC_UINT_NEON_OPS(16, 8)
|
||||
|
||||
template <typename step_t>
|
||||
inline Vectorized<uint8_t> Vectorized<uint8_t>::arange(
|
||||
uint8_t base,
|
||||
step_t step) {
|
||||
const Vectorized<uint8_t> base_vec(base);
|
||||
const Vectorized<uint8_t> step_vec(step);
|
||||
const uint8x16_t step_sizes = {
|
||||
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15};
|
||||
return vmlaq_u8(base_vec, step_sizes, step_vec);
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<uint8_t> inline operator>>(
|
||||
const Vectorized<uint8_t>& a,
|
||||
const Vectorized<uint8_t>& b) {
|
||||
uint8x16_t x = a;
|
||||
uint8x16_t bound = vdupq_n_u8(8);
|
||||
uint8x16_t z = vminq_u8(b, bound);
|
||||
return x >> z;
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<uint8_t> inline operator<<(
|
||||
const Vectorized<uint8_t>& a,
|
||||
const Vectorized<uint8_t>& b) {
|
||||
uint8x16_t bound = vdupq_n_u8(8);
|
||||
uint8x16_t z = vminq_u8(b, bound);
|
||||
return vshlq_u8(a, vreinterpretq_s8_u8(z));
|
||||
}
|
||||
|
||||
inline Vectorized<uint8_t> Vectorized<uint8_t>::set(
|
||||
const Vectorized<uint8_t>& a,
|
||||
const Vectorized<uint8_t>& b,
|
||||
uint64_t count) {
|
||||
if (count == 0) {
|
||||
return a;
|
||||
} else if (count >= 16) {
|
||||
return b;
|
||||
} else {
|
||||
// Build an array of flags: each bit of element is 1 if the corresponding
|
||||
// bit in 'mask' is set, 0 otherwise.
|
||||
uint8x16_t maskArray = {
|
||||
static_cast<uint8_t>((count >= 1LL) ? 0xFF : 0),
|
||||
static_cast<uint8_t>((count >= 2LL) ? 0xFF : 0),
|
||||
static_cast<uint8_t>((count >= 3LL) ? 0xFF : 0),
|
||||
static_cast<uint8_t>((count >= 4LL) ? 0xFF : 0),
|
||||
static_cast<uint8_t>((count >= 5LL) ? 0xFF : 0),
|
||||
static_cast<uint8_t>((count >= 6LL) ? 0xFF : 0),
|
||||
static_cast<uint8_t>((count >= 7LL) ? 0xFF : 0),
|
||||
static_cast<uint8_t>((count >= 8LL) ? 0xFF : 0),
|
||||
static_cast<uint8_t>((count >= 9LL) ? 0xFF : 0),
|
||||
static_cast<uint8_t>((count >= 10LL) ? 0xFF : 0),
|
||||
static_cast<uint8_t>((count >= 11LL) ? 0xFF : 0),
|
||||
static_cast<uint8_t>((count >= 12LL) ? 0xFF : 0),
|
||||
static_cast<uint8_t>((count >= 13LL) ? 0xFF : 0),
|
||||
static_cast<uint8_t>((count >= 14LL) ? 0xFF : 0),
|
||||
static_cast<uint8_t>((count >= 15LL) ? 0xFF : 0),
|
||||
0};
|
||||
|
||||
// Use BSL to select elements from b where the mask is 1, else from a
|
||||
return vbslq_u8(maskArray, b.values, a.values);
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<uint8_t> inline operator/(
|
||||
const Vectorized<uint8_t>& a,
|
||||
const Vectorized<uint8_t>& b) {
|
||||
uint8x16_t x = a;
|
||||
uint8x16_t y = b;
|
||||
return x / y;
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<uint8_t> inline clamp(
|
||||
const Vectorized<uint8_t>& a,
|
||||
const Vectorized<uint8_t>& min,
|
||||
const Vectorized<uint8_t>& max) {
|
||||
return minimum(max, maximum(min, a));
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<uint8_t> inline clamp_max(
|
||||
const Vectorized<uint8_t>& a,
|
||||
const Vectorized<uint8_t>& max) {
|
||||
return minimum(max, a);
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<uint8_t> inline clamp_min(
|
||||
const Vectorized<uint8_t>& a,
|
||||
const Vectorized<uint8_t>& min) {
|
||||
return maximum(min, a);
|
||||
}
|
||||
|
||||
} // namespace CPU_CAPABILITY
|
||||
} // namespace at::vec
|
||||
@ -1390,7 +1390,7 @@ std::pair<Vectorized<float>, Vectorized<float>> inline convert_int8_to_float(
|
||||
|
||||
std::pair<Vectorized<float>, Vectorized<float>> inline convert_int8_to_float(
|
||||
at::vec::Vectorized<uint8_t> src) {
|
||||
auto u8x8 = vget_low_u8(src);
|
||||
auto u8x8 = vld1_u8(src.operator const uint8_t*());
|
||||
auto u16x8 = vmovl_u8(u8x8);
|
||||
auto u32x4_hi = vmovl_u16(vget_high_u16(u16x8));
|
||||
auto u32x4_lo = vmovl_u16(vget_low_u16(u16x8));
|
||||
@ -1412,7 +1412,7 @@ Vectorized<float> inline convert_int8_half_register_to_float(
|
||||
|
||||
Vectorized<float> inline convert_int8_half_register_to_float(
|
||||
at::vec::Vectorized<uint8_t> src) {
|
||||
auto u8x8 = vget_low_u8(src);
|
||||
auto u8x8 = vld1_u8(src.operator const uint8_t*());
|
||||
auto u16x8 = vmovl_u8(u8x8);
|
||||
auto u32x4_lo = vmovl_u16(vget_low_u16(u16x8));
|
||||
|
||||
|
||||
@ -1,192 +0,0 @@
|
||||
#include <ATen/cuda/CUDAGreenContext.h>
|
||||
|
||||
namespace at::cuda {
|
||||
GreenContext::GreenContext(uint32_t device_id, uint32_t num_sms) {
|
||||
#if CUDA_HAS_GREEN_CONTEXT
|
||||
int driver_version;
|
||||
C10_CUDA_CHECK(cudaDriverGetVersion(&driver_version));
|
||||
TORCH_CHECK(
|
||||
driver_version >= 12080, "cuda driver too old to use green context!");
|
||||
CUcontext pctx = nullptr;
|
||||
C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuCtxGetCurrent_(&pctx));
|
||||
if (C10_UNLIKELY(!pctx)) {
|
||||
TORCH_WARN(
|
||||
"Attempted to create a green context but"
|
||||
" there was no primary context! Creating a primary context...");
|
||||
|
||||
cudaFree(0);
|
||||
}
|
||||
|
||||
CUdevice device;
|
||||
device_id_ = device_id;
|
||||
C10_CUDA_DRIVER_CHECK(
|
||||
c10::cuda::DriverAPI::get()->cuDeviceGet_(&device, device_id));
|
||||
|
||||
// Get device resources
|
||||
CUdevResource device_resource;
|
||||
C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuDeviceGetDevResource_(
|
||||
device, &device_resource, CU_DEV_RESOURCE_TYPE_SM));
|
||||
|
||||
// Split resources
|
||||
std::vector<CUdevResource> result(1);
|
||||
auto result_data = result.data();
|
||||
unsigned int nb_groups = 1;
|
||||
CUdevResource remaining;
|
||||
|
||||
C10_CUDA_DRIVER_CHECK(
|
||||
c10::cuda::DriverAPI::get()->cuDevSmResourceSplitByCount_(
|
||||
result_data,
|
||||
&nb_groups,
|
||||
&device_resource,
|
||||
&remaining,
|
||||
0, // default flags
|
||||
num_sms));
|
||||
|
||||
TORCH_CHECK(nb_groups == 1, "Failed to create single resource group");
|
||||
|
||||
// Generate resource descriptor
|
||||
CUdevResourceDesc desc;
|
||||
C10_CUDA_DRIVER_CHECK(
|
||||
c10::cuda::DriverAPI::get()->cuDevResourceGenerateDesc_(
|
||||
&desc, result_data, 1));
|
||||
|
||||
// Create green context
|
||||
// CU_GREEN_CTX_DEFAULT_STREAM is required per docs:
|
||||
// https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__GREEN__CONTEXTS.html
|
||||
C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuGreenCtxCreate_(
|
||||
&green_ctx_, desc, device, CU_GREEN_CTX_DEFAULT_STREAM));
|
||||
|
||||
// Convert to regular context
|
||||
C10_CUDA_DRIVER_CHECK(
|
||||
c10::cuda::DriverAPI::get()->cuCtxFromGreenCtx_(&context_, green_ctx_));
|
||||
TORCH_CHECK(context_, "Green ctx conversion to regular ctx failed!");
|
||||
#else
|
||||
TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
|
||||
#endif
|
||||
}
|
||||
|
||||
std::unique_ptr<GreenContext> GreenContext::create(
|
||||
uint32_t num_sms,
|
||||
std::optional<uint32_t> device_id) {
|
||||
#if CUDA_HAS_GREEN_CONTEXT
|
||||
if (!device_id.has_value()) {
|
||||
device_id = at::cuda::current_device();
|
||||
}
|
||||
return std::make_unique<GreenContext>(device_id.value(), num_sms);
|
||||
#else
|
||||
TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
|
||||
#endif
|
||||
}
|
||||
|
||||
// Implement move operations
|
||||
GreenContext::GreenContext(GreenContext&& other) noexcept{
|
||||
#if CUDA_HAS_GREEN_CONTEXT
|
||||
device_id_ = std::exchange(other.device_id_, -1);
|
||||
green_ctx_ = std::exchange(other.green_ctx_, nullptr);
|
||||
context_ = std::exchange(other.context_, nullptr);
|
||||
parent_stream_ = std::exchange(other.parent_stream_, nullptr);
|
||||
#else
|
||||
TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
|
||||
#endif
|
||||
}
|
||||
|
||||
GreenContext& GreenContext::operator=(GreenContext&& other) noexcept{
|
||||
#if CUDA_HAS_GREEN_CONTEXT
|
||||
if (this != &other) {
|
||||
// Clean up current resources
|
||||
if (green_ctx_) {
|
||||
CUcontext current = nullptr;
|
||||
C10_CUDA_DRIVER_CHECK(
|
||||
c10::cuda::DriverAPI::get()->cuCtxGetCurrent_(¤t));
|
||||
if (current == context_) {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"attempting to overwrite current green ctx "
|
||||
"when it is active!");
|
||||
}
|
||||
C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuGreenCtxDestroy_(green_ctx_));
|
||||
}
|
||||
|
||||
// Take ownership of other's resources
|
||||
device_id_ = std::exchange(other.device_id_, -1);
|
||||
green_ctx_ = std::exchange(other.green_ctx_, nullptr);
|
||||
context_ = std::exchange(other.context_, nullptr);
|
||||
parent_stream_ = std::exchange(other.parent_stream_, nullptr);
|
||||
}
|
||||
return *this;
|
||||
#else
|
||||
TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
|
||||
#endif
|
||||
}
|
||||
|
||||
GreenContext::~GreenContext() noexcept{
|
||||
#if CUDA_HAS_GREEN_CONTEXT
|
||||
C10_CUDA_DRIVER_CHECK(
|
||||
c10::cuda::DriverAPI::get()->cuGreenCtxDestroy_(green_ctx_));
|
||||
#else
|
||||
TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
|
||||
#endif
|
||||
}
|
||||
|
||||
// Get the underlying CUDA context
|
||||
CUcontext GreenContext::getContext() const {
|
||||
#if CUDA_HAS_GREEN_CONTEXT
|
||||
return context_;
|
||||
#else
|
||||
TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
|
||||
#endif
|
||||
}
|
||||
|
||||
// Get the underlying green context
|
||||
#if CUDA_HAS_GREEN_CONTEXT
|
||||
CUgreenCtx GreenContext::getGreenContext() const {
|
||||
return green_ctx_;
|
||||
}
|
||||
#endif
|
||||
|
||||
// Make this context current
|
||||
void GreenContext::setContext() {
|
||||
#if CUDA_HAS_GREEN_CONTEXT
|
||||
auto current_stream = c10::cuda::getCurrentCUDAStream();
|
||||
parent_stream_ = current_stream.stream();
|
||||
|
||||
at::cuda::CUDAEvent ev;
|
||||
ev.record(current_stream);
|
||||
|
||||
CUcontext current = nullptr;
|
||||
C10_CUDA_DRIVER_CHECK(
|
||||
c10::cuda::DriverAPI::get()->cuCtxGetCurrent_(¤t));
|
||||
if (!current) {
|
||||
C10_CUDA_DRIVER_CHECK(
|
||||
c10::cuda::DriverAPI::get()->cuCtxSetCurrent_(context_));
|
||||
} else {
|
||||
C10_CUDA_DRIVER_CHECK(
|
||||
c10::cuda::DriverAPI::get()->cuCtxPushCurrent_(context_));
|
||||
}
|
||||
// currently hardcodes the new green context to use the default stream
|
||||
// TODO(eqy): consider creating a new stream if e.g., it allows interop
|
||||
// with CUDA Graph captures etc.
|
||||
auto default_stream = c10::cuda::getDefaultCUDAStream();
|
||||
ev.block(default_stream);
|
||||
c10::cuda::setCurrentCUDAStream(default_stream);
|
||||
#else
|
||||
TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
|
||||
#endif
|
||||
}
|
||||
|
||||
void GreenContext::popContext() {
|
||||
#if CUDA_HAS_GREEN_CONTEXT
|
||||
// see above note about stream being hardcoded to the default stream
|
||||
at::cuda::CUDAEvent ev;
|
||||
ev.record(c10::cuda::getCurrentCUDAStream());
|
||||
CUcontext popped;
|
||||
C10_CUDA_DRIVER_CHECK(
|
||||
c10::cuda::DriverAPI::get()->cuCtxPopCurrent_(&popped));
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
popped == context_, "expected popped context to be the current ctx");
|
||||
ev.block(c10::cuda::getStreamFromExternal(parent_stream_, device_id_));
|
||||
#else
|
||||
TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
|
||||
#endif
|
||||
}
|
||||
} // namespace at::cuda
|
||||
@ -1,53 +0,0 @@
|
||||
#pragma once
|
||||
#include <ATen/cuda/CUDAEvent.h>
|
||||
|
||||
#if defined(CUDA_VERSION) && !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
|
||||
#include <c10/cuda/driver_api.h>
|
||||
#include <cuda.h>
|
||||
#include <memory>
|
||||
#include <stdexcept>
|
||||
#include <vector>
|
||||
#define CUDA_HAS_GREEN_CONTEXT 1
|
||||
#else
|
||||
#define CUDA_HAS_GREEN_CONTEXT 0
|
||||
#endif
|
||||
|
||||
namespace at::cuda {
|
||||
|
||||
class TORCH_CUDA_CPP_API GreenContext {
|
||||
public:
|
||||
GreenContext(uint32_t device_id, uint32_t num_sms);
|
||||
|
||||
static std::unique_ptr<GreenContext> create(uint32_t num_sms, std::optional<uint32_t> device_id);
|
||||
|
||||
// Delete copy constructor and assignment
|
||||
GreenContext(const GreenContext&) = delete;
|
||||
GreenContext& operator=(const GreenContext&) = delete;
|
||||
|
||||
// Implement move operations
|
||||
GreenContext(GreenContext&& other) noexcept;
|
||||
GreenContext& operator=(GreenContext&& other) noexcept;
|
||||
~GreenContext() noexcept;
|
||||
|
||||
// Get the underlying CUDA context
|
||||
CUcontext getContext() const;
|
||||
|
||||
// Get the underlying green context
|
||||
#if CUDA_HAS_GREEN_CONTEXT
|
||||
CUgreenCtx getGreenContext() const;
|
||||
#endif
|
||||
|
||||
// Make this context current
|
||||
void setContext();
|
||||
|
||||
void popContext();
|
||||
|
||||
private:
|
||||
#if CUDA_HAS_GREEN_CONTEXT
|
||||
int32_t device_id_ = -1;
|
||||
CUgreenCtx green_ctx_ = nullptr;
|
||||
CUcontext context_ = nullptr;
|
||||
cudaStream_t parent_stream_ = nullptr;
|
||||
#endif
|
||||
};
|
||||
} // namespace at::cuda
|
||||
@ -1,270 +0,0 @@
|
||||
#include <cstdint>
|
||||
#include <c10/util/typeid.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#include <c10/util/SmallVector.h>
|
||||
#include <c10/core/Scalar.h>
|
||||
#include <c10/core/ScalarType.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
|
||||
#include <ATen/core/Tensor.h>
|
||||
#include <ATen/core/NamedTensor.h>
|
||||
#include <ATen/Dispatch.h>
|
||||
#include <ATen/ExpandUtils.h>
|
||||
#include <ATen/OpMathType.h>
|
||||
#include <ATen/TensorUtils.h>
|
||||
#include <ATen/cuda/CUDABlas.h>
|
||||
#include <ATen/cuda/tunable/Tunable.h>
|
||||
#include <ATen/cuda/tunable/TunableGemm.h>
|
||||
#include <ATen/native/Resize.h>
|
||||
#include <c10/util/MaybeOwned.h>
|
||||
#include <ATen/native/GroupedMMUtils.h>
|
||||
#include <ATen/native/cuda/RowwiseScaledMM.h>
|
||||
#include <ATen/native/cuda/ScaledGroupMM.h>
|
||||
#include <ATen/native/cuda/GroupMM.h>
|
||||
#include <ATen/ceil_div.h>
|
||||
|
||||
#ifdef USE_FBGEMM_GENAI
|
||||
#include <fbgemm_gpu/torch_ops.h>
|
||||
#endif
|
||||
|
||||
#ifndef AT_PER_OPERATOR_HEADERS
|
||||
#include <ATen/Functions.h>
|
||||
#include <ATen/NativeFunctions.h>
|
||||
#else
|
||||
#include <ATen/ops/_addmm_activation_native.h>
|
||||
#include <ATen/ops/_efficientzerotensor.h>
|
||||
#include <ATen/ops/_scaled_mm_native.h>
|
||||
#include <ATen/ops/_unsafe_view_native.h>
|
||||
#include <ATen/ops/abs.h>
|
||||
#include <ATen/ops/addmm_native.h>
|
||||
#include <ATen/ops/addmv_native.h>
|
||||
#include <ATen/ops/baddbmm_native.h>
|
||||
#include <ATen/ops/bmm_native.h>
|
||||
#include <ATen/ops/copy_native.h>
|
||||
#include <ATen/ops/dot_native.h>
|
||||
#include <ATen/ops/empty.h>
|
||||
#include <ATen/ops/empty_strided.h>
|
||||
#include <ATen/ops/gelu.h>
|
||||
#include <ATen/ops/max.h>
|
||||
#include <ATen/ops/mm_native.h>
|
||||
#include <ATen/ops/mul.h>
|
||||
#include <ATen/ops/relu.h>
|
||||
#include <ATen/ops/ones.h>
|
||||
#include <ATen/ops/scalar_tensor_native.h>
|
||||
#include <ATen/ops/vdot_native.h>
|
||||
#endif
|
||||
|
||||
using at::blas::ScalingType;
|
||||
using at::blas::SwizzleType;
|
||||
|
||||
namespace at::cuda::scaled {
|
||||
|
||||
/**
|
||||
* Both inputs must be fp8,
|
||||
* Each needs a single scale, {Tensorwise (float)}
|
||||
*/
|
||||
bool check_tensorwise_recipe(c10::ScalarType type_a,
|
||||
std::vector<ScalingType>& recipe_a,
|
||||
ArrayRef<Tensor>& scales_a,
|
||||
c10::ScalarType type_b,
|
||||
std::vector<ScalingType>& recipe_b,
|
||||
ArrayRef<Tensor>& scales_b) {
|
||||
// both types must be fp8
|
||||
if (!isFloat8Type(type_a) || !isFloat8Type(type_b)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// 1 scale each, {Tensorwise, float}
|
||||
if (scales_a.size() != 1 || recipe_a.size() != 1 || scales_b.size() != 1 || recipe_b.size() != 1) {
|
||||
return false;
|
||||
}
|
||||
// Need {Blockwise_1x32, e8m0} for A & B
|
||||
if (recipe_a[0] != ScalingType::TensorWise) return false;
|
||||
if (scales_a[0].scalar_type() != ScalarType::Float) return false;
|
||||
if (recipe_b[0] != ScalingType::TensorWise) return false;
|
||||
if (scales_b[0].scalar_type() != ScalarType::Float) return false;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
* Both inputs must be fp8,
|
||||
* Each needs scales, {Rowwise (float)}
|
||||
*/
|
||||
bool check_rowwise_recipe(c10::ScalarType type_a,
|
||||
std::vector<ScalingType>& recipe_a,
|
||||
ArrayRef<Tensor>& scales_a,
|
||||
c10::ScalarType type_b,
|
||||
std::vector<ScalingType>& recipe_b,
|
||||
ArrayRef<Tensor>& scales_b) {
|
||||
// both types must be fp8
|
||||
if (!isFloat8Type(type_a) || !isFloat8Type(type_b)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// 1 scale each, {Tensorwise, float}
|
||||
if (scales_a.size() != 1 || recipe_a.size() != 1 || scales_b.size() != 1 || recipe_b.size() != 1) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Need {RowWise, dp32} for A & B
|
||||
if (recipe_a[0] != ScalingType::RowWise) return false;
|
||||
if (scales_a[0].scalar_type() != ScalarType::Float) return false;
|
||||
if (recipe_b[0] != ScalingType::RowWise) return false;
|
||||
if (scales_b[0].scalar_type() != ScalarType::Float) return false;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Two-level scaling, canonical NVFP4
|
||||
* Both inputs must be fp4
|
||||
* A, B need 2 scales, {Blockwise_1x16 (e4m3), Tensorwise (fp32)}
|
||||
*/
|
||||
bool check_nvfp4_recipe(c10::ScalarType type_a,
|
||||
std::vector<ScalingType>& recipe_a,
|
||||
ArrayRef<Tensor>& scales_a,
|
||||
c10::ScalarType type_b,
|
||||
std::vector<ScalingType>& recipe_b,
|
||||
ArrayRef<Tensor>& scales_b) {
|
||||
// both types must be fp4
|
||||
if (type_a != ScalarType::Float4_e2m1fn_x2 || type_b != ScalarType::Float4_e2m1fn_x2) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// 2 scales, 2 recipes for each input
|
||||
if (scales_a.size() != 2 || recipe_a.size() != 2 || scales_b.size() != 2 || recipe_b.size() != 2) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Need {Blockwise_1x16, e4m3 for scale[0], Tensorwise, fp32 for scale[1]}
|
||||
if (recipe_a[0] != ScalingType::BlockWise1x16 || recipe_a[1] != ScalingType::TensorWise) return false;
|
||||
if (scales_a[0].scalar_type() != ScalarType::Float8_e4m3fn || scales_a[1].scalar_type() != ScalarType::Float) return false;
|
||||
if (recipe_b[0] != ScalingType::BlockWise1x16 || recipe_b[1] != ScalingType::TensorWise) return false;
|
||||
if (scales_b[0].scalar_type() != ScalarType::Float8_e4m3fn || scales_b[1].scalar_type() != ScalarType::Float) return false;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
* Single-level scaling, what PyT currently understands
|
||||
* Both inputs must be fp4
|
||||
* A, B need 1 scale, {Blockwise_1x16 (e4m3)}
|
||||
*/
|
||||
bool check_nvfp4_recipe_single_scale
|
||||
(c10::ScalarType type_a,
|
||||
std::vector<ScalingType>& recipe_a,
|
||||
ArrayRef<Tensor>& scales_a,
|
||||
c10::ScalarType type_b,
|
||||
std::vector<ScalingType>& recipe_b,
|
||||
ArrayRef<Tensor>& scales_b) {
|
||||
// both types must be fp4
|
||||
if (type_a != ScalarType::Float4_e2m1fn_x2 || type_b != ScalarType::Float4_e2m1fn_x2) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// 2 scales, 2 recipes for each input
|
||||
if (scales_a.size() != 1 || recipe_a.size() != 1 || scales_b.size() != 1 || recipe_b.size() != 1) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Need {Blockwise_1x16, e4m3 for scale[0], Tensorwise, fp32 for scale[1]}
|
||||
if (recipe_a[0] != ScalingType::BlockWise1x16) return false;
|
||||
if (scales_a[0].scalar_type() != ScalarType::Float8_e4m3fn) return false;
|
||||
if (recipe_b[0] != ScalingType::BlockWise1x16) return false;
|
||||
if (scales_b[0].scalar_type() != ScalarType::Float8_e4m3fn) return false;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
* Both inputs must be fp8
|
||||
* A, B must only have 1 scale each, A: {Blockwise_1x128 (float), B: {Blockwise_128x128 (float)
|
||||
*/
|
||||
bool check_deepseek_recipe(ScalingType expected_recipe_a,
|
||||
ScalingType expected_recipe_b,
|
||||
c10::ScalarType type_a,
|
||||
std::vector<ScalingType>& recipe_a,
|
||||
ArrayRef<Tensor>& scales_a,
|
||||
c10::ScalarType type_b,
|
||||
std::vector<ScalingType>& recipe_b,
|
||||
ArrayRef<Tensor>& scales_b) {
|
||||
// both types must be fp8
|
||||
if (type_a != ScalarType::Float8_e4m3fn || type_b != ScalarType::Float8_e4m3fn) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// 1 scales, 1 recipes for each input
|
||||
if (scales_a.size() != 1 || recipe_a.size() != 1 || scales_b.size() != 1 || recipe_b.size() != 1) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Need {Blockwise_1x128, float} for A, {Blockwise_128x128, float} for B
|
||||
if (recipe_a[0] != expected_recipe_a) return false;
|
||||
if (scales_a[0].scalar_type() != ScalarType::Float) return false;
|
||||
if (recipe_b[0] != expected_recipe_b) return false;
|
||||
if (scales_b[0].scalar_type() != ScalarType::Float) return false;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
* Both inputs must be fp8
|
||||
* A, B must have 1 scale each, {Blockwise_1x32, e8m0}
|
||||
*/
|
||||
bool check_mxfp8_recipe(c10::ScalarType type_a,
|
||||
std::vector<ScalingType>& recipe_a,
|
||||
ArrayRef<Tensor>& scales_a,
|
||||
c10::ScalarType type_b,
|
||||
std::vector<ScalingType>& recipe_b,
|
||||
ArrayRef<Tensor>& scales_b) {
|
||||
// both types must be fp8
|
||||
if (type_a != ScalarType::Float8_e4m3fn || type_b != ScalarType::Float8_e4m3fn) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// 1 scales, 1 recipes for each input
|
||||
if (scales_a.size() != 1 || recipe_a.size() != 1 || scales_b.size() != 1 || recipe_b.size() != 1) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Need {Blockwise_1x32, e8m0} for A & B
|
||||
if (recipe_a[0] != ScalingType::BlockWise1x32) return false;
|
||||
if (scales_a[0].scalar_type() != ScalarType::Float8_e8m0fnu) return false;
|
||||
if (recipe_b[0] != ScalingType::BlockWise1x32) return false;
|
||||
if (scales_b[0].scalar_type() != ScalarType::Float8_e8m0fnu) return false;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
* Both inputs must be fp4
|
||||
* A, B must have 1 scale each, {Blockwise_1x32, e8m0}
|
||||
*/
|
||||
bool check_mxfp4_recipe(c10::ScalarType type_a,
|
||||
std::vector<ScalingType>& recipe_a,
|
||||
ArrayRef<Tensor>& scales_a,
|
||||
c10::ScalarType type_b,
|
||||
std::vector<ScalingType>& recipe_b,
|
||||
ArrayRef<Tensor>& scales_b) {
|
||||
// both types must be fp4
|
||||
if (type_a != ScalarType::Float4_e2m1fn_x2 || type_b != ScalarType::Float4_e2m1fn_x2) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// 1 scales, 1 recipes for each input
|
||||
if (scales_a.size() != 1 || recipe_a.size() != 1 || scales_b.size() != 1 || recipe_b.size() != 1) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Need {Blockwise_1x32, e8m0} for A & B
|
||||
if (recipe_a[0] != ScalingType::BlockWise1x32) return false;
|
||||
if (scales_a[0].scalar_type() != ScalarType::Float8_e8m0fnu) return false;
|
||||
if (recipe_b[0] != ScalingType::BlockWise1x32) return false;
|
||||
if (scales_b[0].scalar_type() != ScalarType::Float8_e8m0fnu) return false;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace at::native::cuda::blas::scaled
|
||||
@ -1,174 +0,0 @@
|
||||
#include <cstdint>
|
||||
#include <c10/util/typeid.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#include <c10/util/SmallVector.h>
|
||||
#include <c10/core/Scalar.h>
|
||||
#include <c10/core/ScalarType.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
|
||||
#include <ATen/core/Tensor.h>
|
||||
#include <ATen/core/NamedTensor.h>
|
||||
#include <ATen/Dispatch.h>
|
||||
#include <ATen/ExpandUtils.h>
|
||||
#include <ATen/OpMathType.h>
|
||||
#include <ATen/TensorUtils.h>
|
||||
#include <ATen/cuda/CUDABlas.h>
|
||||
#include <ATen/cuda/tunable/Tunable.h>
|
||||
#include <ATen/cuda/tunable/TunableGemm.h>
|
||||
#include <ATen/native/Resize.h>
|
||||
#include <c10/util/MaybeOwned.h>
|
||||
#include <ATen/native/GroupedMMUtils.h>
|
||||
#include <ATen/native/cuda/RowwiseScaledMM.h>
|
||||
#include <ATen/native/cuda/ScaledGroupMM.h>
|
||||
#include <ATen/native/cuda/GroupMM.h>
|
||||
#include <ATen/ceil_div.h>
|
||||
|
||||
#ifdef USE_FBGEMM_GENAI
|
||||
#include <fbgemm_gpu/torch_ops.h>
|
||||
#endif
|
||||
|
||||
#ifndef AT_PER_OPERATOR_HEADERS
|
||||
#include <ATen/Functions.h>
|
||||
#include <ATen/NativeFunctions.h>
|
||||
#else
|
||||
#include <ATen/ops/_addmm_activation_native.h>
|
||||
#include <ATen/ops/_efficientzerotensor.h>
|
||||
#include <ATen/ops/_scaled_mm_native.h>
|
||||
#include <ATen/ops/_unsafe_view_native.h>
|
||||
#include <ATen/ops/abs.h>
|
||||
#include <ATen/ops/addmm_native.h>
|
||||
#include <ATen/ops/addmv_native.h>
|
||||
#include <ATen/ops/baddbmm_native.h>
|
||||
#include <ATen/ops/bmm_native.h>
|
||||
#include <ATen/ops/copy_native.h>
|
||||
#include <ATen/ops/dot_native.h>
|
||||
#include <ATen/ops/empty.h>
|
||||
#include <ATen/ops/empty_strided.h>
|
||||
#include <ATen/ops/gelu.h>
|
||||
#include <ATen/ops/max.h>
|
||||
#include <ATen/ops/mm_native.h>
|
||||
#include <ATen/ops/mul.h>
|
||||
#include <ATen/ops/relu.h>
|
||||
#include <ATen/ops/ones.h>
|
||||
#include <ATen/ops/scalar_tensor_native.h>
|
||||
#include <ATen/ops/vdot_native.h>
|
||||
#endif
|
||||
|
||||
using at::blas::ScalingType;
|
||||
using at::blas::SwizzleType;
|
||||
|
||||
namespace at::cuda::scaled {
|
||||
|
||||
static bool _scaled_mm_allowed_device(bool sm90_only=false, bool sm100_only=false) {
|
||||
#ifdef USE_ROCM
|
||||
static const std::vector<std::string> archs = {
|
||||
"gfx942",
|
||||
#if ROCM_VERSION >= 60300
|
||||
"gfx1200", "gfx1201",
|
||||
#endif
|
||||
#if ROCM_VERSION >= 60500
|
||||
"gfx950"
|
||||
#endif
|
||||
};
|
||||
return at::detail::getCUDAHooks().isGPUArch(archs);
|
||||
#else
|
||||
auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||
|
||||
if (sm90_only || sm100_only) {
|
||||
return (sm90_only && dprops->major == 9) || (sm100_only && dprops->major == 10);
|
||||
} else {
|
||||
return dprops->major >= 9 || (dprops->major == 8 && dprops->minor == 9);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
#ifdef USE_ROCM
|
||||
static bool _scaled_mm_is_fnuz() {
|
||||
return at::detail::getCUDAHooks().isGPUArch({"gfx942"});
|
||||
}
|
||||
#endif
|
||||
/**
|
||||
* Track concrete implementations available
|
||||
*/
|
||||
enum class ScaledGemmImplementation {
|
||||
NONE = 0,
|
||||
TENSORWISE_TENSORWISE = 1,
|
||||
ROWWISE_ROWWISE = 2,
|
||||
BLOCK_128x128_1x128 = 3,
|
||||
BLOCK_1x128_128x128 = 4,
|
||||
BLOCK_1x128_1x128 = 5,
|
||||
MXFP8_MXFP8 = 6,
|
||||
NVFP4_NVFP4 = 7,
|
||||
NVFP4_NVFP4_SINGLE_SCALE = 8,
|
||||
MXFP4_MXFP4 = 9,
|
||||
};
|
||||
|
||||
/**
|
||||
* Convert passed int (enum) from python back into a
|
||||
* strictly-typed enum
|
||||
*/
|
||||
template <class EnumType, class ArrayType>
|
||||
std::vector<EnumType> convert_int_to_enum(ArrayType& v) {
|
||||
std::vector<EnumType> converted;
|
||||
converted.reserve(v.size());
|
||||
|
||||
for (auto vi : v) {
|
||||
converted.push_back(static_cast<EnumType>(vi));
|
||||
}
|
||||
return converted;
|
||||
}
|
||||
|
||||
bool check_tensorwise_recipe(c10::ScalarType,
|
||||
std::vector<ScalingType>&,
|
||||
ArrayRef<Tensor>&,
|
||||
c10::ScalarType,
|
||||
std::vector<ScalingType>&,
|
||||
ArrayRef<Tensor>&);
|
||||
|
||||
|
||||
bool check_rowwise_recipe(c10::ScalarType,
|
||||
std::vector<ScalingType>&,
|
||||
ArrayRef<Tensor>&,
|
||||
c10::ScalarType,
|
||||
std::vector<ScalingType>&,
|
||||
ArrayRef<Tensor>&);
|
||||
|
||||
bool check_nvfp4_recipe(c10::ScalarType,
|
||||
std::vector<ScalingType>&,
|
||||
ArrayRef<Tensor>&,
|
||||
c10::ScalarType,
|
||||
std::vector<ScalingType>&,
|
||||
ArrayRef<Tensor>&);
|
||||
|
||||
bool check_nvfp4_recipe_single_scale
|
||||
(c10::ScalarType,
|
||||
std::vector<ScalingType>&,
|
||||
ArrayRef<Tensor>&,
|
||||
c10::ScalarType,
|
||||
std::vector<ScalingType>&,
|
||||
ArrayRef<Tensor>&);
|
||||
|
||||
bool check_deepseek_recipe(ScalingType,
|
||||
ScalingType,
|
||||
c10::ScalarType,
|
||||
std::vector<ScalingType>&,
|
||||
ArrayRef<Tensor>&,
|
||||
c10::ScalarType,
|
||||
std::vector<ScalingType>&,
|
||||
ArrayRef<Tensor>&);
|
||||
|
||||
bool check_mxfp8_recipe(c10::ScalarType,
|
||||
std::vector<ScalingType>&,
|
||||
ArrayRef<Tensor>&,
|
||||
c10::ScalarType,
|
||||
std::vector<ScalingType>&,
|
||||
ArrayRef<Tensor>&);
|
||||
|
||||
bool check_mxfp4_recipe(c10::ScalarType,
|
||||
std::vector<ScalingType>&,
|
||||
ArrayRef<Tensor>&,
|
||||
c10::ScalarType,
|
||||
std::vector<ScalingType>&,
|
||||
ArrayRef<Tensor>&);
|
||||
|
||||
} // namespace at::native::cuda::blas::scaled
|
||||
@ -70,7 +70,11 @@
|
||||
#define ATEN_CUB_MAXIMUM() NO_ROCM(at_cuda_detail)ROCM_HIPCUB(::cub)::Max()
|
||||
#endif
|
||||
|
||||
#if defined(USE_ROCM)
|
||||
#if (!defined(USE_ROCM) && !CUB_SUPPORTS_NV_BFLOAT16()) || defined(USE_ROCM)
|
||||
|
||||
#if !defined(USE_ROCM)
|
||||
namespace at_cuda_detail {
|
||||
#endif
|
||||
|
||||
// backport https://github.com/NVIDIA/cub/pull/306 for c10::BFloat16
|
||||
|
||||
@ -92,6 +96,10 @@ template <>
|
||||
struct ROCM_HIPCUB(cub)::NumericTraits<c10::BFloat16>:
|
||||
ROCM_HIPCUB(cub)::BaseTraits<ROCM_HIPCUB(cub)::FLOATING_POINT, true, false, unsigned short, c10::BFloat16> {};
|
||||
|
||||
#if !defined(USE_ROCM)
|
||||
} // namespace at_cuda_detail
|
||||
#endif
|
||||
|
||||
#endif
|
||||
|
||||
#if !defined(USE_ROCM)
|
||||
@ -113,7 +121,7 @@ struct cuda_type<c10::Half> {
|
||||
using type = __half;
|
||||
};
|
||||
|
||||
#if !defined(USE_ROCM)
|
||||
#if !defined(USE_ROCM) && CUB_SUPPORTS_NV_BFLOAT16()
|
||||
|
||||
template<>
|
||||
struct cuda_type<c10::BFloat16> {
|
||||
@ -195,6 +203,36 @@ __global__ void transform_vals(InputIteratorT1 a, InputIteratorT2 b, OutputItera
|
||||
*out = scan_op(static_cast<acc_t>(*a), static_cast<acc_t>(*b));
|
||||
}
|
||||
|
||||
#if !CUB_SUPPORTS_FUTURE_VALUE()
|
||||
template<typename ValueT, typename InputIteratorT>
|
||||
struct chained_iterator {
|
||||
using iterator_category = std::random_access_iterator_tag;
|
||||
using difference_type = std::ptrdiff_t;
|
||||
using value_type = ValueT;
|
||||
using pointer = ValueT*;
|
||||
using reference = ValueT&;
|
||||
|
||||
InputIteratorT iter;
|
||||
ValueT *first;
|
||||
difference_type offset = 0;
|
||||
|
||||
__device__ ValueT operator[](difference_type i) {
|
||||
i += offset;
|
||||
if (i == 0) {
|
||||
return *first;
|
||||
} else {
|
||||
return ValueT(iter[i - 1]);
|
||||
}
|
||||
}
|
||||
__device__ chained_iterator operator+(difference_type i) {
|
||||
return chained_iterator{iter, first, i};
|
||||
}
|
||||
__device__ ValueT operator*() {
|
||||
return (*this)[0];
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
// even though cub is supposed to support tensors with int_max elements, in reality it doesn't,
|
||||
// so split at int_max/2
|
||||
constexpr int max_cub_size = std::numeric_limits<int>::max() / 2 + 1; // 2**30
|
||||
@ -239,6 +277,25 @@ inline void inclusive_scan(InputIteratorT input, OutputIteratorT output, ScanOpT
|
||||
first_elem_ptr,
|
||||
scan_op);
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
#if !CUB_SUPPORTS_FUTURE_VALUE()
|
||||
using ArgIndexInputIterator = NO_ROCM(at_cuda_detail)::cub::ArgIndexInputIterator<InputIteratorT>;
|
||||
using tuple = typename ArgIndexInputIterator::value_type;
|
||||
auto input_iter_transform = [=] __device__ (const tuple &x)->input_t {
|
||||
if (x.key == 0) {
|
||||
return *first_elem_ptr;
|
||||
} else {
|
||||
return x.value;
|
||||
}
|
||||
};
|
||||
auto input_ = ATEN_CUB_TRANSFORM_ITERATOR(input_t, decltype(input_iter_transform), ArgIndexInputIterator)(
|
||||
ArgIndexInputIterator(input + i), input_iter_transform);
|
||||
CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceScan::InclusiveScan,
|
||||
input_,
|
||||
output + i,
|
||||
scan_op,
|
||||
size_cub,
|
||||
at::cuda::getCurrentCUDAStream());
|
||||
#else
|
||||
CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceScan::ExclusiveScan,
|
||||
input + i + 1,
|
||||
output + i,
|
||||
@ -246,6 +303,7 @@ inline void inclusive_scan(InputIteratorT input, OutputIteratorT output, ScanOpT
|
||||
::at_cuda_detail::cub::FutureValue<input_t>(first_elem_ptr),
|
||||
size_cub,
|
||||
at::cuda::getCurrentCUDAStream());
|
||||
#endif
|
||||
}
|
||||
#endif
|
||||
}
|
||||
@ -497,6 +555,16 @@ inline void exclusive_scan(InputIteratorT input, OutputIteratorT output, ScanOpT
|
||||
first_elem_ptr,
|
||||
scan_op);
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
#if !CUB_SUPPORTS_FUTURE_VALUE()
|
||||
auto input_ = impl::chained_iterator<InitValueT, InputIteratorT>{
|
||||
input + i, first_elem_ptr};
|
||||
CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceScan::InclusiveScan,
|
||||
input_,
|
||||
output + i,
|
||||
scan_op,
|
||||
size_cub,
|
||||
at::cuda::getCurrentCUDAStream());
|
||||
#else
|
||||
CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceScan::ExclusiveScan,
|
||||
input + i,
|
||||
output + i,
|
||||
@ -504,6 +572,7 @@ inline void exclusive_scan(InputIteratorT input, OutputIteratorT output, ScanOpT
|
||||
::at_cuda_detail::cub::FutureValue<InitValueT>(first_elem_ptr),
|
||||
size_cub,
|
||||
at::cuda::getCurrentCUDAStream());
|
||||
#endif
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
@ -10,6 +10,14 @@
|
||||
#define CUB_VERSION 200001
|
||||
#endif
|
||||
|
||||
// cub sort support for __nv_bfloat16 is added to cub 1.13 in:
|
||||
// https://github.com/NVIDIA/cub/pull/306
|
||||
#if CUB_VERSION >= 101300
|
||||
#define CUB_SUPPORTS_NV_BFLOAT16() true
|
||||
#else
|
||||
#define CUB_SUPPORTS_NV_BFLOAT16() false
|
||||
#endif
|
||||
|
||||
// cub support for CUB_WRAPPED_NAMESPACE is added to cub 1.13.1 in:
|
||||
// https://github.com/NVIDIA/cub/pull/326
|
||||
// CUB_WRAPPED_NAMESPACE is defined globally in cmake/Dependencies.cmake
|
||||
@ -20,6 +28,14 @@
|
||||
#define USE_GLOBAL_CUB_WRAPPED_NAMESPACE() false
|
||||
#endif
|
||||
|
||||
// cub support for cub::FutureValue is added to cub 1.15 in:
|
||||
// https://github.com/NVIDIA/cub/pull/305
|
||||
#if CUB_VERSION >= 101500
|
||||
#define CUB_SUPPORTS_FUTURE_VALUE() true
|
||||
#else
|
||||
#define CUB_SUPPORTS_FUTURE_VALUE() false
|
||||
#endif
|
||||
|
||||
// There were many bc-breaking changes in major version release of CCCL v3.0.0
|
||||
// Please see https://nvidia.github.io/cccl/cccl/3.0_migration_guide.html
|
||||
#if CUB_VERSION >= 200800
|
||||
|
||||
@ -1,23 +0,0 @@
|
||||
#include <ATen/detail/XLAHooksInterface.h>
|
||||
|
||||
namespace at {
|
||||
namespace detail {
|
||||
|
||||
const XLAHooksInterface& getXLAHooks() {
|
||||
auto create_impl = [] {
|
||||
// Create XLA hooks using the registry
|
||||
auto hooks = XLAHooksRegistry()->Create("torch_xla::detail::XLAHooks", XLAHooksArgs{});
|
||||
if (hooks) {
|
||||
return hooks;
|
||||
}
|
||||
// If hooks creation fails, fall back to default implementation
|
||||
return std::make_unique<XLAHooksInterface>();
|
||||
};
|
||||
static auto hooks = create_impl();
|
||||
return *hooks;
|
||||
}
|
||||
} // namespace detail
|
||||
|
||||
C10_DEFINE_REGISTRY(XLAHooksRegistry, XLAHooksInterface, XLAHooksArgs)
|
||||
|
||||
} // namespace at
|
||||
@ -1,79 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include <c10/core/Device.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#include <c10/util/Registry.h>
|
||||
|
||||
#include <ATen/detail/AcceleratorHooksInterface.h>
|
||||
|
||||
C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-parameter")
|
||||
|
||||
namespace at {
|
||||
|
||||
constexpr const char* XLA_HELP =
|
||||
"This error has occurred because you are trying "
|
||||
"to use some XLA functionality, but the XLA library has not been "
|
||||
"loaded by the dynamic linker. You must load xla libraries by `import torch_xla`";
|
||||
|
||||
struct TORCH_API XLAHooksInterface : AcceleratorHooksInterface {
|
||||
~XLAHooksInterface() override = default;
|
||||
|
||||
void init() const override {
|
||||
TORCH_CHECK(false, "Cannot initialize XLA without torch_xla library. ", XLA_HELP);
|
||||
}
|
||||
|
||||
virtual bool hasXLA() const {
|
||||
return false;
|
||||
}
|
||||
|
||||
virtual std::string showConfig() const {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"Cannot query detailed XLA version without torch_xla library. ",
|
||||
XLA_HELP);
|
||||
}
|
||||
|
||||
const Generator& getDefaultGenerator(
|
||||
[[maybe_unused]] DeviceIndex device_index = -1) const override {
|
||||
TORCH_CHECK(
|
||||
false, "Cannot get default XLA generator without torch_xla library. ", XLA_HELP);
|
||||
}
|
||||
|
||||
Generator getNewGenerator(
|
||||
[[maybe_unused]] DeviceIndex device_index = -1) const override {
|
||||
TORCH_CHECK(false, "Cannot get XLA generator without torch_xla library. ", XLA_HELP);
|
||||
}
|
||||
|
||||
virtual DeviceIndex getCurrentDevice() const override {
|
||||
TORCH_CHECK(false, "Cannot get current XLA device without torch_xla library. ", XLA_HELP);
|
||||
}
|
||||
|
||||
Device getDeviceFromPtr(void* /*data*/) const override {
|
||||
TORCH_CHECK(false, "Cannot get device of pointer on XLA without torch_xla library. ", XLA_HELP);
|
||||
}
|
||||
|
||||
Allocator* getPinnedMemoryAllocator() const override {
|
||||
TORCH_CHECK(false, "Cannot get XLA pinned memory allocator without torch_xla library. ", XLA_HELP);
|
||||
}
|
||||
|
||||
bool isPinnedPtr(const void* data) const override {
|
||||
return false;
|
||||
}
|
||||
|
||||
bool hasPrimaryContext(DeviceIndex device_index) const override {
|
||||
TORCH_CHECK(false, "Cannot query primary context without torch_xla library. ", XLA_HELP);
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
struct TORCH_API XLAHooksArgs {};
|
||||
|
||||
TORCH_DECLARE_REGISTRY(XLAHooksRegistry, XLAHooksInterface, XLAHooksArgs);
|
||||
#define REGISTER_XLA_HOOKS(clsname) \
|
||||
C10_REGISTER_CLASS(XLAHooksRegistry, clsname, clsname)
|
||||
|
||||
namespace detail {
|
||||
TORCH_API const XLAHooksInterface& getXLAHooks();
|
||||
} // namespace detail
|
||||
} // namespace at
|
||||
C10_DIAGNOSTIC_POP()
|
||||
@ -11,8 +11,6 @@ inline void check_pixel_shuffle_shapes(const Tensor& self, int64_t upscale_facto
|
||||
"pixel_shuffle expects a positive upscale_factor, but got ",
|
||||
upscale_factor);
|
||||
int64_t c = self.size(-3);
|
||||
TORCH_CHECK_VALUE(upscale_factor <= std::numeric_limits<decltype(upscale_factor)>::max() / upscale_factor,
|
||||
"upscale factor is too large, (upscale_factor)^2 overflowed: upscale_factor=", upscale_factor);
|
||||
int64_t upscale_factor_squared = upscale_factor * upscale_factor;
|
||||
TORCH_CHECK(c % upscale_factor_squared == 0,
|
||||
"pixel_shuffle expects its input's 'channel' dimension to be divisible by the square of "
|
||||
|
||||
@ -259,20 +259,11 @@ inline void winograd_f2k3_input_transform_inplace__rvv(
|
||||
const vfloat32m1_t wd1 = __riscv_vfadd_vv_f32m1(d1, d2, 4);
|
||||
const vfloat32m1_t wd2 = __riscv_vfsub_vv_f32m1(d2, d1, 4);
|
||||
const vfloat32m1_t wd3 = __riscv_vfsub_vv_f32m1(d1, d3, 4);
|
||||
/* GCC 14.2 (RISC-V RVV) ICE workaround:
|
||||
* Avoid single-statement read-modify-write on MEM_REF like:
|
||||
* *input_tile_val =
|
||||
* __riscv_vset_v_f32m1_f32m1x4(*input_tile_val, idx, val);
|
||||
* This triggers an ICE during GIMPLE lower (gsi_replace / riscv_gimple_fold_builtin)
|
||||
* with -march=rv64gcv. Use a temporary then write back.
|
||||
* Do NOT refactor into the single-statement form. Clang is unaffected.
|
||||
*/
|
||||
vfloat32m1x4_t tmp_input_tile_val = *input_tile_val;
|
||||
tmp_input_tile_val = __riscv_vset_v_f32m1_f32m1x4(tmp_input_tile_val, 0, wd0);
|
||||
tmp_input_tile_val = __riscv_vset_v_f32m1_f32m1x4(tmp_input_tile_val, 1, wd1);
|
||||
tmp_input_tile_val = __riscv_vset_v_f32m1_f32m1x4(tmp_input_tile_val, 2, wd2);
|
||||
tmp_input_tile_val = __riscv_vset_v_f32m1_f32m1x4(tmp_input_tile_val, 3, wd3);
|
||||
*input_tile_val = tmp_input_tile_val;
|
||||
|
||||
*input_tile_val = __riscv_vset_v_f32m1_f32m1x4(*input_tile_val, 0, wd0);
|
||||
*input_tile_val = __riscv_vset_v_f32m1_f32m1x4(*input_tile_val, 1, wd1);
|
||||
*input_tile_val = __riscv_vset_v_f32m1_f32m1x4(*input_tile_val, 2, wd2);
|
||||
*input_tile_val = __riscv_vset_v_f32m1_f32m1x4(*input_tile_val, 3, wd3);
|
||||
}
|
||||
|
||||
inline void winograd_f2k3_output_transform_inplace__rvv(
|
||||
@ -286,15 +277,9 @@ inline void winograd_f2k3_output_transform_inplace__rvv(
|
||||
const vfloat32m1_t wm0 = __riscv_vfadd_vv_f32m1(m0_plus_m1, m2, 4);
|
||||
const vfloat32m1_t m1_sub_m2 = __riscv_vfsub_vv_f32m1(m1, m2, 4);
|
||||
const vfloat32m1_t wm1 = __riscv_vfsub_vv_f32m1(m1_sub_m2, m3, 4);
|
||||
/* GCC 14.2 (RISC-V RVV) ICE workaround — see note above.
|
||||
* Keep the temporary + write-back pattern to avoid ICE.
|
||||
* Do NOT rewrite into:
|
||||
* *input_tile_val = __riscv_vset_v_f32m1_f32m1x4(*input_tile_val, idx, val);
|
||||
*/
|
||||
vfloat32m1x4_t tmp_output_tile_val = *input_tile_val;
|
||||
tmp_output_tile_val = __riscv_vset_v_f32m1_f32m1x4(tmp_output_tile_val, 0, wm0);
|
||||
tmp_output_tile_val = __riscv_vset_v_f32m1_f32m1x4(tmp_output_tile_val, 1, wm1);
|
||||
*input_tile_val = tmp_output_tile_val;
|
||||
|
||||
*input_tile_val = __riscv_vset_v_f32m1_f32m1x4(*input_tile_val, 0, wm0);
|
||||
*input_tile_val = __riscv_vset_v_f32m1_f32m1x4(*input_tile_val, 1, wm1);
|
||||
}
|
||||
|
||||
inline vfloat32m1_t
|
||||
@ -315,17 +300,11 @@ inline void winograd_f2k3_kernel_transform__rvv(
|
||||
const vfloat32m1_t const_half = __riscv_vfmv_v_f_f32m1(0.5f, 4);
|
||||
const vfloat32m1_t g0_plus_g2 = __riscv_vfadd_vv_f32m1(g0, g2, 4);
|
||||
vfloat32m1_t half_g0_plus_g2 = __riscv_vfmul_vv_f32m1(const_half, g0_plus_g2, 4);
|
||||
/* GCC 14.2 (RISC-V RVV) ICE workaround — see note above.
|
||||
* Keep the temporary + write-back pattern to avoid ICE.
|
||||
* Do NOT rewrite into:
|
||||
* *transform = __riscv_vset_v_f32m1_f32m1x4(*transform, idx, val);
|
||||
*/
|
||||
vfloat32m1x4_t tmp_transform = *transform;
|
||||
tmp_transform = __riscv_vset_v_f32m1_f32m1x4(tmp_transform, 0, g0);
|
||||
tmp_transform = __riscv_vset_v_f32m1_f32m1x4(tmp_transform, 1, vmuladdq_f32(half_g0_plus_g2, const_half, g1));
|
||||
tmp_transform = __riscv_vset_v_f32m1_f32m1x4(tmp_transform, 2, vmulsubq_f32(half_g0_plus_g2, const_half, g1));
|
||||
tmp_transform = __riscv_vset_v_f32m1_f32m1x4(tmp_transform, 3, g2);
|
||||
*transform = tmp_transform;
|
||||
|
||||
*transform = __riscv_vset_v_f32m1_f32m1x4(*transform, 0, g0);
|
||||
*transform = __riscv_vset_v_f32m1_f32m1x4(*transform, 1, vmuladdq_f32(half_g0_plus_g2, const_half, g1));
|
||||
*transform = __riscv_vset_v_f32m1_f32m1x4(*transform, 2, vmulsubq_f32(half_g0_plus_g2, const_half, g1));
|
||||
*transform = __riscv_vset_v_f32m1_f32m1x4(*transform, 3, g2);
|
||||
}
|
||||
|
||||
inline vfloat32m1x4_t v4f_transpose4x4__rvv(const vfloat32m1x4_t m) {
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -1,574 +0,0 @@
|
||||
#include <cstdint>
|
||||
#include <c10/util/typeid.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#include <c10/util/SmallVector.h>
|
||||
#include <c10/core/Scalar.h>
|
||||
#include <c10/core/ScalarType.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
|
||||
#include <ATen/core/Tensor.h>
|
||||
#include <ATen/core/NamedTensor.h>
|
||||
#include <ATen/Dispatch.h>
|
||||
#include <ATen/ExpandUtils.h>
|
||||
#include <ATen/OpMathType.h>
|
||||
#include <ATen/TensorUtils.h>
|
||||
#include <ATen/cuda/CUDABlas.h>
|
||||
#include <ATen/cuda/CUDAScaledBlas.h>
|
||||
#include <ATen/cuda/tunable/Tunable.h>
|
||||
#include <ATen/cuda/tunable/TunableGemm.h>
|
||||
#include <ATen/native/Resize.h>
|
||||
#include <c10/util/MaybeOwned.h>
|
||||
#include <ATen/native/GroupedMMUtils.h>
|
||||
#include <ATen/native/cuda/RowwiseScaledMM.h>
|
||||
#include <ATen/native/cuda/ScaledGroupMM.h>
|
||||
#include <ATen/native/cuda/GroupMM.h>
|
||||
#include <ATen/ceil_div.h>
|
||||
|
||||
#ifdef USE_FBGEMM_GENAI
|
||||
#include <fbgemm_gpu/torch_ops.h>
|
||||
#endif
|
||||
|
||||
#ifndef AT_PER_OPERATOR_HEADERS
|
||||
#include <ATen/Functions.h>
|
||||
#include <ATen/NativeFunctions.h>
|
||||
#else
|
||||
#include <ATen/ops/_addmm_activation_native.h>
|
||||
#include <ATen/ops/_efficientzerotensor.h>
|
||||
#include <ATen/ops/_scaled_mm_native.h>
|
||||
#include <ATen/ops/_unsafe_view_native.h>
|
||||
#include <ATen/ops/abs.h>
|
||||
#include <ATen/ops/addmm_native.h>
|
||||
#include <ATen/ops/addmv_native.h>
|
||||
#include <ATen/ops/baddbmm_native.h>
|
||||
#include <ATen/ops/bmm_native.h>
|
||||
#include <ATen/ops/copy_native.h>
|
||||
#include <ATen/ops/dot_native.h>
|
||||
#include <ATen/ops/empty.h>
|
||||
#include <ATen/ops/empty_strided.h>
|
||||
#include <ATen/ops/gelu.h>
|
||||
#include <ATen/ops/max.h>
|
||||
#include <ATen/ops/mm_native.h>
|
||||
#include <ATen/ops/mul.h>
|
||||
#include <ATen/ops/relu.h>
|
||||
#include <ATen/ops/ones.h>
|
||||
#include <ATen/ops/scalar_tensor_native.h>
|
||||
#include <ATen/ops/vdot_native.h>
|
||||
#endif
|
||||
|
||||
using at::blas::ScalingType;
|
||||
using at::blas::SwizzleType;
|
||||
|
||||
namespace scaled_blas = at::cuda::scaled;
|
||||
using scaled_blas::ScaledGemmImplementation;
|
||||
using scaled_blas::convert_int_to_enum;
|
||||
using scaled_blas::_scaled_mm_allowed_device;
|
||||
|
||||
namespace at::native {
|
||||
|
||||
namespace {
|
||||
|
||||
// 2d-2d and 2d-3d
|
||||
// scaling=MXFP8
|
||||
// CUDA-only
|
||||
Tensor&
|
||||
_mx8_mx8_bf16_grouped_mm_fbgemm(
|
||||
const Tensor& mat_a,
|
||||
const Tensor& mat_b,
|
||||
const Tensor& scale_a,
|
||||
const SwizzleType& swizzle_a,
|
||||
const Tensor& scale_b,
|
||||
const SwizzleType& swizzle_b,
|
||||
const std::optional<at::Tensor>& offs,
|
||||
Tensor& out) {
|
||||
const bool a_is_2d = mat_a.dim() == 2;
|
||||
const bool b_is_2d = mat_b.dim() == 2;
|
||||
bool b_is_3d = mat_b.dim() == 3;
|
||||
bool is_2d_2d = a_is_2d && b_is_2d;
|
||||
bool is_2d_3d = a_is_2d && b_is_3d;
|
||||
TORCH_CHECK_VALUE(is_2d_2d || is_2d_3d, "MXFP8 grouped GEMM currently only supports 2d-2d and 2d-3d cases");
|
||||
TORCH_CHECK_VALUE(offs.has_value(), "MXFP8 2d-2d and 2d-3d grouped GEMMs requires offsets");
|
||||
TORCH_CHECK_VALUE(out.scalar_type() == at::kBFloat16, "Only bf16 out_dtype is supported for MXFP8 grouped gemm");
|
||||
// MXFP8 expects float8_e8m0fnu scales.
|
||||
TORCH_CHECK_VALUE(scale_a.scalar_type() == at::kFloat8_e8m0fnu && scale_b.scalar_type() == at::kFloat8_e8m0fnu,
|
||||
"For MXFP8 grouped gemm, both scales must be float8_e8m0fnu tensors.");
|
||||
#ifdef USE_ROCM
|
||||
TORCH_CHECK_VALUE(swizzle_a == SwizzleType::NO_SWIZZLE && swizzle_b == SwizzleType::NO_SWIZZLE,
|
||||
"For ROCM MXFP8 grouped gemm, both scale swizzle types must be SWIZZLE_NONE");
|
||||
#else
|
||||
TORCH_CHECK_VALUE(swizzle_a == SwizzleType::SWIZZLE_32_4_4 && swizzle_b == SwizzleType::SWIZZLE_32_4_4,
|
||||
"For CUDA MXFP8 grouped gemm, both scale swizzle types must be SWIZZLE_32_4_4");
|
||||
#endif
|
||||
|
||||
#if defined(USE_FBGEMM_GENAI) and !defined(USE_ROCM)
|
||||
fbgemm_gpu::mx8mx8bf16_grouped_mm(
|
||||
mat_a,
|
||||
mat_b,
|
||||
scale_a,
|
||||
scale_b,
|
||||
offs.value(),
|
||||
out);
|
||||
#else
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(false, "mxfp8_mxfp8 grouped gemm requires compile with USE_FBGEMM_GENAI");
|
||||
#endif
|
||||
return out;
|
||||
}
|
||||
|
||||
// 2d-2d and 2d-3d cases
|
||||
// scaling=rowwise
|
||||
// CUDA-only
|
||||
Tensor&
|
||||
_f8_f8_bf16_rowwise_grouped_mm_cuda(
|
||||
const Tensor& mat_a,
|
||||
const Tensor& mat_b,
|
||||
const Tensor& scale_a,
|
||||
const Tensor& scale_b,
|
||||
const std::optional<Tensor>& offs,
|
||||
const std::optional<Tensor>& bias,
|
||||
const bool use_fast_accum,
|
||||
Tensor& out) {
|
||||
TORCH_CHECK_VALUE(mat_a.dtype() == at::kFloat8_e4m3fn, "Expected mat_a to be Float8_e4m3 matrix got ", mat_a.scalar_type());
|
||||
TORCH_CHECK_VALUE(mat_b.dtype() == at::kFloat8_e4m3fn, "Expected mat_a to be Float8_e4m3 matrix got ", mat_b.scalar_type());
|
||||
|
||||
at::cuda::detail::f8f8bf16_grouped_mm(
|
||||
mat_a,
|
||||
mat_b,
|
||||
scale_a,
|
||||
scale_b,
|
||||
offs,
|
||||
bias,
|
||||
use_fast_accum,
|
||||
out);
|
||||
return out;
|
||||
}
|
||||
|
||||
// 2d-2d and 2d-3d cases
|
||||
// scaling=rowwise
|
||||
// only being called for rocm
|
||||
Tensor&
|
||||
_f8_f8_bf16_rowwise_grouped_mm_rocm(
|
||||
const Tensor& mat_a,
|
||||
const Tensor& mat_b,
|
||||
const Tensor& scale_a,
|
||||
const Tensor& scale_b,
|
||||
const std::optional<Tensor>& offs,
|
||||
Tensor& out) {
|
||||
TORCH_CHECK_VALUE(mat_a.dtype() == at::kFloat8_e4m3fnuz, "Expected mat_a to be Float8_e4m3fnuz matrix got ", mat_a.scalar_type());
|
||||
TORCH_CHECK_VALUE(mat_b.dtype() == at::kFloat8_e4m3fnuz, "Expected mat_a to be Float8_e4m3fnuz matrix got ", mat_b.scalar_type());
|
||||
|
||||
#if defined(USE_FBGEMM_GENAI) && defined(USE_ROCM)
|
||||
fbgemm_gpu::f8f8bf16_rowwise_grouped_mm(
|
||||
mat_a,
|
||||
// FBGEMM expects B matrix shape to be (.., N, K)
|
||||
mat_b.transpose(-2, -1),
|
||||
scale_a,
|
||||
scale_b,
|
||||
offs,
|
||||
out);
|
||||
#else
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(false, "grouped gemm is not supported without USE_FBGEMM_GENAI on ROCM")
|
||||
#endif
|
||||
return out;
|
||||
|
||||
}
|
||||
|
||||
// Dispatch f8 x f8 -> bf16 row-wise scaled to rocm/cuda
|
||||
Tensor&
|
||||
_f8_f8_bf16_rowwise_grouped_mm(
|
||||
const Tensor& mat_a,
|
||||
const Tensor& mat_b,
|
||||
const Tensor& scale_a,
|
||||
const Tensor& scale_b,
|
||||
const std::optional<Tensor>& offs,
|
||||
const std::optional<Tensor>& bias,
|
||||
bool use_fast_accum,
|
||||
Tensor& out) {
|
||||
// FP8 per-tensor and per-row scaling expect fp32 scales.
|
||||
TORCH_CHECK_VALUE(scale_a.scalar_type() == kFloat && scale_b.scalar_type() == kFloat,
|
||||
"For grouped FP8 rowwise, both scales must be float32 tensors");
|
||||
#ifndef USE_ROCM
|
||||
return _f8_f8_bf16_rowwise_grouped_mm_cuda(
|
||||
mat_a,
|
||||
mat_b,
|
||||
scale_a,
|
||||
scale_b,
|
||||
offs,
|
||||
bias,
|
||||
use_fast_accum,
|
||||
out);
|
||||
#else
|
||||
// NOTE: ignore use_fast_accum
|
||||
TORCH_CHECK_VALUE(!bias.has_value(), "ROCM grouped gemm does not support bias")
|
||||
return _f8_f8_bf16_rowwise_grouped_mm_rocm(
|
||||
mat_a,
|
||||
mat_b,
|
||||
scale_a,
|
||||
scale_b,
|
||||
offs,
|
||||
out);
|
||||
#endif
|
||||
}
|
||||
|
||||
void _check_scales_fp8_rowwise(const Tensor& mat, const Tensor& scale, const int dim, const int arg_idx, const int scale_multiplier=1) {
|
||||
// Checks scales for 2d or 3d target tensors (`mat`).
|
||||
if (mat.dim() == 2) {
|
||||
TORCH_CHECK(
|
||||
scale.dim() == 1,
|
||||
"scale must be a 1D tensor, but got ",
|
||||
scale.dim(),
|
||||
"D, arg ",
|
||||
arg_idx);
|
||||
TORCH_CHECK(
|
||||
scale.is_contiguous(), "scale must be contiguous for arg ", arg_idx);
|
||||
TORCH_CHECK(
|
||||
scale.size(0) == mat.size(dim) * scale_multiplier,
|
||||
"scale must have the same length as mat for arg ",
|
||||
arg_idx);
|
||||
} else {
|
||||
TORCH_CHECK(
|
||||
scale.dim() == 2,
|
||||
"scale must be a 2D tensor, but got ",
|
||||
scale.dim(),
|
||||
"D for arg ",
|
||||
arg_idx);
|
||||
TORCH_CHECK(
|
||||
scale.stride(1) == 1,
|
||||
"scale must be contiguous in the last dimension for arg ",
|
||||
arg_idx);
|
||||
TORCH_CHECK(
|
||||
scale.size(0) == mat.size(0),
|
||||
"scale must have the same batch dimension as mat for arg ",
|
||||
arg_idx);
|
||||
TORCH_CHECK(
|
||||
scale.size(1) == mat.size(1 + dim),
|
||||
"scale must have the same first dimension as mat for arg ",
|
||||
arg_idx);
|
||||
}
|
||||
}
|
||||
|
||||
void _check_scales_mxfp8(const Tensor& mat, const Tensor& scale, const int dim, const int arg_idx) {
|
||||
// Checks scales for 2d or 3d target tensors (`mat`).
|
||||
if (mat.dim() == 2) {
|
||||
// For MXFP8, 2d tensors have variable size groups represented as subtensors,
|
||||
// that are converted to blocked padded format individually,
|
||||
// so we can't check the scale sizes without doing a d2h sync to get the group sizes here.
|
||||
TORCH_CHECK(
|
||||
scale.dim() == mat.dim(),
|
||||
"for mxfp8, scale must have same number of dimensions as parent tensor, but got mat.dim() = ", mat.dim(), " and scale.dim() = ", scale.dim(), " for arg ", arg_idx);
|
||||
|
||||
// LHS mat shape (M, total_K) -> scale shape (rounded_up(M, 128), rounded_up_per_group(K/32, 4))
|
||||
// RHS mat shape (total_K, N) -> scale shape (rounded_up(N, 128), rounded_up_per_group(K/32, 4))
|
||||
// * weight is transposed prior to the call, scale stays non-transposed.
|
||||
bool LHS = arg_idx == 0;
|
||||
int scale_dim_to_check = 0;
|
||||
int mat_dim_to_check = LHS ? 0 : 1;
|
||||
TORCH_CHECK(
|
||||
scale.size(scale_dim_to_check) >= mat.size(mat_dim_to_check),
|
||||
"for mxfp8, arg ", arg_idx, " tensor shape (", mat.size(0), ", ", mat.size(1), ") ",
|
||||
"must have scale.shape[", scale_dim_to_check, "] >= ", mat.size(mat_dim_to_check), " but got scale.shape=(", scale.size(0), ", ", scale.size(1), ")");
|
||||
} else {
|
||||
// For MXFP8, 3d tensors have static group sizes (stack of 2d tensors),
|
||||
// so we can check the exact expected scale sizes here without a d2h sync.
|
||||
auto round_up = [](auto x, auto y) {
|
||||
return ((x + y - 1) / y) * y;
|
||||
};
|
||||
|
||||
// TODO: this is for 3d tensor in 2d-3d case specifically.
|
||||
// We'll need to support 3d-3d and 3d-2d cases once mxfp8 grouped gemm supports them.
|
||||
int64_t G = mat.size(0);
|
||||
int64_t K = mat.size(1);
|
||||
int64_t N = mat.size(2);
|
||||
int64_t blocked_scale_K = round_up(K/32, 4);
|
||||
int64_t blocked_scale_N = round_up(N, 128);
|
||||
|
||||
// fbgemm expects stack of flattened blocked scales for 3d tensor, shape (G, blocked_scale_K * blocked_scale_N).
|
||||
TORCH_CHECK(
|
||||
scale.dim() == mat.dim() - 1,
|
||||
"for mxfp8 2d-3d grouped GEMM, the 3d tensor of shape (G,K,N) must have a 2d scale of shape (G, blocked_scale_K * blocked_scale_N), but scale is ", scale.dim(), "D for arg ", arg_idx
|
||||
);
|
||||
TORCH_CHECK(
|
||||
scale.size(0) == G && scale.size(1) == blocked_scale_K * blocked_scale_N,
|
||||
"for mxfp8, the tensor shape (", G, ", ", K, ", ", N, ") must have scale shape (", G, ",", blocked_scale_K, ",", blocked_scale_N, ") for arg ", arg_idx
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
void check_scale(const Tensor& mat, const Tensor& scale, const int dim, const int arg_idx, const int scale_multiplier=1) {
|
||||
bool using_fp8_rowwise = scale.scalar_type() == kFloat;
|
||||
bool using_mxfp8 = scale.scalar_type() == at::kFloat8_e8m0fnu;
|
||||
if (using_fp8_rowwise) {
|
||||
_check_scales_fp8_rowwise(mat, scale, dim, arg_idx, scale_multiplier);
|
||||
} else if (using_mxfp8) {
|
||||
_check_scales_mxfp8(mat, scale, dim, arg_idx);
|
||||
} else {
|
||||
TORCH_CHECK(false, "scale must be float32 or float8_e8m0fnu, but got ", scale.dtype());
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
Tensor
|
||||
_scaled_grouped_mm_cuda(
|
||||
const Tensor& mat_a,
|
||||
const Tensor& mat_b,
|
||||
const Tensor& scale_a,
|
||||
const Tensor& scale_b,
|
||||
const std::optional<at::Tensor>& offs,
|
||||
const std::optional<at::Tensor>& bias,
|
||||
const std::optional<at::Tensor>& scale_result,
|
||||
std::optional<c10::ScalarType> out_dtype,
|
||||
bool use_fast_accum) {
|
||||
bool allowed_device = _scaled_mm_allowed_device(/*sm90_only*/true, /*sm100_only*/true);
|
||||
TORCH_CHECK_VALUE(allowed_device, "torch._scaled_grouped_mm is only supported on CUDA devices with compute capability = [9.0, 10.0], or ROCm MI300+");
|
||||
|
||||
TORCH_CHECK_VALUE(!check_valid_strides_and_return_transposed(mat_a), "Expected mat1 to not be transposed");
|
||||
TORCH_CHECK_VALUE(check_valid_strides_and_return_transposed(mat_b), "Expected mat2 to be transposed");
|
||||
TORCH_CHECK_VALUE(mat_a.dim() == 2 || mat_a.dim() == 3, "mat_a has to be 2 or 3d");
|
||||
TORCH_CHECK_VALUE(mat_b.dim() == 2 || mat_b.dim() == 3, "mat_b has to be 2 or 3d");
|
||||
const bool a_is_2d = mat_a.dim() == 2;
|
||||
const bool b_is_2d = mat_b.dim() == 2;
|
||||
|
||||
// NOTE(slayton): For sub-1B formats want contraction_dim argument?
|
||||
if (!a_is_2d || !b_is_2d) {
|
||||
TORCH_CHECK_VALUE(mat_a.size(-1) == mat_b.size(-2), "contraction dimension of mat_a and mat_b must match");
|
||||
}
|
||||
TORCH_CHECK_VALUE(
|
||||
mat_a.size(-1) % 16 == 0,
|
||||
"Expected trailing dimension of mat_a to be divisible by 16 ",
|
||||
"but got mat1 shape: (",
|
||||
mat_a.sizes(),
|
||||
").");
|
||||
TORCH_CHECK_VALUE(mat_b.size(-2) % 16 == 0 && mat_b.size(-1) % 16 == 0,
|
||||
"Expected mat_b shape to be divisible by 16 ",
|
||||
"but got mat_b shape: (",
|
||||
mat_b.sizes(),
|
||||
").");
|
||||
|
||||
|
||||
TORCH_CHECK_VALUE(!bias.has_value(), "Bias not supported yet");
|
||||
TORCH_CHECK_VALUE(!scale_result.has_value(), "Scale result not supported yet");
|
||||
TORCH_CHECK_VALUE(offs.has_value() == (a_is_2d || b_is_2d), "Have to provide offsets if there is a 2d matrix");
|
||||
|
||||
// NOTE: mxfp8 x mxfp8 requires (and asserts later) that offsets is present.
|
||||
// for rowwise, no offsets implies 3d-3d and is handled by lower-level
|
||||
// routines
|
||||
if (offs.has_value()) {
|
||||
TORCH_CHECK_VALUE(offs->dim() == 1, "offs has to be 1D");
|
||||
TORCH_CHECK_VALUE(offs->dtype() == at::kInt, "Offsets have to be int32");
|
||||
}
|
||||
// FP8 per-tensor and per-row scaling expect fp32 scales.
|
||||
// MXFP8 expects float8_e8m0fnu scales.
|
||||
TORCH_CHECK_VALUE(
|
||||
(scale_a.scalar_type() == kFloat && scale_b.scalar_type() == kFloat) ||
|
||||
(scale_a.scalar_type() == at::kFloat8_e8m0fnu && scale_b.scalar_type() == at::kFloat8_e8m0fnu),
|
||||
"For FP8 tensorwise and rowwise, both scales must both be float32 tensors. For MXFP8, scales must both be float8_e8m0fnu tensors.");
|
||||
|
||||
const int scale_multiplier = (mat_a.dim() == 2 && mat_b.dim() == 2) ? offs->size(0) : 1;
|
||||
check_scale(mat_a, scale_a, 0 ,0, scale_multiplier);
|
||||
check_scale(mat_b, scale_b, 1, 1, scale_multiplier);
|
||||
|
||||
const auto out_dtype_ = out_dtype.value_or(kBFloat16);
|
||||
TORCH_CHECK_VALUE(out_dtype_ == kBFloat16, "Only bf16 high precision output types are supported for grouped gemm");
|
||||
|
||||
Tensor out = create_grouped_gemm_output_tensor(mat_a, mat_b, offs, out_dtype_);
|
||||
|
||||
#if defined(USE_FBGEMM_GENAI) && defined(USE_CUDA) && !defined(USE_ROCM)
|
||||
// MXFP8 grouped GEMM dispatching
|
||||
bool is_mx8mx8bf16 = (
|
||||
mat_a.scalar_type() == at::kFloat8_e4m3fn && mat_b.scalar_type() == at::kFloat8_e4m3fn &&
|
||||
scale_a.scalar_type() == at::kFloat8_e8m0fnu && scale_b.scalar_type() == at::kFloat8_e8m0fnu
|
||||
);
|
||||
#else
|
||||
bool is_mx8mx8bf16 = false;
|
||||
#endif
|
||||
|
||||
if (is_mx8mx8bf16) {
|
||||
// Note: Passing implied SwizzleType here, correctness of scale previously checked
|
||||
// in `check_scale` call
|
||||
return _mx8_mx8_bf16_grouped_mm_fbgemm(
|
||||
mat_a,
|
||||
mat_b,
|
||||
scale_a,
|
||||
SwizzleType::SWIZZLE_32_4_4,
|
||||
scale_b,
|
||||
SwizzleType::SWIZZLE_32_4_4,
|
||||
offs.value(),
|
||||
out);
|
||||
}
|
||||
|
||||
// If we're not MXFP8, then we're row-wise scaling.
|
||||
return _f8_f8_bf16_rowwise_grouped_mm(
|
||||
mat_a,
|
||||
mat_b,
|
||||
scale_a,
|
||||
scale_b,
|
||||
offs,
|
||||
bias,
|
||||
use_fast_accum,
|
||||
out);
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
using acceptance_fn = std::function<bool(c10::ScalarType, std::vector<ScalingType>&, ArrayRef<Tensor>&, c10::ScalarType, std::vector<ScalingType>&, ArrayRef<Tensor>&)>;
|
||||
|
||||
std::array<std::tuple<std::string, acceptance_fn, ScaledGemmImplementation>, 2> scale_grouped_kernel_dispatch = {{
|
||||
{ "rowwise_rowwise", scaled_blas::check_rowwise_recipe, ScaledGemmImplementation::ROWWISE_ROWWISE},
|
||||
{ "mxfp8_mxfp8", scaled_blas::check_mxfp8_recipe, ScaledGemmImplementation::MXFP8_MXFP8}}};
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
Tensor
|
||||
_scaled_grouped_mm_cuda_v2(
|
||||
const Tensor& mat_a, const Tensor& mat_b,
|
||||
ArrayRef<Tensor> scale_a,
|
||||
IntArrayRef scale_recipe_a,
|
||||
IntArrayRef swizzle_a,
|
||||
ArrayRef<Tensor> scale_b,
|
||||
IntArrayRef scale_recipe_b,
|
||||
IntArrayRef swizzle_b,
|
||||
const std::optional<Tensor>& offs,
|
||||
const std::optional<Tensor>& bias,
|
||||
const std::optional<c10::ScalarType> out_dtype,
|
||||
IntArrayRef contraction_dim,
|
||||
bool use_fast_accum) {
|
||||
bool allowed_device = _scaled_mm_allowed_device(/*sm90_only*/true, /*sm100_only*/true);
|
||||
TORCH_CHECK_VALUE(allowed_device, "torch._scaled_grouped_mm is only supported on CUDA devices with compute capability = [9.0, 10.0], or ROCm MI300+");
|
||||
|
||||
TORCH_CHECK_VALUE(!check_valid_strides_and_return_transposed(mat_a), "Expected mat1 to not be transposed");
|
||||
TORCH_CHECK_VALUE(check_valid_strides_and_return_transposed(mat_b), "Expected mat2 to be transposed");
|
||||
TORCH_CHECK_VALUE(mat_a.dim() == 2 || mat_a.dim() == 3, "mat_a has to be 2 or 3d");
|
||||
TORCH_CHECK_VALUE(mat_b.dim() == 2 || mat_b.dim() == 3, "mat_b has to be 2 or 3d");
|
||||
const bool a_is_2d = mat_a.dim() == 2;
|
||||
const bool b_is_2d = mat_b.dim() == 2;
|
||||
|
||||
// NOTE(slayton): For sub-1B formats want contraction_dim argument?
|
||||
if (!a_is_2d || !b_is_2d) {
|
||||
if (contraction_dim.size() > 0) {
|
||||
const int dim_a = contraction_dim[0], dim_b = mat_b.size(contraction_dim[1]);
|
||||
TORCH_CHECK_VALUE(mat_a.size(dim_a) == mat_b.size(dim_b),
|
||||
"Contraction dimensions (", dim_a, ",", dim_b, ") of mat_a and mat_b must match, got: ", mat_a.size(dim_a), " and ",
|
||||
mat_b.size(dim_b));
|
||||
// Note: only (-1, -2) is currently supported
|
||||
TORCH_CHECK_VALUE(dim_a == -1 && dim_b == -2, "Curently contraction dims must be (-1, -2) only");
|
||||
} else {
|
||||
TORCH_CHECK_VALUE(mat_a.size(-1) == mat_b.size(-2), "contraction dimension of mat_a and mat_b must match");
|
||||
}
|
||||
}
|
||||
TORCH_CHECK_VALUE(
|
||||
mat_a.size(-1) % 16 == 0,
|
||||
"Expected trailing dimension of mat_a to be divisible by 16 ",
|
||||
"but got mat1 shape: (",
|
||||
mat_a.sizes(),
|
||||
").");
|
||||
TORCH_CHECK_VALUE(mat_b.size(-2) % 16 == 0 && mat_b.size(-1) % 16 == 0,
|
||||
"Expected mat_b shape to be divisible by 16 ",
|
||||
"but got mat_b shape: (",
|
||||
mat_b.sizes(),
|
||||
").");
|
||||
|
||||
TORCH_CHECK_VALUE(!bias.has_value(), "Bias not supported yet");
|
||||
TORCH_CHECK_VALUE(offs.has_value() == (a_is_2d || b_is_2d), "Have to provide offsets if there is a 2d matrix");
|
||||
|
||||
// NOTE: mxfp8 x mxfp8 requires (and asserts later) that offsets is present.
|
||||
// for rowwise, no offsets implies 3d-3d and is handled by lower-level
|
||||
// routines
|
||||
if (offs.has_value()) {
|
||||
TORCH_CHECK_VALUE(offs->dim() == 1, "offs has to be 1D");
|
||||
TORCH_CHECK_VALUE(offs->dtype() == at::kInt, "Offsets have to be int32");
|
||||
}
|
||||
|
||||
const auto out_dtype_ = out_dtype.value_or(kBFloat16);
|
||||
TORCH_CHECK_VALUE(out_dtype_ == kBFloat16, "Only bf16 high precision output types are supported for grouped gemm");
|
||||
|
||||
Tensor out = create_grouped_gemm_output_tensor(mat_a, mat_b, offs, out_dtype_);
|
||||
|
||||
// Conversion of implicitly-defined enums to explicit
|
||||
auto scale_recipe_a_enum = convert_int_to_enum<ScalingType>(scale_recipe_a);
|
||||
auto swizzle_a_enum = convert_int_to_enum<SwizzleType>(swizzle_a);
|
||||
auto scale_recipe_b_enum = convert_int_to_enum<ScalingType>(scale_recipe_b);
|
||||
auto swizzle_b_enum = convert_int_to_enum<SwizzleType>(swizzle_b);
|
||||
|
||||
// at this point we can start working out what we want to be doing
|
||||
// Try to do as few steps as possible.
|
||||
// NOTE: support is deliberately sparse, can explicitly enumerate all combinations allowed.
|
||||
// Do this via a list of defined (name, acceptance, concrete_impl) tuples.
|
||||
ScaledGemmImplementation gemm_impl = ScaledGemmImplementation::NONE;
|
||||
for (const auto& fn_entry : scale_grouped_kernel_dispatch) {
|
||||
const auto [name, accept_fn, scaled_gemm_impl] = fn_entry;
|
||||
bool ok = accept_fn(mat_a.scalar_type(),
|
||||
scale_recipe_a_enum,
|
||||
scale_a,
|
||||
mat_b.scalar_type(),
|
||||
scale_recipe_b_enum,
|
||||
scale_b);
|
||||
if (ok) {
|
||||
gemm_impl = scaled_gemm_impl;
|
||||
break;
|
||||
}
|
||||
}
|
||||
TORCH_CHECK_VALUE(gemm_impl != ScaledGemmImplementation::NONE,
|
||||
"No gemm implementation was found");
|
||||
|
||||
switch (gemm_impl) {
|
||||
case ScaledGemmImplementation::ROWWISE_ROWWISE: {
|
||||
const int scale_multiplier = (mat_a.dim() == 2 && mat_b.dim() == 2) ? offs->size(0) : 1;
|
||||
_check_scales_fp8_rowwise(mat_a, scale_a[0], 0 /* dim */ , 0 /* arg_idx */, scale_multiplier);
|
||||
_check_scales_fp8_rowwise(mat_b, scale_b[0], 1 /* dim */ , 1 /* arg_idx */, scale_multiplier);
|
||||
return _f8_f8_bf16_rowwise_grouped_mm(
|
||||
mat_a,
|
||||
mat_b,
|
||||
scale_a[0],
|
||||
scale_b[0],
|
||||
offs,
|
||||
bias,
|
||||
use_fast_accum,
|
||||
out);
|
||||
}
|
||||
case ScaledGemmImplementation::MXFP8_MXFP8: {
|
||||
_check_scales_mxfp8(mat_a, scale_a[0], 0 /* dim */, 0 /* arg_idx */);
|
||||
_check_scales_mxfp8(mat_b, scale_b[0], 1 /* dim */, 1 /* arg_idx */);
|
||||
return _mx8_mx8_bf16_grouped_mm_fbgemm(
|
||||
mat_a,
|
||||
mat_b,
|
||||
scale_a[0],
|
||||
swizzle_a_enum[0],
|
||||
scale_b[0],
|
||||
swizzle_b_enum[0],
|
||||
offs.value(),
|
||||
out);
|
||||
}
|
||||
default:
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(false,
|
||||
"_scaled_grouped_mm_cuda_v2 is in an inconsistent state - should never reach here");
|
||||
}
|
||||
}
|
||||
|
||||
Tensor _grouped_mm_cuda(const Tensor& mat_a, const Tensor& mat_b,
|
||||
const std::optional<at::Tensor>& offs,
|
||||
const std::optional<at::Tensor>& bias,
|
||||
std::optional<c10::ScalarType> out_dtype) {
|
||||
_grouped_mm_validate_inputs(mat_a, mat_b, offs, bias, out_dtype);
|
||||
bool a_b_and_out_are_bf16 = (
|
||||
mat_a.dtype() == at::kBFloat16 &&
|
||||
mat_b.dtype() == at::kBFloat16 &&
|
||||
out_dtype.value_or(at::kBFloat16) == at::kBFloat16
|
||||
);
|
||||
#ifndef USE_ROCM
|
||||
bool use_fast_path = _scaled_mm_allowed_device(/*sm90_only*/true, /*sm100_only*/true) && a_b_and_out_are_bf16;
|
||||
#else
|
||||
// _scaled_mm_allowed_device is used here within _grouped_mm_cuda which seems incorrect since scale is not used.
|
||||
// the _grouped_mm_fallback should be safe for any ROCm GPU since it's just calling typical mm/bmm
|
||||
bool use_fast_path = false;
|
||||
#endif
|
||||
const auto out_dtype_ = _resolve_grouped_mm_out_dtype(mat_a, mat_b, out_dtype);
|
||||
Tensor out = create_grouped_gemm_output_tensor(mat_a, mat_b, offs, out_dtype_);
|
||||
if (use_fast_path) {
|
||||
// fast path, no d2h sync needed
|
||||
at::cuda::detail::bf16bf16_grouped_mm(mat_a, mat_b, offs, bias, out);
|
||||
} else {
|
||||
_grouped_mm_fallback(mat_a, mat_b, offs, bias, out_dtype, out);
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
} // namespace at::native
|
||||
@ -1,17 +1,18 @@
|
||||
#pragma once
|
||||
|
||||
#include <ATen/OpMathType.h>
|
||||
#include <ATen/cuda/detail/OffsetCalculator.cuh>
|
||||
#include <ATen/detail/FunctionTraits.h>
|
||||
#include <ATen/native/TensorIterator.h>
|
||||
#include <ATen/native/TensorIteratorDynamicCasting.h>
|
||||
#include <ATen/cuda/detail/OffsetCalculator.cuh>
|
||||
#include <ATen/OpMathType.h>
|
||||
#include <ATen/native/cuda/thread_constants.h>
|
||||
|
||||
#include <thrust/tuple.h>
|
||||
|
||||
#include <ATen/native/cuda/MemoryAccess.cuh>
|
||||
|
||||
#include <tuple>
|
||||
|
||||
|
||||
|
||||
namespace at::native {
|
||||
|
||||
template<int N>
|
||||
@ -61,11 +62,7 @@ __device__ inline void elementwise_kernel_helper(func_t f, policy_t policy) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < elems_per_thread; i++) {
|
||||
if (policy.check_inbounds(i)) {
|
||||
#if defined(__HIP__)
|
||||
results[i] = c10::guts::apply(f, args[i]);
|
||||
#else
|
||||
results[i] = std::apply(f, args[i]);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -23,7 +23,7 @@ namespace at::native {
|
||||
|
||||
// The maximum number of threads in a block
|
||||
#if defined(USE_ROCM)
|
||||
constexpr int MAX_BLOCK_SIZE = 1024;
|
||||
constexpr int MAX_BLOCK_SIZE = 256;
|
||||
#else
|
||||
constexpr int MAX_BLOCK_SIZE = 512;
|
||||
#endif
|
||||
@ -33,7 +33,7 @@ constexpr unsigned MAX_GRID_SIZE = 65535u;
|
||||
// Number of threads in a block given an input size up to MAX_BLOCK_SIZE
|
||||
static int getNumThreads(int nElem) {
|
||||
#if defined(USE_ROCM)
|
||||
int threadSizes[5] = { 64, 128, 256, 512, MAX_BLOCK_SIZE };
|
||||
int threadSizes[5] = { 16, 32, 64, 128, MAX_BLOCK_SIZE };
|
||||
#else
|
||||
int threadSizes[5] = { 32, 64, 128, 256, MAX_BLOCK_SIZE };
|
||||
#endif
|
||||
@ -115,23 +115,9 @@ __device__ scalar_t reduce(Op op, PTA tensor, int plane) {
|
||||
// first the reductions each thread does separately
|
||||
scalar_t sum = static_cast<scalar_t>(0);
|
||||
for (int batch = threadIdx.y; batch < tensor.size(0); batch += blockDim.y) {
|
||||
#if defined(USE_ROCM)
|
||||
constexpr int UNRL = 4; // load deserilize factor
|
||||
scalar_t tmp[UNRL];
|
||||
for (int x = threadIdx.x; x < tensor.size(2); x += blockDim.x*UNRL) {
|
||||
#pragma unroll
|
||||
for (int u = 0; u < UNRL; u++)
|
||||
tmp[u] = op(batch, plane, std::min((int)tensor.size(2)-1, (int)(x+u*blockDim.x)));
|
||||
#pragma unroll
|
||||
for (int u = 0; u < UNRL; u++)
|
||||
if (x+u*blockDim.x < tensor.size(2))
|
||||
sum += tmp[u];
|
||||
}
|
||||
#else
|
||||
for (int x = threadIdx.x; x < tensor.size(2); x += blockDim.x) {
|
||||
sum += op(batch, plane, x);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
__shared__ scalar_t shared[C10_WARP_SIZE];
|
||||
SumReduceOp<scalar_t> reduce_op;
|
||||
@ -306,22 +292,6 @@ __global__ void batch_norm_collect_statistics_kernel(
|
||||
stat_accscalar_t var_n = 0;
|
||||
int n = 0;
|
||||
for (int batch = threadIdx.y; batch < input.size(0); batch += blockDim.y) {
|
||||
#if defined(USE_ROCM)
|
||||
constexpr int UNRL = 4;
|
||||
stat_accscalar_t v_[UNRL];
|
||||
for (int x = threadIdx.x; x < input.size(2); x += blockDim.x*UNRL) {
|
||||
for (int u = 0; u < UNRL; u++)
|
||||
v_[u] = input[batch][plane][min(x+u*blockDim.x, input.size(2)-1)];
|
||||
for (int u = 0; u < UNRL; u++) {
|
||||
if (x+u*blockDim.x < input.size(2)) {
|
||||
stat_accscalar_t d1 = v_[u] - avg;
|
||||
n++;
|
||||
avg += d1 / n;
|
||||
var_n += d1 * (v_[u] - avg);
|
||||
}
|
||||
}
|
||||
}
|
||||
#else
|
||||
for (int x = threadIdx.x; x < input.size(2); x += blockDim.x) {
|
||||
stat_accscalar_t v = input[batch][plane][x];
|
||||
stat_accscalar_t d1 = v - avg;
|
||||
@ -329,7 +299,6 @@ __global__ void batch_norm_collect_statistics_kernel(
|
||||
avg += d1 / n;
|
||||
var_n += d1 * (v - avg);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
// first warpSum to get one value per thread to
|
||||
|
||||
@ -92,16 +92,6 @@ inline thrust::pair<int64_t, int64_t> get_index_mapping2d(
|
||||
output_offset + output_y * output_dim_x + output_x);
|
||||
}
|
||||
|
||||
__device__ __forceinline__ int64_t reflect_index(int64_t x, int64_t len) {
|
||||
const int64_t two = (len - 1) * 2;
|
||||
if (two <= 0) {
|
||||
return 0;
|
||||
}
|
||||
int64_t m = x % two;
|
||||
if (m < 0) m += two;
|
||||
return (m < len) ? m : (two - m);
|
||||
}
|
||||
|
||||
template<typename scalar_t>
|
||||
__global__ void reflection_pad1d_out_kernel(
|
||||
const scalar_t * input, scalar_t * output,
|
||||
@ -116,28 +106,6 @@ __global__ void reflection_pad1d_out_kernel(
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void reflection_pad1d_flat(
|
||||
const scalar_t* __restrict__ input,
|
||||
scalar_t* __restrict__ output,
|
||||
int64_t input_w, int64_t pad_l, int64_t pad_r,
|
||||
int64_t out_w, int64_t plane_count) {
|
||||
|
||||
const int64_t bx = blockDim.x;
|
||||
const int64_t tx = threadIdx.x;
|
||||
|
||||
const int64_t total = plane_count * out_w;
|
||||
const int64_t grid_stride = static_cast<int64_t>(bx) * gridDim.x;
|
||||
int64_t linear = static_cast<int64_t>(blockIdx.x) * bx + tx;
|
||||
|
||||
for (; linear < total; linear += grid_stride) {
|
||||
const int64_t plane = linear / out_w;
|
||||
const int64_t x = linear - plane * out_w;
|
||||
const int64_t j = reflect_index(x - pad_l, input_w);
|
||||
output[plane * out_w + x] = input[plane * input_w + j];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void reflection_pad1d_backward_out_kernel(
|
||||
scalar_t * grad_input, const scalar_t * grad_output,
|
||||
@ -742,44 +710,25 @@ TORCH_IMPL_FUNC(reflection_pad1d_out_cuda)
|
||||
int64_t input_w = input_.size(dim_w);
|
||||
int64_t output_w = input_w + pad_l + pad_r;
|
||||
|
||||
dim3 block_size(output_w > 256 ? 256 : output_w);
|
||||
dim3 grid_size((int)::ceil(output_w / 256.0), nplane, nbatch);
|
||||
|
||||
Tensor input = input_.contiguous();
|
||||
|
||||
const int block_x = static_cast<int>(std::min<int64_t>(256, std::max<int64_t>(1, output_w)));
|
||||
const cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
|
||||
const int max_x = prop->maxGridSize[0];
|
||||
const int max_y = prop->maxGridSize[1];
|
||||
const int max_z = prop->maxGridSize[2];
|
||||
|
||||
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kHalf, kBFloat16, input.scalar_type(), "reflection_pad1d_out", [&] {
|
||||
auto stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
const int64_t gx = at::ceil_div(output_w, static_cast<int64_t>(block_x));
|
||||
|
||||
const bool fits3d = (nplane <= max_y) && (nbatch <= max_z) && (gx <= max_x);
|
||||
|
||||
if (fits3d) {
|
||||
dim3 block(block_x, 1, 1);
|
||||
dim3 grid(gx, static_cast<unsigned>(nplane), static_cast<unsigned>(nbatch));
|
||||
reflection_pad1d_out_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
||||
input.const_data_ptr<scalar_t>(),
|
||||
output.mutable_data_ptr<scalar_t>(),
|
||||
input_w, pad_l, pad_r);
|
||||
} else {
|
||||
dim3 block(block_x, 1, 1);
|
||||
const int64_t plane_count = nplane * nbatch;
|
||||
const int64_t total_blocks = at::ceil_div(plane_count * output_w, static_cast<int64_t>(block_x));
|
||||
const int grid_x = static_cast<int>(std::min<int64_t>(max_x, std::max<int64_t>(1, total_blocks)));
|
||||
dim3 grid(grid_x, 1, 1);
|
||||
|
||||
reflection_pad1d_flat<scalar_t><<<grid, block, 0, stream>>>(
|
||||
input.const_data_ptr<scalar_t>(),
|
||||
output.mutable_data_ptr<scalar_t>(),
|
||||
input_w, pad_l, pad_r, output_w, plane_count);
|
||||
}
|
||||
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
});
|
||||
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(
|
||||
kHalf, kBFloat16, input.scalar_type(), "reflection_pad1d_out_template", [&] {
|
||||
reflection_pad1d_out_kernel<<<
|
||||
grid_size,
|
||||
block_size,
|
||||
0,
|
||||
at::cuda::getCurrentCUDAStream()>>>(
|
||||
input.const_data_ptr<scalar_t>(),
|
||||
output.mutable_data_ptr<scalar_t>(),
|
||||
input_w,
|
||||
pad_l,
|
||||
pad_r);
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
});
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(reflection_pad1d_backward_out_cuda)(const Tensor& grad_output_,
|
||||
|
||||
@ -43,12 +43,6 @@ std::tuple<Tensor&, Tensor&> kthvalue_out_impl_cuda(
|
||||
TORCH_CHECK(k >= 1 && k <= slicesize,
|
||||
"kthvalue(): selected number k out of range for dimension ", dim);
|
||||
|
||||
TORCH_CHECK(
|
||||
slicesize <= std::numeric_limits<int32_t>::max(),
|
||||
"kthvalue(): dimension ", dim, " is too large (", slicesize,
|
||||
"). The current CUDA implementation supports dimension sizes up to ",
|
||||
std::numeric_limits<int32_t>::max());
|
||||
|
||||
at::assert_no_overlap(self, values);
|
||||
|
||||
_reduction_with_indices_allocate_or_resize_output(
|
||||
@ -169,6 +163,10 @@ std::tuple<Tensor&, Tensor&> kthvalue_out_cuda(
|
||||
bool keepdim,
|
||||
Tensor& values,
|
||||
Tensor& indices) {
|
||||
// See note [Writing Nondeterministic Operations]
|
||||
// If there are duplicate elements of the kth value, the procedure for choosing which
|
||||
// of the duplicates to use for the indices output is nondeterministic.
|
||||
at::globalContext().alertNotDeterministic("kthvalue CUDA");
|
||||
auto result = [&]() {
|
||||
NoNamesGuard guard;
|
||||
// `kthvalue_out_impl_cuda` expects contiguous in input `self`.
|
||||
|
||||
@ -65,34 +65,25 @@ __global__ void gatherKthValue(
|
||||
&kValue);
|
||||
|
||||
// Find the index of the k-th highest element
|
||||
__shared__ int32_t minIndexFound;
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
minIndexFound = static_cast<int32_t>(inputSliceSize);
|
||||
}
|
||||
__syncthreads();
|
||||
index_t kValueIndex = 0;
|
||||
bool foundKValue = false;
|
||||
|
||||
for (index_t i = threadIdx.x; i < inputSliceSize; i += blockDim.x) {
|
||||
// Early exit based on best-so-far
|
||||
if (i >= minIndexFound) {
|
||||
break;
|
||||
}
|
||||
|
||||
scalar_t v = doLdg(&inputSliceStart[i * inputWithinSliceStride]);
|
||||
bool isKValue =
|
||||
((v == kValue) || (at::_isnan(v) && at::_isnan(kValue)));
|
||||
|
||||
if (isKValue) {
|
||||
atomicMin(&minIndexFound, static_cast<int32_t>(i));
|
||||
break;
|
||||
}
|
||||
bool inRange = (i < inputSliceSize);
|
||||
scalar_t v = inRange ? doLdg(&inputSliceStart[i * inputWithinSliceStride])
|
||||
: static_cast<scalar_t>(0);
|
||||
bool isKValue = inRange &&
|
||||
((v == kValue) || (at::_isnan(v) && at::_isnan(kValue)));
|
||||
if (isKValue) {
|
||||
kValueIndex = i;
|
||||
foundKValue = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
indicesSliceStart[0] = static_cast<index_t>(minIndexFound);
|
||||
kthValueSliceStart[0] = kValue;
|
||||
if (foundKValue) {
|
||||
kthValueSliceStart[0] = kValue;
|
||||
indicesSliceStart[0] = kValueIndex;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -127,29 +127,6 @@ __global__ void upsample_bilinear2d_nhwc_out_frame(
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef USE_ROCM
|
||||
// Helper function to compute output pixel range that can contribute to input pixel
|
||||
template <typename accscalar_t>
|
||||
__device__ __forceinline__ void compute_output_range(
|
||||
int input_pos,
|
||||
accscalar_t scale,
|
||||
int output_size,
|
||||
bool align_corners,
|
||||
int& min_output,
|
||||
int& max_output) {
|
||||
accscalar_t lo, hi;
|
||||
if (align_corners) {
|
||||
lo = static_cast<accscalar_t>(input_pos - 1) / scale;
|
||||
hi = static_cast<accscalar_t>(input_pos + 1) / scale;
|
||||
} else {
|
||||
lo = (input_pos - static_cast<accscalar_t>(0.5)) / scale - static_cast<accscalar_t>(0.5);
|
||||
hi = (input_pos + static_cast<accscalar_t>(1.5)) / scale - static_cast<accscalar_t>(0.5);
|
||||
}
|
||||
min_output = max(0, static_cast<int>(std::ceil(lo)));
|
||||
max_output = min(output_size - 1, static_cast<int>(std::floor(hi)));
|
||||
}
|
||||
#endif
|
||||
|
||||
// Backward (adjoint) operation 1 <- 2 (accumulates)
|
||||
template <typename scalar_t, typename accscalar_t>
|
||||
C10_LAUNCH_BOUNDS_1(1024)
|
||||
@ -164,74 +141,8 @@ __global__ void upsample_bilinear2d_backward_out_frame(
|
||||
const bool align_corners,
|
||||
scalar_t* __restrict__ idata,
|
||||
const scalar_t* __restrict__ odata) {
|
||||
// In C++, integer multiplication, like in standard arithmetic, is generally commutative.
|
||||
const size_t i_numel = nc * width1 * height1;
|
||||
#ifdef USE_ROCM
|
||||
for (size_t index = blockDim.x * blockIdx.x + threadIdx.x; index < i_numel;
|
||||
index += blockDim.x * gridDim.x) {
|
||||
// Decode input pixel coordinates
|
||||
size_t index_temp = index;
|
||||
const int w1 = index_temp % width1;
|
||||
index_temp /= width1;
|
||||
const int h1 = index_temp % height1;
|
||||
const size_t nc_idx = index_temp / height1;
|
||||
|
||||
accscalar_t grad_sum = 0;
|
||||
|
||||
// Find range of output pixels that could interpolate from this input pixel
|
||||
int h2_min, h2_max, w2_min, w2_max;
|
||||
compute_output_range<accscalar_t>(h1, rheight, height2, align_corners, h2_min, h2_max);
|
||||
compute_output_range<accscalar_t>(w1, rwidth, width2, align_corners, w2_min, w2_max);
|
||||
|
||||
// Iterate over potential output pixels
|
||||
for (int h2 = h2_min; h2 <= h2_max; h2++) {
|
||||
for (int w2 = w2_min; w2 <= w2_max; w2++) {
|
||||
// Compute source coordinates for this output pixel
|
||||
const accscalar_t h1r = area_pixel_compute_source_index<accscalar_t>(
|
||||
rheight, h2, align_corners, /*cubic=*/false);
|
||||
const int h1_base = (int)h1r;
|
||||
const int h1p = (h1_base < height1 - 1) ? 1 : 0;
|
||||
const accscalar_t h1lambda = h1r - h1_base;
|
||||
const accscalar_t h0lambda = static_cast<accscalar_t>(1) - h1lambda;
|
||||
|
||||
const accscalar_t w1r = area_pixel_compute_source_index<accscalar_t>(
|
||||
rwidth, w2, align_corners, /*cubic=*/false);
|
||||
const int w1_base = (int)w1r;
|
||||
const int w1p = (w1_base < width1 - 1) ? 1 : 0;
|
||||
const accscalar_t w1lambda = w1r - w1_base;
|
||||
const accscalar_t w0lambda = static_cast<accscalar_t>(1) - w1lambda;
|
||||
|
||||
// Check if our input pixel participates in this interpolation and accumulate all weights
|
||||
// At boundaries, h1p=0 or w1p=0 causes some sampling positions to collapse
|
||||
// to the same pixel, so we need to accumulate weights from all matching positions
|
||||
accscalar_t weight = 0;
|
||||
|
||||
// Check all four interpolation positions and accumulate weights
|
||||
if (h1 == h1_base && w1 == w1_base) {
|
||||
weight += h0lambda * w0lambda; // top-left
|
||||
}
|
||||
if (h1 == h1_base && w1 == w1_base + w1p) {
|
||||
weight += h0lambda * w1lambda; // top-right (may be same as top-left if w1p=0)
|
||||
}
|
||||
if (h1 == h1_base + h1p && w1 == w1_base) {
|
||||
weight += h1lambda * w0lambda; // bottom-left (may be same as top-left if h1p=0)
|
||||
}
|
||||
if (h1 == h1_base + h1p && w1 == w1_base + w1p) {
|
||||
weight += h1lambda * w1lambda; // bottom-right (may collapse to other positions)
|
||||
}
|
||||
|
||||
if (weight > 0) {
|
||||
const size_t output_idx = nc_idx * height2 * width2 + h2 * width2 + w2;
|
||||
grad_sum += weight * static_cast<accscalar_t>(odata[output_idx]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Write accumulated gradient (no atomics needed)
|
||||
idata[index] = static_cast<scalar_t>(grad_sum);
|
||||
}
|
||||
#else
|
||||
const size_t o_numel = nc * width2 * height2;
|
||||
const size_t i_numel = nc * width1 * height1;
|
||||
for (size_t index = blockDim.x * blockIdx.x + threadIdx.x; index < o_numel;
|
||||
index += blockDim.x * gridDim.x) {
|
||||
size_t index_temp = index;
|
||||
@ -280,7 +191,6 @@ __global__ void upsample_bilinear2d_backward_out_frame(
|
||||
static_cast<scalar_t>(h1lambda * w1lambda * d2val),
|
||||
true);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename scalar_t, typename accscalar_t>
|
||||
@ -477,6 +387,7 @@ static void upsample_bilinear2d_backward_out_cuda_template(
|
||||
// threads are not covering the whole input tensor.
|
||||
grad_input.zero_();
|
||||
|
||||
const size_t num_kernels = nbatch * channels * output_height * output_width;
|
||||
const int num_threads = std::min(
|
||||
at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock, 1024);
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
@ -486,12 +397,6 @@ static void upsample_bilinear2d_backward_out_cuda_template(
|
||||
return;
|
||||
}
|
||||
|
||||
#ifdef USE_ROCM
|
||||
constexpr bool use_input = true;
|
||||
#else
|
||||
constexpr bool use_input = false;
|
||||
#endif
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(
|
||||
at::ScalarType::Half, at::ScalarType::BFloat16,
|
||||
grad_output_.scalar_type(), "upsample_bilinear2d_backward_out_frame", [&] {
|
||||
@ -509,8 +414,6 @@ static void upsample_bilinear2d_backward_out_cuda_template(
|
||||
const accscalar_t rwidth = area_pixel_compute_scale<accscalar_t>(
|
||||
input_width, output_width, align_corners, scales_w);
|
||||
|
||||
const size_t num_kernels = nbatch * channels * output_height * output_width;
|
||||
|
||||
upsample_bilinear2d_backward_nhwc_out_frame<scalar_t, accscalar_t>
|
||||
<<<ceil_div(num_kernels, static_cast<size_t>(num_threads)), num_threads, 0, stream>>>(
|
||||
input_height,
|
||||
@ -541,8 +444,6 @@ static void upsample_bilinear2d_backward_out_cuda_template(
|
||||
const accscalar_t rwidth = area_pixel_compute_scale<accscalar_t>(
|
||||
input_width, output_width, align_corners, scales_w);
|
||||
|
||||
const size_t num_kernels = nbatch * channels * (use_input ? input_height * input_width : output_height * output_width);
|
||||
|
||||
upsample_bilinear2d_backward_out_frame<scalar_t, accscalar_t>
|
||||
<<<ceil_div(num_kernels, static_cast<size_t>(num_threads)),
|
||||
num_threads,
|
||||
|
||||
@ -52,7 +52,7 @@ struct FusedAdagradMathFunctor {
|
||||
using opmath_t = at::opmath_type<scalar_t>;
|
||||
|
||||
C10_DEVICE __forceinline__ void operator()(
|
||||
int64_t chunk_size,
|
||||
int chunk_size,
|
||||
FusedOptimizerTensorListMetadata<3>& tl,
|
||||
const float* lr_ptr,
|
||||
const double& lr,
|
||||
@ -133,4 +133,4 @@ struct FusedAdagradMathFunctor {
|
||||
|
||||
} // namespace
|
||||
|
||||
} // namespace at::native
|
||||
} // namespace at::native
|
||||
@ -2,7 +2,6 @@
|
||||
#include <ATen/WrapDimUtilsMulti.h>
|
||||
#include <ATen/native/Resize.h>
|
||||
#include <ATen/native/mkldnn/xpu/detail/oneDNN.h>
|
||||
#include <ATen/native/xpu/Blas.h>
|
||||
#include <torch/library.h>
|
||||
#ifndef AT_PER_OPERATOR_HEADERS
|
||||
|
||||
@ -51,13 +50,9 @@ Tensor& addmm_out(
|
||||
mat1.dtype(),
|
||||
" != ",
|
||||
mat2.dtype())
|
||||
|
||||
// complex case
|
||||
if (self.is_complex()) {
|
||||
at::native::addmm_complex_out_xpu(self, mat1, mat2, beta, alpha, result);
|
||||
|
||||
return result;
|
||||
}
|
||||
TORCH_CHECK(
|
||||
!mat1.is_complex(), "Complex datatype matmul is not supported in oneDNN");
|
||||
|
||||
std::vector<int64_t> result_shape = {mat1.size(0), mat2.size(1)};
|
||||
result.resize_(result_shape);
|
||||
@ -172,11 +167,8 @@ Tensor& mm_out(const Tensor& self, const Tensor& mat2, Tensor& result) {
|
||||
return result;
|
||||
}
|
||||
|
||||
if (self.is_complex()) {
|
||||
at::native::mm_complex_out_xpu(self, mat2, result);
|
||||
|
||||
return result;
|
||||
}
|
||||
TORCH_CHECK(
|
||||
!self.is_complex(), "Complex datatype matmul is not supported in oneDNN");
|
||||
|
||||
onednn::matmul(result, self, mat2, Tensor(), true, onednn::Attr());
|
||||
return result;
|
||||
@ -216,12 +208,9 @@ Tensor& baddbmm_out(
|
||||
input.sizes());
|
||||
|
||||
// complex case
|
||||
if (input.is_complex()) {
|
||||
at::native::baddbmm_complex_out_xpu(
|
||||
input, batch1, batch2, beta, alpha, result);
|
||||
|
||||
return result;
|
||||
}
|
||||
TORCH_CHECK(
|
||||
!batch1.is_complex(),
|
||||
"Complex datatype matmul is not supported in oneDNN");
|
||||
|
||||
// general case
|
||||
onednn::Attr attr;
|
||||
@ -268,13 +257,8 @@ Tensor& bmm_out(const Tensor& self, const Tensor& batch2, Tensor& result) {
|
||||
return result;
|
||||
}
|
||||
|
||||
// complex case
|
||||
if (self.is_complex()) {
|
||||
at::native::bmm_complex_out_xpu(self, batch2, result);
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
TORCH_CHECK(
|
||||
!self.is_complex(), "Complex datatype matmul is not supported in oneDNN");
|
||||
onednn::matmul(result, self, batch2, at::Tensor(), true, onednn::Attr());
|
||||
return result;
|
||||
}
|
||||
|
||||
@ -222,13 +222,6 @@ struct nextafter_functor {
|
||||
}
|
||||
};
|
||||
|
||||
struct hypot_functor {
|
||||
template <typename T>
|
||||
inline T operator()(const T a, const T b) {
|
||||
return static_cast<T>(precise::sqrt(float(a) * a + float(b) * b));
|
||||
}
|
||||
};
|
||||
|
||||
// Complex binary functors
|
||||
struct polar_functor {
|
||||
template <typename U>
|
||||
@ -369,7 +362,6 @@ struct igammac_functor {
|
||||
REGISTER_OPMATH_BINARY_OP(NAME, half, half); \
|
||||
REGISTER_OPMATH_BINARY_OP(NAME, bfloat, bfloat)
|
||||
|
||||
REGISTER_FLOAT_BINARY_OP(hypot);
|
||||
REGISTER_FLOAT_BINARY_OP(copysign);
|
||||
REGISTER_INT2FLOAT_BINARY_OP(copysign);
|
||||
REGISTER_FLOAT_BINARY_OP(fmax);
|
||||
|
||||
@ -1,16 +0,0 @@
|
||||
#pragma onces
|
||||
#include <c10/metal/common.h>
|
||||
|
||||
template <unsigned N = c10::metal::max_ndim>
|
||||
struct OrgqrParams {
|
||||
int32_t num_batch_dims;
|
||||
|
||||
uint32_t m;
|
||||
uint32_t n;
|
||||
uint32_t k;
|
||||
|
||||
::c10::metal::array<uint32_t, N> A_strides;
|
||||
::c10::metal::array<uint32_t, N> tau_strides;
|
||||
::c10::metal::array<uint32_t, N> H_strides;
|
||||
::c10::metal::array<uint32_t, N> H_sizes;
|
||||
};
|
||||
@ -1,4 +1,3 @@
|
||||
#include <ATen/native/mps/kernels/LinearAlgebra.h>
|
||||
#include <c10/metal/utils.h>
|
||||
#include <metal_array>
|
||||
#include <metal_simdgroup>
|
||||
@ -641,164 +640,6 @@ kernel void applyPivots(
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static T bool_to_float(bool b) {
|
||||
return static_cast<T>(b);
|
||||
}
|
||||
|
||||
template <>
|
||||
half2 bool_to_float(bool b) {
|
||||
return half2(b ? 1 : 0, 0);
|
||||
}
|
||||
|
||||
template <>
|
||||
float2 bool_to_float(bool b) {
|
||||
return float2(b ? 1 : 0, 0);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static T calc_H_irc(
|
||||
device T* A,
|
||||
uint32_t A_stride_r,
|
||||
uint32_t A_stride_c,
|
||||
constant T* tau,
|
||||
uint32_t tau_stride,
|
||||
uint32_t r,
|
||||
uint32_t c,
|
||||
uint32_t i) {
|
||||
T I_val = bool_to_float<T>(r == c);
|
||||
T tau_val = tau[i * tau_stride];
|
||||
|
||||
T A_ci = c10::metal::conj(A[c * A_stride_r + i * A_stride_c]);
|
||||
T A_ri = A[r * A_stride_r + i * A_stride_c];
|
||||
|
||||
T c_eq_i = bool_to_float<T>(c == i);
|
||||
T r_eq_i = bool_to_float<T>(r == i);
|
||||
|
||||
T A_ci_ = (c > i) ? A_ci : c_eq_i;
|
||||
T A_ri_ = (r > i) ? A_ri : r_eq_i;
|
||||
|
||||
return I_val - c10::metal::mul(tau_val, c10::metal::mul(A_ci_, A_ri_));
|
||||
}
|
||||
|
||||
// Calculate (A @ B)[r, c], the element in the r-th row and c-th column of the
|
||||
// result of matrix multiplying A and B together. A and B must be size m-by-m
|
||||
// and have the same strides. The formula for this operation, written in Python
|
||||
// syntax, is:
|
||||
// (A @ B)[r, c] = A[r, :].dot(B[:, c])
|
||||
template <typename T>
|
||||
static T calc_matmul_rc(
|
||||
device T* A,
|
||||
device T* B,
|
||||
uint32_t stride_r,
|
||||
uint32_t stride_c,
|
||||
uint32_t m,
|
||||
uint32_t r,
|
||||
uint32_t c) {
|
||||
T AB_rc = 0;
|
||||
auto A_row_offset = r * stride_r;
|
||||
auto B_col_offset = c * stride_c;
|
||||
|
||||
uint32_t A_col_offset = 0;
|
||||
uint32_t B_row_offset = 0;
|
||||
|
||||
for (uint32_t j = 0; j < m;
|
||||
j++, A_col_offset += stride_c, B_row_offset += stride_r) {
|
||||
AB_rc += c10::metal::mul(
|
||||
A[A_row_offset + A_col_offset], B[B_row_offset + B_col_offset]);
|
||||
}
|
||||
return AB_rc;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
kernel void orgqr(
|
||||
device T* A [[buffer(0)]],
|
||||
constant T* tau [[buffer(1)]],
|
||||
device T* H [[buffer(2)]],
|
||||
device T* H_prod [[buffer(3)]],
|
||||
constant OrgqrParams<>& params [[buffer(4)]],
|
||||
uint tid [[thread_position_in_grid]]) {
|
||||
constant auto& A_strides = params.A_strides;
|
||||
constant auto& tau_strides = params.tau_strides;
|
||||
constant auto& H_strides = params.H_strides;
|
||||
constant auto& H_sizes = params.H_sizes;
|
||||
|
||||
auto num_batch_dims = params.num_batch_dims;
|
||||
auto m = params.m;
|
||||
auto n = params.n;
|
||||
auto k = params.k;
|
||||
|
||||
auto m2 = m * m;
|
||||
auto batch_idx = tid / m2;
|
||||
|
||||
// Find the matrices for this thread's batch index
|
||||
uint32_t A_offset = 0;
|
||||
uint32_t tau_offset = 0;
|
||||
uint32_t H_offset = 0;
|
||||
|
||||
for (auto dim = num_batch_dims - 1; dim >= 0; dim--) {
|
||||
auto dim_size = H_sizes[dim];
|
||||
auto dim_idx = batch_idx % dim_size;
|
||||
|
||||
A_offset += dim_idx * A_strides[dim];
|
||||
tau_offset += dim_idx * tau_strides[dim];
|
||||
H_offset += dim_idx * H_strides[dim];
|
||||
|
||||
batch_idx /= dim_size;
|
||||
}
|
||||
|
||||
A += A_offset;
|
||||
tau += tau_offset;
|
||||
H += H_offset;
|
||||
H_prod += H_offset;
|
||||
|
||||
auto matrix_idx = tid % m2;
|
||||
auto r = matrix_idx / m;
|
||||
auto c = matrix_idx % m;
|
||||
auto A_stride_r = A_strides[num_batch_dims];
|
||||
auto A_stride_c = A_strides[num_batch_dims + 1];
|
||||
auto tau_stride = tau_strides[num_batch_dims];
|
||||
auto H_stride_r = H_strides[num_batch_dims];
|
||||
auto H_stride_c = H_strides[num_batch_dims + 1];
|
||||
|
||||
// Find the element of H and H_prod that this thread will calculate
|
||||
device T* H_elem_ptr = H + (r * H_stride_r + c * H_stride_c);
|
||||
device T* H_prod_elem_ptr = H_prod + (r * H_stride_r + c * H_stride_c);
|
||||
|
||||
for (uint32_t i = 0; i < k; i++) {
|
||||
// Calculate and write H_i
|
||||
|
||||
T H_irc = calc_H_irc(A, A_stride_r, A_stride_c, tau, tau_stride, r, c, i);
|
||||
|
||||
// Calculate element [r, c] of prod(H_0, ..., H_i)
|
||||
if (i == 0) {
|
||||
*H_prod_elem_ptr = H_irc;
|
||||
} else {
|
||||
*H_elem_ptr = H_irc;
|
||||
|
||||
// Need this sync because the below matmul requires all threads to finish
|
||||
// writing their entries to `H_prod` and `H`.
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
T H_prod_0_to_i_rc =
|
||||
calc_matmul_rc(H_prod, H, H_stride_r, H_stride_c, m, r, c);
|
||||
|
||||
// Need this sync because the above matmul uses the current values in
|
||||
// `H_prod`, and we don't want to overwrite those until all threads are
|
||||
// finished using them.
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
*H_prod_elem_ptr = H_prod_0_to_i_rc;
|
||||
}
|
||||
}
|
||||
|
||||
device T* A_elem_ptr = A + (r * A_stride_r + c * A_stride_c);
|
||||
|
||||
if (c < n) {
|
||||
*A_elem_ptr = *H_prod_elem_ptr;
|
||||
}
|
||||
}
|
||||
|
||||
#define INSTANTIATE_MM_OPS(DTYPE) \
|
||||
template [[host_name("matmul_" #DTYPE)]] kernel void matmul<DTYPE>( \
|
||||
constant DTYPE * mat1Data [[buffer(0)]], \
|
||||
@ -838,19 +679,3 @@ INSTANTIATE_MM_OPS(int);
|
||||
INSTANTIATE_MM_OPS(short);
|
||||
INSTANTIATE_MM_OPS(char);
|
||||
INSTANTIATE_MM_OPS(uchar);
|
||||
|
||||
#define REGISTER_ORGQR(T) \
|
||||
template [[host_name("orgqr_" #T)]] \
|
||||
kernel void orgqr<T>( \
|
||||
device T * A [[buffer(0)]], \
|
||||
constant T * tau [[buffer(1)]], \
|
||||
device T * H [[buffer(2)]], \
|
||||
device T * H_prod [[buffer(3)]], \
|
||||
constant OrgqrParams<> & params [[buffer(4)]], \
|
||||
uint tid [[thread_position_in_grid]]);
|
||||
|
||||
REGISTER_ORGQR(float);
|
||||
REGISTER_ORGQR(half);
|
||||
REGISTER_ORGQR(bfloat);
|
||||
REGISTER_ORGQR(float2);
|
||||
REGISTER_ORGQR(half2);
|
||||
|
||||
@ -5,21 +5,6 @@
|
||||
using namespace metal;
|
||||
using namespace c10::metal;
|
||||
|
||||
struct angle_functor {
|
||||
template <typename T, enable_if_t<is_complex_v<T>, bool> = true>
|
||||
inline T operator()(const T x) {
|
||||
return T(atan2(x.y, x.x), 0);
|
||||
}
|
||||
template <typename T, enable_if_t<is_scalar_floating_point_v<T>, bool> = true>
|
||||
inline T operator()(const T x) {
|
||||
return T(isnan(x) ? x : x < 0 ? M_PI_F : 0.0);
|
||||
}
|
||||
template <typename T, enable_if_t<is_scalar_integral_v<T>, bool> = true>
|
||||
inline float operator()(const T x) {
|
||||
return x < 0 ? M_PI_F : 0.0;
|
||||
}
|
||||
};
|
||||
|
||||
// Implement exp wrapper for both real and complex types
|
||||
template <typename T, enable_if_t<is_scalar_floating_point_v<T>, bool> = true>
|
||||
inline T exp_(const T x) {
|
||||
@ -560,7 +545,6 @@ REGISTER_UNARY_OP(abs, float, float);
|
||||
REGISTER_UNARY_OP(abs, half, half);
|
||||
|
||||
#define INSTANTIATE_UNARY_KERNELS2(DTYPE0, DTYPE1) \
|
||||
REGISTER_UNARY_OP(angle, DTYPE1, DTYPE0); \
|
||||
REGISTER_UNARY_OP(erf, DTYPE1, DTYPE0); \
|
||||
REGISTER_UNARY_OP(erfc, DTYPE1, DTYPE0); \
|
||||
REGISTER_UNARY_OP(erfinv, DTYPE1, DTYPE0); \
|
||||
@ -599,7 +583,6 @@ INSTANTIATE_UNARY_KERNELS2(float, int);
|
||||
INSTANTIATE_UNARY_KERNELS2(float, long);
|
||||
|
||||
#define INSTANTIATE_UNARY_KERNELS_VEC2(DTYPE) \
|
||||
REGISTER_UNARY_OP(angle, DTYPE##2, DTYPE##2); \
|
||||
REGISTER_UNARY_OP(neg, DTYPE##2, DTYPE##2); \
|
||||
REGISTER_UNARY_OP(exp, DTYPE##2, DTYPE##2); \
|
||||
REGISTER_UNARY_OP(expm1, DTYPE##2, DTYPE##2); \
|
||||
|
||||
@ -92,8 +92,13 @@ static std::tuple<Tensor, Tensor> sdpa_general_mps(const Tensor& query,
|
||||
}
|
||||
|
||||
// upcasting to float32 if needed to improve precision when multiplying by the scale factor
|
||||
maskedMM = castMPSTensor(mpsGraph, maskedMM, MPSDataTypeFloat32);
|
||||
if ([maskedMM dataType] != MPSDataTypeFloat32) {
|
||||
maskedMM = [mpsGraph castTensor:maskedMM toType:MPSDataTypeFloat32 name:nil];
|
||||
}
|
||||
maskedMM = [mpsGraph multiplicationWithPrimaryTensor:maskedMM secondaryTensor:scaleTensor name:nil];
|
||||
if ([maskedMM dataType] != qTensor.dataType) {
|
||||
maskedMM = [mpsGraph castTensor:maskedMM toType:qTensor.dataType name:nil];
|
||||
}
|
||||
|
||||
if (is_causal) {
|
||||
auto causalMask = [mpsGraph constantWithScalar:1.0f
|
||||
@ -107,9 +112,7 @@ static std::tuple<Tensor, Tensor> sdpa_general_mps(const Tensor& query,
|
||||
name:nil];
|
||||
} else if (attn_mask) {
|
||||
graph->maskTensor = mpsGraphRankedPlaceHolder(mpsGraph, *attn_mask);
|
||||
maskedMM = [mpsGraph additionWithPrimaryTensor:maskedMM
|
||||
secondaryTensor:castMPSTensor(mpsGraph, graph->maskTensor, maskedMM.dataType)
|
||||
name:nil];
|
||||
maskedMM = [mpsGraph additionWithPrimaryTensor:maskedMM secondaryTensor:graph->maskTensor name:nil];
|
||||
}
|
||||
|
||||
// Account for case where all values were masked causing division by 0 in softmax (issue:#156707)
|
||||
@ -130,8 +133,8 @@ static std::tuple<Tensor, Tensor> sdpa_general_mps(const Tensor& query,
|
||||
graph->qTensor = qTensor;
|
||||
graph->kTensor = kTensor;
|
||||
graph->vTensor = vTensor;
|
||||
graph->outputTensor = castMPSTensor(mpsGraph, output, qTensor.dataType);
|
||||
graph->attnTensor = castMPSTensor(mpsGraph, sm, qTensor.dataType);
|
||||
graph->outputTensor = output;
|
||||
graph->attnTensor = sm;
|
||||
});
|
||||
auto qPlaceholder = Placeholder(cachedGraph->qTensor, query);
|
||||
auto kPlaceholder = Placeholder(cachedGraph->kTensor, key);
|
||||
|
||||
@ -202,10 +202,6 @@ static void igammac_mps_kernel(TensorIteratorBase& iter) {
|
||||
lib.exec_binary_kernel(iter, "igammac");
|
||||
}
|
||||
|
||||
static void hypot_mps_kernel(TensorIteratorBase& iter) {
|
||||
lib.exec_binary_kernel(iter, "hypot");
|
||||
}
|
||||
|
||||
REGISTER_DISPATCH(fmax_stub, &fmax_mps_kernel)
|
||||
REGISTER_DISPATCH(fmin_stub, &fmin_mps_kernel)
|
||||
REGISTER_DISPATCH(copysign_stub, ©sign_mps_kernel)
|
||||
@ -233,5 +229,4 @@ REGISTER_DISPATCH(fmod_stub, &fmod_mps_kernel)
|
||||
REGISTER_DISPATCH(remainder_stub, &remainder_mps_kernel)
|
||||
REGISTER_DISPATCH(igamma_stub, &igamma_mps_kernel)
|
||||
REGISTER_DISPATCH(igammac_stub, &igammac_mps_kernel)
|
||||
REGISTER_DISPATCH(hypot_stub, &hypot_mps_kernel)
|
||||
} // namespace at::native
|
||||
|
||||
@ -16,6 +16,7 @@
|
||||
#include <ATen/ops/eq_native.h>
|
||||
#include <ATen/ops/ge_native.h>
|
||||
#include <ATen/ops/gt_native.h>
|
||||
#include <ATen/ops/hypot_native.h>
|
||||
#include <ATen/ops/le_native.h>
|
||||
#include <ATen/ops/logaddexp2_native.h>
|
||||
#include <ATen/ops/logaddexp_native.h>
|
||||
@ -277,6 +278,22 @@ TORCH_IMPL_FUNC(pow_Scalar_out_mps)(const Scalar& base, const Tensor& exp, const
|
||||
}
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(hypot_out_mps)(const Tensor& self, const Tensor& other, const Tensor& output) {
|
||||
mps::BinaryOpBlock hypot_op_block = ^BinaryOpFn(cachedGraph, primaryCastTensor, secondaryCastTensor) {
|
||||
MPSGraph* mpsGraph = cachedGraph->graph();
|
||||
MPSGraphTensor* twoTensor = [mpsGraph constantWithScalar:2.0 shape:@[ @1 ] dataType:primaryCastTensor.dataType];
|
||||
MPSGraphTensor* sumTensor = [mpsGraph additionWithPrimaryTensor:[mpsGraph powerWithPrimaryTensor:primaryCastTensor
|
||||
secondaryTensor:twoTensor
|
||||
name:nil]
|
||||
secondaryTensor:[mpsGraph powerWithPrimaryTensor:secondaryCastTensor
|
||||
secondaryTensor:twoTensor
|
||||
name:nil]
|
||||
name:nil];
|
||||
return [mpsGraph squareRootWithTensor:sumTensor name:nil];
|
||||
};
|
||||
mps::binaryOpTensor(self, other, output, "hypot_out_mps", hypot_op_block);
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(logaddexp_out_mps)(const Tensor& self, const Tensor& other, const Tensor& output) {
|
||||
mps::BinaryOpBlock logaddexp_op_block = ^BinaryOpFn(cachedGraph, primaryCastTensor, secondaryCastTensor) {
|
||||
MPSGraph* mpsGraph = cachedGraph->graph();
|
||||
|
||||
@ -8,9 +8,6 @@
|
||||
#include <ATen/native/Resize.h>
|
||||
#include <ATen/native/mps/MPSGraphSequoiaOps.h>
|
||||
#include <ATen/native/mps/OperationUtils.h>
|
||||
#include <ATen/native/mps/kernels/LinearAlgebra.h>
|
||||
|
||||
#include <fmt/format.h>
|
||||
|
||||
#ifndef AT_PER_OPERATOR_HEADERS
|
||||
#include <ATen/Functions.h>
|
||||
@ -31,7 +28,6 @@
|
||||
#include <ATen/ops/linalg_solve_triangular_native.h>
|
||||
#include <ATen/ops/lu_unpack_native.h>
|
||||
#include <ATen/ops/mm_native.h>
|
||||
#include <ATen/ops/orgqr_native.h>
|
||||
#include <ATen/ops/slice.h>
|
||||
#include <ATen/ops/stack.h>
|
||||
#include <ATen/ops/triangular_solve_native.h>
|
||||
@ -342,8 +338,6 @@ static void linalg_lu_factor_ex_out_mps_impl(const Tensor& A,
|
||||
". See https://developer.apple.com/documentation/metalperformanceshaders/mpsmatrixdecompositionstatus for details.");
|
||||
}
|
||||
}
|
||||
|
||||
map_mps_decomposition_error_code_to_blas(info);
|
||||
}
|
||||
|
||||
static void linalg_solve_out_mps_impl(const Tensor& A,
|
||||
@ -1239,69 +1233,6 @@ static void cholesky_stub_impl(const Tensor& out, const Tensor& info, bool upper
|
||||
}
|
||||
}
|
||||
|
||||
static Tensor& orgqr_stub_impl(Tensor& self, const Tensor& tau) {
|
||||
if (self.numel() == 0) {
|
||||
return self;
|
||||
}
|
||||
|
||||
auto m = self.size(-2);
|
||||
auto n = self.size(-1);
|
||||
auto k = tau.size(-1);
|
||||
|
||||
if (tau.numel() == 0) {
|
||||
auto I = eye(m, self.scalar_type(), std::nullopt, self.device());
|
||||
return self.copy_(I.slice(-1, 0, n));
|
||||
}
|
||||
|
||||
auto num_batch_dims = self.dim() - 2;
|
||||
auto batch_sizes = self.sizes().slice(0, num_batch_dims);
|
||||
|
||||
std::vector<int64_t> H_sizes(num_batch_dims + 2);
|
||||
for (auto dim : c10::irange(num_batch_dims)) {
|
||||
H_sizes[dim] = self.size(dim);
|
||||
}
|
||||
H_sizes[num_batch_dims] = m;
|
||||
H_sizes[num_batch_dims + 1] = m;
|
||||
|
||||
auto H = at::empty(H_sizes, self.options().memory_format(MemoryFormat::Contiguous));
|
||||
auto H_prod = at::empty_like(H);
|
||||
|
||||
OrgqrParams params;
|
||||
|
||||
params.num_batch_dims = num_batch_dims;
|
||||
params.m = m;
|
||||
params.n = n;
|
||||
params.k = k;
|
||||
|
||||
for (const auto dim : c10::irange(self.dim())) {
|
||||
params.A_strides[dim] = self.stride(dim);
|
||||
|
||||
if (dim < tau.dim()) {
|
||||
params.tau_strides[dim] = tau.stride(dim);
|
||||
}
|
||||
|
||||
params.H_strides[dim] = H.stride(dim);
|
||||
params.H_sizes[dim] = H.size(dim);
|
||||
}
|
||||
|
||||
auto num_threads = H.numel();
|
||||
MPSStream* stream = getCurrentMPSStream();
|
||||
|
||||
dispatch_sync_with_rethrow(stream->queue(), ^() {
|
||||
@autoreleasepool {
|
||||
id<MTLComputeCommandEncoder> compute_encoder = stream->commandEncoder();
|
||||
auto pipeline_state = lib.getPipelineStateForFunc(fmt::format("orgqr_{}", scalarToMetalTypeString(self)));
|
||||
getMPSProfiler().beginProfileKernel(pipeline_state, "orgqr", {self, tau});
|
||||
[compute_encoder setComputePipelineState:pipeline_state];
|
||||
mtl_setArgs(compute_encoder, self, tau, H, H_prod, params);
|
||||
mtl_dispatch1DJob(compute_encoder, pipeline_state, num_threads);
|
||||
getMPSProfiler().endProfileKernel(pipeline_state);
|
||||
}
|
||||
});
|
||||
|
||||
return self;
|
||||
}
|
||||
|
||||
} // namespace mps
|
||||
|
||||
Tensor addr_mps(const Tensor& self, const Tensor& vec1, const Tensor& vec2, const Scalar& beta, const Scalar& alpha) {
|
||||
@ -1517,6 +1448,20 @@ TORCH_IMPL_FUNC(_linalg_solve_ex_out_mps)
|
||||
mps::linalg_solve_out_mps_impl(A, B, left, check_errors, result, LU, pivots, info);
|
||||
}
|
||||
|
||||
std::tuple<Tensor&, Tensor&> linalg_lu_factor_out_mps(const Tensor& A, bool pivot, Tensor& LU, Tensor& pivots) {
|
||||
Tensor info = at::empty({}, A.options().dtype(kInt));
|
||||
mps::linalg_lu_factor_ex_out_mps_impl(A, pivot, LU, pivots, info, false);
|
||||
return std::tie(LU, pivots);
|
||||
}
|
||||
|
||||
std::tuple<Tensor, Tensor> linalg_lu_factor_mps(const Tensor& A, bool pivot) {
|
||||
Tensor LU = at::empty({0}, A.options());
|
||||
Tensor pivots = at::empty({0}, A.options().dtype(kInt));
|
||||
Tensor info = at::empty({}, A.options().dtype(kInt));
|
||||
mps::linalg_lu_factor_ex_out_mps_impl(A, pivot, LU, pivots, info, false);
|
||||
return std::make_tuple(std::move(LU), std::move(pivots));
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(lu_unpack_out_mps)
|
||||
(const Tensor& LU_data,
|
||||
const Tensor& LU_pivots,
|
||||
@ -1538,6 +1483,4 @@ TORCH_IMPL_FUNC(linalg_inv_ex_out_mps)(const Tensor& A, bool check_errors, const
|
||||
}
|
||||
|
||||
REGISTER_DISPATCH(cholesky_stub, mps::cholesky_stub_impl)
|
||||
REGISTER_DISPATCH(orgqr_stub, mps::orgqr_stub_impl);
|
||||
|
||||
} // namespace at::native
|
||||
|
||||
@ -34,7 +34,6 @@ REGISTER_UNARY_TI_DISPATCH(sinc);
|
||||
REGISTER_UNARY_TI_DISPATCH(sinh);
|
||||
REGISTER_UNARY_TI_DISPATCH(cosh);
|
||||
REGISTER_UNARY_TI_DISPATCH(tanh);
|
||||
REGISTER_UNARY_TI_DISPATCH(angle);
|
||||
REGISTER_UNARY_TI_DISPATCH(abs);
|
||||
REGISTER_UNARY_TI_DISPATCH(sin);
|
||||
REGISTER_UNARY_TI_DISPATCH(cos);
|
||||
|
||||
@ -12,6 +12,7 @@
|
||||
#include <ATen/ops/_copy_from_and_resize.h>
|
||||
#include <ATen/ops/acos_native.h>
|
||||
#include <ATen/ops/acosh_native.h>
|
||||
#include <ATen/ops/angle_native.h>
|
||||
#include <ATen/ops/asin_native.h>
|
||||
#include <ATen/ops/asinh_native.h>
|
||||
#include <ATen/ops/atan_native.h>
|
||||
@ -203,6 +204,23 @@ Tensor& logical_not_out_mps(const Tensor& self, Tensor& output) {
|
||||
return output;
|
||||
}
|
||||
|
||||
Tensor& angle_out_mps(const Tensor& self, Tensor& output) {
|
||||
mps::unary_op(self, output, "angle_out_mps", ^MPSGraphTensor*(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) {
|
||||
auto realPart = [mpsGraph realPartOfTensor:inputTensor name:nil];
|
||||
auto imagPart = [mpsGraph imaginaryPartOfTensor:inputTensor name:nil];
|
||||
return [mpsGraph atan2WithPrimaryTensor:imagPart secondaryTensor:realPart name:nil];
|
||||
});
|
||||
return output;
|
||||
}
|
||||
|
||||
Tensor angle_mps(const Tensor& self) {
|
||||
const auto float_type = c10::isIntegralType(self.scalar_type(), /*includeBool=*/true)
|
||||
? c10::typeMetaToScalarType(c10::get_default_dtype())
|
||||
: c10::toRealValueType(self.scalar_type());
|
||||
Tensor result = at::empty({0}, self.options().dtype(float_type));
|
||||
return angle_out_mps(self, result);
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(frac_out_mps)(const Tensor& self, const Tensor& output) {
|
||||
TORCH_CHECK(isFloatingType(self.scalar_type()), "frac_out_mps is only implemented for floating types");
|
||||
mps::unary_op(self, output, "frac_out_mps", ^MPSGraphTensor*(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) {
|
||||
|
||||
@ -403,14 +403,16 @@
|
||||
device_check: NoCheck # TensorIterator
|
||||
variants: function, method
|
||||
dispatch:
|
||||
CPU, CUDA, MPS: angle
|
||||
CPU, CUDA: angle
|
||||
MPS: angle_mps
|
||||
SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: angle_sparse_csr
|
||||
tags: pointwise
|
||||
|
||||
- func: angle.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
|
||||
device_check: NoCheck # TensorIterator
|
||||
dispatch:
|
||||
CPU, CUDA, MPS: angle_out
|
||||
CPU, CUDA: angle_out
|
||||
MPS: angle_out_mps
|
||||
SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: angle_sparse_csr_out
|
||||
tags: pointwise
|
||||
|
||||
@ -10040,7 +10042,8 @@
|
||||
structured: True
|
||||
structured_inherits: TensorIteratorBase
|
||||
dispatch:
|
||||
CPU, CUDA, MPS: hypot_out
|
||||
CPU, CUDA: hypot_out
|
||||
MPS: hypot_out_mps
|
||||
tags: pointwise
|
||||
|
||||
- func: hypot(Tensor self, Tensor other) -> Tensor
|
||||
@ -14154,10 +14157,16 @@
|
||||
- func: linalg_lu_factor(Tensor A, *, bool pivot=True) -> (Tensor LU, Tensor pivots)
|
||||
python_module: linalg
|
||||
variants: function
|
||||
dispatch:
|
||||
CompositeImplicitAutograd: linalg_lu_factor
|
||||
MPS: linalg_lu_factor_mps
|
||||
|
||||
- func: linalg_lu_factor.out(Tensor A, *, bool pivot=True, Tensor(a!) LU, Tensor(b!) pivots) -> (Tensor(a!) LU, Tensor(b!) pivots)
|
||||
python_module: linalg
|
||||
variants: function
|
||||
dispatch:
|
||||
CompositeImplicitAutograd: linalg_lu_factor_out
|
||||
MPS: linalg_lu_factor_out_mps
|
||||
|
||||
- func: linalg_lu_factor_ex(Tensor A, *, bool pivot=True, bool check_errors=False) -> (Tensor LU, Tensor pivots, Tensor info)
|
||||
python_module: linalg
|
||||
@ -14359,12 +14368,12 @@
|
||||
python_module: linalg
|
||||
variants: function
|
||||
dispatch:
|
||||
CPU, CUDA, MPS: linalg_householder_product
|
||||
CPU, CUDA: linalg_householder_product
|
||||
|
||||
- func: linalg_householder_product.out(Tensor input, Tensor tau, *, Tensor(a!) out) -> Tensor(a!)
|
||||
python_module: linalg
|
||||
dispatch:
|
||||
CPU, CUDA, MPS: linalg_householder_product_out
|
||||
CPU, CUDA: linalg_householder_product_out
|
||||
|
||||
- func: linalg_inv_ex(Tensor A, *, bool check_errors=False) -> (Tensor inverse, Tensor info)
|
||||
python_module: linalg
|
||||
|
||||
@ -40,7 +40,15 @@
|
||||
#include <thrust/iterator/discard_iterator.h>
|
||||
|
||||
|
||||
#if defined(__CUDACC__) && (defined(CUSPARSE_VERSION) || (defined(USE_ROCM) && ROCM_VERSION >= 60300))
|
||||
#define IS_CUSPARSE11_AVAILABLE() 1
|
||||
#else
|
||||
#define IS_CUSPARSE11_AVAILABLE() 0
|
||||
#endif
|
||||
|
||||
#if IS_CUSPARSE11_AVAILABLE()
|
||||
#include <library_types.h>
|
||||
#endif
|
||||
|
||||
namespace at::native {
|
||||
|
||||
@ -95,9 +103,17 @@ struct csrMatrixRef {
|
||||
int nnz_{0};
|
||||
std::vector<int> size_{};
|
||||
|
||||
cusparseSpMatDescr_t description_{0};
|
||||
#if IS_CUSPARSE11_AVAILABLE()
|
||||
cusparseSpMatDescr_t description_{0};
|
||||
#else
|
||||
cusparseMatDescr_t description_{0};
|
||||
#endif
|
||||
|
||||
csrMatrixRef() = default;
|
||||
csrMatrixRef() {
|
||||
#if !IS_CUSPARSE11_AVAILABLE()
|
||||
create_general_description_(description_);
|
||||
#endif
|
||||
}
|
||||
|
||||
csrMatrixRef(
|
||||
int* csr_indices,
|
||||
@ -110,6 +126,7 @@ struct csrMatrixRef {
|
||||
csr_values_{csr_values},
|
||||
nnz_{nnz},
|
||||
size_{size} {
|
||||
#if IS_CUSPARSE11_AVAILABLE()
|
||||
cudaDataType cuda_data_type = at::cuda::getCudaDataType<scalar_t>();
|
||||
TORCH_CUDASPARSE_CHECK(cusparseCreateCsr(
|
||||
&description_,
|
||||
@ -123,10 +140,17 @@ struct csrMatrixRef {
|
||||
CUSPARSE_INDEX_32I,
|
||||
CUSPARSE_INDEX_BASE_ZERO,
|
||||
cuda_data_type));
|
||||
#else
|
||||
create_general_description_(description_);
|
||||
#endif
|
||||
}
|
||||
|
||||
~csrMatrixRef() {
|
||||
cusparseDestroySpMat(description_);
|
||||
#if IS_CUSPARSE11_AVAILABLE()
|
||||
cusparseDestroySpMat(description_);
|
||||
#else
|
||||
cusparseDestroyMatDescr(description_);
|
||||
#endif
|
||||
}
|
||||
|
||||
int size(int index) const {
|
||||
@ -172,6 +196,8 @@ struct csrOutput {
|
||||
}
|
||||
};
|
||||
|
||||
#if IS_CUSPARSE11_AVAILABLE()
|
||||
|
||||
// RAII guard helps to support cuSparse 11 API for `A @ B` operation
|
||||
// This generic template exists because with cuSparse the `scalar_t` type could be a double or float
|
||||
template <class scalar_t>
|
||||
@ -370,6 +396,284 @@ template struct CusparseMatrixMultiplyOp<float>;
|
||||
|
||||
template struct CusparseMatrixMultiplyOp<double>;
|
||||
|
||||
#else // if not IS_CUSPARSE11_AVAILABLE()
|
||||
|
||||
using DcsrMatrixRef = csrMatrixRef<double>;
|
||||
using ScsrMatrixRef = csrMatrixRef<float>;
|
||||
|
||||
// RAII guard helps to support cuSparse 10 API for `A @ B` operation
|
||||
// This generic template exists because with cuSparse the `scalar_t` type could be a double or float
|
||||
template <class scalar_t>
|
||||
struct CusparseMatrixMultiplyOp {
|
||||
csrOutput operator()(
|
||||
const csrMatrixRef<scalar_t>& lhs,
|
||||
const csrMatrixRef<scalar_t>& rhs,
|
||||
Tensor &output_values,
|
||||
Tensor &output_indices)
|
||||
{
|
||||
static_assert(false&&sizeof(scalar_t), "cusparse csr sparse-sparse MM only supports data type of float and double.");
|
||||
}
|
||||
};
|
||||
|
||||
// Specializacion for `A @ B` operation for double values with cuSparse
|
||||
template<> struct CusparseMatrixMultiplyOp<double> {
|
||||
csrgemm2Info_t gemm2Info_;
|
||||
|
||||
CusparseMatrixMultiplyOp() {
|
||||
TORCH_CUDASPARSE_CHECK(cusparseCreateCsrgemm2Info(&gemm2Info_));
|
||||
}
|
||||
~CusparseMatrixMultiplyOp() {
|
||||
cusparseDestroyCsrgemm2Info(gemm2Info_);
|
||||
}
|
||||
|
||||
csrOutput operator ()(
|
||||
const DcsrMatrixRef& lhs,
|
||||
const DcsrMatrixRef& rhs,
|
||||
Tensor &output_values,
|
||||
Tensor &output_indices) {
|
||||
double alpha = 1.0;
|
||||
DcsrMatrixRef empty;
|
||||
return Dgemm2(lhs, rhs, empty, &alpha, nullptr, output_values, output_indices);
|
||||
}
|
||||
|
||||
csrOutput Dgemm2(
|
||||
const DcsrMatrixRef& A,
|
||||
const DcsrMatrixRef& B,
|
||||
const DcsrMatrixRef& C,
|
||||
const double* alpha,
|
||||
const double* beta,
|
||||
Tensor &output_values,
|
||||
Tensor &output_indices) {
|
||||
void* buffer_{nullptr};
|
||||
cusparseHandle_t cusparseHandle_ = at::cuda::getCurrentCUDASparseHandle();
|
||||
TORCH_CUDASPARSE_CHECK(cusparseSetPointerMode(cusparseHandle_, CUSPARSE_POINTER_MODE_HOST));
|
||||
|
||||
csrOutput out({A.size(0), B.size(1)});
|
||||
int innerSize = confirm_mult_size(A.size_, B.size_);
|
||||
out.csr_pointers_ = at::empty({out.size(0) + 1}, output_indices.options().dtype(kInt));
|
||||
|
||||
// Compute needed buffer size
|
||||
size_t new_bubber_sz;
|
||||
TORCH_CUDASPARSE_CHECK(cusparseDcsrgemm2_bufferSizeExt(
|
||||
cusparseHandle_,
|
||||
out.size(0),
|
||||
out.size(1),
|
||||
innerSize,
|
||||
alpha,
|
||||
A.description_,
|
||||
A.nnz_,
|
||||
A.csr_pointers_,
|
||||
A.csr_indices_,
|
||||
B.description_,
|
||||
B.nnz_,
|
||||
B.csr_pointers_,
|
||||
B.csr_indices_,
|
||||
beta,
|
||||
C.description_,
|
||||
C.nnz_,
|
||||
C.csr_pointers_,
|
||||
C.csr_indices_,
|
||||
gemm2Info_,
|
||||
&new_bubber_sz));
|
||||
|
||||
// (Re)allocate buffer if needed
|
||||
auto& allocator = *::c10::cuda::CUDACachingAllocator::get();
|
||||
at::DataPtr data_ptr = allocator.allocate(new_bubber_sz);
|
||||
buffer_ = data_ptr.get();
|
||||
|
||||
// Find the resulting non-zero pattern.
|
||||
TORCH_CUDASPARSE_CHECK(cusparseXcsrgemm2Nnz(
|
||||
cusparseHandle_,
|
||||
out.size(0),
|
||||
out.size(1),
|
||||
innerSize,
|
||||
A.description_,
|
||||
A.nnz_,
|
||||
A.csr_pointers_,
|
||||
A.csr_indices_,
|
||||
B.description_,
|
||||
B.nnz_,
|
||||
B.csr_pointers_,
|
||||
B.csr_indices_,
|
||||
C.description_,
|
||||
C.nnz_,
|
||||
C.csr_pointers_,
|
||||
C.csr_indices_,
|
||||
out.description_,
|
||||
out.csr_pointers_.data_ptr<int>(),
|
||||
&out.nnz_,
|
||||
gemm2Info_,
|
||||
buffer_));
|
||||
|
||||
out.csr_indices_ = at::empty({out.nnz_}, output_indices.options().dtype(kInt));
|
||||
out.csr_values_ = at::empty({out.nnz_}, output_values.options());
|
||||
|
||||
// Perform the gemm2 operation for doubles
|
||||
// out = alpha ∗ A ∗ B + beta ∗ C
|
||||
TORCH_CUDASPARSE_CHECK(cusparseDcsrgemm2(
|
||||
cusparseHandle_,
|
||||
out.size(0),
|
||||
out.size(1),
|
||||
innerSize,
|
||||
alpha,
|
||||
A.description_,
|
||||
A.nnz_,
|
||||
A.csr_values_,
|
||||
A.csr_pointers_,
|
||||
A.csr_indices_,
|
||||
B.description_,
|
||||
B.nnz_,
|
||||
B.csr_values_,
|
||||
B.csr_pointers_,
|
||||
B.csr_indices_,
|
||||
beta,
|
||||
C.description_,
|
||||
C.nnz_,
|
||||
C.csr_values_,
|
||||
C.csr_pointers_,
|
||||
C.csr_indices_,
|
||||
out.description_,
|
||||
out.csr_values_.data_ptr<double>(),
|
||||
out.csr_pointers_.data_ptr<int>(),
|
||||
out.csr_indices_.data_ptr<int>(),
|
||||
gemm2Info_,
|
||||
buffer_));
|
||||
return out;
|
||||
}
|
||||
};
|
||||
|
||||
// Specializacion for `A @ B` operation for float values with cuSparse
|
||||
template<> struct CusparseMatrixMultiplyOp<float> {
|
||||
csrgemm2Info_t gemm2Info_;
|
||||
|
||||
CusparseMatrixMultiplyOp() {
|
||||
TORCH_CUDASPARSE_CHECK(cusparseCreateCsrgemm2Info(&gemm2Info_));
|
||||
|
||||
}
|
||||
~CusparseMatrixMultiplyOp() {
|
||||
cusparseDestroyCsrgemm2Info(gemm2Info_);
|
||||
}
|
||||
csrOutput operator()(
|
||||
const ScsrMatrixRef& lhs,
|
||||
const ScsrMatrixRef& rhs,
|
||||
Tensor &output_values,
|
||||
Tensor &output_indices) {
|
||||
float alpha = 1.0;
|
||||
ScsrMatrixRef empty;
|
||||
return Sgemm2(lhs, rhs, empty, &alpha, nullptr, output_values, output_indices);
|
||||
}
|
||||
|
||||
csrOutput Sgemm2(
|
||||
const ScsrMatrixRef& A,
|
||||
const ScsrMatrixRef& B,
|
||||
const ScsrMatrixRef& C,
|
||||
const float* alpha,
|
||||
const float* beta,
|
||||
Tensor &output_values,
|
||||
Tensor &output_indices) {
|
||||
void* buffer_{nullptr};
|
||||
cusparseHandle_t cusparseHandle_ = at::cuda::getCurrentCUDASparseHandle();
|
||||
TORCH_CUDASPARSE_CHECK(cusparseSetPointerMode(cusparseHandle_, CUSPARSE_POINTER_MODE_HOST));
|
||||
|
||||
csrOutput out({A.size(0), B.size(1)});
|
||||
|
||||
int innerSize = confirm_mult_size(A.size_, B.size_);
|
||||
|
||||
out.csr_pointers_ = at::empty({out.size(0) + 1}, output_indices.options().dtype(kInt));
|
||||
|
||||
// Compute needed buffer size
|
||||
size_t new_bubber_sz;
|
||||
TORCH_CUDASPARSE_CHECK(cusparseScsrgemm2_bufferSizeExt(
|
||||
cusparseHandle_,
|
||||
out.size(0),
|
||||
out.size(1),
|
||||
innerSize,
|
||||
alpha,
|
||||
A.description_,
|
||||
A.nnz_,
|
||||
A.csr_pointers_,
|
||||
A.csr_indices_,
|
||||
B.description_,
|
||||
B.nnz_,
|
||||
B.csr_pointers_,
|
||||
B.csr_indices_,
|
||||
beta,
|
||||
C.description_,
|
||||
C.nnz_,
|
||||
C.csr_pointers_,
|
||||
C.csr_indices_,
|
||||
gemm2Info_,
|
||||
&new_bubber_sz));
|
||||
|
||||
auto& allocator = *::c10::cuda::CUDACachingAllocator::get();
|
||||
at::DataPtr data_ptr = allocator.allocate(new_bubber_sz);
|
||||
buffer_ = data_ptr.get();
|
||||
|
||||
// Find the resulting non-zero pattern.
|
||||
TORCH_CUDASPARSE_CHECK(cusparseXcsrgemm2Nnz(
|
||||
cusparseHandle_,
|
||||
out.size(0),
|
||||
out.size(1),
|
||||
innerSize,
|
||||
A.description_,
|
||||
A.nnz_,
|
||||
A.csr_pointers_,
|
||||
A.csr_indices_,
|
||||
B.description_,
|
||||
B.nnz_,
|
||||
B.csr_pointers_,
|
||||
B.csr_indices_,
|
||||
C.description_,
|
||||
C.nnz_,
|
||||
C.csr_pointers_,
|
||||
C.csr_indices_,
|
||||
out.description_,
|
||||
out.csr_pointers_.data_ptr<int>(),
|
||||
&out.nnz_,
|
||||
gemm2Info_,
|
||||
buffer_));
|
||||
|
||||
out.csr_indices_ = at::empty({out.nnz_}, output_indices.options().dtype(kInt));
|
||||
out.csr_values_ = at::empty({out.nnz_}, output_values.options());
|
||||
|
||||
// Perform the gemm2 operation for doubles
|
||||
// out = alpha ∗ A ∗ B + beta ∗ C
|
||||
TORCH_CUDASPARSE_CHECK(cusparseScsrgemm2(
|
||||
cusparseHandle_,
|
||||
out.size(0),
|
||||
out.size(1),
|
||||
innerSize,
|
||||
alpha,
|
||||
A.description_,
|
||||
A.nnz_,
|
||||
A.csr_values_,
|
||||
A.csr_pointers_,
|
||||
A.csr_indices_,
|
||||
B.description_,
|
||||
B.nnz_,
|
||||
B.csr_values_,
|
||||
B.csr_pointers_,
|
||||
B.csr_indices_,
|
||||
beta,
|
||||
C.description_,
|
||||
C.nnz_,
|
||||
C.csr_values_,
|
||||
C.csr_pointers_,
|
||||
C.csr_indices_,
|
||||
out.description_,
|
||||
out.csr_values_.data_ptr<float>(),
|
||||
out.csr_pointers_.data_ptr<int>(),
|
||||
out.csr_indices_.data_ptr<int>(),
|
||||
gemm2Info_,
|
||||
buffer_));
|
||||
return out;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
|
||||
#endif // IS_CUSPARSE11_AVAILABLE()
|
||||
|
||||
template <typename scalar_t>
|
||||
void sparse_sparse_matmul_cuda_kernel(
|
||||
Tensor& result,
|
||||
@ -511,15 +815,19 @@ Tensor sparse_sparse_matmul_cuda(const Tensor& mat1_, const Tensor& mat2_) {
|
||||
auto output = at::native::empty_like(mat1_);
|
||||
output.sparse_resize_and_clear_({mat1_.size(0), mat2_.size(1)}, mat1_.sparse_dim(), 0);
|
||||
|
||||
#if !defined(USE_ROCM)
|
||||
#if IS_CUSPARSE11_AVAILABLE() && !defined(USE_ROCM)
|
||||
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kHalf, kBFloat16, mat1_.scalar_type(), "sparse_matmul", [&] {
|
||||
sparse_sparse_matmul_cuda_kernel<scalar_t>(output, mat1_.coalesce(), mat2_.coalesce());
|
||||
});
|
||||
#else
|
||||
#elif IS_CUSPARSE11_AVAILABLE() && defined(USE_ROCM)
|
||||
// ROCm does not support half and bfloat16 types for sparse_matmul
|
||||
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(mat1_.scalar_type(), "sparse_matmul", [&] {
|
||||
sparse_sparse_matmul_cuda_kernel<scalar_t>(output, mat1_.coalesce(), mat2_.coalesce());
|
||||
});
|
||||
#else
|
||||
AT_DISPATCH_FLOATING_TYPES(mat1_.scalar_type(), "sparse_matmul", [&] {
|
||||
sparse_sparse_matmul_cuda_kernel<scalar_t>(output, mat1_.coalesce(), mat2_.coalesce());
|
||||
});
|
||||
#endif
|
||||
return output;
|
||||
}
|
||||
|
||||
@ -33,7 +33,7 @@ using namespace mps;
|
||||
#ifndef PYTORCH_JIT_COMPILE_SHADERS
|
||||
static auto& lib = MetalShaderLibrary::getBundledLibrary();
|
||||
#else
|
||||
#include <ATen/native/mps/SparseTensorMath_metallib.h>
|
||||
#include <ATen/native/mps/Mul_metallib.h>
|
||||
#endif
|
||||
|
||||
static Tensor& s_addmm_out_sparse_dense_mps(
|
||||
@ -369,7 +369,12 @@ static SparseTensor& mul_out_dense_sparse_mps(
|
||||
}
|
||||
|
||||
if (scalar_like) {
|
||||
auto out_vals = values.mul(dense.to(values.options()));
|
||||
auto scalar = dense;
|
||||
if (dense.numel() == 1 && dense.dim() > 0) {
|
||||
scalar = dense.view({});
|
||||
}
|
||||
scalar = scalar.to(values.options());
|
||||
auto out_vals = values.mul(scalar);
|
||||
if (out.scalar_type() != commonDtype) {
|
||||
out_vals = out_vals.to(out.scalar_type());
|
||||
}
|
||||
@ -503,14 +508,14 @@ SparseTensor& mul_out_sparse_mps(const Tensor& t_, const Tensor& src_, SparseTen
|
||||
const auto device = r_.device();
|
||||
auto stream = getCurrentMPSStream();
|
||||
|
||||
auto lhs_indices = lhs._indices().contiguous();
|
||||
auto rhs_indices = rhs._indices().contiguous();
|
||||
auto lhs_values = lhs._values().to(commonDtype).contiguous();
|
||||
auto rhs_values = rhs._values().to(commonDtype).contiguous();
|
||||
auto lhs_indices = lhs._indices();
|
||||
auto rhs_indices = rhs._indices();
|
||||
auto lhs_values = lhs._values().to(commonDtype);
|
||||
auto rhs_values = rhs._values().to(commonDtype);
|
||||
|
||||
// Flatten sparse indices to keys
|
||||
auto lhs_keys = flatten_indices(lhs_indices, lhs.sizes().slice(0, ndim_i));
|
||||
auto rhs_keys = flatten_indices(rhs_indices, rhs.sizes().slice(0, ndim_i));
|
||||
auto lhs_keys = flatten_indices(lhs_indices, lhs.sizes());
|
||||
auto rhs_keys = flatten_indices(rhs_indices, rhs.sizes());
|
||||
|
||||
// Intersect sorted keys (search the shorter in the longer)
|
||||
const bool A_is_lhs = (lhs_nnz <= rhs_nnz);
|
||||
@ -541,54 +546,35 @@ SparseTensor& mul_out_sparse_mps(const Tensor& t_, const Tensor& src_, SparseTen
|
||||
auto out_indices = at::empty({ndim_i, static_cast<int64_t>(M)}, at::device(device).dtype(at::kLong));
|
||||
auto lhs_match = outA_idx.narrow(0, 0, M);
|
||||
auto rhs_match = outB_idx.narrow(0, 0, M);
|
||||
auto dense_sizes_vec = lhs.sizes().slice(ndim_i).vec();
|
||||
int64_t cols64 = 1;
|
||||
for (auto s : dense_sizes_vec) cols64 *= s;
|
||||
const uint32_t cols = static_cast<uint32_t>(std::max<int64_t>(cols64, 1));
|
||||
|
||||
auto to2d = [&](Tensor t, int64_t nnz) -> Tensor {
|
||||
const int64_t t_cols = t.numel() / nnz;
|
||||
if (t_cols == cols64) {
|
||||
return t.view({nnz, cols64});
|
||||
}
|
||||
return t.view({nnz, 1}).expand({nnz, cols64}).contiguous();
|
||||
};
|
||||
|
||||
// make both sides 2d [nnz, cols] buffers so the kernel can index it
|
||||
auto lhs_vals2d = to2d(lhs_values, lhs_nnz);
|
||||
auto rhs_vals2d = to2d(rhs_values, rhs_nnz);
|
||||
|
||||
std::vector<int64_t> out_val_sizes;
|
||||
out_val_sizes.reserve(1 + dense_sizes_vec.size());
|
||||
out_val_sizes.push_back(static_cast<int64_t>(M));
|
||||
out_val_sizes.insert(out_val_sizes.end(), dense_sizes_vec.begin(), dense_sizes_vec.end());
|
||||
auto out_val_sizes = lhs_values.sizes().vec();
|
||||
out_val_sizes[0] = static_cast<int64_t>(M);
|
||||
auto out_values = at::empty(out_val_sizes, lhs_values.options());
|
||||
|
||||
if (M > 0) {
|
||||
dispatch_sync_with_rethrow(stream->queue(), ^() {
|
||||
@autoreleasepool {
|
||||
auto pso = lib.getPipelineStateForFunc(
|
||||
"fused_gather_mul_kernel_" + mps::scalarToMetalTypeString(lhs_values));
|
||||
auto enc = stream->commandEncoder();
|
||||
[enc setComputePipelineState:pso];
|
||||
const uint32_t cols = static_cast<uint32_t>(
|
||||
lhs_values.numel() / std::max<int64_t>(1, lhs_nnz));
|
||||
|
||||
const uint32_t tew = pso.threadExecutionWidth;
|
||||
const uint32_t gridW = std::max<uint32_t>(cols, 1u);
|
||||
const uint32_t tgW = std::min(gridW, tew);
|
||||
MTLSize grid = MTLSizeMake(gridW, 1, M);
|
||||
MTLSize tgs = MTLSizeMake(tgW, 1, 1);
|
||||
dispatch_sync_with_rethrow(stream->queue(), ^() {
|
||||
@autoreleasepool {
|
||||
auto pso = lib.getPipelineStateForFunc(
|
||||
"fused_gather_mul_kernel_" + mps::scalarToMetalTypeString(lhs_values));
|
||||
auto enc = stream->commandEncoder();
|
||||
[enc setComputePipelineState:pso];
|
||||
|
||||
mtl_setArgs(enc,
|
||||
lhs_vals2d, rhs_vals2d,
|
||||
lhs_match, rhs_match,
|
||||
lhs_indices, out_indices,
|
||||
out_values,
|
||||
std::array<uint32_t, 2>{static_cast<uint32_t>(ndim_i), static_cast<uint32_t>(lhs_nnz)},
|
||||
std::array<uint32_t, 2>{M, cols});
|
||||
[enc dispatchThreads:grid threadsPerThreadgroup:tgs];
|
||||
}
|
||||
});
|
||||
}
|
||||
const uint32_t tew = pso.threadExecutionWidth;
|
||||
uint32_t tgW = std::min(cols, tew);
|
||||
MTLSize grid = MTLSizeMake(cols, 1, M);
|
||||
MTLSize tgs = MTLSizeMake(tgW, 1, 1);
|
||||
|
||||
mtl_setArgs(enc,
|
||||
lhs_values, rhs_values,
|
||||
lhs_match, rhs_match,
|
||||
lhs_indices, out_indices,
|
||||
out_values,
|
||||
std::array<uint32_t, 2>{static_cast<uint32_t>(ndim_i), static_cast<uint32_t>(lhs_nnz)},
|
||||
std::array<uint32_t, 2>{M, cols});
|
||||
[enc dispatchThreads:grid threadsPerThreadgroup:tgs];
|
||||
}
|
||||
});
|
||||
|
||||
if (r_.scalar_type() != commonDtype) {
|
||||
out_values = out_values.to(r_.scalar_type());
|
||||
|
||||
@ -62,6 +62,7 @@ kernel void build_row_ptr_from_sorted_rows_by_batch(
|
||||
|
||||
template <typename T>
|
||||
kernel void spmm_bmm_coo_rows_grouped(
|
||||
device const long* rows [[buffer(0)]],
|
||||
device const long* cols [[buffer(1)]],
|
||||
device const T* vals [[buffer(2)]],
|
||||
device const T* dense [[buffer(3)]],
|
||||
@ -72,6 +73,7 @@ kernel void spmm_bmm_coo_rows_grouped(
|
||||
uint3 ltid [[thread_position_in_threadgroup]],
|
||||
uint3 tptg [[threads_per_threadgroup]])
|
||||
{
|
||||
const uint B = dims.x;
|
||||
const uint I = dims.y;
|
||||
const uint J = dims.z;
|
||||
const uint K = dims.w;
|
||||
@ -195,9 +197,9 @@ kernel void fused_gather_mul_kernel(
|
||||
const ulong offR = (ulong)iR * (ulong)view_cols + (ulong)col;
|
||||
const ulong offO = (ulong)k * (ulong)view_cols + (ulong)col;
|
||||
|
||||
const auto a = static_cast<accum_t<T>>(lhs_vals[offL]);
|
||||
const auto b = static_cast<accum_t<T>>(rhs_vals[offR]);
|
||||
out_vals[offO] = static_cast<T>(mul(a, b));
|
||||
const float a = (float)lhs_vals[offL];
|
||||
const float b = (float)rhs_vals[offR];
|
||||
out_vals[offO] = (T)(a * b);
|
||||
}
|
||||
|
||||
// One thread per match copies the indices column
|
||||
@ -319,6 +321,7 @@ INSTANTIATE_FOR_FLOAT_TYPES(INSTANTIATE_FUSED_GATHER_MUL);
|
||||
#define INSTANTIATE_SPMM_BMM_COO_ROWS_GROUPED(DTYPE) \
|
||||
template [[host_name("spmm_bmm_coo_rows_grouped_" #DTYPE)]] kernel void \
|
||||
spmm_bmm_coo_rows_grouped<DTYPE>( \
|
||||
device const long* rows [[buffer(0)]], \
|
||||
device const long* cols [[buffer(1)]], \
|
||||
device const DTYPE* vals [[buffer(2)]], \
|
||||
device const DTYPE* dense [[buffer(3)]], \
|
||||
@ -76,21 +76,14 @@ bool priority_order_init_ = false;
|
||||
// TODO(eqy): more benchmarking to determine whether this should include sm86/89
|
||||
// Needs to be kept in-sync with test_fused_chocie in test_transformers.py
|
||||
bool check_prefer_cudnn_attention() {
|
||||
static const bool prefer_cudnn = c10::utils::check_env("TORCH_CUDNN_SDPA_DEPRIORITIZED") != true;
|
||||
static const bool prefer_cudnn = c10::utils::check_env("TORCH_CUDNN_SDPA_PREFERRED") != false;
|
||||
if (!prefer_cudnn) {
|
||||
return false;
|
||||
}
|
||||
#if (defined(CUDNN_VERSION) && (CUDNN_VERSION >= 90900))
|
||||
try {
|
||||
auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||
auto major = dprops->major;
|
||||
return (major == 9 || major == 10) && !dprops->minor;
|
||||
} catch (c10::Error const& e) {
|
||||
#ifdef DEBUG
|
||||
TORCH_WARN("check_prefer_cudnn_attention() caught exception ", e.what());
|
||||
#endif
|
||||
return false;
|
||||
}
|
||||
auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||
auto major = dprops->major;
|
||||
return (major == 9 || major == 10) && !dprops->minor;
|
||||
#else
|
||||
return false;
|
||||
#endif
|
||||
|
||||
@ -37,10 +37,6 @@ TEST(SingletonOrSharedTypePtr, Comparison) {
|
||||
|
||||
EXPECT_NE(empty, p);
|
||||
EXPECT_NE(p, p2);
|
||||
|
||||
EXPECT_EQ(empty, empty);
|
||||
EXPECT_EQ(p, p);
|
||||
EXPECT_EQ(p2, p2);
|
||||
}
|
||||
|
||||
TEST(SingletonOrSharedTypePtr, SingletonComparison) {
|
||||
@ -51,8 +47,6 @@ TEST(SingletonOrSharedTypePtr, SingletonComparison) {
|
||||
c10::TypePtr type = c10::NoneType::get();
|
||||
EXPECT_NE(type, c10::StringType::get());
|
||||
EXPECT_NE(type, c10::DeviceObjType::get());
|
||||
EXPECT_EQ(type, type);
|
||||
EXPECT_EQ(type, c10::NoneType::get());
|
||||
}
|
||||
|
||||
|
||||
|
||||
@ -526,41 +526,6 @@ namespace {
|
||||
[](const vec& v) { return v.expm1(); },
|
||||
createDefaultUnaryTestCase<vec>(TestSeed(), false, true));
|
||||
}
|
||||
TYPED_TEST(Exponents, ExpU20) {
|
||||
using vec = TypeParam;
|
||||
using VT = ValueType<TypeParam>;
|
||||
using UVT = UvalueType<TypeParam>;
|
||||
|
||||
// Explicit edge values
|
||||
VT v_too_small = VT(-100.0); // much less than -87.3
|
||||
VT exp_too_small = std::exp(v_too_small);
|
||||
VT v_neg_edge = VT(-0x1.5d5e2ap+6f); // just at the edge
|
||||
VT exp_neg_edge = std::exp(v_neg_edge);
|
||||
VT v_zero = VT(0.0); // middle, normal case
|
||||
VT exp_zero = std::exp(v_zero);
|
||||
VT v_pos_edge = VT(0x1.5d5e2ap+6f); // just at the edge
|
||||
VT exp_pos_edge = std::exp(v_pos_edge);
|
||||
VT v_too_large = VT(100.0); // much more than 87.3
|
||||
VT exp_too_large = std::exp(v_too_large);
|
||||
|
||||
auto test_case = TestingCase<vec>::getBuilder()
|
||||
// Randoms in normal range, but the .addCustom() below guarantees we hit the special/fallback cases
|
||||
.addDomain(CheckWithinDomains<UVT>{{{-100, 100}}, false, getDefaultTolerance<UVT>()})
|
||||
.addCustom({ {v_too_small}, exp_too_small })
|
||||
.addCustom({ {v_neg_edge}, exp_neg_edge })
|
||||
.addCustom({ {v_zero}, exp_zero })
|
||||
.addCustom({ {v_pos_edge}, exp_pos_edge })
|
||||
.addCustom({ {v_too_large}, exp_too_large })
|
||||
.setTrialCount(65536)
|
||||
.setTestSeed(TestSeed());
|
||||
|
||||
test_unary<vec>(
|
||||
NAME_INFO(exp_u20_edge_cases),
|
||||
RESOLVE_OVERLOAD(std::exp),
|
||||
[](const vec& v) { return v.exp_u20(); },
|
||||
test_case
|
||||
);
|
||||
}
|
||||
TYPED_TEST(ErrorFunctions, Erf) {
|
||||
using vec = TypeParam;
|
||||
test_unary<vec>(
|
||||
|
||||
@ -58,7 +58,8 @@ def list_benchmarks():
|
||||
|
||||
def run_benchmark(
|
||||
benchmark_name: str,
|
||||
script_args,
|
||||
should_visualize: bool = False,
|
||||
compile_mode: str = "max-autotune-no-cudagraphs",
|
||||
):
|
||||
"""Run a specific benchmark."""
|
||||
if benchmark_name not in BENCHMARK_REGISTRY:
|
||||
@ -67,29 +68,29 @@ def run_benchmark(
|
||||
return False
|
||||
|
||||
print(f"Running benchmark: {benchmark_name}")
|
||||
print(f"Torch compile mode: {script_args.compile_mode}")
|
||||
print(f"Torch compile mode: {compile_mode}")
|
||||
print("=" * 60)
|
||||
|
||||
benchmark_class = BENCHMARK_REGISTRY[benchmark_name]
|
||||
benchmark = benchmark_class(script_args)
|
||||
benchmark = benchmark_class(compile_mode)
|
||||
benchmark.benchmark()
|
||||
if script_args.visualize:
|
||||
if should_visualize:
|
||||
benchmark.visualize()
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def run_all_benchmarks(script_args):
|
||||
def run_all_benchmarks(should_visualize: bool = False, compile_mode: str = "default"):
|
||||
"""Run all available benchmarks."""
|
||||
print("Running all benchmarks...")
|
||||
print(f"Torch compile mode: {script_args.compile_mode}")
|
||||
print(f"Torch compile mode: {compile_mode}")
|
||||
print("=" * 60)
|
||||
|
||||
for name, cls in BENCHMARK_REGISTRY.items():
|
||||
print(f"\n{'=' * 20} {name.upper()} {'=' * 20}")
|
||||
benchmark = cls(script_args)
|
||||
benchmark = cls(compile_mode)
|
||||
benchmark.benchmark()
|
||||
if script_args.visualize:
|
||||
if should_visualize:
|
||||
benchmark.visualize()
|
||||
print()
|
||||
|
||||
@ -136,19 +137,6 @@ Examples:
|
||||
help="Torch compile mode to use (default: default)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--tolerance",
|
||||
type=float,
|
||||
default=None,
|
||||
help="Tolerance for the accuracy check",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exit-on-accuracy-failure",
|
||||
action="store_true",
|
||||
help="Whether to exit with an error message for accuracy failure",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Handle list option
|
||||
@ -158,7 +146,7 @@ Examples:
|
||||
|
||||
# Handle all option
|
||||
if args.all:
|
||||
run_all_benchmarks(args)
|
||||
run_all_benchmarks(args.visualize, args.compile_mode)
|
||||
return
|
||||
|
||||
# Handle specific benchmarks
|
||||
@ -169,7 +157,7 @@ Examples:
|
||||
sys.exit(1)
|
||||
|
||||
for benchmark_name in args.benchmarks:
|
||||
run_benchmark(benchmark_name, args)
|
||||
run_benchmark(benchmark_name, args.visualize, args.compile_mode)
|
||||
print() # Add spacing between benchmarks
|
||||
|
||||
|
||||
|
||||
@ -9,8 +9,8 @@ import torch.nn.functional as F
|
||||
|
||||
|
||||
class CrossEntropyForward(BenchmarkKernel):
|
||||
def __init__(self, script_args):
|
||||
super().__init__(script_args)
|
||||
def __init__(self, compile_mode: str = "max-autotune-no-cudagraphs"):
|
||||
super().__init__(compile_mode)
|
||||
self.available_backends = ["eager", "compiled", "quack", "liger"]
|
||||
|
||||
def get_shapes(self) -> tuple[tuple[int, ...], ...]:
|
||||
@ -106,8 +106,8 @@ class CrossEntropyForward(BenchmarkKernel):
|
||||
|
||||
|
||||
class CrossEntropyBackward(BenchmarkKernel):
|
||||
def __init__(self, script_args):
|
||||
super().__init__(script_args)
|
||||
def __init__(self, compile_mode: str = "max-autotune-no-cudagraphs"):
|
||||
super().__init__(compile_mode)
|
||||
self.available_backends = ["eager", "compiled", "quack", "liger"]
|
||||
|
||||
def get_shapes(self) -> tuple[tuple[int, ...], ...]:
|
||||
@ -194,8 +194,8 @@ class CrossEntropyBackward(BenchmarkKernel):
|
||||
|
||||
|
||||
class SoftmaxForward(BenchmarkKernel):
|
||||
def __init__(self, script_args):
|
||||
super().__init__(script_args)
|
||||
def __init__(self, compile_mode: str = "max-autotune-no-cudagraphs"):
|
||||
super().__init__(compile_mode)
|
||||
self.available_backends = ["eager", "compiled", "quack", "liger"]
|
||||
|
||||
def get_shapes(self) -> tuple[tuple[int, ...], ...]:
|
||||
@ -259,8 +259,8 @@ class SoftmaxForward(BenchmarkKernel):
|
||||
|
||||
|
||||
class SoftmaxBackward(BenchmarkKernel):
|
||||
def __init__(self, script_args):
|
||||
super().__init__(script_args)
|
||||
def __init__(self, compile_mode: str = "max-autotune-no-cudagraphs"):
|
||||
super().__init__(compile_mode)
|
||||
self.available_backends = ["eager", "compiled", "quack", "liger"]
|
||||
|
||||
def get_shapes(self) -> tuple[tuple[int, ...], ...]:
|
||||
@ -329,8 +329,8 @@ class SoftmaxBackward(BenchmarkKernel):
|
||||
|
||||
|
||||
class RMSNormForward(BenchmarkKernel):
|
||||
def __init__(self, script_args):
|
||||
super().__init__(script_args)
|
||||
def __init__(self, compile_mode: str = "max-autotune-no-cudagraphs"):
|
||||
super().__init__(compile_mode)
|
||||
self.available_backends = ["eager", "compiled", "quack", "liger"]
|
||||
|
||||
def get_shapes(self) -> tuple[tuple[int, ...], ...]:
|
||||
@ -383,22 +383,7 @@ class RMSNormForward(BenchmarkKernel):
|
||||
from quack.rmsnorm import _rmsnorm_fwd
|
||||
|
||||
x, w = args
|
||||
y = torch.empty_like(x)
|
||||
|
||||
def quack_fwd():
|
||||
_rmsnorm_fwd(
|
||||
x,
|
||||
w,
|
||||
out=y,
|
||||
bias=None,
|
||||
rstd=None,
|
||||
residual=None,
|
||||
residual_out=None,
|
||||
eps=1e-6,
|
||||
)
|
||||
return y
|
||||
|
||||
return quack_fwd
|
||||
return lambda: _rmsnorm_fwd(x, w, eps=1e-6)
|
||||
|
||||
def liger(self, args, kwargs) -> Any:
|
||||
from liger_kernel.transformers.rms_norm import LigerRMSNorm
|
||||
@ -419,14 +404,9 @@ class RMSNormForward(BenchmarkKernel):
|
||||
|
||||
|
||||
class RMSNormBackward(BenchmarkKernel):
|
||||
def __init__(self, script_args):
|
||||
super().__init__(script_args)
|
||||
self.available_backends = [
|
||||
"eager",
|
||||
"compiled",
|
||||
"quack",
|
||||
"liger",
|
||||
]
|
||||
def __init__(self, compile_mode: str = "max-autotune-no-cudagraphs"):
|
||||
super().__init__(compile_mode)
|
||||
self.available_backends = ["eager", "compiled", "quack", "liger"]
|
||||
|
||||
def get_shapes(self) -> tuple[tuple[int, ...], ...]:
|
||||
# TODO: OOM for (32768, 65536) on h100
|
||||
@ -474,11 +454,8 @@ class RMSNormBackward(BenchmarkKernel):
|
||||
y, [x, w], grad_outputs=dy, retain_graph=True
|
||||
)
|
||||
|
||||
def compute_rstd(self, x, eps):
|
||||
return torch.rsqrt(torch.mean(x.float().square(), dim=-1, keepdim=True) + eps)
|
||||
|
||||
def quack(self, args, kwargs=None) -> Any:
|
||||
from quack.rmsnorm import _get_sm_count, _rmsnorm_bwd
|
||||
from quack.rmsnorm import _rmsnorm_backward
|
||||
|
||||
(
|
||||
x,
|
||||
@ -486,40 +463,15 @@ class RMSNormBackward(BenchmarkKernel):
|
||||
dy,
|
||||
) = args
|
||||
M, N = x.shape
|
||||
|
||||
rstd = self.compute_rstd(x, eps=1e-6)
|
||||
dx = torch.empty_like(x)
|
||||
sm_count = _get_sm_count(x.size(1), x.device)
|
||||
dw_partial = torch.empty(
|
||||
sm_count, x.size(1), device=x.device, dtype=torch.float32
|
||||
)
|
||||
|
||||
def quack_bwd():
|
||||
_rmsnorm_bwd(
|
||||
x,
|
||||
w,
|
||||
dy,
|
||||
rstd,
|
||||
dx,
|
||||
dw_partial,
|
||||
db_partial=None,
|
||||
dresidual_out=None,
|
||||
dresidual=None,
|
||||
sm_count=sm_count,
|
||||
)
|
||||
dw = dw_partial.sum(dim=0).to(w.dtype)
|
||||
return dx, dw
|
||||
|
||||
return quack_bwd
|
||||
rstd = torch.randn(M, device="cuda", dtype=torch.float32)
|
||||
return lambda: _rmsnorm_backward(x, w, dy, rstd)
|
||||
|
||||
def liger(self, args, kwargs=None) -> Any:
|
||||
from liger_kernel.transformers.rms_norm import LigerRMSNorm
|
||||
|
||||
x, w, dy = args
|
||||
M, N = x.shape
|
||||
liger_rmsnorm = LigerRMSNorm(
|
||||
hidden_size=N, eps=1e-6, casting_mode="gemma"
|
||||
).cuda()
|
||||
liger_rmsnorm = LigerRMSNorm(hidden_size=N, eps=1e-6).cuda()
|
||||
liger_rmsnorm.weight.data.copy_(w)
|
||||
y = liger_rmsnorm(x)
|
||||
return lambda: torch.autograd.grad(
|
||||
@ -537,8 +489,8 @@ class RMSNormBackward(BenchmarkKernel):
|
||||
|
||||
|
||||
class LayerNormForward(BenchmarkKernel):
|
||||
def __init__(self, script_args):
|
||||
super().__init__(script_args)
|
||||
def __init__(self, compile_mode: str = "max-autotune-no-cudagraphs"):
|
||||
super().__init__(compile_mode)
|
||||
self.available_backends = ["eager", "compiled", "quack", "liger"]
|
||||
|
||||
def get_shapes(self) -> tuple[tuple[int, ...], ...]:
|
||||
@ -611,8 +563,8 @@ class LayerNormForward(BenchmarkKernel):
|
||||
|
||||
|
||||
class LayerNormBackward(BenchmarkKernel):
|
||||
def __init__(self, script_args):
|
||||
super().__init__(script_args)
|
||||
def __init__(self, compile_mode: str = "max-autotune-no-cudagraphs"):
|
||||
super().__init__(compile_mode)
|
||||
self.available_backends = ["eager", "compiled", "liger"]
|
||||
|
||||
def get_shapes(self) -> tuple[tuple[int, ...], ...]:
|
||||
@ -662,31 +614,20 @@ class LayerNormBackward(BenchmarkKernel):
|
||||
y, [x, w], grad_outputs=dy, retain_graph=True
|
||||
)
|
||||
|
||||
def compute_mean_rstd(self, x, eps):
|
||||
x = x.float()
|
||||
|
||||
var, mean = torch.var_mean(x, dim=-1, keepdim=True, correction=0)
|
||||
rstd = torch.rsqrt(var + eps)
|
||||
return mean, rstd
|
||||
|
||||
def liger(self, args, kwargs) -> Any:
|
||||
"""
|
||||
Call layer_norm_backward directly rather than calling
|
||||
liger_kernel.transformers.layer_norm.LigerLayerNorm and
|
||||
torch.autograd.grad.
|
||||
|
||||
The latter fashion saves mean/rstd in x.dtype which can fail
|
||||
accuracy test. We call layer_norm_backward with fp32 mean and
|
||||
rstd.
|
||||
"""
|
||||
from liger_kernel.ops.layer_norm import layer_norm_backward
|
||||
from liger_kernel.transformers.layer_norm import LigerLayerNorm
|
||||
|
||||
x, w, dy = args
|
||||
eps = 1e-6
|
||||
mean, rstd = self.compute_mean_rstd(x, eps)
|
||||
M, N = x.shape
|
||||
|
||||
return lambda: layer_norm_backward(dy, x, w, None, mean, rstd)[0:2]
|
||||
liger_layernorm = LigerLayerNorm(hidden_size=N, eps=1e-6).cuda()
|
||||
liger_layernorm.weight.data.copy_(w)
|
||||
liger_layernorm.bias.data.copy_(
|
||||
torch.zeros(N, device="cuda", dtype=torch.float32)
|
||||
)
|
||||
y = liger_layernorm(x)
|
||||
return lambda: torch.autograd.grad(
|
||||
y, [x, liger_layernorm.weight], grad_outputs=dy, retain_graph=True
|
||||
)
|
||||
|
||||
def benchmark(self):
|
||||
for M, N in self.get_shapes():
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
import os
|
||||
import sys
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
@ -44,11 +43,10 @@ class Performance:
|
||||
|
||||
|
||||
class BenchmarkKernel:
|
||||
def __init__(self, script_args):
|
||||
self.script_args = script_args
|
||||
def __init__(self, compile_mode: str = "max-autotune-no-cudagraphs"):
|
||||
self.name = self.__class__.__name__
|
||||
self.available_backends: list[str] = []
|
||||
self.compile_mode: str = script_args.compile_mode
|
||||
self.compile_mode: str = compile_mode
|
||||
|
||||
# mapping from backend to list of performance results
|
||||
self.profiling_results: defaultdict[str, list[Performance]] = defaultdict(list)
|
||||
@ -108,21 +106,14 @@ class BenchmarkKernel:
|
||||
args_ref, kwargs_ref = self.clone_inputs(args, kwargs)
|
||||
res[backend] = getattr(self, backend)(args_ref, kwargs_ref)()
|
||||
gold = res["eager"]
|
||||
|
||||
tol = {}
|
||||
if self.script_args.tolerance:
|
||||
tol = {
|
||||
"atol": self.script_args.tolerance,
|
||||
"rtol": self.script_args.tolerance,
|
||||
}
|
||||
for backend in self.available_backends:
|
||||
if backend == "eager":
|
||||
continue
|
||||
try:
|
||||
torch.testing.assert_close(res[backend], gold, **tol)
|
||||
torch.testing.assert_close(res[backend], gold)
|
||||
for t, gold_t in zip(res[backend], gold):
|
||||
if t.requires_grad:
|
||||
torch.testing.assert_close(t.grad, gold_t.grad, **tol)
|
||||
torch.testing.assert_close(t.grad, gold_t.grad)
|
||||
print(
|
||||
f"Accuracy check \033[92m✓ succeed\033[0m for {backend} backend on {self.name} kernel"
|
||||
)
|
||||
@ -130,9 +121,6 @@ class BenchmarkKernel:
|
||||
print(
|
||||
f"Accuracy check \033[91m✗ failed\033[0m for {backend} backend on {self.name} kernel. Error {e}"
|
||||
)
|
||||
if self.script_args.exit_on_accuracy_failure:
|
||||
print("Exit right away since --exit-on-accuracy-failure is set")
|
||||
sys.exit(1)
|
||||
|
||||
def benchmark_single_shape(
|
||||
self, args, kwargs=None, should_check_accuracy=True, setting: str = ""
|
||||
|
||||
@ -1,8 +1,8 @@
|
||||
add_loop_eager,compile_time_instruction_count,3184000000,0.1
|
||||
add_loop_eager,compile_time_instruction_count,3070000000,0.1
|
||||
|
||||
|
||||
|
||||
add_loop_eager_dynamic,compile_time_instruction_count,4595000000,0.1
|
||||
add_loop_eager_dynamic,compile_time_instruction_count,4432000000,0.1
|
||||
|
||||
|
||||
|
||||
@ -18,7 +18,7 @@ add_loop_inductor_gpu,compile_time_instruction_count,26800000000,0.1
|
||||
|
||||
|
||||
|
||||
basic_modules_ListOfLinears_eager,compile_time_instruction_count,1096000000,0.1
|
||||
basic_modules_ListOfLinears_eager,compile_time_instruction_count,1048000000,0.1
|
||||
|
||||
|
||||
|
||||
@ -26,7 +26,7 @@ basic_modules_ListOfLinears_inductor,compile_time_instruction_count,15240000000,
|
||||
|
||||
|
||||
|
||||
basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,17720000000,0.1
|
||||
basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,17020000000,0.1
|
||||
|
||||
|
||||
|
||||
@ -34,11 +34,11 @@ basic_modules_ListOfLinears_inductor_gpu,compile_time_instruction_count,11090000
|
||||
|
||||
|
||||
|
||||
update_hint_regression,compile_time_instruction_count,1645000000,0.1
|
||||
update_hint_regression,compile_time_instruction_count,1719000000,0.1
|
||||
|
||||
|
||||
|
||||
sum_floordiv_regression,compile_time_instruction_count,3813000000,0.1
|
||||
sum_floordiv_regression,compile_time_instruction_count,3686995725,0.1
|
||||
|
||||
|
||||
|
||||
@ -50,31 +50,31 @@ symint_sum_loop,compile_time_instruction_count,4299000000,0.1
|
||||
|
||||
|
||||
|
||||
aotdispatcher_inference_nosubclass_cpu,compile_time_instruction_count,1793000000,0.1
|
||||
aotdispatcher_inference_nosubclass_cpu,compile_time_instruction_count,1869000000,0.1
|
||||
|
||||
|
||||
|
||||
aotdispatcher_inference_subclass_cpu,compile_time_instruction_count,5120000000,0.1
|
||||
aotdispatcher_inference_subclass_cpu,compile_time_instruction_count,5281000000,0.1
|
||||
|
||||
|
||||
|
||||
aotdispatcher_partitioner_cpu,compile_time_instruction_count,7936000000,0.1
|
||||
aotdispatcher_partitioner_cpu,compile_time_instruction_count,8333000000,0.1
|
||||
|
||||
|
||||
|
||||
aotdispatcher_partitioner_cpu2,compile_time_instruction_count,1848000000,0.1
|
||||
aotdispatcher_partitioner_cpu2,compile_time_instruction_count,1909000000,0.1
|
||||
|
||||
|
||||
|
||||
aotdispatcher_training_nosubclass_cpu,compile_time_instruction_count,3152000000,0.1
|
||||
aotdispatcher_training_nosubclass_cpu,compile_time_instruction_count,3442000000,0.1
|
||||
|
||||
|
||||
|
||||
aotdispatcher_training_subclass_cpu,compile_time_instruction_count,8301000000,0.1
|
||||
aotdispatcher_training_subclass_cpu,compile_time_instruction_count,9239000000,0.1
|
||||
|
||||
|
||||
|
||||
mm_loop_inductor_gpu,compile_time_instruction_count,4958000000,0.1
|
||||
mm_loop_inductor_gpu,compile_time_instruction_count,4820968837,0.1
|
||||
|
||||
|
||||
|
||||
@ -82,8 +82,8 @@ mm_loop_inductor_dynamic_gpu,compile_time_instruction_count,9051000000,0.1
|
||||
|
||||
|
||||
|
||||
basic_NestedModule_eager,compile_time_instruction_count,9990000000,0.1
|
||||
basic_NestedModule_eager,compile_time_instruction_count,9554000000,0.1
|
||||
|
||||
|
||||
|
||||
basic_InlineMod_eager,compile_time_instruction_count,8126000000,0.1
|
||||
basic_InlineMod_eager,compile_time_instruction_count,7618000000,0.1
|
||||
|
||||
|
@ -43,7 +43,6 @@ tolerance:
|
||||
- doctr_reco_predictor
|
||||
- drq
|
||||
- phlippe_resnet
|
||||
- pytorch_CycleGAN_and_pix2pix
|
||||
|
||||
higher_bf16:
|
||||
- doctr_reco_predictor
|
||||
|
||||
@ -44,101 +44,21 @@ PyTorch,div_,div__M1_N1_K1_cpu_dtype_onetorch.float32_dtype_twotorch.float32,sho
|
||||
PyTorch,div_,div__M64_N64_K64_cpu_dtype_onetorch.float32_dtype_twotorch.float32,short,False,59.241161,0.000000
|
||||
PyTorch,div_,div__M64_N64_K128_cpu_dtype_onetorch.float32_dtype_twotorch.float32,short,False,59.852816,0.000000
|
||||
PyTorch,add,"add_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.float32",short,False,57.006677,0.000000
|
||||
PyTorch,add,"add_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.bfloat16",short,False,88.167000,0.000000
|
||||
PyTorch,add,"add_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.float64",short,False,57.519000,0.000000
|
||||
PyTorch,sub,"sub_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.float32",short,False,55.606088,0.000000
|
||||
PyTorch,sub,"sub_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.bfloat16",short,False,86.551000,0.000000
|
||||
PyTorch,sub,"sub_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.float64",short,False,57.864088,0.000000
|
||||
PyTorch,div,"div_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.float32",short,False,58.529255,0.000000
|
||||
PyTorch,div,"div_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.bfloat16",short,False,71.641000,0.000000
|
||||
PyTorch,div,"div_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.float64",short,False,83.073000,0.000000
|
||||
PyTorch,mul,"mul_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.float32",short,False,54.645077,0.000000
|
||||
PyTorch,mul,"mul_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.bfloat16",short,False,67.570000,0.000000
|
||||
PyTorch,mul,"mul_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.float64",short,False,57.895000,0.000000
|
||||
PyTorch,add,add_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,4.397014,0.000000
|
||||
PyTorch,add,add_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,7.739000,0.000000
|
||||
PyTorch,add,add_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,7.786000,0.000000
|
||||
PyTorch,add,add_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,1.911000,0.000000
|
||||
PyTorch,add,add_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,59.243500,0.000000
|
||||
PyTorch,add,add_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,105.066000,0.000000
|
||||
PyTorch,add,add_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,106.076000,0.000000
|
||||
PyTorch,add,add_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,47.225000,0.000000
|
||||
PyTorch,add,add_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,57.947691,0.000000
|
||||
PyTorch,add,add_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,107.291000,0.000000
|
||||
PyTorch,add,add_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,107.224000,0.000000
|
||||
PyTorch,add,add_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,47.912000,0.000000
|
||||
PyTorch,sub,sub_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,1.925851,0.000000
|
||||
PyTorch,sub,sub_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,8.0240000,0.000000
|
||||
PyTorch,sub,sub_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,8.069000,0.000000
|
||||
PyTorch,sub,sub_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,1.938000,0.000000
|
||||
PyTorch,sub,sub_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,57.308320,0.000000
|
||||
PyTorch,sub,sub_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,107.091000,0.000000
|
||||
PyTorch,sub,sub_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,108.710000,0.000000
|
||||
PyTorch,sub,sub_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,47.502000,0.000000
|
||||
PyTorch,sub,sub_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,57.787743,0.000000
|
||||
PyTorch,sub,sub_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,108.863000,0.000000
|
||||
PyTorch,sub,sub_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,108.939000,0.000000
|
||||
PyTorch,sub,sub_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,47.603000,0.000000
|
||||
PyTorch,div,div_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,7.978539,0.000000
|
||||
PyTorch,div,div_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,8.741000,0.000000
|
||||
PyTorch,div,div_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,8.757000,0.000000
|
||||
PyTorch,div,div_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,8.774000,0.000000
|
||||
PyTorch,div,div_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,159.754860,0.000000
|
||||
PyTorch,div,div_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,165.552000,0.000000
|
||||
PyTorch,div,div_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,165.755000,0.000000
|
||||
PyTorch,div,div_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,165.714000,0.000000
|
||||
PyTorch,div,div_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,165.360235,0.000000
|
||||
PyTorch,div,div_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,168.376000,0.000000
|
||||
PyTorch,div,div_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,169.604000,0.000000
|
||||
PyTorch,div,div_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,168.428000,0.000000
|
||||
PyTorch,mul,mul_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,3.928136,0.000000
|
||||
PyTorch,mul,mul_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,7.402000,0.000000
|
||||
PyTorch,mul,mul_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,7.567000,0.000000
|
||||
PyTorch,mul,mul_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,4.020000,0.000000
|
||||
PyTorch,mul,mul_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,56.413499,0.000000
|
||||
PyTorch,mul,mul_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,104.638000,0.000000
|
||||
PyTorch,mul,mul_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,104.335000,0.000000
|
||||
PyTorch,mul,mul_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,48.612000,0.000000
|
||||
PyTorch,mul,mul_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,55.925090,0.000000
|
||||
PyTorch,mul,mul_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,106.110000,0.000000
|
||||
PyTorch,mul,mul_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,106.389000,0.000000
|
||||
PyTorch,mul,mul_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,48.195000,0.000000
|
||||
PyTorch,asr,asr_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,1.989000,0.000000
|
||||
PyTorch,asr,asr_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,7.999000,0.000000
|
||||
PyTorch,asr,asr_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,7.939000,0.000000
|
||||
PyTorch,asr,asr_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,1.980000,0.000000
|
||||
PyTorch,asr,asr_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,54.408000,0.000000
|
||||
PyTorch,asr,asr_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,105.647000,0.000000
|
||||
PyTorch,asr,asr_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,106.476000,0.000000
|
||||
PyTorch,asr,asr_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,48.784000,0.000000
|
||||
PyTorch,asr,asr_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,55.583000,0.000000
|
||||
PyTorch,asr,asr_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,108.083000,0.000000
|
||||
PyTorch,asr,asr_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,107.663000,0.000000
|
||||
PyTorch,asr,asr_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,48.283000,0.000000
|
||||
PyTorch,lsl,lsl_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,1.986000,0.000000
|
||||
PyTorch,lsl,lsl_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,7.676000,0.000000
|
||||
PyTorch,lsl,lsl_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,7.618000,0.000000
|
||||
PyTorch,lsl,lsl_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,1.982000,0.000000
|
||||
PyTorch,lsl,lsl_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,54.698000,0.000000
|
||||
PyTorch,lsl,lsl_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,105.899000,0.000000
|
||||
PyTorch,lsl,lsl_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,106.741000,0.000000
|
||||
PyTorch,lsl,lsl_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,51.182000,0.000000
|
||||
PyTorch,lsl,lsl_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,55.290000,0.000000
|
||||
PyTorch,lsl,lsl_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,107.744000,0.000000
|
||||
PyTorch,lsl,lsl_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,107.820000,0.000000
|
||||
PyTorch,lsl,lsl_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,51.298000,0.000000
|
||||
PyTorch,xor,xor_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,1.988000,0.000000
|
||||
PyTorch,xor,xor_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,7.689000,0.000000
|
||||
PyTorch,xor,xor_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,7.695000,0.000000
|
||||
PyTorch,xor,xor_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,1.978000,0.000000
|
||||
PyTorch,xor,xor_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,54.934000,0.000000
|
||||
PyTorch,xor,xor_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,105.217000,0.000000
|
||||
PyTorch,xor,xor_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,104.215000,0.000000
|
||||
PyTorch,xor,xor_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,47.115000,0.000000
|
||||
PyTorch,xor,xor_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,55.974000,0.000000
|
||||
PyTorch,xor,xor_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,106.828000,0.000000
|
||||
PyTorch,xor,xor_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,106.879000,0.000000
|
||||
PyTorch,xor,xor_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,48.197000,0.000000
|
||||
PyTorch,logical_and,"logical_and_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.bool",short,False,78.404254,0.000000
|
||||
PyTorch,logical_and,logical_and_M1_N1_K1_cpu_dtype_onetorch.bool_dtype_twotorch.bool,short,False,5.354032,0.000000
|
||||
PyTorch,logical_and,logical_and_M64_N64_K64_cpu_dtype_onetorch.bool_dtype_twotorch.bool,short,False,54.072783,0.000000
|
||||
@ -151,9 +71,6 @@ PyTorch,baddbmm,baddbmm_B2_M1_N8_K2_cpu_dtypetorch.float32,short,False,6.631313,
|
||||
PyTorch,baddbmm,baddbmm_B2_M1_N8_K2_cpu_dtypetorch.bfloat16,short,False,6.476986,0.000000
|
||||
PyTorch,baddbmm,baddbmm_B128_M64_N32_K64_cpu_dtypetorch.float32,short,False,266.065131,0.000000
|
||||
PyTorch,baddbmm,baddbmm_B128_M64_N32_K64_cpu_dtypetorch.bfloat16,short,False,295.503063,0.000000
|
||||
PyTorch,all,all_M1_N1_K1_cpu,short,False,5.773000,0.000000
|
||||
PyTorch,all,all_M64_N64_K64_cpu,short,False,89.427000,0.000000
|
||||
PyTorch,all,all_M64_N64_K128_cpu,short,False,120.119000,0.000000
|
||||
PyTorch,cat,"cat_sizes(1,1,1)_N2_dim0_cpu",short,False,4.301950,0.000000
|
||||
PyTorch,cat,"cat_sizes(512,512,2)_N2_dim1_cpu",short,False,99.093415,0.000000
|
||||
PyTorch,cat,"cat_sizes(128,1024,2)_N2_dim1_cpu",short,False,96.771578,0.000000
|
||||
|
||||
|
@ -580,9 +580,6 @@ class BenchmarkRunner:
|
||||
else "unknown"
|
||||
)
|
||||
|
||||
# Extract operator name from test_name
|
||||
operator_name = test_name.split("_")[0]
|
||||
|
||||
# Create the record
|
||||
@dataclass
|
||||
class BenchmarkInfo:
|
||||
@ -596,7 +593,6 @@ class BenchmarkRunner:
|
||||
name: str
|
||||
type: str
|
||||
origins: list[str]
|
||||
extra_info: dict[str, Any]
|
||||
|
||||
@dataclass
|
||||
class MetricInfo:
|
||||
@ -622,14 +618,10 @@ class BenchmarkRunner:
|
||||
"device": device,
|
||||
"arch": device_arch,
|
||||
"use_compile": use_compile,
|
||||
"operator_name": operator_name,
|
||||
},
|
||||
),
|
||||
model=ModelInfo(
|
||||
name=test_name,
|
||||
type="micro-benchmark",
|
||||
origins=["pytorch"],
|
||||
extra_info={"operator_name": operator_name},
|
||||
name=test_name, type="micro-benchmark", origins=["pytorch"]
|
||||
),
|
||||
metric=MetricInfo(
|
||||
name="latency",
|
||||
|
||||
@ -25,7 +25,7 @@ binary_configs_broadcast = op_bench.config_list(
|
||||
],
|
||||
cross_product_configs={
|
||||
"device": ["cpu"],
|
||||
"dtype": [torch.float, torch.bfloat16, torch.float64],
|
||||
"dtype": [torch.float],
|
||||
},
|
||||
tags=["short"],
|
||||
)
|
||||
@ -71,8 +71,8 @@ binary_short_configs = op_bench.config_list(
|
||||
],
|
||||
cross_product_configs={
|
||||
"device": ["cpu", "cuda"],
|
||||
"dtype_one": [torch.int32, torch.uint8],
|
||||
"dtype_two": [torch.int32, torch.uint8],
|
||||
"dtype_one": [torch.int32],
|
||||
"dtype_two": [torch.int32],
|
||||
},
|
||||
tags=["short"],
|
||||
)
|
||||
@ -82,8 +82,8 @@ binary_long_configs = op_bench.cross_product_configs(
|
||||
N=[32, 64],
|
||||
K=[256, 512],
|
||||
device=["cpu", "cuda"],
|
||||
dtype_one=[torch.int8, torch.int32, torch.uint8],
|
||||
dtype_two=[torch.int8, torch.int32, torch.uint8],
|
||||
dtype_one=[torch.int8, torch.int32],
|
||||
dtype_two=[torch.int8, torch.int32],
|
||||
tags=["long"],
|
||||
)
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -176,8 +176,8 @@ THIRD_PARTY_LIBS = {
|
||||
"omp": ["//xplat/third-party/linker_lib:omp", "//third_party:no-op"],
|
||||
"pocketfft": ["//third-party/pocket_fft:pocketfft", "//third_party:pocketfft_header"],
|
||||
"psimd": ["//xplat/third-party/psimd:psimd", "//third_party:psimd"],
|
||||
"pthreadpool": ["fbsource//xplat/third-party/pthreadpool:pthreadpool", "//third_party:pthreadpool"],
|
||||
"pthreadpool_header": ["fbsource//xplat/third-party/pthreadpool:pthreadpool_header", "//third_party:pthreadpool_header"],
|
||||
"pthreadpool": ["//xplat/third-party/pthreadpool:pthreadpool", "//third_party:pthreadpool"],
|
||||
"pthreadpool_header": ["//xplat/third-party/pthreadpool:pthreadpool_header", "//third_party:pthreadpool_header"],
|
||||
"moodycamel": ["//third-party/moodycamel:moodycamel", "//third_party:moodycamel"],
|
||||
"pyyaml": ["//third-party/pypi/pyyaml:pyyaml", "//third_party:pyyaml"],
|
||||
"rt": ["//xplat/third-party/linker_lib:rt", "//third_party:rt"],
|
||||
@ -1729,10 +1729,8 @@ def define_buck_targets(
|
||||
"torch/csrc/jit/backends/backend_debug_info.cpp",
|
||||
"torch/csrc/jit/backends/backend_interface.cpp",
|
||||
],
|
||||
compiler_flags = get_pt_compiler_flags() + select({
|
||||
"DEFAULT": [],
|
||||
"ovr_config//os:android": c2_fbandroid_xplat_compiler_flags
|
||||
}),
|
||||
compiler_flags = get_pt_compiler_flags(),
|
||||
fbandroid_compiler_flags = c2_fbandroid_xplat_compiler_flags,
|
||||
# @lint-ignore BUCKLINT link_whole
|
||||
link_whole = True,
|
||||
linker_flags = get_no_as_needed_linker_flag(),
|
||||
@ -2025,9 +2023,6 @@ def define_buck_targets(
|
||||
"ovr_config//os:android-x86_64": [
|
||||
"-mssse3",
|
||||
],
|
||||
}) + select({
|
||||
"DEFAULT": [],
|
||||
"ovr_config//os:android": c2_fbandroid_xplat_compiler_flags,
|
||||
}),
|
||||
exported_preprocessor_flags = get_aten_preprocessor_flags(),
|
||||
exported_deps = [
|
||||
|
||||
@ -855,7 +855,6 @@ libtorch_python_cuda_core_sources = [
|
||||
"torch/csrc/cuda/Stream.cpp",
|
||||
"torch/csrc/cuda/Graph.cpp",
|
||||
"torch/csrc/cuda/MemPool.cpp",
|
||||
"torch/csrc/cuda/GreenContext.cpp",
|
||||
"torch/csrc/cuda/shared/cudart.cpp",
|
||||
"torch/csrc/cuda/shared/nvtx.cpp",
|
||||
"torch/csrc/cuda/utils.cpp",
|
||||
|
||||
@ -13,17 +13,7 @@
|
||||
namespace c10::CachingAllocator {
|
||||
|
||||
// "large" allocations may be packed in 20 MiB blocks
|
||||
constexpr size_t kLargeBuffer = 20971520;
|
||||
// "small" allocations are packed in 2 MiB blocks
|
||||
constexpr size_t kSmallBuffer = 2097152;
|
||||
// all sizes are rounded to at least 512 bytes
|
||||
constexpr size_t kMinBlockSize = 512;
|
||||
// largest "small" allocation is 1 MiB
|
||||
constexpr size_t kSmallSize = 1048576;
|
||||
// allocations between 1 and 10 MiB may use kLargeBuffer
|
||||
constexpr size_t kMinLargeAlloc = 10485760;
|
||||
// round up large allocations to 2 MiB
|
||||
constexpr size_t kRoundLarge = 2097152;
|
||||
const size_t kLargeBuffer = 20971520;
|
||||
|
||||
// A utility class for tokenizing allocator configuration strings into discrete
|
||||
// parts. For example, the config string:
|
||||
|
||||
@ -223,7 +223,7 @@ inline DispatchKey backendToDispatchKey(Backend b) {
|
||||
case Backend::PrivateUse1:
|
||||
return DispatchKey::PrivateUse1;
|
||||
default:
|
||||
TORCH_CHECK(false, "Unknown backend");
|
||||
throw std::runtime_error("Unknown backend");
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -336,7 +336,7 @@ class C10_API Scalar {
|
||||
} else if (isBoolean()) {
|
||||
return ScalarType::Bool;
|
||||
} else {
|
||||
TORCH_CHECK(false, "Unknown scalar type.");
|
||||
throw std::runtime_error("Unknown scalar type.");
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -228,7 +228,7 @@ std::pair<std::string, std::string> getDtypeNames(c10::ScalarType scalarType) {
|
||||
case c10::ScalarType::Float4_e2m1fn_x2:
|
||||
return std::make_pair("float4_e2m1fn_x2", "");
|
||||
default:
|
||||
TORCH_CHECK(false, "Unimplemented scalar type");
|
||||
throw std::runtime_error("Unimplemented scalar type");
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -52,6 +52,19 @@ AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(SPECIALIZE_CppTypeToScalarType)
|
||||
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_CONSTANT)
|
||||
#undef DEFINE_CONSTANT
|
||||
|
||||
inline const char* toString(ScalarType t) {
|
||||
#define DEFINE_CASE(_, name) \
|
||||
case ScalarType::name: \
|
||||
return #name;
|
||||
|
||||
switch (t) {
|
||||
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_CASE)
|
||||
default:
|
||||
return "UNKNOWN_SCALAR";
|
||||
}
|
||||
#undef DEFINE_CASE
|
||||
}
|
||||
|
||||
inline size_t elementSize(ScalarType t) {
|
||||
#define CASE_ELEMENTSIZE_CASE(ctype, name) \
|
||||
case ScalarType::name: \
|
||||
@ -137,6 +150,22 @@ inline ScalarType toQIntType(ScalarType t) {
|
||||
}
|
||||
}
|
||||
|
||||
inline ScalarType toUnderlying(ScalarType t) {
|
||||
switch (t) {
|
||||
case ScalarType::QUInt8:
|
||||
case ScalarType::QUInt4x2:
|
||||
[[fallthrough]];
|
||||
case ScalarType::QUInt2x4:
|
||||
return ScalarType::Byte;
|
||||
case ScalarType::QInt8:
|
||||
return ScalarType::Char;
|
||||
case ScalarType::QInt32:
|
||||
return ScalarType::Int;
|
||||
default:
|
||||
return t;
|
||||
}
|
||||
}
|
||||
|
||||
inline bool isSignedType(ScalarType t) {
|
||||
#define CASE_ISSIGNED(name) \
|
||||
case ScalarType::name: \
|
||||
@ -279,6 +308,12 @@ inline bool canCast(const ScalarType from, const ScalarType to) {
|
||||
|
||||
C10_API ScalarType promoteTypes(ScalarType a, ScalarType b);
|
||||
|
||||
inline std::ostream& operator<<(
|
||||
std::ostream& stream,
|
||||
at::ScalarType scalar_type) {
|
||||
return stream << toString(scalar_type);
|
||||
}
|
||||
|
||||
// Returns a pair of strings representing the names for each dtype.
|
||||
// The returned pair is (name, legacy_name_if_applicable)
|
||||
C10_API std::pair<std::string, std::string> getDtypeNames(
|
||||
|
||||
@ -87,7 +87,9 @@ bool ThreadPool::inThreadPool() const {
|
||||
}
|
||||
|
||||
void ThreadPool::run(std::function<void()> func) {
|
||||
TORCH_CHECK(threads_.size() > 0, "No threads to run a task");
|
||||
if (threads_.empty()) {
|
||||
throw std::runtime_error("No threads to run a task");
|
||||
}
|
||||
std::unique_lock<std::mutex> lock(mutex_);
|
||||
|
||||
// Set task and signal condition variable so that a worker thread will
|
||||
|
||||
@ -131,6 +131,15 @@ namespace Native {
|
||||
* notifyCaptureDestroy.
|
||||
*/
|
||||
|
||||
constexpr size_t kMinBlockSize =
|
||||
512; // all sizes are rounded to at least 512 bytes
|
||||
constexpr size_t kSmallSize = 1048576; // largest "small" allocation is 1 MiB
|
||||
constexpr size_t kSmallBuffer =
|
||||
2097152; // "small" allocations are packed in 2 MiB blocks
|
||||
constexpr size_t kMinLargeAlloc =
|
||||
10485760; // allocations between 1 and 10 MiB may use kLargeBuffer
|
||||
constexpr size_t kRoundLarge = 2097152; // round up large allocations to 2 MiB
|
||||
|
||||
static char SHAREABLE_HANDLE_VERSION = 2;
|
||||
enum ShareableHandleType : char {
|
||||
SHAREABLE_CUDA_MALLOC = 'c',
|
||||
@ -4469,10 +4478,7 @@ struct BackendStaticInitializer {
|
||||
if (key == "backend") {
|
||||
tokenizer.checkToken(++i, ":");
|
||||
i++; // Move to the value after the colon
|
||||
// break up token to trick hipify
|
||||
if (tokenizer[i] ==
|
||||
"c"
|
||||
"udaMallocAsync"
|
||||
if (tokenizer[i] == "cudaMallocAsync"
|
||||
#ifdef USE_ROCM
|
||||
// convenience for ROCm users to allow either CUDA or HIP env var
|
||||
|| tokenizer[i] == "hipMallocAsync"
|
||||
|
||||
@ -913,9 +913,7 @@ struct CudaMallocAsyncAllocator : public CUDAAllocator {
|
||||
}
|
||||
}
|
||||
std::string name() override {
|
||||
// break up token to trick hipify
|
||||
return "c"
|
||||
"udaMallocAsync";
|
||||
return "cudaMallocAsync";
|
||||
}
|
||||
void copy_data(void* dest, const void* src, std::size_t count) const final {
|
||||
C10_CUDA_CHECK(
|
||||
|
||||
@ -51,17 +51,6 @@
|
||||
|
||||
#if defined(CUDA_VERSION) && (CUDA_VERSION >= 12030)
|
||||
#define C10_LIBCUDA_DRIVER_API_OPTIONAL(_) \
|
||||
_(cuCtxFromGreenCtx, 12080) \
|
||||
_(cuCtxGetCurrent, 12080) \
|
||||
_(cuCtxPopCurrent, 12080) \
|
||||
_(cuCtxPushCurrent, 12080) \
|
||||
_(cuCtxSetCurrent, 12080) \
|
||||
_(cuGreenCtxCreate, 12080) \
|
||||
_(cuGreenCtxDestroy, 12080) \
|
||||
_(cuDevSmResourceSplitByCount, 12080) \
|
||||
_(cuDeviceGet, 12080) \
|
||||
_(cuDeviceGetDevResource, 12080) \
|
||||
_(cuDevResourceGenerateDesc, 12080) \
|
||||
_(cuMulticastAddDevice, 12030) \
|
||||
_(cuMulticastBindMem, 12030) \
|
||||
_(cuMulticastCreate, 12030) \
|
||||
|
||||
@ -328,21 +328,6 @@ struct pair {
|
||||
T2 second;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
static T conj(T a) {
|
||||
return a;
|
||||
}
|
||||
|
||||
template <>
|
||||
half2 conj(half2 a) {
|
||||
return half2(a.x, -a.y);
|
||||
}
|
||||
|
||||
template <>
|
||||
float2 conj(float2 a) {
|
||||
return float2(a.x, -a.y);
|
||||
}
|
||||
|
||||
#define INSTANTIATE_FOR_ALL_TYPES(MACRO) \
|
||||
MACRO(float); \
|
||||
MACRO(half); \
|
||||
|
||||
@ -45,7 +45,14 @@ constexpr bool is_pod_v = is_pod<T>::value;
|
||||
|
||||
namespace guts {
|
||||
|
||||
#if defined(__HIP__)
|
||||
#if defined(__cpp_lib_apply) && !defined(__CUDA_ARCH__) && !defined(__HIP__)
|
||||
|
||||
template <class F, class Tuple>
|
||||
C10_HOST_DEVICE inline constexpr decltype(auto) apply(F&& f, Tuple&& t) {
|
||||
return std::apply(std::forward<F>(f), std::forward<Tuple>(t));
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
// Implementation from http://en.cppreference.com/w/cpp/utility/apply (but
|
||||
// modified)
|
||||
|
||||
@ -14,6 +14,16 @@ using namespace c10::CachingDeviceAllocator;
|
||||
|
||||
// newly allocated memory with 512-byte alignment.
|
||||
constexpr size_t kDeviceAlignment = 512;
|
||||
// all sizes are rounded to at least 512 bytes
|
||||
constexpr size_t kMinBlockSize = 512;
|
||||
// largest "small" allocation is 1 MiB
|
||||
constexpr size_t kSmallSize = 1048576;
|
||||
// "small" allocations are packed in 2 MiB blocks
|
||||
constexpr size_t kSmallBuffer = 2097152;
|
||||
// allocations between 1 and 10 MiB may use kLargeBuffer
|
||||
constexpr size_t kMinLargeAlloc = 10485760;
|
||||
// round up large allocations to 2 MiB
|
||||
constexpr size_t kRoundLarge = 2097152;
|
||||
|
||||
namespace {
|
||||
using stream_set = ska::flat_hash_set<xpu::XPUStream>;
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user