mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-28 02:04:53 +08:00
Compare commits
124 Commits
ciflow/tru
...
update_sub
| Author | SHA1 | Date | |
|---|---|---|---|
| 0d3f74bb3a | |||
| 90b30ebf7e | |||
| 173bcda436 | |||
| 6530bc70fb | |||
| 4c38887346 | |||
| 81fa4a204c | |||
| 4e6afa8c07 | |||
| 79aa88cc5d | |||
| fa4cb91846 | |||
| c58d0ad85d | |||
| 000f49551b | |||
| 9940e894ea | |||
| 27302a4932 | |||
| 507614ba43 | |||
| 86f9f1d0ab | |||
| 154e4d36e9 | |||
| a2b6afeac5 | |||
| 262830d86c | |||
| e4c01011c2 | |||
| a60d9e1f6d | |||
| f863550192 | |||
| 84b14f3a10 | |||
| 5121499f6b | |||
| 8f80892359 | |||
| cdb60e44eb | |||
| 25909d2629 | |||
| c7eee49525 | |||
| 621ba05107 | |||
| 39a70cead1 | |||
| d97f6550a2 | |||
| 516e58965a | |||
| b55b779ad3 | |||
| 74e53d0761 | |||
| 798a6d2be1 | |||
| b0e9c86971 | |||
| 661a56002f | |||
| c9bc00f016 | |||
| ec51b139e1 | |||
| eb83c3ca23 | |||
| 7924e3aacf | |||
| 78bcfcf870 | |||
| 1e2e7cb18b | |||
| 003601a70d | |||
| 1d58d5fe25 | |||
| de7fdfe41a | |||
| b31bad1b8f | |||
| 2efcf3ca98 | |||
| 761f946043 | |||
| 8aa465f18e | |||
| 0a5d68d92d | |||
| 42bd210fff | |||
| 1d13c314b3 | |||
| 0c9763a5a0 | |||
| 79a4a9c02e | |||
| 9d0b77f4cd | |||
| d486eee234 | |||
| cddd5f74ab | |||
| dfdb68e51f | |||
| 98c818320a | |||
| cc20b7ad72 | |||
| bc11a42b3f | |||
| 4fc06f2e0a | |||
| 82473c3d59 | |||
| b6a4236e5d | |||
| b04173be9b | |||
| 32ac38f85d | |||
| c9b49e506e | |||
| 6038e476e8 | |||
| 2c851c16e5 | |||
| 31584f2d91 | |||
| 0442125362 | |||
| fdcf402d82 | |||
| 13cda9b89e | |||
| fa6d911dda | |||
| 0db6bcc015 | |||
| 60ac039998 | |||
| 380d440d1c | |||
| 9038a30cee | |||
| 690c8c13b9 | |||
| 28ee6b62ed | |||
| 81577bdb3f | |||
| e67e3d95f3 | |||
| 27af8480ea | |||
| 6494cdc40c | |||
| ac7074efa2 | |||
| 263901cec4 | |||
| c12293dcbe | |||
| 5a4997dcae | |||
| 47f638eae7 | |||
| 882b834082 | |||
| b146ea411e | |||
| 8625ffbd45 | |||
| 0977cc4474 | |||
| d9a55faccc | |||
| 75b8295868 | |||
| defb6a80d8 | |||
| f8fccb1e48 | |||
| 5aac4cfce4 | |||
| baf91bbbfc | |||
| cbcb4f7768 | |||
| 2b93d5b450 | |||
| 6b7cd48e7e | |||
| bf5aa9e42e | |||
| b1eb6dede5 | |||
| 673060beae | |||
| 2e8e9a59a8 | |||
| fb277a5916 | |||
| 73fa0d0c63 | |||
| 36c21cc84e | |||
| 0b68814b44 | |||
| e64a814ae7 | |||
| 0b58d87aec | |||
| 757975ad50 | |||
| 291712026b | |||
| 3e77a2b478 | |||
| 82ef1b5db3 | |||
| 5f370f5c42 | |||
| 05b2e02cb4 | |||
| 12f742941d | |||
| 35180fafee | |||
| c746feb86a | |||
| c5f26db5bf | |||
| 18e99b6d45 | |||
| ab9e466928 |
354
.claude/skills/pytorch-docstring.md
Normal file
354
.claude/skills/pytorch-docstring.md
Normal file
@ -0,0 +1,354 @@
|
||||
# PyTorch Docstring Writing Guide
|
||||
|
||||
This skill describes how to write docstrings for functions and methods in the PyTorch project, following the conventions in `torch/_tensor_docs.py` and `torch/nn/functional.py`.
|
||||
|
||||
## General Principles
|
||||
|
||||
- Use **raw strings** (`r"""..."""`) for all docstrings to avoid issues with LaTeX/math backslashes
|
||||
- Follow **Sphinx/reStructuredText** (reST) format for documentation
|
||||
- Be **concise but complete** - include all essential information
|
||||
- Always include **examples** when possible
|
||||
- Use **cross-references** to related functions/classes
|
||||
|
||||
## Docstring Structure
|
||||
|
||||
### 1. Function Signature (First Line)
|
||||
|
||||
Start with the function signature showing all parameters:
|
||||
|
||||
```python
|
||||
r"""function_name(param1, param2, *, kwarg1=default1, kwarg2=default2) -> ReturnType
|
||||
```
|
||||
|
||||
**Notes:**
|
||||
- Include the function name
|
||||
- Show positional and keyword-only arguments (use `*` separator)
|
||||
- Include default values
|
||||
- Show return type annotation
|
||||
- This line should NOT end with a period
|
||||
|
||||
### 2. Brief Description
|
||||
|
||||
Provide a one-line description of what the function does:
|
||||
|
||||
```python
|
||||
r"""conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1) -> Tensor
|
||||
|
||||
Applies a 2D convolution over an input image composed of several input
|
||||
planes.
|
||||
```
|
||||
|
||||
### 3. Mathematical Formulas (if applicable)
|
||||
|
||||
Use Sphinx math directives for mathematical expressions:
|
||||
|
||||
```python
|
||||
.. math::
|
||||
\text{Softmax}(x_{i}) = \frac{\exp(x_i)}{\sum_j \exp(x_j)}
|
||||
```
|
||||
|
||||
Or inline math: `:math:\`x^2\``
|
||||
|
||||
### 4. Cross-References
|
||||
|
||||
Link to related classes and functions using Sphinx roles:
|
||||
|
||||
- `:class:\`~torch.nn.ModuleName\`` - Link to a class
|
||||
- `:func:\`torch.function_name\`` - Link to a function
|
||||
- `:meth:\`~Tensor.method_name\`` - Link to a method
|
||||
- `:attr:\`attribute_name\`` - Reference an attribute
|
||||
- The `~` prefix shows only the last component (e.g., `Conv2d` instead of `torch.nn.Conv2d`)
|
||||
|
||||
**Example:**
|
||||
```python
|
||||
See :class:`~torch.nn.Conv2d` for details and output shape.
|
||||
```
|
||||
|
||||
### 5. Notes and Warnings
|
||||
|
||||
Use admonitions for important information:
|
||||
|
||||
```python
|
||||
.. note::
|
||||
This function doesn't work directly with NLLLoss,
|
||||
which expects the Log to be computed between the Softmax and itself.
|
||||
Use log_softmax instead (it's faster and has better numerical properties).
|
||||
|
||||
.. warning::
|
||||
:func:`new_tensor` always copies :attr:`data`. If you have a Tensor
|
||||
``data`` and want to avoid a copy, use :func:`torch.Tensor.requires_grad_`
|
||||
or :func:`torch.Tensor.detach`.
|
||||
```
|
||||
|
||||
### 6. Args Section
|
||||
|
||||
Document all parameters with type annotations and descriptions:
|
||||
|
||||
```python
|
||||
Args:
|
||||
input (Tensor): input tensor of shape :math:`(\text{minibatch} , \text{in\_channels} , iH , iW)`
|
||||
weight (Tensor): filters of shape :math:`(\text{out\_channels} , kH , kW)`
|
||||
bias (Tensor, optional): optional bias tensor of shape :math:`(\text{out\_channels})`. Default: ``None``
|
||||
stride (int or tuple): the stride of the convolving kernel. Can be a single number or a
|
||||
tuple `(sH, sW)`. Default: 1
|
||||
```
|
||||
|
||||
**Formatting rules:**
|
||||
- Parameter name in **lowercase**
|
||||
- Type in parentheses: `(Type)`, `(Type, optional)` for optional parameters
|
||||
- Description follows the type
|
||||
- For optional parameters, include "Default: ``value``" at the end
|
||||
- Use double backticks for inline code: ``` ``None`` ```
|
||||
- Indent continuation lines by 2 spaces
|
||||
|
||||
### 7. Keyword Args Section (if applicable)
|
||||
|
||||
Sometimes keyword arguments are documented separately:
|
||||
|
||||
```python
|
||||
Keyword args:
|
||||
dtype (:class:`torch.dtype`, optional): the desired type of returned tensor.
|
||||
Default: if None, same :class:`torch.dtype` as this tensor.
|
||||
device (:class:`torch.device`, optional): the desired device of returned tensor.
|
||||
Default: if None, same :class:`torch.device` as this tensor.
|
||||
requires_grad (bool, optional): If autograd should record operations on the
|
||||
returned tensor. Default: ``False``.
|
||||
```
|
||||
|
||||
### 8. Returns Section (if needed)
|
||||
|
||||
Document the return value:
|
||||
|
||||
```python
|
||||
Returns:
|
||||
Tensor: Sampled tensor of same shape as `logits` from the Gumbel-Softmax distribution.
|
||||
If ``hard=True``, the returned samples will be one-hot, otherwise they will
|
||||
be probability distributions that sum to 1 across `dim`.
|
||||
```
|
||||
|
||||
Or simply include it in the function signature line if obvious from context.
|
||||
|
||||
### 9. Examples Section
|
||||
|
||||
Always include examples when possible:
|
||||
|
||||
```python
|
||||
Examples::
|
||||
|
||||
>>> inputs = torch.randn(33, 16, 30)
|
||||
>>> filters = torch.randn(20, 16, 5)
|
||||
>>> F.conv1d(inputs, filters)
|
||||
|
||||
>>> # With square kernels and equal stride
|
||||
>>> filters = torch.randn(8, 4, 3, 3)
|
||||
>>> inputs = torch.randn(1, 4, 5, 5)
|
||||
>>> F.conv2d(inputs, filters, padding=1)
|
||||
```
|
||||
|
||||
**Formatting rules:**
|
||||
- Use `Examples::` with double colon
|
||||
- Use `>>>` prompt for Python code
|
||||
- Include comments with `#` when helpful
|
||||
- Show actual output when it helps understanding (indent without `>>>`)
|
||||
|
||||
### 10. External References
|
||||
|
||||
Link to papers or external documentation:
|
||||
|
||||
```python
|
||||
.. _Link Name:
|
||||
https://arxiv.org/abs/1611.00712
|
||||
```
|
||||
|
||||
Reference them in text: ```See `Link Name`_```
|
||||
|
||||
## Method Types
|
||||
|
||||
### Native Python Functions
|
||||
|
||||
For regular Python functions, use a standard docstring:
|
||||
|
||||
```python
|
||||
def relu(input: Tensor, inplace: bool = False) -> Tensor:
|
||||
r"""relu(input, inplace=False) -> Tensor
|
||||
|
||||
Applies the rectified linear unit function element-wise. See
|
||||
:class:`~torch.nn.ReLU` for more details.
|
||||
"""
|
||||
# implementation
|
||||
```
|
||||
|
||||
### C-Bound Functions (using add_docstr)
|
||||
|
||||
For C-bound functions, use `_add_docstr`:
|
||||
|
||||
```python
|
||||
conv1d = _add_docstr(
|
||||
torch.conv1d,
|
||||
r"""
|
||||
conv1d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1) -> Tensor
|
||||
|
||||
Applies a 1D convolution over an input signal composed of several input
|
||||
planes.
|
||||
|
||||
See :class:`~torch.nn.Conv1d` for details and output shape.
|
||||
|
||||
Args:
|
||||
input: input tensor of shape :math:`(\text{minibatch} , \text{in\_channels} , iW)`
|
||||
weight: filters of shape :math:`(\text{out\_channels} , kW)`
|
||||
...
|
||||
""",
|
||||
)
|
||||
```
|
||||
|
||||
### In-Place Variants
|
||||
|
||||
For in-place operations (ending with `_`), reference the original:
|
||||
|
||||
```python
|
||||
add_docstr_all(
|
||||
"abs_",
|
||||
r"""
|
||||
abs_() -> Tensor
|
||||
|
||||
In-place version of :meth:`~Tensor.abs`
|
||||
""",
|
||||
)
|
||||
```
|
||||
|
||||
### Alias Functions
|
||||
|
||||
For aliases, simply reference the original:
|
||||
|
||||
```python
|
||||
add_docstr_all(
|
||||
"absolute",
|
||||
r"""
|
||||
absolute() -> Tensor
|
||||
|
||||
Alias for :func:`abs`
|
||||
""",
|
||||
)
|
||||
```
|
||||
|
||||
## Common Patterns
|
||||
|
||||
### Shape Documentation
|
||||
|
||||
Use LaTeX math notation for tensor shapes:
|
||||
|
||||
```python
|
||||
:math:`(\text{minibatch} , \text{in\_channels} , iH , iW)`
|
||||
```
|
||||
|
||||
### Reusable Argument Definitions
|
||||
|
||||
For commonly used arguments, define them once and reuse:
|
||||
|
||||
```python
|
||||
common_args = parse_kwargs(
|
||||
"""
|
||||
dtype (:class:`torch.dtype`, optional): the desired type of returned tensor.
|
||||
Default: if None, same as this tensor.
|
||||
"""
|
||||
)
|
||||
|
||||
# Then use with .format():
|
||||
r"""
|
||||
...
|
||||
|
||||
Keyword args:
|
||||
{dtype}
|
||||
{device}
|
||||
""".format(**common_args)
|
||||
```
|
||||
|
||||
### Template Insertion
|
||||
|
||||
Insert reproducibility notes or other common text:
|
||||
|
||||
```python
|
||||
r"""
|
||||
{tf32_note}
|
||||
|
||||
{cudnn_reproducibility_note}
|
||||
""".format(**reproducibility_notes, **tf32_notes)
|
||||
```
|
||||
|
||||
## Complete Example
|
||||
|
||||
Here's a complete example showing all elements:
|
||||
|
||||
```python
|
||||
def gumbel_softmax(
|
||||
logits: Tensor,
|
||||
tau: float = 1,
|
||||
hard: bool = False,
|
||||
eps: float = 1e-10,
|
||||
dim: int = -1,
|
||||
) -> Tensor:
|
||||
r"""
|
||||
Sample from the Gumbel-Softmax distribution and optionally discretize.
|
||||
|
||||
Args:
|
||||
logits (Tensor): `[..., num_features]` unnormalized log probabilities
|
||||
tau (float): non-negative scalar temperature
|
||||
hard (bool): if ``True``, the returned samples will be discretized as one-hot vectors,
|
||||
but will be differentiated as if it is the soft sample in autograd. Default: ``False``
|
||||
dim (int): A dimension along which softmax will be computed. Default: -1
|
||||
|
||||
Returns:
|
||||
Tensor: Sampled tensor of same shape as `logits` from the Gumbel-Softmax distribution.
|
||||
If ``hard=True``, the returned samples will be one-hot, otherwise they will
|
||||
be probability distributions that sum to 1 across `dim`.
|
||||
|
||||
.. note::
|
||||
This function is here for legacy reasons, may be removed from nn.Functional in the future.
|
||||
|
||||
Examples::
|
||||
>>> logits = torch.randn(20, 32)
|
||||
>>> # Sample soft categorical using reparametrization trick:
|
||||
>>> F.gumbel_softmax(logits, tau=1, hard=False)
|
||||
>>> # Sample hard categorical using "Straight-through" trick:
|
||||
>>> F.gumbel_softmax(logits, tau=1, hard=True)
|
||||
|
||||
.. _Link 1:
|
||||
https://arxiv.org/abs/1611.00712
|
||||
"""
|
||||
# implementation
|
||||
```
|
||||
|
||||
## Quick Checklist
|
||||
|
||||
When writing a PyTorch docstring, ensure:
|
||||
|
||||
- [ ] Use raw string (`r"""`)
|
||||
- [ ] Include function signature on first line
|
||||
- [ ] Provide brief description
|
||||
- [ ] Document all parameters in Args section with types
|
||||
- [ ] Include default values for optional parameters
|
||||
- [ ] Use Sphinx cross-references (`:func:`, `:class:`, `:meth:`)
|
||||
- [ ] Add mathematical formulas if applicable
|
||||
- [ ] Include at least one example in Examples section
|
||||
- [ ] Add warnings/notes for important caveats
|
||||
- [ ] Link to related module class with `:class:`
|
||||
- [ ] Use proper math notation for tensor shapes
|
||||
- [ ] Follow consistent formatting and indentation
|
||||
|
||||
## Common Sphinx Roles Reference
|
||||
|
||||
- `:class:\`~torch.nn.Module\`` - Class reference
|
||||
- `:func:\`torch.function\`` - Function reference
|
||||
- `:meth:\`~Tensor.method\`` - Method reference
|
||||
- `:attr:\`attribute\`` - Attribute reference
|
||||
- `:math:\`equation\`` - Inline math
|
||||
- `:ref:\`label\`` - Internal reference
|
||||
- ``` ``code`` ``` - Inline code (use double backticks)
|
||||
|
||||
## Additional Notes
|
||||
|
||||
- **Indentation**: Use 4 spaces for code, 2 spaces for continuation of parameter descriptions
|
||||
- **Line length**: Try to keep lines under 100 characters when possible
|
||||
- **Periods**: End sentences with periods, but not the signature line
|
||||
- **Backticks**: Use double backticks for code: ``` ``True`` ``None`` ``False`` ```
|
||||
- **Types**: Common types are `Tensor`, `int`, `float`, `bool`, `str`, `tuple`, `list`, etc.
|
||||
7
.github/actions/setup-rocm/action.yml
vendored
7
.github/actions/setup-rocm/action.yml
vendored
@ -124,3 +124,10 @@ runs:
|
||||
id: login-ecr
|
||||
continue-on-error: true
|
||||
uses: aws-actions/amazon-ecr-login@062b18b96a7aff071d4dc91bc00c4c1a7945b076 # v2.0.1
|
||||
|
||||
- name: Preserve github env variables for use in docker
|
||||
shell: bash
|
||||
run: |
|
||||
env | grep '^GITHUB' >> "${RUNNER_TEMP}/github_env_${GITHUB_RUN_ID}"
|
||||
env | grep '^CI' >> "${RUNNER_TEMP}/github_env_${GITHUB_RUN_ID}"
|
||||
env | grep '^RUNNER' >> "${RUNNER_TEMP}/github_env_${GITHUB_RUN_ID}"
|
||||
|
||||
2
.github/ci_commit_pins/vision.txt
vendored
2
.github/ci_commit_pins/vision.txt
vendored
@ -1 +1 @@
|
||||
faffd5cf673615583da6517275e361cb3dbc77e6
|
||||
1752fe6809b74921644866275ab80244b96e80bc
|
||||
|
||||
5
.github/ci_configs/vllm/Dockerfile
vendored
5
.github/ci_configs/vllm/Dockerfile
vendored
@ -283,6 +283,9 @@ RUN --mount=type=bind,source=${TORCH_WHEELS_PATH},target=/dist \
|
||||
uv pip install --system $(cat torch_build_versions.txt | xargs) --index-url https://download.pytorch.org/whl/nightly/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.'); \
|
||||
fi
|
||||
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
uv pip install --system --pre apache-tvm-ffi==0.1.0b15
|
||||
|
||||
# Install the vllm wheel from previous stage
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
uv pip install --system /wheels/vllm/*.whl --verbose
|
||||
@ -295,6 +298,8 @@ RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
ARG torch_cuda_arch_list='8.0;8.9;9.0a;10.0a;12.0'
|
||||
ENV TORCH_CUDA_ARCH_LIST=${torch_cuda_arch_list}
|
||||
|
||||
# TODO(elainewy): remove this once vllm commit is updated, and install flashinfer from pip
|
||||
# see https://github.com/pytorch/pytorch/pull/165274#issuecomment-3408531784
|
||||
ARG FLASHINFER_GIT_REPO="https://github.com/flashinfer-ai/flashinfer.git"
|
||||
ARG FLASHINFER_GIT_REF="v0.2.14.post1"
|
||||
|
||||
|
||||
9
.github/label_to_label.yml
vendored
9
.github/label_to_label.yml
vendored
@ -15,6 +15,11 @@
|
||||
- "module: reinplacing"
|
||||
then:
|
||||
- "module: pt2-dispatcher"
|
||||
- any:
|
||||
- "vllm-compile"
|
||||
then:
|
||||
- "module: vllm"
|
||||
- "oncall: pt2"
|
||||
- any:
|
||||
- "module: vmap"
|
||||
then:
|
||||
@ -27,10 +32,6 @@
|
||||
- "module: pt2 optimizer"
|
||||
then:
|
||||
- "module: dynamo"
|
||||
- any:
|
||||
- "module: flex attention"
|
||||
then:
|
||||
- "module: higher order operators"
|
||||
- any:
|
||||
- "module: aotinductor"
|
||||
then:
|
||||
|
||||
1
.github/workflows/inductor-periodic.yml
vendored
1
.github/workflows/inductor-periodic.yml
vendored
@ -88,7 +88,6 @@ jobs:
|
||||
with:
|
||||
build-environment: linux-jammy-rocm-py3_10
|
||||
docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3-benchmarks
|
||||
sync-tag: rocm-build
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
{ config: "dynamo_eager_torchbench", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.gfx942.1" },
|
||||
|
||||
3
.github/workflows/pull.yml
vendored
3
.github/workflows/pull.yml
vendored
@ -347,7 +347,8 @@ jobs:
|
||||
uses: ./.github/workflows/_linux-build.yml
|
||||
needs: get-label-type
|
||||
with:
|
||||
sync-tag: linux-xpu-n-build
|
||||
# This should sync with the build in xpu.yml but xpu uses a larger runner
|
||||
# sync-tag: linux-xpu-n-build
|
||||
runner_prefix: ${{ needs.get-label-type.outputs.label-type }}
|
||||
build-environment: linux-jammy-xpu-n-py3.10
|
||||
docker-image-name: ci-image:pytorch-linux-jammy-xpu-n-py3
|
||||
|
||||
1
.github/workflows/rocm-mi300.yml
vendored
1
.github/workflows/rocm-mi300.yml
vendored
@ -45,7 +45,6 @@ jobs:
|
||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
||||
build-environment: linux-noble-rocm-py3.12-mi300
|
||||
docker-image-name: ci-image:pytorch-linux-noble-rocm-n-py3
|
||||
sync-tag: rocm-build
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
{ config: "default", shard: 1, num_shards: 6, runner: "linux.rocm.gpu.gfx942.1" },
|
||||
|
||||
1
.github/workflows/rocm-mi355.yml
vendored
1
.github/workflows/rocm-mi355.yml
vendored
@ -42,7 +42,6 @@ jobs:
|
||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
||||
build-environment: linux-noble-rocm-py3.12-mi355
|
||||
docker-image-name: ci-image:pytorch-linux-noble-rocm-n-py3
|
||||
sync-tag: rocm-build
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
{ config: "default", shard: 1, num_shards: 6, runner: "linux.rocm.gpu.mi355.1" },
|
||||
|
||||
12
.github/workflows/rocm-navi31.yml
vendored
12
.github/workflows/rocm-navi31.yml
vendored
@ -26,11 +26,23 @@ jobs:
|
||||
id-token: write
|
||||
contents: read
|
||||
|
||||
get-label-type:
|
||||
name: get-label-type
|
||||
uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main
|
||||
if: ${{ (github.event_name != 'schedule' || github.repository == 'pytorch/pytorch') && github.repository_owner == 'pytorch' }}
|
||||
with:
|
||||
triggering_actor: ${{ github.triggering_actor }}
|
||||
issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }}
|
||||
curr_branch: ${{ github.head_ref || github.ref_name }}
|
||||
curr_ref_type: ${{ github.ref_type }}
|
||||
|
||||
linux-jammy-rocm-py3_10-build:
|
||||
if: ${{ (github.event_name != 'schedule' || github.repository == 'pytorch/pytorch') && github.repository_owner == 'pytorch' }}
|
||||
name: linux-jammy-rocm-py3.10
|
||||
uses: ./.github/workflows/_linux-build.yml
|
||||
needs: get-label-type
|
||||
with:
|
||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
||||
build-environment: linux-jammy-rocm-py3.10
|
||||
docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3
|
||||
sync-tag: rocm-build
|
||||
|
||||
12
.github/workflows/rocm.yml
vendored
12
.github/workflows/rocm.yml
vendored
@ -26,11 +26,23 @@ jobs:
|
||||
id-token: write
|
||||
contents: read
|
||||
|
||||
get-label-type:
|
||||
name: get-label-type
|
||||
uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main
|
||||
if: ${{ (github.event_name != 'schedule' || github.repository == 'pytorch/pytorch') && github.repository_owner == 'pytorch' }}
|
||||
with:
|
||||
triggering_actor: ${{ github.triggering_actor }}
|
||||
issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }}
|
||||
curr_branch: ${{ github.head_ref || github.ref_name }}
|
||||
curr_ref_type: ${{ github.ref_type }}
|
||||
|
||||
linux-jammy-rocm-py3_10-build:
|
||||
if: ${{ (github.event_name != 'schedule' || github.repository == 'pytorch/pytorch') && github.repository_owner == 'pytorch' }}
|
||||
name: linux-jammy-rocm-py3.10
|
||||
uses: ./.github/workflows/_linux-build.yml
|
||||
needs: get-label-type
|
||||
with:
|
||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
||||
build-environment: linux-jammy-rocm-py3.10
|
||||
docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3
|
||||
sync-tag: rocm-build
|
||||
|
||||
@ -833,8 +833,7 @@ exclude_patterns = [
|
||||
command = [
|
||||
'python3',
|
||||
'tools/linter/adapters/grep_linter.py',
|
||||
'--pattern=cudaSetDevice(',
|
||||
'--pattern=cudaGetDevice(',
|
||||
'--pattern=(cudaSetDevice|cudaGetDevice)\\(',
|
||||
'--linter-name=RAWCUDADEVICE',
|
||||
'--error-name=raw CUDA API usage',
|
||||
"""--error-description=\
|
||||
@ -1138,11 +1137,8 @@ command = [
|
||||
[[linter]]
|
||||
code = 'WORKFLOWSYNC'
|
||||
include_patterns = [
|
||||
'.github/workflows/pull.yml',
|
||||
'.github/workflows/trunk.yml',
|
||||
'.github/workflows/periodic.yml',
|
||||
'.github/workflows/mac-mps.yml',
|
||||
'.github/workflows/slow.yml',
|
||||
'.github/workflows/*.yml',
|
||||
'.github/workflows/*.yaml',
|
||||
]
|
||||
command = [
|
||||
'python3',
|
||||
|
||||
@ -289,14 +289,15 @@ IF(USE_FBGEMM_GENAI)
|
||||
|
||||
set_target_properties(fbgemm_genai PROPERTIES POSITION_INDEPENDENT_CODE ON)
|
||||
|
||||
set(fbgemm_genai_mx8mx8bf16_grouped
|
||||
set(fbgemm_genai_cuh
|
||||
"${FBGEMM_GENAI_SRCS}/cutlass_extensions/mx8mx8bf16_grouped/"
|
||||
"${FBGEMM_GENAI_SRCS}/"
|
||||
)
|
||||
|
||||
target_include_directories(fbgemm_genai PRIVATE
|
||||
${FBGEMM_THIRD_PARTY}/cutlass/include
|
||||
${FBGEMM_THIRD_PARTY}/cutlass/tools/util/include
|
||||
${fbgemm_genai_mx8mx8bf16_grouped}
|
||||
${fbgemm_genai_cuh}
|
||||
${FBGEMM_GENAI_SRCS}/common/include/ # includes fbgemm_gpu/quantize/utils.h, fbgemm_gpu/quantize/tuning_cache.hpp
|
||||
${FBGEMM_GENAI_SRCS}/include/ # includes fbgemm_gpu/torch_ops.h
|
||||
)
|
||||
|
||||
@ -19,6 +19,7 @@
|
||||
#include <ATen/detail/MPSHooksInterface.h>
|
||||
#include <ATen/detail/MTIAHooksInterface.h>
|
||||
#include <ATen/detail/PrivateUse1HooksInterface.h>
|
||||
#include <ATen/detail/XLAHooksInterface.h>
|
||||
#include <ATen/detail/XPUHooksInterface.h>
|
||||
#include <c10/core/QEngine.h>
|
||||
#include <c10/core/impl/DeviceGuardImplInterface.h>
|
||||
@ -88,6 +89,8 @@ class TORCH_API Context {
|
||||
return at::detail::getHIPHooks();
|
||||
} else if (opt_device_type == at::kHPU) {
|
||||
return at::detail::getHPUHooks();
|
||||
} else if (opt_device_type == at::kXLA) {
|
||||
return at::detail::getXLAHooks();
|
||||
} else {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
@ -196,7 +199,7 @@ class TORCH_API Context {
|
||||
return c10::impl::hasDeviceGuardImpl(c10::DeviceType::IPU);
|
||||
}
|
||||
static bool hasXLA() {
|
||||
return c10::impl::hasDeviceGuardImpl(c10::DeviceType::XLA);
|
||||
return detail::getXLAHooks().hasXLA();
|
||||
}
|
||||
static bool hasXPU() {
|
||||
return detail::getXPUHooks().hasXPU();
|
||||
|
||||
@ -109,6 +109,10 @@ TORCH_LIBRARY_IMPL(_, AutogradHPU, m) {
|
||||
m.fallback(AUTOGRAD_FALLBACK);
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL(_, AutogradPrivateUse1, m) {
|
||||
m.fallback(AUTOGRAD_FALLBACK);
|
||||
}
|
||||
|
||||
#undef AUTOGRAD_FALLBACK
|
||||
|
||||
} // namespace
|
||||
|
||||
@ -442,11 +442,17 @@ RegistrationHandleRAII Dispatcher::registerFallback(DispatchKey dispatchKey, Ker
|
||||
|
||||
auto idx = getDispatchTableIndexForDispatchKey(dispatchKey);
|
||||
TORCH_CHECK(idx >= 0 && static_cast<uint64_t>(idx) < backendFallbackKernels_.size(), "idx=", idx);
|
||||
// NB: Perserve BC for registering fallback for AutogradPrivateUse1 multiple time,
|
||||
// refer to https://github.com/pytorch/pytorch/issues/163979 for more informations.
|
||||
TORCH_CHECK(
|
||||
!backendFallbackKernels_[idx].kernel.isValid(),
|
||||
"Tried to register multiple backend fallbacks for the same dispatch key ", dispatchKey, "; previous registration ",
|
||||
backendFallbackKernels_[idx].debug, ", new registration ", debug
|
||||
);
|
||||
dispatchKey == DispatchKey::AutogradPrivateUse1 ||
|
||||
!backendFallbackKernels_[idx].kernel.isValid(),
|
||||
"Tried to register multiple backend fallbacks for the same dispatch key ",
|
||||
dispatchKey,
|
||||
"; previous registration ",
|
||||
backendFallbackKernels_[idx].debug,
|
||||
", new registration ",
|
||||
debug);
|
||||
// NB: inferred function schema is always nullptr for fallbacks, as fallbacks
|
||||
// cannot be unboxed
|
||||
backendFallbackKernels_[idx] = impl::AnnotatedKernel(std::move(kernel), nullptr, std::move(debug));
|
||||
|
||||
@ -185,11 +185,11 @@ struct TORCH_API Type {
|
||||
: repr_(nullptr) {}
|
||||
|
||||
/* implicit */ SingletonOrSharedTypePtr(SingletonTypePtr<T> p)
|
||||
: repr_(p) {}
|
||||
: repr_(makeSingletonSharedPtr(p.get())) {}
|
||||
|
||||
template <typename U, std::enable_if_t<std::is_convertible_v<U*, T*>, bool> = true>
|
||||
/* implicit */ SingletonOrSharedTypePtr(SingletonTypePtr<U> p)
|
||||
: repr_(SingletonTypePtr<T>(p.get())) {}
|
||||
: repr_(makeSingletonSharedPtr(static_cast<T*>(p.get()))) {}
|
||||
|
||||
|
||||
// We need to support construction from T* for pybind. The problem
|
||||
@ -202,8 +202,8 @@ struct TORCH_API Type {
|
||||
// Case 2: if T is exactly Type, we need to do a dynamic_cast to
|
||||
// check if it's a SharedType and do the right thing.
|
||||
//
|
||||
// Case 3: Otherwise, T is not a SharedType. (debug-check this
|
||||
// assumption!) Use a singleton pointer.
|
||||
// Case 3: Otherwise, T is not a SharedType. Use a singleton
|
||||
// pointer.
|
||||
|
||||
template <typename U = T, std::enable_if_t<std::is_base_of_v<SharedType, U>, bool> = true>
|
||||
/* implicit */ SingletonOrSharedTypePtr(T* p) : SingletonOrSharedTypePtr(static_cast<typename detail::as_shared_type<U>::type>(p)->shared_from_this()) {}
|
||||
@ -211,15 +211,15 @@ struct TORCH_API Type {
|
||||
template <typename U = T, std::enable_if_t<std::is_same_v<Type, U>, bool> = true>
|
||||
/* implicit */ SingletonOrSharedTypePtr(T* p) {
|
||||
if (auto* shared_p = dynamic_cast<typename detail::as_shared_type<U>::type>(p)) {
|
||||
repr_ = Repr(shared_p->shared_from_this());
|
||||
repr_ = shared_p->shared_from_this();
|
||||
} else {
|
||||
repr_ = Repr(p);
|
||||
repr_ = makeSingletonSharedPtr(p);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename U = T, std::enable_if_t<!std::is_same_v<Type, U> && !std::is_base_of_v<SharedType, U>, bool> = true>
|
||||
/* implicit */ SingletonOrSharedTypePtr(T* p)
|
||||
: repr_(p) {
|
||||
: repr_(makeSingletonSharedPtr(p)) {
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(dynamic_cast<typename detail::as_shared_type<U>::type>(p) == nullptr);
|
||||
}
|
||||
|
||||
@ -230,19 +230,19 @@ struct TORCH_API Type {
|
||||
~SingletonOrSharedTypePtr() = default;
|
||||
|
||||
T* get() const {
|
||||
return repr_.isSharedAndNonNull() ? repr_.shared_.repr_.get() : static_cast<T*>(repr_.rawRepr().first);
|
||||
return repr_.get();
|
||||
}
|
||||
|
||||
operator bool() const {
|
||||
return repr_.isNonNull();
|
||||
return repr_ != nullptr;
|
||||
}
|
||||
|
||||
bool operator==(std::nullptr_t) const {
|
||||
return !repr_.isNonNull();
|
||||
return repr_ == nullptr;
|
||||
}
|
||||
|
||||
bool operator!=(std::nullptr_t) const {
|
||||
return repr_.isNonNull();
|
||||
return repr_ != nullptr;
|
||||
}
|
||||
|
||||
template <typename U = T, std::enable_if_t<!std::is_same_v<std::remove_const_t<U>, void>, bool> = true>
|
||||
@ -255,138 +255,14 @@ struct TORCH_API Type {
|
||||
}
|
||||
|
||||
private:
|
||||
// NOTE: SharedPtrWrapper exists to work around a baffling bug in
|
||||
// nvcc; see comment in destroy() below.
|
||||
struct SharedPtrWrapper {
|
||||
SharedPtrWrapper(std::shared_ptr<T> &&x)
|
||||
: repr_(std::move(x)) {}
|
||||
std::shared_ptr<T> repr_;
|
||||
};
|
||||
union Repr {
|
||||
Repr() : Repr(nullptr) {}
|
||||
// Use shared_ptr's aliasing constructor to create a non-owning pointer
|
||||
// to a singleton. The lifetime is tied to the null shared_ptr, so there's
|
||||
// no reference counting overhead for the singleton itself.
|
||||
static std::shared_ptr<T> makeSingletonSharedPtr(T* ptr) {
|
||||
return std::shared_ptr<T>(std::shared_ptr<T>(), ptr);
|
||||
}
|
||||
|
||||
explicit Repr(std::shared_ptr<T> x)
|
||||
: shared_(std::move(x)) {}
|
||||
|
||||
explicit Repr(std::nullptr_t)
|
||||
: singletonRepr_(nullptr) {}
|
||||
|
||||
explicit Repr(SingletonTypePtr<T> p)
|
||||
: singletonRepr_(p.get()) {}
|
||||
|
||||
~Repr() {
|
||||
destroy();
|
||||
}
|
||||
|
||||
// NOTE: the only non-UB way to access our null state is through
|
||||
// rawRepr(), because our copy operation doesn't preserve which
|
||||
// union member is active for null pointers.
|
||||
Repr(const Repr& rhs) {
|
||||
if (rhs.isSharedAndNonNull()) {
|
||||
new (&shared_) SharedPtrWrapper(rhs.shared_);
|
||||
} else {
|
||||
singletonRepr_.singleton_ = static_cast<T*>(rhs.rawRepr().first);
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(rhs.singletonRepr_.unused_ == nullptr);
|
||||
singletonRepr_.unused_ = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
Repr(Repr&& rhs) noexcept {
|
||||
if (rhs.isSharedAndNonNull()) {
|
||||
new (&shared_) SharedPtrWrapper(std::move(rhs.shared_));
|
||||
} else {
|
||||
singletonRepr_.singleton_ = static_cast<T*>(rhs.rawRepr().first);
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(rhs.singletonRepr_.unused_ == nullptr);
|
||||
singletonRepr_.unused_ = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
Repr& operator=(const Repr& rhs) {
|
||||
if (&rhs == this) {
|
||||
return *this;
|
||||
}
|
||||
if (rhs.isSharedAndNonNull()) {
|
||||
if (isSharedAndNonNull()) {
|
||||
shared_ = rhs.shared_;
|
||||
} else {
|
||||
new (&shared_) SharedPtrWrapper(rhs.shared_);
|
||||
}
|
||||
} else {
|
||||
if (isSharedAndNonNull()) {
|
||||
destroy();
|
||||
}
|
||||
singletonRepr_.singleton_ = static_cast<T*>(rhs.rawRepr().first);
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(rhs.rawRepr().nullIfSingleton_ == nullptr);
|
||||
singletonRepr_.unused_ = nullptr;
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
Repr& operator=(Repr&& rhs) noexcept {
|
||||
if (&rhs == this) {
|
||||
return *this;
|
||||
}
|
||||
if (rhs.isSharedAndNonNull()) {
|
||||
if (isSharedAndNonNull()) {
|
||||
shared_ = std::move(rhs.shared_);
|
||||
} else {
|
||||
new (&shared_) SharedPtrWrapper(std::move(rhs.shared_));
|
||||
}
|
||||
} else {
|
||||
if (isSharedAndNonNull()) {
|
||||
destroy();
|
||||
}
|
||||
singletonRepr_.singleton_ = static_cast<T*>(rhs.rawRepr().first);
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(rhs.rawRepr().nullIfSingleton_ == nullptr);
|
||||
singletonRepr_.unused_ = nullptr;
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
SharedPtrWrapper shared_;
|
||||
|
||||
struct SingletonRepr {
|
||||
explicit SingletonRepr(T* s) : singleton_(s) {}
|
||||
T* singleton_;
|
||||
void* unused_ = nullptr;
|
||||
} singletonRepr_;
|
||||
struct RawRepr {
|
||||
void* first;
|
||||
void* nullIfSingleton_;
|
||||
};
|
||||
|
||||
// It is UB to read the singleton part of Repr if it was
|
||||
// constructed as a shared_ptr and vice versa, but memcpying out
|
||||
// the representation is always OK, so here's an accessor to obey
|
||||
// the letter of the law.
|
||||
RawRepr rawRepr() const {
|
||||
RawRepr repr{};
|
||||
memcpy(&repr, reinterpret_cast<const char *>(this), sizeof(RawRepr));
|
||||
return repr;
|
||||
}
|
||||
|
||||
bool isNonNull() const {
|
||||
auto repr = rawRepr();
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(repr.nullIfSingleton_ == nullptr || repr.first != nullptr);
|
||||
return repr.first != nullptr;
|
||||
}
|
||||
|
||||
bool isSharedAndNonNull() const {
|
||||
return rawRepr().nullIfSingleton_ != nullptr;
|
||||
}
|
||||
|
||||
private:
|
||||
void destroy() {
|
||||
if (isSharedAndNonNull()) {
|
||||
// Without SharedPtrWrapper, this line would read
|
||||
// `shared_.~shared_ptr()` and nvcc would complain with
|
||||
// "error: expected primary-expression before '>' token"
|
||||
// referring to the "t" in "shared_ptr". SharedPtrWrapper
|
||||
// exists to work around this compiler bug.
|
||||
shared_.~SharedPtrWrapper();
|
||||
}
|
||||
}
|
||||
} repr_;
|
||||
std::shared_ptr<T> repr_;
|
||||
};
|
||||
|
||||
using TypePtr = SingletonOrSharedTypePtr<Type>;
|
||||
|
||||
@ -104,71 +104,6 @@ class Vectorized<float> {
|
||||
}
|
||||
return b;
|
||||
}
|
||||
// Implementation is picked from
|
||||
// https://github.com/ARM-software/ComputeLibrary/blob/v25.01/src/core/NEON/SVEMath.inl#L105
|
||||
inline svfloat32_t svexp_f32_z(svbool_t pg, svfloat32_t x) const {
|
||||
const auto c1 =
|
||||
svreinterpret_f32_u32(svdup_n_u32(0x3f7ffff6)); // x^1: 0x1.ffffecp-1f
|
||||
const auto c2 =
|
||||
svreinterpret_f32_u32(svdup_n_u32(0x3efffedb)); // x^2: 0x1.fffdb6p-2f
|
||||
const auto c3 =
|
||||
svreinterpret_f32_u32(svdup_n_u32(0x3e2aaf33)); // x^3: 0x1.555e66p-3f
|
||||
const auto c4 =
|
||||
svreinterpret_f32_u32(svdup_n_u32(0x3d2b9f17)); // x^4: 0x1.573e2ep-5f
|
||||
const auto c5 =
|
||||
svreinterpret_f32_u32(svdup_n_u32(0x3c072010)); // x^5: 0x1.0e4020p-7f
|
||||
const auto shift = svreinterpret_f32_u32(
|
||||
svdup_n_u32(0x4b00007f)); // 2^23 + 127 = 0x1.0000fep23f
|
||||
const auto inv_ln2 = svreinterpret_f32_u32(
|
||||
svdup_n_u32(0x3fb8aa3b)); // 1 / ln(2) = 0x1.715476p+0f
|
||||
const auto neg_ln2_hi = svreinterpret_f32_u32(svdup_n_u32(
|
||||
0xbf317200)); // -ln(2) from bits -1 to -19: -0x1.62e400p-1f
|
||||
const auto neg_ln2_lo = svreinterpret_f32_u32(svdup_n_u32(
|
||||
0xb5bfbe8e)); // -ln(2) from bits -20 to -42: -0x1.7f7d1cp-20f
|
||||
const auto inf = svdup_n_f32(std::numeric_limits<float>::infinity());
|
||||
const auto max_input = svdup_n_f32(88.37f); // Approximately ln(2^127.5)
|
||||
const auto zero = svdup_n_f32(0.f);
|
||||
const auto min_input = svdup_n_f32(-86.64f); // Approximately ln(2^-125)
|
||||
// Range reduction:
|
||||
// e^x = 2^n * e^r
|
||||
// where:
|
||||
// n = floor(x / ln(2))
|
||||
// r = x - n * ln(2)
|
||||
//
|
||||
// By adding x / ln(2) with 2^23 + 127 (shift):
|
||||
// * As FP32 fraction part only has 23-bits, the addition of 2^23 + 127
|
||||
// forces decimal part
|
||||
// of x / ln(2) out of the result. The integer part of x / ln(2) (i.e.
|
||||
// n) + 127 will occupy the whole fraction part of z in FP32 format.
|
||||
// Subtracting 2^23 + 127 (shift) from z will result in the integer part
|
||||
// of x / ln(2) (i.e. n) because the decimal part has been pushed out
|
||||
// and lost.
|
||||
// * The addition of 127 makes the FP32 fraction part of z ready to be
|
||||
// used as the exponent
|
||||
// in FP32 format. Left shifting z by 23 bits will result in 2^n.
|
||||
const auto z = svmla_f32_z(pg, shift, x, inv_ln2);
|
||||
const auto n = svsub_f32_z(pg, z, shift);
|
||||
const auto scale = svreinterpret_f32_u32(
|
||||
svlsl_n_u32_z(pg, svreinterpret_u32_f32(z), 23)); // 2^n
|
||||
// The calculation of n * ln(2) is done using 2 steps to achieve accuracy
|
||||
// beyond FP32. This outperforms longer Taylor series (3-4 tabs) both in
|
||||
// term of accuracy and performance.
|
||||
const auto r_hi = svmla_f32_z(pg, x, n, neg_ln2_hi);
|
||||
const auto r = svmla_f32_z(pg, r_hi, n, neg_ln2_lo);
|
||||
// Compute the truncated Taylor series of e^r.
|
||||
// poly = scale * (1 + c1 * r + c2 * r^2 + c3 * r^3 + c4 * r^4 + c5 * r^5)
|
||||
const auto r2 = svmul_f32_z(pg, r, r);
|
||||
const auto p1 = svmul_f32_z(pg, c1, r);
|
||||
const auto p23 = svmla_f32_z(pg, c2, c3, r);
|
||||
const auto p45 = svmla_f32_z(pg, c4, c5, r);
|
||||
const auto p2345 = svmla_f32_z(pg, p23, p45, r2);
|
||||
const auto p12345 = svmla_f32_z(pg, p1, p2345, r2);
|
||||
auto poly = svmla_f32_z(pg, scale, p12345, scale);
|
||||
// Handle underflow and overflow.
|
||||
poly = svsel_f32(svcmplt_f32(pg, x, min_input), zero, poly);
|
||||
poly = svsel_f32(svcmpgt_f32(pg, x, max_input), inf, poly);
|
||||
return poly;
|
||||
}
|
||||
static Vectorized<float> loadu(const void* ptr, int64_t count = size()) {
|
||||
if (count == size())
|
||||
return svld1_f32(ptrue, reinterpret_cast<const float*>(ptr));
|
||||
@ -313,11 +248,41 @@ class Vectorized<float> {
|
||||
return USE_SLEEF(
|
||||
Vectorized<float>(Sleef_expm1fx_u10sve(values)), map(std::expm1));
|
||||
}
|
||||
// Implementation copied from Arm Optimized Routines:
|
||||
// https://github.com/ARM-software/optimized-routines/blob/master/math/aarch64/sve/expf.c
|
||||
Vectorized<float> exp_u20() const {
|
||||
return exp();
|
||||
// special case to handle special inputs that are too large or too small
|
||||
// i.e. where there's at least one element x, s.t. |x| >= 87.3...
|
||||
svbool_t is_special_case = svacgt(svptrue_b32(), values, 0x1.5d5e2ap+6f);
|
||||
if (svptest_any(svptrue_b32(), is_special_case)) {
|
||||
return exp();
|
||||
}
|
||||
const svfloat32_t ln2_hi = svdup_n_f32(0x1.62e4p-1f);
|
||||
const svfloat32_t ln2_lo = svdup_n_f32(0x1.7f7d1cp-20f);
|
||||
const svfloat32_t c1 = svdup_n_f32(0.5f);
|
||||
const svfloat32_t inv_ln2 = svdup_n_f32(0x1.715476p+0f);
|
||||
|
||||
const float shift = 0x1.803f8p17f;
|
||||
|
||||
/* n = round(x/(ln2/N)). */
|
||||
svfloat32_t z = svmad_x(svptrue_b32(), inv_ln2, values, shift);
|
||||
svfloat32_t n = svsub_x(svptrue_b32(), z, shift);
|
||||
|
||||
/* r = x - n*ln2/N. */
|
||||
svfloat32_t r = values;
|
||||
r = svmls_x(svptrue_b32(), r, n, ln2_hi);
|
||||
r = svmls_x(svptrue_b32(), r, n, ln2_lo);
|
||||
|
||||
/* scale = 2^(n/N). */
|
||||
svfloat32_t scale = svexpa(svreinterpret_u32(z));
|
||||
|
||||
/* poly(r) = exp(r) - 1 ~= r + 0.5 r^2. */
|
||||
svfloat32_t r2 = svmul_x(svptrue_b32(), r, r);
|
||||
svfloat32_t poly = svmla_x(svptrue_b32(), r, r2, c1);
|
||||
return svmla_x(svptrue_b32(), scale, scale, poly);
|
||||
}
|
||||
Vectorized<float> fexp_u20() const {
|
||||
return exp();
|
||||
return exp_u20();
|
||||
}
|
||||
Vectorized<float> fmod(const Vectorized<float>& q) const {USE_SLEEF(
|
||||
{ return Vectorized<float>(Sleef_fmodfx_sve(values, q)); },
|
||||
@ -453,9 +418,11 @@ class Vectorized<float> {
|
||||
ptrue, svmax_f32_z(ptrue, values, CONST_MIN_TANH), CONST_MAX_TANH);
|
||||
|
||||
// Step 2: Calculate exp(2 * x), where x is the clamped value.
|
||||
// svmul_f32_z computes 2 * x, and svexp_f32_z computes the exponential of
|
||||
// the result.
|
||||
svfloat32_t exp2x = svexp_f32_z(ptrue, svmul_f32_z(ptrue, CONST_2, x));
|
||||
// svmul_f32_z computes 2 * x, and exp_u20() computes the exponential of
|
||||
// the result (via Vectorized<float>, then auto-converts back to
|
||||
// svfloat32_t).
|
||||
svfloat32_t exp2x =
|
||||
Vectorized<float>(svmul_f32_z(ptrue, CONST_2, x)).exp_u20();
|
||||
|
||||
// Step 3: Calculate the numerator of the tanh function, which is exp(2x)
|
||||
// - 1.
|
||||
|
||||
@ -6,6 +6,7 @@
|
||||
#ifdef __aarch64__
|
||||
#if !defined(CPU_CAPABILITY_SVE)
|
||||
#include <ATen/cpu/vec/vec128/vec128_bfloat16_neon.h>
|
||||
#include <ATen/cpu/vec/vec128/vec128_double_neon.h>
|
||||
#include <ATen/cpu/vec/vec128/vec128_float_neon.h>
|
||||
#include <ATen/cpu/vec/vec128/vec128_half_neon.h>
|
||||
#include <ATen/cpu/vec/vec128/vec128_int_aarch64.h>
|
||||
|
||||
@ -354,9 +354,47 @@ class Vectorized<c10::BFloat16> : public Vectorized16<
|
||||
|
||||
DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(abs)
|
||||
Vectorized frac() const;
|
||||
DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(neg)
|
||||
DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(trunc)
|
||||
DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(sqrt)
|
||||
|
||||
#ifdef __ARM_FEATURE_BF16
|
||||
Vectorized<c10::BFloat16> neg() const {
|
||||
return -values;
|
||||
}
|
||||
Vectorized<c10::BFloat16> reciprocal() const {
|
||||
return 1.0f / values;
|
||||
}
|
||||
Vectorized<c10::BFloat16> operator==(
|
||||
const Vectorized<c10::BFloat16>& other) const {
|
||||
return values == other.values;
|
||||
}
|
||||
|
||||
Vectorized<c10::BFloat16> operator!=(
|
||||
const Vectorized<c10::BFloat16>& other) const {
|
||||
return values != other.values;
|
||||
}
|
||||
|
||||
Vectorized<c10::BFloat16> operator<(
|
||||
const Vectorized<c10::BFloat16>& other) const {
|
||||
return values < other.values;
|
||||
}
|
||||
|
||||
Vectorized<c10::BFloat16> operator<=(
|
||||
const Vectorized<c10::BFloat16>& other) const {
|
||||
return values <= other.values;
|
||||
}
|
||||
|
||||
Vectorized<c10::BFloat16> operator>(
|
||||
const Vectorized<c10::BFloat16>& other) const {
|
||||
return values > other.values;
|
||||
}
|
||||
|
||||
Vectorized<c10::BFloat16> operator>=(
|
||||
const Vectorized<c10::BFloat16>& other) const {
|
||||
return values >= other.values;
|
||||
}
|
||||
#else
|
||||
DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(neg)
|
||||
DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(reciprocal)
|
||||
DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator==)
|
||||
DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator!=)
|
||||
@ -364,6 +402,7 @@ class Vectorized<c10::BFloat16> : public Vectorized16<
|
||||
DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator<=)
|
||||
DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator>)
|
||||
DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator>=)
|
||||
#endif
|
||||
|
||||
#undef DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD
|
||||
#undef DEFINE_BINARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD
|
||||
@ -412,28 +451,52 @@ template <>
|
||||
Vectorized<c10::BFloat16> inline operator+(
|
||||
const Vectorized<c10::BFloat16>& a,
|
||||
const Vectorized<c10::BFloat16>& b) {
|
||||
#ifdef __ARM_FEATURE_BF16
|
||||
bfloat16x8_t x = a;
|
||||
bfloat16x8_t y = b;
|
||||
return x + y;
|
||||
#else
|
||||
return binary_operator_via_float(std::plus<Vectorized<float>>(), a, b);
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<c10::BFloat16> inline operator-(
|
||||
const Vectorized<c10::BFloat16>& a,
|
||||
const Vectorized<c10::BFloat16>& b) {
|
||||
#ifdef __ARM_FEATURE_BF16
|
||||
bfloat16x8_t x = a;
|
||||
bfloat16x8_t y = b;
|
||||
return x - y;
|
||||
#else
|
||||
return binary_operator_via_float(std::minus<Vectorized<float>>(), a, b);
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<c10::BFloat16> inline operator*(
|
||||
const Vectorized<c10::BFloat16>& a,
|
||||
const Vectorized<c10::BFloat16>& b) {
|
||||
#ifdef __ARM_FEATURE_BF16
|
||||
bfloat16x8_t x = a;
|
||||
bfloat16x8_t y = b;
|
||||
return x * y;
|
||||
#else
|
||||
return binary_operator_via_float(std::multiplies<Vectorized<float>>(), a, b);
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<c10::BFloat16> inline operator/(
|
||||
const Vectorized<c10::BFloat16>& a,
|
||||
const Vectorized<c10::BFloat16>& b) {
|
||||
#ifdef __ARM_FEATURE_BF16
|
||||
bfloat16x8_t x = a;
|
||||
bfloat16x8_t y = b;
|
||||
return x / y;
|
||||
#else
|
||||
return binary_operator_via_float(std::divides<Vectorized<float>>(), a, b);
|
||||
#endif
|
||||
}
|
||||
|
||||
// frac. Implement this here so we can use subtraction
|
||||
@ -544,12 +607,19 @@ Vectorized<c10::BFloat16> inline fmadd(
|
||||
const Vectorized<c10::BFloat16>& a,
|
||||
const Vectorized<c10::BFloat16>& b,
|
||||
const Vectorized<c10::BFloat16>& c) {
|
||||
#ifdef __ARM_FEATURE_BF16
|
||||
bfloat16x8_t x = a;
|
||||
bfloat16x8_t y = b;
|
||||
bfloat16x8_t z = c;
|
||||
return x * y + z;
|
||||
#else
|
||||
// NOTE [BF16 FMA]: There isn't an FMA that accumulates into BF16! Also,
|
||||
// vbfmlalbq_f32 and vbfmlaltq_f32 take the even and odd-numbered
|
||||
// elements, not the bottom and top half, so they don't seem
|
||||
// particularly useful here. Ideally we would include dot product in
|
||||
// the Vectorized interface...
|
||||
return a * b + c;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
@ -557,8 +627,15 @@ Vectorized<c10::BFloat16> inline fnmadd(
|
||||
const Vectorized<c10::BFloat16>& a,
|
||||
const Vectorized<c10::BFloat16>& b,
|
||||
const Vectorized<c10::BFloat16>& c) {
|
||||
#ifdef __ARM_FEATURE_BF16
|
||||
bfloat16x8_t x = a;
|
||||
bfloat16x8_t y = b;
|
||||
bfloat16x8_t z = c;
|
||||
return (-x) * y + z;
|
||||
#else
|
||||
// See NOTE [BF16 FMA] above.
|
||||
return -a * b + c;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
@ -566,8 +643,15 @@ Vectorized<c10::BFloat16> inline fmsub(
|
||||
const Vectorized<c10::BFloat16>& a,
|
||||
const Vectorized<c10::BFloat16>& b,
|
||||
const Vectorized<c10::BFloat16>& c) {
|
||||
#ifdef __ARM_FEATURE_BF16
|
||||
bfloat16x8_t x = a;
|
||||
bfloat16x8_t y = b;
|
||||
bfloat16x8_t z = c;
|
||||
return x * y - z;
|
||||
#else
|
||||
// See NOTE [BF16 FMA] above.
|
||||
return a * b - c;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
@ -575,8 +659,15 @@ Vectorized<c10::BFloat16> inline fnmsub(
|
||||
const Vectorized<c10::BFloat16>& a,
|
||||
const Vectorized<c10::BFloat16>& b,
|
||||
const Vectorized<c10::BFloat16>& c) {
|
||||
#ifdef __ARM_FEATURE_BF16
|
||||
bfloat16x8_t x = a;
|
||||
bfloat16x8_t y = b;
|
||||
bfloat16x8_t z = c;
|
||||
return (-x) * y - z;
|
||||
#else
|
||||
// See NOTE [BF16 FMA] above.
|
||||
return -a * b - c;
|
||||
#endif
|
||||
}
|
||||
|
||||
#endif // !defined(C10_MOBILE) && defined(__aarch64__)
|
||||
|
||||
@ -5,6 +5,114 @@
|
||||
namespace at::vec {
|
||||
inline namespace CPU_CAPABILITY {
|
||||
#if (defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE256))
|
||||
|
||||
// Enable auto-vectorization for GCC-13+ and clang-17+
|
||||
// GCC-12 has a bug: gcc.gnu.org/bugzilla/show_bug.cgi?id=117001
|
||||
#if __GNUC__ > 12 || (defined(__clang__) && (__clang_major__ >= 17))
|
||||
|
||||
template <typename from_type, typename to_type>
|
||||
inline void convertImpl(
|
||||
const from_type* __restrict src,
|
||||
to_type* __restrict dst,
|
||||
int64_t n) {
|
||||
uint64_t len = static_cast<uint64_t>(n);
|
||||
for (uint64_t i = 0; i < len; i++) {
|
||||
dst[i] = static_cast<to_type>(src[i]);
|
||||
}
|
||||
}
|
||||
|
||||
#define CONVERT_TEMPLATE(from_type, to_type) \
|
||||
template <> \
|
||||
inline void convert(const from_type* src, to_type* dst, int64_t n) { \
|
||||
return convertImpl<from_type, to_type>(src, dst, n); \
|
||||
}
|
||||
|
||||
CONVERT_TEMPLATE(uint8_t, uint8_t)
|
||||
CONVERT_TEMPLATE(uint8_t, int8_t)
|
||||
CONVERT_TEMPLATE(uint8_t, int16_t)
|
||||
CONVERT_TEMPLATE(uint8_t, int32_t)
|
||||
CONVERT_TEMPLATE(uint8_t, int64_t)
|
||||
CONVERT_TEMPLATE(uint8_t, float)
|
||||
CONVERT_TEMPLATE(uint8_t, double)
|
||||
CONVERT_TEMPLATE(int8_t, uint8_t)
|
||||
CONVERT_TEMPLATE(int8_t, int8_t)
|
||||
CONVERT_TEMPLATE(int8_t, int16_t)
|
||||
CONVERT_TEMPLATE(int8_t, int32_t)
|
||||
CONVERT_TEMPLATE(int8_t, int64_t)
|
||||
CONVERT_TEMPLATE(int8_t, float)
|
||||
CONVERT_TEMPLATE(int8_t, double)
|
||||
CONVERT_TEMPLATE(int16_t, uint8_t)
|
||||
CONVERT_TEMPLATE(int16_t, int8_t)
|
||||
CONVERT_TEMPLATE(int16_t, int16_t)
|
||||
CONVERT_TEMPLATE(int16_t, int32_t)
|
||||
CONVERT_TEMPLATE(int16_t, int64_t)
|
||||
CONVERT_TEMPLATE(int16_t, float)
|
||||
CONVERT_TEMPLATE(int16_t, double)
|
||||
CONVERT_TEMPLATE(int32_t, uint8_t)
|
||||
CONVERT_TEMPLATE(int32_t, int8_t)
|
||||
CONVERT_TEMPLATE(int32_t, int16_t)
|
||||
CONVERT_TEMPLATE(int32_t, int32_t)
|
||||
CONVERT_TEMPLATE(int32_t, int64_t)
|
||||
CONVERT_TEMPLATE(int32_t, float)
|
||||
CONVERT_TEMPLATE(int32_t, double)
|
||||
CONVERT_TEMPLATE(int64_t, uint8_t)
|
||||
CONVERT_TEMPLATE(int64_t, int8_t)
|
||||
CONVERT_TEMPLATE(int64_t, int16_t)
|
||||
CONVERT_TEMPLATE(int64_t, int32_t)
|
||||
CONVERT_TEMPLATE(int64_t, int64_t)
|
||||
CONVERT_TEMPLATE(int64_t, float)
|
||||
CONVERT_TEMPLATE(int64_t, double)
|
||||
CONVERT_TEMPLATE(float, uint8_t)
|
||||
CONVERT_TEMPLATE(float, int8_t)
|
||||
CONVERT_TEMPLATE(float, int16_t)
|
||||
CONVERT_TEMPLATE(float, int32_t)
|
||||
CONVERT_TEMPLATE(float, int64_t)
|
||||
CONVERT_TEMPLATE(float, float)
|
||||
CONVERT_TEMPLATE(float, double)
|
||||
CONVERT_TEMPLATE(double, uint8_t)
|
||||
CONVERT_TEMPLATE(double, int8_t)
|
||||
CONVERT_TEMPLATE(double, int16_t)
|
||||
CONVERT_TEMPLATE(double, int32_t)
|
||||
CONVERT_TEMPLATE(double, int64_t)
|
||||
CONVERT_TEMPLATE(double, float)
|
||||
CONVERT_TEMPLATE(double, double)
|
||||
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
|
||||
CONVERT_TEMPLATE(float16_t, uint8_t)
|
||||
CONVERT_TEMPLATE(float16_t, int8_t)
|
||||
CONVERT_TEMPLATE(float16_t, int16_t)
|
||||
CONVERT_TEMPLATE(float16_t, int32_t)
|
||||
CONVERT_TEMPLATE(float16_t, int64_t)
|
||||
CONVERT_TEMPLATE(float16_t, float16_t)
|
||||
CONVERT_TEMPLATE(float16_t, float)
|
||||
CONVERT_TEMPLATE(float16_t, double)
|
||||
CONVERT_TEMPLATE(uint8_t, float16_t)
|
||||
CONVERT_TEMPLATE(int8_t, float16_t)
|
||||
CONVERT_TEMPLATE(int16_t, float16_t)
|
||||
CONVERT_TEMPLATE(int32_t, float16_t)
|
||||
CONVERT_TEMPLATE(int64_t, float16_t)
|
||||
CONVERT_TEMPLATE(float, float16_t)
|
||||
CONVERT_TEMPLATE(double, float16_t)
|
||||
#endif
|
||||
#ifdef __ARM_FEATURE_BF16
|
||||
CONVERT_TEMPLATE(bfloat16_t, uint8_t)
|
||||
CONVERT_TEMPLATE(bfloat16_t, int8_t)
|
||||
CONVERT_TEMPLATE(bfloat16_t, int16_t)
|
||||
CONVERT_TEMPLATE(bfloat16_t, int32_t)
|
||||
CONVERT_TEMPLATE(bfloat16_t, int64_t)
|
||||
CONVERT_TEMPLATE(bfloat16_t, bfloat16_t)
|
||||
CONVERT_TEMPLATE(bfloat16_t, float)
|
||||
CONVERT_TEMPLATE(bfloat16_t, double)
|
||||
CONVERT_TEMPLATE(uint8_t, bfloat16_t)
|
||||
CONVERT_TEMPLATE(int8_t, bfloat16_t)
|
||||
CONVERT_TEMPLATE(int16_t, bfloat16_t)
|
||||
CONVERT_TEMPLATE(int32_t, bfloat16_t)
|
||||
CONVERT_TEMPLATE(int64_t, bfloat16_t)
|
||||
CONVERT_TEMPLATE(float, bfloat16_t)
|
||||
CONVERT_TEMPLATE(double, bfloat16_t)
|
||||
#endif
|
||||
|
||||
#endif
|
||||
|
||||
template <typename src_t>
|
||||
struct VecConvert<
|
||||
float,
|
||||
|
||||
586
aten/src/ATen/cpu/vec/vec128/vec128_double_neon.h
Normal file
586
aten/src/ATen/cpu/vec/vec128/vec128_double_neon.h
Normal file
@ -0,0 +1,586 @@
|
||||
#pragma once
|
||||
|
||||
#include <ATen/cpu/vec/intrinsics.h>
|
||||
#include <ATen/cpu/vec/vec_base.h>
|
||||
#include <c10/macros/Macros.h>
|
||||
#include <c10/util/irange.h>
|
||||
#include <cmath>
|
||||
|
||||
namespace at::vec {
|
||||
// Note [CPU_CAPABILITY namespace]
|
||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
// This header, and all of its subheaders, will be compiled with
|
||||
// different architecture flags for each supported set of vector
|
||||
// intrinsics. So we need to make sure they aren't inadvertently
|
||||
// linked together. We do this by declaring objects in an `inline
|
||||
// namespace` which changes the name mangling, but can still be
|
||||
// accessed as `at::vec`.
|
||||
inline namespace CPU_CAPABILITY {
|
||||
|
||||
template <>
|
||||
struct is_vec_specialized_for<double> : std::bool_constant<true> {};
|
||||
|
||||
template <>
|
||||
class Vectorized<double> {
|
||||
private:
|
||||
float64x2_t values;
|
||||
|
||||
public:
|
||||
using value_type = double;
|
||||
using size_type = int;
|
||||
static constexpr size_type size() {
|
||||
return 2;
|
||||
}
|
||||
Vectorized() {
|
||||
values = vdupq_n_f64(0.0);
|
||||
}
|
||||
Vectorized(float64x2_t v) : values(v) {}
|
||||
Vectorized(double val) {
|
||||
values = vdupq_n_f64(val);
|
||||
}
|
||||
template <
|
||||
typename... Args,
|
||||
typename = std::enable_if_t<(sizeof...(Args) == size())>>
|
||||
Vectorized(Args... vals) {
|
||||
__at_align__ double buffer[size()] = {vals...};
|
||||
values = vld1q_f64(buffer);
|
||||
}
|
||||
operator float64x2_t() const {
|
||||
return values;
|
||||
}
|
||||
template <int64_t mask>
|
||||
static Vectorized<double> blend(
|
||||
const Vectorized<double>& a,
|
||||
const Vectorized<double>& b) {
|
||||
// Build an array of flags: each bit of element is 1 if the corresponding
|
||||
// bit in 'mask' is set, 0 otherwise.
|
||||
uint64x2_t maskArray = {
|
||||
(mask & 1ULL) ? 0xFFFFFFFFFFFFFFFF : 0,
|
||||
(mask & 2ULL) ? 0xFFFFFFFFFFFFFFFF : 0};
|
||||
// Use BSL to select elements from b where the mask is 1, else from a
|
||||
return vbslq_f64(maskArray, b.values, a.values);
|
||||
}
|
||||
static Vectorized<double> blendv(
|
||||
const Vectorized<double>& a,
|
||||
const Vectorized<double>& b,
|
||||
const Vectorized<double>& mask_) {
|
||||
return vbslq_f64(vreinterpretq_u64_f64(mask_.values), b.values, a.values);
|
||||
}
|
||||
template <typename step_t>
|
||||
static Vectorized<double> arange(
|
||||
double base = 0.,
|
||||
step_t step = static_cast<step_t>(1)) {
|
||||
return {base, base + static_cast<double>(step)};
|
||||
}
|
||||
static inline Vectorized<double> set(
|
||||
const Vectorized<double>& a,
|
||||
const Vectorized<double>& b,
|
||||
int64_t count = size()) {
|
||||
if (count == 0) {
|
||||
return a;
|
||||
} else if (count >= 2) {
|
||||
return b;
|
||||
} else {
|
||||
float64x2_t c = {b.values[0], a.values[1]};
|
||||
return c;
|
||||
}
|
||||
}
|
||||
static Vectorized<double> loadu(const void* ptr, int64_t count = size()) {
|
||||
if (count == size()) {
|
||||
return vld1q_f64(reinterpret_cast<const double*>(ptr));
|
||||
} else if (count == 1) {
|
||||
float64x1_t x = vld1_f64(reinterpret_cast<const double*>(ptr));
|
||||
float64x1_t z = {0.0};
|
||||
return vcombine_f64(x, z);
|
||||
} else {
|
||||
return vdupq_n_f64(0.0);
|
||||
}
|
||||
}
|
||||
void store(void* ptr, int64_t count = size()) const {
|
||||
if (count == size()) {
|
||||
vst1q_f64(reinterpret_cast<double*>(ptr), values);
|
||||
} else if (count == 1) {
|
||||
vst1_f64(reinterpret_cast<double*>(ptr), vget_low_f64(values));
|
||||
}
|
||||
}
|
||||
const double& operator[](int idx) const = delete;
|
||||
double& operator[](int idx) = delete;
|
||||
int64_t zero_mask() const {
|
||||
// returns an integer mask where all zero elements are translated to 1-bit
|
||||
// and others are translated to 0-bit
|
||||
uint64x2_t cmpReg = vceqzq_f64(values);
|
||||
uint64x2_t mask = {1, 2};
|
||||
uint64x2_t res = vandq_u64(cmpReg, mask);
|
||||
return res[0] | res[1];
|
||||
}
|
||||
Vectorized<double> isnan() const {
|
||||
// NaN check
|
||||
return vreinterpretq_f64_u32(
|
||||
vmvnq_u32(vreinterpretq_u32_u64(vceqq_f64(values, values))));
|
||||
}
|
||||
bool has_inf_nan() const {
|
||||
Vectorized<double> x = vsubq_f64(values, values);
|
||||
float64x2_t r = x.isnan();
|
||||
uint64x2_t u = vreinterpretq_u64_f64(r);
|
||||
return u[0] | u[1];
|
||||
}
|
||||
Vectorized<double> map(double (*f)(double)) const {
|
||||
float64x2_t result;
|
||||
result[0] = f(values[0]);
|
||||
result[1] = f(values[1]);
|
||||
return result;
|
||||
}
|
||||
Vectorized<double> map2(
|
||||
const Vectorized<double>& second,
|
||||
double (*const f)(double, double)) const {
|
||||
float64x2_t result;
|
||||
result[0] = f(values[0], second.values[0]);
|
||||
result[1] = f(values[1], second.values[1]);
|
||||
return result;
|
||||
}
|
||||
Vectorized<double> abs() const {
|
||||
return vabsq_f64(values);
|
||||
}
|
||||
Vectorized<double> angle() const {
|
||||
auto zero = Vectorized<double>(0.0);
|
||||
auto pi = Vectorized<double>(c10::pi<double>);
|
||||
auto tmp = blendv(zero, pi, vreinterpretq_f64_u64(vcltzq_f64(values)));
|
||||
return blendv(tmp, *this, isnan());
|
||||
}
|
||||
Vectorized<double> real() const {
|
||||
return *this;
|
||||
}
|
||||
Vectorized<double> imag() const {
|
||||
return Vectorized<double>(0.0);
|
||||
}
|
||||
Vectorized<double> conj() const {
|
||||
return *this;
|
||||
}
|
||||
Vectorized<double> acos() const {
|
||||
return USE_SLEEF(
|
||||
Vectorized<double>(Sleef_acosd2_u10(values)), map(std::acos));
|
||||
}
|
||||
Vectorized<double> acosh() const {
|
||||
return USE_SLEEF(
|
||||
Vectorized<double>(Sleef_acoshd2_u10(values)), map(std::acosh));
|
||||
}
|
||||
Vectorized<double> asin() const {
|
||||
return USE_SLEEF(
|
||||
Vectorized<double>(Sleef_asind2_u10(values)), map(std::asin));
|
||||
}
|
||||
Vectorized<double> asinh() const {
|
||||
return USE_SLEEF(
|
||||
Vectorized<double>(Sleef_asinhd2_u10(values)), map(std::asinh));
|
||||
}
|
||||
Vectorized<double> atan() const {
|
||||
return USE_SLEEF(
|
||||
Vectorized<double>(Sleef_atand2_u10(values)), map(std::atan));
|
||||
}
|
||||
Vectorized<double> atanh() const {
|
||||
return USE_SLEEF(
|
||||
Vectorized<double>(Sleef_atanhd2_u10(values)), map(std::atanh));
|
||||
}
|
||||
Vectorized<double> atan2(const Vectorized<double>& b) const {USE_SLEEF(
|
||||
{ return Vectorized<double>(Sleef_atan2d2_u10(values, b)); },
|
||||
{
|
||||
__at_align__ double tmp[size()];
|
||||
__at_align__ double tmp_b[size()];
|
||||
store(tmp);
|
||||
b.store(tmp_b);
|
||||
for (int64_t i = 0; i < size(); i++) {
|
||||
tmp[i] = std::atan2(tmp[i], tmp_b[i]);
|
||||
}
|
||||
return loadu(tmp);
|
||||
})} Vectorized<double> copysign(const Vectorized<double>& sign) const {
|
||||
USE_SLEEF(
|
||||
{ return Vectorized<double>(Sleef_copysignd2(values, sign)); },
|
||||
{
|
||||
__at_align__ double tmp[size()];
|
||||
__at_align__ double tmp_sign[size()];
|
||||
store(tmp);
|
||||
sign.store(tmp_sign);
|
||||
for (int64_t i = 0; i < size(); i++) {
|
||||
tmp[i] = std::copysign(tmp[i], tmp_sign[i]);
|
||||
}
|
||||
return loadu(tmp);
|
||||
})} Vectorized<double> erf() const {
|
||||
return USE_SLEEF(
|
||||
Vectorized<double>(Sleef_erfd2_u10(values)), map(std::erf));
|
||||
}
|
||||
Vectorized<double> erfc() const {
|
||||
return USE_SLEEF(
|
||||
Vectorized<double>(Sleef_erfcd2_u15(values)), map(std::erfc));
|
||||
}
|
||||
Vectorized<double> exp() const {
|
||||
return USE_SLEEF(
|
||||
Vectorized<double>(Sleef_expd2_u10(values)), map(std::exp));
|
||||
}
|
||||
Vectorized<double> exp2() const {
|
||||
return USE_SLEEF(
|
||||
Vectorized<double>(Sleef_exp2d2_u10(values)), map(std::exp2));
|
||||
}
|
||||
Vectorized<double> expm1() const {
|
||||
return USE_SLEEF(
|
||||
Vectorized<double>(Sleef_expm1d2_u10(values)), map(std::expm1));
|
||||
}
|
||||
Vectorized<double> fmod(const Vectorized<double>& q) const {USE_SLEEF(
|
||||
{ return Vectorized<double>(Sleef_fmodd2(values, q)); },
|
||||
{
|
||||
__at_align__ double tmp[size()];
|
||||
__at_align__ double tmp_q[size()];
|
||||
store(tmp);
|
||||
q.store(tmp_q);
|
||||
for (int64_t i = 0; i < size(); i++) {
|
||||
tmp[i] = std::fmod(tmp[i], tmp_q[i]);
|
||||
}
|
||||
return loadu(tmp);
|
||||
})} Vectorized<double> hypot(const Vectorized<double>& b) const {
|
||||
USE_SLEEF(
|
||||
{ return Vectorized<double>(Sleef_hypotd2_u05(values, b)); },
|
||||
{
|
||||
__at_align__ double tmp[size()];
|
||||
__at_align__ double tmp_b[size()];
|
||||
store(tmp);
|
||||
b.store(tmp_b);
|
||||
for (int64_t i = 0; i < size(); i++) {
|
||||
tmp[i] = std::hypot(tmp[i], tmp_b[i]);
|
||||
}
|
||||
return loadu(tmp);
|
||||
})} Vectorized<double> i0() const {
|
||||
return map(calc_i0);
|
||||
}
|
||||
Vectorized<double> nextafter(const Vectorized<double>& b) const {USE_SLEEF(
|
||||
{ return Vectorized<double>(Sleef_nextafterd2(values, b)); },
|
||||
{
|
||||
__at_align__ double tmp[size()];
|
||||
__at_align__ double tmp_b[size()];
|
||||
store(tmp);
|
||||
b.store(tmp_b);
|
||||
for (int64_t i = 0; i < size(); ++i) {
|
||||
tmp[i] = std::nextafter(tmp[i], tmp_b[i]);
|
||||
}
|
||||
return loadu(tmp);
|
||||
})} Vectorized<double> log() const {
|
||||
return USE_SLEEF(
|
||||
Vectorized<double>(Sleef_logd2_u10(values)), map(std::log));
|
||||
}
|
||||
Vectorized<double> log2() const {
|
||||
return USE_SLEEF(
|
||||
Vectorized<double>(Sleef_log2d2_u10(values)), map(std::log2));
|
||||
}
|
||||
Vectorized<double> log10() const {
|
||||
return USE_SLEEF(
|
||||
Vectorized<double>(Sleef_log10d2_u10(values)), map(std::log10));
|
||||
}
|
||||
Vectorized<double> log1p() const {
|
||||
return USE_SLEEF(
|
||||
Vectorized<double>(Sleef_log1pd2_u10(values)), map(std::log1p));
|
||||
}
|
||||
Vectorized<double> frac() const;
|
||||
Vectorized<double> sin() const {
|
||||
return USE_SLEEF(
|
||||
Vectorized<double>(Sleef_sind2_u10(values)), map(std::sin));
|
||||
}
|
||||
Vectorized<double> sinh() const {
|
||||
return USE_SLEEF(
|
||||
Vectorized<double>(Sleef_sinhd2_u10(values)), map(std::sinh));
|
||||
}
|
||||
Vectorized<double> cos() const {
|
||||
return USE_SLEEF(
|
||||
Vectorized<double>(Sleef_cosd2_u10(values)), map(std::cos));
|
||||
}
|
||||
Vectorized<double> cosh() const {
|
||||
return USE_SLEEF(
|
||||
Vectorized<double>(Sleef_coshd2_u10(values)), map(std::cosh));
|
||||
}
|
||||
Vectorized<double> pow(const Vectorized<double>& b) const {USE_SLEEF(
|
||||
{ return Vectorized<double>(Sleef_powd2_u10(values, b)); },
|
||||
{
|
||||
__at_align__ double tmp[size()];
|
||||
__at_align__ double tmp_b[size()];
|
||||
store(tmp);
|
||||
b.store(tmp_b);
|
||||
for (int64_t i = 0; i < size(); i++) {
|
||||
tmp[i] = std::pow(tmp[i], tmp_b[i]);
|
||||
}
|
||||
return loadu(tmp);
|
||||
})} // Comparison using the _CMP_**_OQ predicate.
|
||||
// `O`: get false if an operand is NaN
|
||||
// `Q`: do not raise if an operand is NaN
|
||||
Vectorized<double> tan() const {
|
||||
return USE_SLEEF(
|
||||
Vectorized<double>(Sleef_tand2_u10(values)), map(std::tan));
|
||||
}
|
||||
Vectorized<double> tanh() const {
|
||||
return USE_SLEEF(
|
||||
Vectorized<double>(Sleef_tanhd2_u10(values)), map(std::tanh));
|
||||
}
|
||||
Vectorized<double> lgamma() const {
|
||||
return USE_SLEEF(
|
||||
Vectorized<double>(Sleef_lgammad2_u10(values)), map(std::lgamma));
|
||||
}
|
||||
Vectorized<double> erfinv() const {
|
||||
return map(calc_erfinv);
|
||||
}
|
||||
Vectorized<double> exp_u20() const {
|
||||
return exp();
|
||||
}
|
||||
Vectorized<double> fexp_u20() const {
|
||||
return exp();
|
||||
}
|
||||
Vectorized<double> i0e() const {
|
||||
return map(calc_i0e);
|
||||
}
|
||||
Vectorized<double> digamma() const {
|
||||
return map(calc_digamma);
|
||||
}
|
||||
Vectorized<double> igamma(const Vectorized<double>& x) const {
|
||||
__at_align__ double tmp[size()];
|
||||
__at_align__ double tmp_x[size()];
|
||||
store(tmp);
|
||||
x.store(tmp_x);
|
||||
for (int64_t i = 0; i < size(); i++) {
|
||||
tmp[i] = calc_igamma(tmp[i], tmp_x[i]);
|
||||
}
|
||||
return loadu(tmp);
|
||||
}
|
||||
Vectorized<double> igammac(const Vectorized<double>& x) const {
|
||||
__at_align__ double tmp[size()];
|
||||
__at_align__ double tmp_x[size()];
|
||||
store(tmp);
|
||||
x.store(tmp_x);
|
||||
for (int64_t i = 0; i < size(); i++) {
|
||||
tmp[i] = calc_igammac(tmp[i], tmp_x[i]);
|
||||
}
|
||||
return loadu(tmp);
|
||||
}
|
||||
Vectorized<double> ceil() const {
|
||||
return vrndpq_f64(values);
|
||||
}
|
||||
Vectorized<double> floor() const {
|
||||
return vrndmq_f64(values);
|
||||
}
|
||||
Vectorized<double> neg() const {
|
||||
return vnegq_f64(values);
|
||||
}
|
||||
Vectorized<double> round() const {
|
||||
return vrndiq_f64(values);
|
||||
}
|
||||
Vectorized<double> trunc() const {
|
||||
return vrndq_f64(values);
|
||||
}
|
||||
Vectorized<double> sqrt() const {
|
||||
return vsqrtq_f64(values);
|
||||
}
|
||||
Vectorized<double> reciprocal() const {
|
||||
return vdivq_f64(vdupq_n_f64(1.0), values);
|
||||
}
|
||||
Vectorized<double> rsqrt() const {
|
||||
return vdivq_f64(vdupq_n_f64(1.0), vsqrtq_f64(values));
|
||||
}
|
||||
double reduce_add() const {
|
||||
return vaddvq_f64(values);
|
||||
}
|
||||
double reduce_max() const {
|
||||
return vmaxvq_f64(values);
|
||||
}
|
||||
Vectorized<double> operator==(const Vectorized<double>& other) const {
|
||||
return Vectorized<double>(
|
||||
vreinterpretq_f64_u64(vceqq_f64(values, other.values)));
|
||||
}
|
||||
|
||||
Vectorized<double> operator!=(const Vectorized<double>& other) const {
|
||||
float64x2_t r0 = vreinterpretq_f64_u32(
|
||||
vmvnq_u32(vreinterpretq_u32_u64(vceqq_f64(values, other.values))));
|
||||
return Vectorized<double>(r0);
|
||||
}
|
||||
|
||||
Vectorized<double> operator<(const Vectorized<double>& other) const {
|
||||
return Vectorized<double>(
|
||||
vreinterpretq_f64_u64(vcltq_f64(values, other.values)));
|
||||
}
|
||||
|
||||
Vectorized<double> operator<=(const Vectorized<double>& other) const {
|
||||
return Vectorized<double>(
|
||||
vreinterpretq_f64_u64(vcleq_f64(values, other.values)));
|
||||
}
|
||||
|
||||
Vectorized<double> operator>(const Vectorized<double>& other) const {
|
||||
return Vectorized<double>(
|
||||
vreinterpretq_f64_u64(vcgtq_f64(values, other.values)));
|
||||
}
|
||||
|
||||
Vectorized<double> operator>=(const Vectorized<double>& other) const {
|
||||
return Vectorized<double>(
|
||||
vreinterpretq_f64_u64(vcgeq_f64(values, other.values)));
|
||||
}
|
||||
|
||||
Vectorized<double> eq(const Vectorized<double>& other) const;
|
||||
Vectorized<double> ne(const Vectorized<double>& other) const;
|
||||
Vectorized<double> gt(const Vectorized<double>& other) const;
|
||||
Vectorized<double> ge(const Vectorized<double>& other) const;
|
||||
Vectorized<double> lt(const Vectorized<double>& other) const;
|
||||
Vectorized<double> le(const Vectorized<double>& other) const;
|
||||
};
|
||||
|
||||
template <>
|
||||
Vectorized<double> inline operator+(
|
||||
const Vectorized<double>& a,
|
||||
const Vectorized<double>& b) {
|
||||
return vaddq_f64(a, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<double> inline operator-(
|
||||
const Vectorized<double>& a,
|
||||
const Vectorized<double>& b) {
|
||||
return vsubq_f64(a, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<double> inline operator*(
|
||||
const Vectorized<double>& a,
|
||||
const Vectorized<double>& b) {
|
||||
return vmulq_f64(a, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<double> inline operator/(
|
||||
const Vectorized<double>& a,
|
||||
const Vectorized<double>& b) {
|
||||
return vdivq_f64(a, b);
|
||||
}
|
||||
|
||||
// frac. Implement this here so we can use subtraction
|
||||
Vectorized<double> inline Vectorized<double>::frac() const {
|
||||
return *this - this->trunc();
|
||||
}
|
||||
|
||||
// Implements the IEEE 754 201X `maximum` operation, which propagates NaN if
|
||||
// either input is a NaN.
|
||||
template <>
|
||||
Vectorized<double> inline maximum(
|
||||
const Vectorized<double>& a,
|
||||
const Vectorized<double>& b) {
|
||||
return vmaxq_f64(a, b);
|
||||
}
|
||||
|
||||
// Implements the IEEE 754 201X `minimum` operation, which propagates NaN if
|
||||
// either input is a NaN.
|
||||
template <>
|
||||
Vectorized<double> inline minimum(
|
||||
const Vectorized<double>& a,
|
||||
const Vectorized<double>& b) {
|
||||
return vminq_f64(a, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<double> inline clamp(
|
||||
const Vectorized<double>& a,
|
||||
const Vectorized<double>& min,
|
||||
const Vectorized<double>& max) {
|
||||
return vminq_f64(max, vmaxq_f64(min, a));
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<double> inline clamp_max(
|
||||
const Vectorized<double>& a,
|
||||
const Vectorized<double>& max) {
|
||||
return vminq_f64(max, a);
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<double> inline clamp_min(
|
||||
const Vectorized<double>& a,
|
||||
const Vectorized<double>& min) {
|
||||
return vmaxq_f64(min, a);
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<double> inline operator&(
|
||||
const Vectorized<double>& a,
|
||||
const Vectorized<double>& b) {
|
||||
return vreinterpretq_f64_u64(
|
||||
vandq_u64(vreinterpretq_u64_f64(a), vreinterpretq_u64_f64(b)));
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<double> inline operator|(
|
||||
const Vectorized<double>& a,
|
||||
const Vectorized<double>& b) {
|
||||
return vreinterpretq_f64_u64(
|
||||
vorrq_u64(vreinterpretq_u64_f64(a), vreinterpretq_u64_f64(b)));
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<double> inline operator^(
|
||||
const Vectorized<double>& a,
|
||||
const Vectorized<double>& b) {
|
||||
return vreinterpretq_f64_u64(
|
||||
veorq_u64(vreinterpretq_u64_f64(a), vreinterpretq_u64_f64(b)));
|
||||
}
|
||||
|
||||
inline Vectorized<double> Vectorized<double>::eq(
|
||||
const Vectorized<double>& other) const {
|
||||
return (*this == other) & Vectorized<double>(1.0);
|
||||
}
|
||||
|
||||
inline Vectorized<double> Vectorized<double>::ne(
|
||||
const Vectorized<double>& other) const {
|
||||
return (*this != other) & Vectorized<double>(1.0);
|
||||
}
|
||||
|
||||
inline Vectorized<double> Vectorized<double>::gt(
|
||||
const Vectorized<double>& other) const {
|
||||
return (*this > other) & Vectorized<double>(1.0);
|
||||
}
|
||||
|
||||
inline Vectorized<double> Vectorized<double>::ge(
|
||||
const Vectorized<double>& other) const {
|
||||
return (*this >= other) & Vectorized<double>(1.0);
|
||||
}
|
||||
|
||||
inline Vectorized<double> Vectorized<double>::lt(
|
||||
const Vectorized<double>& other) const {
|
||||
return (*this < other) & Vectorized<double>(1.0);
|
||||
}
|
||||
|
||||
inline Vectorized<double> Vectorized<double>::le(
|
||||
const Vectorized<double>& other) const {
|
||||
return (*this <= other) & Vectorized<double>(1.0);
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<double> inline fmadd(
|
||||
const Vectorized<double>& a,
|
||||
const Vectorized<double>& b,
|
||||
const Vectorized<double>& c) {
|
||||
return vfmaq_f64(c, a, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<double> inline fnmadd(
|
||||
const Vectorized<double>& a,
|
||||
const Vectorized<double>& b,
|
||||
const Vectorized<double>& c) {
|
||||
return vfmsq_f64(c, a, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<double> inline fmsub(
|
||||
const Vectorized<double>& a,
|
||||
const Vectorized<double>& b,
|
||||
const Vectorized<double>& c) {
|
||||
return vfmaq_f64(vnegq_f64(c), a, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<double> inline fnmsub(
|
||||
const Vectorized<double>& a,
|
||||
const Vectorized<double>& b,
|
||||
const Vectorized<double>& c) {
|
||||
return vfmsq_f64(vnegq_f64(c), a, b);
|
||||
}
|
||||
|
||||
} // namespace CPU_CAPABILITY
|
||||
} // namespace at::vec
|
||||
@ -307,11 +307,49 @@ class Vectorized<float> {
|
||||
DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(exp)
|
||||
DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(exp2)
|
||||
DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(expm1)
|
||||
// Implementation copied from Arm Optimized Routine
|
||||
// https://github.com/ARM-software/optimized-routines/blob/master/math/aarch64/advsimd/expf.c
|
||||
Vectorized<float> exp_u20() const {
|
||||
return exp();
|
||||
// bail out to sleef if it's a special case:
|
||||
// i.e. there's an input s.t. |input| > 87.3....
|
||||
const float32x4_t special_bound = vdupq_n_f32(0x1.5d5e2ap+6f);
|
||||
uint32x4_t cmp = vcagtq_f32(values, special_bound);
|
||||
if (vpaddd_u64(vreinterpretq_u64_u32(cmp)) != 0) {
|
||||
return exp();
|
||||
}
|
||||
|
||||
const float32x4_t inv_ln2 = vdupq_n_f32(0x1.715476p+0f);
|
||||
const float ln2_hi = 0x1.62e4p-1f;
|
||||
const float ln2_lo = 0x1.7f7d1cp-20f;
|
||||
const float c0 = 0x1.0e4020p-7f;
|
||||
const float c2 = 0x1.555e66p-3f;
|
||||
const float32x4_t ln2_c02 = {ln2_hi, ln2_lo, c0, c2};
|
||||
|
||||
const uint32x4_t exponent_bias = vdupq_n_u32(0x3f800000);
|
||||
const float32x4_t c1 = vdupq_n_f32(0x1.573e2ep-5f);
|
||||
const float32x4_t c3 = vdupq_n_f32(0x1.fffdb6p-2f);
|
||||
const float32x4_t c4 = vdupq_n_f32(0x1.ffffecp-1f);
|
||||
|
||||
/* exp(x) = 2^n (1 + poly(r)), with 1 + poly(r) in [1/sqrt(2),sqrt(2)]
|
||||
x = ln2*n + r, with r in [-ln2/2, ln2/2]. */
|
||||
|
||||
float32x4_t n = vrndaq_f32(vmulq_f32(values, inv_ln2));
|
||||
float32x4_t r = vfmsq_laneq_f32(values, n, ln2_c02, 0);
|
||||
r = vfmsq_laneq_f32(r, n, ln2_c02, 1);
|
||||
uint32x4_t e = vshlq_n_u32(vreinterpretq_u32_s32(vcvtq_s32_f32(n)), 23);
|
||||
float32x4_t scale = vreinterpretq_f32_u32(vaddq_u32(e, exponent_bias));
|
||||
|
||||
float32x4_t r2 = vmulq_f32(r, r);
|
||||
float32x4_t p = vfmaq_laneq_f32(c1, r, ln2_c02, 2);
|
||||
float32x4_t q = vfmaq_laneq_f32(c3, r, ln2_c02, 3);
|
||||
q = vfmaq_f32(q, p, r2);
|
||||
p = vmulq_f32(c4, r);
|
||||
float32x4_t poly = vfmaq_f32(p, q, r2);
|
||||
|
||||
return vfmaq_f32(scale, poly, scale);
|
||||
}
|
||||
Vectorized<float> fexp_u20() const {
|
||||
return exp();
|
||||
return exp_u20();
|
||||
}
|
||||
DEFINE_SLEEF_COMPATIBLE_BINARY_ELEMENTWISE_FUNC_WITH_SLEEF_NAME(
|
||||
fmod,
|
||||
@ -540,42 +578,6 @@ inline Vectorized<float> Vectorized<float>::le(
|
||||
return (*this <= other) & Vectorized<float>(1.0f);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline void convert(const float* src, int32_t* dst, int64_t n) {
|
||||
int64_t i;
|
||||
#ifndef __msvc_cl__
|
||||
#pragma unroll
|
||||
#endif
|
||||
for (i = 0; i <= (n - Vectorized<float>::size());
|
||||
i += Vectorized<float>::size()) {
|
||||
vst1q_s32(dst + i, vcvtq_s32_f32(vld1q_f32(src + i)));
|
||||
}
|
||||
#ifndef __msvc_cl__
|
||||
#pragma unroll
|
||||
#endif
|
||||
for (; i < n; i++) {
|
||||
dst[i] = static_cast<int32_t>(src[i]);
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
inline void convert(const int32_t* src, float* dst, int64_t n) {
|
||||
int64_t i;
|
||||
#ifndef __msvc_cl__
|
||||
#pragma unroll
|
||||
#endif
|
||||
for (i = 0; i <= (n - Vectorized<float>::size());
|
||||
i += Vectorized<float>::size()) {
|
||||
vst1q_f32(dst + i, vcvtq_f32_s32(vld1q_s32(src + i)));
|
||||
}
|
||||
#ifndef __msvc_cl__
|
||||
#pragma unroll
|
||||
#endif
|
||||
for (; i < n; i++) {
|
||||
dst[i] = static_cast<float>(src[i]);
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<float> inline fmadd(
|
||||
const Vectorized<float>& a,
|
||||
|
||||
@ -569,46 +569,6 @@ inline Vectorized<c10::Half> Vectorized<c10::Half>::le(
|
||||
return (*this <= other) & Vectorized<c10::Half>(1);
|
||||
}
|
||||
|
||||
// These are global functions, so the defaults in vec_base.h should
|
||||
// work fine if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC is not available.
|
||||
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
|
||||
template <>
|
||||
inline void convert(const float16_t* src, int16_t* dst, int64_t n) {
|
||||
int64_t i;
|
||||
#ifndef __msvc_cl__
|
||||
#pragma unroll
|
||||
#endif
|
||||
for (i = 0; i <= (n - Vectorized<c10::Half>::size());
|
||||
i += Vectorized<c10::Half>::size()) {
|
||||
vst1q_s16(dst + i, vcvtq_s16_f16(vld1q_f16(src + i)));
|
||||
}
|
||||
#ifndef __msvc_cl__
|
||||
#pragma unroll
|
||||
#endif
|
||||
for (; i < n; i++) {
|
||||
dst[i] = static_cast<int16_t>(src[i]);
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
inline void convert(const int16_t* src, float16_t* dst, int64_t n) {
|
||||
int64_t i;
|
||||
#ifndef __msvc_cl__
|
||||
#pragma unroll
|
||||
#endif
|
||||
for (i = 0; i <= (n - Vectorized<c10::Half>::size());
|
||||
i += Vectorized<c10::Half>::size()) {
|
||||
vst1q_f16(dst + i, vcvtq_f16_s16(vld1q_s16(src + i)));
|
||||
}
|
||||
#ifndef __msvc_cl__
|
||||
#pragma unroll
|
||||
#endif
|
||||
for (; i < n; i++) {
|
||||
dst[i] = static_cast<float16_t>(src[i]);
|
||||
}
|
||||
}
|
||||
#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
|
||||
|
||||
template <>
|
||||
Vectorized<c10::Half> inline fmadd(
|
||||
const Vectorized<c10::Half>& a,
|
||||
|
||||
@ -168,11 +168,9 @@ void CUDAGraph::instantiate() {
|
||||
// https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__GRAPH.html#group__CUDART__GRAPH_1g1accfe1da0c605a577c22d9751a09597
|
||||
// cudaGraphInstantiateWithFlags
|
||||
// https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__GRAPH.html#group__CUDART__GRAPH_1ga2c652a24ba93e52b99a47bec0888233
|
||||
#if !defined(USE_ROCM) || ROCM_VERSION >= 60200
|
||||
int version = 0;
|
||||
AT_CUDA_CHECK(cudaDriverGetVersion(&version));
|
||||
if (version < 11040) {
|
||||
#endif
|
||||
// Trailing NULL, NULL, 0 arguments were recommended by Cuda driver people,
|
||||
// who prefer not to report error message through these arguments moving forward
|
||||
// (they prefer return value, or errors on api calls internal to the capture)
|
||||
@ -183,13 +181,11 @@ void CUDAGraph::instantiate() {
|
||||
#endif
|
||||
//Since ROCm 6.2, we want to go down this path as hipGraphExecDestroy in the destructor will not immediately free the memory.
|
||||
//It will wait for the next sync operation. cudaGraphInstantiateFlagAutoFreeOnLaunch will add async frees after graph launch.
|
||||
#if !defined(USE_ROCM) || ROCM_VERSION >= 60200
|
||||
} else {
|
||||
AT_CUDA_CHECK(cudaGraphInstantiateWithFlags(&graph_exec_,
|
||||
graph_,
|
||||
cudaGraphInstantiateFlagAutoFreeOnLaunch));
|
||||
}
|
||||
#endif
|
||||
has_graph_exec_ = true;
|
||||
}
|
||||
|
||||
|
||||
192
aten/src/ATen/cuda/CUDAGreenContext.cpp
Normal file
192
aten/src/ATen/cuda/CUDAGreenContext.cpp
Normal file
@ -0,0 +1,192 @@
|
||||
#include <ATen/cuda/CUDAGreenContext.h>
|
||||
|
||||
namespace at::cuda {
|
||||
GreenContext::GreenContext(uint32_t device_id, uint32_t num_sms) {
|
||||
#if CUDA_HAS_GREEN_CONTEXT
|
||||
int driver_version;
|
||||
C10_CUDA_CHECK(cudaDriverGetVersion(&driver_version));
|
||||
TORCH_CHECK(
|
||||
driver_version >= 12080, "cuda driver too old to use green context!");
|
||||
CUcontext pctx = nullptr;
|
||||
C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuCtxGetCurrent_(&pctx));
|
||||
if (C10_UNLIKELY(!pctx)) {
|
||||
TORCH_WARN(
|
||||
"Attempted to create a green context but"
|
||||
" there was no primary context! Creating a primary context...");
|
||||
|
||||
cudaFree(0);
|
||||
}
|
||||
|
||||
CUdevice device;
|
||||
device_id_ = device_id;
|
||||
C10_CUDA_DRIVER_CHECK(
|
||||
c10::cuda::DriverAPI::get()->cuDeviceGet_(&device, device_id));
|
||||
|
||||
// Get device resources
|
||||
CUdevResource device_resource;
|
||||
C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuDeviceGetDevResource_(
|
||||
device, &device_resource, CU_DEV_RESOURCE_TYPE_SM));
|
||||
|
||||
// Split resources
|
||||
std::vector<CUdevResource> result(1);
|
||||
auto result_data = result.data();
|
||||
unsigned int nb_groups = 1;
|
||||
CUdevResource remaining;
|
||||
|
||||
C10_CUDA_DRIVER_CHECK(
|
||||
c10::cuda::DriverAPI::get()->cuDevSmResourceSplitByCount_(
|
||||
result_data,
|
||||
&nb_groups,
|
||||
&device_resource,
|
||||
&remaining,
|
||||
0, // default flags
|
||||
num_sms));
|
||||
|
||||
TORCH_CHECK(nb_groups == 1, "Failed to create single resource group");
|
||||
|
||||
// Generate resource descriptor
|
||||
CUdevResourceDesc desc;
|
||||
C10_CUDA_DRIVER_CHECK(
|
||||
c10::cuda::DriverAPI::get()->cuDevResourceGenerateDesc_(
|
||||
&desc, result_data, 1));
|
||||
|
||||
// Create green context
|
||||
// CU_GREEN_CTX_DEFAULT_STREAM is required per docs:
|
||||
// https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__GREEN__CONTEXTS.html
|
||||
C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuGreenCtxCreate_(
|
||||
&green_ctx_, desc, device, CU_GREEN_CTX_DEFAULT_STREAM));
|
||||
|
||||
// Convert to regular context
|
||||
C10_CUDA_DRIVER_CHECK(
|
||||
c10::cuda::DriverAPI::get()->cuCtxFromGreenCtx_(&context_, green_ctx_));
|
||||
TORCH_CHECK(context_, "Green ctx conversion to regular ctx failed!");
|
||||
#else
|
||||
TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
|
||||
#endif
|
||||
}
|
||||
|
||||
std::unique_ptr<GreenContext> GreenContext::create(
|
||||
uint32_t num_sms,
|
||||
std::optional<uint32_t> device_id) {
|
||||
#if CUDA_HAS_GREEN_CONTEXT
|
||||
if (!device_id.has_value()) {
|
||||
device_id = at::cuda::current_device();
|
||||
}
|
||||
return std::make_unique<GreenContext>(device_id.value(), num_sms);
|
||||
#else
|
||||
TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
|
||||
#endif
|
||||
}
|
||||
|
||||
// Implement move operations
|
||||
GreenContext::GreenContext(GreenContext&& other) noexcept{
|
||||
#if CUDA_HAS_GREEN_CONTEXT
|
||||
device_id_ = std::exchange(other.device_id_, -1);
|
||||
green_ctx_ = std::exchange(other.green_ctx_, nullptr);
|
||||
context_ = std::exchange(other.context_, nullptr);
|
||||
parent_stream_ = std::exchange(other.parent_stream_, nullptr);
|
||||
#else
|
||||
TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
|
||||
#endif
|
||||
}
|
||||
|
||||
GreenContext& GreenContext::operator=(GreenContext&& other) noexcept{
|
||||
#if CUDA_HAS_GREEN_CONTEXT
|
||||
if (this != &other) {
|
||||
// Clean up current resources
|
||||
if (green_ctx_) {
|
||||
CUcontext current = nullptr;
|
||||
C10_CUDA_DRIVER_CHECK(
|
||||
c10::cuda::DriverAPI::get()->cuCtxGetCurrent_(¤t));
|
||||
if (current == context_) {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"attempting to overwrite current green ctx "
|
||||
"when it is active!");
|
||||
}
|
||||
C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuGreenCtxDestroy_(green_ctx_));
|
||||
}
|
||||
|
||||
// Take ownership of other's resources
|
||||
device_id_ = std::exchange(other.device_id_, -1);
|
||||
green_ctx_ = std::exchange(other.green_ctx_, nullptr);
|
||||
context_ = std::exchange(other.context_, nullptr);
|
||||
parent_stream_ = std::exchange(other.parent_stream_, nullptr);
|
||||
}
|
||||
return *this;
|
||||
#else
|
||||
TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
|
||||
#endif
|
||||
}
|
||||
|
||||
GreenContext::~GreenContext() noexcept{
|
||||
#if CUDA_HAS_GREEN_CONTEXT
|
||||
C10_CUDA_DRIVER_CHECK(
|
||||
c10::cuda::DriverAPI::get()->cuGreenCtxDestroy_(green_ctx_));
|
||||
#else
|
||||
TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
|
||||
#endif
|
||||
}
|
||||
|
||||
// Get the underlying CUDA context
|
||||
CUcontext GreenContext::getContext() const {
|
||||
#if CUDA_HAS_GREEN_CONTEXT
|
||||
return context_;
|
||||
#else
|
||||
TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
|
||||
#endif
|
||||
}
|
||||
|
||||
// Get the underlying green context
|
||||
#if CUDA_HAS_GREEN_CONTEXT
|
||||
CUgreenCtx GreenContext::getGreenContext() const {
|
||||
return green_ctx_;
|
||||
}
|
||||
#endif
|
||||
|
||||
// Make this context current
|
||||
void GreenContext::setContext() {
|
||||
#if CUDA_HAS_GREEN_CONTEXT
|
||||
auto current_stream = c10::cuda::getCurrentCUDAStream();
|
||||
parent_stream_ = current_stream.stream();
|
||||
|
||||
at::cuda::CUDAEvent ev;
|
||||
ev.record(current_stream);
|
||||
|
||||
CUcontext current = nullptr;
|
||||
C10_CUDA_DRIVER_CHECK(
|
||||
c10::cuda::DriverAPI::get()->cuCtxGetCurrent_(¤t));
|
||||
if (!current) {
|
||||
C10_CUDA_DRIVER_CHECK(
|
||||
c10::cuda::DriverAPI::get()->cuCtxSetCurrent_(context_));
|
||||
} else {
|
||||
C10_CUDA_DRIVER_CHECK(
|
||||
c10::cuda::DriverAPI::get()->cuCtxPushCurrent_(context_));
|
||||
}
|
||||
// currently hardcodes the new green context to use the default stream
|
||||
// TODO(eqy): consider creating a new stream if e.g., it allows interop
|
||||
// with CUDA Graph captures etc.
|
||||
auto default_stream = c10::cuda::getDefaultCUDAStream();
|
||||
ev.block(default_stream);
|
||||
c10::cuda::setCurrentCUDAStream(default_stream);
|
||||
#else
|
||||
TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
|
||||
#endif
|
||||
}
|
||||
|
||||
void GreenContext::popContext() {
|
||||
#if CUDA_HAS_GREEN_CONTEXT
|
||||
// see above note about stream being hardcoded to the default stream
|
||||
at::cuda::CUDAEvent ev;
|
||||
ev.record(c10::cuda::getCurrentCUDAStream());
|
||||
CUcontext popped;
|
||||
C10_CUDA_DRIVER_CHECK(
|
||||
c10::cuda::DriverAPI::get()->cuCtxPopCurrent_(&popped));
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
popped == context_, "expected popped context to be the current ctx");
|
||||
ev.block(c10::cuda::getStreamFromExternal(parent_stream_, device_id_));
|
||||
#else
|
||||
TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
|
||||
#endif
|
||||
}
|
||||
} // namespace at::cuda
|
||||
53
aten/src/ATen/cuda/CUDAGreenContext.h
Normal file
53
aten/src/ATen/cuda/CUDAGreenContext.h
Normal file
@ -0,0 +1,53 @@
|
||||
#pragma once
|
||||
#include <ATen/cuda/CUDAEvent.h>
|
||||
|
||||
#if defined(CUDA_VERSION) && !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
|
||||
#include <c10/cuda/driver_api.h>
|
||||
#include <cuda.h>
|
||||
#include <memory>
|
||||
#include <stdexcept>
|
||||
#include <vector>
|
||||
#define CUDA_HAS_GREEN_CONTEXT 1
|
||||
#else
|
||||
#define CUDA_HAS_GREEN_CONTEXT 0
|
||||
#endif
|
||||
|
||||
namespace at::cuda {
|
||||
|
||||
class TORCH_CUDA_CPP_API GreenContext {
|
||||
public:
|
||||
GreenContext(uint32_t device_id, uint32_t num_sms);
|
||||
|
||||
static std::unique_ptr<GreenContext> create(uint32_t num_sms, std::optional<uint32_t> device_id);
|
||||
|
||||
// Delete copy constructor and assignment
|
||||
GreenContext(const GreenContext&) = delete;
|
||||
GreenContext& operator=(const GreenContext&) = delete;
|
||||
|
||||
// Implement move operations
|
||||
GreenContext(GreenContext&& other) noexcept;
|
||||
GreenContext& operator=(GreenContext&& other) noexcept;
|
||||
~GreenContext() noexcept;
|
||||
|
||||
// Get the underlying CUDA context
|
||||
CUcontext getContext() const;
|
||||
|
||||
// Get the underlying green context
|
||||
#if CUDA_HAS_GREEN_CONTEXT
|
||||
CUgreenCtx getGreenContext() const;
|
||||
#endif
|
||||
|
||||
// Make this context current
|
||||
void setContext();
|
||||
|
||||
void popContext();
|
||||
|
||||
private:
|
||||
#if CUDA_HAS_GREEN_CONTEXT
|
||||
int32_t device_id_ = -1;
|
||||
CUgreenCtx green_ctx_ = nullptr;
|
||||
CUcontext context_ = nullptr;
|
||||
cudaStream_t parent_stream_ = nullptr;
|
||||
#endif
|
||||
};
|
||||
} // namespace at::cuda
|
||||
270
aten/src/ATen/cuda/CUDAScaledBlas.cpp
Normal file
270
aten/src/ATen/cuda/CUDAScaledBlas.cpp
Normal file
@ -0,0 +1,270 @@
|
||||
#include <cstdint>
|
||||
#include <c10/util/typeid.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#include <c10/util/SmallVector.h>
|
||||
#include <c10/core/Scalar.h>
|
||||
#include <c10/core/ScalarType.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
|
||||
#include <ATen/core/Tensor.h>
|
||||
#include <ATen/core/NamedTensor.h>
|
||||
#include <ATen/Dispatch.h>
|
||||
#include <ATen/ExpandUtils.h>
|
||||
#include <ATen/OpMathType.h>
|
||||
#include <ATen/TensorUtils.h>
|
||||
#include <ATen/cuda/CUDABlas.h>
|
||||
#include <ATen/cuda/tunable/Tunable.h>
|
||||
#include <ATen/cuda/tunable/TunableGemm.h>
|
||||
#include <ATen/native/Resize.h>
|
||||
#include <c10/util/MaybeOwned.h>
|
||||
#include <ATen/native/GroupedMMUtils.h>
|
||||
#include <ATen/native/cuda/RowwiseScaledMM.h>
|
||||
#include <ATen/native/cuda/ScaledGroupMM.h>
|
||||
#include <ATen/native/cuda/GroupMM.h>
|
||||
#include <ATen/ceil_div.h>
|
||||
|
||||
#ifdef USE_FBGEMM_GENAI
|
||||
#include <fbgemm_gpu/torch_ops.h>
|
||||
#endif
|
||||
|
||||
#ifndef AT_PER_OPERATOR_HEADERS
|
||||
#include <ATen/Functions.h>
|
||||
#include <ATen/NativeFunctions.h>
|
||||
#else
|
||||
#include <ATen/ops/_addmm_activation_native.h>
|
||||
#include <ATen/ops/_efficientzerotensor.h>
|
||||
#include <ATen/ops/_scaled_mm_native.h>
|
||||
#include <ATen/ops/_unsafe_view_native.h>
|
||||
#include <ATen/ops/abs.h>
|
||||
#include <ATen/ops/addmm_native.h>
|
||||
#include <ATen/ops/addmv_native.h>
|
||||
#include <ATen/ops/baddbmm_native.h>
|
||||
#include <ATen/ops/bmm_native.h>
|
||||
#include <ATen/ops/copy_native.h>
|
||||
#include <ATen/ops/dot_native.h>
|
||||
#include <ATen/ops/empty.h>
|
||||
#include <ATen/ops/empty_strided.h>
|
||||
#include <ATen/ops/gelu.h>
|
||||
#include <ATen/ops/max.h>
|
||||
#include <ATen/ops/mm_native.h>
|
||||
#include <ATen/ops/mul.h>
|
||||
#include <ATen/ops/relu.h>
|
||||
#include <ATen/ops/ones.h>
|
||||
#include <ATen/ops/scalar_tensor_native.h>
|
||||
#include <ATen/ops/vdot_native.h>
|
||||
#endif
|
||||
|
||||
using at::blas::ScalingType;
|
||||
using at::blas::SwizzleType;
|
||||
|
||||
namespace at::cuda::scaled {
|
||||
|
||||
/**
|
||||
* Both inputs must be fp8,
|
||||
* Each needs a single scale, {Tensorwise (float)}
|
||||
*/
|
||||
bool check_tensorwise_recipe(c10::ScalarType type_a,
|
||||
std::vector<ScalingType>& recipe_a,
|
||||
ArrayRef<Tensor>& scales_a,
|
||||
c10::ScalarType type_b,
|
||||
std::vector<ScalingType>& recipe_b,
|
||||
ArrayRef<Tensor>& scales_b) {
|
||||
// both types must be fp8
|
||||
if (!isFloat8Type(type_a) || !isFloat8Type(type_b)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// 1 scale each, {Tensorwise, float}
|
||||
if (scales_a.size() != 1 || recipe_a.size() != 1 || scales_b.size() != 1 || recipe_b.size() != 1) {
|
||||
return false;
|
||||
}
|
||||
// Need {Blockwise_1x32, e8m0} for A & B
|
||||
if (recipe_a[0] != ScalingType::TensorWise) return false;
|
||||
if (scales_a[0].scalar_type() != ScalarType::Float) return false;
|
||||
if (recipe_b[0] != ScalingType::TensorWise) return false;
|
||||
if (scales_b[0].scalar_type() != ScalarType::Float) return false;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
* Both inputs must be fp8,
|
||||
* Each needs scales, {Rowwise (float)}
|
||||
*/
|
||||
bool check_rowwise_recipe(c10::ScalarType type_a,
|
||||
std::vector<ScalingType>& recipe_a,
|
||||
ArrayRef<Tensor>& scales_a,
|
||||
c10::ScalarType type_b,
|
||||
std::vector<ScalingType>& recipe_b,
|
||||
ArrayRef<Tensor>& scales_b) {
|
||||
// both types must be fp8
|
||||
if (!isFloat8Type(type_a) || !isFloat8Type(type_b)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// 1 scale each, {Tensorwise, float}
|
||||
if (scales_a.size() != 1 || recipe_a.size() != 1 || scales_b.size() != 1 || recipe_b.size() != 1) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Need {RowWise, dp32} for A & B
|
||||
if (recipe_a[0] != ScalingType::RowWise) return false;
|
||||
if (scales_a[0].scalar_type() != ScalarType::Float) return false;
|
||||
if (recipe_b[0] != ScalingType::RowWise) return false;
|
||||
if (scales_b[0].scalar_type() != ScalarType::Float) return false;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Two-level scaling, canonical NVFP4
|
||||
* Both inputs must be fp4
|
||||
* A, B need 2 scales, {Blockwise_1x16 (e4m3), Tensorwise (fp32)}
|
||||
*/
|
||||
bool check_nvfp4_recipe(c10::ScalarType type_a,
|
||||
std::vector<ScalingType>& recipe_a,
|
||||
ArrayRef<Tensor>& scales_a,
|
||||
c10::ScalarType type_b,
|
||||
std::vector<ScalingType>& recipe_b,
|
||||
ArrayRef<Tensor>& scales_b) {
|
||||
// both types must be fp4
|
||||
if (type_a != ScalarType::Float4_e2m1fn_x2 || type_b != ScalarType::Float4_e2m1fn_x2) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// 2 scales, 2 recipes for each input
|
||||
if (scales_a.size() != 2 || recipe_a.size() != 2 || scales_b.size() != 2 || recipe_b.size() != 2) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Need {Blockwise_1x16, e4m3 for scale[0], Tensorwise, fp32 for scale[1]}
|
||||
if (recipe_a[0] != ScalingType::BlockWise1x16 || recipe_a[1] != ScalingType::TensorWise) return false;
|
||||
if (scales_a[0].scalar_type() != ScalarType::Float8_e4m3fn || scales_a[1].scalar_type() != ScalarType::Float) return false;
|
||||
if (recipe_b[0] != ScalingType::BlockWise1x16 || recipe_b[1] != ScalingType::TensorWise) return false;
|
||||
if (scales_b[0].scalar_type() != ScalarType::Float8_e4m3fn || scales_b[1].scalar_type() != ScalarType::Float) return false;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
* Single-level scaling, what PyT currently understands
|
||||
* Both inputs must be fp4
|
||||
* A, B need 1 scale, {Blockwise_1x16 (e4m3)}
|
||||
*/
|
||||
bool check_nvfp4_recipe_single_scale
|
||||
(c10::ScalarType type_a,
|
||||
std::vector<ScalingType>& recipe_a,
|
||||
ArrayRef<Tensor>& scales_a,
|
||||
c10::ScalarType type_b,
|
||||
std::vector<ScalingType>& recipe_b,
|
||||
ArrayRef<Tensor>& scales_b) {
|
||||
// both types must be fp4
|
||||
if (type_a != ScalarType::Float4_e2m1fn_x2 || type_b != ScalarType::Float4_e2m1fn_x2) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// 2 scales, 2 recipes for each input
|
||||
if (scales_a.size() != 1 || recipe_a.size() != 1 || scales_b.size() != 1 || recipe_b.size() != 1) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Need {Blockwise_1x16, e4m3 for scale[0], Tensorwise, fp32 for scale[1]}
|
||||
if (recipe_a[0] != ScalingType::BlockWise1x16) return false;
|
||||
if (scales_a[0].scalar_type() != ScalarType::Float8_e4m3fn) return false;
|
||||
if (recipe_b[0] != ScalingType::BlockWise1x16) return false;
|
||||
if (scales_b[0].scalar_type() != ScalarType::Float8_e4m3fn) return false;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
* Both inputs must be fp8
|
||||
* A, B must only have 1 scale each, A: {Blockwise_1x128 (float), B: {Blockwise_128x128 (float)
|
||||
*/
|
||||
bool check_deepseek_recipe(ScalingType expected_recipe_a,
|
||||
ScalingType expected_recipe_b,
|
||||
c10::ScalarType type_a,
|
||||
std::vector<ScalingType>& recipe_a,
|
||||
ArrayRef<Tensor>& scales_a,
|
||||
c10::ScalarType type_b,
|
||||
std::vector<ScalingType>& recipe_b,
|
||||
ArrayRef<Tensor>& scales_b) {
|
||||
// both types must be fp8
|
||||
if (type_a != ScalarType::Float8_e4m3fn || type_b != ScalarType::Float8_e4m3fn) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// 1 scales, 1 recipes for each input
|
||||
if (scales_a.size() != 1 || recipe_a.size() != 1 || scales_b.size() != 1 || recipe_b.size() != 1) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Need {Blockwise_1x128, float} for A, {Blockwise_128x128, float} for B
|
||||
if (recipe_a[0] != expected_recipe_a) return false;
|
||||
if (scales_a[0].scalar_type() != ScalarType::Float) return false;
|
||||
if (recipe_b[0] != expected_recipe_b) return false;
|
||||
if (scales_b[0].scalar_type() != ScalarType::Float) return false;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
* Both inputs must be fp8
|
||||
* A, B must have 1 scale each, {Blockwise_1x32, e8m0}
|
||||
*/
|
||||
bool check_mxfp8_recipe(c10::ScalarType type_a,
|
||||
std::vector<ScalingType>& recipe_a,
|
||||
ArrayRef<Tensor>& scales_a,
|
||||
c10::ScalarType type_b,
|
||||
std::vector<ScalingType>& recipe_b,
|
||||
ArrayRef<Tensor>& scales_b) {
|
||||
// both types must be fp8
|
||||
if (type_a != ScalarType::Float8_e4m3fn || type_b != ScalarType::Float8_e4m3fn) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// 1 scales, 1 recipes for each input
|
||||
if (scales_a.size() != 1 || recipe_a.size() != 1 || scales_b.size() != 1 || recipe_b.size() != 1) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Need {Blockwise_1x32, e8m0} for A & B
|
||||
if (recipe_a[0] != ScalingType::BlockWise1x32) return false;
|
||||
if (scales_a[0].scalar_type() != ScalarType::Float8_e8m0fnu) return false;
|
||||
if (recipe_b[0] != ScalingType::BlockWise1x32) return false;
|
||||
if (scales_b[0].scalar_type() != ScalarType::Float8_e8m0fnu) return false;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
* Both inputs must be fp4
|
||||
* A, B must have 1 scale each, {Blockwise_1x32, e8m0}
|
||||
*/
|
||||
bool check_mxfp4_recipe(c10::ScalarType type_a,
|
||||
std::vector<ScalingType>& recipe_a,
|
||||
ArrayRef<Tensor>& scales_a,
|
||||
c10::ScalarType type_b,
|
||||
std::vector<ScalingType>& recipe_b,
|
||||
ArrayRef<Tensor>& scales_b) {
|
||||
// both types must be fp4
|
||||
if (type_a != ScalarType::Float4_e2m1fn_x2 || type_b != ScalarType::Float4_e2m1fn_x2) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// 1 scales, 1 recipes for each input
|
||||
if (scales_a.size() != 1 || recipe_a.size() != 1 || scales_b.size() != 1 || recipe_b.size() != 1) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Need {Blockwise_1x32, e8m0} for A & B
|
||||
if (recipe_a[0] != ScalingType::BlockWise1x32) return false;
|
||||
if (scales_a[0].scalar_type() != ScalarType::Float8_e8m0fnu) return false;
|
||||
if (recipe_b[0] != ScalingType::BlockWise1x32) return false;
|
||||
if (scales_b[0].scalar_type() != ScalarType::Float8_e8m0fnu) return false;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace at::native::cuda::blas::scaled
|
||||
174
aten/src/ATen/cuda/CUDAScaledBlas.h
Normal file
174
aten/src/ATen/cuda/CUDAScaledBlas.h
Normal file
@ -0,0 +1,174 @@
|
||||
#include <cstdint>
|
||||
#include <c10/util/typeid.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#include <c10/util/SmallVector.h>
|
||||
#include <c10/core/Scalar.h>
|
||||
#include <c10/core/ScalarType.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
|
||||
#include <ATen/core/Tensor.h>
|
||||
#include <ATen/core/NamedTensor.h>
|
||||
#include <ATen/Dispatch.h>
|
||||
#include <ATen/ExpandUtils.h>
|
||||
#include <ATen/OpMathType.h>
|
||||
#include <ATen/TensorUtils.h>
|
||||
#include <ATen/cuda/CUDABlas.h>
|
||||
#include <ATen/cuda/tunable/Tunable.h>
|
||||
#include <ATen/cuda/tunable/TunableGemm.h>
|
||||
#include <ATen/native/Resize.h>
|
||||
#include <c10/util/MaybeOwned.h>
|
||||
#include <ATen/native/GroupedMMUtils.h>
|
||||
#include <ATen/native/cuda/RowwiseScaledMM.h>
|
||||
#include <ATen/native/cuda/ScaledGroupMM.h>
|
||||
#include <ATen/native/cuda/GroupMM.h>
|
||||
#include <ATen/ceil_div.h>
|
||||
|
||||
#ifdef USE_FBGEMM_GENAI
|
||||
#include <fbgemm_gpu/torch_ops.h>
|
||||
#endif
|
||||
|
||||
#ifndef AT_PER_OPERATOR_HEADERS
|
||||
#include <ATen/Functions.h>
|
||||
#include <ATen/NativeFunctions.h>
|
||||
#else
|
||||
#include <ATen/ops/_addmm_activation_native.h>
|
||||
#include <ATen/ops/_efficientzerotensor.h>
|
||||
#include <ATen/ops/_scaled_mm_native.h>
|
||||
#include <ATen/ops/_unsafe_view_native.h>
|
||||
#include <ATen/ops/abs.h>
|
||||
#include <ATen/ops/addmm_native.h>
|
||||
#include <ATen/ops/addmv_native.h>
|
||||
#include <ATen/ops/baddbmm_native.h>
|
||||
#include <ATen/ops/bmm_native.h>
|
||||
#include <ATen/ops/copy_native.h>
|
||||
#include <ATen/ops/dot_native.h>
|
||||
#include <ATen/ops/empty.h>
|
||||
#include <ATen/ops/empty_strided.h>
|
||||
#include <ATen/ops/gelu.h>
|
||||
#include <ATen/ops/max.h>
|
||||
#include <ATen/ops/mm_native.h>
|
||||
#include <ATen/ops/mul.h>
|
||||
#include <ATen/ops/relu.h>
|
||||
#include <ATen/ops/ones.h>
|
||||
#include <ATen/ops/scalar_tensor_native.h>
|
||||
#include <ATen/ops/vdot_native.h>
|
||||
#endif
|
||||
|
||||
using at::blas::ScalingType;
|
||||
using at::blas::SwizzleType;
|
||||
|
||||
namespace at::cuda::scaled {
|
||||
|
||||
static bool _scaled_mm_allowed_device(bool sm90_only=false, bool sm100_only=false) {
|
||||
#ifdef USE_ROCM
|
||||
static const std::vector<std::string> archs = {
|
||||
"gfx942",
|
||||
#if ROCM_VERSION >= 60300
|
||||
"gfx1200", "gfx1201",
|
||||
#endif
|
||||
#if ROCM_VERSION >= 60500
|
||||
"gfx950"
|
||||
#endif
|
||||
};
|
||||
return at::detail::getCUDAHooks().isGPUArch(archs);
|
||||
#else
|
||||
auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||
|
||||
if (sm90_only || sm100_only) {
|
||||
return (sm90_only && dprops->major == 9) || (sm100_only && dprops->major == 10);
|
||||
} else {
|
||||
return dprops->major >= 9 || (dprops->major == 8 && dprops->minor == 9);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
#ifdef USE_ROCM
|
||||
static bool _scaled_mm_is_fnuz() {
|
||||
return at::detail::getCUDAHooks().isGPUArch({"gfx942"});
|
||||
}
|
||||
#endif
|
||||
/**
|
||||
* Track concrete implementations available
|
||||
*/
|
||||
enum class ScaledGemmImplementation {
|
||||
NONE = 0,
|
||||
TENSORWISE_TENSORWISE = 1,
|
||||
ROWWISE_ROWWISE = 2,
|
||||
BLOCK_128x128_1x128 = 3,
|
||||
BLOCK_1x128_128x128 = 4,
|
||||
BLOCK_1x128_1x128 = 5,
|
||||
MXFP8_MXFP8 = 6,
|
||||
NVFP4_NVFP4 = 7,
|
||||
NVFP4_NVFP4_SINGLE_SCALE = 8,
|
||||
MXFP4_MXFP4 = 9,
|
||||
};
|
||||
|
||||
/**
|
||||
* Convert passed int (enum) from python back into a
|
||||
* strictly-typed enum
|
||||
*/
|
||||
template <class EnumType, class ArrayType>
|
||||
std::vector<EnumType> convert_int_to_enum(ArrayType& v) {
|
||||
std::vector<EnumType> converted;
|
||||
converted.reserve(v.size());
|
||||
|
||||
for (auto vi : v) {
|
||||
converted.push_back(static_cast<EnumType>(vi));
|
||||
}
|
||||
return converted;
|
||||
}
|
||||
|
||||
bool check_tensorwise_recipe(c10::ScalarType,
|
||||
std::vector<ScalingType>&,
|
||||
ArrayRef<Tensor>&,
|
||||
c10::ScalarType,
|
||||
std::vector<ScalingType>&,
|
||||
ArrayRef<Tensor>&);
|
||||
|
||||
|
||||
bool check_rowwise_recipe(c10::ScalarType,
|
||||
std::vector<ScalingType>&,
|
||||
ArrayRef<Tensor>&,
|
||||
c10::ScalarType,
|
||||
std::vector<ScalingType>&,
|
||||
ArrayRef<Tensor>&);
|
||||
|
||||
bool check_nvfp4_recipe(c10::ScalarType,
|
||||
std::vector<ScalingType>&,
|
||||
ArrayRef<Tensor>&,
|
||||
c10::ScalarType,
|
||||
std::vector<ScalingType>&,
|
||||
ArrayRef<Tensor>&);
|
||||
|
||||
bool check_nvfp4_recipe_single_scale
|
||||
(c10::ScalarType,
|
||||
std::vector<ScalingType>&,
|
||||
ArrayRef<Tensor>&,
|
||||
c10::ScalarType,
|
||||
std::vector<ScalingType>&,
|
||||
ArrayRef<Tensor>&);
|
||||
|
||||
bool check_deepseek_recipe(ScalingType,
|
||||
ScalingType,
|
||||
c10::ScalarType,
|
||||
std::vector<ScalingType>&,
|
||||
ArrayRef<Tensor>&,
|
||||
c10::ScalarType,
|
||||
std::vector<ScalingType>&,
|
||||
ArrayRef<Tensor>&);
|
||||
|
||||
bool check_mxfp8_recipe(c10::ScalarType,
|
||||
std::vector<ScalingType>&,
|
||||
ArrayRef<Tensor>&,
|
||||
c10::ScalarType,
|
||||
std::vector<ScalingType>&,
|
||||
ArrayRef<Tensor>&);
|
||||
|
||||
bool check_mxfp4_recipe(c10::ScalarType,
|
||||
std::vector<ScalingType>&,
|
||||
ArrayRef<Tensor>&,
|
||||
c10::ScalarType,
|
||||
std::vector<ScalingType>&,
|
||||
ArrayRef<Tensor>&);
|
||||
|
||||
} // namespace at::native::cuda::blas::scaled
|
||||
23
aten/src/ATen/detail/XLAHooksInterface.cpp
Normal file
23
aten/src/ATen/detail/XLAHooksInterface.cpp
Normal file
@ -0,0 +1,23 @@
|
||||
#include <ATen/detail/XLAHooksInterface.h>
|
||||
|
||||
namespace at {
|
||||
namespace detail {
|
||||
|
||||
const XLAHooksInterface& getXLAHooks() {
|
||||
auto create_impl = [] {
|
||||
// Create XLA hooks using the registry
|
||||
auto hooks = XLAHooksRegistry()->Create("torch_xla::detail::XLAHooks", XLAHooksArgs{});
|
||||
if (hooks) {
|
||||
return hooks;
|
||||
}
|
||||
// If hooks creation fails, fall back to default implementation
|
||||
return std::make_unique<XLAHooksInterface>();
|
||||
};
|
||||
static auto hooks = create_impl();
|
||||
return *hooks;
|
||||
}
|
||||
} // namespace detail
|
||||
|
||||
C10_DEFINE_REGISTRY(XLAHooksRegistry, XLAHooksInterface, XLAHooksArgs)
|
||||
|
||||
} // namespace at
|
||||
79
aten/src/ATen/detail/XLAHooksInterface.h
Normal file
79
aten/src/ATen/detail/XLAHooksInterface.h
Normal file
@ -0,0 +1,79 @@
|
||||
#pragma once
|
||||
|
||||
#include <c10/core/Device.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#include <c10/util/Registry.h>
|
||||
|
||||
#include <ATen/detail/AcceleratorHooksInterface.h>
|
||||
|
||||
C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-parameter")
|
||||
|
||||
namespace at {
|
||||
|
||||
constexpr const char* XLA_HELP =
|
||||
"This error has occurred because you are trying "
|
||||
"to use some XLA functionality, but the XLA library has not been "
|
||||
"loaded by the dynamic linker. You must load xla libraries by `import torch_xla`";
|
||||
|
||||
struct TORCH_API XLAHooksInterface : AcceleratorHooksInterface {
|
||||
~XLAHooksInterface() override = default;
|
||||
|
||||
void init() const override {
|
||||
TORCH_CHECK(false, "Cannot initialize XLA without torch_xla library. ", XLA_HELP);
|
||||
}
|
||||
|
||||
virtual bool hasXLA() const {
|
||||
return false;
|
||||
}
|
||||
|
||||
virtual std::string showConfig() const {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"Cannot query detailed XLA version without torch_xla library. ",
|
||||
XLA_HELP);
|
||||
}
|
||||
|
||||
const Generator& getDefaultGenerator(
|
||||
[[maybe_unused]] DeviceIndex device_index = -1) const override {
|
||||
TORCH_CHECK(
|
||||
false, "Cannot get default XLA generator without torch_xla library. ", XLA_HELP);
|
||||
}
|
||||
|
||||
Generator getNewGenerator(
|
||||
[[maybe_unused]] DeviceIndex device_index = -1) const override {
|
||||
TORCH_CHECK(false, "Cannot get XLA generator without torch_xla library. ", XLA_HELP);
|
||||
}
|
||||
|
||||
virtual DeviceIndex getCurrentDevice() const override {
|
||||
TORCH_CHECK(false, "Cannot get current XLA device without torch_xla library. ", XLA_HELP);
|
||||
}
|
||||
|
||||
Device getDeviceFromPtr(void* /*data*/) const override {
|
||||
TORCH_CHECK(false, "Cannot get device of pointer on XLA without torch_xla library. ", XLA_HELP);
|
||||
}
|
||||
|
||||
Allocator* getPinnedMemoryAllocator() const override {
|
||||
TORCH_CHECK(false, "Cannot get XLA pinned memory allocator without torch_xla library. ", XLA_HELP);
|
||||
}
|
||||
|
||||
bool isPinnedPtr(const void* data) const override {
|
||||
return false;
|
||||
}
|
||||
|
||||
bool hasPrimaryContext(DeviceIndex device_index) const override {
|
||||
TORCH_CHECK(false, "Cannot query primary context without torch_xla library. ", XLA_HELP);
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
struct TORCH_API XLAHooksArgs {};
|
||||
|
||||
TORCH_DECLARE_REGISTRY(XLAHooksRegistry, XLAHooksInterface, XLAHooksArgs);
|
||||
#define REGISTER_XLA_HOOKS(clsname) \
|
||||
C10_REGISTER_CLASS(XLAHooksRegistry, clsname, clsname)
|
||||
|
||||
namespace detail {
|
||||
TORCH_API const XLAHooksInterface& getXLAHooks();
|
||||
} // namespace detail
|
||||
} // namespace at
|
||||
C10_DIAGNOSTIC_POP()
|
||||
@ -259,11 +259,20 @@ inline void winograd_f2k3_input_transform_inplace__rvv(
|
||||
const vfloat32m1_t wd1 = __riscv_vfadd_vv_f32m1(d1, d2, 4);
|
||||
const vfloat32m1_t wd2 = __riscv_vfsub_vv_f32m1(d2, d1, 4);
|
||||
const vfloat32m1_t wd3 = __riscv_vfsub_vv_f32m1(d1, d3, 4);
|
||||
|
||||
*input_tile_val = __riscv_vset_v_f32m1_f32m1x4(*input_tile_val, 0, wd0);
|
||||
*input_tile_val = __riscv_vset_v_f32m1_f32m1x4(*input_tile_val, 1, wd1);
|
||||
*input_tile_val = __riscv_vset_v_f32m1_f32m1x4(*input_tile_val, 2, wd2);
|
||||
*input_tile_val = __riscv_vset_v_f32m1_f32m1x4(*input_tile_val, 3, wd3);
|
||||
/* GCC 14.2 (RISC-V RVV) ICE workaround:
|
||||
* Avoid single-statement read-modify-write on MEM_REF like:
|
||||
* *input_tile_val =
|
||||
* __riscv_vset_v_f32m1_f32m1x4(*input_tile_val, idx, val);
|
||||
* This triggers an ICE during GIMPLE lower (gsi_replace / riscv_gimple_fold_builtin)
|
||||
* with -march=rv64gcv. Use a temporary then write back.
|
||||
* Do NOT refactor into the single-statement form. Clang is unaffected.
|
||||
*/
|
||||
vfloat32m1x4_t tmp_input_tile_val = *input_tile_val;
|
||||
tmp_input_tile_val = __riscv_vset_v_f32m1_f32m1x4(tmp_input_tile_val, 0, wd0);
|
||||
tmp_input_tile_val = __riscv_vset_v_f32m1_f32m1x4(tmp_input_tile_val, 1, wd1);
|
||||
tmp_input_tile_val = __riscv_vset_v_f32m1_f32m1x4(tmp_input_tile_val, 2, wd2);
|
||||
tmp_input_tile_val = __riscv_vset_v_f32m1_f32m1x4(tmp_input_tile_val, 3, wd3);
|
||||
*input_tile_val = tmp_input_tile_val;
|
||||
}
|
||||
|
||||
inline void winograd_f2k3_output_transform_inplace__rvv(
|
||||
@ -277,9 +286,15 @@ inline void winograd_f2k3_output_transform_inplace__rvv(
|
||||
const vfloat32m1_t wm0 = __riscv_vfadd_vv_f32m1(m0_plus_m1, m2, 4);
|
||||
const vfloat32m1_t m1_sub_m2 = __riscv_vfsub_vv_f32m1(m1, m2, 4);
|
||||
const vfloat32m1_t wm1 = __riscv_vfsub_vv_f32m1(m1_sub_m2, m3, 4);
|
||||
|
||||
*input_tile_val = __riscv_vset_v_f32m1_f32m1x4(*input_tile_val, 0, wm0);
|
||||
*input_tile_val = __riscv_vset_v_f32m1_f32m1x4(*input_tile_val, 1, wm1);
|
||||
/* GCC 14.2 (RISC-V RVV) ICE workaround — see note above.
|
||||
* Keep the temporary + write-back pattern to avoid ICE.
|
||||
* Do NOT rewrite into:
|
||||
* *input_tile_val = __riscv_vset_v_f32m1_f32m1x4(*input_tile_val, idx, val);
|
||||
*/
|
||||
vfloat32m1x4_t tmp_output_tile_val = *input_tile_val;
|
||||
tmp_output_tile_val = __riscv_vset_v_f32m1_f32m1x4(tmp_output_tile_val, 0, wm0);
|
||||
tmp_output_tile_val = __riscv_vset_v_f32m1_f32m1x4(tmp_output_tile_val, 1, wm1);
|
||||
*input_tile_val = tmp_output_tile_val;
|
||||
}
|
||||
|
||||
inline vfloat32m1_t
|
||||
@ -300,11 +315,17 @@ inline void winograd_f2k3_kernel_transform__rvv(
|
||||
const vfloat32m1_t const_half = __riscv_vfmv_v_f_f32m1(0.5f, 4);
|
||||
const vfloat32m1_t g0_plus_g2 = __riscv_vfadd_vv_f32m1(g0, g2, 4);
|
||||
vfloat32m1_t half_g0_plus_g2 = __riscv_vfmul_vv_f32m1(const_half, g0_plus_g2, 4);
|
||||
|
||||
*transform = __riscv_vset_v_f32m1_f32m1x4(*transform, 0, g0);
|
||||
*transform = __riscv_vset_v_f32m1_f32m1x4(*transform, 1, vmuladdq_f32(half_g0_plus_g2, const_half, g1));
|
||||
*transform = __riscv_vset_v_f32m1_f32m1x4(*transform, 2, vmulsubq_f32(half_g0_plus_g2, const_half, g1));
|
||||
*transform = __riscv_vset_v_f32m1_f32m1x4(*transform, 3, g2);
|
||||
/* GCC 14.2 (RISC-V RVV) ICE workaround — see note above.
|
||||
* Keep the temporary + write-back pattern to avoid ICE.
|
||||
* Do NOT rewrite into:
|
||||
* *transform = __riscv_vset_v_f32m1_f32m1x4(*transform, idx, val);
|
||||
*/
|
||||
vfloat32m1x4_t tmp_transform = *transform;
|
||||
tmp_transform = __riscv_vset_v_f32m1_f32m1x4(tmp_transform, 0, g0);
|
||||
tmp_transform = __riscv_vset_v_f32m1_f32m1x4(tmp_transform, 1, vmuladdq_f32(half_g0_plus_g2, const_half, g1));
|
||||
tmp_transform = __riscv_vset_v_f32m1_f32m1x4(tmp_transform, 2, vmulsubq_f32(half_g0_plus_g2, const_half, g1));
|
||||
tmp_transform = __riscv_vset_v_f32m1_f32m1x4(tmp_transform, 3, g2);
|
||||
*transform = tmp_transform;
|
||||
}
|
||||
|
||||
inline vfloat32m1x4_t v4f_transpose4x4__rvv(const vfloat32m1x4_t m) {
|
||||
|
||||
@ -13,6 +13,7 @@
|
||||
#include <ATen/OpMathType.h>
|
||||
#include <ATen/TensorUtils.h>
|
||||
#include <ATen/cuda/CUDABlas.h>
|
||||
#include <ATen/cuda/CUDAScaledBlas.h>
|
||||
#include <ATen/cuda/tunable/Tunable.h>
|
||||
#include <ATen/cuda/tunable/TunableGemm.h>
|
||||
#include <ATen/native/Resize.h>
|
||||
@ -360,7 +361,7 @@ static bool isInputCompliesAddmmCudaLt(Tensor& result, const Tensor& self, const
|
||||
// and the leading stride is at least max(1, other dim length), so we might
|
||||
// end up with contiguous cols but not rows (i.e. holes between different rows)
|
||||
// and vice versa.
|
||||
mat2_sizes[0] < 65535 * 32 && mat2_sizes[1] < 65535 * 32 &&
|
||||
&& mat2_sizes[0] < 65535 * 32 && mat2_sizes[1] < 65535 * 32 &&
|
||||
mat1_sizes[0] < 65535 * 32 && mat1_sizes[1] < 65535 * 32 &&
|
||||
&& (
|
||||
// filter by dtype
|
||||
@ -1628,104 +1629,6 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
|
||||
return _scaled_gemm(mat1, mat2, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, use_fast_accum, out);
|
||||
}
|
||||
|
||||
namespace {
|
||||
void _check_scales_fp8_rowwise(const Tensor& mat, const Tensor& scale, const int dim, const int arg_idx, const int scale_multiplier=1) {
|
||||
// Checks scales for 2d or 3d target tensors (`mat`).
|
||||
if (mat.dim() == 2) {
|
||||
TORCH_CHECK(
|
||||
scale.dim() == 1,
|
||||
"scale must be a 1D tensor, but got ",
|
||||
scale.dim(),
|
||||
"D, arg ",
|
||||
arg_idx);
|
||||
TORCH_CHECK(
|
||||
scale.is_contiguous(), "scale must be contiguous for arg ", arg_idx);
|
||||
TORCH_CHECK(
|
||||
scale.size(0) == mat.size(dim) * scale_multiplier,
|
||||
"scale must have the same length as mat for arg ",
|
||||
arg_idx);
|
||||
} else {
|
||||
TORCH_CHECK(
|
||||
scale.dim() == 2,
|
||||
"scale must be a 2D tensor, but got ",
|
||||
scale.dim(),
|
||||
"D for arg ",
|
||||
arg_idx);
|
||||
TORCH_CHECK(
|
||||
scale.stride(1) == 1,
|
||||
"scale must be contiguous in the last dimension for arg ",
|
||||
arg_idx);
|
||||
TORCH_CHECK(
|
||||
scale.size(0) == mat.size(0),
|
||||
"scale must have the same batch dimension as mat for arg ",
|
||||
arg_idx);
|
||||
TORCH_CHECK(
|
||||
scale.size(1) == mat.size(1 + dim),
|
||||
"scale must have the same first dimension as mat for arg ",
|
||||
arg_idx);
|
||||
}
|
||||
}
|
||||
|
||||
void _check_scales_mxfp8(const Tensor& mat, const Tensor& scale, const int dim, const int arg_idx) {
|
||||
// Checks scales for 2d or 3d target tensors (`mat`).
|
||||
if (mat.dim() == 2) {
|
||||
// For MXFP8, 2d tensors have variable size groups represented as subtensors,
|
||||
// that are converted to blocked padded format individually,
|
||||
// so we can't check the scale sizes without doing a d2h sync to get the group sizes here.
|
||||
TORCH_CHECK(
|
||||
scale.dim() == mat.dim(),
|
||||
"for mxfp8, scale must have same number of dimensions as parent tensor, but got mat.dim() = ", mat.dim(), " and scale.dim() = ", scale.dim(), " for arg ", arg_idx);
|
||||
|
||||
// LHS mat shape (M, total_K) -> scale shape (rounded_up(M, 128), rounded_up_per_group(K/32, 4))
|
||||
// RHS mat shape (total_K, N) -> scale shape (rounded_up(N, 128), rounded_up_per_group(K/32, 4))
|
||||
// * weight is transposed prior to the call, scale stays non-transposed.
|
||||
bool LHS = arg_idx == 0;
|
||||
int scale_dim_to_check = 0;
|
||||
int mat_dim_to_check = LHS ? 0 : 1;
|
||||
TORCH_CHECK(
|
||||
scale.size(scale_dim_to_check) >= mat.size(mat_dim_to_check),
|
||||
"for mxfp8, arg ", arg_idx, " tensor shape (", mat.size(0), ", ", mat.size(1), ") ",
|
||||
"must have scale.shape[", scale_dim_to_check, "] >= ", mat.size(mat_dim_to_check), " but got scale.shape=(", scale.size(0), ", ", scale.size(1), ")");
|
||||
} else {
|
||||
// For MXFP8, 3d tensors have static group sizes (stack of 2d tensors),
|
||||
// so we can check the exact expected scale sizes here without a d2h sync.
|
||||
auto round_up = [](auto x, auto y) {
|
||||
return ((x + y - 1) / y) * y;
|
||||
};
|
||||
|
||||
// TODO: this is for 3d tensor in 2d-3d case specifically.
|
||||
// We'll need to support 3d-3d and 3d-2d cases once mxfp8 grouped gemm supports them.
|
||||
int64_t G = mat.size(0);
|
||||
int64_t K = mat.size(1);
|
||||
int64_t N = mat.size(2);
|
||||
int64_t blocked_scale_K = round_up(K/32, 4);
|
||||
int64_t blocked_scale_N = round_up(N, 128);
|
||||
|
||||
// fbgemm expects stack of flattened blocked scales for 3d tensor, shape (G, blocked_scale_K * blocked_scale_N).
|
||||
TORCH_CHECK(
|
||||
scale.dim() == mat.dim() - 1,
|
||||
"for mxfp8 2d-3d grouped GEMM, the 3d tensor of shape (G,K,N) must have a 2d scale of shape (G, blocked_scale_K * blocked_scale_N), but scale is ", scale.dim(), "D for arg ", arg_idx
|
||||
);
|
||||
TORCH_CHECK(
|
||||
scale.size(0) == G && scale.size(1) == blocked_scale_K * blocked_scale_N,
|
||||
"for mxfp8, the tensor shape (", G, ", ", K, ", ", N, ") must have scale shape (", G, ",", blocked_scale_K, ",", blocked_scale_N, ") for arg ", arg_idx
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
void check_scale(const Tensor& mat, const Tensor& scale, const int dim, const int arg_idx, const int scale_multiplier=1) {
|
||||
bool using_fp8_rowwise = scale.scalar_type() == kFloat;
|
||||
bool using_mxfp8 = scale.scalar_type() == at::kFloat8_e8m0fnu;
|
||||
if (using_fp8_rowwise) {
|
||||
_check_scales_fp8_rowwise(mat, scale, dim, arg_idx, scale_multiplier);
|
||||
} else if (using_mxfp8) {
|
||||
_check_scales_mxfp8(mat, scale, dim, arg_idx);
|
||||
} else {
|
||||
TORCH_CHECK(false, "scale must be float32 or float8_e8m0fnu, but got ", scale.dtype());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Tensor
|
||||
_scaled_mm_cuda(const Tensor& mat_a, const Tensor& mat_b,
|
||||
const Tensor& scale_a,
|
||||
@ -1740,261 +1643,26 @@ _scaled_mm_cuda(const Tensor& mat_a, const Tensor& mat_b,
|
||||
return _scaled_mm_out_cuda(mat_a, mat_b, scale_a, scale_b, bias, scale_result, out_dtype, use_fast_accum, out);
|
||||
}
|
||||
|
||||
/**
|
||||
* Track concrete implementations available
|
||||
*/
|
||||
enum class ScaledGemmImplementation {
|
||||
NONE = 0,
|
||||
TENSORWISE_TENSORWISE = 1,
|
||||
ROWWISE_ROWWISE = 2,
|
||||
BLOCK_128x128_1x128 = 3,
|
||||
BLOCK_1x128_128x128 = 4,
|
||||
BLOCK_1x128_1x128 = 5,
|
||||
MXFP8_MXFP8 = 6,
|
||||
NVFP4_NVFP4 = 7,
|
||||
NVFP4_NVFP4_SINGLE_SCALE = 8,
|
||||
MXFP4_MXFP4 = 9,
|
||||
};
|
||||
|
||||
/**
|
||||
* Convert passed int (enum) from python back into a
|
||||
* strictly-typed enum
|
||||
*/
|
||||
template <class EnumType, class ArrayType>
|
||||
std::vector<EnumType> convert_int_to_enum(ArrayType& v) {
|
||||
std::vector<EnumType> converted;
|
||||
converted.reserve(v.size());
|
||||
|
||||
for (auto vi : v) {
|
||||
converted.push_back(static_cast<EnumType>(vi));
|
||||
}
|
||||
return converted;
|
||||
}
|
||||
|
||||
/**
|
||||
* Both inputs must be fp8,
|
||||
* Each needs a single scale, {Tensorwise (float)}
|
||||
*/
|
||||
bool check_tensorwise_recipe(c10::ScalarType type_a,
|
||||
std::vector<ScalingType>& recipe_a,
|
||||
ArrayRef<Tensor>& scales_a,
|
||||
c10::ScalarType type_b,
|
||||
std::vector<ScalingType>& recipe_b,
|
||||
ArrayRef<Tensor>& scales_b) {
|
||||
// both types must be fp8
|
||||
if (!isFloat8Type(type_a) || !isFloat8Type(type_b)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// 1 scale each, {Tensorwise, float}
|
||||
if (scales_a.size() != 1 || recipe_a.size() != 1 || scales_b.size() != 1 || recipe_b.size() != 1) {
|
||||
return false;
|
||||
}
|
||||
// Need {Blockwise_1x32, e8m0} for A & B
|
||||
if (recipe_a[0] != ScalingType::TensorWise) return false;
|
||||
if (scales_a[0].scalar_type() != ScalarType::Float) return false;
|
||||
if (recipe_b[0] != ScalingType::TensorWise) return false;
|
||||
if (scales_b[0].scalar_type() != ScalarType::Float) return false;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
* Both inputs must be fp8,
|
||||
* Each needs scales, {Rowwise (float)}
|
||||
*/
|
||||
bool check_rowwise_recipe(c10::ScalarType type_a,
|
||||
std::vector<ScalingType>& recipe_a,
|
||||
ArrayRef<Tensor>& scales_a,
|
||||
c10::ScalarType type_b,
|
||||
std::vector<ScalingType>& recipe_b,
|
||||
ArrayRef<Tensor>& scales_b) {
|
||||
// both types must be fp8
|
||||
if (!isFloat8Type(type_a) || !isFloat8Type(type_b)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// 1 scale each, {Tensorwise, float}
|
||||
if (scales_a.size() != 1 || recipe_a.size() != 1 || scales_b.size() != 1 || recipe_b.size() != 1) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Need {RowWise, dp32} for A & B
|
||||
if (recipe_a[0] != ScalingType::RowWise) return false;
|
||||
if (scales_a[0].scalar_type() != ScalarType::Float) return false;
|
||||
if (recipe_b[0] != ScalingType::RowWise) return false;
|
||||
if (scales_b[0].scalar_type() != ScalarType::Float) return false;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Two-level scaling, canonical NVFP4
|
||||
* Both inputs must be fp4
|
||||
* A, B need 2 scales, {Blockwise_1x16 (e4m3), Tensorwise (fp32)}
|
||||
*/
|
||||
bool check_nvfp4_recipe(c10::ScalarType type_a,
|
||||
std::vector<ScalingType>& recipe_a,
|
||||
ArrayRef<Tensor>& scales_a,
|
||||
c10::ScalarType type_b,
|
||||
std::vector<ScalingType>& recipe_b,
|
||||
ArrayRef<Tensor>& scales_b) {
|
||||
// both types must be fp4
|
||||
if (type_a != ScalarType::Float4_e2m1fn_x2 || type_b != ScalarType::Float4_e2m1fn_x2) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// 2 scales, 2 recipes for each input
|
||||
if (scales_a.size() != 2 || recipe_a.size() != 2 || scales_b.size() != 2 || recipe_b.size() != 2) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Need {Blockwise_1x16, e4m3 for scale[0], Tensorwise, fp32 for scale[1]}
|
||||
if (recipe_a[0] != ScalingType::BlockWise1x16 || recipe_a[1] != ScalingType::TensorWise) return false;
|
||||
if (scales_a[0].scalar_type() != ScalarType::Float8_e4m3fn || scales_a[1].scalar_type() != ScalarType::Float) return false;
|
||||
if (recipe_b[0] != ScalingType::BlockWise1x16 || recipe_b[1] != ScalingType::TensorWise) return false;
|
||||
if (scales_b[0].scalar_type() != ScalarType::Float8_e4m3fn || scales_b[1].scalar_type() != ScalarType::Float) return false;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
* Single-level scaling, what PyT currently understands
|
||||
* Both inputs must be fp4
|
||||
* A, B need 1 scale, {Blockwise_1x16 (e4m3)}
|
||||
*/
|
||||
bool check_nvfp4_recipe_single_scale
|
||||
(c10::ScalarType type_a,
|
||||
std::vector<ScalingType>& recipe_a,
|
||||
ArrayRef<Tensor>& scales_a,
|
||||
c10::ScalarType type_b,
|
||||
std::vector<ScalingType>& recipe_b,
|
||||
ArrayRef<Tensor>& scales_b) {
|
||||
// both types must be fp4
|
||||
if (type_a != ScalarType::Float4_e2m1fn_x2 || type_b != ScalarType::Float4_e2m1fn_x2) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// 2 scales, 2 recipes for each input
|
||||
if (scales_a.size() != 1 || recipe_a.size() != 1 || scales_b.size() != 1 || recipe_b.size() != 1) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Need {Blockwise_1x16, e4m3 for scale[0], Tensorwise, fp32 for scale[1]}
|
||||
if (recipe_a[0] != ScalingType::BlockWise1x16) return false;
|
||||
if (scales_a[0].scalar_type() != ScalarType::Float8_e4m3fn) return false;
|
||||
if (recipe_b[0] != ScalingType::BlockWise1x16) return false;
|
||||
if (scales_b[0].scalar_type() != ScalarType::Float8_e4m3fn) return false;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
* Both inputs must be fp8
|
||||
* A, B must only have 1 scale each, A: {Blockwise_1x128 (float), B: {Blockwise_128x128 (float)
|
||||
*/
|
||||
bool check_deepseek_recipe(ScalingType expected_recipe_a,
|
||||
ScalingType expected_recipe_b,
|
||||
c10::ScalarType type_a,
|
||||
std::vector<ScalingType>& recipe_a,
|
||||
ArrayRef<Tensor>& scales_a,
|
||||
c10::ScalarType type_b,
|
||||
std::vector<ScalingType>& recipe_b,
|
||||
ArrayRef<Tensor>& scales_b) {
|
||||
// both types must be fp8
|
||||
if (type_a != ScalarType::Float8_e4m3fn || type_b != ScalarType::Float8_e4m3fn) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// 1 scales, 1 recipes for each input
|
||||
if (scales_a.size() != 1 || recipe_a.size() != 1 || scales_b.size() != 1 || recipe_b.size() != 1) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Need {Blockwise_1x128, float} for A, {Blockwise_128x128, float} for B
|
||||
if (recipe_a[0] != expected_recipe_a) return false;
|
||||
if (scales_a[0].scalar_type() != ScalarType::Float) return false;
|
||||
if (recipe_b[0] != expected_recipe_b) return false;
|
||||
if (scales_b[0].scalar_type() != ScalarType::Float) return false;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
* Both inputs must be fp8
|
||||
* A, B must have 1 scale each, {Blockwise_1x32, e8m0}
|
||||
*/
|
||||
bool check_mxfp8_recipe(c10::ScalarType type_a,
|
||||
std::vector<ScalingType>& recipe_a,
|
||||
ArrayRef<Tensor>& scales_a,
|
||||
c10::ScalarType type_b,
|
||||
std::vector<ScalingType>& recipe_b,
|
||||
ArrayRef<Tensor>& scales_b) {
|
||||
// both types must be fp8
|
||||
if (type_a != ScalarType::Float8_e4m3fn || type_b != ScalarType::Float8_e4m3fn) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// 1 scales, 1 recipes for each input
|
||||
if (scales_a.size() != 1 || recipe_a.size() != 1 || scales_b.size() != 1 || recipe_b.size() != 1) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Need {Blockwise_1x32, e8m0} for A & B
|
||||
if (recipe_a[0] != ScalingType::BlockWise1x32) return false;
|
||||
if (scales_a[0].scalar_type() != ScalarType::Float8_e8m0fnu) return false;
|
||||
if (recipe_b[0] != ScalingType::BlockWise1x32) return false;
|
||||
if (scales_b[0].scalar_type() != ScalarType::Float8_e8m0fnu) return false;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
* Both inputs must be fp4
|
||||
* A, B must have 1 scale each, {Blockwise_1x32, e8m0}
|
||||
*/
|
||||
bool check_mxfp4_recipe(c10::ScalarType type_a,
|
||||
std::vector<ScalingType>& recipe_a,
|
||||
ArrayRef<Tensor>& scales_a,
|
||||
c10::ScalarType type_b,
|
||||
std::vector<ScalingType>& recipe_b,
|
||||
ArrayRef<Tensor>& scales_b) {
|
||||
// both types must be fp4
|
||||
if (type_a != ScalarType::Float4_e2m1fn_x2 || type_b != ScalarType::Float4_e2m1fn_x2) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// 1 scales, 1 recipes for each input
|
||||
if (scales_a.size() != 1 || recipe_a.size() != 1 || scales_b.size() != 1 || recipe_b.size() != 1) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Need {Blockwise_1x32, e8m0} for A & B
|
||||
if (recipe_a[0] != ScalingType::BlockWise1x32) return false;
|
||||
if (scales_a[0].scalar_type() != ScalarType::Float8_e8m0fnu) return false;
|
||||
if (recipe_b[0] != ScalingType::BlockWise1x32) return false;
|
||||
if (scales_b[0].scalar_type() != ScalarType::Float8_e8m0fnu) return false;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
using acceptance_fn = std::function<bool(c10::ScalarType, std::vector<ScalingType>&, ArrayRef<Tensor>&, c10::ScalarType, std::vector<ScalingType>&, ArrayRef<Tensor>&)>;
|
||||
using namespace std::placeholders;
|
||||
|
||||
namespace scaled_blas = at::cuda::scaled;
|
||||
using scaled_blas::ScaledGemmImplementation;
|
||||
using scaled_blas::convert_int_to_enum;
|
||||
|
||||
std::array<std::tuple<std::string, acceptance_fn, ScaledGemmImplementation>, 9> scale_kernel_dispatch = {{
|
||||
{ "tensorwise_tensorwise", check_tensorwise_recipe, ScaledGemmImplementation::TENSORWISE_TENSORWISE },
|
||||
{ "rowwise_rowwise", check_rowwise_recipe, ScaledGemmImplementation::ROWWISE_ROWWISE},
|
||||
{ "block_1x128_128x128", std::bind(check_deepseek_recipe, ScalingType::BlockWise1x128, ScalingType::BlockWise128x128, _1, _2, _3, _4, _5, _6),
|
||||
{ "tensorwise_tensorwise", scaled_blas::check_tensorwise_recipe, ScaledGemmImplementation::TENSORWISE_TENSORWISE },
|
||||
{ "rowwise_rowwise", scaled_blas::check_rowwise_recipe, ScaledGemmImplementation::ROWWISE_ROWWISE},
|
||||
{ "block_1x128_128x128", std::bind(scaled_blas::check_deepseek_recipe, ScalingType::BlockWise1x128, ScalingType::BlockWise128x128, _1, _2, _3, _4, _5, _6),
|
||||
ScaledGemmImplementation::BLOCK_1x128_128x128},
|
||||
{ "block_128x128_1x128", std::bind(check_deepseek_recipe, ScalingType::BlockWise128x128, ScalingType::BlockWise1x128, _1, _2, _3, _4, _5, _6),
|
||||
{ "block_128x128_1x128", std::bind(scaled_blas::check_deepseek_recipe, ScalingType::BlockWise128x128, ScalingType::BlockWise1x128, _1, _2, _3, _4, _5, _6),
|
||||
ScaledGemmImplementation::BLOCK_128x128_1x128},
|
||||
{ "block_1x128_1x128", std::bind(check_deepseek_recipe, ScalingType::BlockWise1x128, ScalingType::BlockWise1x128, _1, _2, _3, _4, _5, _6),
|
||||
{ "block_1x128_1x128", std::bind(scaled_blas::check_deepseek_recipe, ScalingType::BlockWise1x128, ScalingType::BlockWise1x128, _1, _2, _3, _4, _5, _6),
|
||||
ScaledGemmImplementation::BLOCK_1x128_1x128},
|
||||
{ "nvfp4_nvfp4", check_nvfp4_recipe, ScaledGemmImplementation::NVFP4_NVFP4},
|
||||
{ "nvfp4_nvfp4_single_scale", check_nvfp4_recipe_single_scale, ScaledGemmImplementation::NVFP4_NVFP4_SINGLE_SCALE },
|
||||
{ "mxfp8_mxfp8", check_mxfp8_recipe, ScaledGemmImplementation::MXFP8_MXFP8},
|
||||
{ "mxfp4_mxfp4", check_mxfp4_recipe, ScaledGemmImplementation::MXFP4_MXFP4}}};
|
||||
{ "nvfp4_nvfp4", scaled_blas::check_nvfp4_recipe, ScaledGemmImplementation::NVFP4_NVFP4},
|
||||
{ "nvfp4_nvfp4_single_scale", scaled_blas::check_nvfp4_recipe_single_scale, ScaledGemmImplementation::NVFP4_NVFP4_SINGLE_SCALE },
|
||||
{ "mxfp8_mxfp8", scaled_blas::check_mxfp8_recipe, ScaledGemmImplementation::MXFP8_MXFP8},
|
||||
{ "mxfp4_mxfp4", scaled_blas::check_mxfp4_recipe, ScaledGemmImplementation::MXFP4_MXFP4}}};
|
||||
|
||||
Tensor&
|
||||
_scaled_tensorwise_tensorwise(
|
||||
@ -2596,410 +2264,6 @@ _scaled_mm_cuda_v2(
|
||||
out);
|
||||
}
|
||||
|
||||
// 2d-2d and 2d-3d
|
||||
// scaling=MXFP8
|
||||
// CUDA-only
|
||||
Tensor&
|
||||
_mx8_mx8_bf16_grouped_mm_fbgemm(
|
||||
const Tensor& mat_a,
|
||||
const Tensor& mat_b,
|
||||
const Tensor& scale_a,
|
||||
const SwizzleType& swizzle_a,
|
||||
const Tensor& scale_b,
|
||||
const SwizzleType& swizzle_b,
|
||||
const std::optional<at::Tensor>& offs,
|
||||
Tensor& out) {
|
||||
const bool a_is_2d = mat_a.dim() == 2;
|
||||
const bool b_is_2d = mat_b.dim() == 2;
|
||||
bool b_is_3d = mat_b.dim() == 3;
|
||||
bool is_2d_2d = a_is_2d && b_is_2d;
|
||||
bool is_2d_3d = a_is_2d && b_is_3d;
|
||||
TORCH_CHECK_VALUE(is_2d_2d || is_2d_3d, "MXFP8 grouped GEMM currently only supports 2d-2d and 2d-3d cases");
|
||||
TORCH_CHECK_VALUE(offs.has_value(), "MXFP8 2d-2d and 2d-3d grouped GEMMs requires offsets");
|
||||
TORCH_CHECK_VALUE(out.scalar_type() == at::kBFloat16, "Only bf16 out_dtype is supported for MXFP8 grouped gemm");
|
||||
// MXFP8 expects float8_e8m0fnu scales.
|
||||
TORCH_CHECK_VALUE(scale_a.scalar_type() == at::kFloat8_e8m0fnu && scale_b.scalar_type() == at::kFloat8_e8m0fnu,
|
||||
"For MXFP8 grouped gemm, both scales must be float8_e8m0fnu tensors.");
|
||||
#ifdef USE_ROCM
|
||||
TORCH_CHECK_VALUE(swizzle_a == SwizzleType::NO_SWIZZLE && swizzle_b == SwizzleType::NO_SWIZZLE,
|
||||
"For ROCM MXFP8 grouped gemm, both scale swizzle types must be SWIZZLE_NONE");
|
||||
#else
|
||||
TORCH_CHECK_VALUE(swizzle_a == SwizzleType::SWIZZLE_32_4_4 && swizzle_b == SwizzleType::SWIZZLE_32_4_4,
|
||||
"For CUDA MXFP8 grouped gemm, both scale swizzle types must be SWIZZLE_32_4_4");
|
||||
#endif
|
||||
|
||||
#if defined(USE_FBGEMM_GENAI) and !defined(USE_ROCM)
|
||||
fbgemm_gpu::mx8mx8bf16_grouped_mm(
|
||||
mat_a,
|
||||
mat_b,
|
||||
scale_a,
|
||||
scale_b,
|
||||
offs.value(),
|
||||
out);
|
||||
#else
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(false, "mxfp8_mxfp8 grouped gemm requires compile with USE_FBGEMM_GENAI");
|
||||
#endif
|
||||
return out;
|
||||
}
|
||||
|
||||
// 2d-2d and 2d-3d cases
|
||||
// scaling=rowwise
|
||||
// CUDA-only
|
||||
Tensor&
|
||||
_f8_f8_bf16_rowwise_grouped_mm_cuda(
|
||||
const Tensor& mat_a,
|
||||
const Tensor& mat_b,
|
||||
const Tensor& scale_a,
|
||||
const Tensor& scale_b,
|
||||
const std::optional<Tensor>& offs,
|
||||
const std::optional<Tensor>& bias,
|
||||
const bool use_fast_accum,
|
||||
Tensor& out) {
|
||||
TORCH_CHECK_VALUE(mat_a.dtype() == at::kFloat8_e4m3fn, "Expected mat_a to be Float8_e4m3 matrix got ", mat_a.scalar_type());
|
||||
TORCH_CHECK_VALUE(mat_b.dtype() == at::kFloat8_e4m3fn, "Expected mat_a to be Float8_e4m3 matrix got ", mat_b.scalar_type());
|
||||
|
||||
at::cuda::detail::f8f8bf16_grouped_mm(
|
||||
mat_a,
|
||||
mat_b,
|
||||
scale_a,
|
||||
scale_b,
|
||||
offs,
|
||||
bias,
|
||||
use_fast_accum,
|
||||
out);
|
||||
return out;
|
||||
}
|
||||
|
||||
// 2d-2d and 2d-3d cases
|
||||
// scaling=rowwise
|
||||
// only being called for rocm
|
||||
Tensor&
|
||||
_f8_f8_bf16_rowwise_grouped_mm_rocm(
|
||||
const Tensor& mat_a,
|
||||
const Tensor& mat_b,
|
||||
const Tensor& scale_a,
|
||||
const Tensor& scale_b,
|
||||
const std::optional<Tensor>& offs,
|
||||
Tensor& out) {
|
||||
TORCH_CHECK_VALUE(mat_a.dtype() == at::kFloat8_e4m3fnuz, "Expected mat_a to be Float8_e4m3fnuz matrix got ", mat_a.scalar_type());
|
||||
TORCH_CHECK_VALUE(mat_b.dtype() == at::kFloat8_e4m3fnuz, "Expected mat_a to be Float8_e4m3fnuz matrix got ", mat_b.scalar_type());
|
||||
|
||||
#if defined(USE_FBGEMM_GENAI) && defined(USE_ROCM)
|
||||
fbgemm_gpu::f8f8bf16_rowwise_grouped_mm(
|
||||
mat_a,
|
||||
// FBGEMM expects B matrix shape to be (.., N, K)
|
||||
mat_b.transpose(-2, -1),
|
||||
scale_a,
|
||||
scale_b,
|
||||
offs,
|
||||
out);
|
||||
#else
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(false, "grouped gemm is not supported without USE_FBGEMM_GENAI on ROCM")
|
||||
#endif
|
||||
return out;
|
||||
|
||||
}
|
||||
|
||||
// Dispatch f8 x f8 -> bf16 row-wise scaled to rocm/cuda
|
||||
Tensor&
|
||||
_f8_f8_bf16_rowwise_grouped_mm(
|
||||
const Tensor& mat_a,
|
||||
const Tensor& mat_b,
|
||||
const Tensor& scale_a,
|
||||
const Tensor& scale_b,
|
||||
const std::optional<Tensor>& offs,
|
||||
const std::optional<Tensor>& bias,
|
||||
bool use_fast_accum,
|
||||
Tensor& out) {
|
||||
// FP8 per-tensor and per-row scaling expect fp32 scales.
|
||||
TORCH_CHECK_VALUE(scale_a.scalar_type() == kFloat && scale_b.scalar_type() == kFloat,
|
||||
"For grouped FP8 rowwise, both scales must be float32 tensors");
|
||||
#ifndef USE_ROCM
|
||||
return _f8_f8_bf16_rowwise_grouped_mm_cuda(
|
||||
mat_a,
|
||||
mat_b,
|
||||
scale_a,
|
||||
scale_b,
|
||||
offs,
|
||||
bias,
|
||||
use_fast_accum,
|
||||
out);
|
||||
#else
|
||||
// NOTE: ignore use_fast_accum
|
||||
TORCH_CHECK_VALUE(!bias.has_value(), "ROCM grouped gemm does not support bias")
|
||||
return _f8_f8_bf16_rowwise_grouped_mm_rocm(
|
||||
mat_a,
|
||||
mat_b,
|
||||
scale_a,
|
||||
scale_b,
|
||||
offs,
|
||||
out);
|
||||
#endif
|
||||
}
|
||||
|
||||
Tensor
|
||||
_scaled_grouped_mm_cuda(
|
||||
const Tensor& mat_a,
|
||||
const Tensor& mat_b,
|
||||
const Tensor& scale_a,
|
||||
const Tensor& scale_b,
|
||||
const std::optional<at::Tensor>& offs,
|
||||
const std::optional<at::Tensor>& bias,
|
||||
const std::optional<at::Tensor>& scale_result,
|
||||
std::optional<c10::ScalarType> out_dtype,
|
||||
bool use_fast_accum) {
|
||||
bool allowed_device = _scaled_mm_allowed_device(/*sm90_only*/true, /*sm100_only*/true);
|
||||
TORCH_CHECK_VALUE(allowed_device, "torch._scaled_grouped_mm is only supported on CUDA devices with compute capability = [9.0, 10.0], or ROCm MI300+");
|
||||
|
||||
TORCH_CHECK_VALUE(!check_valid_strides_and_return_transposed(mat_a), "Expected mat1 to not be transposed");
|
||||
TORCH_CHECK_VALUE(check_valid_strides_and_return_transposed(mat_b), "Expected mat2 to be transposed");
|
||||
TORCH_CHECK_VALUE(mat_a.dim() == 2 || mat_a.dim() == 3, "mat_a has to be 2 or 3d");
|
||||
TORCH_CHECK_VALUE(mat_b.dim() == 2 || mat_b.dim() == 3, "mat_b has to be 2 or 3d");
|
||||
const bool a_is_2d = mat_a.dim() == 2;
|
||||
const bool b_is_2d = mat_b.dim() == 2;
|
||||
|
||||
// NOTE(slayton): For sub-1B formats want contraction_dim argument?
|
||||
if (!a_is_2d || !b_is_2d) {
|
||||
TORCH_CHECK_VALUE(mat_a.size(-1) == mat_b.size(-2), "contraction dimension of mat_a and mat_b must match");
|
||||
}
|
||||
TORCH_CHECK_VALUE(
|
||||
mat_a.size(-1) % 16 == 0,
|
||||
"Expected trailing dimension of mat_a to be divisible by 16 ",
|
||||
"but got mat1 shape: (",
|
||||
mat_a.sizes(),
|
||||
").");
|
||||
TORCH_CHECK_VALUE(mat_b.size(-2) % 16 == 0 && mat_b.size(-1) % 16 == 0,
|
||||
"Expected mat_b shape to be divisible by 16 ",
|
||||
"but got mat_b shape: (",
|
||||
mat_b.sizes(),
|
||||
").");
|
||||
|
||||
|
||||
TORCH_CHECK_VALUE(!bias.has_value(), "Bias not supported yet");
|
||||
TORCH_CHECK_VALUE(!scale_result.has_value(), "Scale result not supported yet");
|
||||
TORCH_CHECK_VALUE(offs.has_value() == (a_is_2d || b_is_2d), "Have to provide offsets if there is a 2d matrix");
|
||||
|
||||
// NOTE: mxfp8 x mxfp8 requires (and asserts later) that offsets is present.
|
||||
// for rowwise, no offsets implies 3d-3d and is handled by lower-level
|
||||
// routines
|
||||
if (offs.has_value()) {
|
||||
TORCH_CHECK_VALUE(offs->dim() == 1, "offs has to be 1D");
|
||||
TORCH_CHECK_VALUE(offs->dtype() == at::kInt, "Offsets have to be int32");
|
||||
}
|
||||
// FP8 per-tensor and per-row scaling expect fp32 scales.
|
||||
// MXFP8 expects float8_e8m0fnu scales.
|
||||
TORCH_CHECK_VALUE(
|
||||
(scale_a.scalar_type() == kFloat && scale_b.scalar_type() == kFloat) ||
|
||||
(scale_a.scalar_type() == at::kFloat8_e8m0fnu && scale_b.scalar_type() == at::kFloat8_e8m0fnu),
|
||||
"For FP8 tensorwise and rowwise, both scales must both be float32 tensors. For MXFP8, scales must both be float8_e8m0fnu tensors.");
|
||||
|
||||
const int scale_multiplier = (mat_a.dim() == 2 && mat_b.dim() == 2) ? offs->size(0) : 1;
|
||||
check_scale(mat_a, scale_a, 0 ,0, scale_multiplier);
|
||||
check_scale(mat_b, scale_b, 1, 1, scale_multiplier);
|
||||
|
||||
const auto out_dtype_ = out_dtype.value_or(kBFloat16);
|
||||
TORCH_CHECK_VALUE(out_dtype_ == kBFloat16, "Only bf16 high precision output types are supported for grouped gemm");
|
||||
|
||||
Tensor out = create_grouped_gemm_output_tensor(mat_a, mat_b, offs, out_dtype_);
|
||||
|
||||
#if defined(USE_FBGEMM_GENAI) && defined(USE_CUDA) && !defined(USE_ROCM)
|
||||
// MXFP8 grouped GEMM dispatching
|
||||
bool is_mx8mx8bf16 = (
|
||||
mat_a.scalar_type() == at::kFloat8_e4m3fn && mat_b.scalar_type() == at::kFloat8_e4m3fn &&
|
||||
scale_a.scalar_type() == at::kFloat8_e8m0fnu && scale_b.scalar_type() == at::kFloat8_e8m0fnu
|
||||
);
|
||||
#else
|
||||
bool is_mx8mx8bf16 = false;
|
||||
#endif
|
||||
|
||||
if (is_mx8mx8bf16) {
|
||||
// Note: Passing implied SwizzleType here, correctness of scale previously checked
|
||||
// in `check_scale` call
|
||||
return _mx8_mx8_bf16_grouped_mm_fbgemm(
|
||||
mat_a,
|
||||
mat_b,
|
||||
scale_a,
|
||||
SwizzleType::SWIZZLE_32_4_4,
|
||||
scale_b,
|
||||
SwizzleType::SWIZZLE_32_4_4,
|
||||
offs.value(),
|
||||
out);
|
||||
}
|
||||
|
||||
// If we're not MXFP8, then we're row-wise scaling.
|
||||
return _f8_f8_bf16_rowwise_grouped_mm(
|
||||
mat_a,
|
||||
mat_b,
|
||||
scale_a,
|
||||
scale_b,
|
||||
offs,
|
||||
bias,
|
||||
use_fast_accum,
|
||||
out);
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
std::array<std::tuple<std::string, acceptance_fn, ScaledGemmImplementation>, 2> scale_grouped_kernel_dispatch = {{
|
||||
{ "rowwise_rowwise", check_rowwise_recipe, ScaledGemmImplementation::ROWWISE_ROWWISE},
|
||||
{ "mxfp8_mxfp8", check_mxfp8_recipe, ScaledGemmImplementation::MXFP8_MXFP8}}};
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
Tensor
|
||||
_scaled_grouped_mm_cuda_v2(
|
||||
const Tensor& mat_a, const Tensor& mat_b,
|
||||
ArrayRef<Tensor> scale_a,
|
||||
IntArrayRef scale_recipe_a,
|
||||
IntArrayRef swizzle_a,
|
||||
ArrayRef<Tensor> scale_b,
|
||||
IntArrayRef scale_recipe_b,
|
||||
IntArrayRef swizzle_b,
|
||||
const std::optional<Tensor>& offs,
|
||||
const std::optional<Tensor>& bias,
|
||||
const std::optional<c10::ScalarType> out_dtype,
|
||||
IntArrayRef contraction_dim,
|
||||
bool use_fast_accum) {
|
||||
bool allowed_device = _scaled_mm_allowed_device(/*sm90_only*/true, /*sm100_only*/true);
|
||||
TORCH_CHECK_VALUE(allowed_device, "torch._scaled_grouped_mm is only supported on CUDA devices with compute capability = [9.0, 10.0], or ROCm MI300+");
|
||||
|
||||
TORCH_CHECK_VALUE(!check_valid_strides_and_return_transposed(mat_a), "Expected mat1 to not be transposed");
|
||||
TORCH_CHECK_VALUE(check_valid_strides_and_return_transposed(mat_b), "Expected mat2 to be transposed");
|
||||
TORCH_CHECK_VALUE(mat_a.dim() == 2 || mat_a.dim() == 3, "mat_a has to be 2 or 3d");
|
||||
TORCH_CHECK_VALUE(mat_b.dim() == 2 || mat_b.dim() == 3, "mat_b has to be 2 or 3d");
|
||||
const bool a_is_2d = mat_a.dim() == 2;
|
||||
const bool b_is_2d = mat_b.dim() == 2;
|
||||
|
||||
// NOTE(slayton): For sub-1B formats want contraction_dim argument?
|
||||
if (!a_is_2d || !b_is_2d) {
|
||||
if (contraction_dim.size() > 0) {
|
||||
const int dim_a = contraction_dim[0], dim_b = mat_b.size(contraction_dim[1]);
|
||||
TORCH_CHECK_VALUE(mat_a.size(dim_a) == mat_b.size(dim_b),
|
||||
"Contraction dimensions (", dim_a, ",", dim_b, ") of mat_a and mat_b must match, got: ", mat_a.size(dim_a), " and ",
|
||||
mat_b.size(dim_b));
|
||||
// Note: only (-1, -2) is currently supported
|
||||
TORCH_CHECK_VALUE(dim_a == -1 && dim_b == -2, "Curently contraction dims must be (-1, -2) only");
|
||||
} else {
|
||||
TORCH_CHECK_VALUE(mat_a.size(-1) == mat_b.size(-2), "contraction dimension of mat_a and mat_b must match");
|
||||
}
|
||||
}
|
||||
TORCH_CHECK_VALUE(
|
||||
mat_a.size(-1) % 16 == 0,
|
||||
"Expected trailing dimension of mat_a to be divisible by 16 ",
|
||||
"but got mat1 shape: (",
|
||||
mat_a.sizes(),
|
||||
").");
|
||||
TORCH_CHECK_VALUE(mat_b.size(-2) % 16 == 0 && mat_b.size(-1) % 16 == 0,
|
||||
"Expected mat_b shape to be divisible by 16 ",
|
||||
"but got mat_b shape: (",
|
||||
mat_b.sizes(),
|
||||
").");
|
||||
|
||||
TORCH_CHECK_VALUE(!bias.has_value(), "Bias not supported yet");
|
||||
TORCH_CHECK_VALUE(offs.has_value() == (a_is_2d || b_is_2d), "Have to provide offsets if there is a 2d matrix");
|
||||
|
||||
// NOTE: mxfp8 x mxfp8 requires (and asserts later) that offsets is present.
|
||||
// for rowwise, no offsets implies 3d-3d and is handled by lower-level
|
||||
// routines
|
||||
if (offs.has_value()) {
|
||||
TORCH_CHECK_VALUE(offs->dim() == 1, "offs has to be 1D");
|
||||
TORCH_CHECK_VALUE(offs->dtype() == at::kInt, "Offsets have to be int32");
|
||||
}
|
||||
|
||||
const auto out_dtype_ = out_dtype.value_or(kBFloat16);
|
||||
TORCH_CHECK_VALUE(out_dtype_ == kBFloat16, "Only bf16 high precision output types are supported for grouped gemm");
|
||||
|
||||
Tensor out = create_grouped_gemm_output_tensor(mat_a, mat_b, offs, out_dtype_);
|
||||
|
||||
// Conversion of implicitly-defined enums to explicit
|
||||
auto scale_recipe_a_enum = convert_int_to_enum<ScalingType>(scale_recipe_a);
|
||||
auto swizzle_a_enum = convert_int_to_enum<SwizzleType>(swizzle_a);
|
||||
auto scale_recipe_b_enum = convert_int_to_enum<ScalingType>(scale_recipe_b);
|
||||
auto swizzle_b_enum = convert_int_to_enum<SwizzleType>(swizzle_b);
|
||||
|
||||
// at this point we can start working out what we want to be doing
|
||||
// Try to do as few steps as possible.
|
||||
// NOTE: support is deliberately sparse, can explicitly enumerate all combinations allowed.
|
||||
// Do this via a list of defined (name, acceptance, concrete_impl) tuples.
|
||||
ScaledGemmImplementation gemm_impl = ScaledGemmImplementation::NONE;
|
||||
for (const auto& fn_entry : scale_grouped_kernel_dispatch) {
|
||||
const auto [name, accept_fn, scaled_gemm_impl] = fn_entry;
|
||||
bool ok = accept_fn(mat_a.scalar_type(),
|
||||
scale_recipe_a_enum,
|
||||
scale_a,
|
||||
mat_b.scalar_type(),
|
||||
scale_recipe_b_enum,
|
||||
scale_b);
|
||||
if (ok) {
|
||||
gemm_impl = scaled_gemm_impl;
|
||||
break;
|
||||
}
|
||||
}
|
||||
TORCH_CHECK_VALUE(gemm_impl != ScaledGemmImplementation::NONE,
|
||||
"No gemm implementation was found");
|
||||
|
||||
switch (gemm_impl) {
|
||||
case ScaledGemmImplementation::ROWWISE_ROWWISE: {
|
||||
const int scale_multiplier = (mat_a.dim() == 2 && mat_b.dim() == 2) ? offs->size(0) : 1;
|
||||
_check_scales_fp8_rowwise(mat_a, scale_a[0], 0 /* dim */ , 0 /* arg_idx */, scale_multiplier);
|
||||
_check_scales_fp8_rowwise(mat_b, scale_b[0], 1 /* dim */ , 1 /* arg_idx */, scale_multiplier);
|
||||
return _f8_f8_bf16_rowwise_grouped_mm(
|
||||
mat_a,
|
||||
mat_b,
|
||||
scale_a[0],
|
||||
scale_b[0],
|
||||
offs,
|
||||
bias,
|
||||
use_fast_accum,
|
||||
out);
|
||||
}
|
||||
case ScaledGemmImplementation::MXFP8_MXFP8: {
|
||||
_check_scales_mxfp8(mat_a, scale_a[0], 0 /* dim */, 0 /* arg_idx */);
|
||||
_check_scales_mxfp8(mat_b, scale_b[0], 1 /* dim */, 1 /* arg_idx */);
|
||||
return _mx8_mx8_bf16_grouped_mm_fbgemm(
|
||||
mat_a,
|
||||
mat_b,
|
||||
scale_a[0],
|
||||
swizzle_a_enum[0],
|
||||
scale_b[0],
|
||||
swizzle_b_enum[0],
|
||||
offs.value(),
|
||||
out);
|
||||
}
|
||||
default:
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(false,
|
||||
"_scaled_grouped_mm_cuda_v2 is in an inconsistent state - should never reach here");
|
||||
}
|
||||
}
|
||||
|
||||
Tensor _grouped_mm_cuda(const Tensor& mat_a, const Tensor& mat_b,
|
||||
const std::optional<at::Tensor>& offs,
|
||||
const std::optional<at::Tensor>& bias,
|
||||
std::optional<c10::ScalarType> out_dtype) {
|
||||
_grouped_mm_validate_inputs(mat_a, mat_b, offs, bias, out_dtype);
|
||||
bool a_b_and_out_are_bf16 = (
|
||||
mat_a.dtype() == at::kBFloat16 &&
|
||||
mat_b.dtype() == at::kBFloat16 &&
|
||||
out_dtype.value_or(at::kBFloat16) == at::kBFloat16
|
||||
);
|
||||
#ifndef USE_ROCM
|
||||
bool use_fast_path = _scaled_mm_allowed_device(/*sm90_only*/true, /*sm100_only*/true) && a_b_and_out_are_bf16;
|
||||
#else
|
||||
// _scaled_mm_allowed_device is used here within _grouped_mm_cuda which seems incorrect since scale is not used.
|
||||
// the _grouped_mm_fallback should be safe for any ROCm GPU since it's just calling typical mm/bmm
|
||||
bool use_fast_path = false;
|
||||
#endif
|
||||
const auto out_dtype_ = _resolve_grouped_mm_out_dtype(mat_a, mat_b, out_dtype);
|
||||
Tensor out = create_grouped_gemm_output_tensor(mat_a, mat_b, offs, out_dtype_);
|
||||
if (use_fast_path) {
|
||||
// fast path, no d2h sync needed
|
||||
at::cuda::detail::bf16bf16_grouped_mm(mat_a, mat_b, offs, bias, out);
|
||||
} else {
|
||||
_grouped_mm_fallback(mat_a, mat_b, offs, bias, out_dtype, out);
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
static void baddbmm_bmm_out_dtype_checks(const Tensor& batch1, const Tensor& batch2, const Scalar& beta, const Scalar& alpha, const at::ScalarType out_dtype, bool is_bmm, const std::optional<Tensor>& self_baddbmm = std::nullopt) {
|
||||
// ref ATen/native/LinearAlgebra.cpp common_checks_baddbmm_bmm
|
||||
TORCH_CHECK(batch1.dim() == 3, "batch1 must be a 3D tensor");
|
||||
|
||||
574
aten/src/ATen/native/cuda/GroupedBlas.cpp
Normal file
574
aten/src/ATen/native/cuda/GroupedBlas.cpp
Normal file
@ -0,0 +1,574 @@
|
||||
#include <cstdint>
|
||||
#include <c10/util/typeid.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#include <c10/util/SmallVector.h>
|
||||
#include <c10/core/Scalar.h>
|
||||
#include <c10/core/ScalarType.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
|
||||
#include <ATen/core/Tensor.h>
|
||||
#include <ATen/core/NamedTensor.h>
|
||||
#include <ATen/Dispatch.h>
|
||||
#include <ATen/ExpandUtils.h>
|
||||
#include <ATen/OpMathType.h>
|
||||
#include <ATen/TensorUtils.h>
|
||||
#include <ATen/cuda/CUDABlas.h>
|
||||
#include <ATen/cuda/CUDAScaledBlas.h>
|
||||
#include <ATen/cuda/tunable/Tunable.h>
|
||||
#include <ATen/cuda/tunable/TunableGemm.h>
|
||||
#include <ATen/native/Resize.h>
|
||||
#include <c10/util/MaybeOwned.h>
|
||||
#include <ATen/native/GroupedMMUtils.h>
|
||||
#include <ATen/native/cuda/RowwiseScaledMM.h>
|
||||
#include <ATen/native/cuda/ScaledGroupMM.h>
|
||||
#include <ATen/native/cuda/GroupMM.h>
|
||||
#include <ATen/ceil_div.h>
|
||||
|
||||
#ifdef USE_FBGEMM_GENAI
|
||||
#include <fbgemm_gpu/torch_ops.h>
|
||||
#endif
|
||||
|
||||
#ifndef AT_PER_OPERATOR_HEADERS
|
||||
#include <ATen/Functions.h>
|
||||
#include <ATen/NativeFunctions.h>
|
||||
#else
|
||||
#include <ATen/ops/_addmm_activation_native.h>
|
||||
#include <ATen/ops/_efficientzerotensor.h>
|
||||
#include <ATen/ops/_scaled_mm_native.h>
|
||||
#include <ATen/ops/_unsafe_view_native.h>
|
||||
#include <ATen/ops/abs.h>
|
||||
#include <ATen/ops/addmm_native.h>
|
||||
#include <ATen/ops/addmv_native.h>
|
||||
#include <ATen/ops/baddbmm_native.h>
|
||||
#include <ATen/ops/bmm_native.h>
|
||||
#include <ATen/ops/copy_native.h>
|
||||
#include <ATen/ops/dot_native.h>
|
||||
#include <ATen/ops/empty.h>
|
||||
#include <ATen/ops/empty_strided.h>
|
||||
#include <ATen/ops/gelu.h>
|
||||
#include <ATen/ops/max.h>
|
||||
#include <ATen/ops/mm_native.h>
|
||||
#include <ATen/ops/mul.h>
|
||||
#include <ATen/ops/relu.h>
|
||||
#include <ATen/ops/ones.h>
|
||||
#include <ATen/ops/scalar_tensor_native.h>
|
||||
#include <ATen/ops/vdot_native.h>
|
||||
#endif
|
||||
|
||||
using at::blas::ScalingType;
|
||||
using at::blas::SwizzleType;
|
||||
|
||||
namespace scaled_blas = at::cuda::scaled;
|
||||
using scaled_blas::ScaledGemmImplementation;
|
||||
using scaled_blas::convert_int_to_enum;
|
||||
using scaled_blas::_scaled_mm_allowed_device;
|
||||
|
||||
namespace at::native {
|
||||
|
||||
namespace {
|
||||
|
||||
// 2d-2d and 2d-3d
|
||||
// scaling=MXFP8
|
||||
// CUDA-only
|
||||
Tensor&
|
||||
_mx8_mx8_bf16_grouped_mm_fbgemm(
|
||||
const Tensor& mat_a,
|
||||
const Tensor& mat_b,
|
||||
const Tensor& scale_a,
|
||||
const SwizzleType& swizzle_a,
|
||||
const Tensor& scale_b,
|
||||
const SwizzleType& swizzle_b,
|
||||
const std::optional<at::Tensor>& offs,
|
||||
Tensor& out) {
|
||||
const bool a_is_2d = mat_a.dim() == 2;
|
||||
const bool b_is_2d = mat_b.dim() == 2;
|
||||
bool b_is_3d = mat_b.dim() == 3;
|
||||
bool is_2d_2d = a_is_2d && b_is_2d;
|
||||
bool is_2d_3d = a_is_2d && b_is_3d;
|
||||
TORCH_CHECK_VALUE(is_2d_2d || is_2d_3d, "MXFP8 grouped GEMM currently only supports 2d-2d and 2d-3d cases");
|
||||
TORCH_CHECK_VALUE(offs.has_value(), "MXFP8 2d-2d and 2d-3d grouped GEMMs requires offsets");
|
||||
TORCH_CHECK_VALUE(out.scalar_type() == at::kBFloat16, "Only bf16 out_dtype is supported for MXFP8 grouped gemm");
|
||||
// MXFP8 expects float8_e8m0fnu scales.
|
||||
TORCH_CHECK_VALUE(scale_a.scalar_type() == at::kFloat8_e8m0fnu && scale_b.scalar_type() == at::kFloat8_e8m0fnu,
|
||||
"For MXFP8 grouped gemm, both scales must be float8_e8m0fnu tensors.");
|
||||
#ifdef USE_ROCM
|
||||
TORCH_CHECK_VALUE(swizzle_a == SwizzleType::NO_SWIZZLE && swizzle_b == SwizzleType::NO_SWIZZLE,
|
||||
"For ROCM MXFP8 grouped gemm, both scale swizzle types must be SWIZZLE_NONE");
|
||||
#else
|
||||
TORCH_CHECK_VALUE(swizzle_a == SwizzleType::SWIZZLE_32_4_4 && swizzle_b == SwizzleType::SWIZZLE_32_4_4,
|
||||
"For CUDA MXFP8 grouped gemm, both scale swizzle types must be SWIZZLE_32_4_4");
|
||||
#endif
|
||||
|
||||
#if defined(USE_FBGEMM_GENAI) and !defined(USE_ROCM)
|
||||
fbgemm_gpu::mx8mx8bf16_grouped_mm(
|
||||
mat_a,
|
||||
mat_b,
|
||||
scale_a,
|
||||
scale_b,
|
||||
offs.value(),
|
||||
out);
|
||||
#else
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(false, "mxfp8_mxfp8 grouped gemm requires compile with USE_FBGEMM_GENAI");
|
||||
#endif
|
||||
return out;
|
||||
}
|
||||
|
||||
// 2d-2d and 2d-3d cases
|
||||
// scaling=rowwise
|
||||
// CUDA-only
|
||||
Tensor&
|
||||
_f8_f8_bf16_rowwise_grouped_mm_cuda(
|
||||
const Tensor& mat_a,
|
||||
const Tensor& mat_b,
|
||||
const Tensor& scale_a,
|
||||
const Tensor& scale_b,
|
||||
const std::optional<Tensor>& offs,
|
||||
const std::optional<Tensor>& bias,
|
||||
const bool use_fast_accum,
|
||||
Tensor& out) {
|
||||
TORCH_CHECK_VALUE(mat_a.dtype() == at::kFloat8_e4m3fn, "Expected mat_a to be Float8_e4m3 matrix got ", mat_a.scalar_type());
|
||||
TORCH_CHECK_VALUE(mat_b.dtype() == at::kFloat8_e4m3fn, "Expected mat_a to be Float8_e4m3 matrix got ", mat_b.scalar_type());
|
||||
|
||||
at::cuda::detail::f8f8bf16_grouped_mm(
|
||||
mat_a,
|
||||
mat_b,
|
||||
scale_a,
|
||||
scale_b,
|
||||
offs,
|
||||
bias,
|
||||
use_fast_accum,
|
||||
out);
|
||||
return out;
|
||||
}
|
||||
|
||||
// 2d-2d and 2d-3d cases
|
||||
// scaling=rowwise
|
||||
// only being called for rocm
|
||||
Tensor&
|
||||
_f8_f8_bf16_rowwise_grouped_mm_rocm(
|
||||
const Tensor& mat_a,
|
||||
const Tensor& mat_b,
|
||||
const Tensor& scale_a,
|
||||
const Tensor& scale_b,
|
||||
const std::optional<Tensor>& offs,
|
||||
Tensor& out) {
|
||||
TORCH_CHECK_VALUE(mat_a.dtype() == at::kFloat8_e4m3fnuz, "Expected mat_a to be Float8_e4m3fnuz matrix got ", mat_a.scalar_type());
|
||||
TORCH_CHECK_VALUE(mat_b.dtype() == at::kFloat8_e4m3fnuz, "Expected mat_a to be Float8_e4m3fnuz matrix got ", mat_b.scalar_type());
|
||||
|
||||
#if defined(USE_FBGEMM_GENAI) && defined(USE_ROCM)
|
||||
fbgemm_gpu::f8f8bf16_rowwise_grouped_mm(
|
||||
mat_a,
|
||||
// FBGEMM expects B matrix shape to be (.., N, K)
|
||||
mat_b.transpose(-2, -1),
|
||||
scale_a,
|
||||
scale_b,
|
||||
offs,
|
||||
out);
|
||||
#else
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(false, "grouped gemm is not supported without USE_FBGEMM_GENAI on ROCM")
|
||||
#endif
|
||||
return out;
|
||||
|
||||
}
|
||||
|
||||
// Dispatch f8 x f8 -> bf16 row-wise scaled to rocm/cuda
|
||||
Tensor&
|
||||
_f8_f8_bf16_rowwise_grouped_mm(
|
||||
const Tensor& mat_a,
|
||||
const Tensor& mat_b,
|
||||
const Tensor& scale_a,
|
||||
const Tensor& scale_b,
|
||||
const std::optional<Tensor>& offs,
|
||||
const std::optional<Tensor>& bias,
|
||||
bool use_fast_accum,
|
||||
Tensor& out) {
|
||||
// FP8 per-tensor and per-row scaling expect fp32 scales.
|
||||
TORCH_CHECK_VALUE(scale_a.scalar_type() == kFloat && scale_b.scalar_type() == kFloat,
|
||||
"For grouped FP8 rowwise, both scales must be float32 tensors");
|
||||
#ifndef USE_ROCM
|
||||
return _f8_f8_bf16_rowwise_grouped_mm_cuda(
|
||||
mat_a,
|
||||
mat_b,
|
||||
scale_a,
|
||||
scale_b,
|
||||
offs,
|
||||
bias,
|
||||
use_fast_accum,
|
||||
out);
|
||||
#else
|
||||
// NOTE: ignore use_fast_accum
|
||||
TORCH_CHECK_VALUE(!bias.has_value(), "ROCM grouped gemm does not support bias")
|
||||
return _f8_f8_bf16_rowwise_grouped_mm_rocm(
|
||||
mat_a,
|
||||
mat_b,
|
||||
scale_a,
|
||||
scale_b,
|
||||
offs,
|
||||
out);
|
||||
#endif
|
||||
}
|
||||
|
||||
void _check_scales_fp8_rowwise(const Tensor& mat, const Tensor& scale, const int dim, const int arg_idx, const int scale_multiplier=1) {
|
||||
// Checks scales for 2d or 3d target tensors (`mat`).
|
||||
if (mat.dim() == 2) {
|
||||
TORCH_CHECK(
|
||||
scale.dim() == 1,
|
||||
"scale must be a 1D tensor, but got ",
|
||||
scale.dim(),
|
||||
"D, arg ",
|
||||
arg_idx);
|
||||
TORCH_CHECK(
|
||||
scale.is_contiguous(), "scale must be contiguous for arg ", arg_idx);
|
||||
TORCH_CHECK(
|
||||
scale.size(0) == mat.size(dim) * scale_multiplier,
|
||||
"scale must have the same length as mat for arg ",
|
||||
arg_idx);
|
||||
} else {
|
||||
TORCH_CHECK(
|
||||
scale.dim() == 2,
|
||||
"scale must be a 2D tensor, but got ",
|
||||
scale.dim(),
|
||||
"D for arg ",
|
||||
arg_idx);
|
||||
TORCH_CHECK(
|
||||
scale.stride(1) == 1,
|
||||
"scale must be contiguous in the last dimension for arg ",
|
||||
arg_idx);
|
||||
TORCH_CHECK(
|
||||
scale.size(0) == mat.size(0),
|
||||
"scale must have the same batch dimension as mat for arg ",
|
||||
arg_idx);
|
||||
TORCH_CHECK(
|
||||
scale.size(1) == mat.size(1 + dim),
|
||||
"scale must have the same first dimension as mat for arg ",
|
||||
arg_idx);
|
||||
}
|
||||
}
|
||||
|
||||
void _check_scales_mxfp8(const Tensor& mat, const Tensor& scale, const int dim, const int arg_idx) {
|
||||
// Checks scales for 2d or 3d target tensors (`mat`).
|
||||
if (mat.dim() == 2) {
|
||||
// For MXFP8, 2d tensors have variable size groups represented as subtensors,
|
||||
// that are converted to blocked padded format individually,
|
||||
// so we can't check the scale sizes without doing a d2h sync to get the group sizes here.
|
||||
TORCH_CHECK(
|
||||
scale.dim() == mat.dim(),
|
||||
"for mxfp8, scale must have same number of dimensions as parent tensor, but got mat.dim() = ", mat.dim(), " and scale.dim() = ", scale.dim(), " for arg ", arg_idx);
|
||||
|
||||
// LHS mat shape (M, total_K) -> scale shape (rounded_up(M, 128), rounded_up_per_group(K/32, 4))
|
||||
// RHS mat shape (total_K, N) -> scale shape (rounded_up(N, 128), rounded_up_per_group(K/32, 4))
|
||||
// * weight is transposed prior to the call, scale stays non-transposed.
|
||||
bool LHS = arg_idx == 0;
|
||||
int scale_dim_to_check = 0;
|
||||
int mat_dim_to_check = LHS ? 0 : 1;
|
||||
TORCH_CHECK(
|
||||
scale.size(scale_dim_to_check) >= mat.size(mat_dim_to_check),
|
||||
"for mxfp8, arg ", arg_idx, " tensor shape (", mat.size(0), ", ", mat.size(1), ") ",
|
||||
"must have scale.shape[", scale_dim_to_check, "] >= ", mat.size(mat_dim_to_check), " but got scale.shape=(", scale.size(0), ", ", scale.size(1), ")");
|
||||
} else {
|
||||
// For MXFP8, 3d tensors have static group sizes (stack of 2d tensors),
|
||||
// so we can check the exact expected scale sizes here without a d2h sync.
|
||||
auto round_up = [](auto x, auto y) {
|
||||
return ((x + y - 1) / y) * y;
|
||||
};
|
||||
|
||||
// TODO: this is for 3d tensor in 2d-3d case specifically.
|
||||
// We'll need to support 3d-3d and 3d-2d cases once mxfp8 grouped gemm supports them.
|
||||
int64_t G = mat.size(0);
|
||||
int64_t K = mat.size(1);
|
||||
int64_t N = mat.size(2);
|
||||
int64_t blocked_scale_K = round_up(K/32, 4);
|
||||
int64_t blocked_scale_N = round_up(N, 128);
|
||||
|
||||
// fbgemm expects stack of flattened blocked scales for 3d tensor, shape (G, blocked_scale_K * blocked_scale_N).
|
||||
TORCH_CHECK(
|
||||
scale.dim() == mat.dim() - 1,
|
||||
"for mxfp8 2d-3d grouped GEMM, the 3d tensor of shape (G,K,N) must have a 2d scale of shape (G, blocked_scale_K * blocked_scale_N), but scale is ", scale.dim(), "D for arg ", arg_idx
|
||||
);
|
||||
TORCH_CHECK(
|
||||
scale.size(0) == G && scale.size(1) == blocked_scale_K * blocked_scale_N,
|
||||
"for mxfp8, the tensor shape (", G, ", ", K, ", ", N, ") must have scale shape (", G, ",", blocked_scale_K, ",", blocked_scale_N, ") for arg ", arg_idx
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
void check_scale(const Tensor& mat, const Tensor& scale, const int dim, const int arg_idx, const int scale_multiplier=1) {
|
||||
bool using_fp8_rowwise = scale.scalar_type() == kFloat;
|
||||
bool using_mxfp8 = scale.scalar_type() == at::kFloat8_e8m0fnu;
|
||||
if (using_fp8_rowwise) {
|
||||
_check_scales_fp8_rowwise(mat, scale, dim, arg_idx, scale_multiplier);
|
||||
} else if (using_mxfp8) {
|
||||
_check_scales_mxfp8(mat, scale, dim, arg_idx);
|
||||
} else {
|
||||
TORCH_CHECK(false, "scale must be float32 or float8_e8m0fnu, but got ", scale.dtype());
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
Tensor
|
||||
_scaled_grouped_mm_cuda(
|
||||
const Tensor& mat_a,
|
||||
const Tensor& mat_b,
|
||||
const Tensor& scale_a,
|
||||
const Tensor& scale_b,
|
||||
const std::optional<at::Tensor>& offs,
|
||||
const std::optional<at::Tensor>& bias,
|
||||
const std::optional<at::Tensor>& scale_result,
|
||||
std::optional<c10::ScalarType> out_dtype,
|
||||
bool use_fast_accum) {
|
||||
bool allowed_device = _scaled_mm_allowed_device(/*sm90_only*/true, /*sm100_only*/true);
|
||||
TORCH_CHECK_VALUE(allowed_device, "torch._scaled_grouped_mm is only supported on CUDA devices with compute capability = [9.0, 10.0], or ROCm MI300+");
|
||||
|
||||
TORCH_CHECK_VALUE(!check_valid_strides_and_return_transposed(mat_a), "Expected mat1 to not be transposed");
|
||||
TORCH_CHECK_VALUE(check_valid_strides_and_return_transposed(mat_b), "Expected mat2 to be transposed");
|
||||
TORCH_CHECK_VALUE(mat_a.dim() == 2 || mat_a.dim() == 3, "mat_a has to be 2 or 3d");
|
||||
TORCH_CHECK_VALUE(mat_b.dim() == 2 || mat_b.dim() == 3, "mat_b has to be 2 or 3d");
|
||||
const bool a_is_2d = mat_a.dim() == 2;
|
||||
const bool b_is_2d = mat_b.dim() == 2;
|
||||
|
||||
// NOTE(slayton): For sub-1B formats want contraction_dim argument?
|
||||
if (!a_is_2d || !b_is_2d) {
|
||||
TORCH_CHECK_VALUE(mat_a.size(-1) == mat_b.size(-2), "contraction dimension of mat_a and mat_b must match");
|
||||
}
|
||||
TORCH_CHECK_VALUE(
|
||||
mat_a.size(-1) % 16 == 0,
|
||||
"Expected trailing dimension of mat_a to be divisible by 16 ",
|
||||
"but got mat1 shape: (",
|
||||
mat_a.sizes(),
|
||||
").");
|
||||
TORCH_CHECK_VALUE(mat_b.size(-2) % 16 == 0 && mat_b.size(-1) % 16 == 0,
|
||||
"Expected mat_b shape to be divisible by 16 ",
|
||||
"but got mat_b shape: (",
|
||||
mat_b.sizes(),
|
||||
").");
|
||||
|
||||
|
||||
TORCH_CHECK_VALUE(!bias.has_value(), "Bias not supported yet");
|
||||
TORCH_CHECK_VALUE(!scale_result.has_value(), "Scale result not supported yet");
|
||||
TORCH_CHECK_VALUE(offs.has_value() == (a_is_2d || b_is_2d), "Have to provide offsets if there is a 2d matrix");
|
||||
|
||||
// NOTE: mxfp8 x mxfp8 requires (and asserts later) that offsets is present.
|
||||
// for rowwise, no offsets implies 3d-3d and is handled by lower-level
|
||||
// routines
|
||||
if (offs.has_value()) {
|
||||
TORCH_CHECK_VALUE(offs->dim() == 1, "offs has to be 1D");
|
||||
TORCH_CHECK_VALUE(offs->dtype() == at::kInt, "Offsets have to be int32");
|
||||
}
|
||||
// FP8 per-tensor and per-row scaling expect fp32 scales.
|
||||
// MXFP8 expects float8_e8m0fnu scales.
|
||||
TORCH_CHECK_VALUE(
|
||||
(scale_a.scalar_type() == kFloat && scale_b.scalar_type() == kFloat) ||
|
||||
(scale_a.scalar_type() == at::kFloat8_e8m0fnu && scale_b.scalar_type() == at::kFloat8_e8m0fnu),
|
||||
"For FP8 tensorwise and rowwise, both scales must both be float32 tensors. For MXFP8, scales must both be float8_e8m0fnu tensors.");
|
||||
|
||||
const int scale_multiplier = (mat_a.dim() == 2 && mat_b.dim() == 2) ? offs->size(0) : 1;
|
||||
check_scale(mat_a, scale_a, 0 ,0, scale_multiplier);
|
||||
check_scale(mat_b, scale_b, 1, 1, scale_multiplier);
|
||||
|
||||
const auto out_dtype_ = out_dtype.value_or(kBFloat16);
|
||||
TORCH_CHECK_VALUE(out_dtype_ == kBFloat16, "Only bf16 high precision output types are supported for grouped gemm");
|
||||
|
||||
Tensor out = create_grouped_gemm_output_tensor(mat_a, mat_b, offs, out_dtype_);
|
||||
|
||||
#if defined(USE_FBGEMM_GENAI) && defined(USE_CUDA) && !defined(USE_ROCM)
|
||||
// MXFP8 grouped GEMM dispatching
|
||||
bool is_mx8mx8bf16 = (
|
||||
mat_a.scalar_type() == at::kFloat8_e4m3fn && mat_b.scalar_type() == at::kFloat8_e4m3fn &&
|
||||
scale_a.scalar_type() == at::kFloat8_e8m0fnu && scale_b.scalar_type() == at::kFloat8_e8m0fnu
|
||||
);
|
||||
#else
|
||||
bool is_mx8mx8bf16 = false;
|
||||
#endif
|
||||
|
||||
if (is_mx8mx8bf16) {
|
||||
// Note: Passing implied SwizzleType here, correctness of scale previously checked
|
||||
// in `check_scale` call
|
||||
return _mx8_mx8_bf16_grouped_mm_fbgemm(
|
||||
mat_a,
|
||||
mat_b,
|
||||
scale_a,
|
||||
SwizzleType::SWIZZLE_32_4_4,
|
||||
scale_b,
|
||||
SwizzleType::SWIZZLE_32_4_4,
|
||||
offs.value(),
|
||||
out);
|
||||
}
|
||||
|
||||
// If we're not MXFP8, then we're row-wise scaling.
|
||||
return _f8_f8_bf16_rowwise_grouped_mm(
|
||||
mat_a,
|
||||
mat_b,
|
||||
scale_a,
|
||||
scale_b,
|
||||
offs,
|
||||
bias,
|
||||
use_fast_accum,
|
||||
out);
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
using acceptance_fn = std::function<bool(c10::ScalarType, std::vector<ScalingType>&, ArrayRef<Tensor>&, c10::ScalarType, std::vector<ScalingType>&, ArrayRef<Tensor>&)>;
|
||||
|
||||
std::array<std::tuple<std::string, acceptance_fn, ScaledGemmImplementation>, 2> scale_grouped_kernel_dispatch = {{
|
||||
{ "rowwise_rowwise", scaled_blas::check_rowwise_recipe, ScaledGemmImplementation::ROWWISE_ROWWISE},
|
||||
{ "mxfp8_mxfp8", scaled_blas::check_mxfp8_recipe, ScaledGemmImplementation::MXFP8_MXFP8}}};
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
Tensor
|
||||
_scaled_grouped_mm_cuda_v2(
|
||||
const Tensor& mat_a, const Tensor& mat_b,
|
||||
ArrayRef<Tensor> scale_a,
|
||||
IntArrayRef scale_recipe_a,
|
||||
IntArrayRef swizzle_a,
|
||||
ArrayRef<Tensor> scale_b,
|
||||
IntArrayRef scale_recipe_b,
|
||||
IntArrayRef swizzle_b,
|
||||
const std::optional<Tensor>& offs,
|
||||
const std::optional<Tensor>& bias,
|
||||
const std::optional<c10::ScalarType> out_dtype,
|
||||
IntArrayRef contraction_dim,
|
||||
bool use_fast_accum) {
|
||||
bool allowed_device = _scaled_mm_allowed_device(/*sm90_only*/true, /*sm100_only*/true);
|
||||
TORCH_CHECK_VALUE(allowed_device, "torch._scaled_grouped_mm is only supported on CUDA devices with compute capability = [9.0, 10.0], or ROCm MI300+");
|
||||
|
||||
TORCH_CHECK_VALUE(!check_valid_strides_and_return_transposed(mat_a), "Expected mat1 to not be transposed");
|
||||
TORCH_CHECK_VALUE(check_valid_strides_and_return_transposed(mat_b), "Expected mat2 to be transposed");
|
||||
TORCH_CHECK_VALUE(mat_a.dim() == 2 || mat_a.dim() == 3, "mat_a has to be 2 or 3d");
|
||||
TORCH_CHECK_VALUE(mat_b.dim() == 2 || mat_b.dim() == 3, "mat_b has to be 2 or 3d");
|
||||
const bool a_is_2d = mat_a.dim() == 2;
|
||||
const bool b_is_2d = mat_b.dim() == 2;
|
||||
|
||||
// NOTE(slayton): For sub-1B formats want contraction_dim argument?
|
||||
if (!a_is_2d || !b_is_2d) {
|
||||
if (contraction_dim.size() > 0) {
|
||||
const int dim_a = contraction_dim[0], dim_b = mat_b.size(contraction_dim[1]);
|
||||
TORCH_CHECK_VALUE(mat_a.size(dim_a) == mat_b.size(dim_b),
|
||||
"Contraction dimensions (", dim_a, ",", dim_b, ") of mat_a and mat_b must match, got: ", mat_a.size(dim_a), " and ",
|
||||
mat_b.size(dim_b));
|
||||
// Note: only (-1, -2) is currently supported
|
||||
TORCH_CHECK_VALUE(dim_a == -1 && dim_b == -2, "Curently contraction dims must be (-1, -2) only");
|
||||
} else {
|
||||
TORCH_CHECK_VALUE(mat_a.size(-1) == mat_b.size(-2), "contraction dimension of mat_a and mat_b must match");
|
||||
}
|
||||
}
|
||||
TORCH_CHECK_VALUE(
|
||||
mat_a.size(-1) % 16 == 0,
|
||||
"Expected trailing dimension of mat_a to be divisible by 16 ",
|
||||
"but got mat1 shape: (",
|
||||
mat_a.sizes(),
|
||||
").");
|
||||
TORCH_CHECK_VALUE(mat_b.size(-2) % 16 == 0 && mat_b.size(-1) % 16 == 0,
|
||||
"Expected mat_b shape to be divisible by 16 ",
|
||||
"but got mat_b shape: (",
|
||||
mat_b.sizes(),
|
||||
").");
|
||||
|
||||
TORCH_CHECK_VALUE(!bias.has_value(), "Bias not supported yet");
|
||||
TORCH_CHECK_VALUE(offs.has_value() == (a_is_2d || b_is_2d), "Have to provide offsets if there is a 2d matrix");
|
||||
|
||||
// NOTE: mxfp8 x mxfp8 requires (and asserts later) that offsets is present.
|
||||
// for rowwise, no offsets implies 3d-3d and is handled by lower-level
|
||||
// routines
|
||||
if (offs.has_value()) {
|
||||
TORCH_CHECK_VALUE(offs->dim() == 1, "offs has to be 1D");
|
||||
TORCH_CHECK_VALUE(offs->dtype() == at::kInt, "Offsets have to be int32");
|
||||
}
|
||||
|
||||
const auto out_dtype_ = out_dtype.value_or(kBFloat16);
|
||||
TORCH_CHECK_VALUE(out_dtype_ == kBFloat16, "Only bf16 high precision output types are supported for grouped gemm");
|
||||
|
||||
Tensor out = create_grouped_gemm_output_tensor(mat_a, mat_b, offs, out_dtype_);
|
||||
|
||||
// Conversion of implicitly-defined enums to explicit
|
||||
auto scale_recipe_a_enum = convert_int_to_enum<ScalingType>(scale_recipe_a);
|
||||
auto swizzle_a_enum = convert_int_to_enum<SwizzleType>(swizzle_a);
|
||||
auto scale_recipe_b_enum = convert_int_to_enum<ScalingType>(scale_recipe_b);
|
||||
auto swizzle_b_enum = convert_int_to_enum<SwizzleType>(swizzle_b);
|
||||
|
||||
// at this point we can start working out what we want to be doing
|
||||
// Try to do as few steps as possible.
|
||||
// NOTE: support is deliberately sparse, can explicitly enumerate all combinations allowed.
|
||||
// Do this via a list of defined (name, acceptance, concrete_impl) tuples.
|
||||
ScaledGemmImplementation gemm_impl = ScaledGemmImplementation::NONE;
|
||||
for (const auto& fn_entry : scale_grouped_kernel_dispatch) {
|
||||
const auto [name, accept_fn, scaled_gemm_impl] = fn_entry;
|
||||
bool ok = accept_fn(mat_a.scalar_type(),
|
||||
scale_recipe_a_enum,
|
||||
scale_a,
|
||||
mat_b.scalar_type(),
|
||||
scale_recipe_b_enum,
|
||||
scale_b);
|
||||
if (ok) {
|
||||
gemm_impl = scaled_gemm_impl;
|
||||
break;
|
||||
}
|
||||
}
|
||||
TORCH_CHECK_VALUE(gemm_impl != ScaledGemmImplementation::NONE,
|
||||
"No gemm implementation was found");
|
||||
|
||||
switch (gemm_impl) {
|
||||
case ScaledGemmImplementation::ROWWISE_ROWWISE: {
|
||||
const int scale_multiplier = (mat_a.dim() == 2 && mat_b.dim() == 2) ? offs->size(0) : 1;
|
||||
_check_scales_fp8_rowwise(mat_a, scale_a[0], 0 /* dim */ , 0 /* arg_idx */, scale_multiplier);
|
||||
_check_scales_fp8_rowwise(mat_b, scale_b[0], 1 /* dim */ , 1 /* arg_idx */, scale_multiplier);
|
||||
return _f8_f8_bf16_rowwise_grouped_mm(
|
||||
mat_a,
|
||||
mat_b,
|
||||
scale_a[0],
|
||||
scale_b[0],
|
||||
offs,
|
||||
bias,
|
||||
use_fast_accum,
|
||||
out);
|
||||
}
|
||||
case ScaledGemmImplementation::MXFP8_MXFP8: {
|
||||
_check_scales_mxfp8(mat_a, scale_a[0], 0 /* dim */, 0 /* arg_idx */);
|
||||
_check_scales_mxfp8(mat_b, scale_b[0], 1 /* dim */, 1 /* arg_idx */);
|
||||
return _mx8_mx8_bf16_grouped_mm_fbgemm(
|
||||
mat_a,
|
||||
mat_b,
|
||||
scale_a[0],
|
||||
swizzle_a_enum[0],
|
||||
scale_b[0],
|
||||
swizzle_b_enum[0],
|
||||
offs.value(),
|
||||
out);
|
||||
}
|
||||
default:
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(false,
|
||||
"_scaled_grouped_mm_cuda_v2 is in an inconsistent state - should never reach here");
|
||||
}
|
||||
}
|
||||
|
||||
Tensor _grouped_mm_cuda(const Tensor& mat_a, const Tensor& mat_b,
|
||||
const std::optional<at::Tensor>& offs,
|
||||
const std::optional<at::Tensor>& bias,
|
||||
std::optional<c10::ScalarType> out_dtype) {
|
||||
_grouped_mm_validate_inputs(mat_a, mat_b, offs, bias, out_dtype);
|
||||
bool a_b_and_out_are_bf16 = (
|
||||
mat_a.dtype() == at::kBFloat16 &&
|
||||
mat_b.dtype() == at::kBFloat16 &&
|
||||
out_dtype.value_or(at::kBFloat16) == at::kBFloat16
|
||||
);
|
||||
#ifndef USE_ROCM
|
||||
bool use_fast_path = _scaled_mm_allowed_device(/*sm90_only*/true, /*sm100_only*/true) && a_b_and_out_are_bf16;
|
||||
#else
|
||||
// _scaled_mm_allowed_device is used here within _grouped_mm_cuda which seems incorrect since scale is not used.
|
||||
// the _grouped_mm_fallback should be safe for any ROCm GPU since it's just calling typical mm/bmm
|
||||
bool use_fast_path = false;
|
||||
#endif
|
||||
const auto out_dtype_ = _resolve_grouped_mm_out_dtype(mat_a, mat_b, out_dtype);
|
||||
Tensor out = create_grouped_gemm_output_tensor(mat_a, mat_b, offs, out_dtype_);
|
||||
if (use_fast_path) {
|
||||
// fast path, no d2h sync needed
|
||||
at::cuda::detail::bf16bf16_grouped_mm(mat_a, mat_b, offs, bias, out);
|
||||
} else {
|
||||
_grouped_mm_fallback(mat_a, mat_b, offs, bias, out_dtype, out);
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
} // namespace at::native
|
||||
@ -6,7 +6,7 @@
|
||||
#endif
|
||||
|
||||
// ROCm 6.3 is planned to have these functions, but until then here they are.
|
||||
#if defined(USE_ROCM) && ROCM_VERSION >= 60201
|
||||
#if defined(USE_ROCM)
|
||||
#include <device_functions.h>
|
||||
#include <hip/hip_fp16.h>
|
||||
#include <hip/hip_bf16.h>
|
||||
@ -115,9 +115,7 @@ __device__ __forceinline__ void fastSpecializedAtomicAdd(
|
||||
index_t index,
|
||||
const index_t numel,
|
||||
scalar_t value) {
|
||||
#if ( \
|
||||
(defined(USE_ROCM) && ROCM_VERSION < 60201) || \
|
||||
(defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 700)))
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 700))
|
||||
gpuAtomicAddNoReturn(
|
||||
reinterpret_cast<at::Half*>(tensor) + index,
|
||||
static_cast<at::Half>(value));
|
||||
@ -160,9 +158,7 @@ __device__ __forceinline__ void fastSpecializedAtomicAdd(
|
||||
index_t index,
|
||||
const index_t numel,
|
||||
scalar_t value) {
|
||||
#if ( \
|
||||
(defined(USE_ROCM) && ROCM_VERSION < 60201) || \
|
||||
(defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)))
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))
|
||||
gpuAtomicAddNoReturn(
|
||||
reinterpret_cast<at::BFloat16*>(tensor) + index,
|
||||
static_cast<at::BFloat16>(value));
|
||||
|
||||
@ -115,9 +115,23 @@ __device__ scalar_t reduce(Op op, PTA tensor, int plane) {
|
||||
// first the reductions each thread does separately
|
||||
scalar_t sum = static_cast<scalar_t>(0);
|
||||
for (int batch = threadIdx.y; batch < tensor.size(0); batch += blockDim.y) {
|
||||
#if defined(USE_ROCM)
|
||||
constexpr int UNRL = 4; // load deserilize factor
|
||||
scalar_t tmp[UNRL];
|
||||
for (int x = threadIdx.x; x < tensor.size(2); x += blockDim.x*UNRL) {
|
||||
#pragma unroll
|
||||
for (int u = 0; u < UNRL; u++)
|
||||
tmp[u] = op(batch, plane, std::min((int)tensor.size(2)-1, (int)(x+u*blockDim.x)));
|
||||
#pragma unroll
|
||||
for (int u = 0; u < UNRL; u++)
|
||||
if (x+u*blockDim.x < tensor.size(2))
|
||||
sum += tmp[u];
|
||||
}
|
||||
#else
|
||||
for (int x = threadIdx.x; x < tensor.size(2); x += blockDim.x) {
|
||||
sum += op(batch, plane, x);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
__shared__ scalar_t shared[C10_WARP_SIZE];
|
||||
SumReduceOp<scalar_t> reduce_op;
|
||||
@ -292,6 +306,22 @@ __global__ void batch_norm_collect_statistics_kernel(
|
||||
stat_accscalar_t var_n = 0;
|
||||
int n = 0;
|
||||
for (int batch = threadIdx.y; batch < input.size(0); batch += blockDim.y) {
|
||||
#if defined(USE_ROCM)
|
||||
constexpr int UNRL = 4;
|
||||
stat_accscalar_t v_[UNRL];
|
||||
for (int x = threadIdx.x; x < input.size(2); x += blockDim.x*UNRL) {
|
||||
for (int u = 0; u < UNRL; u++)
|
||||
v_[u] = input[batch][plane][min(x+u*blockDim.x, input.size(2)-1)];
|
||||
for (int u = 0; u < UNRL; u++) {
|
||||
if (x+u*blockDim.x < input.size(2)) {
|
||||
stat_accscalar_t d1 = v_[u] - avg;
|
||||
n++;
|
||||
avg += d1 / n;
|
||||
var_n += d1 * (v_[u] - avg);
|
||||
}
|
||||
}
|
||||
}
|
||||
#else
|
||||
for (int x = threadIdx.x; x < input.size(2); x += blockDim.x) {
|
||||
stat_accscalar_t v = input[batch][plane][x];
|
||||
stat_accscalar_t d1 = v - avg;
|
||||
@ -299,6 +329,7 @@ __global__ void batch_norm_collect_statistics_kernel(
|
||||
avg += d1 / n;
|
||||
var_n += d1 * (v - avg);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
// first warpSum to get one value per thread to
|
||||
|
||||
@ -43,6 +43,12 @@ std::tuple<Tensor&, Tensor&> kthvalue_out_impl_cuda(
|
||||
TORCH_CHECK(k >= 1 && k <= slicesize,
|
||||
"kthvalue(): selected number k out of range for dimension ", dim);
|
||||
|
||||
TORCH_CHECK(
|
||||
slicesize <= std::numeric_limits<int32_t>::max(),
|
||||
"kthvalue(): dimension ", dim, " is too large (", slicesize,
|
||||
"). The current CUDA implementation supports dimension sizes up to ",
|
||||
std::numeric_limits<int32_t>::max());
|
||||
|
||||
at::assert_no_overlap(self, values);
|
||||
|
||||
_reduction_with_indices_allocate_or_resize_output(
|
||||
@ -163,10 +169,6 @@ std::tuple<Tensor&, Tensor&> kthvalue_out_cuda(
|
||||
bool keepdim,
|
||||
Tensor& values,
|
||||
Tensor& indices) {
|
||||
// See note [Writing Nondeterministic Operations]
|
||||
// If there are duplicate elements of the kth value, the procedure for choosing which
|
||||
// of the duplicates to use for the indices output is nondeterministic.
|
||||
at::globalContext().alertNotDeterministic("kthvalue CUDA");
|
||||
auto result = [&]() {
|
||||
NoNamesGuard guard;
|
||||
// `kthvalue_out_impl_cuda` expects contiguous in input `self`.
|
||||
|
||||
@ -65,25 +65,34 @@ __global__ void gatherKthValue(
|
||||
&kValue);
|
||||
|
||||
// Find the index of the k-th highest element
|
||||
index_t kValueIndex = 0;
|
||||
bool foundKValue = false;
|
||||
__shared__ int32_t minIndexFound;
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
minIndexFound = static_cast<int32_t>(inputSliceSize);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
for (index_t i = threadIdx.x; i < inputSliceSize; i += blockDim.x) {
|
||||
bool inRange = (i < inputSliceSize);
|
||||
scalar_t v = inRange ? doLdg(&inputSliceStart[i * inputWithinSliceStride])
|
||||
: static_cast<scalar_t>(0);
|
||||
bool isKValue = inRange &&
|
||||
((v == kValue) || (at::_isnan(v) && at::_isnan(kValue)));
|
||||
if (isKValue) {
|
||||
kValueIndex = i;
|
||||
foundKValue = true;
|
||||
break;
|
||||
}
|
||||
// Early exit based on best-so-far
|
||||
if (i >= minIndexFound) {
|
||||
break;
|
||||
}
|
||||
|
||||
scalar_t v = doLdg(&inputSliceStart[i * inputWithinSliceStride]);
|
||||
bool isKValue =
|
||||
((v == kValue) || (at::_isnan(v) && at::_isnan(kValue)));
|
||||
|
||||
if (isKValue) {
|
||||
atomicMin(&minIndexFound, static_cast<int32_t>(i));
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (foundKValue) {
|
||||
kthValueSliceStart[0] = kValue;
|
||||
indicesSliceStart[0] = kValueIndex;
|
||||
__syncthreads();
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
indicesSliceStart[0] = static_cast<index_t>(minIndexFound);
|
||||
kthValueSliceStart[0] = kValue;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -127,6 +127,29 @@ __global__ void upsample_bilinear2d_nhwc_out_frame(
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef USE_ROCM
|
||||
// Helper function to compute output pixel range that can contribute to input pixel
|
||||
template <typename accscalar_t>
|
||||
__device__ __forceinline__ void compute_output_range(
|
||||
int input_pos,
|
||||
accscalar_t scale,
|
||||
int output_size,
|
||||
bool align_corners,
|
||||
int& min_output,
|
||||
int& max_output) {
|
||||
accscalar_t lo, hi;
|
||||
if (align_corners) {
|
||||
lo = static_cast<accscalar_t>(input_pos - 1) / scale;
|
||||
hi = static_cast<accscalar_t>(input_pos + 1) / scale;
|
||||
} else {
|
||||
lo = (input_pos - static_cast<accscalar_t>(0.5)) / scale - static_cast<accscalar_t>(0.5);
|
||||
hi = (input_pos + static_cast<accscalar_t>(1.5)) / scale - static_cast<accscalar_t>(0.5);
|
||||
}
|
||||
min_output = max(0, static_cast<int>(std::ceil(lo)));
|
||||
max_output = min(output_size - 1, static_cast<int>(std::floor(hi)));
|
||||
}
|
||||
#endif
|
||||
|
||||
// Backward (adjoint) operation 1 <- 2 (accumulates)
|
||||
template <typename scalar_t, typename accscalar_t>
|
||||
C10_LAUNCH_BOUNDS_1(1024)
|
||||
@ -141,8 +164,74 @@ __global__ void upsample_bilinear2d_backward_out_frame(
|
||||
const bool align_corners,
|
||||
scalar_t* __restrict__ idata,
|
||||
const scalar_t* __restrict__ odata) {
|
||||
const size_t o_numel = nc * width2 * height2;
|
||||
// In C++, integer multiplication, like in standard arithmetic, is generally commutative.
|
||||
const size_t i_numel = nc * width1 * height1;
|
||||
#ifdef USE_ROCM
|
||||
for (size_t index = blockDim.x * blockIdx.x + threadIdx.x; index < i_numel;
|
||||
index += blockDim.x * gridDim.x) {
|
||||
// Decode input pixel coordinates
|
||||
size_t index_temp = index;
|
||||
const int w1 = index_temp % width1;
|
||||
index_temp /= width1;
|
||||
const int h1 = index_temp % height1;
|
||||
const size_t nc_idx = index_temp / height1;
|
||||
|
||||
accscalar_t grad_sum = 0;
|
||||
|
||||
// Find range of output pixels that could interpolate from this input pixel
|
||||
int h2_min, h2_max, w2_min, w2_max;
|
||||
compute_output_range<accscalar_t>(h1, rheight, height2, align_corners, h2_min, h2_max);
|
||||
compute_output_range<accscalar_t>(w1, rwidth, width2, align_corners, w2_min, w2_max);
|
||||
|
||||
// Iterate over potential output pixels
|
||||
for (int h2 = h2_min; h2 <= h2_max; h2++) {
|
||||
for (int w2 = w2_min; w2 <= w2_max; w2++) {
|
||||
// Compute source coordinates for this output pixel
|
||||
const accscalar_t h1r = area_pixel_compute_source_index<accscalar_t>(
|
||||
rheight, h2, align_corners, /*cubic=*/false);
|
||||
const int h1_base = (int)h1r;
|
||||
const int h1p = (h1_base < height1 - 1) ? 1 : 0;
|
||||
const accscalar_t h1lambda = h1r - h1_base;
|
||||
const accscalar_t h0lambda = static_cast<accscalar_t>(1) - h1lambda;
|
||||
|
||||
const accscalar_t w1r = area_pixel_compute_source_index<accscalar_t>(
|
||||
rwidth, w2, align_corners, /*cubic=*/false);
|
||||
const int w1_base = (int)w1r;
|
||||
const int w1p = (w1_base < width1 - 1) ? 1 : 0;
|
||||
const accscalar_t w1lambda = w1r - w1_base;
|
||||
const accscalar_t w0lambda = static_cast<accscalar_t>(1) - w1lambda;
|
||||
|
||||
// Check if our input pixel participates in this interpolation and accumulate all weights
|
||||
// At boundaries, h1p=0 or w1p=0 causes some sampling positions to collapse
|
||||
// to the same pixel, so we need to accumulate weights from all matching positions
|
||||
accscalar_t weight = 0;
|
||||
|
||||
// Check all four interpolation positions and accumulate weights
|
||||
if (h1 == h1_base && w1 == w1_base) {
|
||||
weight += h0lambda * w0lambda; // top-left
|
||||
}
|
||||
if (h1 == h1_base && w1 == w1_base + w1p) {
|
||||
weight += h0lambda * w1lambda; // top-right (may be same as top-left if w1p=0)
|
||||
}
|
||||
if (h1 == h1_base + h1p && w1 == w1_base) {
|
||||
weight += h1lambda * w0lambda; // bottom-left (may be same as top-left if h1p=0)
|
||||
}
|
||||
if (h1 == h1_base + h1p && w1 == w1_base + w1p) {
|
||||
weight += h1lambda * w1lambda; // bottom-right (may collapse to other positions)
|
||||
}
|
||||
|
||||
if (weight > 0) {
|
||||
const size_t output_idx = nc_idx * height2 * width2 + h2 * width2 + w2;
|
||||
grad_sum += weight * static_cast<accscalar_t>(odata[output_idx]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Write accumulated gradient (no atomics needed)
|
||||
idata[index] = static_cast<scalar_t>(grad_sum);
|
||||
}
|
||||
#else
|
||||
const size_t o_numel = nc * width2 * height2;
|
||||
for (size_t index = blockDim.x * blockIdx.x + threadIdx.x; index < o_numel;
|
||||
index += blockDim.x * gridDim.x) {
|
||||
size_t index_temp = index;
|
||||
@ -191,6 +280,7 @@ __global__ void upsample_bilinear2d_backward_out_frame(
|
||||
static_cast<scalar_t>(h1lambda * w1lambda * d2val),
|
||||
true);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename scalar_t, typename accscalar_t>
|
||||
@ -387,7 +477,6 @@ static void upsample_bilinear2d_backward_out_cuda_template(
|
||||
// threads are not covering the whole input tensor.
|
||||
grad_input.zero_();
|
||||
|
||||
const size_t num_kernels = nbatch * channels * output_height * output_width;
|
||||
const int num_threads = std::min(
|
||||
at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock, 1024);
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
@ -397,6 +486,12 @@ static void upsample_bilinear2d_backward_out_cuda_template(
|
||||
return;
|
||||
}
|
||||
|
||||
#ifdef USE_ROCM
|
||||
constexpr bool use_input = true;
|
||||
#else
|
||||
constexpr bool use_input = false;
|
||||
#endif
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(
|
||||
at::ScalarType::Half, at::ScalarType::BFloat16,
|
||||
grad_output_.scalar_type(), "upsample_bilinear2d_backward_out_frame", [&] {
|
||||
@ -414,6 +509,8 @@ static void upsample_bilinear2d_backward_out_cuda_template(
|
||||
const accscalar_t rwidth = area_pixel_compute_scale<accscalar_t>(
|
||||
input_width, output_width, align_corners, scales_w);
|
||||
|
||||
const size_t num_kernels = nbatch * channels * output_height * output_width;
|
||||
|
||||
upsample_bilinear2d_backward_nhwc_out_frame<scalar_t, accscalar_t>
|
||||
<<<ceil_div(num_kernels, static_cast<size_t>(num_threads)), num_threads, 0, stream>>>(
|
||||
input_height,
|
||||
@ -444,6 +541,8 @@ static void upsample_bilinear2d_backward_out_cuda_template(
|
||||
const accscalar_t rwidth = area_pixel_compute_scale<accscalar_t>(
|
||||
input_width, output_width, align_corners, scales_w);
|
||||
|
||||
const size_t num_kernels = nbatch * channels * (use_input ? input_height * input_width : output_height * output_width);
|
||||
|
||||
upsample_bilinear2d_backward_out_frame<scalar_t, accscalar_t>
|
||||
<<<ceil_div(num_kernels, static_cast<size_t>(num_threads)),
|
||||
num_threads,
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
#if (defined(USE_ROCM) && ROCM_VERSION >= 50700) || ((defined(CUDA_VERSION) && CUDA_VERSION >= 12000) && (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 800)))
|
||||
#if defined(USE_ROCM) || ((defined(CUDA_VERSION) && CUDA_VERSION >= 12000) && (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 800)))
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_runtime.h>
|
||||
@ -133,7 +133,7 @@ inline __host__ __device__ uint32_t getAlignmentRoundUp(const void* p) {
|
||||
#define CDNA2_OR_LATER 0
|
||||
#endif
|
||||
|
||||
#if (defined(USE_ROCM) && ROCM_VERSION >= 50700) || ((defined(CUDA_VERSION) && CUDA_VERSION >= 12000) && (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 800)))
|
||||
#if defined(USE_ROCM) || ((defined(CUDA_VERSION) && CUDA_VERSION >= 12000) && (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 800)))
|
||||
|
||||
#if defined(USE_ROCM)
|
||||
// TODO: Support RDNA
|
||||
@ -1161,7 +1161,7 @@ at::Tensor _weight_int4pack_mm_cuda(
|
||||
auto C_final = at::empty(
|
||||
{m, n}, at::TensorOptions().dtype(at::kBFloat16).device(A.device()));
|
||||
|
||||
#if (defined(USE_ROCM) && ROCM_VERSION >= 50700) || ((defined(CUDA_VERSION) && CUDA_VERSION >= 12000) && (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 800)))
|
||||
#if defined(USE_ROCM) || ((defined(CUDA_VERSION) && CUDA_VERSION >= 12000) && (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 800)))
|
||||
auto stream = at::cuda::getCurrentCUDAStream();
|
||||
#define RUN_GEMM(WARPS, K_TILES_PER_WARP, Q_GROUP_SIZE, REDUCE_TYPE) \
|
||||
do { \
|
||||
@ -1327,7 +1327,7 @@ at::Tensor _convert_weight_to_int4pack_cuda(
|
||||
{nTilesTensor, kSuperTiles, 32, innerKTiles / 2},
|
||||
at::TensorOptions().dtype(at::kInt).device(in.device()));
|
||||
|
||||
#if (defined(USE_ROCM) && ROCM_VERSION >= 50700) || ((defined(CUDA_VERSION) && CUDA_VERSION >= 12000) && (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 800)))
|
||||
#if defined(USE_ROCM) || ((defined(CUDA_VERSION) && CUDA_VERSION >= 12000) && (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 800)))
|
||||
auto stream = at::cuda::getCurrentCUDAStream();
|
||||
dim3 grid(kSuperTiles, nTiles);
|
||||
|
||||
|
||||
@ -2,6 +2,7 @@
|
||||
#include <ATen/WrapDimUtilsMulti.h>
|
||||
#include <ATen/native/Resize.h>
|
||||
#include <ATen/native/mkldnn/xpu/detail/oneDNN.h>
|
||||
#include <ATen/native/xpu/Blas.h>
|
||||
#include <torch/library.h>
|
||||
#ifndef AT_PER_OPERATOR_HEADERS
|
||||
|
||||
@ -50,9 +51,13 @@ Tensor& addmm_out(
|
||||
mat1.dtype(),
|
||||
" != ",
|
||||
mat2.dtype())
|
||||
|
||||
// complex case
|
||||
TORCH_CHECK(
|
||||
!mat1.is_complex(), "Complex datatype matmul is not supported in oneDNN");
|
||||
if (self.is_complex()) {
|
||||
at::native::addmm_complex_out_xpu(self, mat1, mat2, beta, alpha, result);
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
std::vector<int64_t> result_shape = {mat1.size(0), mat2.size(1)};
|
||||
result.resize_(result_shape);
|
||||
@ -167,8 +172,11 @@ Tensor& mm_out(const Tensor& self, const Tensor& mat2, Tensor& result) {
|
||||
return result;
|
||||
}
|
||||
|
||||
TORCH_CHECK(
|
||||
!self.is_complex(), "Complex datatype matmul is not supported in oneDNN");
|
||||
if (self.is_complex()) {
|
||||
at::native::mm_complex_out_xpu(self, mat2, result);
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
onednn::matmul(result, self, mat2, Tensor(), true, onednn::Attr());
|
||||
return result;
|
||||
@ -208,9 +216,12 @@ Tensor& baddbmm_out(
|
||||
input.sizes());
|
||||
|
||||
// complex case
|
||||
TORCH_CHECK(
|
||||
!batch1.is_complex(),
|
||||
"Complex datatype matmul is not supported in oneDNN");
|
||||
if (input.is_complex()) {
|
||||
at::native::baddbmm_complex_out_xpu(
|
||||
input, batch1, batch2, beta, alpha, result);
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
// general case
|
||||
onednn::Attr attr;
|
||||
@ -257,8 +268,13 @@ Tensor& bmm_out(const Tensor& self, const Tensor& batch2, Tensor& result) {
|
||||
return result;
|
||||
}
|
||||
|
||||
TORCH_CHECK(
|
||||
!self.is_complex(), "Complex datatype matmul is not supported in oneDNN");
|
||||
// complex case
|
||||
if (self.is_complex()) {
|
||||
at::native::bmm_complex_out_xpu(self, batch2, result);
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
onednn::matmul(result, self, batch2, at::Tensor(), true, onednn::Attr());
|
||||
return result;
|
||||
}
|
||||
|
||||
@ -222,6 +222,13 @@ struct nextafter_functor {
|
||||
}
|
||||
};
|
||||
|
||||
struct hypot_functor {
|
||||
template <typename T>
|
||||
inline T operator()(const T a, const T b) {
|
||||
return static_cast<T>(precise::sqrt(float(a) * a + float(b) * b));
|
||||
}
|
||||
};
|
||||
|
||||
// Complex binary functors
|
||||
struct polar_functor {
|
||||
template <typename U>
|
||||
@ -362,6 +369,7 @@ struct igammac_functor {
|
||||
REGISTER_OPMATH_BINARY_OP(NAME, half, half); \
|
||||
REGISTER_OPMATH_BINARY_OP(NAME, bfloat, bfloat)
|
||||
|
||||
REGISTER_FLOAT_BINARY_OP(hypot);
|
||||
REGISTER_FLOAT_BINARY_OP(copysign);
|
||||
REGISTER_INT2FLOAT_BINARY_OP(copysign);
|
||||
REGISTER_FLOAT_BINARY_OP(fmax);
|
||||
|
||||
16
aten/src/ATen/native/mps/kernels/LinearAlgebra.h
Normal file
16
aten/src/ATen/native/mps/kernels/LinearAlgebra.h
Normal file
@ -0,0 +1,16 @@
|
||||
#pragma onces
|
||||
#include <c10/metal/common.h>
|
||||
|
||||
template <unsigned N = c10::metal::max_ndim>
|
||||
struct OrgqrParams {
|
||||
int32_t num_batch_dims;
|
||||
|
||||
uint32_t m;
|
||||
uint32_t n;
|
||||
uint32_t k;
|
||||
|
||||
::c10::metal::array<uint32_t, N> A_strides;
|
||||
::c10::metal::array<uint32_t, N> tau_strides;
|
||||
::c10::metal::array<uint32_t, N> H_strides;
|
||||
::c10::metal::array<uint32_t, N> H_sizes;
|
||||
};
|
||||
@ -1,3 +1,4 @@
|
||||
#include <ATen/native/mps/kernels/LinearAlgebra.h>
|
||||
#include <c10/metal/utils.h>
|
||||
#include <metal_array>
|
||||
#include <metal_simdgroup>
|
||||
@ -640,6 +641,164 @@ kernel void applyPivots(
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static T bool_to_float(bool b) {
|
||||
return static_cast<T>(b);
|
||||
}
|
||||
|
||||
template <>
|
||||
half2 bool_to_float(bool b) {
|
||||
return half2(b ? 1 : 0, 0);
|
||||
}
|
||||
|
||||
template <>
|
||||
float2 bool_to_float(bool b) {
|
||||
return float2(b ? 1 : 0, 0);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static T calc_H_irc(
|
||||
device T* A,
|
||||
uint32_t A_stride_r,
|
||||
uint32_t A_stride_c,
|
||||
constant T* tau,
|
||||
uint32_t tau_stride,
|
||||
uint32_t r,
|
||||
uint32_t c,
|
||||
uint32_t i) {
|
||||
T I_val = bool_to_float<T>(r == c);
|
||||
T tau_val = tau[i * tau_stride];
|
||||
|
||||
T A_ci = c10::metal::conj(A[c * A_stride_r + i * A_stride_c]);
|
||||
T A_ri = A[r * A_stride_r + i * A_stride_c];
|
||||
|
||||
T c_eq_i = bool_to_float<T>(c == i);
|
||||
T r_eq_i = bool_to_float<T>(r == i);
|
||||
|
||||
T A_ci_ = (c > i) ? A_ci : c_eq_i;
|
||||
T A_ri_ = (r > i) ? A_ri : r_eq_i;
|
||||
|
||||
return I_val - c10::metal::mul(tau_val, c10::metal::mul(A_ci_, A_ri_));
|
||||
}
|
||||
|
||||
// Calculate (A @ B)[r, c], the element in the r-th row and c-th column of the
|
||||
// result of matrix multiplying A and B together. A and B must be size m-by-m
|
||||
// and have the same strides. The formula for this operation, written in Python
|
||||
// syntax, is:
|
||||
// (A @ B)[r, c] = A[r, :].dot(B[:, c])
|
||||
template <typename T>
|
||||
static T calc_matmul_rc(
|
||||
device T* A,
|
||||
device T* B,
|
||||
uint32_t stride_r,
|
||||
uint32_t stride_c,
|
||||
uint32_t m,
|
||||
uint32_t r,
|
||||
uint32_t c) {
|
||||
T AB_rc = 0;
|
||||
auto A_row_offset = r * stride_r;
|
||||
auto B_col_offset = c * stride_c;
|
||||
|
||||
uint32_t A_col_offset = 0;
|
||||
uint32_t B_row_offset = 0;
|
||||
|
||||
for (uint32_t j = 0; j < m;
|
||||
j++, A_col_offset += stride_c, B_row_offset += stride_r) {
|
||||
AB_rc += c10::metal::mul(
|
||||
A[A_row_offset + A_col_offset], B[B_row_offset + B_col_offset]);
|
||||
}
|
||||
return AB_rc;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
kernel void orgqr(
|
||||
device T* A [[buffer(0)]],
|
||||
constant T* tau [[buffer(1)]],
|
||||
device T* H [[buffer(2)]],
|
||||
device T* H_prod [[buffer(3)]],
|
||||
constant OrgqrParams<>& params [[buffer(4)]],
|
||||
uint tid [[thread_position_in_grid]]) {
|
||||
constant auto& A_strides = params.A_strides;
|
||||
constant auto& tau_strides = params.tau_strides;
|
||||
constant auto& H_strides = params.H_strides;
|
||||
constant auto& H_sizes = params.H_sizes;
|
||||
|
||||
auto num_batch_dims = params.num_batch_dims;
|
||||
auto m = params.m;
|
||||
auto n = params.n;
|
||||
auto k = params.k;
|
||||
|
||||
auto m2 = m * m;
|
||||
auto batch_idx = tid / m2;
|
||||
|
||||
// Find the matrices for this thread's batch index
|
||||
uint32_t A_offset = 0;
|
||||
uint32_t tau_offset = 0;
|
||||
uint32_t H_offset = 0;
|
||||
|
||||
for (auto dim = num_batch_dims - 1; dim >= 0; dim--) {
|
||||
auto dim_size = H_sizes[dim];
|
||||
auto dim_idx = batch_idx % dim_size;
|
||||
|
||||
A_offset += dim_idx * A_strides[dim];
|
||||
tau_offset += dim_idx * tau_strides[dim];
|
||||
H_offset += dim_idx * H_strides[dim];
|
||||
|
||||
batch_idx /= dim_size;
|
||||
}
|
||||
|
||||
A += A_offset;
|
||||
tau += tau_offset;
|
||||
H += H_offset;
|
||||
H_prod += H_offset;
|
||||
|
||||
auto matrix_idx = tid % m2;
|
||||
auto r = matrix_idx / m;
|
||||
auto c = matrix_idx % m;
|
||||
auto A_stride_r = A_strides[num_batch_dims];
|
||||
auto A_stride_c = A_strides[num_batch_dims + 1];
|
||||
auto tau_stride = tau_strides[num_batch_dims];
|
||||
auto H_stride_r = H_strides[num_batch_dims];
|
||||
auto H_stride_c = H_strides[num_batch_dims + 1];
|
||||
|
||||
// Find the element of H and H_prod that this thread will calculate
|
||||
device T* H_elem_ptr = H + (r * H_stride_r + c * H_stride_c);
|
||||
device T* H_prod_elem_ptr = H_prod + (r * H_stride_r + c * H_stride_c);
|
||||
|
||||
for (uint32_t i = 0; i < k; i++) {
|
||||
// Calculate and write H_i
|
||||
|
||||
T H_irc = calc_H_irc(A, A_stride_r, A_stride_c, tau, tau_stride, r, c, i);
|
||||
|
||||
// Calculate element [r, c] of prod(H_0, ..., H_i)
|
||||
if (i == 0) {
|
||||
*H_prod_elem_ptr = H_irc;
|
||||
} else {
|
||||
*H_elem_ptr = H_irc;
|
||||
|
||||
// Need this sync because the below matmul requires all threads to finish
|
||||
// writing their entries to `H_prod` and `H`.
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
T H_prod_0_to_i_rc =
|
||||
calc_matmul_rc(H_prod, H, H_stride_r, H_stride_c, m, r, c);
|
||||
|
||||
// Need this sync because the above matmul uses the current values in
|
||||
// `H_prod`, and we don't want to overwrite those until all threads are
|
||||
// finished using them.
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
*H_prod_elem_ptr = H_prod_0_to_i_rc;
|
||||
}
|
||||
}
|
||||
|
||||
device T* A_elem_ptr = A + (r * A_stride_r + c * A_stride_c);
|
||||
|
||||
if (c < n) {
|
||||
*A_elem_ptr = *H_prod_elem_ptr;
|
||||
}
|
||||
}
|
||||
|
||||
#define INSTANTIATE_MM_OPS(DTYPE) \
|
||||
template [[host_name("matmul_" #DTYPE)]] kernel void matmul<DTYPE>( \
|
||||
constant DTYPE * mat1Data [[buffer(0)]], \
|
||||
@ -679,3 +838,19 @@ INSTANTIATE_MM_OPS(int);
|
||||
INSTANTIATE_MM_OPS(short);
|
||||
INSTANTIATE_MM_OPS(char);
|
||||
INSTANTIATE_MM_OPS(uchar);
|
||||
|
||||
#define REGISTER_ORGQR(T) \
|
||||
template [[host_name("orgqr_" #T)]] \
|
||||
kernel void orgqr<T>( \
|
||||
device T * A [[buffer(0)]], \
|
||||
constant T * tau [[buffer(1)]], \
|
||||
device T * H [[buffer(2)]], \
|
||||
device T * H_prod [[buffer(3)]], \
|
||||
constant OrgqrParams<> & params [[buffer(4)]], \
|
||||
uint tid [[thread_position_in_grid]]);
|
||||
|
||||
REGISTER_ORGQR(float);
|
||||
REGISTER_ORGQR(half);
|
||||
REGISTER_ORGQR(bfloat);
|
||||
REGISTER_ORGQR(float2);
|
||||
REGISTER_ORGQR(half2);
|
||||
|
||||
@ -5,6 +5,21 @@
|
||||
using namespace metal;
|
||||
using namespace c10::metal;
|
||||
|
||||
struct angle_functor {
|
||||
template <typename T, enable_if_t<is_complex_v<T>, bool> = true>
|
||||
inline T operator()(const T x) {
|
||||
return T(atan2(x.y, x.x), 0);
|
||||
}
|
||||
template <typename T, enable_if_t<is_scalar_floating_point_v<T>, bool> = true>
|
||||
inline T operator()(const T x) {
|
||||
return T(isnan(x) ? x : x < 0 ? M_PI_F : 0.0);
|
||||
}
|
||||
template <typename T, enable_if_t<is_scalar_integral_v<T>, bool> = true>
|
||||
inline float operator()(const T x) {
|
||||
return x < 0 ? M_PI_F : 0.0;
|
||||
}
|
||||
};
|
||||
|
||||
// Implement exp wrapper for both real and complex types
|
||||
template <typename T, enable_if_t<is_scalar_floating_point_v<T>, bool> = true>
|
||||
inline T exp_(const T x) {
|
||||
@ -545,6 +560,7 @@ REGISTER_UNARY_OP(abs, float, float);
|
||||
REGISTER_UNARY_OP(abs, half, half);
|
||||
|
||||
#define INSTANTIATE_UNARY_KERNELS2(DTYPE0, DTYPE1) \
|
||||
REGISTER_UNARY_OP(angle, DTYPE1, DTYPE0); \
|
||||
REGISTER_UNARY_OP(erf, DTYPE1, DTYPE0); \
|
||||
REGISTER_UNARY_OP(erfc, DTYPE1, DTYPE0); \
|
||||
REGISTER_UNARY_OP(erfinv, DTYPE1, DTYPE0); \
|
||||
@ -583,6 +599,7 @@ INSTANTIATE_UNARY_KERNELS2(float, int);
|
||||
INSTANTIATE_UNARY_KERNELS2(float, long);
|
||||
|
||||
#define INSTANTIATE_UNARY_KERNELS_VEC2(DTYPE) \
|
||||
REGISTER_UNARY_OP(angle, DTYPE##2, DTYPE##2); \
|
||||
REGISTER_UNARY_OP(neg, DTYPE##2, DTYPE##2); \
|
||||
REGISTER_UNARY_OP(exp, DTYPE##2, DTYPE##2); \
|
||||
REGISTER_UNARY_OP(expm1, DTYPE##2, DTYPE##2); \
|
||||
|
||||
@ -202,6 +202,10 @@ static void igammac_mps_kernel(TensorIteratorBase& iter) {
|
||||
lib.exec_binary_kernel(iter, "igammac");
|
||||
}
|
||||
|
||||
static void hypot_mps_kernel(TensorIteratorBase& iter) {
|
||||
lib.exec_binary_kernel(iter, "hypot");
|
||||
}
|
||||
|
||||
REGISTER_DISPATCH(fmax_stub, &fmax_mps_kernel)
|
||||
REGISTER_DISPATCH(fmin_stub, &fmin_mps_kernel)
|
||||
REGISTER_DISPATCH(copysign_stub, ©sign_mps_kernel)
|
||||
@ -229,4 +233,5 @@ REGISTER_DISPATCH(fmod_stub, &fmod_mps_kernel)
|
||||
REGISTER_DISPATCH(remainder_stub, &remainder_mps_kernel)
|
||||
REGISTER_DISPATCH(igamma_stub, &igamma_mps_kernel)
|
||||
REGISTER_DISPATCH(igammac_stub, &igammac_mps_kernel)
|
||||
REGISTER_DISPATCH(hypot_stub, &hypot_mps_kernel)
|
||||
} // namespace at::native
|
||||
|
||||
@ -16,7 +16,6 @@
|
||||
#include <ATen/ops/eq_native.h>
|
||||
#include <ATen/ops/ge_native.h>
|
||||
#include <ATen/ops/gt_native.h>
|
||||
#include <ATen/ops/hypot_native.h>
|
||||
#include <ATen/ops/le_native.h>
|
||||
#include <ATen/ops/logaddexp2_native.h>
|
||||
#include <ATen/ops/logaddexp_native.h>
|
||||
@ -278,22 +277,6 @@ TORCH_IMPL_FUNC(pow_Scalar_out_mps)(const Scalar& base, const Tensor& exp, const
|
||||
}
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(hypot_out_mps)(const Tensor& self, const Tensor& other, const Tensor& output) {
|
||||
mps::BinaryOpBlock hypot_op_block = ^BinaryOpFn(cachedGraph, primaryCastTensor, secondaryCastTensor) {
|
||||
MPSGraph* mpsGraph = cachedGraph->graph();
|
||||
MPSGraphTensor* twoTensor = [mpsGraph constantWithScalar:2.0 shape:@[ @1 ] dataType:primaryCastTensor.dataType];
|
||||
MPSGraphTensor* sumTensor = [mpsGraph additionWithPrimaryTensor:[mpsGraph powerWithPrimaryTensor:primaryCastTensor
|
||||
secondaryTensor:twoTensor
|
||||
name:nil]
|
||||
secondaryTensor:[mpsGraph powerWithPrimaryTensor:secondaryCastTensor
|
||||
secondaryTensor:twoTensor
|
||||
name:nil]
|
||||
name:nil];
|
||||
return [mpsGraph squareRootWithTensor:sumTensor name:nil];
|
||||
};
|
||||
mps::binaryOpTensor(self, other, output, "hypot_out_mps", hypot_op_block);
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(logaddexp_out_mps)(const Tensor& self, const Tensor& other, const Tensor& output) {
|
||||
mps::BinaryOpBlock logaddexp_op_block = ^BinaryOpFn(cachedGraph, primaryCastTensor, secondaryCastTensor) {
|
||||
MPSGraph* mpsGraph = cachedGraph->graph();
|
||||
|
||||
@ -8,6 +8,9 @@
|
||||
#include <ATen/native/Resize.h>
|
||||
#include <ATen/native/mps/MPSGraphSequoiaOps.h>
|
||||
#include <ATen/native/mps/OperationUtils.h>
|
||||
#include <ATen/native/mps/kernels/LinearAlgebra.h>
|
||||
|
||||
#include <fmt/format.h>
|
||||
|
||||
#ifndef AT_PER_OPERATOR_HEADERS
|
||||
#include <ATen/Functions.h>
|
||||
@ -28,6 +31,7 @@
|
||||
#include <ATen/ops/linalg_solve_triangular_native.h>
|
||||
#include <ATen/ops/lu_unpack_native.h>
|
||||
#include <ATen/ops/mm_native.h>
|
||||
#include <ATen/ops/orgqr_native.h>
|
||||
#include <ATen/ops/slice.h>
|
||||
#include <ATen/ops/stack.h>
|
||||
#include <ATen/ops/triangular_solve_native.h>
|
||||
@ -1235,6 +1239,69 @@ static void cholesky_stub_impl(const Tensor& out, const Tensor& info, bool upper
|
||||
}
|
||||
}
|
||||
|
||||
static Tensor& orgqr_stub_impl(Tensor& self, const Tensor& tau) {
|
||||
if (self.numel() == 0) {
|
||||
return self;
|
||||
}
|
||||
|
||||
auto m = self.size(-2);
|
||||
auto n = self.size(-1);
|
||||
auto k = tau.size(-1);
|
||||
|
||||
if (tau.numel() == 0) {
|
||||
auto I = eye(m, self.scalar_type(), std::nullopt, self.device());
|
||||
return self.copy_(I.slice(-1, 0, n));
|
||||
}
|
||||
|
||||
auto num_batch_dims = self.dim() - 2;
|
||||
auto batch_sizes = self.sizes().slice(0, num_batch_dims);
|
||||
|
||||
std::vector<int64_t> H_sizes(num_batch_dims + 2);
|
||||
for (auto dim : c10::irange(num_batch_dims)) {
|
||||
H_sizes[dim] = self.size(dim);
|
||||
}
|
||||
H_sizes[num_batch_dims] = m;
|
||||
H_sizes[num_batch_dims + 1] = m;
|
||||
|
||||
auto H = at::empty(H_sizes, self.options().memory_format(MemoryFormat::Contiguous));
|
||||
auto H_prod = at::empty_like(H);
|
||||
|
||||
OrgqrParams params;
|
||||
|
||||
params.num_batch_dims = num_batch_dims;
|
||||
params.m = m;
|
||||
params.n = n;
|
||||
params.k = k;
|
||||
|
||||
for (const auto dim : c10::irange(self.dim())) {
|
||||
params.A_strides[dim] = self.stride(dim);
|
||||
|
||||
if (dim < tau.dim()) {
|
||||
params.tau_strides[dim] = tau.stride(dim);
|
||||
}
|
||||
|
||||
params.H_strides[dim] = H.stride(dim);
|
||||
params.H_sizes[dim] = H.size(dim);
|
||||
}
|
||||
|
||||
auto num_threads = H.numel();
|
||||
MPSStream* stream = getCurrentMPSStream();
|
||||
|
||||
dispatch_sync_with_rethrow(stream->queue(), ^() {
|
||||
@autoreleasepool {
|
||||
id<MTLComputeCommandEncoder> compute_encoder = stream->commandEncoder();
|
||||
auto pipeline_state = lib.getPipelineStateForFunc(fmt::format("orgqr_{}", scalarToMetalTypeString(self)));
|
||||
getMPSProfiler().beginProfileKernel(pipeline_state, "orgqr", {self, tau});
|
||||
[compute_encoder setComputePipelineState:pipeline_state];
|
||||
mtl_setArgs(compute_encoder, self, tau, H, H_prod, params);
|
||||
mtl_dispatch1DJob(compute_encoder, pipeline_state, num_threads);
|
||||
getMPSProfiler().endProfileKernel(pipeline_state);
|
||||
}
|
||||
});
|
||||
|
||||
return self;
|
||||
}
|
||||
|
||||
} // namespace mps
|
||||
|
||||
Tensor addr_mps(const Tensor& self, const Tensor& vec1, const Tensor& vec2, const Scalar& beta, const Scalar& alpha) {
|
||||
@ -1471,4 +1538,6 @@ TORCH_IMPL_FUNC(linalg_inv_ex_out_mps)(const Tensor& A, bool check_errors, const
|
||||
}
|
||||
|
||||
REGISTER_DISPATCH(cholesky_stub, mps::cholesky_stub_impl)
|
||||
REGISTER_DISPATCH(orgqr_stub, mps::orgqr_stub_impl);
|
||||
|
||||
} // namespace at::native
|
||||
|
||||
@ -34,6 +34,7 @@ REGISTER_UNARY_TI_DISPATCH(sinc);
|
||||
REGISTER_UNARY_TI_DISPATCH(sinh);
|
||||
REGISTER_UNARY_TI_DISPATCH(cosh);
|
||||
REGISTER_UNARY_TI_DISPATCH(tanh);
|
||||
REGISTER_UNARY_TI_DISPATCH(angle);
|
||||
REGISTER_UNARY_TI_DISPATCH(abs);
|
||||
REGISTER_UNARY_TI_DISPATCH(sin);
|
||||
REGISTER_UNARY_TI_DISPATCH(cos);
|
||||
|
||||
@ -12,7 +12,6 @@
|
||||
#include <ATen/ops/_copy_from_and_resize.h>
|
||||
#include <ATen/ops/acos_native.h>
|
||||
#include <ATen/ops/acosh_native.h>
|
||||
#include <ATen/ops/angle_native.h>
|
||||
#include <ATen/ops/asin_native.h>
|
||||
#include <ATen/ops/asinh_native.h>
|
||||
#include <ATen/ops/atan_native.h>
|
||||
@ -204,23 +203,6 @@ Tensor& logical_not_out_mps(const Tensor& self, Tensor& output) {
|
||||
return output;
|
||||
}
|
||||
|
||||
Tensor& angle_out_mps(const Tensor& self, Tensor& output) {
|
||||
mps::unary_op(self, output, "angle_out_mps", ^MPSGraphTensor*(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) {
|
||||
auto realPart = [mpsGraph realPartOfTensor:inputTensor name:nil];
|
||||
auto imagPart = [mpsGraph imaginaryPartOfTensor:inputTensor name:nil];
|
||||
return [mpsGraph atan2WithPrimaryTensor:imagPart secondaryTensor:realPart name:nil];
|
||||
});
|
||||
return output;
|
||||
}
|
||||
|
||||
Tensor angle_mps(const Tensor& self) {
|
||||
const auto float_type = c10::isIntegralType(self.scalar_type(), /*includeBool=*/true)
|
||||
? c10::typeMetaToScalarType(c10::get_default_dtype())
|
||||
: c10::toRealValueType(self.scalar_type());
|
||||
Tensor result = at::empty({0}, self.options().dtype(float_type));
|
||||
return angle_out_mps(self, result);
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(frac_out_mps)(const Tensor& self, const Tensor& output) {
|
||||
TORCH_CHECK(isFloatingType(self.scalar_type()), "frac_out_mps is only implemented for floating types");
|
||||
mps::unary_op(self, output, "frac_out_mps", ^MPSGraphTensor*(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) {
|
||||
|
||||
@ -403,16 +403,14 @@
|
||||
device_check: NoCheck # TensorIterator
|
||||
variants: function, method
|
||||
dispatch:
|
||||
CPU, CUDA: angle
|
||||
MPS: angle_mps
|
||||
CPU, CUDA, MPS: angle
|
||||
SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: angle_sparse_csr
|
||||
tags: pointwise
|
||||
|
||||
- func: angle.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
|
||||
device_check: NoCheck # TensorIterator
|
||||
dispatch:
|
||||
CPU, CUDA: angle_out
|
||||
MPS: angle_out_mps
|
||||
CPU, CUDA, MPS: angle_out
|
||||
SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: angle_sparse_csr_out
|
||||
tags: pointwise
|
||||
|
||||
@ -10042,8 +10040,7 @@
|
||||
structured: True
|
||||
structured_inherits: TensorIteratorBase
|
||||
dispatch:
|
||||
CPU, CUDA: hypot_out
|
||||
MPS: hypot_out_mps
|
||||
CPU, CUDA, MPS: hypot_out
|
||||
tags: pointwise
|
||||
|
||||
- func: hypot(Tensor self, Tensor other) -> Tensor
|
||||
@ -14362,12 +14359,12 @@
|
||||
python_module: linalg
|
||||
variants: function
|
||||
dispatch:
|
||||
CPU, CUDA: linalg_householder_product
|
||||
CPU, CUDA, MPS: linalg_householder_product
|
||||
|
||||
- func: linalg_householder_product.out(Tensor input, Tensor tau, *, Tensor(a!) out) -> Tensor(a!)
|
||||
python_module: linalg
|
||||
dispatch:
|
||||
CPU, CUDA: linalg_householder_product_out
|
||||
CPU, CUDA, MPS: linalg_householder_product_out
|
||||
|
||||
- func: linalg_inv_ex(Tensor A, *, bool check_errors=False) -> (Tensor inverse, Tensor info)
|
||||
python_module: linalg
|
||||
|
||||
@ -575,24 +575,9 @@ void spmm(
|
||||
cusparseOperation_t opB = transpose_B ? CUSPARSE_OPERATION_TRANSPOSE
|
||||
: CUSPARSE_OPERATION_NON_TRANSPOSE;
|
||||
|
||||
// CUDA < 11.0 doesn't support 64-bit indices and doesn't raise an error about this
|
||||
// silently returning incorrect results
|
||||
#if defined(USE_ROCM) && (ROCM_VERSION < 60300)
|
||||
auto mat1_32 = at::native::_sparse_csr_tensor_unsafe(
|
||||
mat1.crow_indices().to(kInt),
|
||||
mat1.col_indices().to(kInt),
|
||||
mat1.values(),
|
||||
mat1.sizes(),
|
||||
mat1.scalar_type(),
|
||||
mat1.layout(),
|
||||
mat1.device());
|
||||
auto descA = at::cuda::sparse::CuSparseSpMatCsrDescriptor(mat1_32);
|
||||
auto algorithm = CUSPARSE_MM_ALG_DEFAULT;
|
||||
#else // defined(USE_ROCM) && (ROCM_VERSION < 60300)
|
||||
// TODO: update this to support COO sparse layout
|
||||
auto descA = at::cuda::sparse::CuSparseSpMatCsrDescriptor(mat1);
|
||||
auto algorithm = CUSPARSE_SPMM_CSR_ALG2;
|
||||
#endif // defined(USE_ROCM) && (ROCM_VERSION < 60300)
|
||||
|
||||
auto descB = at::cuda::sparse::CuSparseConstDnMatDescriptor(
|
||||
transpose_B ? mat2_->mT() : *mat2_);
|
||||
|
||||
@ -33,7 +33,7 @@ using namespace mps;
|
||||
#ifndef PYTORCH_JIT_COMPILE_SHADERS
|
||||
static auto& lib = MetalShaderLibrary::getBundledLibrary();
|
||||
#else
|
||||
#include <ATen/native/mps/Mul_metallib.h>
|
||||
#include <ATen/native/mps/SparseTensorMath_metallib.h>
|
||||
#endif
|
||||
|
||||
static Tensor& s_addmm_out_sparse_dense_mps(
|
||||
@ -369,12 +369,7 @@ static SparseTensor& mul_out_dense_sparse_mps(
|
||||
}
|
||||
|
||||
if (scalar_like) {
|
||||
auto scalar = dense;
|
||||
if (dense.numel() == 1 && dense.dim() > 0) {
|
||||
scalar = dense.view({});
|
||||
}
|
||||
scalar = scalar.to(values.options());
|
||||
auto out_vals = values.mul(scalar);
|
||||
auto out_vals = values.mul(dense.to(values.options()));
|
||||
if (out.scalar_type() != commonDtype) {
|
||||
out_vals = out_vals.to(out.scalar_type());
|
||||
}
|
||||
@ -508,14 +503,14 @@ SparseTensor& mul_out_sparse_mps(const Tensor& t_, const Tensor& src_, SparseTen
|
||||
const auto device = r_.device();
|
||||
auto stream = getCurrentMPSStream();
|
||||
|
||||
auto lhs_indices = lhs._indices();
|
||||
auto rhs_indices = rhs._indices();
|
||||
auto lhs_values = lhs._values().to(commonDtype);
|
||||
auto rhs_values = rhs._values().to(commonDtype);
|
||||
auto lhs_indices = lhs._indices().contiguous();
|
||||
auto rhs_indices = rhs._indices().contiguous();
|
||||
auto lhs_values = lhs._values().to(commonDtype).contiguous();
|
||||
auto rhs_values = rhs._values().to(commonDtype).contiguous();
|
||||
|
||||
// Flatten sparse indices to keys
|
||||
auto lhs_keys = flatten_indices(lhs_indices, lhs.sizes());
|
||||
auto rhs_keys = flatten_indices(rhs_indices, rhs.sizes());
|
||||
auto lhs_keys = flatten_indices(lhs_indices, lhs.sizes().slice(0, ndim_i));
|
||||
auto rhs_keys = flatten_indices(rhs_indices, rhs.sizes().slice(0, ndim_i));
|
||||
|
||||
// Intersect sorted keys (search the shorter in the longer)
|
||||
const bool A_is_lhs = (lhs_nnz <= rhs_nnz);
|
||||
@ -546,35 +541,54 @@ SparseTensor& mul_out_sparse_mps(const Tensor& t_, const Tensor& src_, SparseTen
|
||||
auto out_indices = at::empty({ndim_i, static_cast<int64_t>(M)}, at::device(device).dtype(at::kLong));
|
||||
auto lhs_match = outA_idx.narrow(0, 0, M);
|
||||
auto rhs_match = outB_idx.narrow(0, 0, M);
|
||||
auto out_val_sizes = lhs_values.sizes().vec();
|
||||
out_val_sizes[0] = static_cast<int64_t>(M);
|
||||
auto dense_sizes_vec = lhs.sizes().slice(ndim_i).vec();
|
||||
int64_t cols64 = 1;
|
||||
for (auto s : dense_sizes_vec) cols64 *= s;
|
||||
const uint32_t cols = static_cast<uint32_t>(std::max<int64_t>(cols64, 1));
|
||||
|
||||
auto to2d = [&](Tensor t, int64_t nnz) -> Tensor {
|
||||
const int64_t t_cols = t.numel() / nnz;
|
||||
if (t_cols == cols64) {
|
||||
return t.view({nnz, cols64});
|
||||
}
|
||||
return t.view({nnz, 1}).expand({nnz, cols64}).contiguous();
|
||||
};
|
||||
|
||||
// make both sides 2d [nnz, cols] buffers so the kernel can index it
|
||||
auto lhs_vals2d = to2d(lhs_values, lhs_nnz);
|
||||
auto rhs_vals2d = to2d(rhs_values, rhs_nnz);
|
||||
|
||||
std::vector<int64_t> out_val_sizes;
|
||||
out_val_sizes.reserve(1 + dense_sizes_vec.size());
|
||||
out_val_sizes.push_back(static_cast<int64_t>(M));
|
||||
out_val_sizes.insert(out_val_sizes.end(), dense_sizes_vec.begin(), dense_sizes_vec.end());
|
||||
auto out_values = at::empty(out_val_sizes, lhs_values.options());
|
||||
|
||||
const uint32_t cols = static_cast<uint32_t>(
|
||||
lhs_values.numel() / std::max<int64_t>(1, lhs_nnz));
|
||||
if (M > 0) {
|
||||
dispatch_sync_with_rethrow(stream->queue(), ^() {
|
||||
@autoreleasepool {
|
||||
auto pso = lib.getPipelineStateForFunc(
|
||||
"fused_gather_mul_kernel_" + mps::scalarToMetalTypeString(lhs_values));
|
||||
auto enc = stream->commandEncoder();
|
||||
[enc setComputePipelineState:pso];
|
||||
|
||||
dispatch_sync_with_rethrow(stream->queue(), ^() {
|
||||
@autoreleasepool {
|
||||
auto pso = lib.getPipelineStateForFunc(
|
||||
"fused_gather_mul_kernel_" + mps::scalarToMetalTypeString(lhs_values));
|
||||
auto enc = stream->commandEncoder();
|
||||
[enc setComputePipelineState:pso];
|
||||
const uint32_t tew = pso.threadExecutionWidth;
|
||||
const uint32_t gridW = std::max<uint32_t>(cols, 1u);
|
||||
const uint32_t tgW = std::min(gridW, tew);
|
||||
MTLSize grid = MTLSizeMake(gridW, 1, M);
|
||||
MTLSize tgs = MTLSizeMake(tgW, 1, 1);
|
||||
|
||||
const uint32_t tew = pso.threadExecutionWidth;
|
||||
uint32_t tgW = std::min(cols, tew);
|
||||
MTLSize grid = MTLSizeMake(cols, 1, M);
|
||||
MTLSize tgs = MTLSizeMake(tgW, 1, 1);
|
||||
|
||||
mtl_setArgs(enc,
|
||||
lhs_values, rhs_values,
|
||||
lhs_match, rhs_match,
|
||||
lhs_indices, out_indices,
|
||||
out_values,
|
||||
std::array<uint32_t, 2>{static_cast<uint32_t>(ndim_i), static_cast<uint32_t>(lhs_nnz)},
|
||||
std::array<uint32_t, 2>{M, cols});
|
||||
[enc dispatchThreads:grid threadsPerThreadgroup:tgs];
|
||||
}
|
||||
});
|
||||
mtl_setArgs(enc,
|
||||
lhs_vals2d, rhs_vals2d,
|
||||
lhs_match, rhs_match,
|
||||
lhs_indices, out_indices,
|
||||
out_values,
|
||||
std::array<uint32_t, 2>{static_cast<uint32_t>(ndim_i), static_cast<uint32_t>(lhs_nnz)},
|
||||
std::array<uint32_t, 2>{M, cols});
|
||||
[enc dispatchThreads:grid threadsPerThreadgroup:tgs];
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
if (r_.scalar_type() != commonDtype) {
|
||||
out_values = out_values.to(r_.scalar_type());
|
||||
|
||||
@ -195,9 +195,9 @@ kernel void fused_gather_mul_kernel(
|
||||
const ulong offR = (ulong)iR * (ulong)view_cols + (ulong)col;
|
||||
const ulong offO = (ulong)k * (ulong)view_cols + (ulong)col;
|
||||
|
||||
const float a = (float)lhs_vals[offL];
|
||||
const float b = (float)rhs_vals[offR];
|
||||
out_vals[offO] = (T)(a * b);
|
||||
const auto a = static_cast<accum_t<T>>(lhs_vals[offL]);
|
||||
const auto b = static_cast<accum_t<T>>(rhs_vals[offR]);
|
||||
out_vals[offO] = static_cast<T>(mul(a, b));
|
||||
}
|
||||
|
||||
// One thread per match copies the indices column
|
||||
@ -76,14 +76,21 @@ bool priority_order_init_ = false;
|
||||
// TODO(eqy): more benchmarking to determine whether this should include sm86/89
|
||||
// Needs to be kept in-sync with test_fused_chocie in test_transformers.py
|
||||
bool check_prefer_cudnn_attention() {
|
||||
static const bool prefer_cudnn = c10::utils::check_env("TORCH_CUDNN_SDPA_PREFERRED") != false;
|
||||
static const bool prefer_cudnn = c10::utils::check_env("TORCH_CUDNN_SDPA_DEPRIORITIZED") != true;
|
||||
if (!prefer_cudnn) {
|
||||
return false;
|
||||
}
|
||||
#if (defined(CUDNN_VERSION) && (CUDNN_VERSION >= 90900))
|
||||
auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||
auto major = dprops->major;
|
||||
return (major == 9 || major == 10) && !dprops->minor;
|
||||
try {
|
||||
auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||
auto major = dprops->major;
|
||||
return (major == 9 || major == 10) && !dprops->minor;
|
||||
} catch (c10::Error const& e) {
|
||||
#ifdef DEBUG
|
||||
TORCH_WARN("check_prefer_cudnn_attention() caught exception ", e.what());
|
||||
#endif
|
||||
return false;
|
||||
}
|
||||
#else
|
||||
return false;
|
||||
#endif
|
||||
|
||||
@ -37,6 +37,10 @@ TEST(SingletonOrSharedTypePtr, Comparison) {
|
||||
|
||||
EXPECT_NE(empty, p);
|
||||
EXPECT_NE(p, p2);
|
||||
|
||||
EXPECT_EQ(empty, empty);
|
||||
EXPECT_EQ(p, p);
|
||||
EXPECT_EQ(p2, p2);
|
||||
}
|
||||
|
||||
TEST(SingletonOrSharedTypePtr, SingletonComparison) {
|
||||
@ -47,6 +51,8 @@ TEST(SingletonOrSharedTypePtr, SingletonComparison) {
|
||||
c10::TypePtr type = c10::NoneType::get();
|
||||
EXPECT_NE(type, c10::StringType::get());
|
||||
EXPECT_NE(type, c10::DeviceObjType::get());
|
||||
EXPECT_EQ(type, type);
|
||||
EXPECT_EQ(type, c10::NoneType::get());
|
||||
}
|
||||
|
||||
|
||||
|
||||
@ -526,6 +526,41 @@ namespace {
|
||||
[](const vec& v) { return v.expm1(); },
|
||||
createDefaultUnaryTestCase<vec>(TestSeed(), false, true));
|
||||
}
|
||||
TYPED_TEST(Exponents, ExpU20) {
|
||||
using vec = TypeParam;
|
||||
using VT = ValueType<TypeParam>;
|
||||
using UVT = UvalueType<TypeParam>;
|
||||
|
||||
// Explicit edge values
|
||||
VT v_too_small = VT(-100.0); // much less than -87.3
|
||||
VT exp_too_small = std::exp(v_too_small);
|
||||
VT v_neg_edge = VT(-0x1.5d5e2ap+6f); // just at the edge
|
||||
VT exp_neg_edge = std::exp(v_neg_edge);
|
||||
VT v_zero = VT(0.0); // middle, normal case
|
||||
VT exp_zero = std::exp(v_zero);
|
||||
VT v_pos_edge = VT(0x1.5d5e2ap+6f); // just at the edge
|
||||
VT exp_pos_edge = std::exp(v_pos_edge);
|
||||
VT v_too_large = VT(100.0); // much more than 87.3
|
||||
VT exp_too_large = std::exp(v_too_large);
|
||||
|
||||
auto test_case = TestingCase<vec>::getBuilder()
|
||||
// Randoms in normal range, but the .addCustom() below guarantees we hit the special/fallback cases
|
||||
.addDomain(CheckWithinDomains<UVT>{{{-100, 100}}, false, getDefaultTolerance<UVT>()})
|
||||
.addCustom({ {v_too_small}, exp_too_small })
|
||||
.addCustom({ {v_neg_edge}, exp_neg_edge })
|
||||
.addCustom({ {v_zero}, exp_zero })
|
||||
.addCustom({ {v_pos_edge}, exp_pos_edge })
|
||||
.addCustom({ {v_too_large}, exp_too_large })
|
||||
.setTrialCount(65536)
|
||||
.setTestSeed(TestSeed());
|
||||
|
||||
test_unary<vec>(
|
||||
NAME_INFO(exp_u20_edge_cases),
|
||||
RESOLVE_OVERLOAD(std::exp),
|
||||
[](const vec& v) { return v.exp_u20(); },
|
||||
test_case
|
||||
);
|
||||
}
|
||||
TYPED_TEST(ErrorFunctions, Erf) {
|
||||
using vec = TypeParam;
|
||||
test_unary<vec>(
|
||||
|
||||
@ -58,8 +58,7 @@ def list_benchmarks():
|
||||
|
||||
def run_benchmark(
|
||||
benchmark_name: str,
|
||||
should_visualize: bool = False,
|
||||
compile_mode: str = "max-autotune-no-cudagraphs",
|
||||
script_args,
|
||||
):
|
||||
"""Run a specific benchmark."""
|
||||
if benchmark_name not in BENCHMARK_REGISTRY:
|
||||
@ -68,29 +67,29 @@ def run_benchmark(
|
||||
return False
|
||||
|
||||
print(f"Running benchmark: {benchmark_name}")
|
||||
print(f"Torch compile mode: {compile_mode}")
|
||||
print(f"Torch compile mode: {script_args.compile_mode}")
|
||||
print("=" * 60)
|
||||
|
||||
benchmark_class = BENCHMARK_REGISTRY[benchmark_name]
|
||||
benchmark = benchmark_class(compile_mode)
|
||||
benchmark = benchmark_class(script_args)
|
||||
benchmark.benchmark()
|
||||
if should_visualize:
|
||||
if script_args.visualize:
|
||||
benchmark.visualize()
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def run_all_benchmarks(should_visualize: bool = False, compile_mode: str = "default"):
|
||||
def run_all_benchmarks(script_args):
|
||||
"""Run all available benchmarks."""
|
||||
print("Running all benchmarks...")
|
||||
print(f"Torch compile mode: {compile_mode}")
|
||||
print(f"Torch compile mode: {script_args.compile_mode}")
|
||||
print("=" * 60)
|
||||
|
||||
for name, cls in BENCHMARK_REGISTRY.items():
|
||||
print(f"\n{'=' * 20} {name.upper()} {'=' * 20}")
|
||||
benchmark = cls(compile_mode)
|
||||
benchmark = cls(script_args)
|
||||
benchmark.benchmark()
|
||||
if should_visualize:
|
||||
if script_args.visualize:
|
||||
benchmark.visualize()
|
||||
print()
|
||||
|
||||
@ -137,6 +136,19 @@ Examples:
|
||||
help="Torch compile mode to use (default: default)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--tolerance",
|
||||
type=float,
|
||||
default=None,
|
||||
help="Tolerance for the accuracy check",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exit-on-accuracy-failure",
|
||||
action="store_true",
|
||||
help="Whether to exit with an error message for accuracy failure",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Handle list option
|
||||
@ -146,7 +158,7 @@ Examples:
|
||||
|
||||
# Handle all option
|
||||
if args.all:
|
||||
run_all_benchmarks(args.visualize, args.compile_mode)
|
||||
run_all_benchmarks(args)
|
||||
return
|
||||
|
||||
# Handle specific benchmarks
|
||||
@ -157,7 +169,7 @@ Examples:
|
||||
sys.exit(1)
|
||||
|
||||
for benchmark_name in args.benchmarks:
|
||||
run_benchmark(benchmark_name, args.visualize, args.compile_mode)
|
||||
run_benchmark(benchmark_name, args)
|
||||
print() # Add spacing between benchmarks
|
||||
|
||||
|
||||
|
||||
@ -9,8 +9,8 @@ import torch.nn.functional as F
|
||||
|
||||
|
||||
class CrossEntropyForward(BenchmarkKernel):
|
||||
def __init__(self, compile_mode: str = "max-autotune-no-cudagraphs"):
|
||||
super().__init__(compile_mode)
|
||||
def __init__(self, script_args):
|
||||
super().__init__(script_args)
|
||||
self.available_backends = ["eager", "compiled", "quack", "liger"]
|
||||
|
||||
def get_shapes(self) -> tuple[tuple[int, ...], ...]:
|
||||
@ -106,8 +106,8 @@ class CrossEntropyForward(BenchmarkKernel):
|
||||
|
||||
|
||||
class CrossEntropyBackward(BenchmarkKernel):
|
||||
def __init__(self, compile_mode: str = "max-autotune-no-cudagraphs"):
|
||||
super().__init__(compile_mode)
|
||||
def __init__(self, script_args):
|
||||
super().__init__(script_args)
|
||||
self.available_backends = ["eager", "compiled", "quack", "liger"]
|
||||
|
||||
def get_shapes(self) -> tuple[tuple[int, ...], ...]:
|
||||
@ -194,8 +194,8 @@ class CrossEntropyBackward(BenchmarkKernel):
|
||||
|
||||
|
||||
class SoftmaxForward(BenchmarkKernel):
|
||||
def __init__(self, compile_mode: str = "max-autotune-no-cudagraphs"):
|
||||
super().__init__(compile_mode)
|
||||
def __init__(self, script_args):
|
||||
super().__init__(script_args)
|
||||
self.available_backends = ["eager", "compiled", "quack", "liger"]
|
||||
|
||||
def get_shapes(self) -> tuple[tuple[int, ...], ...]:
|
||||
@ -259,8 +259,8 @@ class SoftmaxForward(BenchmarkKernel):
|
||||
|
||||
|
||||
class SoftmaxBackward(BenchmarkKernel):
|
||||
def __init__(self, compile_mode: str = "max-autotune-no-cudagraphs"):
|
||||
super().__init__(compile_mode)
|
||||
def __init__(self, script_args):
|
||||
super().__init__(script_args)
|
||||
self.available_backends = ["eager", "compiled", "quack", "liger"]
|
||||
|
||||
def get_shapes(self) -> tuple[tuple[int, ...], ...]:
|
||||
@ -329,8 +329,8 @@ class SoftmaxBackward(BenchmarkKernel):
|
||||
|
||||
|
||||
class RMSNormForward(BenchmarkKernel):
|
||||
def __init__(self, compile_mode: str = "max-autotune-no-cudagraphs"):
|
||||
super().__init__(compile_mode)
|
||||
def __init__(self, script_args):
|
||||
super().__init__(script_args)
|
||||
self.available_backends = ["eager", "compiled", "quack", "liger"]
|
||||
|
||||
def get_shapes(self) -> tuple[tuple[int, ...], ...]:
|
||||
@ -383,7 +383,22 @@ class RMSNormForward(BenchmarkKernel):
|
||||
from quack.rmsnorm import _rmsnorm_fwd
|
||||
|
||||
x, w = args
|
||||
return lambda: _rmsnorm_fwd(x, w, eps=1e-6)
|
||||
y = torch.empty_like(x)
|
||||
|
||||
def quack_fwd():
|
||||
_rmsnorm_fwd(
|
||||
x,
|
||||
w,
|
||||
out=y,
|
||||
bias=None,
|
||||
rstd=None,
|
||||
residual=None,
|
||||
residual_out=None,
|
||||
eps=1e-6,
|
||||
)
|
||||
return y
|
||||
|
||||
return quack_fwd
|
||||
|
||||
def liger(self, args, kwargs) -> Any:
|
||||
from liger_kernel.transformers.rms_norm import LigerRMSNorm
|
||||
@ -404,9 +419,14 @@ class RMSNormForward(BenchmarkKernel):
|
||||
|
||||
|
||||
class RMSNormBackward(BenchmarkKernel):
|
||||
def __init__(self, compile_mode: str = "max-autotune-no-cudagraphs"):
|
||||
super().__init__(compile_mode)
|
||||
self.available_backends = ["eager", "compiled", "quack", "liger"]
|
||||
def __init__(self, script_args):
|
||||
super().__init__(script_args)
|
||||
self.available_backends = [
|
||||
"eager",
|
||||
"compiled",
|
||||
"quack",
|
||||
"liger",
|
||||
]
|
||||
|
||||
def get_shapes(self) -> tuple[tuple[int, ...], ...]:
|
||||
# TODO: OOM for (32768, 65536) on h100
|
||||
@ -454,8 +474,11 @@ class RMSNormBackward(BenchmarkKernel):
|
||||
y, [x, w], grad_outputs=dy, retain_graph=True
|
||||
)
|
||||
|
||||
def compute_rstd(self, x, eps):
|
||||
return torch.rsqrt(torch.mean(x.float().square(), dim=-1, keepdim=True) + eps)
|
||||
|
||||
def quack(self, args, kwargs=None) -> Any:
|
||||
from quack.rmsnorm import _rmsnorm_backward
|
||||
from quack.rmsnorm import _get_sm_count, _rmsnorm_bwd
|
||||
|
||||
(
|
||||
x,
|
||||
@ -463,15 +486,40 @@ class RMSNormBackward(BenchmarkKernel):
|
||||
dy,
|
||||
) = args
|
||||
M, N = x.shape
|
||||
rstd = torch.randn(M, device="cuda", dtype=torch.float32)
|
||||
return lambda: _rmsnorm_backward(x, w, dy, rstd)
|
||||
|
||||
rstd = self.compute_rstd(x, eps=1e-6)
|
||||
dx = torch.empty_like(x)
|
||||
sm_count = _get_sm_count(x.size(1), x.device)
|
||||
dw_partial = torch.empty(
|
||||
sm_count, x.size(1), device=x.device, dtype=torch.float32
|
||||
)
|
||||
|
||||
def quack_bwd():
|
||||
_rmsnorm_bwd(
|
||||
x,
|
||||
w,
|
||||
dy,
|
||||
rstd,
|
||||
dx,
|
||||
dw_partial,
|
||||
db_partial=None,
|
||||
dresidual_out=None,
|
||||
dresidual=None,
|
||||
sm_count=sm_count,
|
||||
)
|
||||
dw = dw_partial.sum(dim=0).to(w.dtype)
|
||||
return dx, dw
|
||||
|
||||
return quack_bwd
|
||||
|
||||
def liger(self, args, kwargs=None) -> Any:
|
||||
from liger_kernel.transformers.rms_norm import LigerRMSNorm
|
||||
|
||||
x, w, dy = args
|
||||
M, N = x.shape
|
||||
liger_rmsnorm = LigerRMSNorm(hidden_size=N, eps=1e-6).cuda()
|
||||
liger_rmsnorm = LigerRMSNorm(
|
||||
hidden_size=N, eps=1e-6, casting_mode="gemma"
|
||||
).cuda()
|
||||
liger_rmsnorm.weight.data.copy_(w)
|
||||
y = liger_rmsnorm(x)
|
||||
return lambda: torch.autograd.grad(
|
||||
@ -489,8 +537,8 @@ class RMSNormBackward(BenchmarkKernel):
|
||||
|
||||
|
||||
class LayerNormForward(BenchmarkKernel):
|
||||
def __init__(self, compile_mode: str = "max-autotune-no-cudagraphs"):
|
||||
super().__init__(compile_mode)
|
||||
def __init__(self, script_args):
|
||||
super().__init__(script_args)
|
||||
self.available_backends = ["eager", "compiled", "quack", "liger"]
|
||||
|
||||
def get_shapes(self) -> tuple[tuple[int, ...], ...]:
|
||||
@ -563,8 +611,8 @@ class LayerNormForward(BenchmarkKernel):
|
||||
|
||||
|
||||
class LayerNormBackward(BenchmarkKernel):
|
||||
def __init__(self, compile_mode: str = "max-autotune-no-cudagraphs"):
|
||||
super().__init__(compile_mode)
|
||||
def __init__(self, script_args):
|
||||
super().__init__(script_args)
|
||||
self.available_backends = ["eager", "compiled", "liger"]
|
||||
|
||||
def get_shapes(self) -> tuple[tuple[int, ...], ...]:
|
||||
@ -614,20 +662,31 @@ class LayerNormBackward(BenchmarkKernel):
|
||||
y, [x, w], grad_outputs=dy, retain_graph=True
|
||||
)
|
||||
|
||||
def compute_mean_rstd(self, x, eps):
|
||||
x = x.float()
|
||||
|
||||
var, mean = torch.var_mean(x, dim=-1, keepdim=True, correction=0)
|
||||
rstd = torch.rsqrt(var + eps)
|
||||
return mean, rstd
|
||||
|
||||
def liger(self, args, kwargs) -> Any:
|
||||
from liger_kernel.transformers.layer_norm import LigerLayerNorm
|
||||
"""
|
||||
Call layer_norm_backward directly rather than calling
|
||||
liger_kernel.transformers.layer_norm.LigerLayerNorm and
|
||||
torch.autograd.grad.
|
||||
|
||||
The latter fashion saves mean/rstd in x.dtype which can fail
|
||||
accuracy test. We call layer_norm_backward with fp32 mean and
|
||||
rstd.
|
||||
"""
|
||||
from liger_kernel.ops.layer_norm import layer_norm_backward
|
||||
|
||||
x, w, dy = args
|
||||
eps = 1e-6
|
||||
mean, rstd = self.compute_mean_rstd(x, eps)
|
||||
M, N = x.shape
|
||||
liger_layernorm = LigerLayerNorm(hidden_size=N, eps=1e-6).cuda()
|
||||
liger_layernorm.weight.data.copy_(w)
|
||||
liger_layernorm.bias.data.copy_(
|
||||
torch.zeros(N, device="cuda", dtype=torch.float32)
|
||||
)
|
||||
y = liger_layernorm(x)
|
||||
return lambda: torch.autograd.grad(
|
||||
y, [x, liger_layernorm.weight], grad_outputs=dy, retain_graph=True
|
||||
)
|
||||
|
||||
return lambda: layer_norm_backward(dy, x, w, None, mean, rstd)[0:2]
|
||||
|
||||
def benchmark(self):
|
||||
for M, N in self.get_shapes():
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import os
|
||||
import sys
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
@ -43,10 +44,11 @@ class Performance:
|
||||
|
||||
|
||||
class BenchmarkKernel:
|
||||
def __init__(self, compile_mode: str = "max-autotune-no-cudagraphs"):
|
||||
def __init__(self, script_args):
|
||||
self.script_args = script_args
|
||||
self.name = self.__class__.__name__
|
||||
self.available_backends: list[str] = []
|
||||
self.compile_mode: str = compile_mode
|
||||
self.compile_mode: str = script_args.compile_mode
|
||||
|
||||
# mapping from backend to list of performance results
|
||||
self.profiling_results: defaultdict[str, list[Performance]] = defaultdict(list)
|
||||
@ -106,14 +108,21 @@ class BenchmarkKernel:
|
||||
args_ref, kwargs_ref = self.clone_inputs(args, kwargs)
|
||||
res[backend] = getattr(self, backend)(args_ref, kwargs_ref)()
|
||||
gold = res["eager"]
|
||||
|
||||
tol = {}
|
||||
if self.script_args.tolerance:
|
||||
tol = {
|
||||
"atol": self.script_args.tolerance,
|
||||
"rtol": self.script_args.tolerance,
|
||||
}
|
||||
for backend in self.available_backends:
|
||||
if backend == "eager":
|
||||
continue
|
||||
try:
|
||||
torch.testing.assert_close(res[backend], gold)
|
||||
torch.testing.assert_close(res[backend], gold, **tol)
|
||||
for t, gold_t in zip(res[backend], gold):
|
||||
if t.requires_grad:
|
||||
torch.testing.assert_close(t.grad, gold_t.grad)
|
||||
torch.testing.assert_close(t.grad, gold_t.grad, **tol)
|
||||
print(
|
||||
f"Accuracy check \033[92m✓ succeed\033[0m for {backend} backend on {self.name} kernel"
|
||||
)
|
||||
@ -121,6 +130,9 @@ class BenchmarkKernel:
|
||||
print(
|
||||
f"Accuracy check \033[91m✗ failed\033[0m for {backend} backend on {self.name} kernel. Error {e}"
|
||||
)
|
||||
if self.script_args.exit_on_accuracy_failure:
|
||||
print("Exit right away since --exit-on-accuracy-failure is set")
|
||||
sys.exit(1)
|
||||
|
||||
def benchmark_single_shape(
|
||||
self, args, kwargs=None, should_check_accuracy=True, setting: str = ""
|
||||
|
||||
@ -43,6 +43,7 @@ tolerance:
|
||||
- doctr_reco_predictor
|
||||
- drq
|
||||
- phlippe_resnet
|
||||
- pytorch_CycleGAN_and_pix2pix
|
||||
|
||||
higher_bf16:
|
||||
- doctr_reco_predictor
|
||||
|
||||
@ -44,9 +44,17 @@ PyTorch,div_,div__M1_N1_K1_cpu_dtype_onetorch.float32_dtype_twotorch.float32,sho
|
||||
PyTorch,div_,div__M64_N64_K64_cpu_dtype_onetorch.float32_dtype_twotorch.float32,short,False,59.241161,0.000000
|
||||
PyTorch,div_,div__M64_N64_K128_cpu_dtype_onetorch.float32_dtype_twotorch.float32,short,False,59.852816,0.000000
|
||||
PyTorch,add,"add_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.float32",short,False,57.006677,0.000000
|
||||
PyTorch,add,"add_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.bfloat16",short,False,88.167000,0.000000
|
||||
PyTorch,add,"add_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.float64",short,False,57.519000,0.000000
|
||||
PyTorch,sub,"sub_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.float32",short,False,55.606088,0.000000
|
||||
PyTorch,sub,"sub_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.bfloat16",short,False,86.551000,0.000000
|
||||
PyTorch,sub,"sub_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.float64",short,False,57.864088,0.000000
|
||||
PyTorch,div,"div_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.float32",short,False,58.529255,0.000000
|
||||
PyTorch,div,"div_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.bfloat16",short,False,71.641000,0.000000
|
||||
PyTorch,div,"div_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.float64",short,False,83.073000,0.000000
|
||||
PyTorch,mul,"mul_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.float32",short,False,54.645077,0.000000
|
||||
PyTorch,mul,"mul_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.bfloat16",short,False,67.570000,0.000000
|
||||
PyTorch,mul,"mul_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.float64",short,False,57.895000,0.000000
|
||||
PyTorch,add,add_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,4.397014,0.000000
|
||||
PyTorch,add,add_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,7.739000,0.000000
|
||||
PyTorch,add,add_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,7.786000,0.000000
|
||||
|
||||
|
@ -25,7 +25,7 @@ binary_configs_broadcast = op_bench.config_list(
|
||||
],
|
||||
cross_product_configs={
|
||||
"device": ["cpu"],
|
||||
"dtype": [torch.float],
|
||||
"dtype": [torch.float, torch.bfloat16, torch.float64],
|
||||
},
|
||||
tags=["short"],
|
||||
)
|
||||
|
||||
@ -176,8 +176,8 @@ THIRD_PARTY_LIBS = {
|
||||
"omp": ["//xplat/third-party/linker_lib:omp", "//third_party:no-op"],
|
||||
"pocketfft": ["//third-party/pocket_fft:pocketfft", "//third_party:pocketfft_header"],
|
||||
"psimd": ["//xplat/third-party/psimd:psimd", "//third_party:psimd"],
|
||||
"pthreadpool": ["//xplat/third-party/pthreadpool:pthreadpool", "//third_party:pthreadpool"],
|
||||
"pthreadpool_header": ["//xplat/third-party/pthreadpool:pthreadpool_header", "//third_party:pthreadpool_header"],
|
||||
"pthreadpool": ["fbsource//xplat/third-party/pthreadpool:pthreadpool", "//third_party:pthreadpool"],
|
||||
"pthreadpool_header": ["fbsource//xplat/third-party/pthreadpool:pthreadpool_header", "//third_party:pthreadpool_header"],
|
||||
"moodycamel": ["//third-party/moodycamel:moodycamel", "//third_party:moodycamel"],
|
||||
"pyyaml": ["//third-party/pypi/pyyaml:pyyaml", "//third_party:pyyaml"],
|
||||
"rt": ["//xplat/third-party/linker_lib:rt", "//third_party:rt"],
|
||||
|
||||
@ -855,6 +855,7 @@ libtorch_python_cuda_core_sources = [
|
||||
"torch/csrc/cuda/Stream.cpp",
|
||||
"torch/csrc/cuda/Graph.cpp",
|
||||
"torch/csrc/cuda/MemPool.cpp",
|
||||
"torch/csrc/cuda/GreenContext.cpp",
|
||||
"torch/csrc/cuda/shared/cudart.cpp",
|
||||
"torch/csrc/cuda/shared/nvtx.cpp",
|
||||
"torch/csrc/cuda/utils.cpp",
|
||||
|
||||
@ -51,6 +51,17 @@
|
||||
|
||||
#if defined(CUDA_VERSION) && (CUDA_VERSION >= 12030)
|
||||
#define C10_LIBCUDA_DRIVER_API_OPTIONAL(_) \
|
||||
_(cuCtxFromGreenCtx, 12080) \
|
||||
_(cuCtxGetCurrent, 12080) \
|
||||
_(cuCtxPopCurrent, 12080) \
|
||||
_(cuCtxPushCurrent, 12080) \
|
||||
_(cuCtxSetCurrent, 12080) \
|
||||
_(cuGreenCtxCreate, 12080) \
|
||||
_(cuGreenCtxDestroy, 12080) \
|
||||
_(cuDevSmResourceSplitByCount, 12080) \
|
||||
_(cuDeviceGet, 12080) \
|
||||
_(cuDeviceGetDevResource, 12080) \
|
||||
_(cuDevResourceGenerateDesc, 12080) \
|
||||
_(cuMulticastAddDevice, 12030) \
|
||||
_(cuMulticastBindMem, 12030) \
|
||||
_(cuMulticastCreate, 12030) \
|
||||
|
||||
@ -328,6 +328,21 @@ struct pair {
|
||||
T2 second;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
static T conj(T a) {
|
||||
return a;
|
||||
}
|
||||
|
||||
template <>
|
||||
half2 conj(half2 a) {
|
||||
return half2(a.x, -a.y);
|
||||
}
|
||||
|
||||
template <>
|
||||
float2 conj(float2 a) {
|
||||
return float2(a.x, -a.y);
|
||||
}
|
||||
|
||||
#define INSTANTIATE_FOR_ALL_TYPES(MACRO) \
|
||||
MACRO(float); \
|
||||
MACRO(half); \
|
||||
|
||||
@ -607,6 +607,12 @@ if(USE_CUDA)
|
||||
set_source_files_properties(${ASYNC_MM_FILE} PROPERTIES COMPILE_FLAGS "-gencode arch=compute_90a,code=sm_90a")
|
||||
endif()
|
||||
endif()
|
||||
if(NOT WIN32)
|
||||
set_source_files_properties(
|
||||
${TORCH_ROOT}/aten/src/ATen/cuda/CUDAGreenContext.cpp
|
||||
PROPERTIES COMPILE_FLAGS "-DPYTORCH_C10_DRIVER_API_SUPPORTED=1"
|
||||
)
|
||||
endif()
|
||||
set_source_files_properties(
|
||||
${TORCH_ROOT}/aten/src/ATen/cuda/detail/LazyNVRTC.cpp
|
||||
PROPERTIES COMPILE_DEFINITIONS "NVRTC_SHORTHASH=${CUDA_NVRTC_SHORTHASH}"
|
||||
|
||||
@ -1638,38 +1638,7 @@ if(USE_KINETO)
|
||||
message(STATUS " KINETO_LIBRARY_TYPE = ${KINETO_LIBRARY_TYPE}")
|
||||
|
||||
if(NOT LIBKINETO_NOCUPTI)
|
||||
set(CUDA_SOURCE_DIR "${CUDA_TOOLKIT_ROOT_DIR}" CACHE STRING "")
|
||||
message(STATUS " CUDA_SOURCE_DIR = ${CUDA_SOURCE_DIR}")
|
||||
message(STATUS " CUDA_INCLUDE_DIRS = ${CUDA_INCLUDE_DIRS}")
|
||||
|
||||
if(NOT MSVC)
|
||||
if(USE_CUPTI_SO)
|
||||
set(CUPTI_LIB_NAME "libcupti.so")
|
||||
else()
|
||||
set(CUPTI_LIB_NAME "libcupti_static.a")
|
||||
endif()
|
||||
else()
|
||||
set(CUPTI_LIB_NAME "cupti.lib")
|
||||
endif()
|
||||
|
||||
find_library(CUPTI_LIBRARY_PATH ${CUPTI_LIB_NAME} PATHS
|
||||
${CUDA_SOURCE_DIR}
|
||||
${CUDA_SOURCE_DIR}/extras/CUPTI/lib64
|
||||
${CUDA_SOURCE_DIR}/lib
|
||||
${CUDA_SOURCE_DIR}/lib64
|
||||
NO_DEFAULT_PATH)
|
||||
|
||||
find_path(CUPTI_INCLUDE_DIR cupti.h PATHS
|
||||
${CUDA_SOURCE_DIR}/extras/CUPTI/include
|
||||
${CUDA_INCLUDE_DIRS}
|
||||
${CUDA_SOURCE_DIR}
|
||||
${CUDA_SOURCE_DIR}/include
|
||||
NO_DEFAULT_PATH)
|
||||
|
||||
if(CUPTI_LIBRARY_PATH AND CUPTI_INCLUDE_DIR)
|
||||
message(STATUS " CUPTI_INCLUDE_DIR = ${CUPTI_INCLUDE_DIR}")
|
||||
set(CUDA_cupti_LIBRARY ${CUPTI_LIBRARY_PATH})
|
||||
message(STATUS " CUDA_cupti_LIBRARY = ${CUDA_cupti_LIBRARY}")
|
||||
if(TARGET CUDA::cupti)
|
||||
message(STATUS "Found CUPTI")
|
||||
set(LIBKINETO_NOCUPTI OFF CACHE STRING "" FORCE)
|
||||
|
||||
@ -1682,7 +1651,7 @@ if(USE_KINETO)
|
||||
if(NOT APPLE)
|
||||
set(CMAKE_REQUIRED_LIBRARIES ${CMAKE_REQUIRED_LIBRARIES} "dl" "pthread")
|
||||
endif()
|
||||
set(CMAKE_REQUIRED_LINK_OPTIONS "-Wl,--whole-archive,${CUPTI_LIBRARY_PATH},--no-whole-archive")
|
||||
set(CMAKE_REQUIRED_LIBRARIES ${CMAKE_REQUIRED_LIBRARIES} $<LINK_LIBRARY:WHOLE_ARCHIVE,CUDA::cupti_static>)
|
||||
check_cxx_source_runs("#include <stdexcept>
|
||||
int main() {
|
||||
try {
|
||||
|
||||
@ -272,7 +272,7 @@ Here, we'll briefly introduce the implementation process of custom operators, fo
|
||||
* Name: `input`
|
||||
* Output Type: `Tensor`
|
||||
|
||||
2. **Register Operator&Autograd Fallback:**
|
||||
2. **Register Operator**
|
||||
|
||||
::::{tab-set}
|
||||
|
||||
@ -285,19 +285,11 @@ Here, we'll briefly introduce the implementation process of custom operators, fo
|
||||
:end-before: LITERALINCLUDE END: CUSTOM OPERATOR DEFAULT
|
||||
:linenos:
|
||||
|
||||
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/csrc/aten/OpenRegExtra.cpp
|
||||
:language: c++
|
||||
:start-after: LITERALINCLUDE START: CUSTOM OPERATOR FALLBACK
|
||||
:end-before: LITERALINCLUDE END: CUSTOM OPERATOR FALLBACK
|
||||
:emphasize-lines: 2
|
||||
:linenos:
|
||||
```
|
||||
|
||||
:::
|
||||
|
||||
::::
|
||||
|
||||
Use `TORCH_LIBRARY_IMPL` to register the `wrapper_custom_abs` implementation for the `custom_abs` operator in `PrivateUse1`. However, because `Autograd` is always enabled in PyTorch, PyTorch defaults to finding and executing the corresponding backward implementation even if only forward computation is required(will fallthrough in backward implementation). Therefore, we also need to register the corresponding implementation for `AutogradPrivateUse1` of the `custom_abs` operator. Fortunately, PyTorch also provides a general `Autograd Fallback` mechanism named `torch::autograd::autogradNotImplementedFallback`, if only forward computation is involved, it is equivalent to a fallthrough operation, selecting the next DispatchKey for computation; if backward computation is involved, an error is thrown.
|
||||
Use `TORCH_LIBRARY_IMPL` to register the `wrapper_custom_abs` implementation for the `custom_abs` operator in `PrivateUse1`. Because `Autograd` is always enabled in PyTorch, PyTorch defaults to finding and executing the corresponding backward implementation even if only forward computation is required(will fallthrough in backward implementation). Fortunately, PyTorch have implemented a general `Autograd Fallback` for PrivateUse1 as well, if only forward computation is involved, it is equivalent to a fallthrough operation, selecting the next DispatchKey for computation; if backward computation is involved, an error is thrown.
|
||||
|
||||
3. **Register Metadata(optional, but required by the graph mode, etc.):**
|
||||
|
||||
|
||||
@ -333,6 +333,11 @@ AArch64 CPU
|
||||
|
||||
- Sunita Nadampalli (`snadampal <https://github.com/snadampal>`__)
|
||||
|
||||
Out-of-tree Backend Integration (PrivateUse1)
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
- Jiawei Li (`fffrog <https://github.com/fffrog>`__)
|
||||
|
||||
Docs / Tutorials
|
||||
~~~~~~~~~~~~~~~~
|
||||
|
||||
|
||||
@ -258,6 +258,28 @@ See the docs for {class}`~torch.cuda.gds.GdsFile` for an example of how to use t
|
||||
|
||||
```
|
||||
|
||||
## Green Contexts (experimental)
|
||||
|
||||
`torch.cuda.green_contexts` provides thin wrappers around the CUDA Green Context APIs
|
||||
to enable more general carveout of SM resources for CUDA kernels.
|
||||
|
||||
These APIs can be used in PyTorch with CUDA versions greater than or equal to 12.8.
|
||||
|
||||
See the docs for {class}`~torch.cuda.green_contexts.GreenContext` for an example of how to use these.
|
||||
|
||||
```{eval-rst}
|
||||
.. currentmodule:: torch.cuda.green_contexts
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. autosummary::
|
||||
:toctree: generated
|
||||
:nosignatures:
|
||||
|
||||
GreenContext
|
||||
```
|
||||
|
||||
|
||||
% This module needs to be documented. Adding here in the meantime
|
||||
|
||||
% for tracking purposes
|
||||
@ -270,6 +292,10 @@ See the docs for {class}`~torch.cuda.gds.GdsFile` for an example of how to use t
|
||||
.. py:module:: torch.cuda.gds
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. py:module:: torch.cuda.green_contexts
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. py:module:: torch.cuda.jiterator
|
||||
```
|
||||
|
||||
@ -41,6 +41,7 @@ torch.distributed.fsdp.fully_shard <distributed.fsdp.fully_shard>
|
||||
torch.distributed.tensor.parallel <distributed.tensor.parallel>
|
||||
torch.distributed.optim <distributed.optim>
|
||||
torch.distributed.pipelining <distributed.pipelining>
|
||||
torch.distributed._symmetric_memory <symmetric_memory>
|
||||
torch.distributed.checkpoint <distributed.checkpoint>
|
||||
torch.distributions <distributions>
|
||||
torch.compiler <torch.compiler>
|
||||
|
||||
380
docs/source/symmetric_memory.md
Normal file
380
docs/source/symmetric_memory.md
Normal file
@ -0,0 +1,380 @@
|
||||
```{eval-rst}
|
||||
.. role:: hidden
|
||||
:class: hidden-section
|
||||
```
|
||||
|
||||
# PyTorch Symmetric Memory
|
||||
|
||||
:::{note}
|
||||
`torch.distributed._symmetric_memory` is currently in alpha state and under
|
||||
development. API changes may be possible.
|
||||
:::
|
||||
|
||||
## Why Symmetric Memory?
|
||||
|
||||
With rapidly evolving parallelization techniques, existing frameworks and
|
||||
libraries often struggle to keep up, and developers increasingly rely on custom
|
||||
implementations directly scheduling communications and computations. In recent
|
||||
years we’ve witnessed a shift from primarily relying on one-dimensional
|
||||
data-parallelism techniques to multi-dimensional parallelism ones. The latter
|
||||
have different latency requirements for different types of communications and
|
||||
thus require fine-grained overlapping of compute and communications.
|
||||
|
||||
To minimize compute interference, they also require the use of copy engines and
|
||||
network interface cards (NICs) to drive communication. Network transport
|
||||
protocols such as remote direct memory access (RDMA) enhance the performance by
|
||||
enabling direct, high-speed, and low-latency communication between processors
|
||||
and memory. This increase in variety indicates the need for finer-grained
|
||||
communication primitives than are offered today by high-level collective APIs,
|
||||
ones that would enable developers to implement specific algorithms tailored for
|
||||
their use cases, such as low-latency collectives, fine-grained
|
||||
compute-communications overlap, or custom fusions.
|
||||
|
||||
Furthermore, today’s advanced AI systems connect GPUs with high-bandwidth links
|
||||
(such as NVLinks, InfiniBand or RoCE), making GPU global memory directly
|
||||
accessible to peers. Such connections present a great opportunity for
|
||||
programmers to program the system as a single, gigantic GPU with vast accessible
|
||||
memory, instead of programming singular “GPU islands.”
|
||||
|
||||
In this document, we will show how you can use PyTorch Symmetric Memory to
|
||||
program modern GPU systems as a “single GPU” and achieve fine-grained remote
|
||||
access.
|
||||
|
||||
## What PyTorch Symmetric Memory unlocks?
|
||||
|
||||
PyTorch Symmetric Memory unlocks three new capabilities:
|
||||
|
||||
- **Customized communication patterns**: Increased flexibility in kernel writing
|
||||
allows developers to write custom kernels that implement their custom
|
||||
computations and communications, directly tailored to the need of the
|
||||
application. It will also be straightforward to add support for new data types
|
||||
along with the special compute that those data types might require, even if it’s
|
||||
not present yet in the standard libraries.
|
||||
|
||||
- **In-kernel compute-comm fusion**: Device-initiated communication capability
|
||||
allows developers to write kernels with both computation and communication
|
||||
instructions, allowing for the fusion of computation and data movement in the
|
||||
smallest possible granularity.
|
||||
|
||||
- **Low-latency remote access**: Network transport protocols like RDMA enhance the
|
||||
performance of symmetric memory in networked environments by enabling direct,
|
||||
high-speed, and low-latency communication between processors and memory. RDMA
|
||||
eliminates the overhead associated with the traditional network stack and CPU
|
||||
involvement. It also offloads data transfer from the compute to the NICs,
|
||||
freeing up compute resources for computational tasks.
|
||||
|
||||
Next, we will show you how PyTorch Symmetric Memory (SymmMem) enables new
|
||||
applications with the above capabilities.
|
||||
|
||||
## A “Hello World” example
|
||||
|
||||
The PyTorch SymmMem programming model involves two key elements:
|
||||
|
||||
- creating symmetric tensors
|
||||
- creating SymmMem kernels
|
||||
|
||||
To create symmetric tensors, one can use the
|
||||
`torch.distributed._symmetric_memory` package:
|
||||
|
||||
```python
|
||||
import torch.distributed._symmetric_memory as symm_mem
|
||||
|
||||
t = symm_mem.empty(128, device=torch.device("cuda", rank))
|
||||
hdl = symm_mem.rendezvous(t, group)
|
||||
```
|
||||
|
||||
The `symm_mem.empty` function creates a tensor that is backed by a symmetric
|
||||
memory allocation. The `rendezvous` function establishes a rendezvous with peers
|
||||
in the group, and returns a handle to the symmetric memory allocation. The
|
||||
handle provides method to access information related to the symmetric memory
|
||||
allocation, such as pointers to symmetric buffer on peer ranks, multicast
|
||||
pointer (if supported), and signal pads.
|
||||
|
||||
The `empty` and `rendezvous` functions must be called in the same order on all
|
||||
ranks in the group.
|
||||
|
||||
Then, collectives can be called on these tensors. For example, to perform a
|
||||
one-shot all-reduce:
|
||||
|
||||
```python
|
||||
# Most SymmMem ops are under the torch.ops.symm_mem namespace
|
||||
torch.ops.symm_mem.one_shot_all_reduce(t, "sum", group)
|
||||
```
|
||||
|
||||
Please note that `torch.ops.symm_mem` is an "op namespace" instead of a python
|
||||
module. Therefore, you can't import it by `import torch.ops.symm_mem`, neither
|
||||
can you import an op by `from torch.ops.symm_mem import one_shot_all_reduce`.
|
||||
You can call the op directly as in the example above.
|
||||
|
||||
## Write your own kernel
|
||||
|
||||
To write your own kernel doing communications with symmetric memory, you’ll need
|
||||
access to the addresses of mapped peer buffers and access to signal pads that
|
||||
are required for synchronization. In the kernel you’ll also need to perform
|
||||
correct synchronizations to make sure that peers are ready for communication,
|
||||
and signal to them that this GPU is ready.
|
||||
|
||||
PyTorch Symmetric Memory provides CUDA Graph-compatible synchronization
|
||||
primitives that operate on the signal pad accompanying each symmetric memory
|
||||
allocation. Kernels using symmetric memory can be written both in CUDA and in
|
||||
Triton. Here’s an example allocating symmetric tensor and exchanging handles:
|
||||
|
||||
```python
|
||||
import torch.distributed._symmetric_memory as symm_mem
|
||||
|
||||
dist.init_process_group()
|
||||
rank = dist.get_rank()
|
||||
|
||||
# Allocate a tensor
|
||||
t = symm_mem.empty(4096, device=f"cuda:{rank}")
|
||||
# Establish symmetric memory and obtain the handle
|
||||
hdl = symm_mem.rendezvous(t, dist.group.WORLD)
|
||||
```
|
||||
|
||||
Access to buffer pointers, multimem pointer, and signal pads is provided via:
|
||||
|
||||
```python
|
||||
hdl.buffer_ptrs
|
||||
hdl.multicast_ptr
|
||||
hdl.signal_pad_ptrs
|
||||
```
|
||||
|
||||
Data pointed to by `buffer_ptrs` can be accessed just like regular local data,
|
||||
and any necessary compute can also be performed in the usual ways. As with local
|
||||
data, you can and should use vectorized accesses to improve efficiency.
|
||||
|
||||
Symmetric memory is especially convenient for writing kernels in Triton. While
|
||||
previously Triton removed the barriers to writing efficient CUDA code, now
|
||||
communications can be added easily to Triton kernels. The kernel below
|
||||
demonstrates a low-latency, all-reduce kernel written in Triton.
|
||||
|
||||
```python
|
||||
@triton.jit
|
||||
def one_shot_all_reduce_kernel(
|
||||
buf_tuple,
|
||||
signal_pad_ptrs,
|
||||
output_ptr,
|
||||
numel: tl.constexpr,
|
||||
rank: tl.constexpr,
|
||||
world_size: tl.constexpr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
ptx_utils.symm_mem_sync(
|
||||
signal_pad_ptrs, None, rank, world_size, hasSubsequenceMemAccess=True
|
||||
)
|
||||
|
||||
pid = tl.program_id(axis=0)
|
||||
block_start = pid * BLOCK_SIZE
|
||||
|
||||
while block_start < numel:
|
||||
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
||||
mask = offsets < numel
|
||||
acc = tl.zeros((BLOCK_SIZE,), dtype=tl.bfloat16)
|
||||
|
||||
for i in tl.static_range(world_size):
|
||||
buffer_rank = buf_tuple[i]
|
||||
x = tl.load(buffer_rank + offsets, mask=mask)
|
||||
acc += x
|
||||
|
||||
tl.store(output_ptr + offsets, acc, mask=mask)
|
||||
block_start += tl.num_programs(axis=0) * BLOCK_SIZE
|
||||
|
||||
ptx_utils.symm_mem_sync(
|
||||
signal_pad_ptrs, None, rank, world_size, hasPreviousMemAccess=True
|
||||
)
|
||||
```
|
||||
|
||||
Synchronizations at the beginning and the end of the kernel above guarantee that
|
||||
all the processes see consistent data. The bulk of the kernel is recognizable
|
||||
Triton code, and Triton will optimize it behind the scene, making sure memory
|
||||
accesses are performed in an efficient way with vectorization and unrolling. As
|
||||
with all Triton kernels, it is easily modifiable to add extra computations or
|
||||
change the communication algorithm. Visit
|
||||
https://github.com/meta-pytorch/kraken/blob/main/kraken to see additional
|
||||
utilities and examples of using symmetric memory to implement common patterns in
|
||||
Triton.
|
||||
|
||||
## Scale out
|
||||
|
||||
Large language models distribute experts onto more than 8 GPUs, hence requiring
|
||||
multi-node access capability. NICs capable of RDMA come to help. In addition,
|
||||
software libraries such as NVSHMEM or rocSHMEM abstract away the programming
|
||||
difference between intra-node access and inter-node access with primitives that
|
||||
are slightly higher level than pointer access, such as put and get.
|
||||
|
||||
PyTorch provides NVSHMEM plugins to augment Triton kernels’ cross-node
|
||||
capabilities. As shown in the code snippet below, one can initiate a cross-node
|
||||
put command within the kernel.
|
||||
|
||||
```python
|
||||
import torch.distributed._symmetric_memory._nvshmem_triton as nvshmem
|
||||
from torch.distributed._symmetric_memory._nvshmem_triton import requires_nvshmem
|
||||
|
||||
@requires_nvshmem
|
||||
@triton.jit
|
||||
def my_put_kernel(
|
||||
dest,
|
||||
src,
|
||||
nelems,
|
||||
pe,
|
||||
):
|
||||
nvshmem.put(dest, src, nelems, pe)
|
||||
```
|
||||
|
||||
The `requires_nvshmem` decorator is used to indicate that the kernel requires
|
||||
the NVSHMEM device library as an external dependency. When Triton compiles the
|
||||
kernel, the decorator will search your system paths for the NVSHMEM device
|
||||
library. If it is available, Triton will include the necessary device assembly
|
||||
to use the NVSHMEM functions.
|
||||
|
||||
## API Reference
|
||||
|
||||
```{eval-rst}
|
||||
.. currentmodule:: torch.distributed._symmetric_memory
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. autofunction:: empty
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. autofunction:: rendezvous
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. autofunction:: is_nvshmem_available
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. autofunction:: set_backend
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. autofunction:: get_backend
|
||||
```
|
||||
|
||||
## Op Reference
|
||||
:::{note}
|
||||
The following ops are hosted in the `torch.ops.symm_mem` namespace. You can call
|
||||
them directly via `torch.ops.symm_mem.<op_name>`.
|
||||
:::
|
||||
|
||||
```{eval-rst}
|
||||
.. currentmodule:: torch.ops.symm_mem
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. py:function:: multimem_all_reduce_(input: Tensor, reduce_op: str, group_name: str) -> Tensor
|
||||
|
||||
Performs a multimem all-reduce operation on the input tensor. This operation
|
||||
requires hardware support for multimem operations. On NVIDIA GPUs, NVLink
|
||||
SHARP is required.
|
||||
|
||||
:param Tensor input: Input tensor to perform all-reduce on. Must be symmetric.
|
||||
:param str reduce_op: Reduction operation to perform. Currently only "sum" is supported.
|
||||
:param str group_name: Name of the group to perform all-reduce on.
|
||||
|
||||
|
||||
.. py:function:: multimem_all_gather_out(input: Tensor, group_name: str, out: Tensor) -> Tensor
|
||||
|
||||
Performs a multimem all-gather operation on the input tensor. This operation requires hardware support for multimem operations. On NVIDIA GPUs, NVLink SHARP is required.
|
||||
|
||||
:param Tensor input: Input tensor to perform all-gather on.
|
||||
:param str group_name: Name of the group to perform all-gather on.
|
||||
:param Tensor out: Output tensor to store the result of the all-gather operation. Must be symmetric.
|
||||
|
||||
|
||||
.. py:function:: one_shot_all_reduce(input: Tensor, reduce_op: str, group_name: str) -> Tensor
|
||||
|
||||
Performs a one-shot all-reduce operation on the input tensor.
|
||||
|
||||
:param Tensor input: Input tensor to perform all-reduce on. Must be symmetric.
|
||||
:param str reduce_op: Reduction operation to perform. Currently only "sum" is supported.
|
||||
:param str group_name: Name of the group to perform all-reduce on.
|
||||
|
||||
|
||||
.. py:function:: one_shot_all_reduce_out(input: Tensor, reduce_op: str, group_name: str, out: Tensor) -> Tensor
|
||||
|
||||
Performs a one-shot all-reduce operation based on the input tensor and writes the result to the output tensor.
|
||||
|
||||
:param Tensor input: Input tensor to perform all-reduce on. Must be symmetric.
|
||||
:param str reduce_op: Reduction operation to perform. Currently only "sum" is supported.
|
||||
:param str group_name: Name of the group to perform all-reduce on.
|
||||
:param Tensor out: Output tensor to store the result of the all-reduce operation. Can be a regular tensor.
|
||||
|
||||
|
||||
.. py:function:: two_shot_all_reduce_(input: Tensor, reduce_op: str, group_name: str) -> Tensor
|
||||
|
||||
Performs a two-shot all-reduce operation on the input tensor.
|
||||
|
||||
:param Tensor input: Input tensor to perform all-reduce on. Must be symmetric.
|
||||
:param str reduce_op: Reduction operation to perform. Currently only "sum" is supported.
|
||||
:param str group_name: Name of the group to perform all-reduce on.
|
||||
|
||||
|
||||
.. py:function:: all_to_all_vdev(input: Tensor, out: Tensor, in_splits: Tensor, out_splits_offsets: Tensor, group_name: str) -> None
|
||||
|
||||
Performs an all-to-all-v operation using NVSHMEM, with split information provided on device.
|
||||
|
||||
:param Tensor input: Input tensor to perform all-to-all on. Must be symmetric.
|
||||
:param Tensor out: Output tensor to store the result of the all-to-all operation. Must be symmetric.
|
||||
:param Tensor in_splits: Tensor containing splits of data to send to each peer. Must be symmetric. Must be of size (group_size,). The splits are in the unit of elements in the 1st dimension.
|
||||
:param Tensor out_splits_offsets: Tensor containing the splits and offsets of data received from each peer. Must be symmetric. Must be of size (2, group_size). The rows are (in order): output splits and output offsets.
|
||||
:param str group_name: Name of the group to perform all-to-all on.
|
||||
|
||||
|
||||
.. py:function:: all_to_all_vdev_2d(input: Tensor, out: Tensor, in_splits: Tensor, out_splits_offsets: Tensor, group_name: str, [major_align: int = None]) -> None
|
||||
|
||||
Perform a 2D all-to-all-v operation using NVSHMEM, with split information provided on device. In Mixture of Experts models, this operation can be used to dispatch tokens.
|
||||
|
||||
:param Tensor input: Input tensor to perform all-to-all on. Must be symmetric.
|
||||
:param Tensor out: Output tensor to store the result of the all-to-all operation. Must be symmetric.
|
||||
:param Tensor in_splits: Tensor containing the splits of data to send to each expert. Must be symmetric. Must be of size (group_size * ne,), where ne is the number of experts per rank. The splits are in the unit of elements in the 1st dimension.
|
||||
:param Tensor out_splits_offsets: Tensor containing the splits and offsets of data received from each peer. Must be symmetric. Must be of size (2, group_size * ne). The rows are (in order): output splits and output offsets.
|
||||
:param str group_name: Name of the group to perform all-to-all on.
|
||||
:param int major_align: Optional alignment for the major dimension of the output chunk for each expert. If not provided, the alignment is assumed to be 1. Any alignment adjustment will be reflected in the output offsets.
|
||||
|
||||
A 2D AllToAllv shuffle is illustrated below:
|
||||
(world_size = 2, ne = 2, total number of experts = 4)::
|
||||
|
||||
Source: | Rank 0 | Rank 1 |
|
||||
| c0 | c1 | c2 | c3 | d0 | d1 | d2 | d3 |
|
||||
|
||||
Dest : | Rank 0 | Rank 1 |
|
||||
| c0 | d0 | c1 | d1 | c2 | d2 | c3 | d3 |
|
||||
|
||||
where each `c_i` / `d_i` are slices of the `input` tensor, targeting expert
|
||||
`i`, with length indicated by input splits. That is, the 2D AllToAllv
|
||||
shuffle achieves a transpose from rank-major order at input to expert-major
|
||||
order at output.
|
||||
|
||||
If `major_align` is not 1, the output offsets of c1, c2, c3 will be
|
||||
up-aligned to this value. For example, if c0 has length 5 and d0 has
|
||||
length 7 (making a total of 12), and if the `major_align` is set to 16,
|
||||
the output offset of c1 will be 16. Similar for c2 and c3. This value has
|
||||
no effect on the offset of the minor dimension, i.e. d0, d1, d2 and d3.
|
||||
Note: since cutlass does not support empty bins, we set the aligned length
|
||||
to `major_align` if it is 0. See
|
||||
https://github.com/pytorch/pytorch/issues/152668.
|
||||
|
||||
|
||||
.. py:function:: all_to_all_vdev_2d_offset(Tensor input, Tensor out, Tensor in_splits_offsets, Tensor out_splits_offsets, str group_name) -> None
|
||||
|
||||
Perform a 2D AllToAllv shuffle operation, with input split and offset
|
||||
information provided on device. The input offsets are not required to be
|
||||
exact prefix sum of the input splits, i.e. paddings are allowed between the
|
||||
split chunks. The paddings, however, will not be transferred to peer
|
||||
ranks.
|
||||
|
||||
In Mixture of Experts models, this operation can be used to combine tokens
|
||||
processed by experts on parallel ranks. This operation can be viewed as an
|
||||
"reverse" operation to the `all_to_all_vdev_2d` operation (which shuffles
|
||||
tokens to experts).
|
||||
|
||||
:param Tensor input: Input tensor to perform all-to-all on. Must be symmetric.
|
||||
:param Tensor out: Output tensor to store the result of the all-to-all operation. Must be symmetric.
|
||||
:param Tensor in_splits_offsets: Tensor containing the splits and offsets of data to send to each expert. Must be symmetric. Must be of size (2, group_size * ne), where `ne` is the number of experts. The rows are (in order): input splits and input offsets. The splits are in the unit of elements in the 1st dimension.
|
||||
:param Tensor out_splits_offsets: Tensor containing the splits and offsets of data received from each peer. Must be symmetric. Must be of size (2, group_size * ne). The rows are (in order): output splits and output offsets.
|
||||
:param str group_name: Name of the group to perform all-to-all on.
|
||||
|
||||
```
|
||||
@ -130,6 +130,7 @@ errors.bad-param-name-override = false
|
||||
# Mypy doesn't require that imports are explicitly imported, so be compatible with that.
|
||||
# Might be a good idea to turn this on in future.
|
||||
errors.implicit-import = false
|
||||
errors.deprecated = false # re-enable after we've fix import formatting
|
||||
permissive-ignores = true
|
||||
replace-imports-with-any = ["!sympy.printing.*", "sympy.*", "onnxscript.onnx_opset.*"]
|
||||
search-path = ["tools/experimental"]
|
||||
|
||||
@ -156,12 +156,6 @@ TORCH_LIBRARY_IMPL(openreg, PrivateUse1, m) {
|
||||
}
|
||||
// LITERALINCLUDE END: CUSTOM OPERATOR DEFAULT
|
||||
|
||||
// LITERALINCLUDE START: CUSTOM OPERATOR FALLBACK
|
||||
TORCH_LIBRARY_IMPL(_, AutogradPrivateUse1, m) {
|
||||
m.fallback(torch::autograd::autogradNotImplementedFallback());
|
||||
}
|
||||
// LITERALINCLUDE END: CUSTOM OPERATOR FALLBACK
|
||||
|
||||
// The rest is for testing purposes
|
||||
TORCH_LIBRARY_IMPL(aten, PrivateUse1, m) {
|
||||
/*
|
||||
|
||||
@ -257,7 +257,10 @@ class RecordWorkerEventsTest(unittest.TestCase):
|
||||
self.assertEqual(failed_event.metadata["state"], "FAILED")
|
||||
self.assertEqual(failed_event.metadata["global_rank"], 0)
|
||||
md = json.loads(failed_event.metadata["metadata"])
|
||||
self.assertEqual(failed_event.metadata["raw_error"], '{"message": "<NONE>"}')
|
||||
self.assertEqual(
|
||||
failed_event.metadata["raw_error"],
|
||||
'{"message": "<NONE>", "errorTraits": {"category": "system_terminated_error", "retryability": "False"}}',
|
||||
)
|
||||
self.assertEqual(md["exit_code"], [1])
|
||||
self.assertEqual(md["worker_pid"], [1000])
|
||||
|
||||
|
||||
@ -127,8 +127,9 @@ def echo1(msg: str, exitcode: int = 0) -> str:
|
||||
print(f"exit {exitcode} from {rank}", file=sys.stderr)
|
||||
sys.exit(exitcode)
|
||||
else:
|
||||
print(f"{msg} stdout from {rank}")
|
||||
print(f"{msg} stderr from {rank}", file=sys.stderr)
|
||||
for m in msg.split(","):
|
||||
print(f"{m} stdout from {rank}")
|
||||
print(f"{m} stderr from {rank}", file=sys.stderr)
|
||||
return f"{msg}_{rank}"
|
||||
|
||||
|
||||
@ -247,6 +248,13 @@ class _StartProcessesTest(TestCase):
|
||||
for line in expected:
|
||||
self.assertIn(line, actual)
|
||||
|
||||
def assert_not_in_file(self, lines: list[str], filename: str) -> None:
|
||||
lines = [f"{line.rstrip()}\n" for line in lines]
|
||||
with open(filename) as fp:
|
||||
actual = fp.readlines()
|
||||
for line in lines:
|
||||
self.assertNotIn(line, actual)
|
||||
|
||||
def assert_pids_noexist(self, pids: dict[int, int]):
|
||||
for local_rank, pid in pids.items():
|
||||
with self.assertRaises(
|
||||
@ -360,8 +368,8 @@ if not (TEST_WITH_DEV_DBG_ASAN or IS_WINDOWS or IS_MACOS):
|
||||
|
||||
self.assertIsNone(pc.wait(timeout=0.1, period=0.01))
|
||||
self.assertIsNotNone(pc.wait(period=0.1))
|
||||
self.assertTrue(pc._stderr_tail.stopped())
|
||||
self.assertTrue(pc._stdout_tail.stopped())
|
||||
for tail_log in pc._tail_logs:
|
||||
self.assertTrue(tail_log.stopped())
|
||||
|
||||
def test_pcontext_wait_on_a_child_thread(self):
|
||||
asyncio.run(asyncio.to_thread(self.test_pcontext_wait))
|
||||
@ -379,8 +387,8 @@ if not (TEST_WITH_DEV_DBG_ASAN or IS_WINDOWS or IS_MACOS):
|
||||
pids = pc.pids()
|
||||
pc.close()
|
||||
self.assert_pids_noexist(pids)
|
||||
self.assertTrue(pc._stderr_tail.stopped())
|
||||
self.assertTrue(pc._stdout_tail.stopped())
|
||||
for tail_log in pc._tail_logs:
|
||||
self.assertTrue(tail_log.stopped())
|
||||
|
||||
def test_function_with_tensor(self):
|
||||
for start_method in self._start_methods:
|
||||
@ -482,8 +490,8 @@ if not (TEST_WITH_DEV_DBG_ASAN or IS_WINDOWS or IS_MACOS):
|
||||
int(error_file_data["message"]["extraInfo"]["timestamp"]),
|
||||
int(failure.timestamp),
|
||||
)
|
||||
self.assertTrue(pc._stderr_tail.stopped())
|
||||
self.assertTrue(pc._stdout_tail.stopped())
|
||||
for tail_log in pc._tail_logs:
|
||||
self.assertTrue(tail_log.stopped())
|
||||
|
||||
def test_wait_for_all_child_procs_to_exit(self):
|
||||
"""
|
||||
@ -580,8 +588,8 @@ if not (TEST_WITH_DEV_DBG_ASAN or IS_WINDOWS or IS_MACOS):
|
||||
self.assert_in_file([], results.stdouts[0])
|
||||
self.assertFalse(results.stderrs[1])
|
||||
self.assertFalse(results.stdouts[1])
|
||||
self.assertTrue(pc._stderr_tail.stopped())
|
||||
self.assertTrue(pc._stdout_tail.stopped())
|
||||
for tail_log in pc._tail_logs:
|
||||
self.assertTrue(tail_log.stopped())
|
||||
|
||||
failure = results.failures[1]
|
||||
self.assertEqual(-15, failure.exitcode)
|
||||
@ -731,8 +739,37 @@ if not (TEST_WITH_DEV_DBG_ASAN or IS_WINDOWS or IS_MACOS):
|
||||
self.assert_in_file(["hello stderr from 0"], pc.stderrs[0])
|
||||
self.assert_in_file(["world stderr from 1"], pc.stderrs[1])
|
||||
self.assertFalse(pc.stdouts[1])
|
||||
self.assertTrue(pc._stderr_tail.stopped())
|
||||
self.assertTrue(pc._stdout_tail.stopped())
|
||||
for tail_log in pc._tail_logs:
|
||||
self.assertTrue(tail_log.stopped())
|
||||
|
||||
def test_binary_duplicate_log_filters(self):
|
||||
pc = start_processes(
|
||||
name="trainer",
|
||||
entrypoint=bin("echo1.py"),
|
||||
args={0: ("helloA,helloB",), 1: ("worldA,worldB",)},
|
||||
envs={0: {"RANK": "0"}, 1: {"RANK": "1"}},
|
||||
logs_specs=DefaultLogsSpecs(
|
||||
log_dir=self.log_dir(),
|
||||
redirects={0: Std.ERR, 1: Std.NONE},
|
||||
tee={0: Std.OUT, 1: Std.ERR},
|
||||
),
|
||||
log_line_prefixes={0: "[rank0]:", 1: "[rank1]:"},
|
||||
duplicate_stdout_filters=["helloA"],
|
||||
duplicate_stderr_filters=["worldA", "B"],
|
||||
start_method="spawn",
|
||||
)
|
||||
|
||||
result = pc.wait()
|
||||
|
||||
self.assertFalse(result.is_failed())
|
||||
self.assert_in_file(["[rank0]:helloA stdout from 0"], pc.filtered_stdout)
|
||||
self.assert_not_in_file(
|
||||
["[rank0]:helloB stdout from 0"], pc.filtered_stdout
|
||||
)
|
||||
self.assert_in_file(["[rank1]:worldA stderr from 1"], pc.filtered_stderr)
|
||||
self.assert_in_file(["[rank1]:worldB stderr from 1"], pc.filtered_stderr)
|
||||
for tail_log in pc._tail_logs:
|
||||
self.assertTrue(tail_log.stopped())
|
||||
|
||||
|
||||
# tests incompatible with tsan or asan, the redirect functionality does not work on macos or windows
|
||||
@ -794,8 +831,44 @@ if not (TEST_WITH_DEV_DBG_ASAN or IS_WINDOWS or IS_MACOS or IS_CI):
|
||||
self.assert_in_file(["hello stderr from 0"], pc.stderrs[0])
|
||||
self.assert_in_file(["world stderr from 1"], pc.stderrs[1])
|
||||
self.assertFalse(pc.stdouts[1])
|
||||
self.assertTrue(pc._stderr_tail.stopped())
|
||||
self.assertTrue(pc._stdout_tail.stopped())
|
||||
for tail_log in pc._tail_logs:
|
||||
self.assertTrue(tail_log.stopped())
|
||||
|
||||
def test_function_duplicate_log_filters(self):
|
||||
for start_method in self._start_methods:
|
||||
with self.subTest(start_method=start_method):
|
||||
pc = start_processes(
|
||||
name="trainer",
|
||||
entrypoint=echo1,
|
||||
args={0: ("helloA,helloB",), 1: ("worldA,worldB",)},
|
||||
envs={0: {"RANK": "0"}, 1: {"RANK": "1"}},
|
||||
logs_specs=DefaultLogsSpecs(
|
||||
log_dir=self.log_dir(),
|
||||
redirects={0: Std.ERR, 1: Std.NONE},
|
||||
tee={0: Std.OUT, 1: Std.ERR},
|
||||
),
|
||||
duplicate_stdout_filters=["helloA"],
|
||||
duplicate_stderr_filters=["worldA", "B"],
|
||||
start_method="spawn",
|
||||
)
|
||||
|
||||
result = pc.wait()
|
||||
|
||||
self.assertFalse(result.is_failed())
|
||||
self.assert_in_file(
|
||||
["[trainer0]:helloA stdout from 0"], pc.filtered_stdout
|
||||
)
|
||||
self.assert_not_in_file(
|
||||
["[trainer0]:helloB stdout from 0"], pc.filtered_stdout
|
||||
)
|
||||
self.assert_in_file(
|
||||
["[trainer1]:worldA stderr from 1"], pc.filtered_stderr
|
||||
)
|
||||
self.assert_in_file(
|
||||
["[trainer1]:worldB stderr from 1"], pc.filtered_stderr
|
||||
)
|
||||
for tail_log in pc._tail_logs:
|
||||
self.assertTrue(tail_log.stopped())
|
||||
|
||||
def test_function(self):
|
||||
for start_method, redirs in product(self._start_methods, redirects_all()):
|
||||
@ -880,8 +953,8 @@ if not (TEST_WITH_DEV_DBG_ASAN or IS_WINDOWS or IS_MACOS or IS_CI):
|
||||
self.assertFalse(results.stdouts[0])
|
||||
self.assertFalse(results.stderrs[1])
|
||||
self.assertFalse(results.stdouts[1])
|
||||
self.assertTrue(pc._stderr_tail.stopped())
|
||||
self.assertTrue(pc._stdout_tail.stopped())
|
||||
for tail_log in pc._tail_logs:
|
||||
self.assertTrue(tail_log.stopped())
|
||||
|
||||
def test_no_zombie_process_function(self):
|
||||
signals = [signal.SIGTERM, signal.SIGINT, signal.SIGHUP, signal.SIGQUIT]
|
||||
|
||||
@ -23,5 +23,6 @@ if __name__ == "__main__":
|
||||
print(f"exit {exitcode} from {rank}", file=sys.stderr)
|
||||
sys.exit(exitcode)
|
||||
else:
|
||||
print(f"{args.msg} stdout from {rank}")
|
||||
print(f"{args.msg} stderr from {rank}", file=sys.stderr)
|
||||
for msg in args.msg.split(","):
|
||||
print(f"{msg} stdout from {rank}")
|
||||
print(f"{msg} stderr from {rank}", file=sys.stderr)
|
||||
|
||||
@ -84,6 +84,53 @@ class TailLogTest(unittest.TestCase):
|
||||
)
|
||||
self.assertTrue(tail.stopped())
|
||||
|
||||
def test_tail_write_to_dst_file(self):
|
||||
"""
|
||||
writer() writes 0 - max (on number on each line) to a log file.
|
||||
Run nprocs such writers and tail the log files into a temp file
|
||||
and validate that all lines are accounted for.
|
||||
"""
|
||||
nprocs = 32
|
||||
max = 1000
|
||||
interval_sec = 0.0001
|
||||
|
||||
log_files = {
|
||||
local_rank: os.path.join(self.test_dir, f"{local_rank}_stdout.log")
|
||||
for local_rank in range(nprocs)
|
||||
}
|
||||
|
||||
dst = os.path.join(self.test_dir, "tailed_stdout.log")
|
||||
tail = TailLog(
|
||||
name="writer", log_files=log_files, dst=dst, interval_sec=interval_sec
|
||||
).start()
|
||||
# sleep here is intentional to ensure that the log tail
|
||||
# can gracefully handle and wait for non-existent log files
|
||||
time.sleep(interval_sec * 10)
|
||||
|
||||
futs = []
|
||||
for local_rank, file in log_files.items():
|
||||
f = self.threadpool.submit(
|
||||
write, max=max, sleep=interval_sec * local_rank, file=file
|
||||
)
|
||||
futs.append(f)
|
||||
|
||||
wait(futs, return_when=ALL_COMPLETED)
|
||||
self.assertFalse(tail.stopped())
|
||||
tail.stop()
|
||||
|
||||
actual: dict[int, set[int]] = {}
|
||||
with open(dst) as dst_file:
|
||||
for line in dst_file:
|
||||
header, num = line.split(":")
|
||||
nums = actual.setdefault(header, set())
|
||||
nums.add(int(num))
|
||||
|
||||
self.assertEqual(nprocs, len(actual))
|
||||
self.assertEqual(
|
||||
{f"[writer{i}]": set(range(max)) for i in range(nprocs)}, actual
|
||||
)
|
||||
self.assertTrue(tail.stopped())
|
||||
|
||||
def test_tail_with_custom_prefix(self):
|
||||
"""
|
||||
writer() writes 0 - max (on number on each line) to a log file.
|
||||
@ -131,6 +178,52 @@ class TailLogTest(unittest.TestCase):
|
||||
self.assertIn(f"[worker{i}][{i}]", headers)
|
||||
self.assertTrue(tail.stopped())
|
||||
|
||||
def test_tail_with_custom_filter(self):
|
||||
"""
|
||||
writer() writes 0 - max (on number on each line) to a log file.
|
||||
Run nprocs such writers and tail the log files into an IOString
|
||||
and validate that all lines are accounted for.
|
||||
"""
|
||||
nprocs = 3
|
||||
max = 20
|
||||
interval_sec = 0.0001
|
||||
|
||||
log_files = {
|
||||
local_rank: os.path.join(self.test_dir, f"{local_rank}_stdout.log")
|
||||
for local_rank in range(nprocs)
|
||||
}
|
||||
|
||||
dst = io.StringIO()
|
||||
tail = TailLog(
|
||||
"writer",
|
||||
log_files,
|
||||
dst,
|
||||
interval_sec=interval_sec,
|
||||
log_line_filter=lambda line: "2" in line, # only print lines containing '2'
|
||||
).start()
|
||||
# sleep here is intentional to ensure that the log tail
|
||||
# can gracefully handle and wait for non-existent log files
|
||||
time.sleep(interval_sec * 10)
|
||||
futs = []
|
||||
for local_rank, file in log_files.items():
|
||||
f = self.threadpool.submit(
|
||||
write, max=max, sleep=interval_sec * local_rank, file=file
|
||||
)
|
||||
futs.append(f)
|
||||
wait(futs, return_when=ALL_COMPLETED)
|
||||
self.assertFalse(tail.stopped())
|
||||
tail.stop()
|
||||
dst.seek(0)
|
||||
|
||||
actual: dict[int, set[int]] = {}
|
||||
for line in dst.readlines():
|
||||
header, num = line.split(":")
|
||||
nums = actual.setdefault(header, set())
|
||||
nums.add(int(num))
|
||||
self.assertEqual(nprocs, len(actual))
|
||||
self.assertEqual({f"[writer{i}]": {2, 12} for i in range(nprocs)}, actual)
|
||||
self.assertTrue(tail.stopped())
|
||||
|
||||
def test_tail_no_files(self):
|
||||
"""
|
||||
Ensures that the log tail can gracefully handle no log files
|
||||
|
||||
@ -55,9 +55,10 @@ class SignalHandlingTest(TestCase):
|
||||
mock_threading.main_thread.return_value
|
||||
)
|
||||
mock_pcontext = MagicMock(spec=PContext)
|
||||
# Mock the _stdout_tail and _stderr_tail attributes
|
||||
mock_pcontext._stdout_tail = MagicMock()
|
||||
mock_pcontext._stderr_tail = MagicMock()
|
||||
# Mock the stdout_tail and stderr_tail
|
||||
mock_stdout_tail = MagicMock()
|
||||
mock_stderr_tail = MagicMock()
|
||||
mock_pcontext._tail_logs = [mock_stdout_tail, mock_stderr_tail]
|
||||
|
||||
# Remove environment variable if it exists to test default behavior
|
||||
if "TORCHELASTIC_SIGNALS_TO_HANDLE" in os.environ:
|
||||
@ -84,8 +85,8 @@ class SignalHandlingTest(TestCase):
|
||||
# Verify _start was called
|
||||
mock_pcontext._start.assert_called_once()
|
||||
# Verify _stdout_tail.start() and _stderr_tail.start() were called
|
||||
mock_pcontext._stdout_tail.start.assert_called_once()
|
||||
mock_pcontext._stderr_tail.start.assert_called_once()
|
||||
mock_stdout_tail.start.assert_called_once()
|
||||
mock_stderr_tail.start.assert_called_once()
|
||||
|
||||
@patch("torch.distributed.elastic.multiprocessing.api.threading")
|
||||
@patch("torch.distributed.elastic.multiprocessing.api.signal")
|
||||
@ -99,9 +100,10 @@ class SignalHandlingTest(TestCase):
|
||||
mock_threading.main_thread.return_value
|
||||
)
|
||||
mock_pcontext = MagicMock(spec=PContext)
|
||||
# Mock the _stdout_tail and _stderr_tail attributes
|
||||
mock_pcontext._stdout_tail = MagicMock()
|
||||
mock_pcontext._stderr_tail = MagicMock()
|
||||
# Mock the stdout_tail and stderr_tail
|
||||
mock_stdout_tail = MagicMock()
|
||||
mock_stderr_tail = MagicMock()
|
||||
mock_pcontext._tail_logs = [mock_stdout_tail, mock_stderr_tail]
|
||||
|
||||
# Set custom signals in the environment variable
|
||||
os.environ["TORCHELASTIC_SIGNALS_TO_HANDLE"] = "SIGTERM,SIGUSR1,SIGUSR2"
|
||||
@ -139,9 +141,10 @@ class SignalHandlingTest(TestCase):
|
||||
mock_threading.main_thread.return_value
|
||||
)
|
||||
mock_pcontext = MagicMock(spec=PContext)
|
||||
# Mock the _stdout_tail and _stderr_tail attributes
|
||||
mock_pcontext._stdout_tail = MagicMock()
|
||||
mock_pcontext._stderr_tail = MagicMock()
|
||||
# Mock the stdout_tail and stderr_tail
|
||||
mock_stdout_tail = MagicMock()
|
||||
mock_stderr_tail = MagicMock()
|
||||
mock_pcontext._tail_logs = [mock_stdout_tail, mock_stderr_tail]
|
||||
|
||||
# Set invalid signals in the environment variable
|
||||
os.environ["TORCHELASTIC_SIGNALS_TO_HANDLE"] = "SIGTERM,INVALID_SIGNAL"
|
||||
@ -180,9 +183,10 @@ class SignalHandlingTest(TestCase):
|
||||
mock_threading.main_thread.return_value
|
||||
)
|
||||
mock_pcontext = MagicMock(spec=PContext)
|
||||
# Mock the _stdout_tail and _stderr_tail attributes
|
||||
mock_pcontext._stdout_tail = MagicMock()
|
||||
mock_pcontext._stderr_tail = MagicMock()
|
||||
# Mock the stdout_tail and stderr_tail
|
||||
mock_stdout_tail = MagicMock()
|
||||
mock_stderr_tail = MagicMock()
|
||||
mock_pcontext._tail_logs = [mock_stdout_tail, mock_stderr_tail]
|
||||
|
||||
# Set signals including ones not supported on Windows
|
||||
os.environ["TORCHELASTIC_SIGNALS_TO_HANDLE"] = "SIGTERM,SIGHUP,SIGUSR1"
|
||||
@ -234,9 +238,10 @@ class SignalHandlingTest(TestCase):
|
||||
mock_threading.current_thread.return_value = MagicMock() # Not the main thread
|
||||
mock_threading.main_thread.return_value = MagicMock()
|
||||
mock_pcontext = MagicMock(spec=PContext)
|
||||
# Mock the _stdout_tail and _stderr_tail attributes
|
||||
mock_pcontext._stdout_tail = MagicMock()
|
||||
mock_pcontext._stderr_tail = MagicMock()
|
||||
# Mock the stdout_tail and stderr_tail
|
||||
mock_stdout_tail = MagicMock()
|
||||
mock_stderr_tail = MagicMock()
|
||||
mock_pcontext._tail_logs = [mock_stdout_tail, mock_stderr_tail]
|
||||
|
||||
# Call the start method
|
||||
PContext.start(mock_pcontext)
|
||||
@ -262,9 +267,10 @@ class SignalHandlingTest(TestCase):
|
||||
mock_threading.main_thread.return_value
|
||||
)
|
||||
mock_pcontext = MagicMock(spec=PContext)
|
||||
# Mock the _stdout_tail and _stderr_tail attributes
|
||||
mock_pcontext._stdout_tail = MagicMock()
|
||||
mock_pcontext._stderr_tail = MagicMock()
|
||||
# Mock the stdout_tail and stderr_tail
|
||||
mock_stdout_tail = MagicMock()
|
||||
mock_stderr_tail = MagicMock()
|
||||
mock_pcontext._tail_logs = [mock_stdout_tail, mock_stderr_tail]
|
||||
|
||||
# Set environment variable to include SIGUSR1 and SIGUSR2
|
||||
os.environ["TORCHELASTIC_SIGNALS_TO_HANDLE"] = "SIGUSR1,SIGUSR2"
|
||||
@ -323,8 +329,8 @@ class SignalHandlingTest(TestCase):
|
||||
# Verify _start was called
|
||||
mock_pcontext._start.assert_called_once()
|
||||
# Verify _stdout_tail.start() and _stderr_tail.start() were called
|
||||
mock_pcontext._stdout_tail.start.assert_called_once()
|
||||
mock_pcontext._stderr_tail.start.assert_called_once()
|
||||
mock_stdout_tail.start.assert_called_once()
|
||||
mock_stderr_tail.start.assert_called_once()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -15,7 +15,7 @@ from torch.testing._internal.common_utils import (
|
||||
TestCase,
|
||||
)
|
||||
from torch.testing._internal.distributed.fake_pg import FakeStore
|
||||
from torch.utils._debug_mode import DebugMode
|
||||
from torch.utils._debug_mode import _OpCall, _RedistributeCall, DebugMode
|
||||
from torch.utils._python_dispatch import TorchDispatchMode
|
||||
|
||||
|
||||
@ -60,6 +60,10 @@ class TestDTensorDebugMode(TestCase):
|
||||
aten::sum(t: f32[1, 32])""",
|
||||
)
|
||||
|
||||
self.assertTrue(isinstance(debug_mode.operators[0], _OpCall))
|
||||
self.assertTrue(isinstance(debug_mode.operators[2], _RedistributeCall))
|
||||
self.assertEqual(next(iter(debug_mode.operators[1])), torch.ops.aten.mm.default)
|
||||
|
||||
def test_debug_string_inside_context(self):
|
||||
mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
|
||||
|
||||
@ -330,6 +334,46 @@ class TestDTensorDebugMode(TestCase):
|
||||
f(x)
|
||||
self.assertEqual(len(debug_mode.debug_string()), 0)
|
||||
|
||||
def test_nn_module(self):
|
||||
class Foo(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.l1 = torch.nn.Linear(4, 4)
|
||||
self.l2 = torch.nn.Linear(4, 4)
|
||||
|
||||
def forward(self, x):
|
||||
return self.l2(self.l1(x))
|
||||
|
||||
class Bar(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.abc = Foo()
|
||||
self.xyz = torch.nn.Linear(4, 4)
|
||||
|
||||
def forward(self, x):
|
||||
return self.xyz(self.abc(x))
|
||||
|
||||
mod = Bar()
|
||||
inp = torch.randn(4, 4)
|
||||
with DebugMode(record_nn_module=True) as debug_mode:
|
||||
_ = mod(inp)
|
||||
|
||||
self.assertExpectedInline(
|
||||
debug_mode.debug_string(),
|
||||
"""\
|
||||
[nn.Mod] Bar
|
||||
[nn.Mod] Bar.abc
|
||||
[nn.Mod] Bar.abc.l1
|
||||
aten::t(t: f32[4, 4])
|
||||
aten::addmm(t: f32[4], t: f32[4, 4], t: f32[4, 4])
|
||||
[nn.Mod] Bar.abc.l2
|
||||
aten::t(t: f32[4, 4])
|
||||
aten::addmm(t: f32[4], t: f32[4, 4], t: f32[4, 4])
|
||||
[nn.Mod] Bar.xyz
|
||||
aten::t(t: f32[4, 4])
|
||||
aten::addmm(t: f32[4], t: f32[4, 4], t: f32[4, 4])""",
|
||||
)
|
||||
|
||||
|
||||
instantiate_parametrized_tests(TestDTensorDebugMode)
|
||||
|
||||
|
||||
@ -6,7 +6,10 @@ import unittest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.fx.traceback as fx_traceback
|
||||
from torch._dynamo.functional_export import _dynamo_graph_capture_for_export
|
||||
from torch._dynamo.functional_export import (
|
||||
_dynamo_graph_capture_for_export,
|
||||
dynamo_graph_capture_for_export,
|
||||
)
|
||||
from torch._functorch.aot_autograd import aot_export_joint_with_descriptors
|
||||
from torch._functorch.partitioners import min_cut_rematerialization_partition
|
||||
from torch._guards import tracing, TracingContext
|
||||
@ -96,6 +99,13 @@ def strict_export_and_aot_export_joint_with_descriptors(model, inputs):
|
||||
return aot_export_joint_with_descriptors_alone(ep.module(), inputs)
|
||||
|
||||
|
||||
def graph_capture_and_aot_export_joint_with_descriptors_v2(model, inputs):
|
||||
gm = dynamo_graph_capture_for_export(model)(inputs)
|
||||
fake_mode = gm.meta.get("fake_mode", None)
|
||||
with tracing(TracingContext(fake_mode)):
|
||||
return aot_export_joint_with_descriptors_alone(gm, inputs)
|
||||
|
||||
|
||||
def graph_capture_and_aot_export_joint_with_descriptors(model, inputs):
|
||||
with torch._dynamo.config.patch(install_free_tensors=True):
|
||||
# TODO: switch to use the official graph_capture API once it is ready
|
||||
@ -288,6 +298,7 @@ class DTensorExportTest(TestCase):
|
||||
@parametrize(
|
||||
"export_fn",
|
||||
[
|
||||
graph_capture_and_aot_export_joint_with_descriptors_v2,
|
||||
graph_capture_and_aot_export_joint_with_descriptors,
|
||||
aot_export_joint_with_descriptors_alone,
|
||||
],
|
||||
@ -307,7 +318,21 @@ class DTensorExportTest(TestCase):
|
||||
def test_annotate_aot_export_joint_with_descriptors_alone(self):
|
||||
self._run_test(aot_export_joint_with_descriptors_alone, True)
|
||||
|
||||
def test_dynamic_shapes(self):
|
||||
@parametrize(
|
||||
"export_fn_with_answer",
|
||||
[
|
||||
(
|
||||
graph_capture_and_aot_export_joint_with_descriptors_v2,
|
||||
"[[4, 10], [4], [10, 4], [10], [4, 10], [4], [10, 4], [10], [s64, 10], [s64, 10]]",
|
||||
),
|
||||
(
|
||||
graph_capture_and_aot_export_joint_with_descriptors,
|
||||
"[[4, 10], [4], [10, 4], [10], [s22, 10], [s22, 10]]",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_dynamic_shapes(self, export_fn_with_answer):
|
||||
export_fn, answer = export_fn_with_answer
|
||||
dp_degree = 2
|
||||
tp_degree = self.world_size // dp_degree
|
||||
|
||||
@ -331,7 +356,7 @@ class DTensorExportTest(TestCase):
|
||||
inputs = distribute_tensor(inputs, mesh_2d["tp"], placements=[Replicate()])
|
||||
torch._dynamo.mark_dynamic(inputs, 0, min=5, max=100)
|
||||
|
||||
joint_gm = graph_capture_and_aot_export_joint_with_descriptors(tp_model, inputs)
|
||||
joint_gm = export_fn(tp_model, inputs)
|
||||
|
||||
res = []
|
||||
for node in joint_gm.graph.nodes:
|
||||
@ -341,12 +366,16 @@ class DTensorExportTest(TestCase):
|
||||
if isinstance(fake_val, torch._subclasses.fake_tensor.FakeTensor):
|
||||
res.append(list(fake_val.shape))
|
||||
|
||||
self.assertExpectedInline(
|
||||
str(res),
|
||||
"""[[4, 10], [4], [10, 4], [10], [s22, 10], [s22, 10]]""",
|
||||
)
|
||||
self.assertEqual(str(res), answer)
|
||||
|
||||
def test_einsum_dtensor_export(self):
|
||||
@parametrize(
|
||||
"export_fn",
|
||||
[
|
||||
dynamo_graph_capture_for_export,
|
||||
_dynamo_graph_capture_for_export,
|
||||
],
|
||||
)
|
||||
def test_einsum_dtensor_export(self, export_fn):
|
||||
"""Test exporting a model with einsum that has DTensor inputs/outputs with side effects"""
|
||||
world_size = 4
|
||||
# Create device mesh
|
||||
@ -366,9 +395,7 @@ class DTensorExportTest(TestCase):
|
||||
output = model(x_dtensor, y_dtensor, z_dtensor)
|
||||
with torch._dynamo.config.patch(install_free_tensors=True):
|
||||
# TODO: switch to use the official graph_capture API once it is ready
|
||||
gm = _dynamo_graph_capture_for_export(model)(
|
||||
x_dtensor, y_dtensor, z_dtensor
|
||||
)
|
||||
gm = export_fn(model)(x_dtensor, y_dtensor, z_dtensor)
|
||||
output_gm = gm(x_dtensor, y_dtensor, z_dtensor)
|
||||
self.assertEqual(output, output_gm)
|
||||
|
||||
|
||||
@ -13,6 +13,7 @@ from torch.distributed.tensor import (
|
||||
distribute_module,
|
||||
distribute_tensor,
|
||||
DTensor,
|
||||
Partial,
|
||||
Replicate,
|
||||
Shard,
|
||||
)
|
||||
@ -649,6 +650,41 @@ class DistMathOpsTest(DTensorTestBase):
|
||||
self.assertEqual(len(comm_counts), 1)
|
||||
self.assertEqual(comm_counts[funcol.all_gather_into_tensor], 1)
|
||||
|
||||
@with_comms
|
||||
def test_vector_norm(self):
|
||||
device_mesh = self.build_device_mesh()
|
||||
|
||||
grad = torch.randn(12, 8)
|
||||
|
||||
sharded_grad = distribute_tensor(grad, device_mesh, [Shard(0)])
|
||||
|
||||
# non-sharded op
|
||||
out = torch.ops.aten.linalg_vector_norm(grad, 2)
|
||||
|
||||
# sharded op
|
||||
sharded_out = torch.ops.aten.linalg_vector_norm(sharded_grad, 2)
|
||||
|
||||
self.assertEqual(sharded_out.full_tensor(), out)
|
||||
|
||||
@with_comms
|
||||
def test_vector_norm_partial(self):
|
||||
device_mesh = self.build_device_mesh()
|
||||
|
||||
rank = device_mesh.get_local_rank()
|
||||
all_ranks = list(range(self.world_size))
|
||||
|
||||
local_grad = torch.tensor([rank, 1], dtype=torch.float32)
|
||||
full_grad = torch.tensor([sum(all_ranks), self.world_size], dtype=torch.float32)
|
||||
|
||||
partial_grad = DTensor.from_local(local_grad, device_mesh, [Partial()])
|
||||
|
||||
# full result
|
||||
out = torch.ops.aten.linalg_vector_norm(full_grad, 2)
|
||||
|
||||
# partial result
|
||||
partial_out = torch.ops.aten.linalg_vector_norm(partial_grad, 2)
|
||||
self.assertEqual(partial_out.full_tensor(), out)
|
||||
|
||||
@with_comms
|
||||
def test_foreach_norm(self):
|
||||
device_mesh = self.build_device_mesh()
|
||||
@ -668,6 +704,33 @@ class DistMathOpsTest(DTensorTestBase):
|
||||
for o, so in zip(out, sharded_out):
|
||||
self.assertEqual(so.full_tensor(), o)
|
||||
|
||||
@with_comms
|
||||
def test_foreach_norm_partial(self):
|
||||
device_mesh = self.build_device_mesh()
|
||||
|
||||
rank = device_mesh.get_local_rank()
|
||||
all_ranks = list(range(self.world_size))
|
||||
|
||||
local_grad0 = torch.tensor([rank, 1], dtype=torch.float32)
|
||||
local_grad1 = torch.tensor([rank + 1, 2], dtype=torch.float32)
|
||||
|
||||
grad0 = torch.tensor([sum(all_ranks), self.world_size], dtype=torch.float32)
|
||||
grad1 = torch.tensor(
|
||||
[sum(all_ranks) + self.world_size, 2 * self.world_size], dtype=torch.float32
|
||||
)
|
||||
|
||||
partial_grad0 = DTensor.from_local(local_grad0, device_mesh, [Partial()])
|
||||
partial_grad1 = DTensor.from_local(local_grad1, device_mesh, [Partial()])
|
||||
|
||||
# full result
|
||||
out = torch.ops.aten._foreach_norm([grad0, grad1], 2)
|
||||
|
||||
# partial result
|
||||
partial_out = torch.ops.aten._foreach_norm([partial_grad0, partial_grad1], 2)
|
||||
|
||||
for o, po in zip(out, partial_out):
|
||||
self.assertEqual(po.full_tensor(), o)
|
||||
|
||||
@with_comms
|
||||
def test_foreach_norm_different_mesh(self):
|
||||
mesh_shape = (2, self.world_size // 2)
|
||||
|
||||
@ -7,6 +7,10 @@ import itertools
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from torch.distributed._local_tensor import (
|
||||
maybe_disable_local_tensor_mode,
|
||||
maybe_run_for_local_tensor,
|
||||
)
|
||||
from torch.distributed.device_mesh import init_device_mesh
|
||||
from torch.distributed.tensor import (
|
||||
DeviceMesh,
|
||||
@ -29,7 +33,9 @@ from torch.testing._internal.common_utils import (
|
||||
TEST_HPU,
|
||||
)
|
||||
from torch.testing._internal.distributed._tensor.common_dtensor import (
|
||||
create_local_tensor_test_class,
|
||||
DTensorTestBase,
|
||||
map_local_tensor_for_rank,
|
||||
with_comms,
|
||||
)
|
||||
from torch.utils._debug_mode import DebugMode
|
||||
@ -163,7 +169,9 @@ class RedistributeTest(DTensorTestBase):
|
||||
)
|
||||
|
||||
# make local tensor as the element of the corresponding chunked list
|
||||
local_tensor = splitted_list[self.rank]
|
||||
local_tensor = map_local_tensor_for_rank(
|
||||
splitted_list, self.rank, lambda tl, r: tl[r]
|
||||
)
|
||||
replica_tensor = distribute_tensor(local_replica, device_mesh, replica_spec)
|
||||
with comm_mode:
|
||||
reshard_tensor = replica_tensor.redistribute(device_mesh, shard_spec)
|
||||
@ -407,7 +415,7 @@ class RedistributeTest(DTensorTestBase):
|
||||
def test_partial_to_shard(self, dtype):
|
||||
device_mesh = self.build_device_mesh()
|
||||
partial_spec = [Partial()]
|
||||
my_rank = device_mesh.get_rank()
|
||||
my_rank = self.rank
|
||||
|
||||
input_sizes_and_shard_dim = [
|
||||
((self.world_size * 3, 3), 0),
|
||||
@ -440,8 +448,13 @@ class RedistributeTest(DTensorTestBase):
|
||||
for idx in range(self.world_size)
|
||||
]
|
||||
|
||||
local_shape = list(input_size)
|
||||
local_shape[shard_dim] = chunk_sizes[my_rank]
|
||||
@maybe_run_for_local_tensor
|
||||
def _compute_local_shape(rank) -> list[int]:
|
||||
local_shape = list(input_size)
|
||||
local_shape[shard_dim] = chunk_sizes[rank]
|
||||
return local_shape
|
||||
|
||||
local_shape = _compute_local_shape(my_rank)
|
||||
|
||||
# test partial to shard, trigger reduce_scatter
|
||||
with comm_mode:
|
||||
@ -534,10 +547,12 @@ class RedistributeTest(DTensorTestBase):
|
||||
1,
|
||||
)
|
||||
else:
|
||||
self.assertEqual(
|
||||
comm_mode.get_comm_counts()[funcol.all_gather_into_tensor],
|
||||
1,
|
||||
)
|
||||
# TODO: Integrate local tensor with CommDebugMode
|
||||
if not self.is_local_tensor_enabled:
|
||||
self.assertEqual(
|
||||
comm_mode.get_comm_counts()[funcol.all_gather_into_tensor],
|
||||
1,
|
||||
)
|
||||
|
||||
# test 2d device mesh
|
||||
mesh_2d = DeviceMesh(
|
||||
@ -586,7 +601,8 @@ class RedistributeTest(DTensorTestBase):
|
||||
out_dt = sharded_dt.redistribute(mesh_2d, dst)
|
||||
|
||||
self.assertEqual(out_dt.placements, expected_dt.placements)
|
||||
self.assertEqual(comm_mode.get_total_counts(), comm_counts_2d[idx])
|
||||
if not self.is_local_tensor_enabled:
|
||||
self.assertEqual(comm_mode.get_total_counts(), comm_counts_2d[idx])
|
||||
|
||||
local_out_dt = out_dt.to_local()
|
||||
local_expected_dt = expected_dt.to_local()
|
||||
@ -1027,23 +1043,27 @@ class DistributeWithDeviceOrderTest(DTensorTestBase):
|
||||
def test_ordered_distribute_all_combination(self):
|
||||
"""Exhaustively test all possible sharding combinations and verify correctness"""
|
||||
torch.manual_seed(21)
|
||||
mesh = init_device_mesh(self.device_type, (2, 2, 2))
|
||||
input_tensor_shape = [
|
||||
# even sharding
|
||||
(16, 8),
|
||||
(8, 16, 32),
|
||||
(8, 32, 16, 16),
|
||||
# uneven sharding with padding
|
||||
(17, 5),
|
||||
(13, 2, 13),
|
||||
(33, 16, 8, 1),
|
||||
]
|
||||
|
||||
with maybe_disable_local_tensor_mode():
|
||||
mesh = init_device_mesh(self.device_type, (2, 2, 2))
|
||||
input_tensor_shape = [
|
||||
# even sharding
|
||||
(16, 8),
|
||||
(8, 16, 32),
|
||||
(8, 32, 16, 16),
|
||||
# uneven sharding with padding
|
||||
(17, 5),
|
||||
(13, 2, 13),
|
||||
(33, 16, 8, 1),
|
||||
]
|
||||
|
||||
# 1. Verify correctness of distribute_tensor from Tensor to DTensor.
|
||||
for tensor_shape in input_tensor_shape:
|
||||
input_data = torch.randn(tensor_shape, device=self.device_type)
|
||||
tensor_rank = input_data.ndim
|
||||
for shard_order in self.generate_shard_orders(mesh, tensor_rank):
|
||||
with maybe_disable_local_tensor_mode():
|
||||
shard_orders = self.generate_shard_orders(mesh, tensor_rank)
|
||||
for shard_order in shard_orders:
|
||||
sharded_dt = self.distribute_tensor(
|
||||
input_data.clone(), mesh, placements=None, shard_order=shard_order
|
||||
)
|
||||
@ -1057,7 +1077,9 @@ class DistributeWithDeviceOrderTest(DTensorTestBase):
|
||||
input_data = torch.randn(tensor_shape, device=self.device_type)
|
||||
tensor_rank = input_data.ndim
|
||||
prev_sharded_dt = None
|
||||
for shard_order in self.generate_shard_orders(mesh, tensor_rank):
|
||||
with maybe_disable_local_tensor_mode():
|
||||
shard_orders = self.generate_shard_orders(mesh, tensor_rank)
|
||||
for shard_order in shard_orders:
|
||||
if prev_sharded_dt is None:
|
||||
prev_sharded_dt = self.distribute_tensor(
|
||||
input_data.clone(),
|
||||
@ -1077,26 +1099,27 @@ class DistributeWithDeviceOrderTest(DTensorTestBase):
|
||||
"""Test mixing Partial in the original placements and do redistribute."""
|
||||
# This test takes 226s to complete on 8XA100...
|
||||
torch.manual_seed(21)
|
||||
mesh = init_device_mesh(self.device_type, (2, 2, 2))
|
||||
input_tensor_shape = [
|
||||
# even sharding
|
||||
(16, 8),
|
||||
(8, 16, 32),
|
||||
# uneven sharding with padding
|
||||
(17, 5),
|
||||
(13, 2, 13),
|
||||
(33, 16, 8, 1),
|
||||
]
|
||||
placement_choice = [
|
||||
Shard(0),
|
||||
Shard(1),
|
||||
Shard(2),
|
||||
Partial("sum"),
|
||||
Partial("min"),
|
||||
Replicate(),
|
||||
]
|
||||
# pick 3 for the 3D mesh
|
||||
partial_placement_comb = list(itertools.combinations(placement_choice, 3))
|
||||
with maybe_disable_local_tensor_mode():
|
||||
mesh = init_device_mesh(self.device_type, (2, 2, 2))
|
||||
input_tensor_shape = [
|
||||
# even sharding
|
||||
(16, 8),
|
||||
(8, 16, 32),
|
||||
# uneven sharding with padding
|
||||
(17, 5),
|
||||
(13, 2, 13),
|
||||
(33, 16, 8, 1),
|
||||
]
|
||||
placement_choice = [
|
||||
Shard(0),
|
||||
Shard(1),
|
||||
Shard(2),
|
||||
Partial("sum"),
|
||||
Partial("min"),
|
||||
Replicate(),
|
||||
]
|
||||
# pick 3 for the 3D mesh
|
||||
partial_placement_comb = list(itertools.combinations(placement_choice, 3))
|
||||
|
||||
def _is_valid_placement(placements, tensor_rank):
|
||||
# Check if placements is valid for tensor with rank `tensor_rank`
|
||||
@ -1112,7 +1135,9 @@ class DistributeWithDeviceOrderTest(DTensorTestBase):
|
||||
continue
|
||||
local_tensor = torch.randn(shape, device=self.device_type)
|
||||
full_tensor = DTensor.from_local(local_tensor, mesh, placements)
|
||||
for shard_order in self.generate_shard_orders(mesh, len(shape)):
|
||||
with maybe_disable_local_tensor_mode():
|
||||
shard_orders = self.generate_shard_orders(mesh, len(shape))
|
||||
for shard_order in shard_orders:
|
||||
sharded_dt = self.redistribute(
|
||||
full_tensor, mesh, placements=None, shard_order=shard_order
|
||||
)
|
||||
@ -1163,5 +1188,18 @@ class DistributeWithDeviceOrderTest(DTensorTestBase):
|
||||
self.assertEqual(x_ordered_dt.to_local(), x_strided_dt.to_local())
|
||||
|
||||
|
||||
RedistributeTestWithLocalTensor = create_local_tensor_test_class(
|
||||
RedistributeTest,
|
||||
)
|
||||
|
||||
MultiDimRedistributeTestWithLocalTensor = create_local_tensor_test_class(
|
||||
MultiDimRedistributeTest,
|
||||
skipped_tests=["test_multi_dim_mesh"],
|
||||
)
|
||||
|
||||
DistributeWithDeviceOrderTestWithLocalTensor = create_local_tensor_test_class(
|
||||
DistributeWithDeviceOrderTest,
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
||||
@ -44,9 +44,22 @@ device_type = str(get_devtype())
|
||||
|
||||
def apply_reordering_and_get_graph(graph, out_li) -> None:
|
||||
gm = graph.owning_module
|
||||
from torch._inductor.config import aten_distributed_optimizations as dist_opts
|
||||
from torch._inductor.fx_passes.overlap_scheduling import schedule_overlap_bucketing
|
||||
|
||||
schedule_overlap_bucketing(gm)
|
||||
# Read config values, only pass non-None values to use function defaults
|
||||
kwargs: dict[str, object] = {}
|
||||
config_keys = (
|
||||
"collective_bucketing",
|
||||
"max_compute_pre_fetch",
|
||||
"custom_runtime_estimation",
|
||||
"insert_overlap_deps",
|
||||
)
|
||||
for key in config_keys:
|
||||
if (val := getattr(dist_opts, key)) is not None:
|
||||
kwargs[key] = val
|
||||
|
||||
schedule_overlap_bucketing(gm, **kwargs)
|
||||
gm.graph.lint()
|
||||
out_li.append(str(gm.graph))
|
||||
|
||||
@ -62,14 +75,14 @@ def run_and_get_aten_graph(fn, *inputs):
|
||||
|
||||
def get_patches():
|
||||
return {
|
||||
"test_configs.estimate_aten_runtime": estimate_aten_runtime,
|
||||
"aten_distributed_optimizations.custom_runtime_estimation": estimate_aten_runtime,
|
||||
"reorder_for_locality": False,
|
||||
"triton.native_matmul": False,
|
||||
"reorder_for_compute_comm_overlap_passes": [],
|
||||
"compile_threads": 1,
|
||||
"force_disable_caches": True,
|
||||
# Messes up existing test strings
|
||||
"test_configs.aten_fx_overlap_insert_overlap_deps": False,
|
||||
"aten_distributed_optimizations.insert_overlap_deps": False,
|
||||
# interferes with testing, / custom estimation
|
||||
"test_configs.assume_bucketing_reduces_latency": False,
|
||||
}
|
||||
@ -351,21 +364,56 @@ graph():
|
||||
# these have no overlap opportunities
|
||||
self.assertEqual(counters["inductor"]["overlap_scheduling_bad_exposed"], 0)
|
||||
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
def test_overlap_scheduling_via_config(self):
|
||||
"""Test overlap scheduling enabled via config in post_grad pass."""
|
||||
|
||||
def func(a):
|
||||
ar = _functional_collectives.all_reduce(a, "sum", "0")
|
||||
b = torch.matmul(a, a)
|
||||
return torch.matmul(ar, b)
|
||||
|
||||
patches = {
|
||||
**get_patches(),
|
||||
"aten_distributed_optimizations.enable_overlap_scheduling": True,
|
||||
}
|
||||
|
||||
with _dynamo_dist_per_rank_init(
|
||||
self.rank,
|
||||
self.world_size,
|
||||
self.backend(device_type),
|
||||
fake_pg=not at_least_x_gpu(2),
|
||||
):
|
||||
inputs = torch.ones(4, 4, dtype=torch.float, device=device_type) + self.rank
|
||||
|
||||
with torch._inductor.config.patch(patches):
|
||||
compiled_func = torch.compile(func)
|
||||
out, code = run_and_get_code(compiled_func, inputs)
|
||||
|
||||
# Verify that wait_tensor is sinked below matmul
|
||||
FileCheck().check("all_reduce").check("mm").check("wait_tensor").check(
|
||||
"mm"
|
||||
).run(code[0])
|
||||
|
||||
correct = func(inputs)
|
||||
self.assertTrue(same(out, correct))
|
||||
self.assertEqual(counters["inductor"]["overlap_scheduling_exposed"], 0)
|
||||
|
||||
|
||||
def get_bucket_patches(compute_multiplier=1.0):
|
||||
estimate_aten_runtime_part = functools.partial(
|
||||
estimate_aten_runtime, compute_multiplier=compute_multiplier
|
||||
)
|
||||
return {
|
||||
"test_configs.estimate_aten_runtime": estimate_aten_runtime_part,
|
||||
"test_configs.aten_fx_overlap_preserving_bucketing": True,
|
||||
"aten_distributed_optimizations.custom_runtime_estimation": estimate_aten_runtime_part,
|
||||
"aten_distributed_optimizations.collective_bucketing": True,
|
||||
"reorder_for_locality": False,
|
||||
"triton.native_matmul": False,
|
||||
"reorder_for_compute_comm_overlap_passes": [],
|
||||
"compile_threads": 1,
|
||||
"force_disable_caches": True,
|
||||
# messes up test strings
|
||||
"test_configs.aten_fx_overlap_insert_overlap_deps": False,
|
||||
"aten_distributed_optimizations.insert_overlap_deps": False,
|
||||
# interferes with testing, / custom estimation
|
||||
"test_configs.assume_bucketing_reduces_latency": False,
|
||||
}
|
||||
@ -806,7 +854,7 @@ class TestComputeCommReorderingBucketing(TestComputeCommReorderingMultiProc):
|
||||
fake_pg=not at_least_x_gpu(2),
|
||||
),
|
||||
torch._inductor.config.patch(
|
||||
"test_configs.aten_fx_overlap_insert_overlap_deps", True
|
||||
"aten_distributed_optimizations.insert_overlap_deps", True
|
||||
),
|
||||
torch._inductor.config.patch(post_grad_custom_post_pass=apply),
|
||||
):
|
||||
|
||||
@ -3817,27 +3817,6 @@ class NcclProcessGroupWithDispatchedCollectivesTests(
|
||||
dist.all_gather_into_tensor(output_tensor, tensor)
|
||||
self.assertEqual(output_tensor, tensor)
|
||||
|
||||
@requires_nccl()
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_allgather_noncontig(self):
|
||||
store = dist.FileStore(self.file_name, self.world_size)
|
||||
dist.init_process_group(
|
||||
"nccl",
|
||||
world_size=self.world_size,
|
||||
rank=self.rank,
|
||||
store=store,
|
||||
)
|
||||
device = "cuda"
|
||||
tensor = (
|
||||
torch.arange(0, 16, device=torch.device(device))
|
||||
.view(2, 2, 2, 2)
|
||||
.to(memory_format=torch.channels_last)
|
||||
)
|
||||
tensor_list = [torch.empty_like(tensor) for _ in range(self.world_size)]
|
||||
dist.all_gather(tensor_list, tensor)
|
||||
for o in tensor_list:
|
||||
self.assertEqual(o, tensor)
|
||||
|
||||
@requires_nccl()
|
||||
@skip_if_lt_x_gpu(1)
|
||||
@parametrize("float8_dtype", [torch.float8_e4m3fn, torch.float8_e5m2])
|
||||
|
||||
@ -462,7 +462,9 @@ class DeviceMeshTestNDim(DTensorTestBase):
|
||||
ep_mesh_2 = DeviceMesh(self.device_type, mesh_group_2)
|
||||
ep_mesh = ep_mesh_1 if self.rank < self.world_size // 2 else ep_mesh_2
|
||||
# ep_mesh is considered different from mesh_2d["TP"]
|
||||
self.assertEqual(mesh_2d["TP"]._flatten_mesh_list, ep_mesh._flatten_mesh_list)
|
||||
self.assertEqual(
|
||||
mesh_2d["TP"].mesh.flatten().tolist(), ep_mesh.mesh.flatten().tolist()
|
||||
)
|
||||
self.assertEqual(mesh_2d["TP"]._layout, ep_mesh._layout)
|
||||
self.assertEqual(mesh_2d["TP"].mesh.shape, ep_mesh.mesh.shape)
|
||||
self.assertEqual(mesh_2d["TP"].device_type, ep_mesh.device_type)
|
||||
@ -477,7 +479,7 @@ class DeviceMeshTestNDim(DTensorTestBase):
|
||||
another_mesh_1 if self.rank < self.world_size // 2 else another_mesh_2
|
||||
)
|
||||
# another_mesh is considered the same as ep_mesh
|
||||
self.assertEqual(ep_mesh._flatten_mesh_list, another_mesh._flatten_mesh_list)
|
||||
self.assertEqual(ep_mesh._flatten_rank_map, another_mesh._flatten_rank_map)
|
||||
self.assertEqual(ep_mesh._layout, another_mesh._layout)
|
||||
self.assertEqual(ep_mesh.mesh.shape, another_mesh.mesh.shape)
|
||||
self.assertEqual(ep_mesh.device_type, another_mesh.device_type)
|
||||
@ -1049,6 +1051,34 @@ class TestDeviceMeshGetItem(DTensorTestBase):
|
||||
)
|
||||
w.wait()
|
||||
|
||||
@with_comms
|
||||
def test_concatenate_2d(self):
|
||||
mesh_shape = (2, 4)
|
||||
mesh_dim_names = ("dp", "tp")
|
||||
mesh_2d = init_device_mesh(
|
||||
self.device_type, mesh_shape, mesh_dim_names=mesh_dim_names
|
||||
)
|
||||
concatenated_mesh = DeviceMesh._concatenate([mesh_2d["dp"], mesh_2d["tp"]])
|
||||
self.assertEqual(concatenated_mesh.mesh, mesh_2d.mesh)
|
||||
self.assertEqual(concatenated_mesh.get_group("dp"), mesh_2d.get_group("dp"))
|
||||
self.assertEqual(concatenated_mesh.get_group("tp"), mesh_2d.get_group("tp"))
|
||||
|
||||
@with_comms
|
||||
def test_concatenate_3d(self):
|
||||
mesh_shape = (2, 2, 2)
|
||||
mesh_dim_names = ("pp", "dp", "tp")
|
||||
mesh_3d = init_device_mesh(
|
||||
self.device_type, mesh_shape, mesh_dim_names=mesh_dim_names
|
||||
)
|
||||
concatenated_mesh = DeviceMesh._concatenate([mesh_3d["dp"], mesh_3d["tp"]])
|
||||
dp_tp_mesh = mesh_3d["dp", "tp"]
|
||||
self.assertEqual(concatenated_mesh.mesh, dp_tp_mesh.mesh)
|
||||
self.assertEqual(concatenated_mesh.get_group("dp"), dp_tp_mesh.get_group("dp"))
|
||||
self.assertEqual(concatenated_mesh.get_group("tp"), dp_tp_mesh.get_group("tp"))
|
||||
self.assertEqual(
|
||||
mesh_3d, DeviceMesh._concatenate([mesh_3d["pp", "dp"], mesh_3d["tp"]])
|
||||
)
|
||||
|
||||
@with_comms
|
||||
def test_reconstruct_mesh_with_flatten_dim(self):
|
||||
mesh_3d = init_device_mesh(
|
||||
|
||||
@ -471,6 +471,67 @@ from user code:
|
||||
assert hasattr(backend_result.compiled_fn, "serialize")
|
||||
self.assertIsNotNone(backend_result.compiled_fn.serialize)
|
||||
|
||||
def test_fullgraph_capture_with_pytree_module(self):
|
||||
from torch._dynamo.functional_export import dynamo_graph_capture_for_export
|
||||
|
||||
class Module(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.linear = torch.nn.Linear(3, 3)
|
||||
self.linear1 = torch.nn.Linear(3, 3)
|
||||
self.linear2 = torch.nn.Linear(3, 3)
|
||||
self.linear3 = torch.nn.Linear(3, 3)
|
||||
|
||||
def forward(self, x):
|
||||
return {
|
||||
"y": self.linear2(x[2] + 1),
|
||||
"z": self.linear3(x[1] - 1),
|
||||
"w": self.linear(x[0]["b"] + 2),
|
||||
"v": self.linear1(x[0]["a"] - 2),
|
||||
}
|
||||
|
||||
mod = Module()
|
||||
compiled_mod = dynamo_graph_capture_for_export(mod)(
|
||||
(
|
||||
{"a": torch.randn(3, 3), "b": torch.randn(3, 3)},
|
||||
torch.randn(3, 3),
|
||||
torch.randn(3, 3),
|
||||
)
|
||||
)
|
||||
|
||||
inputs = (
|
||||
{"a": torch.randn(3, 3), "b": torch.randn(3, 3)},
|
||||
torch.randn(3, 3),
|
||||
torch.randn(3, 3),
|
||||
)
|
||||
self.assertEqual(compiled_mod(inputs), mod(inputs))
|
||||
|
||||
def test_fullgraph_capture_with_pytree_func(self):
|
||||
from torch._dynamo.functional_export import dynamo_graph_capture_for_export
|
||||
|
||||
def foo(x):
|
||||
return {
|
||||
"y": x[2] + 1,
|
||||
"z": x[1] - 1,
|
||||
"w": x[0]["b"] + 2,
|
||||
"v": x[0]["a"] - 2,
|
||||
}
|
||||
|
||||
compiled_foo = dynamo_graph_capture_for_export(foo)(
|
||||
(
|
||||
{"a": torch.randn(4, 3), "b": torch.randn(3, 2)},
|
||||
torch.randn(2, 3),
|
||||
torch.randn(3, 4),
|
||||
)
|
||||
)
|
||||
|
||||
inputs = (
|
||||
{"a": torch.randn(4, 3), "b": torch.randn(3, 2)},
|
||||
torch.randn(2, 3),
|
||||
torch.randn(3, 4),
|
||||
)
|
||||
self.assertEqual(compiled_foo(inputs), foo(inputs))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
||||
@ -1076,7 +1076,22 @@ class DictTests(torch._dynamo.test_case.TestCase):
|
||||
def test_newly_constructed_default_dict(self):
|
||||
def f(x):
|
||||
d = defaultdict(list)
|
||||
d[0] = 42
|
||||
d[0] = [
|
||||
42,
|
||||
]
|
||||
return x + 1, d
|
||||
|
||||
x = torch.ones(2)
|
||||
ref = f(x)
|
||||
res = torch.compile(f, backend="eager", fullgraph=True)(x)
|
||||
|
||||
self.assertEqual(ref, res)
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_newly_constructed_default_dict_with_dict(self):
|
||||
def f(x):
|
||||
d = defaultdict(dict, {2: {"a": 1}})
|
||||
d[0] = {"b": 2}
|
||||
return x + 1, d
|
||||
|
||||
x = torch.ones(2)
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
# Owner(s): ["module: dynamo"]
|
||||
|
||||
import dataclasses
|
||||
import importlib
|
||||
import pickle
|
||||
import sys
|
||||
import tempfile
|
||||
@ -748,7 +747,7 @@ class TestGuardSerialization(TestGuardSerializationBase):
|
||||
):
|
||||
self._test_serialization("NN_MODULE", fn, m, x)
|
||||
|
||||
def test_function_match(self):
|
||||
def test_class_match(self):
|
||||
def fn(x):
|
||||
# usage of this context manager installs a FUNCTION_MATCH guard
|
||||
with torch.no_grad():
|
||||
@ -760,9 +759,9 @@ class TestGuardSerialization(TestGuardSerializationBase):
|
||||
# we don't support FUNCTION_MATCH because it adds an ID_MATCH guard, and we don't
|
||||
# support that in serialization
|
||||
with self.assertRaisesRegex(
|
||||
PackageError, "FUNCTION_MATCH guard cannot be serialized."
|
||||
PackageError, "CLASS_MATCH guard cannot be serialized."
|
||||
):
|
||||
self._test_serialization("FUNCTION_MATCH", fn, x)
|
||||
self._test_serialization("CLASS_MATCH", fn, x)
|
||||
|
||||
def test_closure_match(self):
|
||||
def fn(x):
|
||||
@ -958,12 +957,12 @@ class TestGuardSerialization(TestGuardSerializationBase):
|
||||
self._test_check_fn(ref, loaded, {"x": torch.randn(3)}, True)
|
||||
|
||||
def fn(x):
|
||||
# usage of this context manager installs a FUNCTION_MATCH guard
|
||||
# usage of this context manager installs a CLASS_MATCH guard
|
||||
with torch.no_grad():
|
||||
y = x * 2
|
||||
return y
|
||||
|
||||
ref, loaded = self._test_serialization("FUNCTION_MATCH", fn, torch.randn(3))
|
||||
ref, loaded = self._test_serialization("CLASS_MATCH", fn, torch.randn(3))
|
||||
self._test_check_fn(ref, loaded, {"x": torch.randn(3)}, True)
|
||||
|
||||
def test_dispatch_key_set_match(self):
|
||||
@ -983,23 +982,6 @@ class TestGuardSerialization(TestGuardSerializationBase):
|
||||
dks = torch._C._dispatch_keys(x)
|
||||
self._test_check_fn(ref, loaded, {"x": x, "dks": dks}, False)
|
||||
|
||||
def test_name_match(self):
|
||||
def fn(x, y):
|
||||
return torch.cond(x, lambda x: y + 1, lambda x: y - 1, (y,))
|
||||
|
||||
x = torch.tensor(True)
|
||||
y = torch.randn(3)
|
||||
ref, loaded = self._test_serialization("NAME_MATCH", fn, x, y)
|
||||
|
||||
self._test_check_fn(ref, loaded, {"x": x, "y": y}, True)
|
||||
|
||||
op = importlib.import_module("torch._higher_order_ops.cond").cond_op
|
||||
prev, op.__name__ = op.__name__, ""
|
||||
try:
|
||||
self._test_check_fn(ref, loaded, {"x": x, "y": y}, False)
|
||||
finally:
|
||||
op.__name__ = prev
|
||||
|
||||
def test_dual_level(self):
|
||||
def fn(x):
|
||||
with torch.autograd.forward_ad.dual_level():
|
||||
@ -1485,7 +1467,8 @@ class TestGuardSerialization(TestGuardSerializationBase):
|
||||
torch._dynamo.optimize(
|
||||
package=package,
|
||||
guard_filter_fn=lambda gs: [
|
||||
x.guard_type not in ("CLOSURE_MATCH", "ID_MATCH") for x in gs
|
||||
x.guard_type not in ("CLOSURE_MATCH", "ID_MATCH", "CLASS_MATCH")
|
||||
for x in gs
|
||||
],
|
||||
)(foo)(ddp_model, x)
|
||||
self.assertEqual(len(package._codes[foo.__code__].guarded_codes), 1)
|
||||
|
||||
@ -245,8 +245,7 @@ due to:
|
||||
Traceback (most recent call last):
|
||||
File "test_logging.py", line N, in throw
|
||||
raise AssertionError
|
||||
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
|
||||
LoweringException: AssertionError:
|
||||
torch._inductor.exc.InductorError: LoweringException: AssertionError:
|
||||
target: aten.round.default
|
||||
args[0]: TensorBox(StorageBox(
|
||||
InputBuffer(name='primals_1', layout=FixedLayout('cpu', torch.float32, size=[1000, 1000], stride=[1000, 1]))
|
||||
|
||||
@ -658,6 +658,31 @@ graph():
|
||||
fn = torch.compile(f, backend="eager", dynamic=True, fullgraph=True)
|
||||
fn(torch.tensor([5]), 5)
|
||||
|
||||
@torch._dynamo.config.patch(capture_scalar_outputs=True)
|
||||
@torch._dynamo.config.patch(capture_dynamic_output_shape_ops=True)
|
||||
def test_cond_runtime_assert_generation(self):
|
||||
def fn(x):
|
||||
y = x.nonzero() # unbacked binding u0
|
||||
torch._check(y.shape[0] % 4 == 0)
|
||||
|
||||
return torch.randn(y.shape[0])
|
||||
|
||||
@torch.compile(dynamic=True, backend="aot_eager")
|
||||
def foo(x):
|
||||
b = torch.cond(
|
||||
pred=(x.shape[0] % 4 == 0),
|
||||
true_fn=lambda: fn(x),
|
||||
false_fn=lambda: fn(x),
|
||||
)
|
||||
|
||||
return b
|
||||
|
||||
foo(torch.randn(4, 4))
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, "Runtime assertion failed for expression Eq(Mod(u1, 4), 0)*"
|
||||
):
|
||||
foo(torch.randn(5, 5))
|
||||
|
||||
def test_tensor_setattr_getset_descriptor(self):
|
||||
# Tensor attribute `real` has special getter/setter for complex dtype.
|
||||
def f(x):
|
||||
|
||||
@ -326,7 +326,8 @@ def add(x, y):
|
||||
|
||||
def guard_filter_fn(guards):
|
||||
return [
|
||||
guard.guard_type not in ("CLOSURE_MATCH", "FUNCTION_MATCH")
|
||||
guard.guard_type
|
||||
not in ("CLOSURE_MATCH", "FUNCTION_MATCH", "MODULE_MATCH")
|
||||
for guard in guards
|
||||
]
|
||||
|
||||
|
||||
@ -7,11 +7,17 @@ import torch._inductor.test_case
|
||||
import torch.fx.traceback as fx_traceback
|
||||
import torch.utils.checkpoint
|
||||
from torch._dynamo.backends.common import aot_autograd
|
||||
from torch._guards import detect_fake_mode
|
||||
from torch._inductor.test_case import run_tests
|
||||
from torch._inductor.utils import run_fw_bw_and_get_code
|
||||
from torch.fx._graph_pickler import GraphPickler
|
||||
from torch.fx.passes.regional_inductor import regional_inductor
|
||||
from torch.nn.attention.flex_attention import create_block_mask, flex_attention
|
||||
from torch.testing._internal.common_utils import skipIfTorchDynamo
|
||||
from torch.testing._internal.common_utils import (
|
||||
instantiate_parametrized_tests,
|
||||
parametrize,
|
||||
skipIfTorchDynamo,
|
||||
)
|
||||
from torch.testing._internal.triton_utils import requires_cuda_and_triton
|
||||
|
||||
|
||||
@ -36,7 +42,29 @@ from torch.testing._internal.triton_utils import requires_cuda_and_triton
|
||||
# f) disallow nested regional compile
|
||||
|
||||
|
||||
def aot_eager_regional_inductor():
|
||||
def aot_eager_regional_inductor(serialize=False):
|
||||
if serialize:
|
||||
|
||||
def regional_inductor_pickle(gm, *example_args):
|
||||
result = regional_inductor(gm, *example_args)
|
||||
serialized = GraphPickler.dumps(result)
|
||||
|
||||
fake_mode = detect_fake_mode(example_args)
|
||||
assert fake_mode is not None
|
||||
# Serialize and deserialize the result to confirm pickling works
|
||||
# Use a fresh tracing context on the new process
|
||||
context = torch._guards.TracingContext(fake_mode)
|
||||
with torch._guards.tracing(context):
|
||||
result = GraphPickler.loads(serialized, fake_mode)
|
||||
assert isinstance(result, torch.fx.GraphModule)
|
||||
result.recompile()
|
||||
return result
|
||||
|
||||
return aot_autograd(
|
||||
fw_compiler=regional_inductor_pickle,
|
||||
bw_compiler=regional_inductor_pickle,
|
||||
)
|
||||
|
||||
return aot_autograd(
|
||||
fw_compiler=regional_inductor,
|
||||
bw_compiler=regional_inductor,
|
||||
@ -44,8 +72,10 @@ def aot_eager_regional_inductor():
|
||||
|
||||
|
||||
@skipIfTorchDynamo("Not a suitable dynamo wrapped test")
|
||||
@instantiate_parametrized_tests
|
||||
class RegionalInductorTests(torch._inductor.test_case.TestCase):
|
||||
def test_simple(self):
|
||||
@parametrize("serialize", [False, True])
|
||||
def test_simple(self, serialize):
|
||||
def fn(x, y):
|
||||
sin = torch.sin(x)
|
||||
|
||||
@ -56,7 +86,7 @@ class RegionalInductorTests(torch._inductor.test_case.TestCase):
|
||||
return torch.sin(add)
|
||||
|
||||
opt_fn = torch.compile(
|
||||
fn, backend=aot_eager_regional_inductor(), fullgraph=True
|
||||
fn, backend=aot_eager_regional_inductor(serialize=serialize), fullgraph=True
|
||||
)
|
||||
x = torch.randn(10, requires_grad=True)
|
||||
y = torch.randn(10, requires_grad=True)
|
||||
@ -65,7 +95,8 @@ class RegionalInductorTests(torch._inductor.test_case.TestCase):
|
||||
_, codes = run_fw_bw_and_get_code(lambda: opt_fn(x, y))
|
||||
self.assertEqual(len(codes), 2)
|
||||
|
||||
def test_repeated_blocks(self):
|
||||
@parametrize("serialize", [False, True])
|
||||
def test_repeated_blocks(self, serialize):
|
||||
def fn(x, y):
|
||||
sin = torch.sin(x)
|
||||
|
||||
@ -86,7 +117,9 @@ class RegionalInductorTests(torch._inductor.test_case.TestCase):
|
||||
mod = Mod()
|
||||
|
||||
opt_mod = torch.compile(
|
||||
mod, backend=aot_eager_regional_inductor(), fullgraph=True
|
||||
mod,
|
||||
backend=aot_eager_regional_inductor(serialize=serialize),
|
||||
fullgraph=True,
|
||||
)
|
||||
x = torch.randn(10, requires_grad=True)
|
||||
y = torch.randn(10, requires_grad=True)
|
||||
@ -96,7 +129,8 @@ class RegionalInductorTests(torch._inductor.test_case.TestCase):
|
||||
_, codes = run_fw_bw_and_get_code(lambda: opt_mod(x, y))
|
||||
self.assertEqual(len(codes), 4)
|
||||
|
||||
def test_invoke_subgraph(self):
|
||||
@parametrize("serialize", [False, True])
|
||||
def test_invoke_subgraph(self, serialize):
|
||||
# Checks that get_attr nodes custom metadata is propagated
|
||||
@torch.compiler.nested_compile_region
|
||||
def gn(x):
|
||||
@ -109,15 +143,17 @@ class RegionalInductorTests(torch._inductor.test_case.TestCase):
|
||||
return torch.sigmoid(z)
|
||||
|
||||
opt_fn = torch.compile(
|
||||
fn, backend=aot_eager_regional_inductor(), fullgraph=True
|
||||
fn, backend=aot_eager_regional_inductor(serialize=serialize), fullgraph=True
|
||||
)
|
||||
x = torch.randn(10, requires_grad=True)
|
||||
|
||||
_, codes = run_fw_bw_and_get_code(lambda: opt_fn(x))
|
||||
self.assertEqual(len(codes), 2)
|
||||
|
||||
def test_invoke_subgraph_inner(self):
|
||||
@parametrize("serialize", [False, True])
|
||||
def test_invoke_subgraph_inner(self, serialize):
|
||||
# Checks that the inductor regions are searched recursively.
|
||||
|
||||
@torch.compiler.nested_compile_region
|
||||
def gn(x):
|
||||
with fx_traceback.annotate({"compile_with_inductor": 0}):
|
||||
@ -131,7 +167,7 @@ class RegionalInductorTests(torch._inductor.test_case.TestCase):
|
||||
return torch.sigmoid(x)
|
||||
|
||||
opt_fn = torch.compile(
|
||||
fn, backend=aot_eager_regional_inductor(), fullgraph=True
|
||||
fn, backend=aot_eager_regional_inductor(serialize=serialize), fullgraph=True
|
||||
)
|
||||
x = torch.randn(10, requires_grad=True)
|
||||
|
||||
@ -141,7 +177,8 @@ class RegionalInductorTests(torch._inductor.test_case.TestCase):
|
||||
self.assertEqual(len(codes), 2)
|
||||
|
||||
@requires_cuda_and_triton
|
||||
def test_flex_attention(self):
|
||||
@parametrize("serialize", [False, True])
|
||||
def test_flex_attention(self, serialize):
|
||||
def _squared(score, b, h, m, n):
|
||||
return score * score
|
||||
|
||||
@ -170,7 +207,7 @@ class RegionalInductorTests(torch._inductor.test_case.TestCase):
|
||||
|
||||
opt_fn = torch.compile(
|
||||
fn,
|
||||
backend=aot_eager_regional_inductor(),
|
||||
backend=aot_eager_regional_inductor(serialize),
|
||||
fullgraph=True,
|
||||
)
|
||||
|
||||
@ -179,7 +216,8 @@ class RegionalInductorTests(torch._inductor.test_case.TestCase):
|
||||
self.assertEqual(len(codes), 2)
|
||||
|
||||
@requires_cuda_and_triton
|
||||
def test_selective_ac_flex(self):
|
||||
@parametrize("serialize", [False, True])
|
||||
def test_selective_ac_flex(self, serialize):
|
||||
class FlexAttentionModule(torch.nn.Module):
|
||||
def __init__(self, hidden_size, num_heads):
|
||||
super().__init__()
|
||||
|
||||
@ -8101,14 +8101,6 @@ class ReproTestsDevice(torch._dynamo.test_case.TestCase):
|
||||
res = gm(x, y)
|
||||
self.assertEqual(res, ref)
|
||||
|
||||
def test_current_accelerator(self):
|
||||
@torch.compile(backend="eager", fullgraph=True)
|
||||
def fn(x):
|
||||
torch.accelerator.current_accelerator()
|
||||
return x + 1
|
||||
|
||||
self.assertEqual(fn(torch.ones(3)), torch.ones(3) + 1)
|
||||
|
||||
|
||||
instantiate_parametrized_tests(ReproTests)
|
||||
|
||||
|
||||
@ -55,7 +55,7 @@ torch.fx.node.Node.append(self, x: 'Node') -> None
|
||||
torch.fx.node.Node.format_node(self, placeholder_names: Optional[List[str]] = None, maybe_return_typename: Optional[List[str]] = None, include_tensor_metadata: bool = False) -> Optional[str]
|
||||
torch.fx.node.Node.insert_arg(self, idx: int, arg: torch.fx.node.Argument) -> None
|
||||
torch.fx.node.Node.prepend(self, x: 'Node') -> None
|
||||
torch.fx.node.Node.replace_all_uses_with(self, replace_with: 'Node', delete_user_cb: Callable[[Node], bool] = <function <lambda>>, propagate_meta: bool = False) -> List[Node]
|
||||
torch.fx.node.Node.replace_all_uses_with(self, replace_with: 'Node', delete_user_cb: Optional[Callable[[Node], bool]] = None, propagate_meta: bool = False) -> List[Node]
|
||||
torch.fx.node.Node.replace_input_with(self, old_input: 'Node', new_input: 'Node') -> None
|
||||
torch.fx.node.Node.update_arg(self, idx: int, arg: torch.fx.node.Argument) -> None
|
||||
torch.fx.node.Node.update_kwarg(self, key: str, arg: torch.fx.node.Argument) -> None
|
||||
|
||||
@ -402,6 +402,43 @@ def forward(self, x):
|
||||
|
||||
self.assertEqual(res_export, res_eager)
|
||||
|
||||
def test_dynamo_graph_capture(self):
|
||||
from torch._dynamo.functional_export import dynamo_graph_capture_for_export
|
||||
|
||||
class Foo(torch.nn.Module):
|
||||
def forward(self, dct, lst, bleh):
|
||||
x = dct["a"] * lst[1][0]
|
||||
y = dct["b"] * lst[0]
|
||||
out_dict = {}
|
||||
|
||||
# Mutate and get a new entry in there
|
||||
lst_copy = lst.copy()
|
||||
lst_copy.append(lst[0])
|
||||
out_dict["a"] = x
|
||||
out_dict["b"] = y
|
||||
return (
|
||||
dct["a"],
|
||||
out_dict["b"],
|
||||
bleh,
|
||||
lst_copy[-1],
|
||||
out_dict["a"],
|
||||
[5, 6],
|
||||
)
|
||||
|
||||
foo = Foo()
|
||||
|
||||
def make_inputs():
|
||||
return (
|
||||
{"a": torch.randn(2, 3), "b": torch.randn(2, 3)},
|
||||
[torch.randn(2, 3), (torch.randn(2, 3),)],
|
||||
torch.randn(2, 3),
|
||||
)
|
||||
|
||||
trace_inputs = make_inputs()
|
||||
gm = dynamo_graph_capture_for_export(foo)(*trace_inputs)
|
||||
test_inputs = make_inputs()
|
||||
self.assertEqual(gm(*test_inputs), foo(*test_inputs))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
||||
@ -1934,22 +1934,13 @@ graph():
|
||||
# TODO (tmanlaibaatar) this kinda sucks but today there is no good way to get
|
||||
# good source name. We should have an util that post processes dynamo source names
|
||||
# to be more readable.
|
||||
if is_strict_v2_test(self._testMethodName) or is_inline_and_install_strict_test(
|
||||
self._testMethodName
|
||||
with self.assertWarnsRegex(
|
||||
UserWarning,
|
||||
r"(L\['self']\._modules\['_export_root']\.forward\.__func__\.__closure__\[1\]\.cell_contents\.bank"
|
||||
r"|L\['self']\._modules\['_export_root']\.forward\.__func__\.__closure__\[1\]\.cell_contents\.bank_dict"
|
||||
r"|L\['self']\._modules\['_export_root']\.forward\.__func__\.__closure__\[0\]\.cell_contents)",
|
||||
):
|
||||
with self.assertWarnsRegex(
|
||||
UserWarning,
|
||||
r"(L\['self']\._modules\['_export_root']\.forward\.__func__\.__closure__\[1\]\.cell_contents\.bank"
|
||||
r"|L\['self']\._modules\['_export_root']\.forward\.__func__\.__closure__\[1\]\.cell_contents\.bank_dict"
|
||||
r"|L\['self']\._modules\['_export_root']\.forward\.__func__\.__closure__\[0\]\.cell_contents)",
|
||||
):
|
||||
ref(torch.randn(4, 4), torch.randn(4, 4))
|
||||
else:
|
||||
with self.assertWarnsRegex(
|
||||
UserWarning,
|
||||
r"(L\['global_list'\]|L\['self'\]\.bank|L\['self'\]\.bank_dict)",
|
||||
):
|
||||
ref(torch.randn(4, 4), torch.randn(4, 4))
|
||||
ref(torch.randn(4, 4), torch.randn(4, 4))
|
||||
|
||||
def test_mask_nonzero_static(self):
|
||||
class TestModule(torch.nn.Module):
|
||||
@ -17262,10 +17253,17 @@ def forward(self, x):
|
||||
lengths=torch.IntTensor([0, 2, 0, 1, 1, 1, 0, 3]),
|
||||
offsets=torch.IntTensor([0, 0, 2, 2, 3, 4, 5, 5, 8]),
|
||||
)
|
||||
with self.assertWarnsRegex(
|
||||
UserWarning,
|
||||
"While exporting, we found certain side effects happened in the model.forward. "
|
||||
"Here are the list of potential sources you can double check: \[\"L\['jt'\]\"\]",
|
||||
# TODO tmanlaibaatar
|
||||
# because we call unflatten in the flat tracer, it creates a new JaggedTensor
|
||||
# and it gets pruned as it is not reachable. Not sure what the right way to fix
|
||||
# is but since it is just warning, probably ok to xfail it for now.
|
||||
with (
|
||||
self.assertWarnsRegex(
|
||||
UserWarning,
|
||||
"While exporting, we found certain side effects happened in the model.forward. "
|
||||
"Here are the list of potential sources you can double check: \[\"L\['jt'\]\"\]",
|
||||
),
|
||||
torch._export.config.patch(use_new_tracer_experimental=False),
|
||||
):
|
||||
_ = torch.export.export(foo, (jt,), strict=True)
|
||||
|
||||
|
||||
@ -1060,6 +1060,27 @@ def forward(self, x):
|
||||
inp = (torch.randn(3), None)
|
||||
self.assertTrue(torch.allclose(unf(*inp), M1()(*inp)))
|
||||
|
||||
def test_unflatten_root_module_type(self) -> None:
|
||||
class M(torch.nn.Module):
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return x + x
|
||||
|
||||
class M1(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.m = M()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.m(x)
|
||||
|
||||
inp = (torch.randn(3),)
|
||||
ep = torch.export.export(M1(), inp)
|
||||
unf = torch.export.unflatten(ep)
|
||||
self.assertIsNotNone(unf.type_name())
|
||||
self.assertEqual(unf.type_name().split(".")[-1], "M1")
|
||||
self.assertEqual(unf.m.type_name().split(".")[-1], "M")
|
||||
self.assertTrue(torch.allclose(unf(*inp), M1()(*inp)))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user