mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-26 16:44:54 +08:00
Compare commits
120 Commits
cpp-docs-d
...
trunk/42bd
| Author | SHA1 | Date | |
|---|---|---|---|
| 42bd210fff | |||
| 1d13c314b3 | |||
| 0c9763a5a0 | |||
| 79a4a9c02e | |||
| 9d0b77f4cd | |||
| d486eee234 | |||
| cddd5f74ab | |||
| dfdb68e51f | |||
| 98c818320a | |||
| cc20b7ad72 | |||
| bc11a42b3f | |||
| 4fc06f2e0a | |||
| 82473c3d59 | |||
| b6a4236e5d | |||
| b04173be9b | |||
| 32ac38f85d | |||
| c9b49e506e | |||
| 6038e476e8 | |||
| 2c851c16e5 | |||
| 31584f2d91 | |||
| 0442125362 | |||
| fdcf402d82 | |||
| 13cda9b89e | |||
| fa6d911dda | |||
| 0db6bcc015 | |||
| 60ac039998 | |||
| 380d440d1c | |||
| 9038a30cee | |||
| 690c8c13b9 | |||
| 28ee6b62ed | |||
| 81577bdb3f | |||
| e67e3d95f3 | |||
| 27af8480ea | |||
| 6494cdc40c | |||
| ac7074efa2 | |||
| 263901cec4 | |||
| c12293dcbe | |||
| 5a4997dcae | |||
| 47f638eae7 | |||
| 882b834082 | |||
| b146ea411e | |||
| 8625ffbd45 | |||
| 0977cc4474 | |||
| d9a55faccc | |||
| 75b8295868 | |||
| defb6a80d8 | |||
| f8fccb1e48 | |||
| 5aac4cfce4 | |||
| baf91bbbfc | |||
| cbcb4f7768 | |||
| 2b93d5b450 | |||
| 6b7cd48e7e | |||
| bf5aa9e42e | |||
| b1eb6dede5 | |||
| 673060beae | |||
| 2e8e9a59a8 | |||
| fb277a5916 | |||
| 73fa0d0c63 | |||
| 36c21cc84e | |||
| 0b68814b44 | |||
| e64a814ae7 | |||
| 0b58d87aec | |||
| 757975ad50 | |||
| 291712026b | |||
| 3e77a2b478 | |||
| 82ef1b5db3 | |||
| 5f370f5c42 | |||
| 05b2e02cb4 | |||
| 12f742941d | |||
| 35180fafee | |||
| c746feb86a | |||
| c5f26db5bf | |||
| 18e99b6d45 | |||
| ab9e466928 | |||
| af4ba78543 | |||
| 282f39a4bc | |||
| a479769488 | |||
| 26c7375477 | |||
| d01f15152c | |||
| 4fae6968b1 | |||
| f9953e0f61 | |||
| 34ed7a8f0d | |||
| 2fde10d914 | |||
| 0a93295da0 | |||
| 4b898b51b9 | |||
| 550e3e6efb | |||
| 715449ca76 | |||
| 84d8d06fc3 | |||
| 60992d98b2 | |||
| 59e015e3a1 | |||
| 8904a5a7c9 | |||
| f5df9ca03a | |||
| 2998abd777 | |||
| e13580e41c | |||
| f3b8e15f20 | |||
| 5211f4c108 | |||
| ad9027b80d | |||
| a1005427bf | |||
| 35153d0846 | |||
| 7773a22cdb | |||
| 7cb467a169 | |||
| 12aac12b8d | |||
| 2b748d0a56 | |||
| 16745a882a | |||
| 8daef35cf1 | |||
| 51319ca090 | |||
| d311a3d1dc | |||
| 04adfe5ba9 | |||
| 4be1e3bf92 | |||
| e7592f4005 | |||
| d334c3649d | |||
| 9f82535c5a | |||
| 5b35fc8777 | |||
| 2f38eece7c | |||
| 830e789a55 | |||
| ad4dc52bf6 | |||
| dac9ed9790 | |||
| 1c7fe8f861 | |||
| 4e643422f6 | |||
| 3c3b278872 |
@ -19,7 +19,7 @@ pip_install \
|
||||
transformers==4.36.2
|
||||
|
||||
pip_install coloredlogs packaging
|
||||
pip_install onnxruntime==1.23.0
|
||||
pip_install onnxruntime==1.23.1
|
||||
pip_install onnxscript==0.5.4
|
||||
|
||||
# Cache the transformers model to be used later by ONNX tests. We need to run the transformers
|
||||
|
||||
@ -334,12 +334,12 @@ sympy==1.13.3
|
||||
#Pinned versions:
|
||||
#test that import:
|
||||
|
||||
onnx==1.18.0
|
||||
onnx==1.19.1
|
||||
#Description: Required by onnx tests, and mypy and test_public_bindings.py when checking torch.onnx._internal
|
||||
#Pinned versions:
|
||||
#test that import:
|
||||
|
||||
onnxscript==0.5.3
|
||||
onnxscript==0.5.4
|
||||
#Description: Required by mypy and test_public_bindings.py when checking torch.onnx._internal
|
||||
#Pinned versions:
|
||||
#test that import:
|
||||
|
||||
@ -6,7 +6,7 @@ dependencies = [
|
||||
"GitPython==3.1.45",
|
||||
"docker==7.1.0",
|
||||
"pytest==7.3.2",
|
||||
"uv==0.8.6"
|
||||
"uv==0.9.5"
|
||||
]
|
||||
|
||||
[tool.setuptools]
|
||||
|
||||
354
.claude/skills/pytorch-docstring.md
Normal file
354
.claude/skills/pytorch-docstring.md
Normal file
@ -0,0 +1,354 @@
|
||||
# PyTorch Docstring Writing Guide
|
||||
|
||||
This skill describes how to write docstrings for functions and methods in the PyTorch project, following the conventions in `torch/_tensor_docs.py` and `torch/nn/functional.py`.
|
||||
|
||||
## General Principles
|
||||
|
||||
- Use **raw strings** (`r"""..."""`) for all docstrings to avoid issues with LaTeX/math backslashes
|
||||
- Follow **Sphinx/reStructuredText** (reST) format for documentation
|
||||
- Be **concise but complete** - include all essential information
|
||||
- Always include **examples** when possible
|
||||
- Use **cross-references** to related functions/classes
|
||||
|
||||
## Docstring Structure
|
||||
|
||||
### 1. Function Signature (First Line)
|
||||
|
||||
Start with the function signature showing all parameters:
|
||||
|
||||
```python
|
||||
r"""function_name(param1, param2, *, kwarg1=default1, kwarg2=default2) -> ReturnType
|
||||
```
|
||||
|
||||
**Notes:**
|
||||
- Include the function name
|
||||
- Show positional and keyword-only arguments (use `*` separator)
|
||||
- Include default values
|
||||
- Show return type annotation
|
||||
- This line should NOT end with a period
|
||||
|
||||
### 2. Brief Description
|
||||
|
||||
Provide a one-line description of what the function does:
|
||||
|
||||
```python
|
||||
r"""conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1) -> Tensor
|
||||
|
||||
Applies a 2D convolution over an input image composed of several input
|
||||
planes.
|
||||
```
|
||||
|
||||
### 3. Mathematical Formulas (if applicable)
|
||||
|
||||
Use Sphinx math directives for mathematical expressions:
|
||||
|
||||
```python
|
||||
.. math::
|
||||
\text{Softmax}(x_{i}) = \frac{\exp(x_i)}{\sum_j \exp(x_j)}
|
||||
```
|
||||
|
||||
Or inline math: `:math:\`x^2\``
|
||||
|
||||
### 4. Cross-References
|
||||
|
||||
Link to related classes and functions using Sphinx roles:
|
||||
|
||||
- `:class:\`~torch.nn.ModuleName\`` - Link to a class
|
||||
- `:func:\`torch.function_name\`` - Link to a function
|
||||
- `:meth:\`~Tensor.method_name\`` - Link to a method
|
||||
- `:attr:\`attribute_name\`` - Reference an attribute
|
||||
- The `~` prefix shows only the last component (e.g., `Conv2d` instead of `torch.nn.Conv2d`)
|
||||
|
||||
**Example:**
|
||||
```python
|
||||
See :class:`~torch.nn.Conv2d` for details and output shape.
|
||||
```
|
||||
|
||||
### 5. Notes and Warnings
|
||||
|
||||
Use admonitions for important information:
|
||||
|
||||
```python
|
||||
.. note::
|
||||
This function doesn't work directly with NLLLoss,
|
||||
which expects the Log to be computed between the Softmax and itself.
|
||||
Use log_softmax instead (it's faster and has better numerical properties).
|
||||
|
||||
.. warning::
|
||||
:func:`new_tensor` always copies :attr:`data`. If you have a Tensor
|
||||
``data`` and want to avoid a copy, use :func:`torch.Tensor.requires_grad_`
|
||||
or :func:`torch.Tensor.detach`.
|
||||
```
|
||||
|
||||
### 6. Args Section
|
||||
|
||||
Document all parameters with type annotations and descriptions:
|
||||
|
||||
```python
|
||||
Args:
|
||||
input (Tensor): input tensor of shape :math:`(\text{minibatch} , \text{in\_channels} , iH , iW)`
|
||||
weight (Tensor): filters of shape :math:`(\text{out\_channels} , kH , kW)`
|
||||
bias (Tensor, optional): optional bias tensor of shape :math:`(\text{out\_channels})`. Default: ``None``
|
||||
stride (int or tuple): the stride of the convolving kernel. Can be a single number or a
|
||||
tuple `(sH, sW)`. Default: 1
|
||||
```
|
||||
|
||||
**Formatting rules:**
|
||||
- Parameter name in **lowercase**
|
||||
- Type in parentheses: `(Type)`, `(Type, optional)` for optional parameters
|
||||
- Description follows the type
|
||||
- For optional parameters, include "Default: ``value``" at the end
|
||||
- Use double backticks for inline code: ``` ``None`` ```
|
||||
- Indent continuation lines by 2 spaces
|
||||
|
||||
### 7. Keyword Args Section (if applicable)
|
||||
|
||||
Sometimes keyword arguments are documented separately:
|
||||
|
||||
```python
|
||||
Keyword args:
|
||||
dtype (:class:`torch.dtype`, optional): the desired type of returned tensor.
|
||||
Default: if None, same :class:`torch.dtype` as this tensor.
|
||||
device (:class:`torch.device`, optional): the desired device of returned tensor.
|
||||
Default: if None, same :class:`torch.device` as this tensor.
|
||||
requires_grad (bool, optional): If autograd should record operations on the
|
||||
returned tensor. Default: ``False``.
|
||||
```
|
||||
|
||||
### 8. Returns Section (if needed)
|
||||
|
||||
Document the return value:
|
||||
|
||||
```python
|
||||
Returns:
|
||||
Tensor: Sampled tensor of same shape as `logits` from the Gumbel-Softmax distribution.
|
||||
If ``hard=True``, the returned samples will be one-hot, otherwise they will
|
||||
be probability distributions that sum to 1 across `dim`.
|
||||
```
|
||||
|
||||
Or simply include it in the function signature line if obvious from context.
|
||||
|
||||
### 9. Examples Section
|
||||
|
||||
Always include examples when possible:
|
||||
|
||||
```python
|
||||
Examples::
|
||||
|
||||
>>> inputs = torch.randn(33, 16, 30)
|
||||
>>> filters = torch.randn(20, 16, 5)
|
||||
>>> F.conv1d(inputs, filters)
|
||||
|
||||
>>> # With square kernels and equal stride
|
||||
>>> filters = torch.randn(8, 4, 3, 3)
|
||||
>>> inputs = torch.randn(1, 4, 5, 5)
|
||||
>>> F.conv2d(inputs, filters, padding=1)
|
||||
```
|
||||
|
||||
**Formatting rules:**
|
||||
- Use `Examples::` with double colon
|
||||
- Use `>>>` prompt for Python code
|
||||
- Include comments with `#` when helpful
|
||||
- Show actual output when it helps understanding (indent without `>>>`)
|
||||
|
||||
### 10. External References
|
||||
|
||||
Link to papers or external documentation:
|
||||
|
||||
```python
|
||||
.. _Link Name:
|
||||
https://arxiv.org/abs/1611.00712
|
||||
```
|
||||
|
||||
Reference them in text: ```See `Link Name`_```
|
||||
|
||||
## Method Types
|
||||
|
||||
### Native Python Functions
|
||||
|
||||
For regular Python functions, use a standard docstring:
|
||||
|
||||
```python
|
||||
def relu(input: Tensor, inplace: bool = False) -> Tensor:
|
||||
r"""relu(input, inplace=False) -> Tensor
|
||||
|
||||
Applies the rectified linear unit function element-wise. See
|
||||
:class:`~torch.nn.ReLU` for more details.
|
||||
"""
|
||||
# implementation
|
||||
```
|
||||
|
||||
### C-Bound Functions (using add_docstr)
|
||||
|
||||
For C-bound functions, use `_add_docstr`:
|
||||
|
||||
```python
|
||||
conv1d = _add_docstr(
|
||||
torch.conv1d,
|
||||
r"""
|
||||
conv1d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1) -> Tensor
|
||||
|
||||
Applies a 1D convolution over an input signal composed of several input
|
||||
planes.
|
||||
|
||||
See :class:`~torch.nn.Conv1d` for details and output shape.
|
||||
|
||||
Args:
|
||||
input: input tensor of shape :math:`(\text{minibatch} , \text{in\_channels} , iW)`
|
||||
weight: filters of shape :math:`(\text{out\_channels} , kW)`
|
||||
...
|
||||
""",
|
||||
)
|
||||
```
|
||||
|
||||
### In-Place Variants
|
||||
|
||||
For in-place operations (ending with `_`), reference the original:
|
||||
|
||||
```python
|
||||
add_docstr_all(
|
||||
"abs_",
|
||||
r"""
|
||||
abs_() -> Tensor
|
||||
|
||||
In-place version of :meth:`~Tensor.abs`
|
||||
""",
|
||||
)
|
||||
```
|
||||
|
||||
### Alias Functions
|
||||
|
||||
For aliases, simply reference the original:
|
||||
|
||||
```python
|
||||
add_docstr_all(
|
||||
"absolute",
|
||||
r"""
|
||||
absolute() -> Tensor
|
||||
|
||||
Alias for :func:`abs`
|
||||
""",
|
||||
)
|
||||
```
|
||||
|
||||
## Common Patterns
|
||||
|
||||
### Shape Documentation
|
||||
|
||||
Use LaTeX math notation for tensor shapes:
|
||||
|
||||
```python
|
||||
:math:`(\text{minibatch} , \text{in\_channels} , iH , iW)`
|
||||
```
|
||||
|
||||
### Reusable Argument Definitions
|
||||
|
||||
For commonly used arguments, define them once and reuse:
|
||||
|
||||
```python
|
||||
common_args = parse_kwargs(
|
||||
"""
|
||||
dtype (:class:`torch.dtype`, optional): the desired type of returned tensor.
|
||||
Default: if None, same as this tensor.
|
||||
"""
|
||||
)
|
||||
|
||||
# Then use with .format():
|
||||
r"""
|
||||
...
|
||||
|
||||
Keyword args:
|
||||
{dtype}
|
||||
{device}
|
||||
""".format(**common_args)
|
||||
```
|
||||
|
||||
### Template Insertion
|
||||
|
||||
Insert reproducibility notes or other common text:
|
||||
|
||||
```python
|
||||
r"""
|
||||
{tf32_note}
|
||||
|
||||
{cudnn_reproducibility_note}
|
||||
""".format(**reproducibility_notes, **tf32_notes)
|
||||
```
|
||||
|
||||
## Complete Example
|
||||
|
||||
Here's a complete example showing all elements:
|
||||
|
||||
```python
|
||||
def gumbel_softmax(
|
||||
logits: Tensor,
|
||||
tau: float = 1,
|
||||
hard: bool = False,
|
||||
eps: float = 1e-10,
|
||||
dim: int = -1,
|
||||
) -> Tensor:
|
||||
r"""
|
||||
Sample from the Gumbel-Softmax distribution and optionally discretize.
|
||||
|
||||
Args:
|
||||
logits (Tensor): `[..., num_features]` unnormalized log probabilities
|
||||
tau (float): non-negative scalar temperature
|
||||
hard (bool): if ``True``, the returned samples will be discretized as one-hot vectors,
|
||||
but will be differentiated as if it is the soft sample in autograd. Default: ``False``
|
||||
dim (int): A dimension along which softmax will be computed. Default: -1
|
||||
|
||||
Returns:
|
||||
Tensor: Sampled tensor of same shape as `logits` from the Gumbel-Softmax distribution.
|
||||
If ``hard=True``, the returned samples will be one-hot, otherwise they will
|
||||
be probability distributions that sum to 1 across `dim`.
|
||||
|
||||
.. note::
|
||||
This function is here for legacy reasons, may be removed from nn.Functional in the future.
|
||||
|
||||
Examples::
|
||||
>>> logits = torch.randn(20, 32)
|
||||
>>> # Sample soft categorical using reparametrization trick:
|
||||
>>> F.gumbel_softmax(logits, tau=1, hard=False)
|
||||
>>> # Sample hard categorical using "Straight-through" trick:
|
||||
>>> F.gumbel_softmax(logits, tau=1, hard=True)
|
||||
|
||||
.. _Link 1:
|
||||
https://arxiv.org/abs/1611.00712
|
||||
"""
|
||||
# implementation
|
||||
```
|
||||
|
||||
## Quick Checklist
|
||||
|
||||
When writing a PyTorch docstring, ensure:
|
||||
|
||||
- [ ] Use raw string (`r"""`)
|
||||
- [ ] Include function signature on first line
|
||||
- [ ] Provide brief description
|
||||
- [ ] Document all parameters in Args section with types
|
||||
- [ ] Include default values for optional parameters
|
||||
- [ ] Use Sphinx cross-references (`:func:`, `:class:`, `:meth:`)
|
||||
- [ ] Add mathematical formulas if applicable
|
||||
- [ ] Include at least one example in Examples section
|
||||
- [ ] Add warnings/notes for important caveats
|
||||
- [ ] Link to related module class with `:class:`
|
||||
- [ ] Use proper math notation for tensor shapes
|
||||
- [ ] Follow consistent formatting and indentation
|
||||
|
||||
## Common Sphinx Roles Reference
|
||||
|
||||
- `:class:\`~torch.nn.Module\`` - Class reference
|
||||
- `:func:\`torch.function\`` - Function reference
|
||||
- `:meth:\`~Tensor.method\`` - Method reference
|
||||
- `:attr:\`attribute\`` - Attribute reference
|
||||
- `:math:\`equation\`` - Inline math
|
||||
- `:ref:\`label\`` - Internal reference
|
||||
- ``` ``code`` ``` - Inline code (use double backticks)
|
||||
|
||||
## Additional Notes
|
||||
|
||||
- **Indentation**: Use 4 spaces for code, 2 spaces for continuation of parameter descriptions
|
||||
- **Line length**: Try to keep lines under 100 characters when possible
|
||||
- **Periods**: End sentences with periods, but not the signature line
|
||||
- **Backticks**: Use double backticks for code: ``` ``True`` ``None`` ``False`` ```
|
||||
- **Types**: Common types are `Tensor`, `int`, `float`, `bool`, `str`, `tuple`, `list`, etc.
|
||||
7
.github/actions/setup-rocm/action.yml
vendored
7
.github/actions/setup-rocm/action.yml
vendored
@ -124,3 +124,10 @@ runs:
|
||||
id: login-ecr
|
||||
continue-on-error: true
|
||||
uses: aws-actions/amazon-ecr-login@062b18b96a7aff071d4dc91bc00c4c1a7945b076 # v2.0.1
|
||||
|
||||
- name: Preserve github env variables for use in docker
|
||||
shell: bash
|
||||
run: |
|
||||
env | grep '^GITHUB' >> "${RUNNER_TEMP}/github_env_${GITHUB_RUN_ID}"
|
||||
env | grep '^CI' >> "${RUNNER_TEMP}/github_env_${GITHUB_RUN_ID}"
|
||||
env | grep '^RUNNER' >> "${RUNNER_TEMP}/github_env_${GITHUB_RUN_ID}"
|
||||
|
||||
2
.github/ci_commit_pins/vision.txt
vendored
2
.github/ci_commit_pins/vision.txt
vendored
@ -1 +1 @@
|
||||
faffd5cf673615583da6517275e361cb3dbc77e6
|
||||
1752fe6809b74921644866275ab80244b96e80bc
|
||||
|
||||
5
.github/ci_configs/vllm/Dockerfile
vendored
5
.github/ci_configs/vllm/Dockerfile
vendored
@ -283,6 +283,9 @@ RUN --mount=type=bind,source=${TORCH_WHEELS_PATH},target=/dist \
|
||||
uv pip install --system $(cat torch_build_versions.txt | xargs) --index-url https://download.pytorch.org/whl/nightly/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.'); \
|
||||
fi
|
||||
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
uv pip install --system --pre apache-tvm-ffi==0.1.0b15
|
||||
|
||||
# Install the vllm wheel from previous stage
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
uv pip install --system /wheels/vllm/*.whl --verbose
|
||||
@ -295,6 +298,8 @@ RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
ARG torch_cuda_arch_list='8.0;8.9;9.0a;10.0a;12.0'
|
||||
ENV TORCH_CUDA_ARCH_LIST=${torch_cuda_arch_list}
|
||||
|
||||
# TODO(elainewy): remove this once vllm commit is updated, and install flashinfer from pip
|
||||
# see https://github.com/pytorch/pytorch/pull/165274#issuecomment-3408531784
|
||||
ARG FLASHINFER_GIT_REPO="https://github.com/flashinfer-ai/flashinfer.git"
|
||||
ARG FLASHINFER_GIT_REF="v0.2.14.post1"
|
||||
|
||||
|
||||
9
.github/label_to_label.yml
vendored
9
.github/label_to_label.yml
vendored
@ -15,6 +15,11 @@
|
||||
- "module: reinplacing"
|
||||
then:
|
||||
- "module: pt2-dispatcher"
|
||||
- any:
|
||||
- "vllm-compile"
|
||||
then:
|
||||
- "module: vllm"
|
||||
- "oncall: pt2"
|
||||
- any:
|
||||
- "module: vmap"
|
||||
then:
|
||||
@ -27,10 +32,6 @@
|
||||
- "module: pt2 optimizer"
|
||||
then:
|
||||
- "module: dynamo"
|
||||
- any:
|
||||
- "module: flex attention"
|
||||
then:
|
||||
- "module: higher order operators"
|
||||
- any:
|
||||
- "module: aotinductor"
|
||||
then:
|
||||
|
||||
1
.github/workflows/inductor-periodic.yml
vendored
1
.github/workflows/inductor-periodic.yml
vendored
@ -88,7 +88,6 @@ jobs:
|
||||
with:
|
||||
build-environment: linux-jammy-rocm-py3_10
|
||||
docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3-benchmarks
|
||||
sync-tag: rocm-build
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
{ config: "dynamo_eager_torchbench", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.gfx942.1" },
|
||||
|
||||
15
.github/workflows/periodic.yml
vendored
15
.github/workflows/periodic.yml
vendored
@ -147,15 +147,16 @@ jobs:
|
||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
||||
build-environment: linux-jammy-cuda12.8-py3.10-gcc9-debug
|
||||
docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9
|
||||
cuda-arch-list: 8.9
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
{ config: "default", shard: 1, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu", owners: ["oncall:debug-build"] },
|
||||
{ config: "default", shard: 2, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu", owners: ["oncall:debug-build"] },
|
||||
{ config: "default", shard: 3, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu", owners: ["oncall:debug-build"] },
|
||||
{ config: "default", shard: 4, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu", owners: ["oncall:debug-build"] },
|
||||
{ config: "default", shard: 5, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu", owners: ["oncall:debug-build"] },
|
||||
{ config: "default", shard: 6, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu", owners: ["oncall:debug-build"] },
|
||||
{ config: "default", shard: 7, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu", owners: ["oncall:debug-build"] },
|
||||
{ config: "default", shard: 1, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", owners: ["oncall:debug-build"] },
|
||||
{ config: "default", shard: 2, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", owners: ["oncall:debug-build"] },
|
||||
{ config: "default", shard: 3, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", owners: ["oncall:debug-build"] },
|
||||
{ config: "default", shard: 4, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", owners: ["oncall:debug-build"] },
|
||||
{ config: "default", shard: 5, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", owners: ["oncall:debug-build"] },
|
||||
{ config: "default", shard: 6, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", owners: ["oncall:debug-build"] },
|
||||
{ config: "default", shard: 7, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", owners: ["oncall:debug-build"] },
|
||||
]}
|
||||
secrets: inherit
|
||||
|
||||
|
||||
3
.github/workflows/pull.yml
vendored
3
.github/workflows/pull.yml
vendored
@ -347,7 +347,8 @@ jobs:
|
||||
uses: ./.github/workflows/_linux-build.yml
|
||||
needs: get-label-type
|
||||
with:
|
||||
sync-tag: linux-xpu-n-build
|
||||
# This should sync with the build in xpu.yml but xpu uses a larger runner
|
||||
# sync-tag: linux-xpu-n-build
|
||||
runner_prefix: ${{ needs.get-label-type.outputs.label-type }}
|
||||
build-environment: linux-jammy-xpu-n-py3.10
|
||||
docker-image-name: ci-image:pytorch-linux-jammy-xpu-n-py3
|
||||
|
||||
1
.github/workflows/rocm-mi300.yml
vendored
1
.github/workflows/rocm-mi300.yml
vendored
@ -45,7 +45,6 @@ jobs:
|
||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
||||
build-environment: linux-noble-rocm-py3.12-mi300
|
||||
docker-image-name: ci-image:pytorch-linux-noble-rocm-n-py3
|
||||
sync-tag: rocm-build
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
{ config: "default", shard: 1, num_shards: 6, runner: "linux.rocm.gpu.gfx942.1" },
|
||||
|
||||
1
.github/workflows/rocm-mi355.yml
vendored
1
.github/workflows/rocm-mi355.yml
vendored
@ -42,7 +42,6 @@ jobs:
|
||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
||||
build-environment: linux-noble-rocm-py3.12-mi355
|
||||
docker-image-name: ci-image:pytorch-linux-noble-rocm-n-py3
|
||||
sync-tag: rocm-build
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
{ config: "default", shard: 1, num_shards: 6, runner: "linux.rocm.gpu.mi355.1" },
|
||||
|
||||
12
.github/workflows/rocm-navi31.yml
vendored
12
.github/workflows/rocm-navi31.yml
vendored
@ -26,11 +26,23 @@ jobs:
|
||||
id-token: write
|
||||
contents: read
|
||||
|
||||
get-label-type:
|
||||
name: get-label-type
|
||||
uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main
|
||||
if: ${{ (github.event_name != 'schedule' || github.repository == 'pytorch/pytorch') && github.repository_owner == 'pytorch' }}
|
||||
with:
|
||||
triggering_actor: ${{ github.triggering_actor }}
|
||||
issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }}
|
||||
curr_branch: ${{ github.head_ref || github.ref_name }}
|
||||
curr_ref_type: ${{ github.ref_type }}
|
||||
|
||||
linux-jammy-rocm-py3_10-build:
|
||||
if: ${{ (github.event_name != 'schedule' || github.repository == 'pytorch/pytorch') && github.repository_owner == 'pytorch' }}
|
||||
name: linux-jammy-rocm-py3.10
|
||||
uses: ./.github/workflows/_linux-build.yml
|
||||
needs: get-label-type
|
||||
with:
|
||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
||||
build-environment: linux-jammy-rocm-py3.10
|
||||
docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3
|
||||
sync-tag: rocm-build
|
||||
|
||||
12
.github/workflows/rocm.yml
vendored
12
.github/workflows/rocm.yml
vendored
@ -26,11 +26,23 @@ jobs:
|
||||
id-token: write
|
||||
contents: read
|
||||
|
||||
get-label-type:
|
||||
name: get-label-type
|
||||
uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main
|
||||
if: ${{ (github.event_name != 'schedule' || github.repository == 'pytorch/pytorch') && github.repository_owner == 'pytorch' }}
|
||||
with:
|
||||
triggering_actor: ${{ github.triggering_actor }}
|
||||
issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }}
|
||||
curr_branch: ${{ github.head_ref || github.ref_name }}
|
||||
curr_ref_type: ${{ github.ref_type }}
|
||||
|
||||
linux-jammy-rocm-py3_10-build:
|
||||
if: ${{ (github.event_name != 'schedule' || github.repository == 'pytorch/pytorch') && github.repository_owner == 'pytorch' }}
|
||||
name: linux-jammy-rocm-py3.10
|
||||
uses: ./.github/workflows/_linux-build.yml
|
||||
needs: get-label-type
|
||||
with:
|
||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
||||
build-environment: linux-jammy-rocm-py3.10
|
||||
docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3
|
||||
sync-tag: rocm-build
|
||||
|
||||
147
.github/workflows/trunk-tagging.yml
vendored
147
.github/workflows/trunk-tagging.yml
vendored
@ -58,8 +58,10 @@ jobs:
|
||||
else
|
||||
COMMIT_SHA="${{ github.sha }}"
|
||||
fi
|
||||
echo "sha=${COMMIT_SHA}" >> "${GITHUB_OUTPUT}"
|
||||
echo "tag_name=trunk/${COMMIT_SHA}" >> "${GITHUB_OUTPUT}"
|
||||
{
|
||||
echo "sha=${COMMIT_SHA}"
|
||||
echo "tag_name=trunk/${COMMIT_SHA}"
|
||||
} >> "${GITHUB_OUTPUT}"
|
||||
|
||||
- name: Validate commit SHA
|
||||
run: |
|
||||
@ -87,7 +89,7 @@ jobs:
|
||||
echo "✅ Commit ${COMMIT_SHA} is valid (automatic push trigger)"
|
||||
fi
|
||||
|
||||
- name: Create and push tag with retry
|
||||
- name: Create and push tag(s) with retry
|
||||
id: check_tag
|
||||
env:
|
||||
TAG_NAME: ${{ steps.commit.outputs.tag_name }}
|
||||
@ -112,14 +114,23 @@ jobs:
|
||||
return 1
|
||||
}
|
||||
|
||||
# Exit early if tag already exists
|
||||
if check_tag_exists; then
|
||||
echo "✅ Tag already exists - no action needed"
|
||||
echo "exists=true" >> "${GITHUB_OUTPUT}"
|
||||
exit 0
|
||||
fi
|
||||
# Counters for summary reporting
|
||||
created_count=0
|
||||
skipped_count=0
|
||||
failed_count=0
|
||||
|
||||
echo "Tag ${TAG_NAME} does not exist, proceeding with creation"
|
||||
# Always write outputs once on exit
|
||||
finish() {
|
||||
set +e
|
||||
if [ -n "${GITHUB_OUTPUT:-}" ]; then
|
||||
{
|
||||
echo "created_count=${created_count}"
|
||||
echo "skipped_count=${skipped_count}"
|
||||
echo "failed_count=${failed_count}"
|
||||
} >> "${GITHUB_OUTPUT}"
|
||||
fi
|
||||
}
|
||||
trap finish EXIT
|
||||
|
||||
# Retry configuration
|
||||
MAX_RETRIES=5
|
||||
@ -194,31 +205,111 @@ jobs:
|
||||
}
|
||||
}
|
||||
|
||||
# Execute with retry
|
||||
if retry_with_backoff "tag_with_retry" "Creating tag ${TAG_NAME} for commit ${COMMIT_SHA}"; then
|
||||
echo "exists=false" >> "${GITHUB_OUTPUT}"
|
||||
# New behavior for push events: enumerate commits in the push and tag each one.
|
||||
# For workflow_dispatch, retain existing single-SHA behavior.
|
||||
|
||||
# Always fetch tags once up front to improve idempotency in loops
|
||||
git fetch origin --tags --quiet || true
|
||||
|
||||
if [ "${{ github.event_name }}" = "push" ]; then
|
||||
BEFORE_SHA="${{ github.event.before }}"
|
||||
AFTER_SHA="${{ github.sha }}" # same as event.after
|
||||
|
||||
# List commits introduced by this push (old..new), oldest first for stable ordering
|
||||
commits_file="$(mktemp)"
|
||||
git rev-list --reverse "${BEFORE_SHA}..${AFTER_SHA}" > "${commits_file}"
|
||||
|
||||
if [ ! -s "${commits_file}" ]; then
|
||||
echo "No new commits found between ${BEFORE_SHA}..${AFTER_SHA}; nothing to tag."
|
||||
rm -f "${commits_file}"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
commit_count="$(wc -l < "${commits_file}" | tr -d ' ')"
|
||||
echo "Found ${commit_count} commit(s) to tag for push:"
|
||||
while IFS= read -r sha; do
|
||||
printf ' %s\n' "${sha}"
|
||||
done < "${commits_file}"
|
||||
|
||||
while IFS= read -r sha; do
|
||||
TAG_NAME="trunk/${sha}"
|
||||
COMMIT_SHA="${sha}"
|
||||
|
||||
# If tag already exists locally or remotely, skip (idempotent)
|
||||
if check_tag_exists; then
|
||||
echo "✅ Tag ${TAG_NAME} already exists - skipping"
|
||||
skipped_count=$((skipped_count + 1))
|
||||
continue
|
||||
fi
|
||||
|
||||
echo "Tag ${TAG_NAME} does not exist, proceeding with creation"
|
||||
|
||||
if retry_with_backoff "tag_with_retry" "Creating tag ${TAG_NAME} for commit ${COMMIT_SHA}"; then
|
||||
created_count=$((created_count + 1))
|
||||
else
|
||||
echo "Tag creation failed after all retry attempts for ${TAG_NAME}"
|
||||
failed_count=$((failed_count + 1))
|
||||
fi
|
||||
done < "${commits_file}"
|
||||
|
||||
rm -f "${commits_file}"
|
||||
|
||||
if [ "${failed_count}" -gt 0 ]; then
|
||||
exit 1
|
||||
fi
|
||||
exit 0
|
||||
else
|
||||
echo "Tag creation failed after all retry attempts"
|
||||
exit 1
|
||||
# workflow_dispatch path (single SHA tagging preserved)
|
||||
|
||||
# Exit early if tag already exists
|
||||
if check_tag_exists; then
|
||||
echo "✅ Tag already exists - no action needed"
|
||||
skipped_count=1
|
||||
exit 0
|
||||
fi
|
||||
|
||||
echo "Tag ${TAG_NAME} does not exist, proceeding with creation"
|
||||
|
||||
if retry_with_backoff "tag_with_retry" "Creating tag ${TAG_NAME} for commit ${COMMIT_SHA}"; then
|
||||
created_count=1
|
||||
exit 0
|
||||
else
|
||||
echo "Tag creation failed after all retry attempts"
|
||||
failed_count=1
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
|
||||
- name: Tag creation summary
|
||||
if: always()
|
||||
run: |
|
||||
if [ "${{ steps.check_tag.outputs.exists }}" = "true" ]; then
|
||||
echo "✅ Tag ${{ steps.commit.outputs.tag_name }} already existed - no action needed"
|
||||
elif [ "${{ job.status }}" = "success" ]; then
|
||||
echo "✅ Successfully created tag ${{ steps.commit.outputs.tag_name }} for commit ${{ steps.commit.outputs.sha }}"
|
||||
if [ "${{ github.event_name }}" = "push" ]; then
|
||||
echo "Trigger: push on main"
|
||||
echo "Created: ${{ steps.check_tag.outputs.created_count }}"
|
||||
echo "Skipped (already existed): ${{ steps.check_tag.outputs.skipped_count }}"
|
||||
echo "Failed: ${{ steps.check_tag.outputs.failed_count }}"
|
||||
if [ "${{ steps.check_tag.outputs.failed_count }}" = "0" ]; then
|
||||
echo "✅ Completed tagging for push range ${{ github.event.before }}..${{ github.sha }}"
|
||||
else
|
||||
echo "❌ Some tags failed to create for push range ${{ github.event.before }}..${{ github.sha }}"
|
||||
fi
|
||||
else
|
||||
echo "❌ Failed to create tag ${{ steps.commit.outputs.tag_name }} for commit ${{ steps.commit.outputs.sha }}"
|
||||
fi
|
||||
if [ "${{ steps.check_tag.outputs.failed_count }}" = "0" ]; then
|
||||
if [ "${{ steps.check_tag.outputs.created_count }}" = "0" ]; then
|
||||
echo "✅ Tag ${{ steps.commit.outputs.tag_name }} already existed - no action needed"
|
||||
else
|
||||
echo "✅ Successfully created tag ${{ steps.commit.outputs.tag_name }} for commit ${{ steps.commit.outputs.sha }}"
|
||||
fi
|
||||
else
|
||||
echo "❌ Failed to create tag ${{ steps.commit.outputs.tag_name }} for commit ${{ steps.commit.outputs.sha }}"
|
||||
fi
|
||||
|
||||
echo ""
|
||||
echo "Tag details:"
|
||||
echo " Name: ${{ steps.commit.outputs.tag_name }}"
|
||||
echo " Commit: ${{ steps.commit.outputs.sha }}"
|
||||
echo " Trigger: ${{ github.event_name }}"
|
||||
if [ -n "${{ github.event.inputs.commit_sha }}" ]; then
|
||||
echo " Manual commit: ${{ github.event.inputs.commit_sha }}"
|
||||
echo ""
|
||||
echo "Tag details:"
|
||||
echo " Name: ${{ steps.commit.outputs.tag_name }}"
|
||||
echo " Commit: ${{ steps.commit.outputs.sha }}"
|
||||
echo " Trigger: ${{ github.event_name }}"
|
||||
if [ -n "${{ github.event.inputs.commit_sha }}" ]; then
|
||||
echo " Manual commit: ${{ github.event.inputs.commit_sha }}"
|
||||
fi
|
||||
fi
|
||||
|
||||
@ -1138,11 +1138,8 @@ command = [
|
||||
[[linter]]
|
||||
code = 'WORKFLOWSYNC'
|
||||
include_patterns = [
|
||||
'.github/workflows/pull.yml',
|
||||
'.github/workflows/trunk.yml',
|
||||
'.github/workflows/periodic.yml',
|
||||
'.github/workflows/mac-mps.yml',
|
||||
'.github/workflows/slow.yml',
|
||||
'.github/workflows/*.yml',
|
||||
'.github/workflows/*.yaml',
|
||||
]
|
||||
command = [
|
||||
'python3',
|
||||
|
||||
@ -289,14 +289,15 @@ IF(USE_FBGEMM_GENAI)
|
||||
|
||||
set_target_properties(fbgemm_genai PROPERTIES POSITION_INDEPENDENT_CODE ON)
|
||||
|
||||
set(fbgemm_genai_mx8mx8bf16_grouped
|
||||
set(fbgemm_genai_cuh
|
||||
"${FBGEMM_GENAI_SRCS}/cutlass_extensions/mx8mx8bf16_grouped/"
|
||||
"${FBGEMM_GENAI_SRCS}/"
|
||||
)
|
||||
|
||||
target_include_directories(fbgemm_genai PRIVATE
|
||||
${FBGEMM_THIRD_PARTY}/cutlass/include
|
||||
${FBGEMM_THIRD_PARTY}/cutlass/tools/util/include
|
||||
${fbgemm_genai_mx8mx8bf16_grouped}
|
||||
${fbgemm_genai_cuh}
|
||||
${FBGEMM_GENAI_SRCS}/common/include/ # includes fbgemm_gpu/quantize/utils.h, fbgemm_gpu/quantize/tuning_cache.hpp
|
||||
${FBGEMM_GENAI_SRCS}/include/ # includes fbgemm_gpu/torch_ops.h
|
||||
)
|
||||
|
||||
@ -19,6 +19,7 @@
|
||||
#include <ATen/detail/MPSHooksInterface.h>
|
||||
#include <ATen/detail/MTIAHooksInterface.h>
|
||||
#include <ATen/detail/PrivateUse1HooksInterface.h>
|
||||
#include <ATen/detail/XLAHooksInterface.h>
|
||||
#include <ATen/detail/XPUHooksInterface.h>
|
||||
#include <c10/core/QEngine.h>
|
||||
#include <c10/core/impl/DeviceGuardImplInterface.h>
|
||||
@ -88,6 +89,8 @@ class TORCH_API Context {
|
||||
return at::detail::getHIPHooks();
|
||||
} else if (opt_device_type == at::kHPU) {
|
||||
return at::detail::getHPUHooks();
|
||||
} else if (opt_device_type == at::kXLA) {
|
||||
return at::detail::getXLAHooks();
|
||||
} else {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
@ -196,7 +199,7 @@ class TORCH_API Context {
|
||||
return c10::impl::hasDeviceGuardImpl(c10::DeviceType::IPU);
|
||||
}
|
||||
static bool hasXLA() {
|
||||
return c10::impl::hasDeviceGuardImpl(c10::DeviceType::XLA);
|
||||
return detail::getXLAHooks().hasXLA();
|
||||
}
|
||||
static bool hasXPU() {
|
||||
return detail::getXPUHooks().hasXPU();
|
||||
|
||||
@ -59,9 +59,7 @@ struct TORCH_API Generator {
|
||||
|
||||
explicit Generator(c10::intrusive_ptr<c10::GeneratorImpl> gen_impl)
|
||||
: impl_(std::move(gen_impl)) {
|
||||
if (impl_.get() == nullptr) {
|
||||
throw std::runtime_error("GeneratorImpl with nullptr is not supported");
|
||||
}
|
||||
TORCH_CHECK(impl_.get(), "GeneratorImpl with nullptr is not supported");
|
||||
}
|
||||
|
||||
bool operator==(const Generator& rhs) const {
|
||||
|
||||
@ -111,9 +111,7 @@ class TORCH_API TensorBase {
|
||||
explicit TensorBase(
|
||||
c10::intrusive_ptr<TensorImpl, UndefinedTensorImpl> tensor_impl)
|
||||
: impl_(std::move(tensor_impl)) {
|
||||
if (impl_.get() == nullptr) {
|
||||
throw std::runtime_error("TensorImpl with nullptr is not supported");
|
||||
}
|
||||
TORCH_CHECK(impl_.get(), "TensorImpl with nullptr is not supported");
|
||||
}
|
||||
TensorBase(const TensorBase&) = default;
|
||||
TensorBase(TensorBase&&) noexcept = default;
|
||||
|
||||
@ -109,6 +109,10 @@ TORCH_LIBRARY_IMPL(_, AutogradHPU, m) {
|
||||
m.fallback(AUTOGRAD_FALLBACK);
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL(_, AutogradPrivateUse1, m) {
|
||||
m.fallback(AUTOGRAD_FALLBACK);
|
||||
}
|
||||
|
||||
#undef AUTOGRAD_FALLBACK
|
||||
|
||||
} // namespace
|
||||
|
||||
@ -442,11 +442,17 @@ RegistrationHandleRAII Dispatcher::registerFallback(DispatchKey dispatchKey, Ker
|
||||
|
||||
auto idx = getDispatchTableIndexForDispatchKey(dispatchKey);
|
||||
TORCH_CHECK(idx >= 0 && static_cast<uint64_t>(idx) < backendFallbackKernels_.size(), "idx=", idx);
|
||||
// NB: Perserve BC for registering fallback for AutogradPrivateUse1 multiple time,
|
||||
// refer to https://github.com/pytorch/pytorch/issues/163979 for more informations.
|
||||
TORCH_CHECK(
|
||||
!backendFallbackKernels_[idx].kernel.isValid(),
|
||||
"Tried to register multiple backend fallbacks for the same dispatch key ", dispatchKey, "; previous registration ",
|
||||
backendFallbackKernels_[idx].debug, ", new registration ", debug
|
||||
);
|
||||
dispatchKey == DispatchKey::AutogradPrivateUse1 ||
|
||||
!backendFallbackKernels_[idx].kernel.isValid(),
|
||||
"Tried to register multiple backend fallbacks for the same dispatch key ",
|
||||
dispatchKey,
|
||||
"; previous registration ",
|
||||
backendFallbackKernels_[idx].debug,
|
||||
", new registration ",
|
||||
debug);
|
||||
// NB: inferred function schema is always nullptr for fallbacks, as fallbacks
|
||||
// cannot be unboxed
|
||||
backendFallbackKernels_[idx] = impl::AnnotatedKernel(std::move(kernel), nullptr, std::move(debug));
|
||||
|
||||
@ -68,11 +68,7 @@ Symbol InternedStrings::_symbol(const std::string& s) {
|
||||
return it->second;
|
||||
|
||||
auto pos = s.find("::");
|
||||
if (pos == std::string::npos) {
|
||||
std::stringstream ss;
|
||||
ss << "all symbols must have a namespace, <namespace>::<string>, but found: " << s;
|
||||
throw std::runtime_error(ss.str());
|
||||
}
|
||||
TORCH_CHECK(pos != std::string::npos, "all symbols must have a namespace, <namespace>::<string>, but found: ", s);
|
||||
Symbol ns = _symbol("namespaces::" + s.substr(0, pos));
|
||||
|
||||
Symbol sym(sym_to_info_.size());
|
||||
@ -121,12 +117,7 @@ std::string Symbol::domainString() const {
|
||||
}
|
||||
|
||||
Symbol Symbol::fromDomainAndUnqualString(const std::string & d, const std::string & s) {
|
||||
if (d.compare(0, domain_prefix().size(), domain_prefix()) != 0) {
|
||||
std::ostringstream ss;
|
||||
ss << "Symbol: domain string is expected to be prefixed with '"
|
||||
<< domain_prefix() << "', e.g. 'org.pytorch.aten'";
|
||||
throw std::runtime_error(ss.str());
|
||||
}
|
||||
TORCH_CHECK(d.compare(0, domain_prefix().size(), domain_prefix()) == 0, "Symbol: domain string is expected to be prefixed with '", domain_prefix(), "', e.g. 'org.pytorch.aten'");
|
||||
std::string qualString = d.substr(domain_prefix().size()) + "::" + s;
|
||||
return fromQualString(qualString);
|
||||
}
|
||||
|
||||
@ -7,6 +7,7 @@
|
||||
#include <ATen/core/jit_type.h>
|
||||
#include <ATen/core/stack.h>
|
||||
#include <ATen/core/type_factory.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#include <c10/util/StringUtil.h>
|
||||
#include <c10/util/hash.h>
|
||||
#include <c10/util/irange.h>
|
||||
@ -412,7 +413,7 @@ size_t IValue::hash(const IValue& v) {
|
||||
case Tag::Enum:
|
||||
case Tag::Stream:
|
||||
case Tag::Uninitialized:
|
||||
throw std::runtime_error(
|
||||
TORCH_CHECK(false,
|
||||
"unhashable type: '" + v.type()->repr_str() + "'");
|
||||
}
|
||||
// the above switch should be exhaustive
|
||||
|
||||
@ -8,6 +8,7 @@
|
||||
#include <ATen/core/type_factory.h>
|
||||
#include <ATen/core/qualified_name.h>
|
||||
#include <c10/util/TypeList.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#include <optional>
|
||||
#include <c10/core/SymFloat.h>
|
||||
#include <c10/core/SymBool.h>
|
||||
@ -116,10 +117,8 @@ struct SingleElementType : public SharedType {
|
||||
|
||||
protected:
|
||||
SingleElementType(TypePtr elem) : SharedType(Kind), elem(std::move(elem)) {
|
||||
if (!this->elem) {
|
||||
throw std::runtime_error(c10::str(
|
||||
TORCH_CHECK(this->elem, c10::str(
|
||||
"Can not create ", typeKindToString(Kind), " with None type"));
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
@ -416,16 +415,12 @@ struct TORCH_API SymbolicShape {
|
||||
}
|
||||
|
||||
ShapeSymbol operator[](size_t i) const {
|
||||
if (!dims_) {
|
||||
throw std::runtime_error("Rank isn't fixed");
|
||||
}
|
||||
TORCH_CHECK(dims_, "Rank isn't fixed");
|
||||
return (*dims_).at(i);
|
||||
}
|
||||
|
||||
ShapeSymbol at(size_t i) const {
|
||||
if (!dims_) {
|
||||
throw std::runtime_error("Rank isn't fixed");
|
||||
}
|
||||
TORCH_CHECK(dims_, "Rank isn't fixed");
|
||||
return (*dims_).at(i);
|
||||
}
|
||||
|
||||
@ -520,9 +515,7 @@ struct VaryingShape {
|
||||
}
|
||||
|
||||
const std::optional<T> &operator[](size_t i) const {
|
||||
if (!dims_) {
|
||||
throw std::runtime_error("Rank isn't fixed");
|
||||
}
|
||||
TORCH_CHECK(dims_, "Rank isn't fixed");
|
||||
return (*dims_).at(i);
|
||||
}
|
||||
|
||||
@ -957,9 +950,7 @@ struct TORCH_API DictType : public SharedType {
|
||||
|
||||
TypePtr createWithContained(
|
||||
std::vector<TypePtr> contained_types) const override {
|
||||
if (contained_types.size() != 2) {
|
||||
throw std::runtime_error("Expected 2 contained types");
|
||||
}
|
||||
TORCH_CHECK(contained_types.size() == 2, "Expected 2 contained types");
|
||||
return create(std::move(contained_types.at(0)), std::move(contained_types.at(1)));
|
||||
}
|
||||
|
||||
|
||||
@ -8,6 +8,7 @@
|
||||
#include <ATen/core/jit_type.h>
|
||||
#include <c10/macros/Macros.h>
|
||||
#include <c10/util/env.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#include <c10/util/flat_hash_map.h>
|
||||
#include <c10/util/irange.h>
|
||||
#include <array>
|
||||
@ -826,9 +827,7 @@ TupleType::TupleType(
|
||||
: NamedType(TypeKind::TupleType, std::move(name)),
|
||||
elements_(std::move(elements)),
|
||||
has_free_variables_(std::any_of(elements_.begin(), elements_.end(), [](const TypePtr& v) {
|
||||
if (!v) {
|
||||
throw std::runtime_error("Can not create tuple with None type");
|
||||
}
|
||||
TORCH_CHECK(v, "Can not create tuple with None type");
|
||||
return v->hasFreeVariables();
|
||||
})), schema_(std::move(schema)) {
|
||||
|
||||
|
||||
@ -6,9 +6,11 @@
|
||||
#ifdef __aarch64__
|
||||
#if !defined(CPU_CAPABILITY_SVE)
|
||||
#include <ATen/cpu/vec/vec128/vec128_bfloat16_neon.h>
|
||||
#include <ATen/cpu/vec/vec128/vec128_double_neon.h>
|
||||
#include <ATen/cpu/vec/vec128/vec128_float_neon.h>
|
||||
#include <ATen/cpu/vec/vec128/vec128_half_neon.h>
|
||||
#include <ATen/cpu/vec/vec128/vec128_int_aarch64.h>
|
||||
#include <ATen/cpu/vec/vec128/vec128_uint_aarch64.h>
|
||||
#endif
|
||||
|
||||
#include <ATen/cpu/vec/vec128/vec128_convert.h>
|
||||
|
||||
@ -354,9 +354,47 @@ class Vectorized<c10::BFloat16> : public Vectorized16<
|
||||
|
||||
DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(abs)
|
||||
Vectorized frac() const;
|
||||
DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(neg)
|
||||
DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(trunc)
|
||||
DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(sqrt)
|
||||
|
||||
#ifdef __ARM_FEATURE_BF16
|
||||
Vectorized<c10::BFloat16> neg() const {
|
||||
return -values;
|
||||
}
|
||||
Vectorized<c10::BFloat16> reciprocal() const {
|
||||
return 1.0f / values;
|
||||
}
|
||||
Vectorized<c10::BFloat16> operator==(
|
||||
const Vectorized<c10::BFloat16>& other) const {
|
||||
return values == other.values;
|
||||
}
|
||||
|
||||
Vectorized<c10::BFloat16> operator!=(
|
||||
const Vectorized<c10::BFloat16>& other) const {
|
||||
return values != other.values;
|
||||
}
|
||||
|
||||
Vectorized<c10::BFloat16> operator<(
|
||||
const Vectorized<c10::BFloat16>& other) const {
|
||||
return values < other.values;
|
||||
}
|
||||
|
||||
Vectorized<c10::BFloat16> operator<=(
|
||||
const Vectorized<c10::BFloat16>& other) const {
|
||||
return values <= other.values;
|
||||
}
|
||||
|
||||
Vectorized<c10::BFloat16> operator>(
|
||||
const Vectorized<c10::BFloat16>& other) const {
|
||||
return values > other.values;
|
||||
}
|
||||
|
||||
Vectorized<c10::BFloat16> operator>=(
|
||||
const Vectorized<c10::BFloat16>& other) const {
|
||||
return values >= other.values;
|
||||
}
|
||||
#else
|
||||
DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(neg)
|
||||
DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(reciprocal)
|
||||
DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator==)
|
||||
DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator!=)
|
||||
@ -364,6 +402,7 @@ class Vectorized<c10::BFloat16> : public Vectorized16<
|
||||
DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator<=)
|
||||
DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator>)
|
||||
DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator>=)
|
||||
#endif
|
||||
|
||||
#undef DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD
|
||||
#undef DEFINE_BINARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD
|
||||
@ -412,28 +451,52 @@ template <>
|
||||
Vectorized<c10::BFloat16> inline operator+(
|
||||
const Vectorized<c10::BFloat16>& a,
|
||||
const Vectorized<c10::BFloat16>& b) {
|
||||
#ifdef __ARM_FEATURE_BF16
|
||||
bfloat16x8_t x = a;
|
||||
bfloat16x8_t y = b;
|
||||
return x + y;
|
||||
#else
|
||||
return binary_operator_via_float(std::plus<Vectorized<float>>(), a, b);
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<c10::BFloat16> inline operator-(
|
||||
const Vectorized<c10::BFloat16>& a,
|
||||
const Vectorized<c10::BFloat16>& b) {
|
||||
#ifdef __ARM_FEATURE_BF16
|
||||
bfloat16x8_t x = a;
|
||||
bfloat16x8_t y = b;
|
||||
return x - y;
|
||||
#else
|
||||
return binary_operator_via_float(std::minus<Vectorized<float>>(), a, b);
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<c10::BFloat16> inline operator*(
|
||||
const Vectorized<c10::BFloat16>& a,
|
||||
const Vectorized<c10::BFloat16>& b) {
|
||||
#ifdef __ARM_FEATURE_BF16
|
||||
bfloat16x8_t x = a;
|
||||
bfloat16x8_t y = b;
|
||||
return x * y;
|
||||
#else
|
||||
return binary_operator_via_float(std::multiplies<Vectorized<float>>(), a, b);
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<c10::BFloat16> inline operator/(
|
||||
const Vectorized<c10::BFloat16>& a,
|
||||
const Vectorized<c10::BFloat16>& b) {
|
||||
#ifdef __ARM_FEATURE_BF16
|
||||
bfloat16x8_t x = a;
|
||||
bfloat16x8_t y = b;
|
||||
return x / y;
|
||||
#else
|
||||
return binary_operator_via_float(std::divides<Vectorized<float>>(), a, b);
|
||||
#endif
|
||||
}
|
||||
|
||||
// frac. Implement this here so we can use subtraction
|
||||
@ -544,12 +607,19 @@ Vectorized<c10::BFloat16> inline fmadd(
|
||||
const Vectorized<c10::BFloat16>& a,
|
||||
const Vectorized<c10::BFloat16>& b,
|
||||
const Vectorized<c10::BFloat16>& c) {
|
||||
#ifdef __ARM_FEATURE_BF16
|
||||
bfloat16x8_t x = a;
|
||||
bfloat16x8_t y = b;
|
||||
bfloat16x8_t z = c;
|
||||
return x * y + z;
|
||||
#else
|
||||
// NOTE [BF16 FMA]: There isn't an FMA that accumulates into BF16! Also,
|
||||
// vbfmlalbq_f32 and vbfmlaltq_f32 take the even and odd-numbered
|
||||
// elements, not the bottom and top half, so they don't seem
|
||||
// particularly useful here. Ideally we would include dot product in
|
||||
// the Vectorized interface...
|
||||
return a * b + c;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
@ -557,8 +627,15 @@ Vectorized<c10::BFloat16> inline fnmadd(
|
||||
const Vectorized<c10::BFloat16>& a,
|
||||
const Vectorized<c10::BFloat16>& b,
|
||||
const Vectorized<c10::BFloat16>& c) {
|
||||
#ifdef __ARM_FEATURE_BF16
|
||||
bfloat16x8_t x = a;
|
||||
bfloat16x8_t y = b;
|
||||
bfloat16x8_t z = c;
|
||||
return (-x) * y + z;
|
||||
#else
|
||||
// See NOTE [BF16 FMA] above.
|
||||
return -a * b + c;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
@ -566,8 +643,15 @@ Vectorized<c10::BFloat16> inline fmsub(
|
||||
const Vectorized<c10::BFloat16>& a,
|
||||
const Vectorized<c10::BFloat16>& b,
|
||||
const Vectorized<c10::BFloat16>& c) {
|
||||
#ifdef __ARM_FEATURE_BF16
|
||||
bfloat16x8_t x = a;
|
||||
bfloat16x8_t y = b;
|
||||
bfloat16x8_t z = c;
|
||||
return x * y - z;
|
||||
#else
|
||||
// See NOTE [BF16 FMA] above.
|
||||
return a * b - c;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
@ -575,8 +659,15 @@ Vectorized<c10::BFloat16> inline fnmsub(
|
||||
const Vectorized<c10::BFloat16>& a,
|
||||
const Vectorized<c10::BFloat16>& b,
|
||||
const Vectorized<c10::BFloat16>& c) {
|
||||
#ifdef __ARM_FEATURE_BF16
|
||||
bfloat16x8_t x = a;
|
||||
bfloat16x8_t y = b;
|
||||
bfloat16x8_t z = c;
|
||||
return (-x) * y - z;
|
||||
#else
|
||||
// See NOTE [BF16 FMA] above.
|
||||
return -a * b - c;
|
||||
#endif
|
||||
}
|
||||
|
||||
#endif // !defined(C10_MOBILE) && defined(__aarch64__)
|
||||
|
||||
586
aten/src/ATen/cpu/vec/vec128/vec128_double_neon.h
Normal file
586
aten/src/ATen/cpu/vec/vec128/vec128_double_neon.h
Normal file
@ -0,0 +1,586 @@
|
||||
#pragma once
|
||||
|
||||
#include <ATen/cpu/vec/intrinsics.h>
|
||||
#include <ATen/cpu/vec/vec_base.h>
|
||||
#include <c10/macros/Macros.h>
|
||||
#include <c10/util/irange.h>
|
||||
#include <cmath>
|
||||
|
||||
namespace at::vec {
|
||||
// Note [CPU_CAPABILITY namespace]
|
||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
// This header, and all of its subheaders, will be compiled with
|
||||
// different architecture flags for each supported set of vector
|
||||
// intrinsics. So we need to make sure they aren't inadvertently
|
||||
// linked together. We do this by declaring objects in an `inline
|
||||
// namespace` which changes the name mangling, but can still be
|
||||
// accessed as `at::vec`.
|
||||
inline namespace CPU_CAPABILITY {
|
||||
|
||||
template <>
|
||||
struct is_vec_specialized_for<double> : std::bool_constant<true> {};
|
||||
|
||||
template <>
|
||||
class Vectorized<double> {
|
||||
private:
|
||||
float64x2_t values;
|
||||
|
||||
public:
|
||||
using value_type = double;
|
||||
using size_type = int;
|
||||
static constexpr size_type size() {
|
||||
return 2;
|
||||
}
|
||||
Vectorized() {
|
||||
values = vdupq_n_f64(0.0);
|
||||
}
|
||||
Vectorized(float64x2_t v) : values(v) {}
|
||||
Vectorized(double val) {
|
||||
values = vdupq_n_f64(val);
|
||||
}
|
||||
template <
|
||||
typename... Args,
|
||||
typename = std::enable_if_t<(sizeof...(Args) == size())>>
|
||||
Vectorized(Args... vals) {
|
||||
__at_align__ double buffer[size()] = {vals...};
|
||||
values = vld1q_f64(buffer);
|
||||
}
|
||||
operator float64x2_t() const {
|
||||
return values;
|
||||
}
|
||||
template <int64_t mask>
|
||||
static Vectorized<double> blend(
|
||||
const Vectorized<double>& a,
|
||||
const Vectorized<double>& b) {
|
||||
// Build an array of flags: each bit of element is 1 if the corresponding
|
||||
// bit in 'mask' is set, 0 otherwise.
|
||||
uint64x2_t maskArray = {
|
||||
(mask & 1ULL) ? 0xFFFFFFFFFFFFFFFF : 0,
|
||||
(mask & 2ULL) ? 0xFFFFFFFFFFFFFFFF : 0};
|
||||
// Use BSL to select elements from b where the mask is 1, else from a
|
||||
return vbslq_f64(maskArray, b.values, a.values);
|
||||
}
|
||||
static Vectorized<double> blendv(
|
||||
const Vectorized<double>& a,
|
||||
const Vectorized<double>& b,
|
||||
const Vectorized<double>& mask_) {
|
||||
return vbslq_f64(vreinterpretq_u64_f64(mask_.values), b.values, a.values);
|
||||
}
|
||||
template <typename step_t>
|
||||
static Vectorized<double> arange(
|
||||
double base = 0.,
|
||||
step_t step = static_cast<step_t>(1)) {
|
||||
return {base, base + static_cast<double>(step)};
|
||||
}
|
||||
static inline Vectorized<double> set(
|
||||
const Vectorized<double>& a,
|
||||
const Vectorized<double>& b,
|
||||
int64_t count = size()) {
|
||||
if (count == 0) {
|
||||
return a;
|
||||
} else if (count >= 2) {
|
||||
return b;
|
||||
} else {
|
||||
float64x2_t c = {b.values[0], a.values[1]};
|
||||
return c;
|
||||
}
|
||||
}
|
||||
static Vectorized<double> loadu(const void* ptr, int64_t count = size()) {
|
||||
if (count == size()) {
|
||||
return vld1q_f64(reinterpret_cast<const double*>(ptr));
|
||||
} else if (count == 1) {
|
||||
float64x1_t x = vld1_f64(reinterpret_cast<const double*>(ptr));
|
||||
float64x1_t z = {0.0};
|
||||
return vcombine_f64(x, z);
|
||||
} else {
|
||||
return vdupq_n_f64(0.0);
|
||||
}
|
||||
}
|
||||
void store(void* ptr, int64_t count = size()) const {
|
||||
if (count == size()) {
|
||||
vst1q_f64(reinterpret_cast<double*>(ptr), values);
|
||||
} else if (count == 1) {
|
||||
vst1_f64(reinterpret_cast<double*>(ptr), vget_low_f64(values));
|
||||
}
|
||||
}
|
||||
const double& operator[](int idx) const = delete;
|
||||
double& operator[](int idx) = delete;
|
||||
int64_t zero_mask() const {
|
||||
// returns an integer mask where all zero elements are translated to 1-bit
|
||||
// and others are translated to 0-bit
|
||||
uint64x2_t cmpReg = vceqzq_f64(values);
|
||||
uint64x2_t mask = {1, 2};
|
||||
uint64x2_t res = vandq_u64(cmpReg, mask);
|
||||
return res[0] | res[1];
|
||||
}
|
||||
Vectorized<double> isnan() const {
|
||||
// NaN check
|
||||
return vreinterpretq_f64_u32(
|
||||
vmvnq_u32(vreinterpretq_u32_u64(vceqq_f64(values, values))));
|
||||
}
|
||||
bool has_inf_nan() const {
|
||||
Vectorized<double> x = vsubq_f64(values, values);
|
||||
float64x2_t r = x.isnan();
|
||||
uint64x2_t u = vreinterpretq_u64_f64(r);
|
||||
return u[0] | u[1];
|
||||
}
|
||||
Vectorized<double> map(double (*f)(double)) const {
|
||||
float64x2_t result;
|
||||
result[0] = f(values[0]);
|
||||
result[1] = f(values[1]);
|
||||
return result;
|
||||
}
|
||||
Vectorized<double> map2(
|
||||
const Vectorized<double>& second,
|
||||
double (*const f)(double, double)) const {
|
||||
float64x2_t result;
|
||||
result[0] = f(values[0], second.values[0]);
|
||||
result[1] = f(values[1], second.values[1]);
|
||||
return result;
|
||||
}
|
||||
Vectorized<double> abs() const {
|
||||
return vabsq_f64(values);
|
||||
}
|
||||
Vectorized<double> angle() const {
|
||||
auto zero = Vectorized<double>(0.0);
|
||||
auto pi = Vectorized<double>(c10::pi<double>);
|
||||
auto tmp = blendv(zero, pi, vreinterpretq_f64_u64(vcltzq_f64(values)));
|
||||
return blendv(tmp, *this, isnan());
|
||||
}
|
||||
Vectorized<double> real() const {
|
||||
return *this;
|
||||
}
|
||||
Vectorized<double> imag() const {
|
||||
return Vectorized<double>(0.0);
|
||||
}
|
||||
Vectorized<double> conj() const {
|
||||
return *this;
|
||||
}
|
||||
Vectorized<double> acos() const {
|
||||
return USE_SLEEF(
|
||||
Vectorized<double>(Sleef_acosd2_u10(values)), map(std::acos));
|
||||
}
|
||||
Vectorized<double> acosh() const {
|
||||
return USE_SLEEF(
|
||||
Vectorized<double>(Sleef_acoshd2_u10(values)), map(std::acosh));
|
||||
}
|
||||
Vectorized<double> asin() const {
|
||||
return USE_SLEEF(
|
||||
Vectorized<double>(Sleef_asind2_u10(values)), map(std::asin));
|
||||
}
|
||||
Vectorized<double> asinh() const {
|
||||
return USE_SLEEF(
|
||||
Vectorized<double>(Sleef_asinhd2_u10(values)), map(std::asinh));
|
||||
}
|
||||
Vectorized<double> atan() const {
|
||||
return USE_SLEEF(
|
||||
Vectorized<double>(Sleef_atand2_u10(values)), map(std::atan));
|
||||
}
|
||||
Vectorized<double> atanh() const {
|
||||
return USE_SLEEF(
|
||||
Vectorized<double>(Sleef_atanhd2_u10(values)), map(std::atanh));
|
||||
}
|
||||
Vectorized<double> atan2(const Vectorized<double>& b) const {USE_SLEEF(
|
||||
{ return Vectorized<double>(Sleef_atan2d2_u10(values, b)); },
|
||||
{
|
||||
__at_align__ double tmp[size()];
|
||||
__at_align__ double tmp_b[size()];
|
||||
store(tmp);
|
||||
b.store(tmp_b);
|
||||
for (int64_t i = 0; i < size(); i++) {
|
||||
tmp[i] = std::atan2(tmp[i], tmp_b[i]);
|
||||
}
|
||||
return loadu(tmp);
|
||||
})} Vectorized<double> copysign(const Vectorized<double>& sign) const {
|
||||
USE_SLEEF(
|
||||
{ return Vectorized<double>(Sleef_copysignd2(values, sign)); },
|
||||
{
|
||||
__at_align__ double tmp[size()];
|
||||
__at_align__ double tmp_sign[size()];
|
||||
store(tmp);
|
||||
sign.store(tmp_sign);
|
||||
for (int64_t i = 0; i < size(); i++) {
|
||||
tmp[i] = std::copysign(tmp[i], tmp_sign[i]);
|
||||
}
|
||||
return loadu(tmp);
|
||||
})} Vectorized<double> erf() const {
|
||||
return USE_SLEEF(
|
||||
Vectorized<double>(Sleef_erfd2_u10(values)), map(std::erf));
|
||||
}
|
||||
Vectorized<double> erfc() const {
|
||||
return USE_SLEEF(
|
||||
Vectorized<double>(Sleef_erfcd2_u15(values)), map(std::erfc));
|
||||
}
|
||||
Vectorized<double> exp() const {
|
||||
return USE_SLEEF(
|
||||
Vectorized<double>(Sleef_expd2_u10(values)), map(std::exp));
|
||||
}
|
||||
Vectorized<double> exp2() const {
|
||||
return USE_SLEEF(
|
||||
Vectorized<double>(Sleef_exp2d2_u10(values)), map(std::exp2));
|
||||
}
|
||||
Vectorized<double> expm1() const {
|
||||
return USE_SLEEF(
|
||||
Vectorized<double>(Sleef_expm1d2_u10(values)), map(std::expm1));
|
||||
}
|
||||
Vectorized<double> fmod(const Vectorized<double>& q) const {USE_SLEEF(
|
||||
{ return Vectorized<double>(Sleef_fmodd2(values, q)); },
|
||||
{
|
||||
__at_align__ double tmp[size()];
|
||||
__at_align__ double tmp_q[size()];
|
||||
store(tmp);
|
||||
q.store(tmp_q);
|
||||
for (int64_t i = 0; i < size(); i++) {
|
||||
tmp[i] = std::fmod(tmp[i], tmp_q[i]);
|
||||
}
|
||||
return loadu(tmp);
|
||||
})} Vectorized<double> hypot(const Vectorized<double>& b) const {
|
||||
USE_SLEEF(
|
||||
{ return Vectorized<double>(Sleef_hypotd2_u05(values, b)); },
|
||||
{
|
||||
__at_align__ double tmp[size()];
|
||||
__at_align__ double tmp_b[size()];
|
||||
store(tmp);
|
||||
b.store(tmp_b);
|
||||
for (int64_t i = 0; i < size(); i++) {
|
||||
tmp[i] = std::hypot(tmp[i], tmp_b[i]);
|
||||
}
|
||||
return loadu(tmp);
|
||||
})} Vectorized<double> i0() const {
|
||||
return map(calc_i0);
|
||||
}
|
||||
Vectorized<double> nextafter(const Vectorized<double>& b) const {USE_SLEEF(
|
||||
{ return Vectorized<double>(Sleef_nextafterd2(values, b)); },
|
||||
{
|
||||
__at_align__ double tmp[size()];
|
||||
__at_align__ double tmp_b[size()];
|
||||
store(tmp);
|
||||
b.store(tmp_b);
|
||||
for (int64_t i = 0; i < size(); ++i) {
|
||||
tmp[i] = std::nextafter(tmp[i], tmp_b[i]);
|
||||
}
|
||||
return loadu(tmp);
|
||||
})} Vectorized<double> log() const {
|
||||
return USE_SLEEF(
|
||||
Vectorized<double>(Sleef_logd2_u10(values)), map(std::log));
|
||||
}
|
||||
Vectorized<double> log2() const {
|
||||
return USE_SLEEF(
|
||||
Vectorized<double>(Sleef_log2d2_u10(values)), map(std::log2));
|
||||
}
|
||||
Vectorized<double> log10() const {
|
||||
return USE_SLEEF(
|
||||
Vectorized<double>(Sleef_log10d2_u10(values)), map(std::log10));
|
||||
}
|
||||
Vectorized<double> log1p() const {
|
||||
return USE_SLEEF(
|
||||
Vectorized<double>(Sleef_log1pd2_u10(values)), map(std::log1p));
|
||||
}
|
||||
Vectorized<double> frac() const;
|
||||
Vectorized<double> sin() const {
|
||||
return USE_SLEEF(
|
||||
Vectorized<double>(Sleef_sind2_u10(values)), map(std::sin));
|
||||
}
|
||||
Vectorized<double> sinh() const {
|
||||
return USE_SLEEF(
|
||||
Vectorized<double>(Sleef_sinhd2_u10(values)), map(std::sinh));
|
||||
}
|
||||
Vectorized<double> cos() const {
|
||||
return USE_SLEEF(
|
||||
Vectorized<double>(Sleef_cosd2_u10(values)), map(std::cos));
|
||||
}
|
||||
Vectorized<double> cosh() const {
|
||||
return USE_SLEEF(
|
||||
Vectorized<double>(Sleef_coshd2_u10(values)), map(std::cosh));
|
||||
}
|
||||
Vectorized<double> pow(const Vectorized<double>& b) const {USE_SLEEF(
|
||||
{ return Vectorized<double>(Sleef_powd2_u10(values, b)); },
|
||||
{
|
||||
__at_align__ double tmp[size()];
|
||||
__at_align__ double tmp_b[size()];
|
||||
store(tmp);
|
||||
b.store(tmp_b);
|
||||
for (int64_t i = 0; i < size(); i++) {
|
||||
tmp[i] = std::pow(tmp[i], tmp_b[i]);
|
||||
}
|
||||
return loadu(tmp);
|
||||
})} // Comparison using the _CMP_**_OQ predicate.
|
||||
// `O`: get false if an operand is NaN
|
||||
// `Q`: do not raise if an operand is NaN
|
||||
Vectorized<double> tan() const {
|
||||
return USE_SLEEF(
|
||||
Vectorized<double>(Sleef_tand2_u10(values)), map(std::tan));
|
||||
}
|
||||
Vectorized<double> tanh() const {
|
||||
return USE_SLEEF(
|
||||
Vectorized<double>(Sleef_tanhd2_u10(values)), map(std::tanh));
|
||||
}
|
||||
Vectorized<double> lgamma() const {
|
||||
return USE_SLEEF(
|
||||
Vectorized<double>(Sleef_lgammad2_u10(values)), map(std::lgamma));
|
||||
}
|
||||
Vectorized<double> erfinv() const {
|
||||
return map(calc_erfinv);
|
||||
}
|
||||
Vectorized<double> exp_u20() const {
|
||||
return exp();
|
||||
}
|
||||
Vectorized<double> fexp_u20() const {
|
||||
return exp();
|
||||
}
|
||||
Vectorized<double> i0e() const {
|
||||
return map(calc_i0e);
|
||||
}
|
||||
Vectorized<double> digamma() const {
|
||||
return map(calc_digamma);
|
||||
}
|
||||
Vectorized<double> igamma(const Vectorized<double>& x) const {
|
||||
__at_align__ double tmp[size()];
|
||||
__at_align__ double tmp_x[size()];
|
||||
store(tmp);
|
||||
x.store(tmp_x);
|
||||
for (int64_t i = 0; i < size(); i++) {
|
||||
tmp[i] = calc_igamma(tmp[i], tmp_x[i]);
|
||||
}
|
||||
return loadu(tmp);
|
||||
}
|
||||
Vectorized<double> igammac(const Vectorized<double>& x) const {
|
||||
__at_align__ double tmp[size()];
|
||||
__at_align__ double tmp_x[size()];
|
||||
store(tmp);
|
||||
x.store(tmp_x);
|
||||
for (int64_t i = 0; i < size(); i++) {
|
||||
tmp[i] = calc_igammac(tmp[i], tmp_x[i]);
|
||||
}
|
||||
return loadu(tmp);
|
||||
}
|
||||
Vectorized<double> ceil() const {
|
||||
return vrndpq_f64(values);
|
||||
}
|
||||
Vectorized<double> floor() const {
|
||||
return vrndmq_f64(values);
|
||||
}
|
||||
Vectorized<double> neg() const {
|
||||
return vnegq_f64(values);
|
||||
}
|
||||
Vectorized<double> round() const {
|
||||
return vrndiq_f64(values);
|
||||
}
|
||||
Vectorized<double> trunc() const {
|
||||
return vrndq_f64(values);
|
||||
}
|
||||
Vectorized<double> sqrt() const {
|
||||
return vsqrtq_f64(values);
|
||||
}
|
||||
Vectorized<double> reciprocal() const {
|
||||
return vdivq_f64(vdupq_n_f64(1.0), values);
|
||||
}
|
||||
Vectorized<double> rsqrt() const {
|
||||
return vdivq_f64(vdupq_n_f64(1.0), vsqrtq_f64(values));
|
||||
}
|
||||
double reduce_add() const {
|
||||
return vaddvq_f64(values);
|
||||
}
|
||||
double reduce_max() const {
|
||||
return vmaxvq_f64(values);
|
||||
}
|
||||
Vectorized<double> operator==(const Vectorized<double>& other) const {
|
||||
return Vectorized<double>(
|
||||
vreinterpretq_f64_u64(vceqq_f64(values, other.values)));
|
||||
}
|
||||
|
||||
Vectorized<double> operator!=(const Vectorized<double>& other) const {
|
||||
float64x2_t r0 = vreinterpretq_f64_u32(
|
||||
vmvnq_u32(vreinterpretq_u32_u64(vceqq_f64(values, other.values))));
|
||||
return Vectorized<double>(r0);
|
||||
}
|
||||
|
||||
Vectorized<double> operator<(const Vectorized<double>& other) const {
|
||||
return Vectorized<double>(
|
||||
vreinterpretq_f64_u64(vcltq_f64(values, other.values)));
|
||||
}
|
||||
|
||||
Vectorized<double> operator<=(const Vectorized<double>& other) const {
|
||||
return Vectorized<double>(
|
||||
vreinterpretq_f64_u64(vcleq_f64(values, other.values)));
|
||||
}
|
||||
|
||||
Vectorized<double> operator>(const Vectorized<double>& other) const {
|
||||
return Vectorized<double>(
|
||||
vreinterpretq_f64_u64(vcgtq_f64(values, other.values)));
|
||||
}
|
||||
|
||||
Vectorized<double> operator>=(const Vectorized<double>& other) const {
|
||||
return Vectorized<double>(
|
||||
vreinterpretq_f64_u64(vcgeq_f64(values, other.values)));
|
||||
}
|
||||
|
||||
Vectorized<double> eq(const Vectorized<double>& other) const;
|
||||
Vectorized<double> ne(const Vectorized<double>& other) const;
|
||||
Vectorized<double> gt(const Vectorized<double>& other) const;
|
||||
Vectorized<double> ge(const Vectorized<double>& other) const;
|
||||
Vectorized<double> lt(const Vectorized<double>& other) const;
|
||||
Vectorized<double> le(const Vectorized<double>& other) const;
|
||||
};
|
||||
|
||||
template <>
|
||||
Vectorized<double> inline operator+(
|
||||
const Vectorized<double>& a,
|
||||
const Vectorized<double>& b) {
|
||||
return vaddq_f64(a, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<double> inline operator-(
|
||||
const Vectorized<double>& a,
|
||||
const Vectorized<double>& b) {
|
||||
return vsubq_f64(a, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<double> inline operator*(
|
||||
const Vectorized<double>& a,
|
||||
const Vectorized<double>& b) {
|
||||
return vmulq_f64(a, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<double> inline operator/(
|
||||
const Vectorized<double>& a,
|
||||
const Vectorized<double>& b) {
|
||||
return vdivq_f64(a, b);
|
||||
}
|
||||
|
||||
// frac. Implement this here so we can use subtraction
|
||||
Vectorized<double> inline Vectorized<double>::frac() const {
|
||||
return *this - this->trunc();
|
||||
}
|
||||
|
||||
// Implements the IEEE 754 201X `maximum` operation, which propagates NaN if
|
||||
// either input is a NaN.
|
||||
template <>
|
||||
Vectorized<double> inline maximum(
|
||||
const Vectorized<double>& a,
|
||||
const Vectorized<double>& b) {
|
||||
return vmaxq_f64(a, b);
|
||||
}
|
||||
|
||||
// Implements the IEEE 754 201X `minimum` operation, which propagates NaN if
|
||||
// either input is a NaN.
|
||||
template <>
|
||||
Vectorized<double> inline minimum(
|
||||
const Vectorized<double>& a,
|
||||
const Vectorized<double>& b) {
|
||||
return vminq_f64(a, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<double> inline clamp(
|
||||
const Vectorized<double>& a,
|
||||
const Vectorized<double>& min,
|
||||
const Vectorized<double>& max) {
|
||||
return vminq_f64(max, vmaxq_f64(min, a));
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<double> inline clamp_max(
|
||||
const Vectorized<double>& a,
|
||||
const Vectorized<double>& max) {
|
||||
return vminq_f64(max, a);
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<double> inline clamp_min(
|
||||
const Vectorized<double>& a,
|
||||
const Vectorized<double>& min) {
|
||||
return vmaxq_f64(min, a);
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<double> inline operator&(
|
||||
const Vectorized<double>& a,
|
||||
const Vectorized<double>& b) {
|
||||
return vreinterpretq_f64_u64(
|
||||
vandq_u64(vreinterpretq_u64_f64(a), vreinterpretq_u64_f64(b)));
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<double> inline operator|(
|
||||
const Vectorized<double>& a,
|
||||
const Vectorized<double>& b) {
|
||||
return vreinterpretq_f64_u64(
|
||||
vorrq_u64(vreinterpretq_u64_f64(a), vreinterpretq_u64_f64(b)));
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<double> inline operator^(
|
||||
const Vectorized<double>& a,
|
||||
const Vectorized<double>& b) {
|
||||
return vreinterpretq_f64_u64(
|
||||
veorq_u64(vreinterpretq_u64_f64(a), vreinterpretq_u64_f64(b)));
|
||||
}
|
||||
|
||||
inline Vectorized<double> Vectorized<double>::eq(
|
||||
const Vectorized<double>& other) const {
|
||||
return (*this == other) & Vectorized<double>(1.0);
|
||||
}
|
||||
|
||||
inline Vectorized<double> Vectorized<double>::ne(
|
||||
const Vectorized<double>& other) const {
|
||||
return (*this != other) & Vectorized<double>(1.0);
|
||||
}
|
||||
|
||||
inline Vectorized<double> Vectorized<double>::gt(
|
||||
const Vectorized<double>& other) const {
|
||||
return (*this > other) & Vectorized<double>(1.0);
|
||||
}
|
||||
|
||||
inline Vectorized<double> Vectorized<double>::ge(
|
||||
const Vectorized<double>& other) const {
|
||||
return (*this >= other) & Vectorized<double>(1.0);
|
||||
}
|
||||
|
||||
inline Vectorized<double> Vectorized<double>::lt(
|
||||
const Vectorized<double>& other) const {
|
||||
return (*this < other) & Vectorized<double>(1.0);
|
||||
}
|
||||
|
||||
inline Vectorized<double> Vectorized<double>::le(
|
||||
const Vectorized<double>& other) const {
|
||||
return (*this <= other) & Vectorized<double>(1.0);
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<double> inline fmadd(
|
||||
const Vectorized<double>& a,
|
||||
const Vectorized<double>& b,
|
||||
const Vectorized<double>& c) {
|
||||
return vfmaq_f64(c, a, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<double> inline fnmadd(
|
||||
const Vectorized<double>& a,
|
||||
const Vectorized<double>& b,
|
||||
const Vectorized<double>& c) {
|
||||
return vfmsq_f64(c, a, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<double> inline fmsub(
|
||||
const Vectorized<double>& a,
|
||||
const Vectorized<double>& b,
|
||||
const Vectorized<double>& c) {
|
||||
return vfmaq_f64(vnegq_f64(c), a, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<double> inline fnmsub(
|
||||
const Vectorized<double>& a,
|
||||
const Vectorized<double>& b,
|
||||
const Vectorized<double>& c) {
|
||||
return vfmsq_f64(vnegq_f64(c), a, b);
|
||||
}
|
||||
|
||||
} // namespace CPU_CAPABILITY
|
||||
} // namespace at::vec
|
||||
378
aten/src/ATen/cpu/vec/vec128/vec128_uint_aarch64.h
Normal file
378
aten/src/ATen/cpu/vec/vec128/vec128_uint_aarch64.h
Normal file
@ -0,0 +1,378 @@
|
||||
#pragma once
|
||||
|
||||
#include <ATen/cpu/vec/intrinsics.h>
|
||||
#include <ATen/cpu/vec/vec_base.h>
|
||||
#include <c10/macros/Macros.h>
|
||||
#include <c10/util/irange.h>
|
||||
|
||||
namespace at::vec {
|
||||
// Note [CPU_CAPABILITY namespace]
|
||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
// This header, and all of its subheaders, will be compiled with
|
||||
// different architecture flags for each supported set of vector
|
||||
// intrinsics. So we need to make sure they aren't inadvertently
|
||||
// linked together. We do this by declaring objects in an `inline
|
||||
// namespace` which changes the name mangling, but can still be
|
||||
// accessed as `at::vec`.
|
||||
inline namespace CPU_CAPABILITY {
|
||||
|
||||
#define VEC_UINT_NEON_TEMPLATE(vl, bit) \
|
||||
template <> \
|
||||
struct is_vec_specialized_for<uint##bit##_t> : std::bool_constant<true> {}; \
|
||||
\
|
||||
template <> \
|
||||
class Vectorized<uint##bit##_t> { \
|
||||
using neon_type = uint##bit##x##vl##_t; \
|
||||
\
|
||||
private: \
|
||||
neon_type values; \
|
||||
\
|
||||
public: \
|
||||
using value_type = uint##bit##_t; \
|
||||
using size_type = int; \
|
||||
static constexpr size_type size() { \
|
||||
return vl; \
|
||||
} \
|
||||
Vectorized() { \
|
||||
values = vdupq_n_u##bit(0); \
|
||||
} \
|
||||
Vectorized(neon_type v) : values(v) {} \
|
||||
Vectorized(uint##bit##_t val); \
|
||||
template < \
|
||||
typename... Args, \
|
||||
typename = std::enable_if_t<(sizeof...(Args) == size())>> \
|
||||
Vectorized(Args... vals) { \
|
||||
__at_align__ uint##bit##_t buffer[size()] = {vals...}; \
|
||||
values = vld1q_u##bit(buffer); \
|
||||
} \
|
||||
operator neon_type() const { \
|
||||
return values; \
|
||||
} \
|
||||
static Vectorized<uint##bit##_t> loadu( \
|
||||
const void* ptr, \
|
||||
uint64_t count = size()); \
|
||||
void store(void* ptr, uint64_t count = size()) const; \
|
||||
template <uint64_t mask> \
|
||||
static Vectorized<uint##bit##_t> blend( \
|
||||
const Vectorized<uint##bit##_t>& a, \
|
||||
const Vectorized<uint##bit##_t>& b); \
|
||||
static Vectorized<uint##bit##_t> blendv( \
|
||||
const Vectorized<uint##bit##_t>& a, \
|
||||
const Vectorized<uint##bit##_t>& b, \
|
||||
const Vectorized<uint##bit##_t>& mask_) { \
|
||||
return vbslq_u##bit(mask_.values, b, a); \
|
||||
} \
|
||||
template <typename step_t> \
|
||||
static Vectorized<uint##bit##_t> arange( \
|
||||
value_type base = 0, \
|
||||
step_t step = static_cast<step_t>(1)); \
|
||||
static Vectorized<uint##bit##_t> set( \
|
||||
const Vectorized<uint##bit##_t>& a, \
|
||||
const Vectorized<uint##bit##_t>& b, \
|
||||
uint64_t count = size()); \
|
||||
const uint##bit##_t& operator[](uint idx) const = delete; \
|
||||
uint##bit##_t& operator[](uint idx) = delete; \
|
||||
Vectorized<uint##bit##_t> abs() const { \
|
||||
return values; \
|
||||
} \
|
||||
Vectorized<uint##bit##_t> real() const { \
|
||||
return values; \
|
||||
} \
|
||||
Vectorized<uint##bit##_t> imag() const { \
|
||||
return vdupq_n_u##bit(0); \
|
||||
} \
|
||||
Vectorized<uint##bit##_t> conj() const { \
|
||||
return values; \
|
||||
} \
|
||||
Vectorized<uint##bit##_t> neg() const { \
|
||||
return vreinterpretq_u##bit##_s##bit( \
|
||||
vnegq_s##bit(vreinterpretq_s##bit##_u##bit(values))); \
|
||||
} \
|
||||
uint##bit##_t reduce_add() const { \
|
||||
return vaddvq_u##bit(values); \
|
||||
} \
|
||||
uint##bit##_t reduce_max() const; \
|
||||
Vectorized<uint##bit##_t> operator==( \
|
||||
const Vectorized<uint##bit##_t>& other) const { \
|
||||
return Vectorized<value_type>(vceqq_u##bit(values, other.values)); \
|
||||
} \
|
||||
Vectorized<uint##bit##_t> operator!=( \
|
||||
const Vectorized<uint##bit##_t>& other) const; \
|
||||
Vectorized<uint##bit##_t> operator<( \
|
||||
const Vectorized<uint##bit##_t>& other) const { \
|
||||
return Vectorized<value_type>(vcltq_u##bit(values, other.values)); \
|
||||
} \
|
||||
Vectorized<uint##bit##_t> operator<=( \
|
||||
const Vectorized<uint##bit##_t>& other) const { \
|
||||
return Vectorized<value_type>(vcleq_u##bit(values, other.values)); \
|
||||
} \
|
||||
Vectorized<uint##bit##_t> operator>( \
|
||||
const Vectorized<uint##bit##_t>& other) const { \
|
||||
return Vectorized<value_type>(vcgtq_u##bit(values, other.values)); \
|
||||
} \
|
||||
Vectorized<uint##bit##_t> operator>=( \
|
||||
const Vectorized<uint##bit##_t>& other) const { \
|
||||
return Vectorized<value_type>(vcgeq_u##bit(values, other.values)); \
|
||||
} \
|
||||
Vectorized<uint##bit##_t> eq( \
|
||||
const Vectorized<uint##bit##_t>& other) const; \
|
||||
Vectorized<uint##bit##_t> ne( \
|
||||
const Vectorized<uint##bit##_t>& other) const; \
|
||||
Vectorized<uint##bit##_t> gt( \
|
||||
const Vectorized<uint##bit##_t>& other) const; \
|
||||
Vectorized<uint##bit##_t> ge( \
|
||||
const Vectorized<uint##bit##_t>& other) const; \
|
||||
Vectorized<uint##bit##_t> lt( \
|
||||
const Vectorized<uint##bit##_t>& other) const; \
|
||||
Vectorized<uint##bit##_t> le( \
|
||||
const Vectorized<uint##bit##_t>& other) const; \
|
||||
}; \
|
||||
template <> \
|
||||
Vectorized<uint##bit##_t> inline operator+( \
|
||||
const Vectorized<uint##bit##_t>& a, \
|
||||
const Vectorized<uint##bit##_t>& b) { \
|
||||
return vaddq_u##bit(a, b); \
|
||||
} \
|
||||
template <> \
|
||||
Vectorized<uint##bit##_t> inline operator-( \
|
||||
const Vectorized<uint##bit##_t>& a, \
|
||||
const Vectorized<uint##bit##_t>& b) { \
|
||||
return vsubq_u##bit(a, b); \
|
||||
} \
|
||||
template <> \
|
||||
Vectorized<uint##bit##_t> inline operator&( \
|
||||
const Vectorized<uint##bit##_t>& a, \
|
||||
const Vectorized<uint##bit##_t>& b) { \
|
||||
return vandq_u##bit(a, b); \
|
||||
} \
|
||||
template <> \
|
||||
Vectorized<uint##bit##_t> inline operator|( \
|
||||
const Vectorized<uint##bit##_t>& a, \
|
||||
const Vectorized<uint##bit##_t>& b) { \
|
||||
return vorrq_u##bit(a, b); \
|
||||
} \
|
||||
template <> \
|
||||
Vectorized<uint##bit##_t> inline operator^( \
|
||||
const Vectorized<uint##bit##_t>& a, \
|
||||
const Vectorized<uint##bit##_t>& b) { \
|
||||
return veorq_u##bit(a, b); \
|
||||
} \
|
||||
Vectorized<uint##bit##_t> inline Vectorized<uint##bit##_t>::eq( \
|
||||
const Vectorized<uint##bit##_t>& other) const { \
|
||||
return (*this == other) & Vectorized<uint##bit##_t>(1); \
|
||||
} \
|
||||
Vectorized<uint##bit##_t> inline Vectorized<uint##bit##_t>::ne( \
|
||||
const Vectorized<uint##bit##_t>& other) const { \
|
||||
return (*this != other) & Vectorized<uint##bit##_t>(1); \
|
||||
} \
|
||||
Vectorized<uint##bit##_t> inline Vectorized<uint##bit##_t>::gt( \
|
||||
const Vectorized<uint##bit##_t>& other) const { \
|
||||
return (*this > other) & Vectorized<uint##bit##_t>(1); \
|
||||
} \
|
||||
Vectorized<uint##bit##_t> inline Vectorized<uint##bit##_t>::ge( \
|
||||
const Vectorized<uint##bit##_t>& other) const { \
|
||||
return (*this >= other) & Vectorized<uint##bit##_t>(1); \
|
||||
} \
|
||||
Vectorized<uint##bit##_t> inline Vectorized<uint##bit##_t>::lt( \
|
||||
const Vectorized<uint##bit##_t>& other) const { \
|
||||
return (*this < other) & Vectorized<uint##bit##_t>(1); \
|
||||
} \
|
||||
Vectorized<uint##bit##_t> inline Vectorized<uint##bit##_t>::le( \
|
||||
const Vectorized<uint##bit##_t>& other) const { \
|
||||
return (*this <= other) & Vectorized<uint##bit##_t>(1); \
|
||||
}
|
||||
|
||||
VEC_UINT_NEON_TEMPLATE(16, 8)
|
||||
|
||||
inline uint8_t Vectorized<uint8_t>::reduce_max() const {
|
||||
return vmaxvq_u8(values);
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<uint8_t> inline operator*(
|
||||
const Vectorized<uint8_t>& a,
|
||||
const Vectorized<uint8_t>& b) {
|
||||
return vmulq_u8(a, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline Vectorized<uint8_t> operator~(const Vectorized<uint8_t>& a) {
|
||||
return vmvnq_u8(a);
|
||||
}
|
||||
|
||||
inline Vectorized<uint8_t> Vectorized<uint8_t>::operator!=(
|
||||
const Vectorized<uint8_t>& other) const {
|
||||
return ~(*this == other);
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<uint8_t> inline minimum(
|
||||
const Vectorized<uint8_t>& a,
|
||||
const Vectorized<uint8_t>& b) {
|
||||
return vminq_u8(a, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<uint8_t> inline maximum(
|
||||
const Vectorized<uint8_t>& a,
|
||||
const Vectorized<uint8_t>& b) {
|
||||
return vmaxq_u8(a, b);
|
||||
}
|
||||
|
||||
template <uint64_t mask>
|
||||
Vectorized<uint8_t> Vectorized<uint8_t>::blend(
|
||||
const Vectorized<uint8_t>& a,
|
||||
const Vectorized<uint8_t>& b) {
|
||||
// Build an array of flags: each bit of element is 1 if the corresponding bit
|
||||
// in 'mask' is set, 0 otherwise.
|
||||
uint8x16_t maskArray = {
|
||||
(mask & 1LL) ? 0xFF : 0,
|
||||
(mask & 2LL) ? 0xFF : 0,
|
||||
(mask & 4LL) ? 0xFF : 0,
|
||||
(mask & 8LL) ? 0xFF : 0,
|
||||
(mask & 16LL) ? 0xFF : 0,
|
||||
(mask & 32LL) ? 0xFF : 0,
|
||||
(mask & 64LL) ? 0xFF : 0,
|
||||
(mask & 128LL) ? 0xFF : 0,
|
||||
(mask & 256LL) ? 0xFF : 0,
|
||||
(mask & 512LL) ? 0xFF : 0,
|
||||
(mask & 1024LL) ? 0xFF : 0,
|
||||
(mask & 2048LL) ? 0xFF : 0,
|
||||
(mask & 4096LL) ? 0xFF : 0,
|
||||
(mask & 8192LL) ? 0xFF : 0,
|
||||
(mask & 16384LL) ? 0xFF : 0,
|
||||
(mask & 32768LL) ? 0xFF : 0};
|
||||
// Use BSL to select elements from b where the mask is 1, else from a
|
||||
return vbslq_u8(maskArray, b.values, a.values);
|
||||
}
|
||||
|
||||
#define VEC_UINT_NEON_OPS(vl, bit) \
|
||||
inline Vectorized<uint##bit##_t>::Vectorized(uint##bit##_t val) { \
|
||||
values = vdupq_n_u##bit(val); \
|
||||
} \
|
||||
inline Vectorized<uint##bit##_t> Vectorized<uint##bit##_t>::loadu( \
|
||||
const void* ptr, uint64_t count) { \
|
||||
if (count == size()) { \
|
||||
return vld1q_u##bit(reinterpret_cast<const uint##bit##_t*>(ptr)); \
|
||||
} else { \
|
||||
__at_align__ uint##bit##_t tmp_values[size()]; \
|
||||
for (const auto i : c10::irange(size())) { \
|
||||
tmp_values[i] = 0; \
|
||||
} \
|
||||
std::memcpy( \
|
||||
tmp_values, \
|
||||
reinterpret_cast<const uint##bit##_t*>(ptr), \
|
||||
count * sizeof(uint##bit##_t)); \
|
||||
return vld1q_u##bit(reinterpret_cast<const uint##bit##_t*>(tmp_values)); \
|
||||
} \
|
||||
} \
|
||||
inline void Vectorized<uint##bit##_t>::store(void* ptr, uint64_t count) \
|
||||
const { \
|
||||
if (count == size()) { \
|
||||
vst1q_u##bit(reinterpret_cast<uint##bit##_t*>(ptr), values); \
|
||||
} else { \
|
||||
uint##bit##_t tmp_values[size()]; \
|
||||
vst1q_u##bit(reinterpret_cast<uint##bit##_t*>(tmp_values), values); \
|
||||
std::memcpy(ptr, tmp_values, count * sizeof(uint##bit##_t)); \
|
||||
} \
|
||||
}
|
||||
|
||||
VEC_UINT_NEON_OPS(16, 8)
|
||||
|
||||
template <typename step_t>
|
||||
inline Vectorized<uint8_t> Vectorized<uint8_t>::arange(
|
||||
uint8_t base,
|
||||
step_t step) {
|
||||
const Vectorized<uint8_t> base_vec(base);
|
||||
const Vectorized<uint8_t> step_vec(step);
|
||||
const uint8x16_t step_sizes = {
|
||||
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15};
|
||||
return vmlaq_u8(base_vec, step_sizes, step_vec);
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<uint8_t> inline operator>>(
|
||||
const Vectorized<uint8_t>& a,
|
||||
const Vectorized<uint8_t>& b) {
|
||||
uint8x16_t x = a;
|
||||
uint8x16_t bound = vdupq_n_u8(8);
|
||||
uint8x16_t z = vminq_u8(b, bound);
|
||||
return x >> z;
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<uint8_t> inline operator<<(
|
||||
const Vectorized<uint8_t>& a,
|
||||
const Vectorized<uint8_t>& b) {
|
||||
uint8x16_t bound = vdupq_n_u8(8);
|
||||
uint8x16_t z = vminq_u8(b, bound);
|
||||
return vshlq_u8(a, vreinterpretq_s8_u8(z));
|
||||
}
|
||||
|
||||
inline Vectorized<uint8_t> Vectorized<uint8_t>::set(
|
||||
const Vectorized<uint8_t>& a,
|
||||
const Vectorized<uint8_t>& b,
|
||||
uint64_t count) {
|
||||
if (count == 0) {
|
||||
return a;
|
||||
} else if (count >= 16) {
|
||||
return b;
|
||||
} else {
|
||||
// Build an array of flags: each bit of element is 1 if the corresponding
|
||||
// bit in 'mask' is set, 0 otherwise.
|
||||
uint8x16_t maskArray = {
|
||||
static_cast<uint8_t>((count >= 1LL) ? 0xFF : 0),
|
||||
static_cast<uint8_t>((count >= 2LL) ? 0xFF : 0),
|
||||
static_cast<uint8_t>((count >= 3LL) ? 0xFF : 0),
|
||||
static_cast<uint8_t>((count >= 4LL) ? 0xFF : 0),
|
||||
static_cast<uint8_t>((count >= 5LL) ? 0xFF : 0),
|
||||
static_cast<uint8_t>((count >= 6LL) ? 0xFF : 0),
|
||||
static_cast<uint8_t>((count >= 7LL) ? 0xFF : 0),
|
||||
static_cast<uint8_t>((count >= 8LL) ? 0xFF : 0),
|
||||
static_cast<uint8_t>((count >= 9LL) ? 0xFF : 0),
|
||||
static_cast<uint8_t>((count >= 10LL) ? 0xFF : 0),
|
||||
static_cast<uint8_t>((count >= 11LL) ? 0xFF : 0),
|
||||
static_cast<uint8_t>((count >= 12LL) ? 0xFF : 0),
|
||||
static_cast<uint8_t>((count >= 13LL) ? 0xFF : 0),
|
||||
static_cast<uint8_t>((count >= 14LL) ? 0xFF : 0),
|
||||
static_cast<uint8_t>((count >= 15LL) ? 0xFF : 0),
|
||||
0};
|
||||
|
||||
// Use BSL to select elements from b where the mask is 1, else from a
|
||||
return vbslq_u8(maskArray, b.values, a.values);
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<uint8_t> inline operator/(
|
||||
const Vectorized<uint8_t>& a,
|
||||
const Vectorized<uint8_t>& b) {
|
||||
uint8x16_t x = a;
|
||||
uint8x16_t y = b;
|
||||
return x / y;
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<uint8_t> inline clamp(
|
||||
const Vectorized<uint8_t>& a,
|
||||
const Vectorized<uint8_t>& min,
|
||||
const Vectorized<uint8_t>& max) {
|
||||
return minimum(max, maximum(min, a));
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<uint8_t> inline clamp_max(
|
||||
const Vectorized<uint8_t>& a,
|
||||
const Vectorized<uint8_t>& max) {
|
||||
return minimum(max, a);
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<uint8_t> inline clamp_min(
|
||||
const Vectorized<uint8_t>& a,
|
||||
const Vectorized<uint8_t>& min) {
|
||||
return maximum(min, a);
|
||||
}
|
||||
|
||||
} // namespace CPU_CAPABILITY
|
||||
} // namespace at::vec
|
||||
@ -1390,7 +1390,7 @@ std::pair<Vectorized<float>, Vectorized<float>> inline convert_int8_to_float(
|
||||
|
||||
std::pair<Vectorized<float>, Vectorized<float>> inline convert_int8_to_float(
|
||||
at::vec::Vectorized<uint8_t> src) {
|
||||
auto u8x8 = vld1_u8(src.operator const uint8_t*());
|
||||
auto u8x8 = vget_low_u8(src);
|
||||
auto u16x8 = vmovl_u8(u8x8);
|
||||
auto u32x4_hi = vmovl_u16(vget_high_u16(u16x8));
|
||||
auto u32x4_lo = vmovl_u16(vget_low_u16(u16x8));
|
||||
@ -1412,7 +1412,7 @@ Vectorized<float> inline convert_int8_half_register_to_float(
|
||||
|
||||
Vectorized<float> inline convert_int8_half_register_to_float(
|
||||
at::vec::Vectorized<uint8_t> src) {
|
||||
auto u8x8 = vld1_u8(src.operator const uint8_t*());
|
||||
auto u8x8 = vget_low_u8(src);
|
||||
auto u16x8 = vmovl_u8(u8x8);
|
||||
auto u32x4_lo = vmovl_u16(vget_low_u16(u16x8));
|
||||
|
||||
|
||||
192
aten/src/ATen/cuda/CUDAGreenContext.cpp
Normal file
192
aten/src/ATen/cuda/CUDAGreenContext.cpp
Normal file
@ -0,0 +1,192 @@
|
||||
#include <ATen/cuda/CUDAGreenContext.h>
|
||||
|
||||
namespace at::cuda {
|
||||
GreenContext::GreenContext(uint32_t device_id, uint32_t num_sms) {
|
||||
#if CUDA_HAS_GREEN_CONTEXT
|
||||
int driver_version;
|
||||
C10_CUDA_CHECK(cudaDriverGetVersion(&driver_version));
|
||||
TORCH_CHECK(
|
||||
driver_version >= 12080, "cuda driver too old to use green context!");
|
||||
CUcontext pctx = nullptr;
|
||||
C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuCtxGetCurrent_(&pctx));
|
||||
if (C10_UNLIKELY(!pctx)) {
|
||||
TORCH_WARN(
|
||||
"Attempted to create a green context but"
|
||||
" there was no primary context! Creating a primary context...");
|
||||
|
||||
cudaFree(0);
|
||||
}
|
||||
|
||||
CUdevice device;
|
||||
device_id_ = device_id;
|
||||
C10_CUDA_DRIVER_CHECK(
|
||||
c10::cuda::DriverAPI::get()->cuDeviceGet_(&device, device_id));
|
||||
|
||||
// Get device resources
|
||||
CUdevResource device_resource;
|
||||
C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuDeviceGetDevResource_(
|
||||
device, &device_resource, CU_DEV_RESOURCE_TYPE_SM));
|
||||
|
||||
// Split resources
|
||||
std::vector<CUdevResource> result(1);
|
||||
auto result_data = result.data();
|
||||
unsigned int nb_groups = 1;
|
||||
CUdevResource remaining;
|
||||
|
||||
C10_CUDA_DRIVER_CHECK(
|
||||
c10::cuda::DriverAPI::get()->cuDevSmResourceSplitByCount_(
|
||||
result_data,
|
||||
&nb_groups,
|
||||
&device_resource,
|
||||
&remaining,
|
||||
0, // default flags
|
||||
num_sms));
|
||||
|
||||
TORCH_CHECK(nb_groups == 1, "Failed to create single resource group");
|
||||
|
||||
// Generate resource descriptor
|
||||
CUdevResourceDesc desc;
|
||||
C10_CUDA_DRIVER_CHECK(
|
||||
c10::cuda::DriverAPI::get()->cuDevResourceGenerateDesc_(
|
||||
&desc, result_data, 1));
|
||||
|
||||
// Create green context
|
||||
// CU_GREEN_CTX_DEFAULT_STREAM is required per docs:
|
||||
// https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__GREEN__CONTEXTS.html
|
||||
C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuGreenCtxCreate_(
|
||||
&green_ctx_, desc, device, CU_GREEN_CTX_DEFAULT_STREAM));
|
||||
|
||||
// Convert to regular context
|
||||
C10_CUDA_DRIVER_CHECK(
|
||||
c10::cuda::DriverAPI::get()->cuCtxFromGreenCtx_(&context_, green_ctx_));
|
||||
TORCH_CHECK(context_, "Green ctx conversion to regular ctx failed!");
|
||||
#else
|
||||
TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
|
||||
#endif
|
||||
}
|
||||
|
||||
std::unique_ptr<GreenContext> GreenContext::create(
|
||||
uint32_t num_sms,
|
||||
std::optional<uint32_t> device_id) {
|
||||
#if CUDA_HAS_GREEN_CONTEXT
|
||||
if (!device_id.has_value()) {
|
||||
device_id = at::cuda::current_device();
|
||||
}
|
||||
return std::make_unique<GreenContext>(device_id.value(), num_sms);
|
||||
#else
|
||||
TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
|
||||
#endif
|
||||
}
|
||||
|
||||
// Implement move operations
|
||||
GreenContext::GreenContext(GreenContext&& other) noexcept{
|
||||
#if CUDA_HAS_GREEN_CONTEXT
|
||||
device_id_ = std::exchange(other.device_id_, -1);
|
||||
green_ctx_ = std::exchange(other.green_ctx_, nullptr);
|
||||
context_ = std::exchange(other.context_, nullptr);
|
||||
parent_stream_ = std::exchange(other.parent_stream_, nullptr);
|
||||
#else
|
||||
TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
|
||||
#endif
|
||||
}
|
||||
|
||||
GreenContext& GreenContext::operator=(GreenContext&& other) noexcept{
|
||||
#if CUDA_HAS_GREEN_CONTEXT
|
||||
if (this != &other) {
|
||||
// Clean up current resources
|
||||
if (green_ctx_) {
|
||||
CUcontext current = nullptr;
|
||||
C10_CUDA_DRIVER_CHECK(
|
||||
c10::cuda::DriverAPI::get()->cuCtxGetCurrent_(¤t));
|
||||
if (current == context_) {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"attempting to overwrite current green ctx "
|
||||
"when it is active!");
|
||||
}
|
||||
C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuGreenCtxDestroy_(green_ctx_));
|
||||
}
|
||||
|
||||
// Take ownership of other's resources
|
||||
device_id_ = std::exchange(other.device_id_, -1);
|
||||
green_ctx_ = std::exchange(other.green_ctx_, nullptr);
|
||||
context_ = std::exchange(other.context_, nullptr);
|
||||
parent_stream_ = std::exchange(other.parent_stream_, nullptr);
|
||||
}
|
||||
return *this;
|
||||
#else
|
||||
TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
|
||||
#endif
|
||||
}
|
||||
|
||||
GreenContext::~GreenContext() noexcept{
|
||||
#if CUDA_HAS_GREEN_CONTEXT
|
||||
C10_CUDA_DRIVER_CHECK(
|
||||
c10::cuda::DriverAPI::get()->cuGreenCtxDestroy_(green_ctx_));
|
||||
#else
|
||||
TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
|
||||
#endif
|
||||
}
|
||||
|
||||
// Get the underlying CUDA context
|
||||
CUcontext GreenContext::getContext() const {
|
||||
#if CUDA_HAS_GREEN_CONTEXT
|
||||
return context_;
|
||||
#else
|
||||
TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
|
||||
#endif
|
||||
}
|
||||
|
||||
// Get the underlying green context
|
||||
#if CUDA_HAS_GREEN_CONTEXT
|
||||
CUgreenCtx GreenContext::getGreenContext() const {
|
||||
return green_ctx_;
|
||||
}
|
||||
#endif
|
||||
|
||||
// Make this context current
|
||||
void GreenContext::setContext() {
|
||||
#if CUDA_HAS_GREEN_CONTEXT
|
||||
auto current_stream = c10::cuda::getCurrentCUDAStream();
|
||||
parent_stream_ = current_stream.stream();
|
||||
|
||||
at::cuda::CUDAEvent ev;
|
||||
ev.record(current_stream);
|
||||
|
||||
CUcontext current = nullptr;
|
||||
C10_CUDA_DRIVER_CHECK(
|
||||
c10::cuda::DriverAPI::get()->cuCtxGetCurrent_(¤t));
|
||||
if (!current) {
|
||||
C10_CUDA_DRIVER_CHECK(
|
||||
c10::cuda::DriverAPI::get()->cuCtxSetCurrent_(context_));
|
||||
} else {
|
||||
C10_CUDA_DRIVER_CHECK(
|
||||
c10::cuda::DriverAPI::get()->cuCtxPushCurrent_(context_));
|
||||
}
|
||||
// currently hardcodes the new green context to use the default stream
|
||||
// TODO(eqy): consider creating a new stream if e.g., it allows interop
|
||||
// with CUDA Graph captures etc.
|
||||
auto default_stream = c10::cuda::getDefaultCUDAStream();
|
||||
ev.block(default_stream);
|
||||
c10::cuda::setCurrentCUDAStream(default_stream);
|
||||
#else
|
||||
TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
|
||||
#endif
|
||||
}
|
||||
|
||||
void GreenContext::popContext() {
|
||||
#if CUDA_HAS_GREEN_CONTEXT
|
||||
// see above note about stream being hardcoded to the default stream
|
||||
at::cuda::CUDAEvent ev;
|
||||
ev.record(c10::cuda::getCurrentCUDAStream());
|
||||
CUcontext popped;
|
||||
C10_CUDA_DRIVER_CHECK(
|
||||
c10::cuda::DriverAPI::get()->cuCtxPopCurrent_(&popped));
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
popped == context_, "expected popped context to be the current ctx");
|
||||
ev.block(c10::cuda::getStreamFromExternal(parent_stream_, device_id_));
|
||||
#else
|
||||
TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
|
||||
#endif
|
||||
}
|
||||
} // namespace at::cuda
|
||||
53
aten/src/ATen/cuda/CUDAGreenContext.h
Normal file
53
aten/src/ATen/cuda/CUDAGreenContext.h
Normal file
@ -0,0 +1,53 @@
|
||||
#pragma once
|
||||
#include <ATen/cuda/CUDAEvent.h>
|
||||
|
||||
#if defined(CUDA_VERSION) && !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
|
||||
#include <c10/cuda/driver_api.h>
|
||||
#include <cuda.h>
|
||||
#include <memory>
|
||||
#include <stdexcept>
|
||||
#include <vector>
|
||||
#define CUDA_HAS_GREEN_CONTEXT 1
|
||||
#else
|
||||
#define CUDA_HAS_GREEN_CONTEXT 0
|
||||
#endif
|
||||
|
||||
namespace at::cuda {
|
||||
|
||||
class TORCH_CUDA_CPP_API GreenContext {
|
||||
public:
|
||||
GreenContext(uint32_t device_id, uint32_t num_sms);
|
||||
|
||||
static std::unique_ptr<GreenContext> create(uint32_t num_sms, std::optional<uint32_t> device_id);
|
||||
|
||||
// Delete copy constructor and assignment
|
||||
GreenContext(const GreenContext&) = delete;
|
||||
GreenContext& operator=(const GreenContext&) = delete;
|
||||
|
||||
// Implement move operations
|
||||
GreenContext(GreenContext&& other) noexcept;
|
||||
GreenContext& operator=(GreenContext&& other) noexcept;
|
||||
~GreenContext() noexcept;
|
||||
|
||||
// Get the underlying CUDA context
|
||||
CUcontext getContext() const;
|
||||
|
||||
// Get the underlying green context
|
||||
#if CUDA_HAS_GREEN_CONTEXT
|
||||
CUgreenCtx getGreenContext() const;
|
||||
#endif
|
||||
|
||||
// Make this context current
|
||||
void setContext();
|
||||
|
||||
void popContext();
|
||||
|
||||
private:
|
||||
#if CUDA_HAS_GREEN_CONTEXT
|
||||
int32_t device_id_ = -1;
|
||||
CUgreenCtx green_ctx_ = nullptr;
|
||||
CUcontext context_ = nullptr;
|
||||
cudaStream_t parent_stream_ = nullptr;
|
||||
#endif
|
||||
};
|
||||
} // namespace at::cuda
|
||||
@ -70,11 +70,7 @@
|
||||
#define ATEN_CUB_MAXIMUM() NO_ROCM(at_cuda_detail)ROCM_HIPCUB(::cub)::Max()
|
||||
#endif
|
||||
|
||||
#if (!defined(USE_ROCM) && !CUB_SUPPORTS_NV_BFLOAT16()) || defined(USE_ROCM)
|
||||
|
||||
#if !defined(USE_ROCM)
|
||||
namespace at_cuda_detail {
|
||||
#endif
|
||||
#if defined(USE_ROCM)
|
||||
|
||||
// backport https://github.com/NVIDIA/cub/pull/306 for c10::BFloat16
|
||||
|
||||
@ -96,10 +92,6 @@ template <>
|
||||
struct ROCM_HIPCUB(cub)::NumericTraits<c10::BFloat16>:
|
||||
ROCM_HIPCUB(cub)::BaseTraits<ROCM_HIPCUB(cub)::FLOATING_POINT, true, false, unsigned short, c10::BFloat16> {};
|
||||
|
||||
#if !defined(USE_ROCM)
|
||||
} // namespace at_cuda_detail
|
||||
#endif
|
||||
|
||||
#endif
|
||||
|
||||
#if !defined(USE_ROCM)
|
||||
@ -121,7 +113,7 @@ struct cuda_type<c10::Half> {
|
||||
using type = __half;
|
||||
};
|
||||
|
||||
#if !defined(USE_ROCM) && CUB_SUPPORTS_NV_BFLOAT16()
|
||||
#if !defined(USE_ROCM)
|
||||
|
||||
template<>
|
||||
struct cuda_type<c10::BFloat16> {
|
||||
@ -203,36 +195,6 @@ __global__ void transform_vals(InputIteratorT1 a, InputIteratorT2 b, OutputItera
|
||||
*out = scan_op(static_cast<acc_t>(*a), static_cast<acc_t>(*b));
|
||||
}
|
||||
|
||||
#if !CUB_SUPPORTS_FUTURE_VALUE()
|
||||
template<typename ValueT, typename InputIteratorT>
|
||||
struct chained_iterator {
|
||||
using iterator_category = std::random_access_iterator_tag;
|
||||
using difference_type = std::ptrdiff_t;
|
||||
using value_type = ValueT;
|
||||
using pointer = ValueT*;
|
||||
using reference = ValueT&;
|
||||
|
||||
InputIteratorT iter;
|
||||
ValueT *first;
|
||||
difference_type offset = 0;
|
||||
|
||||
__device__ ValueT operator[](difference_type i) {
|
||||
i += offset;
|
||||
if (i == 0) {
|
||||
return *first;
|
||||
} else {
|
||||
return ValueT(iter[i - 1]);
|
||||
}
|
||||
}
|
||||
__device__ chained_iterator operator+(difference_type i) {
|
||||
return chained_iterator{iter, first, i};
|
||||
}
|
||||
__device__ ValueT operator*() {
|
||||
return (*this)[0];
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
// even though cub is supposed to support tensors with int_max elements, in reality it doesn't,
|
||||
// so split at int_max/2
|
||||
constexpr int max_cub_size = std::numeric_limits<int>::max() / 2 + 1; // 2**30
|
||||
@ -277,25 +239,6 @@ inline void inclusive_scan(InputIteratorT input, OutputIteratorT output, ScanOpT
|
||||
first_elem_ptr,
|
||||
scan_op);
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
#if !CUB_SUPPORTS_FUTURE_VALUE()
|
||||
using ArgIndexInputIterator = NO_ROCM(at_cuda_detail)::cub::ArgIndexInputIterator<InputIteratorT>;
|
||||
using tuple = typename ArgIndexInputIterator::value_type;
|
||||
auto input_iter_transform = [=] __device__ (const tuple &x)->input_t {
|
||||
if (x.key == 0) {
|
||||
return *first_elem_ptr;
|
||||
} else {
|
||||
return x.value;
|
||||
}
|
||||
};
|
||||
auto input_ = ATEN_CUB_TRANSFORM_ITERATOR(input_t, decltype(input_iter_transform), ArgIndexInputIterator)(
|
||||
ArgIndexInputIterator(input + i), input_iter_transform);
|
||||
CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceScan::InclusiveScan,
|
||||
input_,
|
||||
output + i,
|
||||
scan_op,
|
||||
size_cub,
|
||||
at::cuda::getCurrentCUDAStream());
|
||||
#else
|
||||
CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceScan::ExclusiveScan,
|
||||
input + i + 1,
|
||||
output + i,
|
||||
@ -303,7 +246,6 @@ inline void inclusive_scan(InputIteratorT input, OutputIteratorT output, ScanOpT
|
||||
::at_cuda_detail::cub::FutureValue<input_t>(first_elem_ptr),
|
||||
size_cub,
|
||||
at::cuda::getCurrentCUDAStream());
|
||||
#endif
|
||||
}
|
||||
#endif
|
||||
}
|
||||
@ -555,16 +497,6 @@ inline void exclusive_scan(InputIteratorT input, OutputIteratorT output, ScanOpT
|
||||
first_elem_ptr,
|
||||
scan_op);
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
#if !CUB_SUPPORTS_FUTURE_VALUE()
|
||||
auto input_ = impl::chained_iterator<InitValueT, InputIteratorT>{
|
||||
input + i, first_elem_ptr};
|
||||
CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceScan::InclusiveScan,
|
||||
input_,
|
||||
output + i,
|
||||
scan_op,
|
||||
size_cub,
|
||||
at::cuda::getCurrentCUDAStream());
|
||||
#else
|
||||
CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceScan::ExclusiveScan,
|
||||
input + i,
|
||||
output + i,
|
||||
@ -572,7 +504,6 @@ inline void exclusive_scan(InputIteratorT input, OutputIteratorT output, ScanOpT
|
||||
::at_cuda_detail::cub::FutureValue<InitValueT>(first_elem_ptr),
|
||||
size_cub,
|
||||
at::cuda::getCurrentCUDAStream());
|
||||
#endif
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
@ -10,14 +10,6 @@
|
||||
#define CUB_VERSION 200001
|
||||
#endif
|
||||
|
||||
// cub sort support for __nv_bfloat16 is added to cub 1.13 in:
|
||||
// https://github.com/NVIDIA/cub/pull/306
|
||||
#if CUB_VERSION >= 101300
|
||||
#define CUB_SUPPORTS_NV_BFLOAT16() true
|
||||
#else
|
||||
#define CUB_SUPPORTS_NV_BFLOAT16() false
|
||||
#endif
|
||||
|
||||
// cub support for CUB_WRAPPED_NAMESPACE is added to cub 1.13.1 in:
|
||||
// https://github.com/NVIDIA/cub/pull/326
|
||||
// CUB_WRAPPED_NAMESPACE is defined globally in cmake/Dependencies.cmake
|
||||
@ -28,14 +20,6 @@
|
||||
#define USE_GLOBAL_CUB_WRAPPED_NAMESPACE() false
|
||||
#endif
|
||||
|
||||
// cub support for cub::FutureValue is added to cub 1.15 in:
|
||||
// https://github.com/NVIDIA/cub/pull/305
|
||||
#if CUB_VERSION >= 101500
|
||||
#define CUB_SUPPORTS_FUTURE_VALUE() true
|
||||
#else
|
||||
#define CUB_SUPPORTS_FUTURE_VALUE() false
|
||||
#endif
|
||||
|
||||
// There were many bc-breaking changes in major version release of CCCL v3.0.0
|
||||
// Please see https://nvidia.github.io/cccl/cccl/3.0_migration_guide.html
|
||||
#if CUB_VERSION >= 200800
|
||||
|
||||
23
aten/src/ATen/detail/XLAHooksInterface.cpp
Normal file
23
aten/src/ATen/detail/XLAHooksInterface.cpp
Normal file
@ -0,0 +1,23 @@
|
||||
#include <ATen/detail/XLAHooksInterface.h>
|
||||
|
||||
namespace at {
|
||||
namespace detail {
|
||||
|
||||
const XLAHooksInterface& getXLAHooks() {
|
||||
auto create_impl = [] {
|
||||
// Create XLA hooks using the registry
|
||||
auto hooks = XLAHooksRegistry()->Create("torch_xla::detail::XLAHooks", XLAHooksArgs{});
|
||||
if (hooks) {
|
||||
return hooks;
|
||||
}
|
||||
// If hooks creation fails, fall back to default implementation
|
||||
return std::make_unique<XLAHooksInterface>();
|
||||
};
|
||||
static auto hooks = create_impl();
|
||||
return *hooks;
|
||||
}
|
||||
} // namespace detail
|
||||
|
||||
C10_DEFINE_REGISTRY(XLAHooksRegistry, XLAHooksInterface, XLAHooksArgs)
|
||||
|
||||
} // namespace at
|
||||
79
aten/src/ATen/detail/XLAHooksInterface.h
Normal file
79
aten/src/ATen/detail/XLAHooksInterface.h
Normal file
@ -0,0 +1,79 @@
|
||||
#pragma once
|
||||
|
||||
#include <c10/core/Device.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#include <c10/util/Registry.h>
|
||||
|
||||
#include <ATen/detail/AcceleratorHooksInterface.h>
|
||||
|
||||
C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-parameter")
|
||||
|
||||
namespace at {
|
||||
|
||||
constexpr const char* XLA_HELP =
|
||||
"This error has occurred because you are trying "
|
||||
"to use some XLA functionality, but the XLA library has not been "
|
||||
"loaded by the dynamic linker. You must load xla libraries by `import torch_xla`";
|
||||
|
||||
struct TORCH_API XLAHooksInterface : AcceleratorHooksInterface {
|
||||
~XLAHooksInterface() override = default;
|
||||
|
||||
void init() const override {
|
||||
TORCH_CHECK(false, "Cannot initialize XLA without torch_xla library. ", XLA_HELP);
|
||||
}
|
||||
|
||||
virtual bool hasXLA() const {
|
||||
return false;
|
||||
}
|
||||
|
||||
virtual std::string showConfig() const {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"Cannot query detailed XLA version without torch_xla library. ",
|
||||
XLA_HELP);
|
||||
}
|
||||
|
||||
const Generator& getDefaultGenerator(
|
||||
[[maybe_unused]] DeviceIndex device_index = -1) const override {
|
||||
TORCH_CHECK(
|
||||
false, "Cannot get default XLA generator without torch_xla library. ", XLA_HELP);
|
||||
}
|
||||
|
||||
Generator getNewGenerator(
|
||||
[[maybe_unused]] DeviceIndex device_index = -1) const override {
|
||||
TORCH_CHECK(false, "Cannot get XLA generator without torch_xla library. ", XLA_HELP);
|
||||
}
|
||||
|
||||
virtual DeviceIndex getCurrentDevice() const override {
|
||||
TORCH_CHECK(false, "Cannot get current XLA device without torch_xla library. ", XLA_HELP);
|
||||
}
|
||||
|
||||
Device getDeviceFromPtr(void* /*data*/) const override {
|
||||
TORCH_CHECK(false, "Cannot get device of pointer on XLA without torch_xla library. ", XLA_HELP);
|
||||
}
|
||||
|
||||
Allocator* getPinnedMemoryAllocator() const override {
|
||||
TORCH_CHECK(false, "Cannot get XLA pinned memory allocator without torch_xla library. ", XLA_HELP);
|
||||
}
|
||||
|
||||
bool isPinnedPtr(const void* data) const override {
|
||||
return false;
|
||||
}
|
||||
|
||||
bool hasPrimaryContext(DeviceIndex device_index) const override {
|
||||
TORCH_CHECK(false, "Cannot query primary context without torch_xla library. ", XLA_HELP);
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
struct TORCH_API XLAHooksArgs {};
|
||||
|
||||
TORCH_DECLARE_REGISTRY(XLAHooksRegistry, XLAHooksInterface, XLAHooksArgs);
|
||||
#define REGISTER_XLA_HOOKS(clsname) \
|
||||
C10_REGISTER_CLASS(XLAHooksRegistry, clsname, clsname)
|
||||
|
||||
namespace detail {
|
||||
TORCH_API const XLAHooksInterface& getXLAHooks();
|
||||
} // namespace detail
|
||||
} // namespace at
|
||||
C10_DIAGNOSTIC_POP()
|
||||
@ -11,6 +11,8 @@ inline void check_pixel_shuffle_shapes(const Tensor& self, int64_t upscale_facto
|
||||
"pixel_shuffle expects a positive upscale_factor, but got ",
|
||||
upscale_factor);
|
||||
int64_t c = self.size(-3);
|
||||
TORCH_CHECK_VALUE(upscale_factor <= std::numeric_limits<decltype(upscale_factor)>::max() / upscale_factor,
|
||||
"upscale factor is too large, (upscale_factor)^2 overflowed: upscale_factor=", upscale_factor);
|
||||
int64_t upscale_factor_squared = upscale_factor * upscale_factor;
|
||||
TORCH_CHECK(c % upscale_factor_squared == 0,
|
||||
"pixel_shuffle expects its input's 'channel' dimension to be divisible by the square of "
|
||||
|
||||
@ -259,11 +259,20 @@ inline void winograd_f2k3_input_transform_inplace__rvv(
|
||||
const vfloat32m1_t wd1 = __riscv_vfadd_vv_f32m1(d1, d2, 4);
|
||||
const vfloat32m1_t wd2 = __riscv_vfsub_vv_f32m1(d2, d1, 4);
|
||||
const vfloat32m1_t wd3 = __riscv_vfsub_vv_f32m1(d1, d3, 4);
|
||||
|
||||
*input_tile_val = __riscv_vset_v_f32m1_f32m1x4(*input_tile_val, 0, wd0);
|
||||
*input_tile_val = __riscv_vset_v_f32m1_f32m1x4(*input_tile_val, 1, wd1);
|
||||
*input_tile_val = __riscv_vset_v_f32m1_f32m1x4(*input_tile_val, 2, wd2);
|
||||
*input_tile_val = __riscv_vset_v_f32m1_f32m1x4(*input_tile_val, 3, wd3);
|
||||
/* GCC 14.2 (RISC-V RVV) ICE workaround:
|
||||
* Avoid single-statement read-modify-write on MEM_REF like:
|
||||
* *input_tile_val =
|
||||
* __riscv_vset_v_f32m1_f32m1x4(*input_tile_val, idx, val);
|
||||
* This triggers an ICE during GIMPLE lower (gsi_replace / riscv_gimple_fold_builtin)
|
||||
* with -march=rv64gcv. Use a temporary then write back.
|
||||
* Do NOT refactor into the single-statement form. Clang is unaffected.
|
||||
*/
|
||||
vfloat32m1x4_t tmp_input_tile_val = *input_tile_val;
|
||||
tmp_input_tile_val = __riscv_vset_v_f32m1_f32m1x4(tmp_input_tile_val, 0, wd0);
|
||||
tmp_input_tile_val = __riscv_vset_v_f32m1_f32m1x4(tmp_input_tile_val, 1, wd1);
|
||||
tmp_input_tile_val = __riscv_vset_v_f32m1_f32m1x4(tmp_input_tile_val, 2, wd2);
|
||||
tmp_input_tile_val = __riscv_vset_v_f32m1_f32m1x4(tmp_input_tile_val, 3, wd3);
|
||||
*input_tile_val = tmp_input_tile_val;
|
||||
}
|
||||
|
||||
inline void winograd_f2k3_output_transform_inplace__rvv(
|
||||
@ -277,9 +286,15 @@ inline void winograd_f2k3_output_transform_inplace__rvv(
|
||||
const vfloat32m1_t wm0 = __riscv_vfadd_vv_f32m1(m0_plus_m1, m2, 4);
|
||||
const vfloat32m1_t m1_sub_m2 = __riscv_vfsub_vv_f32m1(m1, m2, 4);
|
||||
const vfloat32m1_t wm1 = __riscv_vfsub_vv_f32m1(m1_sub_m2, m3, 4);
|
||||
|
||||
*input_tile_val = __riscv_vset_v_f32m1_f32m1x4(*input_tile_val, 0, wm0);
|
||||
*input_tile_val = __riscv_vset_v_f32m1_f32m1x4(*input_tile_val, 1, wm1);
|
||||
/* GCC 14.2 (RISC-V RVV) ICE workaround — see note above.
|
||||
* Keep the temporary + write-back pattern to avoid ICE.
|
||||
* Do NOT rewrite into:
|
||||
* *input_tile_val = __riscv_vset_v_f32m1_f32m1x4(*input_tile_val, idx, val);
|
||||
*/
|
||||
vfloat32m1x4_t tmp_output_tile_val = *input_tile_val;
|
||||
tmp_output_tile_val = __riscv_vset_v_f32m1_f32m1x4(tmp_output_tile_val, 0, wm0);
|
||||
tmp_output_tile_val = __riscv_vset_v_f32m1_f32m1x4(tmp_output_tile_val, 1, wm1);
|
||||
*input_tile_val = tmp_output_tile_val;
|
||||
}
|
||||
|
||||
inline vfloat32m1_t
|
||||
@ -300,11 +315,17 @@ inline void winograd_f2k3_kernel_transform__rvv(
|
||||
const vfloat32m1_t const_half = __riscv_vfmv_v_f_f32m1(0.5f, 4);
|
||||
const vfloat32m1_t g0_plus_g2 = __riscv_vfadd_vv_f32m1(g0, g2, 4);
|
||||
vfloat32m1_t half_g0_plus_g2 = __riscv_vfmul_vv_f32m1(const_half, g0_plus_g2, 4);
|
||||
|
||||
*transform = __riscv_vset_v_f32m1_f32m1x4(*transform, 0, g0);
|
||||
*transform = __riscv_vset_v_f32m1_f32m1x4(*transform, 1, vmuladdq_f32(half_g0_plus_g2, const_half, g1));
|
||||
*transform = __riscv_vset_v_f32m1_f32m1x4(*transform, 2, vmulsubq_f32(half_g0_plus_g2, const_half, g1));
|
||||
*transform = __riscv_vset_v_f32m1_f32m1x4(*transform, 3, g2);
|
||||
/* GCC 14.2 (RISC-V RVV) ICE workaround — see note above.
|
||||
* Keep the temporary + write-back pattern to avoid ICE.
|
||||
* Do NOT rewrite into:
|
||||
* *transform = __riscv_vset_v_f32m1_f32m1x4(*transform, idx, val);
|
||||
*/
|
||||
vfloat32m1x4_t tmp_transform = *transform;
|
||||
tmp_transform = __riscv_vset_v_f32m1_f32m1x4(tmp_transform, 0, g0);
|
||||
tmp_transform = __riscv_vset_v_f32m1_f32m1x4(tmp_transform, 1, vmuladdq_f32(half_g0_plus_g2, const_half, g1));
|
||||
tmp_transform = __riscv_vset_v_f32m1_f32m1x4(tmp_transform, 2, vmulsubq_f32(half_g0_plus_g2, const_half, g1));
|
||||
tmp_transform = __riscv_vset_v_f32m1_f32m1x4(tmp_transform, 3, g2);
|
||||
*transform = tmp_transform;
|
||||
}
|
||||
|
||||
inline vfloat32m1x4_t v4f_transpose4x4__rvv(const vfloat32m1x4_t m) {
|
||||
|
||||
@ -272,28 +272,110 @@ cuda::blas::GEMMAndBiasActivationEpilogue activation_to_gemm_and_blas_arg(Activa
|
||||
}
|
||||
}
|
||||
|
||||
static bool getDisableAddmmCudaLt() {
|
||||
static const auto env_value = c10::utils::get_env("DISABLE_ADDMM_CUDA_LT");
|
||||
if (env_value == "1") {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
/*
|
||||
* Checks whether DISABLE_ADDMM_CUDA_LT is set.
|
||||
* Additionally, for ROCM we test whether the architecture supports the Lt.
|
||||
*/
|
||||
static bool isGloballyDisabledAddmmCudaLt(const at::Device& device) {
|
||||
// When hipBLASLt is not supported on the architecture, return true
|
||||
#ifdef USE_ROCM
|
||||
static const std::vector<std::string> archs = {
|
||||
"gfx90a", "gfx942",
|
||||
#if ROCM_VERSION >= 60300
|
||||
"gfx1100", "gfx1101", "gfx1200", "gfx1201", "gfx908",
|
||||
#endif
|
||||
#if ROCM_VERSION >= 70000
|
||||
"gfx950", "gfx1150", "gfx1151"
|
||||
#endif
|
||||
};
|
||||
const auto is_hipblas_lt_arch_supported = at::detail::getCUDAHooks().isGPUArch(archs, device.index());
|
||||
if (!is_hipblas_lt_arch_supported) {
|
||||
return true;
|
||||
}
|
||||
#endif
|
||||
|
||||
// Check whether it is disabled in the env
|
||||
static const auto is_addmm_cuda_lt_disabled = c10::utils::get_env("DISABLE_ADDMM_CUDA_LT");
|
||||
if (is_addmm_cuda_lt_disabled == "1") {
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
#ifdef USE_ROCM
|
||||
static bool isSupportedHipLtROCmArch(int index) {
|
||||
static const std::vector<std::string> archs = {
|
||||
"gfx90a", "gfx942",
|
||||
#if ROCM_VERSION >= 60300
|
||||
"gfx1100", "gfx1101", "gfx1200", "gfx1201", "gfx908",
|
||||
#endif
|
||||
#if ROCM_VERSION >= 70000
|
||||
"gfx950", "gfx1150", "gfx1151"
|
||||
#endif
|
||||
};
|
||||
return at::detail::getCUDAHooks().isGPUArch(archs, index);
|
||||
/*
|
||||
* Check whether for the given input we want to enable the Lt interface
|
||||
*/
|
||||
static bool isInputCompliesAddmmCudaLt(Tensor& result, const Tensor& self, const Tensor& mat1, const Tensor& mat2, const Scalar& beta, const Scalar& alpha) {
|
||||
// Implies 2D bias which we currently not send through Lt.
|
||||
// TODO: this check is done pre col-major input preparation,
|
||||
// so, this condition can be ralexed in cases when a col-major
|
||||
// copy of result is needed.
|
||||
if (result.is_same(self)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
#if defined(USE_ROCM) && ROCM_VERSION == 60400
|
||||
// hipblaslt TT fp32 regression on ROCm 6.4, cannot use
|
||||
const auto args = cublasCommonArgs(mat1, mat2, result);
|
||||
if (args.transa == 't' && args.transb == 't') {
|
||||
return false;
|
||||
}
|
||||
#endif
|
||||
|
||||
const auto mat1_sizes = mat1.sizes();
|
||||
const auto mat2_sizes = mat2.sizes();
|
||||
#if defined(CUDA_VERSION) || defined(USE_ROCM)
|
||||
const auto scalar_type = mat1.scalar_type();
|
||||
return (beta.toComplexDouble() == 1.0
|
||||
// self.dim() == 1 && result.dim() == 2 && self.sizes()[0] == mat2_sizes[1]
|
||||
// is to use lt interface only when self is bias.
|
||||
&& self.dim() == 1 && self.sizes()[0] == mat2_sizes[1] && self.is_contiguous()
|
||||
&& result.dim() == 2 && result.is_contiguous()
|
||||
&& ( // some dtype restrictions
|
||||
#ifndef USE_ROCM
|
||||
scalar_type == at::ScalarType::Double ||
|
||||
#endif
|
||||
scalar_type == at::ScalarType::Float ||
|
||||
scalar_type == at::ScalarType::Half ||
|
||||
scalar_type == at::ScalarType::BFloat16
|
||||
)
|
||||
&& ( // some shape/stride restrictions
|
||||
// Strangely, if mat2 has only 1 row or column, we get
|
||||
// CUBLAS_STATUS_INVALID_VALUE error from cublasLtMatmulAlgoGetHeuristic.
|
||||
// NOTE: extension to mat1 because mat1/mat2 can be swapped based off
|
||||
// their row-/col-majorness.
|
||||
mat1_sizes[0] > 1 && mat1_sizes[1] > 1 &&
|
||||
mat2_sizes[0] > 1 && mat2_sizes[1] > 1
|
||||
// The last conditions is to skip 16b transA and non-trans-B having
|
||||
// leading dim >> rows when they are sliced from a large tensor
|
||||
// see fbcode/caffe2/test/test_linalg.py:test_corner_cases_of_cublasltmatmul
|
||||
#if !(defined(CUDA_VERSION) && CUDA_VERSION >= 12010 || defined(USE_ROCM))
|
||||
// Related to avoiding the leading stride >> leading dim problematic case
|
||||
// with 16b dtypes described above. For such dtypes we only allow inputs
|
||||
// which are either row- or col-major (i.e. non-overlapping, compact memory layout).
|
||||
// In that case the leading stride will be equal to the outer dim len.
|
||||
// Why do we catch this case here? The following `prepare_matrix_for_cublas` method
|
||||
// does not modify inputs as long as there is a stride of length 1
|
||||
// and the leading stride is at least max(1, other dim length), so we might
|
||||
// end up with contiguous cols but not rows (i.e. holes between different rows)
|
||||
// and vice versa.
|
||||
&& mat2_sizes[0] < 65535 * 32 && mat2_sizes[1] < 65535 * 32 &&
|
||||
mat1_sizes[0] < 65535 * 32 && mat1_sizes[1] < 65535 * 32 &&
|
||||
&& (
|
||||
// filter by dtype
|
||||
(scalar_type != at::ScalarType::Half && scalar_type != at::ScalarType::BFloat16) ||
|
||||
// check mat1/mat2 is row-/col-major
|
||||
(mat1.is_non_overlapping_and_dense() && mat2.is_non_overlapping_and_dense())
|
||||
)
|
||||
#endif
|
||||
)
|
||||
);
|
||||
#endif
|
||||
|
||||
// no compliance by default
|
||||
return false;
|
||||
}
|
||||
#endif
|
||||
|
||||
template <typename scalar_t>
|
||||
void launchTunableGemmAndBias(cublasCommonArgs &args, const Scalar& alpha, const scalar_t* bias, cuda::blas::GEMMAndBiasActivationEpilogue activation) {
|
||||
@ -335,7 +417,70 @@ void launchTunableGemmAndBias(cublasCommonArgs &args, const Scalar& alpha, const
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t, typename res_scalar_t = scalar_t>
|
||||
bool launchGemmAndBiasCublasLt(
|
||||
// args contains result which is modified
|
||||
cublasCommonArgs& args,
|
||||
const Tensor& self,
|
||||
const Scalar& alpha,
|
||||
Activation activation = Activation::None
|
||||
) {
|
||||
const auto* self_ptr = self.const_data_ptr<scalar_t>();
|
||||
|
||||
const auto tuning_ctx = at::cuda::tunable::getTuningContext();
|
||||
if (tuning_ctx->IsTunableOpEnabled()) {
|
||||
// TODO: maybe also return some success state?
|
||||
launchTunableGemmAndBias<scalar_t>(
|
||||
args, alpha, self_ptr, activation_to_gemm_and_blas_arg(activation)
|
||||
);
|
||||
return true;
|
||||
}
|
||||
|
||||
return at::cuda::blas::gemm_and_bias<scalar_t, res_scalar_t>(
|
||||
args.transa == 't',
|
||||
args.transb == 't',
|
||||
args.m,
|
||||
args.n,
|
||||
args.k,
|
||||
alpha.to<at::opmath_type<scalar_t>>(),
|
||||
args.mata->const_data_ptr<scalar_t>(),
|
||||
args.lda,
|
||||
args.matb->const_data_ptr<scalar_t>(),
|
||||
args.ldb,
|
||||
self_ptr,
|
||||
args.result->data_ptr<res_scalar_t>(),
|
||||
args.result_ld,
|
||||
activation_to_gemm_and_blas_arg(activation)
|
||||
);
|
||||
}
|
||||
|
||||
template <typename scalar_t, typename res_scalar_t = scalar_t>
|
||||
bool launchGemmCublas(
|
||||
// args contains result which is modified
|
||||
cublasCommonArgs& args,
|
||||
const Scalar& alpha,
|
||||
const Scalar& beta
|
||||
) {
|
||||
at::cuda::blas::gemm<scalar_t, res_scalar_t>(
|
||||
args.transa,
|
||||
args.transb,
|
||||
args.m,
|
||||
args.n,
|
||||
args.k,
|
||||
alpha.to<at::opmath_type<scalar_t>>(),
|
||||
args.mata->const_data_ptr<scalar_t>(),
|
||||
args.lda,
|
||||
args.matb->const_data_ptr<scalar_t>(),
|
||||
args.ldb,
|
||||
beta.to<at::opmath_type<scalar_t>>(),
|
||||
args.result->data_ptr<res_scalar_t>(),
|
||||
args.result_ld
|
||||
);
|
||||
return true; // success!
|
||||
}
|
||||
|
||||
Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& mat1, const Tensor& mat2, const Scalar& beta, const Scalar& alpha, Activation activation=Activation::None, bool disable_addmm_cuda_lt_override=false) {
|
||||
// Shape checks {
|
||||
// Make sure to keep addmm_cuda below in sync with this code; it
|
||||
// preflights a check to try to avoid actually needing to call
|
||||
// expand().
|
||||
@ -345,105 +490,62 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
|
||||
"expected mat1 and mat2 to have the same dtype, but got: ", mat1.dtype(), " != ", mat2.dtype()
|
||||
)
|
||||
|
||||
if (result.is_same(self)) {
|
||||
TORCH_CHECK(result.dim() == 2, "tensors must be 2-D");
|
||||
TORCH_CHECK(self.sizes()[0] == mat1.sizes()[0], "self dim 0 must match mat1 dim 0");
|
||||
TORCH_CHECK(self.sizes()[1] == mat2.sizes()[1], "self dim 1 must match mat2 dim 1");
|
||||
}
|
||||
// } Shape checks
|
||||
|
||||
// NOLINTNEXTLINE(*c-array*)
|
||||
TensorArg targs[]{{result, "out", 0}, {self, "self", 1}, {mat1, "mat1", 2}, {mat2, "mat2", 3}};
|
||||
checkAllSameGPU(__func__, targs);
|
||||
|
||||
IntArrayRef mat1_sizes = mat1.sizes();
|
||||
IntArrayRef mat2_sizes = mat2.sizes();
|
||||
IntArrayRef self__sizes;
|
||||
bool useLtInterface = false;
|
||||
#if defined(USE_ROCM)
|
||||
// When hipBLASLt is not supported on the architecture,
|
||||
// disable_addmm_cuda_lt will always be to set to true
|
||||
static bool disable_addmm_cuda_lt =
|
||||
!isSupportedHipLtROCmArch(self.device().index()) || getDisableAddmmCudaLt();
|
||||
#else
|
||||
static bool disable_addmm_cuda_lt = getDisableAddmmCudaLt();
|
||||
#endif
|
||||
// Handle whether to use the Lt interface {
|
||||
static bool persistent_disable_addmm_cuda_lt = isGloballyDisabledAddmmCudaLt(self.device());
|
||||
// if lt path fails, we recurse back into this function here and force the lt path to off
|
||||
// we cannot update varible disable_addmm_cuda_lt from above since it is static and would be permanent
|
||||
bool disable_addmm_cuda_lt_final = disable_addmm_cuda_lt || disable_addmm_cuda_lt_override;
|
||||
#if defined(USE_ROCM) && ROCM_VERSION == 60400
|
||||
// hipblaslt TT fp32 regression on ROCm 6.4, cannot use
|
||||
cublasCommonArgs _args(mat1, mat2, result);
|
||||
if (_args.transa == 't' && _args.transb == 't') {
|
||||
disable_addmm_cuda_lt_final = true;
|
||||
}
|
||||
#endif
|
||||
bool disable_addmm_cuda_lt = persistent_disable_addmm_cuda_lt || disable_addmm_cuda_lt_override;
|
||||
#ifdef USE_ROCM
|
||||
// Conditioned on the device index, which is not persistent
|
||||
disable_addmm_cuda_lt = isGloballyDisabledAddmmCudaLt(self.device()) || disable_addmm_cuda_lt;
|
||||
#endif
|
||||
// Condition on the input
|
||||
disable_addmm_cuda_lt = !isInputCompliesAddmmCudaLt(result, self, mat1, mat2, beta, alpha) || disable_addmm_cuda_lt;
|
||||
// }
|
||||
|
||||
at::ScalarType scalar_type = mat1.scalar_type();
|
||||
bool is_float_output_with_half_input = (scalar_type == at::ScalarType::Half || scalar_type == at::ScalarType::BFloat16) && result.scalar_type() == at::ScalarType::Float;
|
||||
c10::MaybeOwned<Tensor> self_;
|
||||
if (&result != &self) {
|
||||
#if defined(CUDA_VERSION) || defined(USE_ROCM)
|
||||
// Strangely, if mat2 has only 1 row or column, we get
|
||||
// CUBLAS_STATUS_INVALID_VALUE error from cublasLtMatmulAlgoGetHeuristic.
|
||||
// self.dim() == 1 && result.dim() == 2 && self.sizes()[0] == mat2_sizes[1]
|
||||
// is to use lt interface only when self is bias.
|
||||
// for cuda 11.4, cublasLtMatmul is activated
|
||||
// the last two conditions is to skip 16b transA and non-trans-B having
|
||||
// leading dim >> rows when they are sliced from a large tensor
|
||||
// see fbcode/caffe2/test/test_linalg.py:test_corner_cases_of_cublasltmatmul
|
||||
if (!disable_addmm_cuda_lt_final) {
|
||||
useLtInterface = beta.toComplexDouble() == 1.0 && self.dim() == 1 &&
|
||||
result.dim() == 2 && self.sizes()[0] == mat2_sizes[1] &&
|
||||
self.is_contiguous() && result.is_contiguous() &&
|
||||
#ifdef USE_ROCM
|
||||
(scalar_type == at::ScalarType::Float ||
|
||||
scalar_type == at::ScalarType::Half ||
|
||||
scalar_type == at::ScalarType::BFloat16) &&
|
||||
#else
|
||||
(scalar_type == at::ScalarType::Double ||
|
||||
scalar_type == at::ScalarType::Float ||
|
||||
scalar_type == at::ScalarType::Half ||
|
||||
scalar_type == at::ScalarType::BFloat16) &&
|
||||
#endif
|
||||
#if (defined(CUDA_VERSION) && CUDA_VERSION >= 12010 || defined(USE_ROCM))
|
||||
mat2_sizes[0] > 1 && mat2_sizes[1] > 1;
|
||||
#else
|
||||
mat2_sizes[0] > 1 && mat2_sizes[1] > 1 &&
|
||||
mat2_sizes[0] < 65535 * 32 && mat2_sizes[1] < 65535 * 32 &&
|
||||
mat1_sizes[0] < 65535 * 32 && mat1_sizes[1] < 65535 * 32 &&
|
||||
// avoid leading dim >> rows bugs
|
||||
((mat1.strides()[0] == 1 && mat1.strides()[1] == mat1_sizes[0]) ||
|
||||
(mat1.strides()[1] == 1 && mat1.strides()[0] == mat1_sizes[1]) ||
|
||||
(scalar_type != at::ScalarType::Half &&
|
||||
scalar_type != at::ScalarType::BFloat16)) &&
|
||||
((mat2.strides()[0] == 1 && mat2.strides()[1] == mat2_sizes[0]) ||
|
||||
(mat2.strides()[1] == 1 && mat2.strides()[0] == mat2_sizes[1]) ||
|
||||
(scalar_type != at::ScalarType::Half &&
|
||||
scalar_type != at::ScalarType::BFloat16));
|
||||
#endif
|
||||
}
|
||||
#endif
|
||||
if (!useLtInterface) {
|
||||
self_ = expand_size(self, {mat1_sizes[0], mat2_sizes[1]}, "addmm");
|
||||
}
|
||||
self__sizes = self_->sizes();
|
||||
} else {
|
||||
self_ = c10::MaybeOwned<Tensor>::borrowed(self);
|
||||
self__sizes = self_->sizes();
|
||||
TORCH_CHECK(result.dim() == 2, "tensors must be 2-D");
|
||||
TORCH_CHECK(self__sizes[0] == mat1_sizes[0], "self_ dim 0 must match mat1 dim 0");
|
||||
TORCH_CHECK(self__sizes[1] == mat2_sizes[1], "self_ dim 1 must match mat2 dim 1");
|
||||
}
|
||||
|
||||
if (&result != &self) {
|
||||
at::native::resize_output(result, {mat1_sizes[0], mat2_sizes[1]});
|
||||
if (beta.toComplexDouble() != 0.0 && !useLtInterface) {
|
||||
at::native::copy_(result, *self_);
|
||||
// Handle result/self shapes
|
||||
if (!result.is_same(self)) {
|
||||
at::native::resize_output(result, {mat1.sizes()[0], mat2.sizes()[1]});
|
||||
|
||||
const auto self_maybe_expanded = [&]() -> c10::MaybeOwned<Tensor> {
|
||||
if (disable_addmm_cuda_lt) {
|
||||
// When in non-Lt path we do expand self even before
|
||||
// check for beta != 0.0 to make sure that
|
||||
// test_sparse_csr.py::TestSparseCSRCUDA::test_addmm_errors_*
|
||||
// runs green.
|
||||
return expand_size(self, result.sizes(), "addmm");
|
||||
}
|
||||
// copy next, should broadcast
|
||||
return c10::MaybeOwned<Tensor>::borrowed(self);
|
||||
}();
|
||||
// We copy bias when in the non-Lt path
|
||||
if (beta.toComplexDouble() != 0.0 && disable_addmm_cuda_lt) {
|
||||
// NOTE: self should broadcast over result
|
||||
at::native::copy_(result, *self_maybe_expanded);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
IntArrayRef result_sizes = result.sizes();
|
||||
if ((result_sizes[0] == 0) || (result_sizes[1] == 0)) {
|
||||
// Short circuit on empty result
|
||||
if (result.numel() == 0) {
|
||||
return result;
|
||||
}
|
||||
|
||||
cublasCommonArgs args(mat1, mat2, result);
|
||||
|
||||
if (mat1.numel() == 0) {
|
||||
// Short circuit if the reduction dim is empty
|
||||
if (mat1.sizes()[1] == 0) {
|
||||
// By definition, when beta==0, values in self should be ignored. nans and infs
|
||||
// should not propagate
|
||||
if (beta.toComplexDouble() == 0.) {
|
||||
@ -455,158 +557,64 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
|
||||
result,
|
||||
self.expand(result.sizes()),
|
||||
at::native::scalar_tensor(
|
||||
beta,
|
||||
self.scalar_type(),
|
||||
std::nullopt /* layout */,
|
||||
at::kCPU,
|
||||
std::nullopt /* pin_memory */));
|
||||
beta,
|
||||
self.scalar_type(),
|
||||
std::nullopt /* layout */,
|
||||
at::kCPU,
|
||||
std::nullopt /* pin_memory */
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
cublasCommonArgs args(mat1, mat2, result);
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!args.result->is_conj());
|
||||
|
||||
if (useLtInterface) {
|
||||
#if defined(USE_ROCM)
|
||||
bool okay = true;
|
||||
// The Lt path
|
||||
if (!disable_addmm_cuda_lt) {
|
||||
bool lt_success = false;
|
||||
if (is_float_output_with_half_input) {
|
||||
#ifdef USE_ROCM
|
||||
TORCH_CHECK(false, "float output with half input is not enabled for ROCm");
|
||||
} else {
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(
|
||||
at::ScalarType::Half,
|
||||
at::ScalarType::BFloat16,
|
||||
scalar_type,
|
||||
"addmm_cuda_lt",
|
||||
[&] {
|
||||
auto tuning_ctx = at::cuda::tunable::getTuningContext();
|
||||
if (tuning_ctx->IsTunableOpEnabled()) {
|
||||
launchTunableGemmAndBias<scalar_t>(
|
||||
args,
|
||||
alpha,
|
||||
(&result != &self) ? self.const_data_ptr<scalar_t>() : nullptr,
|
||||
activation_to_gemm_and_blas_arg(activation));
|
||||
} else {
|
||||
okay = at::cuda::blas::gemm_and_bias<scalar_t>(
|
||||
args.transa == 't',
|
||||
args.transb == 't',
|
||||
args.m,
|
||||
args.n,
|
||||
args.k,
|
||||
alpha.to<at::opmath_type<scalar_t>>(),
|
||||
args.mata->const_data_ptr<scalar_t>(),
|
||||
args.lda,
|
||||
args.matb->const_data_ptr<scalar_t>(),
|
||||
args.ldb,
|
||||
// This condition is needed for mm case on ROCm for hipblasLt path.
|
||||
// Passing the bias ptr as null to avoid accuracy issues for mm case.
|
||||
(&result != &self) ? self.const_data_ptr<scalar_t>() : nullptr,
|
||||
args.result->data_ptr<scalar_t>(),
|
||||
args.result_ld,
|
||||
activation_to_gemm_and_blas_arg(activation)
|
||||
);
|
||||
}
|
||||
});
|
||||
}
|
||||
if (!okay) {
|
||||
// lt path failed; recurse but disable lt path
|
||||
return addmm_out_cuda_impl(result, self, mat1, mat2, beta, alpha, activation, true);
|
||||
}
|
||||
#else
|
||||
auto activation_epilogue = activation_to_gemm_and_blas_arg(activation);
|
||||
bool okay = true;
|
||||
if (is_float_output_with_half_input) {
|
||||
#else
|
||||
if (at::cuda::tunable::getTuningContext()->IsTunableOpEnabled()) {
|
||||
TORCH_CHECK(false, "Tunable GEMM is not supported for float output with reduced float input");
|
||||
}
|
||||
AT_DISPATCH_REDUCED_FLOATING_TYPES(
|
||||
scalar_type,
|
||||
"addmm_cuda_lt",
|
||||
[&] {
|
||||
auto tuning_ctx = at::cuda::tunable::getTuningContext();
|
||||
if (tuning_ctx->IsTunableOpEnabled()) {
|
||||
TORCH_CHECK(false, "Tunable GEMM is not supported for float output with reduced float input");
|
||||
lt_success = launchGemmAndBiasCublasLt<scalar_t, float>(args, self, alpha, activation);
|
||||
}
|
||||
else {
|
||||
okay = at::cuda::blas::gemm_and_bias<scalar_t, float>(
|
||||
args.transa == 't',
|
||||
args.transb == 't',
|
||||
args.m,
|
||||
args.n,
|
||||
args.k,
|
||||
alpha.to<at::opmath_type<scalar_t>>(),
|
||||
args.mata->const_data_ptr<scalar_t>(),
|
||||
args.lda,
|
||||
args.matb->const_data_ptr<scalar_t>(),
|
||||
args.ldb,
|
||||
self.const_data_ptr<scalar_t>(),
|
||||
args.result->data_ptr<float>(),
|
||||
args.result_ld,
|
||||
activation_epilogue
|
||||
);
|
||||
}});
|
||||
);
|
||||
#endif
|
||||
} else {
|
||||
// !is_float_output_with_half_input
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(
|
||||
at::ScalarType::Half,
|
||||
at::ScalarType::BFloat16,
|
||||
scalar_type,
|
||||
"addmm_cuda_lt",
|
||||
[&] {
|
||||
auto tuning_ctx = at::cuda::tunable::getTuningContext();
|
||||
if (tuning_ctx->IsTunableOpEnabled()) {
|
||||
launchTunableGemmAndBias<scalar_t>(
|
||||
args,
|
||||
alpha,
|
||||
self.const_data_ptr<scalar_t>(),
|
||||
activation_epilogue);
|
||||
lt_success = launchGemmAndBiasCublasLt<scalar_t>(args, self, alpha, activation);
|
||||
}
|
||||
else {
|
||||
okay = at::cuda::blas::gemm_and_bias<scalar_t>(
|
||||
args.transa == 't',
|
||||
args.transb == 't',
|
||||
args.m,
|
||||
args.n,
|
||||
args.k,
|
||||
alpha.to<at::opmath_type<scalar_t>>(),
|
||||
args.mata->const_data_ptr<scalar_t>(),
|
||||
args.lda,
|
||||
args.matb->const_data_ptr<scalar_t>(),
|
||||
args.ldb,
|
||||
self.const_data_ptr<scalar_t>(),
|
||||
args.result->data_ptr<scalar_t>(),
|
||||
args.result_ld,
|
||||
activation_epilogue
|
||||
);
|
||||
}});
|
||||
}
|
||||
if (!okay) {
|
||||
// lt path failed; recurse but disable lt path
|
||||
);
|
||||
} // end is_float_output_with_half_input
|
||||
|
||||
if (!lt_success) {
|
||||
// lt path failed; recurse but disable lt path
|
||||
return addmm_out_cuda_impl(result, self, mat1, mat2, beta, alpha, activation, true);
|
||||
}
|
||||
#endif
|
||||
} else
|
||||
{
|
||||
// end Lt path
|
||||
} else {
|
||||
// No Lt, we use a GEMM instead
|
||||
if (is_float_output_with_half_input) {
|
||||
AT_DISPATCH_REDUCED_FLOATING_TYPES(
|
||||
scalar_type,
|
||||
"addmm_cuda",
|
||||
[&] {
|
||||
using opmath_t = at::opmath_type<scalar_t>;
|
||||
opmath_t alpha_val = alpha.to<opmath_t>();
|
||||
opmath_t beta_val = beta.to<opmath_t>();
|
||||
const scalar_t* mat1_ptr = args.mata->const_data_ptr<scalar_t>();
|
||||
const scalar_t* mat2_ptr = args.matb->const_data_ptr<scalar_t>();
|
||||
|
||||
float* result_ptr = args.result->mutable_data_ptr<float>();
|
||||
at::cuda::blas::gemm<scalar_t, float>(
|
||||
args.transa,
|
||||
args.transb,
|
||||
args.m,
|
||||
args.n,
|
||||
args.k,
|
||||
alpha_val,
|
||||
mat1_ptr,
|
||||
args.lda,
|
||||
mat2_ptr,
|
||||
args.ldb,
|
||||
beta_val,
|
||||
result_ptr,
|
||||
args.result_ld);
|
||||
});
|
||||
launchGemmCublas<scalar_t, float>(args, alpha, beta);
|
||||
}
|
||||
);
|
||||
} else {
|
||||
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(
|
||||
at::ScalarType::Half,
|
||||
@ -614,28 +622,12 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
|
||||
scalar_type,
|
||||
"addmm_cuda",
|
||||
[&] {
|
||||
using opmath_t = at::opmath_type<scalar_t>;
|
||||
opmath_t alpha_val = alpha.to<opmath_t>();
|
||||
opmath_t beta_val = beta.to<opmath_t>();
|
||||
const scalar_t* mat1_ptr = args.mata->const_data_ptr<scalar_t>();
|
||||
const scalar_t* mat2_ptr = args.matb->const_data_ptr<scalar_t>();
|
||||
scalar_t* result_ptr = args.result->mutable_data_ptr<scalar_t>();
|
||||
at::cuda::blas::gemm<scalar_t>(
|
||||
args.transa,
|
||||
args.transb,
|
||||
args.m,
|
||||
args.n,
|
||||
args.k,
|
||||
alpha_val,
|
||||
mat1_ptr,
|
||||
args.lda,
|
||||
mat2_ptr,
|
||||
args.ldb,
|
||||
beta_val,
|
||||
result_ptr,
|
||||
args.result_ld);
|
||||
});
|
||||
launchGemmCublas<scalar_t>(args, alpha, beta);
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
// Apply epilogue
|
||||
switch (activation) {
|
||||
case Activation::RELU:
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
|
||||
@ -647,14 +639,14 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
|
||||
break;
|
||||
default: break;
|
||||
}
|
||||
}
|
||||
} // end GEMM path
|
||||
|
||||
// Preprocessor gate here needs to match the inverse of the check
|
||||
// gating activation_to_gemm_and_blas_arg above; here we are manually
|
||||
// performing a post-GELU because we weren't able to use the GELU
|
||||
// epilogue above.
|
||||
#if !defined(CUDA_VERSION) && !defined(USE_ROCM)
|
||||
if (useLtInterface && activation == Activation::GELU) {
|
||||
if (!disable_addmm_cuda_lt && activation == Activation::GELU) {
|
||||
at::gelu_(const_cast<Tensor&>(*args.result), "tanh");
|
||||
}
|
||||
#endif
|
||||
|
||||
@ -1,18 +1,17 @@
|
||||
#pragma once
|
||||
|
||||
#include <ATen/OpMathType.h>
|
||||
#include <ATen/cuda/detail/OffsetCalculator.cuh>
|
||||
#include <ATen/detail/FunctionTraits.h>
|
||||
#include <ATen/native/TensorIterator.h>
|
||||
#include <ATen/native/TensorIteratorDynamicCasting.h>
|
||||
#include <ATen/cuda/detail/OffsetCalculator.cuh>
|
||||
#include <ATen/OpMathType.h>
|
||||
#include <ATen/native/cuda/thread_constants.h>
|
||||
|
||||
#include <thrust/tuple.h>
|
||||
|
||||
#include <ATen/native/cuda/MemoryAccess.cuh>
|
||||
|
||||
#include <tuple>
|
||||
|
||||
|
||||
|
||||
namespace at::native {
|
||||
|
||||
template<int N>
|
||||
@ -62,7 +61,11 @@ __device__ inline void elementwise_kernel_helper(func_t f, policy_t policy) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < elems_per_thread; i++) {
|
||||
if (policy.check_inbounds(i)) {
|
||||
#if defined(__HIP__)
|
||||
results[i] = c10::guts::apply(f, args[i]);
|
||||
#else
|
||||
results[i] = std::apply(f, args[i]);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -23,7 +23,7 @@ namespace at::native {
|
||||
|
||||
// The maximum number of threads in a block
|
||||
#if defined(USE_ROCM)
|
||||
constexpr int MAX_BLOCK_SIZE = 256;
|
||||
constexpr int MAX_BLOCK_SIZE = 1024;
|
||||
#else
|
||||
constexpr int MAX_BLOCK_SIZE = 512;
|
||||
#endif
|
||||
@ -33,7 +33,7 @@ constexpr unsigned MAX_GRID_SIZE = 65535u;
|
||||
// Number of threads in a block given an input size up to MAX_BLOCK_SIZE
|
||||
static int getNumThreads(int nElem) {
|
||||
#if defined(USE_ROCM)
|
||||
int threadSizes[5] = { 16, 32, 64, 128, MAX_BLOCK_SIZE };
|
||||
int threadSizes[5] = { 64, 128, 256, 512, MAX_BLOCK_SIZE };
|
||||
#else
|
||||
int threadSizes[5] = { 32, 64, 128, 256, MAX_BLOCK_SIZE };
|
||||
#endif
|
||||
@ -115,9 +115,23 @@ __device__ scalar_t reduce(Op op, PTA tensor, int plane) {
|
||||
// first the reductions each thread does separately
|
||||
scalar_t sum = static_cast<scalar_t>(0);
|
||||
for (int batch = threadIdx.y; batch < tensor.size(0); batch += blockDim.y) {
|
||||
#if defined(USE_ROCM)
|
||||
constexpr int UNRL = 4; // load deserilize factor
|
||||
scalar_t tmp[UNRL];
|
||||
for (int x = threadIdx.x; x < tensor.size(2); x += blockDim.x*UNRL) {
|
||||
#pragma unroll
|
||||
for (int u = 0; u < UNRL; u++)
|
||||
tmp[u] = op(batch, plane, std::min((int)tensor.size(2)-1, (int)(x+u*blockDim.x)));
|
||||
#pragma unroll
|
||||
for (int u = 0; u < UNRL; u++)
|
||||
if (x+u*blockDim.x < tensor.size(2))
|
||||
sum += tmp[u];
|
||||
}
|
||||
#else
|
||||
for (int x = threadIdx.x; x < tensor.size(2); x += blockDim.x) {
|
||||
sum += op(batch, plane, x);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
__shared__ scalar_t shared[C10_WARP_SIZE];
|
||||
SumReduceOp<scalar_t> reduce_op;
|
||||
@ -292,6 +306,22 @@ __global__ void batch_norm_collect_statistics_kernel(
|
||||
stat_accscalar_t var_n = 0;
|
||||
int n = 0;
|
||||
for (int batch = threadIdx.y; batch < input.size(0); batch += blockDim.y) {
|
||||
#if defined(USE_ROCM)
|
||||
constexpr int UNRL = 4;
|
||||
stat_accscalar_t v_[UNRL];
|
||||
for (int x = threadIdx.x; x < input.size(2); x += blockDim.x*UNRL) {
|
||||
for (int u = 0; u < UNRL; u++)
|
||||
v_[u] = input[batch][plane][min(x+u*blockDim.x, input.size(2)-1)];
|
||||
for (int u = 0; u < UNRL; u++) {
|
||||
if (x+u*blockDim.x < input.size(2)) {
|
||||
stat_accscalar_t d1 = v_[u] - avg;
|
||||
n++;
|
||||
avg += d1 / n;
|
||||
var_n += d1 * (v_[u] - avg);
|
||||
}
|
||||
}
|
||||
}
|
||||
#else
|
||||
for (int x = threadIdx.x; x < input.size(2); x += blockDim.x) {
|
||||
stat_accscalar_t v = input[batch][plane][x];
|
||||
stat_accscalar_t d1 = v - avg;
|
||||
@ -299,6 +329,7 @@ __global__ void batch_norm_collect_statistics_kernel(
|
||||
avg += d1 / n;
|
||||
var_n += d1 * (v - avg);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
// first warpSum to get one value per thread to
|
||||
|
||||
@ -92,6 +92,16 @@ inline thrust::pair<int64_t, int64_t> get_index_mapping2d(
|
||||
output_offset + output_y * output_dim_x + output_x);
|
||||
}
|
||||
|
||||
__device__ __forceinline__ int64_t reflect_index(int64_t x, int64_t len) {
|
||||
const int64_t two = (len - 1) * 2;
|
||||
if (two <= 0) {
|
||||
return 0;
|
||||
}
|
||||
int64_t m = x % two;
|
||||
if (m < 0) m += two;
|
||||
return (m < len) ? m : (two - m);
|
||||
}
|
||||
|
||||
template<typename scalar_t>
|
||||
__global__ void reflection_pad1d_out_kernel(
|
||||
const scalar_t * input, scalar_t * output,
|
||||
@ -106,6 +116,28 @@ __global__ void reflection_pad1d_out_kernel(
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void reflection_pad1d_flat(
|
||||
const scalar_t* __restrict__ input,
|
||||
scalar_t* __restrict__ output,
|
||||
int64_t input_w, int64_t pad_l, int64_t pad_r,
|
||||
int64_t out_w, int64_t plane_count) {
|
||||
|
||||
const int64_t bx = blockDim.x;
|
||||
const int64_t tx = threadIdx.x;
|
||||
|
||||
const int64_t total = plane_count * out_w;
|
||||
const int64_t grid_stride = static_cast<int64_t>(bx) * gridDim.x;
|
||||
int64_t linear = static_cast<int64_t>(blockIdx.x) * bx + tx;
|
||||
|
||||
for (; linear < total; linear += grid_stride) {
|
||||
const int64_t plane = linear / out_w;
|
||||
const int64_t x = linear - plane * out_w;
|
||||
const int64_t j = reflect_index(x - pad_l, input_w);
|
||||
output[plane * out_w + x] = input[plane * input_w + j];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void reflection_pad1d_backward_out_kernel(
|
||||
scalar_t * grad_input, const scalar_t * grad_output,
|
||||
@ -710,25 +742,44 @@ TORCH_IMPL_FUNC(reflection_pad1d_out_cuda)
|
||||
int64_t input_w = input_.size(dim_w);
|
||||
int64_t output_w = input_w + pad_l + pad_r;
|
||||
|
||||
dim3 block_size(output_w > 256 ? 256 : output_w);
|
||||
dim3 grid_size((int)::ceil(output_w / 256.0), nplane, nbatch);
|
||||
|
||||
Tensor input = input_.contiguous();
|
||||
|
||||
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(
|
||||
kHalf, kBFloat16, input.scalar_type(), "reflection_pad1d_out_template", [&] {
|
||||
reflection_pad1d_out_kernel<<<
|
||||
grid_size,
|
||||
block_size,
|
||||
0,
|
||||
at::cuda::getCurrentCUDAStream()>>>(
|
||||
input.const_data_ptr<scalar_t>(),
|
||||
output.mutable_data_ptr<scalar_t>(),
|
||||
input_w,
|
||||
pad_l,
|
||||
pad_r);
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
});
|
||||
const int block_x = static_cast<int>(std::min<int64_t>(256, std::max<int64_t>(1, output_w)));
|
||||
const cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
|
||||
const int max_x = prop->maxGridSize[0];
|
||||
const int max_y = prop->maxGridSize[1];
|
||||
const int max_z = prop->maxGridSize[2];
|
||||
|
||||
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kHalf, kBFloat16, input.scalar_type(), "reflection_pad1d_out", [&] {
|
||||
auto stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
const int64_t gx = at::ceil_div(output_w, static_cast<int64_t>(block_x));
|
||||
|
||||
const bool fits3d = (nplane <= max_y) && (nbatch <= max_z) && (gx <= max_x);
|
||||
|
||||
if (fits3d) {
|
||||
dim3 block(block_x, 1, 1);
|
||||
dim3 grid(gx, static_cast<unsigned>(nplane), static_cast<unsigned>(nbatch));
|
||||
reflection_pad1d_out_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
||||
input.const_data_ptr<scalar_t>(),
|
||||
output.mutable_data_ptr<scalar_t>(),
|
||||
input_w, pad_l, pad_r);
|
||||
} else {
|
||||
dim3 block(block_x, 1, 1);
|
||||
const int64_t plane_count = nplane * nbatch;
|
||||
const int64_t total_blocks = at::ceil_div(plane_count * output_w, static_cast<int64_t>(block_x));
|
||||
const int grid_x = static_cast<int>(std::min<int64_t>(max_x, std::max<int64_t>(1, total_blocks)));
|
||||
dim3 grid(grid_x, 1, 1);
|
||||
|
||||
reflection_pad1d_flat<scalar_t><<<grid, block, 0, stream>>>(
|
||||
input.const_data_ptr<scalar_t>(),
|
||||
output.mutable_data_ptr<scalar_t>(),
|
||||
input_w, pad_l, pad_r, output_w, plane_count);
|
||||
}
|
||||
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
});
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(reflection_pad1d_backward_out_cuda)(const Tensor& grad_output_,
|
||||
|
||||
@ -43,6 +43,12 @@ std::tuple<Tensor&, Tensor&> kthvalue_out_impl_cuda(
|
||||
TORCH_CHECK(k >= 1 && k <= slicesize,
|
||||
"kthvalue(): selected number k out of range for dimension ", dim);
|
||||
|
||||
TORCH_CHECK(
|
||||
slicesize <= std::numeric_limits<int32_t>::max(),
|
||||
"kthvalue(): dimension ", dim, " is too large (", slicesize,
|
||||
"). The current CUDA implementation supports dimension sizes up to ",
|
||||
std::numeric_limits<int32_t>::max());
|
||||
|
||||
at::assert_no_overlap(self, values);
|
||||
|
||||
_reduction_with_indices_allocate_or_resize_output(
|
||||
@ -163,10 +169,6 @@ std::tuple<Tensor&, Tensor&> kthvalue_out_cuda(
|
||||
bool keepdim,
|
||||
Tensor& values,
|
||||
Tensor& indices) {
|
||||
// See note [Writing Nondeterministic Operations]
|
||||
// If there are duplicate elements of the kth value, the procedure for choosing which
|
||||
// of the duplicates to use for the indices output is nondeterministic.
|
||||
at::globalContext().alertNotDeterministic("kthvalue CUDA");
|
||||
auto result = [&]() {
|
||||
NoNamesGuard guard;
|
||||
// `kthvalue_out_impl_cuda` expects contiguous in input `self`.
|
||||
|
||||
@ -65,25 +65,34 @@ __global__ void gatherKthValue(
|
||||
&kValue);
|
||||
|
||||
// Find the index of the k-th highest element
|
||||
index_t kValueIndex = 0;
|
||||
bool foundKValue = false;
|
||||
__shared__ int32_t minIndexFound;
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
minIndexFound = static_cast<int32_t>(inputSliceSize);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
for (index_t i = threadIdx.x; i < inputSliceSize; i += blockDim.x) {
|
||||
bool inRange = (i < inputSliceSize);
|
||||
scalar_t v = inRange ? doLdg(&inputSliceStart[i * inputWithinSliceStride])
|
||||
: static_cast<scalar_t>(0);
|
||||
bool isKValue = inRange &&
|
||||
((v == kValue) || (at::_isnan(v) && at::_isnan(kValue)));
|
||||
if (isKValue) {
|
||||
kValueIndex = i;
|
||||
foundKValue = true;
|
||||
break;
|
||||
}
|
||||
// Early exit based on best-so-far
|
||||
if (i >= minIndexFound) {
|
||||
break;
|
||||
}
|
||||
|
||||
scalar_t v = doLdg(&inputSliceStart[i * inputWithinSliceStride]);
|
||||
bool isKValue =
|
||||
((v == kValue) || (at::_isnan(v) && at::_isnan(kValue)));
|
||||
|
||||
if (isKValue) {
|
||||
atomicMin(&minIndexFound, static_cast<int32_t>(i));
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (foundKValue) {
|
||||
kthValueSliceStart[0] = kValue;
|
||||
indicesSliceStart[0] = kValueIndex;
|
||||
__syncthreads();
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
indicesSliceStart[0] = static_cast<index_t>(minIndexFound);
|
||||
kthValueSliceStart[0] = kValue;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -52,7 +52,7 @@ struct FusedAdagradMathFunctor {
|
||||
using opmath_t = at::opmath_type<scalar_t>;
|
||||
|
||||
C10_DEVICE __forceinline__ void operator()(
|
||||
int chunk_size,
|
||||
int64_t chunk_size,
|
||||
FusedOptimizerTensorListMetadata<3>& tl,
|
||||
const float* lr_ptr,
|
||||
const double& lr,
|
||||
@ -133,4 +133,4 @@ struct FusedAdagradMathFunctor {
|
||||
|
||||
} // namespace
|
||||
|
||||
} // namespace at::native
|
||||
} // namespace at::native
|
||||
|
||||
16
aten/src/ATen/native/mps/kernels/LinearAlgebra.h
Normal file
16
aten/src/ATen/native/mps/kernels/LinearAlgebra.h
Normal file
@ -0,0 +1,16 @@
|
||||
#pragma onces
|
||||
#include <c10/metal/common.h>
|
||||
|
||||
template <unsigned N = c10::metal::max_ndim>
|
||||
struct OrgqrParams {
|
||||
int32_t num_batch_dims;
|
||||
|
||||
uint32_t m;
|
||||
uint32_t n;
|
||||
uint32_t k;
|
||||
|
||||
::c10::metal::array<uint32_t, N> A_strides;
|
||||
::c10::metal::array<uint32_t, N> tau_strides;
|
||||
::c10::metal::array<uint32_t, N> H_strides;
|
||||
::c10::metal::array<uint32_t, N> H_sizes;
|
||||
};
|
||||
@ -1,3 +1,4 @@
|
||||
#include <ATen/native/mps/kernels/LinearAlgebra.h>
|
||||
#include <c10/metal/utils.h>
|
||||
#include <metal_array>
|
||||
#include <metal_simdgroup>
|
||||
@ -640,6 +641,164 @@ kernel void applyPivots(
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static T bool_to_float(bool b) {
|
||||
return static_cast<T>(b);
|
||||
}
|
||||
|
||||
template <>
|
||||
half2 bool_to_float(bool b) {
|
||||
return half2(b ? 1 : 0, 0);
|
||||
}
|
||||
|
||||
template <>
|
||||
float2 bool_to_float(bool b) {
|
||||
return float2(b ? 1 : 0, 0);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static T calc_H_irc(
|
||||
device T* A,
|
||||
uint32_t A_stride_r,
|
||||
uint32_t A_stride_c,
|
||||
constant T* tau,
|
||||
uint32_t tau_stride,
|
||||
uint32_t r,
|
||||
uint32_t c,
|
||||
uint32_t i) {
|
||||
T I_val = bool_to_float<T>(r == c);
|
||||
T tau_val = tau[i * tau_stride];
|
||||
|
||||
T A_ci = c10::metal::conj(A[c * A_stride_r + i * A_stride_c]);
|
||||
T A_ri = A[r * A_stride_r + i * A_stride_c];
|
||||
|
||||
T c_eq_i = bool_to_float<T>(c == i);
|
||||
T r_eq_i = bool_to_float<T>(r == i);
|
||||
|
||||
T A_ci_ = (c > i) ? A_ci : c_eq_i;
|
||||
T A_ri_ = (r > i) ? A_ri : r_eq_i;
|
||||
|
||||
return I_val - c10::metal::mul(tau_val, c10::metal::mul(A_ci_, A_ri_));
|
||||
}
|
||||
|
||||
// Calculate (A @ B)[r, c], the element in the r-th row and c-th column of the
|
||||
// result of matrix multiplying A and B together. A and B must be size m-by-m
|
||||
// and have the same strides. The formula for this operation, written in Python
|
||||
// syntax, is:
|
||||
// (A @ B)[r, c] = A[r, :].dot(B[:, c])
|
||||
template <typename T>
|
||||
static T calc_matmul_rc(
|
||||
device T* A,
|
||||
device T* B,
|
||||
uint32_t stride_r,
|
||||
uint32_t stride_c,
|
||||
uint32_t m,
|
||||
uint32_t r,
|
||||
uint32_t c) {
|
||||
T AB_rc = 0;
|
||||
auto A_row_offset = r * stride_r;
|
||||
auto B_col_offset = c * stride_c;
|
||||
|
||||
uint32_t A_col_offset = 0;
|
||||
uint32_t B_row_offset = 0;
|
||||
|
||||
for (uint32_t j = 0; j < m;
|
||||
j++, A_col_offset += stride_c, B_row_offset += stride_r) {
|
||||
AB_rc += c10::metal::mul(
|
||||
A[A_row_offset + A_col_offset], B[B_row_offset + B_col_offset]);
|
||||
}
|
||||
return AB_rc;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
kernel void orgqr(
|
||||
device T* A [[buffer(0)]],
|
||||
constant T* tau [[buffer(1)]],
|
||||
device T* H [[buffer(2)]],
|
||||
device T* H_prod [[buffer(3)]],
|
||||
constant OrgqrParams<>& params [[buffer(4)]],
|
||||
uint tid [[thread_position_in_grid]]) {
|
||||
constant auto& A_strides = params.A_strides;
|
||||
constant auto& tau_strides = params.tau_strides;
|
||||
constant auto& H_strides = params.H_strides;
|
||||
constant auto& H_sizes = params.H_sizes;
|
||||
|
||||
auto num_batch_dims = params.num_batch_dims;
|
||||
auto m = params.m;
|
||||
auto n = params.n;
|
||||
auto k = params.k;
|
||||
|
||||
auto m2 = m * m;
|
||||
auto batch_idx = tid / m2;
|
||||
|
||||
// Find the matrices for this thread's batch index
|
||||
uint32_t A_offset = 0;
|
||||
uint32_t tau_offset = 0;
|
||||
uint32_t H_offset = 0;
|
||||
|
||||
for (auto dim = num_batch_dims - 1; dim >= 0; dim--) {
|
||||
auto dim_size = H_sizes[dim];
|
||||
auto dim_idx = batch_idx % dim_size;
|
||||
|
||||
A_offset += dim_idx * A_strides[dim];
|
||||
tau_offset += dim_idx * tau_strides[dim];
|
||||
H_offset += dim_idx * H_strides[dim];
|
||||
|
||||
batch_idx /= dim_size;
|
||||
}
|
||||
|
||||
A += A_offset;
|
||||
tau += tau_offset;
|
||||
H += H_offset;
|
||||
H_prod += H_offset;
|
||||
|
||||
auto matrix_idx = tid % m2;
|
||||
auto r = matrix_idx / m;
|
||||
auto c = matrix_idx % m;
|
||||
auto A_stride_r = A_strides[num_batch_dims];
|
||||
auto A_stride_c = A_strides[num_batch_dims + 1];
|
||||
auto tau_stride = tau_strides[num_batch_dims];
|
||||
auto H_stride_r = H_strides[num_batch_dims];
|
||||
auto H_stride_c = H_strides[num_batch_dims + 1];
|
||||
|
||||
// Find the element of H and H_prod that this thread will calculate
|
||||
device T* H_elem_ptr = H + (r * H_stride_r + c * H_stride_c);
|
||||
device T* H_prod_elem_ptr = H_prod + (r * H_stride_r + c * H_stride_c);
|
||||
|
||||
for (uint32_t i = 0; i < k; i++) {
|
||||
// Calculate and write H_i
|
||||
|
||||
T H_irc = calc_H_irc(A, A_stride_r, A_stride_c, tau, tau_stride, r, c, i);
|
||||
|
||||
// Calculate element [r, c] of prod(H_0, ..., H_i)
|
||||
if (i == 0) {
|
||||
*H_prod_elem_ptr = H_irc;
|
||||
} else {
|
||||
*H_elem_ptr = H_irc;
|
||||
|
||||
// Need this sync because the below matmul requires all threads to finish
|
||||
// writing their entries to `H_prod` and `H`.
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
T H_prod_0_to_i_rc =
|
||||
calc_matmul_rc(H_prod, H, H_stride_r, H_stride_c, m, r, c);
|
||||
|
||||
// Need this sync because the above matmul uses the current values in
|
||||
// `H_prod`, and we don't want to overwrite those until all threads are
|
||||
// finished using them.
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
*H_prod_elem_ptr = H_prod_0_to_i_rc;
|
||||
}
|
||||
}
|
||||
|
||||
device T* A_elem_ptr = A + (r * A_stride_r + c * A_stride_c);
|
||||
|
||||
if (c < n) {
|
||||
*A_elem_ptr = *H_prod_elem_ptr;
|
||||
}
|
||||
}
|
||||
|
||||
#define INSTANTIATE_MM_OPS(DTYPE) \
|
||||
template [[host_name("matmul_" #DTYPE)]] kernel void matmul<DTYPE>( \
|
||||
constant DTYPE * mat1Data [[buffer(0)]], \
|
||||
@ -679,3 +838,19 @@ INSTANTIATE_MM_OPS(int);
|
||||
INSTANTIATE_MM_OPS(short);
|
||||
INSTANTIATE_MM_OPS(char);
|
||||
INSTANTIATE_MM_OPS(uchar);
|
||||
|
||||
#define REGISTER_ORGQR(T) \
|
||||
template [[host_name("orgqr_" #T)]] \
|
||||
kernel void orgqr<T>( \
|
||||
device T * A [[buffer(0)]], \
|
||||
constant T * tau [[buffer(1)]], \
|
||||
device T * H [[buffer(2)]], \
|
||||
device T * H_prod [[buffer(3)]], \
|
||||
constant OrgqrParams<> & params [[buffer(4)]], \
|
||||
uint tid [[thread_position_in_grid]]);
|
||||
|
||||
REGISTER_ORGQR(float);
|
||||
REGISTER_ORGQR(half);
|
||||
REGISTER_ORGQR(bfloat);
|
||||
REGISTER_ORGQR(float2);
|
||||
REGISTER_ORGQR(half2);
|
||||
|
||||
@ -92,13 +92,8 @@ static std::tuple<Tensor, Tensor> sdpa_general_mps(const Tensor& query,
|
||||
}
|
||||
|
||||
// upcasting to float32 if needed to improve precision when multiplying by the scale factor
|
||||
if ([maskedMM dataType] != MPSDataTypeFloat32) {
|
||||
maskedMM = [mpsGraph castTensor:maskedMM toType:MPSDataTypeFloat32 name:nil];
|
||||
}
|
||||
maskedMM = castMPSTensor(mpsGraph, maskedMM, MPSDataTypeFloat32);
|
||||
maskedMM = [mpsGraph multiplicationWithPrimaryTensor:maskedMM secondaryTensor:scaleTensor name:nil];
|
||||
if ([maskedMM dataType] != qTensor.dataType) {
|
||||
maskedMM = [mpsGraph castTensor:maskedMM toType:qTensor.dataType name:nil];
|
||||
}
|
||||
|
||||
if (is_causal) {
|
||||
auto causalMask = [mpsGraph constantWithScalar:1.0f
|
||||
@ -112,7 +107,9 @@ static std::tuple<Tensor, Tensor> sdpa_general_mps(const Tensor& query,
|
||||
name:nil];
|
||||
} else if (attn_mask) {
|
||||
graph->maskTensor = mpsGraphRankedPlaceHolder(mpsGraph, *attn_mask);
|
||||
maskedMM = [mpsGraph additionWithPrimaryTensor:maskedMM secondaryTensor:graph->maskTensor name:nil];
|
||||
maskedMM = [mpsGraph additionWithPrimaryTensor:maskedMM
|
||||
secondaryTensor:castMPSTensor(mpsGraph, graph->maskTensor, maskedMM.dataType)
|
||||
name:nil];
|
||||
}
|
||||
|
||||
// Account for case where all values were masked causing division by 0 in softmax (issue:#156707)
|
||||
@ -133,8 +130,8 @@ static std::tuple<Tensor, Tensor> sdpa_general_mps(const Tensor& query,
|
||||
graph->qTensor = qTensor;
|
||||
graph->kTensor = kTensor;
|
||||
graph->vTensor = vTensor;
|
||||
graph->outputTensor = output;
|
||||
graph->attnTensor = sm;
|
||||
graph->outputTensor = castMPSTensor(mpsGraph, output, qTensor.dataType);
|
||||
graph->attnTensor = castMPSTensor(mpsGraph, sm, qTensor.dataType);
|
||||
});
|
||||
auto qPlaceholder = Placeholder(cachedGraph->qTensor, query);
|
||||
auto kPlaceholder = Placeholder(cachedGraph->kTensor, key);
|
||||
|
||||
@ -8,6 +8,9 @@
|
||||
#include <ATen/native/Resize.h>
|
||||
#include <ATen/native/mps/MPSGraphSequoiaOps.h>
|
||||
#include <ATen/native/mps/OperationUtils.h>
|
||||
#include <ATen/native/mps/kernels/LinearAlgebra.h>
|
||||
|
||||
#include <fmt/format.h>
|
||||
|
||||
#ifndef AT_PER_OPERATOR_HEADERS
|
||||
#include <ATen/Functions.h>
|
||||
@ -28,6 +31,7 @@
|
||||
#include <ATen/ops/linalg_solve_triangular_native.h>
|
||||
#include <ATen/ops/lu_unpack_native.h>
|
||||
#include <ATen/ops/mm_native.h>
|
||||
#include <ATen/ops/orgqr_native.h>
|
||||
#include <ATen/ops/slice.h>
|
||||
#include <ATen/ops/stack.h>
|
||||
#include <ATen/ops/triangular_solve_native.h>
|
||||
@ -338,6 +342,8 @@ static void linalg_lu_factor_ex_out_mps_impl(const Tensor& A,
|
||||
". See https://developer.apple.com/documentation/metalperformanceshaders/mpsmatrixdecompositionstatus for details.");
|
||||
}
|
||||
}
|
||||
|
||||
map_mps_decomposition_error_code_to_blas(info);
|
||||
}
|
||||
|
||||
static void linalg_solve_out_mps_impl(const Tensor& A,
|
||||
@ -1233,6 +1239,69 @@ static void cholesky_stub_impl(const Tensor& out, const Tensor& info, bool upper
|
||||
}
|
||||
}
|
||||
|
||||
static Tensor& orgqr_stub_impl(Tensor& self, const Tensor& tau) {
|
||||
if (self.numel() == 0) {
|
||||
return self;
|
||||
}
|
||||
|
||||
auto m = self.size(-2);
|
||||
auto n = self.size(-1);
|
||||
auto k = tau.size(-1);
|
||||
|
||||
if (tau.numel() == 0) {
|
||||
auto I = eye(m, self.scalar_type(), std::nullopt, self.device());
|
||||
return self.copy_(I.slice(-1, 0, n));
|
||||
}
|
||||
|
||||
auto num_batch_dims = self.dim() - 2;
|
||||
auto batch_sizes = self.sizes().slice(0, num_batch_dims);
|
||||
|
||||
std::vector<int64_t> H_sizes(num_batch_dims + 2);
|
||||
for (auto dim : c10::irange(num_batch_dims)) {
|
||||
H_sizes[dim] = self.size(dim);
|
||||
}
|
||||
H_sizes[num_batch_dims] = m;
|
||||
H_sizes[num_batch_dims + 1] = m;
|
||||
|
||||
auto H = at::empty(H_sizes, self.options().memory_format(MemoryFormat::Contiguous));
|
||||
auto H_prod = at::empty_like(H);
|
||||
|
||||
OrgqrParams params;
|
||||
|
||||
params.num_batch_dims = num_batch_dims;
|
||||
params.m = m;
|
||||
params.n = n;
|
||||
params.k = k;
|
||||
|
||||
for (const auto dim : c10::irange(self.dim())) {
|
||||
params.A_strides[dim] = self.stride(dim);
|
||||
|
||||
if (dim < tau.dim()) {
|
||||
params.tau_strides[dim] = tau.stride(dim);
|
||||
}
|
||||
|
||||
params.H_strides[dim] = H.stride(dim);
|
||||
params.H_sizes[dim] = H.size(dim);
|
||||
}
|
||||
|
||||
auto num_threads = H.numel();
|
||||
MPSStream* stream = getCurrentMPSStream();
|
||||
|
||||
dispatch_sync_with_rethrow(stream->queue(), ^() {
|
||||
@autoreleasepool {
|
||||
id<MTLComputeCommandEncoder> compute_encoder = stream->commandEncoder();
|
||||
auto pipeline_state = lib.getPipelineStateForFunc(fmt::format("orgqr_{}", scalarToMetalTypeString(self)));
|
||||
getMPSProfiler().beginProfileKernel(pipeline_state, "orgqr", {self, tau});
|
||||
[compute_encoder setComputePipelineState:pipeline_state];
|
||||
mtl_setArgs(compute_encoder, self, tau, H, H_prod, params);
|
||||
mtl_dispatch1DJob(compute_encoder, pipeline_state, num_threads);
|
||||
getMPSProfiler().endProfileKernel(pipeline_state);
|
||||
}
|
||||
});
|
||||
|
||||
return self;
|
||||
}
|
||||
|
||||
} // namespace mps
|
||||
|
||||
Tensor addr_mps(const Tensor& self, const Tensor& vec1, const Tensor& vec2, const Scalar& beta, const Scalar& alpha) {
|
||||
@ -1448,20 +1517,6 @@ TORCH_IMPL_FUNC(_linalg_solve_ex_out_mps)
|
||||
mps::linalg_solve_out_mps_impl(A, B, left, check_errors, result, LU, pivots, info);
|
||||
}
|
||||
|
||||
std::tuple<Tensor&, Tensor&> linalg_lu_factor_out_mps(const Tensor& A, bool pivot, Tensor& LU, Tensor& pivots) {
|
||||
Tensor info = at::empty({}, A.options().dtype(kInt));
|
||||
mps::linalg_lu_factor_ex_out_mps_impl(A, pivot, LU, pivots, info, false);
|
||||
return std::tie(LU, pivots);
|
||||
}
|
||||
|
||||
std::tuple<Tensor, Tensor> linalg_lu_factor_mps(const Tensor& A, bool pivot) {
|
||||
Tensor LU = at::empty({0}, A.options());
|
||||
Tensor pivots = at::empty({0}, A.options().dtype(kInt));
|
||||
Tensor info = at::empty({}, A.options().dtype(kInt));
|
||||
mps::linalg_lu_factor_ex_out_mps_impl(A, pivot, LU, pivots, info, false);
|
||||
return std::make_tuple(std::move(LU), std::move(pivots));
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(lu_unpack_out_mps)
|
||||
(const Tensor& LU_data,
|
||||
const Tensor& LU_pivots,
|
||||
@ -1483,4 +1538,6 @@ TORCH_IMPL_FUNC(linalg_inv_ex_out_mps)(const Tensor& A, bool check_errors, const
|
||||
}
|
||||
|
||||
REGISTER_DISPATCH(cholesky_stub, mps::cholesky_stub_impl)
|
||||
REGISTER_DISPATCH(orgqr_stub, mps::orgqr_stub_impl);
|
||||
|
||||
} // namespace at::native
|
||||
|
||||
@ -14157,16 +14157,10 @@
|
||||
- func: linalg_lu_factor(Tensor A, *, bool pivot=True) -> (Tensor LU, Tensor pivots)
|
||||
python_module: linalg
|
||||
variants: function
|
||||
dispatch:
|
||||
CompositeImplicitAutograd: linalg_lu_factor
|
||||
MPS: linalg_lu_factor_mps
|
||||
|
||||
- func: linalg_lu_factor.out(Tensor A, *, bool pivot=True, Tensor(a!) LU, Tensor(b!) pivots) -> (Tensor(a!) LU, Tensor(b!) pivots)
|
||||
python_module: linalg
|
||||
variants: function
|
||||
dispatch:
|
||||
CompositeImplicitAutograd: linalg_lu_factor_out
|
||||
MPS: linalg_lu_factor_out_mps
|
||||
|
||||
- func: linalg_lu_factor_ex(Tensor A, *, bool pivot=True, bool check_errors=False) -> (Tensor LU, Tensor pivots, Tensor info)
|
||||
python_module: linalg
|
||||
@ -14368,12 +14362,12 @@
|
||||
python_module: linalg
|
||||
variants: function
|
||||
dispatch:
|
||||
CPU, CUDA: linalg_householder_product
|
||||
CPU, CUDA, MPS: linalg_householder_product
|
||||
|
||||
- func: linalg_householder_product.out(Tensor input, Tensor tau, *, Tensor(a!) out) -> Tensor(a!)
|
||||
python_module: linalg
|
||||
dispatch:
|
||||
CPU, CUDA: linalg_householder_product_out
|
||||
CPU, CUDA, MPS: linalg_householder_product_out
|
||||
|
||||
- func: linalg_inv_ex(Tensor A, *, bool check_errors=False) -> (Tensor inverse, Tensor info)
|
||||
python_module: linalg
|
||||
|
||||
@ -40,15 +40,7 @@
|
||||
#include <thrust/iterator/discard_iterator.h>
|
||||
|
||||
|
||||
#if defined(__CUDACC__) && (defined(CUSPARSE_VERSION) || (defined(USE_ROCM) && ROCM_VERSION >= 60300))
|
||||
#define IS_CUSPARSE11_AVAILABLE() 1
|
||||
#else
|
||||
#define IS_CUSPARSE11_AVAILABLE() 0
|
||||
#endif
|
||||
|
||||
#if IS_CUSPARSE11_AVAILABLE()
|
||||
#include <library_types.h>
|
||||
#endif
|
||||
|
||||
namespace at::native {
|
||||
|
||||
@ -103,17 +95,9 @@ struct csrMatrixRef {
|
||||
int nnz_{0};
|
||||
std::vector<int> size_{};
|
||||
|
||||
#if IS_CUSPARSE11_AVAILABLE()
|
||||
cusparseSpMatDescr_t description_{0};
|
||||
#else
|
||||
cusparseMatDescr_t description_{0};
|
||||
#endif
|
||||
cusparseSpMatDescr_t description_{0};
|
||||
|
||||
csrMatrixRef() {
|
||||
#if !IS_CUSPARSE11_AVAILABLE()
|
||||
create_general_description_(description_);
|
||||
#endif
|
||||
}
|
||||
csrMatrixRef() = default;
|
||||
|
||||
csrMatrixRef(
|
||||
int* csr_indices,
|
||||
@ -126,7 +110,6 @@ struct csrMatrixRef {
|
||||
csr_values_{csr_values},
|
||||
nnz_{nnz},
|
||||
size_{size} {
|
||||
#if IS_CUSPARSE11_AVAILABLE()
|
||||
cudaDataType cuda_data_type = at::cuda::getCudaDataType<scalar_t>();
|
||||
TORCH_CUDASPARSE_CHECK(cusparseCreateCsr(
|
||||
&description_,
|
||||
@ -140,17 +123,10 @@ struct csrMatrixRef {
|
||||
CUSPARSE_INDEX_32I,
|
||||
CUSPARSE_INDEX_BASE_ZERO,
|
||||
cuda_data_type));
|
||||
#else
|
||||
create_general_description_(description_);
|
||||
#endif
|
||||
}
|
||||
|
||||
~csrMatrixRef() {
|
||||
#if IS_CUSPARSE11_AVAILABLE()
|
||||
cusparseDestroySpMat(description_);
|
||||
#else
|
||||
cusparseDestroyMatDescr(description_);
|
||||
#endif
|
||||
cusparseDestroySpMat(description_);
|
||||
}
|
||||
|
||||
int size(int index) const {
|
||||
@ -196,8 +172,6 @@ struct csrOutput {
|
||||
}
|
||||
};
|
||||
|
||||
#if IS_CUSPARSE11_AVAILABLE()
|
||||
|
||||
// RAII guard helps to support cuSparse 11 API for `A @ B` operation
|
||||
// This generic template exists because with cuSparse the `scalar_t` type could be a double or float
|
||||
template <class scalar_t>
|
||||
@ -396,284 +370,6 @@ template struct CusparseMatrixMultiplyOp<float>;
|
||||
|
||||
template struct CusparseMatrixMultiplyOp<double>;
|
||||
|
||||
#else // if not IS_CUSPARSE11_AVAILABLE()
|
||||
|
||||
using DcsrMatrixRef = csrMatrixRef<double>;
|
||||
using ScsrMatrixRef = csrMatrixRef<float>;
|
||||
|
||||
// RAII guard helps to support cuSparse 10 API for `A @ B` operation
|
||||
// This generic template exists because with cuSparse the `scalar_t` type could be a double or float
|
||||
template <class scalar_t>
|
||||
struct CusparseMatrixMultiplyOp {
|
||||
csrOutput operator()(
|
||||
const csrMatrixRef<scalar_t>& lhs,
|
||||
const csrMatrixRef<scalar_t>& rhs,
|
||||
Tensor &output_values,
|
||||
Tensor &output_indices)
|
||||
{
|
||||
static_assert(false&&sizeof(scalar_t), "cusparse csr sparse-sparse MM only supports data type of float and double.");
|
||||
}
|
||||
};
|
||||
|
||||
// Specializacion for `A @ B` operation for double values with cuSparse
|
||||
template<> struct CusparseMatrixMultiplyOp<double> {
|
||||
csrgemm2Info_t gemm2Info_;
|
||||
|
||||
CusparseMatrixMultiplyOp() {
|
||||
TORCH_CUDASPARSE_CHECK(cusparseCreateCsrgemm2Info(&gemm2Info_));
|
||||
}
|
||||
~CusparseMatrixMultiplyOp() {
|
||||
cusparseDestroyCsrgemm2Info(gemm2Info_);
|
||||
}
|
||||
|
||||
csrOutput operator ()(
|
||||
const DcsrMatrixRef& lhs,
|
||||
const DcsrMatrixRef& rhs,
|
||||
Tensor &output_values,
|
||||
Tensor &output_indices) {
|
||||
double alpha = 1.0;
|
||||
DcsrMatrixRef empty;
|
||||
return Dgemm2(lhs, rhs, empty, &alpha, nullptr, output_values, output_indices);
|
||||
}
|
||||
|
||||
csrOutput Dgemm2(
|
||||
const DcsrMatrixRef& A,
|
||||
const DcsrMatrixRef& B,
|
||||
const DcsrMatrixRef& C,
|
||||
const double* alpha,
|
||||
const double* beta,
|
||||
Tensor &output_values,
|
||||
Tensor &output_indices) {
|
||||
void* buffer_{nullptr};
|
||||
cusparseHandle_t cusparseHandle_ = at::cuda::getCurrentCUDASparseHandle();
|
||||
TORCH_CUDASPARSE_CHECK(cusparseSetPointerMode(cusparseHandle_, CUSPARSE_POINTER_MODE_HOST));
|
||||
|
||||
csrOutput out({A.size(0), B.size(1)});
|
||||
int innerSize = confirm_mult_size(A.size_, B.size_);
|
||||
out.csr_pointers_ = at::empty({out.size(0) + 1}, output_indices.options().dtype(kInt));
|
||||
|
||||
// Compute needed buffer size
|
||||
size_t new_bubber_sz;
|
||||
TORCH_CUDASPARSE_CHECK(cusparseDcsrgemm2_bufferSizeExt(
|
||||
cusparseHandle_,
|
||||
out.size(0),
|
||||
out.size(1),
|
||||
innerSize,
|
||||
alpha,
|
||||
A.description_,
|
||||
A.nnz_,
|
||||
A.csr_pointers_,
|
||||
A.csr_indices_,
|
||||
B.description_,
|
||||
B.nnz_,
|
||||
B.csr_pointers_,
|
||||
B.csr_indices_,
|
||||
beta,
|
||||
C.description_,
|
||||
C.nnz_,
|
||||
C.csr_pointers_,
|
||||
C.csr_indices_,
|
||||
gemm2Info_,
|
||||
&new_bubber_sz));
|
||||
|
||||
// (Re)allocate buffer if needed
|
||||
auto& allocator = *::c10::cuda::CUDACachingAllocator::get();
|
||||
at::DataPtr data_ptr = allocator.allocate(new_bubber_sz);
|
||||
buffer_ = data_ptr.get();
|
||||
|
||||
// Find the resulting non-zero pattern.
|
||||
TORCH_CUDASPARSE_CHECK(cusparseXcsrgemm2Nnz(
|
||||
cusparseHandle_,
|
||||
out.size(0),
|
||||
out.size(1),
|
||||
innerSize,
|
||||
A.description_,
|
||||
A.nnz_,
|
||||
A.csr_pointers_,
|
||||
A.csr_indices_,
|
||||
B.description_,
|
||||
B.nnz_,
|
||||
B.csr_pointers_,
|
||||
B.csr_indices_,
|
||||
C.description_,
|
||||
C.nnz_,
|
||||
C.csr_pointers_,
|
||||
C.csr_indices_,
|
||||
out.description_,
|
||||
out.csr_pointers_.data_ptr<int>(),
|
||||
&out.nnz_,
|
||||
gemm2Info_,
|
||||
buffer_));
|
||||
|
||||
out.csr_indices_ = at::empty({out.nnz_}, output_indices.options().dtype(kInt));
|
||||
out.csr_values_ = at::empty({out.nnz_}, output_values.options());
|
||||
|
||||
// Perform the gemm2 operation for doubles
|
||||
// out = alpha ∗ A ∗ B + beta ∗ C
|
||||
TORCH_CUDASPARSE_CHECK(cusparseDcsrgemm2(
|
||||
cusparseHandle_,
|
||||
out.size(0),
|
||||
out.size(1),
|
||||
innerSize,
|
||||
alpha,
|
||||
A.description_,
|
||||
A.nnz_,
|
||||
A.csr_values_,
|
||||
A.csr_pointers_,
|
||||
A.csr_indices_,
|
||||
B.description_,
|
||||
B.nnz_,
|
||||
B.csr_values_,
|
||||
B.csr_pointers_,
|
||||
B.csr_indices_,
|
||||
beta,
|
||||
C.description_,
|
||||
C.nnz_,
|
||||
C.csr_values_,
|
||||
C.csr_pointers_,
|
||||
C.csr_indices_,
|
||||
out.description_,
|
||||
out.csr_values_.data_ptr<double>(),
|
||||
out.csr_pointers_.data_ptr<int>(),
|
||||
out.csr_indices_.data_ptr<int>(),
|
||||
gemm2Info_,
|
||||
buffer_));
|
||||
return out;
|
||||
}
|
||||
};
|
||||
|
||||
// Specializacion for `A @ B` operation for float values with cuSparse
|
||||
template<> struct CusparseMatrixMultiplyOp<float> {
|
||||
csrgemm2Info_t gemm2Info_;
|
||||
|
||||
CusparseMatrixMultiplyOp() {
|
||||
TORCH_CUDASPARSE_CHECK(cusparseCreateCsrgemm2Info(&gemm2Info_));
|
||||
|
||||
}
|
||||
~CusparseMatrixMultiplyOp() {
|
||||
cusparseDestroyCsrgemm2Info(gemm2Info_);
|
||||
}
|
||||
csrOutput operator()(
|
||||
const ScsrMatrixRef& lhs,
|
||||
const ScsrMatrixRef& rhs,
|
||||
Tensor &output_values,
|
||||
Tensor &output_indices) {
|
||||
float alpha = 1.0;
|
||||
ScsrMatrixRef empty;
|
||||
return Sgemm2(lhs, rhs, empty, &alpha, nullptr, output_values, output_indices);
|
||||
}
|
||||
|
||||
csrOutput Sgemm2(
|
||||
const ScsrMatrixRef& A,
|
||||
const ScsrMatrixRef& B,
|
||||
const ScsrMatrixRef& C,
|
||||
const float* alpha,
|
||||
const float* beta,
|
||||
Tensor &output_values,
|
||||
Tensor &output_indices) {
|
||||
void* buffer_{nullptr};
|
||||
cusparseHandle_t cusparseHandle_ = at::cuda::getCurrentCUDASparseHandle();
|
||||
TORCH_CUDASPARSE_CHECK(cusparseSetPointerMode(cusparseHandle_, CUSPARSE_POINTER_MODE_HOST));
|
||||
|
||||
csrOutput out({A.size(0), B.size(1)});
|
||||
|
||||
int innerSize = confirm_mult_size(A.size_, B.size_);
|
||||
|
||||
out.csr_pointers_ = at::empty({out.size(0) + 1}, output_indices.options().dtype(kInt));
|
||||
|
||||
// Compute needed buffer size
|
||||
size_t new_bubber_sz;
|
||||
TORCH_CUDASPARSE_CHECK(cusparseScsrgemm2_bufferSizeExt(
|
||||
cusparseHandle_,
|
||||
out.size(0),
|
||||
out.size(1),
|
||||
innerSize,
|
||||
alpha,
|
||||
A.description_,
|
||||
A.nnz_,
|
||||
A.csr_pointers_,
|
||||
A.csr_indices_,
|
||||
B.description_,
|
||||
B.nnz_,
|
||||
B.csr_pointers_,
|
||||
B.csr_indices_,
|
||||
beta,
|
||||
C.description_,
|
||||
C.nnz_,
|
||||
C.csr_pointers_,
|
||||
C.csr_indices_,
|
||||
gemm2Info_,
|
||||
&new_bubber_sz));
|
||||
|
||||
auto& allocator = *::c10::cuda::CUDACachingAllocator::get();
|
||||
at::DataPtr data_ptr = allocator.allocate(new_bubber_sz);
|
||||
buffer_ = data_ptr.get();
|
||||
|
||||
// Find the resulting non-zero pattern.
|
||||
TORCH_CUDASPARSE_CHECK(cusparseXcsrgemm2Nnz(
|
||||
cusparseHandle_,
|
||||
out.size(0),
|
||||
out.size(1),
|
||||
innerSize,
|
||||
A.description_,
|
||||
A.nnz_,
|
||||
A.csr_pointers_,
|
||||
A.csr_indices_,
|
||||
B.description_,
|
||||
B.nnz_,
|
||||
B.csr_pointers_,
|
||||
B.csr_indices_,
|
||||
C.description_,
|
||||
C.nnz_,
|
||||
C.csr_pointers_,
|
||||
C.csr_indices_,
|
||||
out.description_,
|
||||
out.csr_pointers_.data_ptr<int>(),
|
||||
&out.nnz_,
|
||||
gemm2Info_,
|
||||
buffer_));
|
||||
|
||||
out.csr_indices_ = at::empty({out.nnz_}, output_indices.options().dtype(kInt));
|
||||
out.csr_values_ = at::empty({out.nnz_}, output_values.options());
|
||||
|
||||
// Perform the gemm2 operation for doubles
|
||||
// out = alpha ∗ A ∗ B + beta ∗ C
|
||||
TORCH_CUDASPARSE_CHECK(cusparseScsrgemm2(
|
||||
cusparseHandle_,
|
||||
out.size(0),
|
||||
out.size(1),
|
||||
innerSize,
|
||||
alpha,
|
||||
A.description_,
|
||||
A.nnz_,
|
||||
A.csr_values_,
|
||||
A.csr_pointers_,
|
||||
A.csr_indices_,
|
||||
B.description_,
|
||||
B.nnz_,
|
||||
B.csr_values_,
|
||||
B.csr_pointers_,
|
||||
B.csr_indices_,
|
||||
beta,
|
||||
C.description_,
|
||||
C.nnz_,
|
||||
C.csr_values_,
|
||||
C.csr_pointers_,
|
||||
C.csr_indices_,
|
||||
out.description_,
|
||||
out.csr_values_.data_ptr<float>(),
|
||||
out.csr_pointers_.data_ptr<int>(),
|
||||
out.csr_indices_.data_ptr<int>(),
|
||||
gemm2Info_,
|
||||
buffer_));
|
||||
return out;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
|
||||
#endif // IS_CUSPARSE11_AVAILABLE()
|
||||
|
||||
template <typename scalar_t>
|
||||
void sparse_sparse_matmul_cuda_kernel(
|
||||
Tensor& result,
|
||||
@ -815,19 +511,15 @@ Tensor sparse_sparse_matmul_cuda(const Tensor& mat1_, const Tensor& mat2_) {
|
||||
auto output = at::native::empty_like(mat1_);
|
||||
output.sparse_resize_and_clear_({mat1_.size(0), mat2_.size(1)}, mat1_.sparse_dim(), 0);
|
||||
|
||||
#if IS_CUSPARSE11_AVAILABLE() && !defined(USE_ROCM)
|
||||
#if !defined(USE_ROCM)
|
||||
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kHalf, kBFloat16, mat1_.scalar_type(), "sparse_matmul", [&] {
|
||||
sparse_sparse_matmul_cuda_kernel<scalar_t>(output, mat1_.coalesce(), mat2_.coalesce());
|
||||
});
|
||||
#elif IS_CUSPARSE11_AVAILABLE() && defined(USE_ROCM)
|
||||
#else
|
||||
// ROCm does not support half and bfloat16 types for sparse_matmul
|
||||
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(mat1_.scalar_type(), "sparse_matmul", [&] {
|
||||
sparse_sparse_matmul_cuda_kernel<scalar_t>(output, mat1_.coalesce(), mat2_.coalesce());
|
||||
});
|
||||
#else
|
||||
AT_DISPATCH_FLOATING_TYPES(mat1_.scalar_type(), "sparse_matmul", [&] {
|
||||
sparse_sparse_matmul_cuda_kernel<scalar_t>(output, mat1_.coalesce(), mat2_.coalesce());
|
||||
});
|
||||
#endif
|
||||
return output;
|
||||
}
|
||||
|
||||
@ -33,7 +33,7 @@ using namespace mps;
|
||||
#ifndef PYTORCH_JIT_COMPILE_SHADERS
|
||||
static auto& lib = MetalShaderLibrary::getBundledLibrary();
|
||||
#else
|
||||
#include <ATen/native/mps/Mul_metallib.h>
|
||||
#include <ATen/native/mps/SparseTensorMath_metallib.h>
|
||||
#endif
|
||||
|
||||
static Tensor& s_addmm_out_sparse_dense_mps(
|
||||
@ -369,12 +369,7 @@ static SparseTensor& mul_out_dense_sparse_mps(
|
||||
}
|
||||
|
||||
if (scalar_like) {
|
||||
auto scalar = dense;
|
||||
if (dense.numel() == 1 && dense.dim() > 0) {
|
||||
scalar = dense.view({});
|
||||
}
|
||||
scalar = scalar.to(values.options());
|
||||
auto out_vals = values.mul(scalar);
|
||||
auto out_vals = values.mul(dense.to(values.options()));
|
||||
if (out.scalar_type() != commonDtype) {
|
||||
out_vals = out_vals.to(out.scalar_type());
|
||||
}
|
||||
@ -508,14 +503,14 @@ SparseTensor& mul_out_sparse_mps(const Tensor& t_, const Tensor& src_, SparseTen
|
||||
const auto device = r_.device();
|
||||
auto stream = getCurrentMPSStream();
|
||||
|
||||
auto lhs_indices = lhs._indices();
|
||||
auto rhs_indices = rhs._indices();
|
||||
auto lhs_values = lhs._values().to(commonDtype);
|
||||
auto rhs_values = rhs._values().to(commonDtype);
|
||||
auto lhs_indices = lhs._indices().contiguous();
|
||||
auto rhs_indices = rhs._indices().contiguous();
|
||||
auto lhs_values = lhs._values().to(commonDtype).contiguous();
|
||||
auto rhs_values = rhs._values().to(commonDtype).contiguous();
|
||||
|
||||
// Flatten sparse indices to keys
|
||||
auto lhs_keys = flatten_indices(lhs_indices, lhs.sizes());
|
||||
auto rhs_keys = flatten_indices(rhs_indices, rhs.sizes());
|
||||
auto lhs_keys = flatten_indices(lhs_indices, lhs.sizes().slice(0, ndim_i));
|
||||
auto rhs_keys = flatten_indices(rhs_indices, rhs.sizes().slice(0, ndim_i));
|
||||
|
||||
// Intersect sorted keys (search the shorter in the longer)
|
||||
const bool A_is_lhs = (lhs_nnz <= rhs_nnz);
|
||||
@ -546,35 +541,54 @@ SparseTensor& mul_out_sparse_mps(const Tensor& t_, const Tensor& src_, SparseTen
|
||||
auto out_indices = at::empty({ndim_i, static_cast<int64_t>(M)}, at::device(device).dtype(at::kLong));
|
||||
auto lhs_match = outA_idx.narrow(0, 0, M);
|
||||
auto rhs_match = outB_idx.narrow(0, 0, M);
|
||||
auto out_val_sizes = lhs_values.sizes().vec();
|
||||
out_val_sizes[0] = static_cast<int64_t>(M);
|
||||
auto dense_sizes_vec = lhs.sizes().slice(ndim_i).vec();
|
||||
int64_t cols64 = 1;
|
||||
for (auto s : dense_sizes_vec) cols64 *= s;
|
||||
const uint32_t cols = static_cast<uint32_t>(std::max<int64_t>(cols64, 1));
|
||||
|
||||
auto to2d = [&](Tensor t, int64_t nnz) -> Tensor {
|
||||
const int64_t t_cols = t.numel() / nnz;
|
||||
if (t_cols == cols64) {
|
||||
return t.view({nnz, cols64});
|
||||
}
|
||||
return t.view({nnz, 1}).expand({nnz, cols64}).contiguous();
|
||||
};
|
||||
|
||||
// make both sides 2d [nnz, cols] buffers so the kernel can index it
|
||||
auto lhs_vals2d = to2d(lhs_values, lhs_nnz);
|
||||
auto rhs_vals2d = to2d(rhs_values, rhs_nnz);
|
||||
|
||||
std::vector<int64_t> out_val_sizes;
|
||||
out_val_sizes.reserve(1 + dense_sizes_vec.size());
|
||||
out_val_sizes.push_back(static_cast<int64_t>(M));
|
||||
out_val_sizes.insert(out_val_sizes.end(), dense_sizes_vec.begin(), dense_sizes_vec.end());
|
||||
auto out_values = at::empty(out_val_sizes, lhs_values.options());
|
||||
|
||||
const uint32_t cols = static_cast<uint32_t>(
|
||||
lhs_values.numel() / std::max<int64_t>(1, lhs_nnz));
|
||||
if (M > 0) {
|
||||
dispatch_sync_with_rethrow(stream->queue(), ^() {
|
||||
@autoreleasepool {
|
||||
auto pso = lib.getPipelineStateForFunc(
|
||||
"fused_gather_mul_kernel_" + mps::scalarToMetalTypeString(lhs_values));
|
||||
auto enc = stream->commandEncoder();
|
||||
[enc setComputePipelineState:pso];
|
||||
|
||||
dispatch_sync_with_rethrow(stream->queue(), ^() {
|
||||
@autoreleasepool {
|
||||
auto pso = lib.getPipelineStateForFunc(
|
||||
"fused_gather_mul_kernel_" + mps::scalarToMetalTypeString(lhs_values));
|
||||
auto enc = stream->commandEncoder();
|
||||
[enc setComputePipelineState:pso];
|
||||
const uint32_t tew = pso.threadExecutionWidth;
|
||||
const uint32_t gridW = std::max<uint32_t>(cols, 1u);
|
||||
const uint32_t tgW = std::min(gridW, tew);
|
||||
MTLSize grid = MTLSizeMake(gridW, 1, M);
|
||||
MTLSize tgs = MTLSizeMake(tgW, 1, 1);
|
||||
|
||||
const uint32_t tew = pso.threadExecutionWidth;
|
||||
uint32_t tgW = std::min(cols, tew);
|
||||
MTLSize grid = MTLSizeMake(cols, 1, M);
|
||||
MTLSize tgs = MTLSizeMake(tgW, 1, 1);
|
||||
|
||||
mtl_setArgs(enc,
|
||||
lhs_values, rhs_values,
|
||||
lhs_match, rhs_match,
|
||||
lhs_indices, out_indices,
|
||||
out_values,
|
||||
std::array<uint32_t, 2>{static_cast<uint32_t>(ndim_i), static_cast<uint32_t>(lhs_nnz)},
|
||||
std::array<uint32_t, 2>{M, cols});
|
||||
[enc dispatchThreads:grid threadsPerThreadgroup:tgs];
|
||||
}
|
||||
});
|
||||
mtl_setArgs(enc,
|
||||
lhs_vals2d, rhs_vals2d,
|
||||
lhs_match, rhs_match,
|
||||
lhs_indices, out_indices,
|
||||
out_values,
|
||||
std::array<uint32_t, 2>{static_cast<uint32_t>(ndim_i), static_cast<uint32_t>(lhs_nnz)},
|
||||
std::array<uint32_t, 2>{M, cols});
|
||||
[enc dispatchThreads:grid threadsPerThreadgroup:tgs];
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
if (r_.scalar_type() != commonDtype) {
|
||||
out_values = out_values.to(r_.scalar_type());
|
||||
|
||||
@ -62,7 +62,6 @@ kernel void build_row_ptr_from_sorted_rows_by_batch(
|
||||
|
||||
template <typename T>
|
||||
kernel void spmm_bmm_coo_rows_grouped(
|
||||
device const long* rows [[buffer(0)]],
|
||||
device const long* cols [[buffer(1)]],
|
||||
device const T* vals [[buffer(2)]],
|
||||
device const T* dense [[buffer(3)]],
|
||||
@ -73,7 +72,6 @@ kernel void spmm_bmm_coo_rows_grouped(
|
||||
uint3 ltid [[thread_position_in_threadgroup]],
|
||||
uint3 tptg [[threads_per_threadgroup]])
|
||||
{
|
||||
const uint B = dims.x;
|
||||
const uint I = dims.y;
|
||||
const uint J = dims.z;
|
||||
const uint K = dims.w;
|
||||
@ -197,9 +195,9 @@ kernel void fused_gather_mul_kernel(
|
||||
const ulong offR = (ulong)iR * (ulong)view_cols + (ulong)col;
|
||||
const ulong offO = (ulong)k * (ulong)view_cols + (ulong)col;
|
||||
|
||||
const float a = (float)lhs_vals[offL];
|
||||
const float b = (float)rhs_vals[offR];
|
||||
out_vals[offO] = (T)(a * b);
|
||||
const auto a = static_cast<accum_t<T>>(lhs_vals[offL]);
|
||||
const auto b = static_cast<accum_t<T>>(rhs_vals[offR]);
|
||||
out_vals[offO] = static_cast<T>(mul(a, b));
|
||||
}
|
||||
|
||||
// One thread per match copies the indices column
|
||||
@ -321,7 +319,6 @@ INSTANTIATE_FOR_FLOAT_TYPES(INSTANTIATE_FUSED_GATHER_MUL);
|
||||
#define INSTANTIATE_SPMM_BMM_COO_ROWS_GROUPED(DTYPE) \
|
||||
template [[host_name("spmm_bmm_coo_rows_grouped_" #DTYPE)]] kernel void \
|
||||
spmm_bmm_coo_rows_grouped<DTYPE>( \
|
||||
device const long* rows [[buffer(0)]], \
|
||||
device const long* cols [[buffer(1)]], \
|
||||
device const DTYPE* vals [[buffer(2)]], \
|
||||
device const DTYPE* dense [[buffer(3)]], \
|
||||
@ -58,8 +58,7 @@ def list_benchmarks():
|
||||
|
||||
def run_benchmark(
|
||||
benchmark_name: str,
|
||||
should_visualize: bool = False,
|
||||
compile_mode: str = "max-autotune-no-cudagraphs",
|
||||
script_args,
|
||||
):
|
||||
"""Run a specific benchmark."""
|
||||
if benchmark_name not in BENCHMARK_REGISTRY:
|
||||
@ -68,29 +67,29 @@ def run_benchmark(
|
||||
return False
|
||||
|
||||
print(f"Running benchmark: {benchmark_name}")
|
||||
print(f"Torch compile mode: {compile_mode}")
|
||||
print(f"Torch compile mode: {script_args.compile_mode}")
|
||||
print("=" * 60)
|
||||
|
||||
benchmark_class = BENCHMARK_REGISTRY[benchmark_name]
|
||||
benchmark = benchmark_class(compile_mode)
|
||||
benchmark = benchmark_class(script_args)
|
||||
benchmark.benchmark()
|
||||
if should_visualize:
|
||||
if script_args.visualize:
|
||||
benchmark.visualize()
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def run_all_benchmarks(should_visualize: bool = False, compile_mode: str = "default"):
|
||||
def run_all_benchmarks(script_args):
|
||||
"""Run all available benchmarks."""
|
||||
print("Running all benchmarks...")
|
||||
print(f"Torch compile mode: {compile_mode}")
|
||||
print(f"Torch compile mode: {script_args.compile_mode}")
|
||||
print("=" * 60)
|
||||
|
||||
for name, cls in BENCHMARK_REGISTRY.items():
|
||||
print(f"\n{'=' * 20} {name.upper()} {'=' * 20}")
|
||||
benchmark = cls(compile_mode)
|
||||
benchmark = cls(script_args)
|
||||
benchmark.benchmark()
|
||||
if should_visualize:
|
||||
if script_args.visualize:
|
||||
benchmark.visualize()
|
||||
print()
|
||||
|
||||
@ -137,6 +136,19 @@ Examples:
|
||||
help="Torch compile mode to use (default: default)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--tolerance",
|
||||
type=float,
|
||||
default=None,
|
||||
help="Tolerance for the accuracy check",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exit-on-accuracy-failure",
|
||||
action="store_true",
|
||||
help="Whether to exit with an error message for accuracy failure",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Handle list option
|
||||
@ -146,7 +158,7 @@ Examples:
|
||||
|
||||
# Handle all option
|
||||
if args.all:
|
||||
run_all_benchmarks(args.visualize, args.compile_mode)
|
||||
run_all_benchmarks(args)
|
||||
return
|
||||
|
||||
# Handle specific benchmarks
|
||||
@ -157,7 +169,7 @@ Examples:
|
||||
sys.exit(1)
|
||||
|
||||
for benchmark_name in args.benchmarks:
|
||||
run_benchmark(benchmark_name, args.visualize, args.compile_mode)
|
||||
run_benchmark(benchmark_name, args)
|
||||
print() # Add spacing between benchmarks
|
||||
|
||||
|
||||
|
||||
@ -9,8 +9,8 @@ import torch.nn.functional as F
|
||||
|
||||
|
||||
class CrossEntropyForward(BenchmarkKernel):
|
||||
def __init__(self, compile_mode: str = "max-autotune-no-cudagraphs"):
|
||||
super().__init__(compile_mode)
|
||||
def __init__(self, script_args):
|
||||
super().__init__(script_args)
|
||||
self.available_backends = ["eager", "compiled", "quack", "liger"]
|
||||
|
||||
def get_shapes(self) -> tuple[tuple[int, ...], ...]:
|
||||
@ -106,8 +106,8 @@ class CrossEntropyForward(BenchmarkKernel):
|
||||
|
||||
|
||||
class CrossEntropyBackward(BenchmarkKernel):
|
||||
def __init__(self, compile_mode: str = "max-autotune-no-cudagraphs"):
|
||||
super().__init__(compile_mode)
|
||||
def __init__(self, script_args):
|
||||
super().__init__(script_args)
|
||||
self.available_backends = ["eager", "compiled", "quack", "liger"]
|
||||
|
||||
def get_shapes(self) -> tuple[tuple[int, ...], ...]:
|
||||
@ -194,8 +194,8 @@ class CrossEntropyBackward(BenchmarkKernel):
|
||||
|
||||
|
||||
class SoftmaxForward(BenchmarkKernel):
|
||||
def __init__(self, compile_mode: str = "max-autotune-no-cudagraphs"):
|
||||
super().__init__(compile_mode)
|
||||
def __init__(self, script_args):
|
||||
super().__init__(script_args)
|
||||
self.available_backends = ["eager", "compiled", "quack", "liger"]
|
||||
|
||||
def get_shapes(self) -> tuple[tuple[int, ...], ...]:
|
||||
@ -259,8 +259,8 @@ class SoftmaxForward(BenchmarkKernel):
|
||||
|
||||
|
||||
class SoftmaxBackward(BenchmarkKernel):
|
||||
def __init__(self, compile_mode: str = "max-autotune-no-cudagraphs"):
|
||||
super().__init__(compile_mode)
|
||||
def __init__(self, script_args):
|
||||
super().__init__(script_args)
|
||||
self.available_backends = ["eager", "compiled", "quack", "liger"]
|
||||
|
||||
def get_shapes(self) -> tuple[tuple[int, ...], ...]:
|
||||
@ -329,8 +329,8 @@ class SoftmaxBackward(BenchmarkKernel):
|
||||
|
||||
|
||||
class RMSNormForward(BenchmarkKernel):
|
||||
def __init__(self, compile_mode: str = "max-autotune-no-cudagraphs"):
|
||||
super().__init__(compile_mode)
|
||||
def __init__(self, script_args):
|
||||
super().__init__(script_args)
|
||||
self.available_backends = ["eager", "compiled", "quack", "liger"]
|
||||
|
||||
def get_shapes(self) -> tuple[tuple[int, ...], ...]:
|
||||
@ -383,7 +383,22 @@ class RMSNormForward(BenchmarkKernel):
|
||||
from quack.rmsnorm import _rmsnorm_fwd
|
||||
|
||||
x, w = args
|
||||
return lambda: _rmsnorm_fwd(x, w, eps=1e-6)
|
||||
y = torch.empty_like(x)
|
||||
|
||||
def quack_fwd():
|
||||
_rmsnorm_fwd(
|
||||
x,
|
||||
w,
|
||||
out=y,
|
||||
bias=None,
|
||||
rstd=None,
|
||||
residual=None,
|
||||
residual_out=None,
|
||||
eps=1e-6,
|
||||
)
|
||||
return y
|
||||
|
||||
return quack_fwd
|
||||
|
||||
def liger(self, args, kwargs) -> Any:
|
||||
from liger_kernel.transformers.rms_norm import LigerRMSNorm
|
||||
@ -404,9 +419,14 @@ class RMSNormForward(BenchmarkKernel):
|
||||
|
||||
|
||||
class RMSNormBackward(BenchmarkKernel):
|
||||
def __init__(self, compile_mode: str = "max-autotune-no-cudagraphs"):
|
||||
super().__init__(compile_mode)
|
||||
self.available_backends = ["eager", "compiled", "quack", "liger"]
|
||||
def __init__(self, script_args):
|
||||
super().__init__(script_args)
|
||||
self.available_backends = [
|
||||
"eager",
|
||||
"compiled",
|
||||
"quack",
|
||||
"liger",
|
||||
]
|
||||
|
||||
def get_shapes(self) -> tuple[tuple[int, ...], ...]:
|
||||
# TODO: OOM for (32768, 65536) on h100
|
||||
@ -454,8 +474,11 @@ class RMSNormBackward(BenchmarkKernel):
|
||||
y, [x, w], grad_outputs=dy, retain_graph=True
|
||||
)
|
||||
|
||||
def compute_rstd(self, x, eps):
|
||||
return torch.rsqrt(torch.mean(x.float().square(), dim=-1, keepdim=True) + eps)
|
||||
|
||||
def quack(self, args, kwargs=None) -> Any:
|
||||
from quack.rmsnorm import _rmsnorm_backward
|
||||
from quack.rmsnorm import _get_sm_count, _rmsnorm_bwd
|
||||
|
||||
(
|
||||
x,
|
||||
@ -463,15 +486,40 @@ class RMSNormBackward(BenchmarkKernel):
|
||||
dy,
|
||||
) = args
|
||||
M, N = x.shape
|
||||
rstd = torch.randn(M, device="cuda", dtype=torch.float32)
|
||||
return lambda: _rmsnorm_backward(x, w, dy, rstd)
|
||||
|
||||
rstd = self.compute_rstd(x, eps=1e-6)
|
||||
dx = torch.empty_like(x)
|
||||
sm_count = _get_sm_count(x.size(1), x.device)
|
||||
dw_partial = torch.empty(
|
||||
sm_count, x.size(1), device=x.device, dtype=torch.float32
|
||||
)
|
||||
|
||||
def quack_bwd():
|
||||
_rmsnorm_bwd(
|
||||
x,
|
||||
w,
|
||||
dy,
|
||||
rstd,
|
||||
dx,
|
||||
dw_partial,
|
||||
db_partial=None,
|
||||
dresidual_out=None,
|
||||
dresidual=None,
|
||||
sm_count=sm_count,
|
||||
)
|
||||
dw = dw_partial.sum(dim=0).to(w.dtype)
|
||||
return dx, dw
|
||||
|
||||
return quack_bwd
|
||||
|
||||
def liger(self, args, kwargs=None) -> Any:
|
||||
from liger_kernel.transformers.rms_norm import LigerRMSNorm
|
||||
|
||||
x, w, dy = args
|
||||
M, N = x.shape
|
||||
liger_rmsnorm = LigerRMSNorm(hidden_size=N, eps=1e-6).cuda()
|
||||
liger_rmsnorm = LigerRMSNorm(
|
||||
hidden_size=N, eps=1e-6, casting_mode="gemma"
|
||||
).cuda()
|
||||
liger_rmsnorm.weight.data.copy_(w)
|
||||
y = liger_rmsnorm(x)
|
||||
return lambda: torch.autograd.grad(
|
||||
@ -489,8 +537,8 @@ class RMSNormBackward(BenchmarkKernel):
|
||||
|
||||
|
||||
class LayerNormForward(BenchmarkKernel):
|
||||
def __init__(self, compile_mode: str = "max-autotune-no-cudagraphs"):
|
||||
super().__init__(compile_mode)
|
||||
def __init__(self, script_args):
|
||||
super().__init__(script_args)
|
||||
self.available_backends = ["eager", "compiled", "quack", "liger"]
|
||||
|
||||
def get_shapes(self) -> tuple[tuple[int, ...], ...]:
|
||||
@ -563,8 +611,8 @@ class LayerNormForward(BenchmarkKernel):
|
||||
|
||||
|
||||
class LayerNormBackward(BenchmarkKernel):
|
||||
def __init__(self, compile_mode: str = "max-autotune-no-cudagraphs"):
|
||||
super().__init__(compile_mode)
|
||||
def __init__(self, script_args):
|
||||
super().__init__(script_args)
|
||||
self.available_backends = ["eager", "compiled", "liger"]
|
||||
|
||||
def get_shapes(self) -> tuple[tuple[int, ...], ...]:
|
||||
@ -614,20 +662,31 @@ class LayerNormBackward(BenchmarkKernel):
|
||||
y, [x, w], grad_outputs=dy, retain_graph=True
|
||||
)
|
||||
|
||||
def compute_mean_rstd(self, x, eps):
|
||||
x = x.float()
|
||||
|
||||
var, mean = torch.var_mean(x, dim=-1, keepdim=True, correction=0)
|
||||
rstd = torch.rsqrt(var + eps)
|
||||
return mean, rstd
|
||||
|
||||
def liger(self, args, kwargs) -> Any:
|
||||
from liger_kernel.transformers.layer_norm import LigerLayerNorm
|
||||
"""
|
||||
Call layer_norm_backward directly rather than calling
|
||||
liger_kernel.transformers.layer_norm.LigerLayerNorm and
|
||||
torch.autograd.grad.
|
||||
|
||||
The latter fashion saves mean/rstd in x.dtype which can fail
|
||||
accuracy test. We call layer_norm_backward with fp32 mean and
|
||||
rstd.
|
||||
"""
|
||||
from liger_kernel.ops.layer_norm import layer_norm_backward
|
||||
|
||||
x, w, dy = args
|
||||
eps = 1e-6
|
||||
mean, rstd = self.compute_mean_rstd(x, eps)
|
||||
M, N = x.shape
|
||||
liger_layernorm = LigerLayerNorm(hidden_size=N, eps=1e-6).cuda()
|
||||
liger_layernorm.weight.data.copy_(w)
|
||||
liger_layernorm.bias.data.copy_(
|
||||
torch.zeros(N, device="cuda", dtype=torch.float32)
|
||||
)
|
||||
y = liger_layernorm(x)
|
||||
return lambda: torch.autograd.grad(
|
||||
y, [x, liger_layernorm.weight], grad_outputs=dy, retain_graph=True
|
||||
)
|
||||
|
||||
return lambda: layer_norm_backward(dy, x, w, None, mean, rstd)[0:2]
|
||||
|
||||
def benchmark(self):
|
||||
for M, N in self.get_shapes():
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import os
|
||||
import sys
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
@ -43,10 +44,11 @@ class Performance:
|
||||
|
||||
|
||||
class BenchmarkKernel:
|
||||
def __init__(self, compile_mode: str = "max-autotune-no-cudagraphs"):
|
||||
def __init__(self, script_args):
|
||||
self.script_args = script_args
|
||||
self.name = self.__class__.__name__
|
||||
self.available_backends: list[str] = []
|
||||
self.compile_mode: str = compile_mode
|
||||
self.compile_mode: str = script_args.compile_mode
|
||||
|
||||
# mapping from backend to list of performance results
|
||||
self.profiling_results: defaultdict[str, list[Performance]] = defaultdict(list)
|
||||
@ -106,14 +108,21 @@ class BenchmarkKernel:
|
||||
args_ref, kwargs_ref = self.clone_inputs(args, kwargs)
|
||||
res[backend] = getattr(self, backend)(args_ref, kwargs_ref)()
|
||||
gold = res["eager"]
|
||||
|
||||
tol = {}
|
||||
if self.script_args.tolerance:
|
||||
tol = {
|
||||
"atol": self.script_args.tolerance,
|
||||
"rtol": self.script_args.tolerance,
|
||||
}
|
||||
for backend in self.available_backends:
|
||||
if backend == "eager":
|
||||
continue
|
||||
try:
|
||||
torch.testing.assert_close(res[backend], gold)
|
||||
torch.testing.assert_close(res[backend], gold, **tol)
|
||||
for t, gold_t in zip(res[backend], gold):
|
||||
if t.requires_grad:
|
||||
torch.testing.assert_close(t.grad, gold_t.grad)
|
||||
torch.testing.assert_close(t.grad, gold_t.grad, **tol)
|
||||
print(
|
||||
f"Accuracy check \033[92m✓ succeed\033[0m for {backend} backend on {self.name} kernel"
|
||||
)
|
||||
@ -121,6 +130,9 @@ class BenchmarkKernel:
|
||||
print(
|
||||
f"Accuracy check \033[91m✗ failed\033[0m for {backend} backend on {self.name} kernel. Error {e}"
|
||||
)
|
||||
if self.script_args.exit_on_accuracy_failure:
|
||||
print("Exit right away since --exit-on-accuracy-failure is set")
|
||||
sys.exit(1)
|
||||
|
||||
def benchmark_single_shape(
|
||||
self, args, kwargs=None, should_check_accuracy=True, setting: str = ""
|
||||
|
||||
@ -1,8 +1,8 @@
|
||||
add_loop_eager,compile_time_instruction_count,3070000000,0.1
|
||||
add_loop_eager,compile_time_instruction_count,3184000000,0.1
|
||||
|
||||
|
||||
|
||||
add_loop_eager_dynamic,compile_time_instruction_count,4432000000,0.1
|
||||
add_loop_eager_dynamic,compile_time_instruction_count,4595000000,0.1
|
||||
|
||||
|
||||
|
||||
@ -18,7 +18,7 @@ add_loop_inductor_gpu,compile_time_instruction_count,26800000000,0.1
|
||||
|
||||
|
||||
|
||||
basic_modules_ListOfLinears_eager,compile_time_instruction_count,1048000000,0.1
|
||||
basic_modules_ListOfLinears_eager,compile_time_instruction_count,1096000000,0.1
|
||||
|
||||
|
||||
|
||||
@ -26,7 +26,7 @@ basic_modules_ListOfLinears_inductor,compile_time_instruction_count,15240000000,
|
||||
|
||||
|
||||
|
||||
basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,17020000000,0.1
|
||||
basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,17720000000,0.1
|
||||
|
||||
|
||||
|
||||
@ -34,11 +34,11 @@ basic_modules_ListOfLinears_inductor_gpu,compile_time_instruction_count,11090000
|
||||
|
||||
|
||||
|
||||
update_hint_regression,compile_time_instruction_count,1719000000,0.1
|
||||
update_hint_regression,compile_time_instruction_count,1645000000,0.1
|
||||
|
||||
|
||||
|
||||
sum_floordiv_regression,compile_time_instruction_count,3686995725,0.1
|
||||
sum_floordiv_regression,compile_time_instruction_count,3813000000,0.1
|
||||
|
||||
|
||||
|
||||
@ -50,31 +50,31 @@ symint_sum_loop,compile_time_instruction_count,4299000000,0.1
|
||||
|
||||
|
||||
|
||||
aotdispatcher_inference_nosubclass_cpu,compile_time_instruction_count,1869000000,0.1
|
||||
aotdispatcher_inference_nosubclass_cpu,compile_time_instruction_count,1793000000,0.1
|
||||
|
||||
|
||||
|
||||
aotdispatcher_inference_subclass_cpu,compile_time_instruction_count,5281000000,0.1
|
||||
aotdispatcher_inference_subclass_cpu,compile_time_instruction_count,5120000000,0.1
|
||||
|
||||
|
||||
|
||||
aotdispatcher_partitioner_cpu,compile_time_instruction_count,8333000000,0.1
|
||||
aotdispatcher_partitioner_cpu,compile_time_instruction_count,7936000000,0.1
|
||||
|
||||
|
||||
|
||||
aotdispatcher_partitioner_cpu2,compile_time_instruction_count,1909000000,0.1
|
||||
aotdispatcher_partitioner_cpu2,compile_time_instruction_count,1848000000,0.1
|
||||
|
||||
|
||||
|
||||
aotdispatcher_training_nosubclass_cpu,compile_time_instruction_count,3442000000,0.1
|
||||
aotdispatcher_training_nosubclass_cpu,compile_time_instruction_count,3152000000,0.1
|
||||
|
||||
|
||||
|
||||
aotdispatcher_training_subclass_cpu,compile_time_instruction_count,9239000000,0.1
|
||||
aotdispatcher_training_subclass_cpu,compile_time_instruction_count,8301000000,0.1
|
||||
|
||||
|
||||
|
||||
mm_loop_inductor_gpu,compile_time_instruction_count,4820968837,0.1
|
||||
mm_loop_inductor_gpu,compile_time_instruction_count,4958000000,0.1
|
||||
|
||||
|
||||
|
||||
@ -82,8 +82,8 @@ mm_loop_inductor_dynamic_gpu,compile_time_instruction_count,9051000000,0.1
|
||||
|
||||
|
||||
|
||||
basic_NestedModule_eager,compile_time_instruction_count,9554000000,0.1
|
||||
basic_NestedModule_eager,compile_time_instruction_count,9990000000,0.1
|
||||
|
||||
|
||||
|
||||
basic_InlineMod_eager,compile_time_instruction_count,7618000000,0.1
|
||||
basic_InlineMod_eager,compile_time_instruction_count,8126000000,0.1
|
||||
|
||||
|
@ -43,6 +43,7 @@ tolerance:
|
||||
- doctr_reco_predictor
|
||||
- drq
|
||||
- phlippe_resnet
|
||||
- pytorch_CycleGAN_and_pix2pix
|
||||
|
||||
higher_bf16:
|
||||
- doctr_reco_predictor
|
||||
|
||||
@ -44,21 +44,101 @@ PyTorch,div_,div__M1_N1_K1_cpu_dtype_onetorch.float32_dtype_twotorch.float32,sho
|
||||
PyTorch,div_,div__M64_N64_K64_cpu_dtype_onetorch.float32_dtype_twotorch.float32,short,False,59.241161,0.000000
|
||||
PyTorch,div_,div__M64_N64_K128_cpu_dtype_onetorch.float32_dtype_twotorch.float32,short,False,59.852816,0.000000
|
||||
PyTorch,add,"add_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.float32",short,False,57.006677,0.000000
|
||||
PyTorch,add,"add_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.bfloat16",short,False,88.167000,0.000000
|
||||
PyTorch,add,"add_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.float64",short,False,57.519000,0.000000
|
||||
PyTorch,sub,"sub_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.float32",short,False,55.606088,0.000000
|
||||
PyTorch,sub,"sub_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.bfloat16",short,False,86.551000,0.000000
|
||||
PyTorch,sub,"sub_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.float64",short,False,57.864088,0.000000
|
||||
PyTorch,div,"div_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.float32",short,False,58.529255,0.000000
|
||||
PyTorch,div,"div_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.bfloat16",short,False,71.641000,0.000000
|
||||
PyTorch,div,"div_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.float64",short,False,83.073000,0.000000
|
||||
PyTorch,mul,"mul_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.float32",short,False,54.645077,0.000000
|
||||
PyTorch,mul,"mul_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.bfloat16",short,False,67.570000,0.000000
|
||||
PyTorch,mul,"mul_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.float64",short,False,57.895000,0.000000
|
||||
PyTorch,add,add_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,4.397014,0.000000
|
||||
PyTorch,add,add_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,7.739000,0.000000
|
||||
PyTorch,add,add_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,7.786000,0.000000
|
||||
PyTorch,add,add_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,1.911000,0.000000
|
||||
PyTorch,add,add_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,59.243500,0.000000
|
||||
PyTorch,add,add_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,105.066000,0.000000
|
||||
PyTorch,add,add_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,106.076000,0.000000
|
||||
PyTorch,add,add_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,47.225000,0.000000
|
||||
PyTorch,add,add_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,57.947691,0.000000
|
||||
PyTorch,add,add_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,107.291000,0.000000
|
||||
PyTorch,add,add_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,107.224000,0.000000
|
||||
PyTorch,add,add_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,47.912000,0.000000
|
||||
PyTorch,sub,sub_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,1.925851,0.000000
|
||||
PyTorch,sub,sub_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,8.0240000,0.000000
|
||||
PyTorch,sub,sub_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,8.069000,0.000000
|
||||
PyTorch,sub,sub_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,1.938000,0.000000
|
||||
PyTorch,sub,sub_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,57.308320,0.000000
|
||||
PyTorch,sub,sub_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,107.091000,0.000000
|
||||
PyTorch,sub,sub_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,108.710000,0.000000
|
||||
PyTorch,sub,sub_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,47.502000,0.000000
|
||||
PyTorch,sub,sub_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,57.787743,0.000000
|
||||
PyTorch,sub,sub_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,108.863000,0.000000
|
||||
PyTorch,sub,sub_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,108.939000,0.000000
|
||||
PyTorch,sub,sub_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,47.603000,0.000000
|
||||
PyTorch,div,div_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,7.978539,0.000000
|
||||
PyTorch,div,div_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,8.741000,0.000000
|
||||
PyTorch,div,div_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,8.757000,0.000000
|
||||
PyTorch,div,div_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,8.774000,0.000000
|
||||
PyTorch,div,div_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,159.754860,0.000000
|
||||
PyTorch,div,div_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,165.552000,0.000000
|
||||
PyTorch,div,div_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,165.755000,0.000000
|
||||
PyTorch,div,div_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,165.714000,0.000000
|
||||
PyTorch,div,div_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,165.360235,0.000000
|
||||
PyTorch,div,div_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,168.376000,0.000000
|
||||
PyTorch,div,div_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,169.604000,0.000000
|
||||
PyTorch,div,div_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,168.428000,0.000000
|
||||
PyTorch,mul,mul_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,3.928136,0.000000
|
||||
PyTorch,mul,mul_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,7.402000,0.000000
|
||||
PyTorch,mul,mul_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,7.567000,0.000000
|
||||
PyTorch,mul,mul_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,4.020000,0.000000
|
||||
PyTorch,mul,mul_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,56.413499,0.000000
|
||||
PyTorch,mul,mul_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,104.638000,0.000000
|
||||
PyTorch,mul,mul_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,104.335000,0.000000
|
||||
PyTorch,mul,mul_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,48.612000,0.000000
|
||||
PyTorch,mul,mul_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,55.925090,0.000000
|
||||
PyTorch,mul,mul_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,106.110000,0.000000
|
||||
PyTorch,mul,mul_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,106.389000,0.000000
|
||||
PyTorch,mul,mul_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,48.195000,0.000000
|
||||
PyTorch,asr,asr_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,1.989000,0.000000
|
||||
PyTorch,asr,asr_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,7.999000,0.000000
|
||||
PyTorch,asr,asr_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,7.939000,0.000000
|
||||
PyTorch,asr,asr_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,1.980000,0.000000
|
||||
PyTorch,asr,asr_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,54.408000,0.000000
|
||||
PyTorch,asr,asr_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,105.647000,0.000000
|
||||
PyTorch,asr,asr_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,106.476000,0.000000
|
||||
PyTorch,asr,asr_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,48.784000,0.000000
|
||||
PyTorch,asr,asr_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,55.583000,0.000000
|
||||
PyTorch,asr,asr_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,108.083000,0.000000
|
||||
PyTorch,asr,asr_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,107.663000,0.000000
|
||||
PyTorch,asr,asr_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,48.283000,0.000000
|
||||
PyTorch,lsl,lsl_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,1.986000,0.000000
|
||||
PyTorch,lsl,lsl_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,7.676000,0.000000
|
||||
PyTorch,lsl,lsl_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,7.618000,0.000000
|
||||
PyTorch,lsl,lsl_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,1.982000,0.000000
|
||||
PyTorch,lsl,lsl_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,54.698000,0.000000
|
||||
PyTorch,lsl,lsl_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,105.899000,0.000000
|
||||
PyTorch,lsl,lsl_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,106.741000,0.000000
|
||||
PyTorch,lsl,lsl_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,51.182000,0.000000
|
||||
PyTorch,lsl,lsl_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,55.290000,0.000000
|
||||
PyTorch,lsl,lsl_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,107.744000,0.000000
|
||||
PyTorch,lsl,lsl_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,107.820000,0.000000
|
||||
PyTorch,lsl,lsl_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,51.298000,0.000000
|
||||
PyTorch,xor,xor_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,1.988000,0.000000
|
||||
PyTorch,xor,xor_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,7.689000,0.000000
|
||||
PyTorch,xor,xor_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,7.695000,0.000000
|
||||
PyTorch,xor,xor_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,1.978000,0.000000
|
||||
PyTorch,xor,xor_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,54.934000,0.000000
|
||||
PyTorch,xor,xor_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,105.217000,0.000000
|
||||
PyTorch,xor,xor_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,104.215000,0.000000
|
||||
PyTorch,xor,xor_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,47.115000,0.000000
|
||||
PyTorch,xor,xor_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,55.974000,0.000000
|
||||
PyTorch,xor,xor_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,106.828000,0.000000
|
||||
PyTorch,xor,xor_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,106.879000,0.000000
|
||||
PyTorch,xor,xor_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,48.197000,0.000000
|
||||
PyTorch,logical_and,"logical_and_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.bool",short,False,78.404254,0.000000
|
||||
PyTorch,logical_and,logical_and_M1_N1_K1_cpu_dtype_onetorch.bool_dtype_twotorch.bool,short,False,5.354032,0.000000
|
||||
PyTorch,logical_and,logical_and_M64_N64_K64_cpu_dtype_onetorch.bool_dtype_twotorch.bool,short,False,54.072783,0.000000
|
||||
@ -71,6 +151,9 @@ PyTorch,baddbmm,baddbmm_B2_M1_N8_K2_cpu_dtypetorch.float32,short,False,6.631313,
|
||||
PyTorch,baddbmm,baddbmm_B2_M1_N8_K2_cpu_dtypetorch.bfloat16,short,False,6.476986,0.000000
|
||||
PyTorch,baddbmm,baddbmm_B128_M64_N32_K64_cpu_dtypetorch.float32,short,False,266.065131,0.000000
|
||||
PyTorch,baddbmm,baddbmm_B128_M64_N32_K64_cpu_dtypetorch.bfloat16,short,False,295.503063,0.000000
|
||||
PyTorch,all,all_M1_N1_K1_cpu,short,False,5.773000,0.000000
|
||||
PyTorch,all,all_M64_N64_K64_cpu,short,False,89.427000,0.000000
|
||||
PyTorch,all,all_M64_N64_K128_cpu,short,False,120.119000,0.000000
|
||||
PyTorch,cat,"cat_sizes(1,1,1)_N2_dim0_cpu",short,False,4.301950,0.000000
|
||||
PyTorch,cat,"cat_sizes(512,512,2)_N2_dim1_cpu",short,False,99.093415,0.000000
|
||||
PyTorch,cat,"cat_sizes(128,1024,2)_N2_dim1_cpu",short,False,96.771578,0.000000
|
||||
|
||||
|
@ -580,6 +580,9 @@ class BenchmarkRunner:
|
||||
else "unknown"
|
||||
)
|
||||
|
||||
# Extract operator name from test_name
|
||||
operator_name = test_name.split("_")[0]
|
||||
|
||||
# Create the record
|
||||
@dataclass
|
||||
class BenchmarkInfo:
|
||||
@ -593,6 +596,7 @@ class BenchmarkRunner:
|
||||
name: str
|
||||
type: str
|
||||
origins: list[str]
|
||||
extra_info: dict[str, Any]
|
||||
|
||||
@dataclass
|
||||
class MetricInfo:
|
||||
@ -618,10 +622,14 @@ class BenchmarkRunner:
|
||||
"device": device,
|
||||
"arch": device_arch,
|
||||
"use_compile": use_compile,
|
||||
"operator_name": operator_name,
|
||||
},
|
||||
),
|
||||
model=ModelInfo(
|
||||
name=test_name, type="micro-benchmark", origins=["pytorch"]
|
||||
name=test_name,
|
||||
type="micro-benchmark",
|
||||
origins=["pytorch"],
|
||||
extra_info={"operator_name": operator_name},
|
||||
),
|
||||
metric=MetricInfo(
|
||||
name="latency",
|
||||
|
||||
@ -25,7 +25,7 @@ binary_configs_broadcast = op_bench.config_list(
|
||||
],
|
||||
cross_product_configs={
|
||||
"device": ["cpu"],
|
||||
"dtype": [torch.float],
|
||||
"dtype": [torch.float, torch.bfloat16, torch.float64],
|
||||
},
|
||||
tags=["short"],
|
||||
)
|
||||
@ -71,8 +71,8 @@ binary_short_configs = op_bench.config_list(
|
||||
],
|
||||
cross_product_configs={
|
||||
"device": ["cpu", "cuda"],
|
||||
"dtype_one": [torch.int32],
|
||||
"dtype_two": [torch.int32],
|
||||
"dtype_one": [torch.int32, torch.uint8],
|
||||
"dtype_two": [torch.int32, torch.uint8],
|
||||
},
|
||||
tags=["short"],
|
||||
)
|
||||
@ -82,8 +82,8 @@ binary_long_configs = op_bench.cross_product_configs(
|
||||
N=[32, 64],
|
||||
K=[256, 512],
|
||||
device=["cpu", "cuda"],
|
||||
dtype_one=[torch.int8, torch.int32],
|
||||
dtype_two=[torch.int8, torch.int32],
|
||||
dtype_one=[torch.int8, torch.int32, torch.uint8],
|
||||
dtype_two=[torch.int8, torch.int32, torch.uint8],
|
||||
tags=["long"],
|
||||
)
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -176,8 +176,8 @@ THIRD_PARTY_LIBS = {
|
||||
"omp": ["//xplat/third-party/linker_lib:omp", "//third_party:no-op"],
|
||||
"pocketfft": ["//third-party/pocket_fft:pocketfft", "//third_party:pocketfft_header"],
|
||||
"psimd": ["//xplat/third-party/psimd:psimd", "//third_party:psimd"],
|
||||
"pthreadpool": ["//xplat/third-party/pthreadpool:pthreadpool", "//third_party:pthreadpool"],
|
||||
"pthreadpool_header": ["//xplat/third-party/pthreadpool:pthreadpool_header", "//third_party:pthreadpool_header"],
|
||||
"pthreadpool": ["fbsource//xplat/third-party/pthreadpool:pthreadpool", "//third_party:pthreadpool"],
|
||||
"pthreadpool_header": ["fbsource//xplat/third-party/pthreadpool:pthreadpool_header", "//third_party:pthreadpool_header"],
|
||||
"moodycamel": ["//third-party/moodycamel:moodycamel", "//third_party:moodycamel"],
|
||||
"pyyaml": ["//third-party/pypi/pyyaml:pyyaml", "//third_party:pyyaml"],
|
||||
"rt": ["//xplat/third-party/linker_lib:rt", "//third_party:rt"],
|
||||
@ -1729,8 +1729,10 @@ def define_buck_targets(
|
||||
"torch/csrc/jit/backends/backend_debug_info.cpp",
|
||||
"torch/csrc/jit/backends/backend_interface.cpp",
|
||||
],
|
||||
compiler_flags = get_pt_compiler_flags(),
|
||||
fbandroid_compiler_flags = c2_fbandroid_xplat_compiler_flags,
|
||||
compiler_flags = get_pt_compiler_flags() + select({
|
||||
"DEFAULT": [],
|
||||
"ovr_config//os:android": c2_fbandroid_xplat_compiler_flags
|
||||
}),
|
||||
# @lint-ignore BUCKLINT link_whole
|
||||
link_whole = True,
|
||||
linker_flags = get_no_as_needed_linker_flag(),
|
||||
@ -2023,6 +2025,9 @@ def define_buck_targets(
|
||||
"ovr_config//os:android-x86_64": [
|
||||
"-mssse3",
|
||||
],
|
||||
}) + select({
|
||||
"DEFAULT": [],
|
||||
"ovr_config//os:android": c2_fbandroid_xplat_compiler_flags,
|
||||
}),
|
||||
exported_preprocessor_flags = get_aten_preprocessor_flags(),
|
||||
exported_deps = [
|
||||
|
||||
@ -855,6 +855,7 @@ libtorch_python_cuda_core_sources = [
|
||||
"torch/csrc/cuda/Stream.cpp",
|
||||
"torch/csrc/cuda/Graph.cpp",
|
||||
"torch/csrc/cuda/MemPool.cpp",
|
||||
"torch/csrc/cuda/GreenContext.cpp",
|
||||
"torch/csrc/cuda/shared/cudart.cpp",
|
||||
"torch/csrc/cuda/shared/nvtx.cpp",
|
||||
"torch/csrc/cuda/utils.cpp",
|
||||
|
||||
@ -13,7 +13,17 @@
|
||||
namespace c10::CachingAllocator {
|
||||
|
||||
// "large" allocations may be packed in 20 MiB blocks
|
||||
const size_t kLargeBuffer = 20971520;
|
||||
constexpr size_t kLargeBuffer = 20971520;
|
||||
// "small" allocations are packed in 2 MiB blocks
|
||||
constexpr size_t kSmallBuffer = 2097152;
|
||||
// all sizes are rounded to at least 512 bytes
|
||||
constexpr size_t kMinBlockSize = 512;
|
||||
// largest "small" allocation is 1 MiB
|
||||
constexpr size_t kSmallSize = 1048576;
|
||||
// allocations between 1 and 10 MiB may use kLargeBuffer
|
||||
constexpr size_t kMinLargeAlloc = 10485760;
|
||||
// round up large allocations to 2 MiB
|
||||
constexpr size_t kRoundLarge = 2097152;
|
||||
|
||||
// A utility class for tokenizing allocator configuration strings into discrete
|
||||
// parts. For example, the config string:
|
||||
|
||||
@ -223,7 +223,7 @@ inline DispatchKey backendToDispatchKey(Backend b) {
|
||||
case Backend::PrivateUse1:
|
||||
return DispatchKey::PrivateUse1;
|
||||
default:
|
||||
throw std::runtime_error("Unknown backend");
|
||||
TORCH_CHECK(false, "Unknown backend");
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -336,7 +336,7 @@ class C10_API Scalar {
|
||||
} else if (isBoolean()) {
|
||||
return ScalarType::Bool;
|
||||
} else {
|
||||
throw std::runtime_error("Unknown scalar type.");
|
||||
TORCH_CHECK(false, "Unknown scalar type.");
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -228,7 +228,7 @@ std::pair<std::string, std::string> getDtypeNames(c10::ScalarType scalarType) {
|
||||
case c10::ScalarType::Float4_e2m1fn_x2:
|
||||
return std::make_pair("float4_e2m1fn_x2", "");
|
||||
default:
|
||||
throw std::runtime_error("Unimplemented scalar type");
|
||||
TORCH_CHECK(false, "Unimplemented scalar type");
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -52,19 +52,6 @@ AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(SPECIALIZE_CppTypeToScalarType)
|
||||
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_CONSTANT)
|
||||
#undef DEFINE_CONSTANT
|
||||
|
||||
inline const char* toString(ScalarType t) {
|
||||
#define DEFINE_CASE(_, name) \
|
||||
case ScalarType::name: \
|
||||
return #name;
|
||||
|
||||
switch (t) {
|
||||
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_CASE)
|
||||
default:
|
||||
return "UNKNOWN_SCALAR";
|
||||
}
|
||||
#undef DEFINE_CASE
|
||||
}
|
||||
|
||||
inline size_t elementSize(ScalarType t) {
|
||||
#define CASE_ELEMENTSIZE_CASE(ctype, name) \
|
||||
case ScalarType::name: \
|
||||
@ -150,22 +137,6 @@ inline ScalarType toQIntType(ScalarType t) {
|
||||
}
|
||||
}
|
||||
|
||||
inline ScalarType toUnderlying(ScalarType t) {
|
||||
switch (t) {
|
||||
case ScalarType::QUInt8:
|
||||
case ScalarType::QUInt4x2:
|
||||
[[fallthrough]];
|
||||
case ScalarType::QUInt2x4:
|
||||
return ScalarType::Byte;
|
||||
case ScalarType::QInt8:
|
||||
return ScalarType::Char;
|
||||
case ScalarType::QInt32:
|
||||
return ScalarType::Int;
|
||||
default:
|
||||
return t;
|
||||
}
|
||||
}
|
||||
|
||||
inline bool isSignedType(ScalarType t) {
|
||||
#define CASE_ISSIGNED(name) \
|
||||
case ScalarType::name: \
|
||||
@ -308,12 +279,6 @@ inline bool canCast(const ScalarType from, const ScalarType to) {
|
||||
|
||||
C10_API ScalarType promoteTypes(ScalarType a, ScalarType b);
|
||||
|
||||
inline std::ostream& operator<<(
|
||||
std::ostream& stream,
|
||||
at::ScalarType scalar_type) {
|
||||
return stream << toString(scalar_type);
|
||||
}
|
||||
|
||||
// Returns a pair of strings representing the names for each dtype.
|
||||
// The returned pair is (name, legacy_name_if_applicable)
|
||||
C10_API std::pair<std::string, std::string> getDtypeNames(
|
||||
|
||||
@ -87,9 +87,7 @@ bool ThreadPool::inThreadPool() const {
|
||||
}
|
||||
|
||||
void ThreadPool::run(std::function<void()> func) {
|
||||
if (threads_.empty()) {
|
||||
throw std::runtime_error("No threads to run a task");
|
||||
}
|
||||
TORCH_CHECK(threads_.size() > 0, "No threads to run a task");
|
||||
std::unique_lock<std::mutex> lock(mutex_);
|
||||
|
||||
// Set task and signal condition variable so that a worker thread will
|
||||
|
||||
@ -131,15 +131,6 @@ namespace Native {
|
||||
* notifyCaptureDestroy.
|
||||
*/
|
||||
|
||||
constexpr size_t kMinBlockSize =
|
||||
512; // all sizes are rounded to at least 512 bytes
|
||||
constexpr size_t kSmallSize = 1048576; // largest "small" allocation is 1 MiB
|
||||
constexpr size_t kSmallBuffer =
|
||||
2097152; // "small" allocations are packed in 2 MiB blocks
|
||||
constexpr size_t kMinLargeAlloc =
|
||||
10485760; // allocations between 1 and 10 MiB may use kLargeBuffer
|
||||
constexpr size_t kRoundLarge = 2097152; // round up large allocations to 2 MiB
|
||||
|
||||
static char SHAREABLE_HANDLE_VERSION = 2;
|
||||
enum ShareableHandleType : char {
|
||||
SHAREABLE_CUDA_MALLOC = 'c',
|
||||
@ -4478,7 +4469,10 @@ struct BackendStaticInitializer {
|
||||
if (key == "backend") {
|
||||
tokenizer.checkToken(++i, ":");
|
||||
i++; // Move to the value after the colon
|
||||
if (tokenizer[i] == "cudaMallocAsync"
|
||||
// break up token to trick hipify
|
||||
if (tokenizer[i] ==
|
||||
"c"
|
||||
"udaMallocAsync"
|
||||
#ifdef USE_ROCM
|
||||
// convenience for ROCm users to allow either CUDA or HIP env var
|
||||
|| tokenizer[i] == "hipMallocAsync"
|
||||
|
||||
@ -913,7 +913,9 @@ struct CudaMallocAsyncAllocator : public CUDAAllocator {
|
||||
}
|
||||
}
|
||||
std::string name() override {
|
||||
return "cudaMallocAsync";
|
||||
// break up token to trick hipify
|
||||
return "c"
|
||||
"udaMallocAsync";
|
||||
}
|
||||
void copy_data(void* dest, const void* src, std::size_t count) const final {
|
||||
C10_CUDA_CHECK(
|
||||
|
||||
@ -51,6 +51,17 @@
|
||||
|
||||
#if defined(CUDA_VERSION) && (CUDA_VERSION >= 12030)
|
||||
#define C10_LIBCUDA_DRIVER_API_OPTIONAL(_) \
|
||||
_(cuCtxFromGreenCtx, 12080) \
|
||||
_(cuCtxGetCurrent, 12080) \
|
||||
_(cuCtxPopCurrent, 12080) \
|
||||
_(cuCtxPushCurrent, 12080) \
|
||||
_(cuCtxSetCurrent, 12080) \
|
||||
_(cuGreenCtxCreate, 12080) \
|
||||
_(cuGreenCtxDestroy, 12080) \
|
||||
_(cuDevSmResourceSplitByCount, 12080) \
|
||||
_(cuDeviceGet, 12080) \
|
||||
_(cuDeviceGetDevResource, 12080) \
|
||||
_(cuDevResourceGenerateDesc, 12080) \
|
||||
_(cuMulticastAddDevice, 12030) \
|
||||
_(cuMulticastBindMem, 12030) \
|
||||
_(cuMulticastCreate, 12030) \
|
||||
|
||||
@ -328,6 +328,21 @@ struct pair {
|
||||
T2 second;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
static T conj(T a) {
|
||||
return a;
|
||||
}
|
||||
|
||||
template <>
|
||||
half2 conj(half2 a) {
|
||||
return half2(a.x, -a.y);
|
||||
}
|
||||
|
||||
template <>
|
||||
float2 conj(float2 a) {
|
||||
return float2(a.x, -a.y);
|
||||
}
|
||||
|
||||
#define INSTANTIATE_FOR_ALL_TYPES(MACRO) \
|
||||
MACRO(float); \
|
||||
MACRO(half); \
|
||||
|
||||
@ -45,14 +45,7 @@ constexpr bool is_pod_v = is_pod<T>::value;
|
||||
|
||||
namespace guts {
|
||||
|
||||
#if defined(__cpp_lib_apply) && !defined(__CUDA_ARCH__) && !defined(__HIP__)
|
||||
|
||||
template <class F, class Tuple>
|
||||
C10_HOST_DEVICE inline constexpr decltype(auto) apply(F&& f, Tuple&& t) {
|
||||
return std::apply(std::forward<F>(f), std::forward<Tuple>(t));
|
||||
}
|
||||
|
||||
#else
|
||||
#if defined(__HIP__)
|
||||
|
||||
// Implementation from http://en.cppreference.com/w/cpp/utility/apply (but
|
||||
// modified)
|
||||
|
||||
@ -14,16 +14,6 @@ using namespace c10::CachingDeviceAllocator;
|
||||
|
||||
// newly allocated memory with 512-byte alignment.
|
||||
constexpr size_t kDeviceAlignment = 512;
|
||||
// all sizes are rounded to at least 512 bytes
|
||||
constexpr size_t kMinBlockSize = 512;
|
||||
// largest "small" allocation is 1 MiB
|
||||
constexpr size_t kSmallSize = 1048576;
|
||||
// "small" allocations are packed in 2 MiB blocks
|
||||
constexpr size_t kSmallBuffer = 2097152;
|
||||
// allocations between 1 and 10 MiB may use kLargeBuffer
|
||||
constexpr size_t kMinLargeAlloc = 10485760;
|
||||
// round up large allocations to 2 MiB
|
||||
constexpr size_t kRoundLarge = 2097152;
|
||||
|
||||
namespace {
|
||||
using stream_set = ska::flat_hash_set<xpu::XPUStream>;
|
||||
|
||||
@ -607,6 +607,12 @@ if(USE_CUDA)
|
||||
set_source_files_properties(${ASYNC_MM_FILE} PROPERTIES COMPILE_FLAGS "-gencode arch=compute_90a,code=sm_90a")
|
||||
endif()
|
||||
endif()
|
||||
if(NOT WIN32)
|
||||
set_source_files_properties(
|
||||
${TORCH_ROOT}/aten/src/ATen/cuda/CUDAGreenContext.cpp
|
||||
PROPERTIES COMPILE_FLAGS "-DPYTORCH_C10_DRIVER_API_SUPPORTED=1"
|
||||
)
|
||||
endif()
|
||||
set_source_files_properties(
|
||||
${TORCH_ROOT}/aten/src/ATen/cuda/detail/LazyNVRTC.cpp
|
||||
PROPERTIES COMPILE_DEFINITIONS "NVRTC_SHORTHASH=${CUDA_NVRTC_SHORTHASH}"
|
||||
|
||||
@ -1638,38 +1638,7 @@ if(USE_KINETO)
|
||||
message(STATUS " KINETO_LIBRARY_TYPE = ${KINETO_LIBRARY_TYPE}")
|
||||
|
||||
if(NOT LIBKINETO_NOCUPTI)
|
||||
set(CUDA_SOURCE_DIR "${CUDA_TOOLKIT_ROOT_DIR}" CACHE STRING "")
|
||||
message(STATUS " CUDA_SOURCE_DIR = ${CUDA_SOURCE_DIR}")
|
||||
message(STATUS " CUDA_INCLUDE_DIRS = ${CUDA_INCLUDE_DIRS}")
|
||||
|
||||
if(NOT MSVC)
|
||||
if(USE_CUPTI_SO)
|
||||
set(CUPTI_LIB_NAME "libcupti.so")
|
||||
else()
|
||||
set(CUPTI_LIB_NAME "libcupti_static.a")
|
||||
endif()
|
||||
else()
|
||||
set(CUPTI_LIB_NAME "cupti.lib")
|
||||
endif()
|
||||
|
||||
find_library(CUPTI_LIBRARY_PATH ${CUPTI_LIB_NAME} PATHS
|
||||
${CUDA_SOURCE_DIR}
|
||||
${CUDA_SOURCE_DIR}/extras/CUPTI/lib64
|
||||
${CUDA_SOURCE_DIR}/lib
|
||||
${CUDA_SOURCE_DIR}/lib64
|
||||
NO_DEFAULT_PATH)
|
||||
|
||||
find_path(CUPTI_INCLUDE_DIR cupti.h PATHS
|
||||
${CUDA_SOURCE_DIR}/extras/CUPTI/include
|
||||
${CUDA_INCLUDE_DIRS}
|
||||
${CUDA_SOURCE_DIR}
|
||||
${CUDA_SOURCE_DIR}/include
|
||||
NO_DEFAULT_PATH)
|
||||
|
||||
if(CUPTI_LIBRARY_PATH AND CUPTI_INCLUDE_DIR)
|
||||
message(STATUS " CUPTI_INCLUDE_DIR = ${CUPTI_INCLUDE_DIR}")
|
||||
set(CUDA_cupti_LIBRARY ${CUPTI_LIBRARY_PATH})
|
||||
message(STATUS " CUDA_cupti_LIBRARY = ${CUDA_cupti_LIBRARY}")
|
||||
if(TARGET CUDA::cupti)
|
||||
message(STATUS "Found CUPTI")
|
||||
set(LIBKINETO_NOCUPTI OFF CACHE STRING "" FORCE)
|
||||
|
||||
@ -1682,7 +1651,7 @@ if(USE_KINETO)
|
||||
if(NOT APPLE)
|
||||
set(CMAKE_REQUIRED_LIBRARIES ${CMAKE_REQUIRED_LIBRARIES} "dl" "pthread")
|
||||
endif()
|
||||
set(CMAKE_REQUIRED_LINK_OPTIONS "-Wl,--whole-archive,${CUPTI_LIBRARY_PATH},--no-whole-archive")
|
||||
set(CMAKE_REQUIRED_LIBRARIES ${CMAKE_REQUIRED_LIBRARIES} $<LINK_LIBRARY:WHOLE_ARCHIVE,CUDA::cupti_static>)
|
||||
check_cxx_source_runs("#include <stdexcept>
|
||||
int main() {
|
||||
try {
|
||||
|
||||
@ -272,7 +272,7 @@ Here, we'll briefly introduce the implementation process of custom operators, fo
|
||||
* Name: `input`
|
||||
* Output Type: `Tensor`
|
||||
|
||||
2. **Register Operator&Autograd Fallback:**
|
||||
2. **Register Operator**
|
||||
|
||||
::::{tab-set}
|
||||
|
||||
@ -285,19 +285,11 @@ Here, we'll briefly introduce the implementation process of custom operators, fo
|
||||
:end-before: LITERALINCLUDE END: CUSTOM OPERATOR DEFAULT
|
||||
:linenos:
|
||||
|
||||
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/csrc/aten/OpenRegExtra.cpp
|
||||
:language: c++
|
||||
:start-after: LITERALINCLUDE START: CUSTOM OPERATOR FALLBACK
|
||||
:end-before: LITERALINCLUDE END: CUSTOM OPERATOR FALLBACK
|
||||
:emphasize-lines: 2
|
||||
:linenos:
|
||||
```
|
||||
|
||||
:::
|
||||
|
||||
::::
|
||||
|
||||
Use `TORCH_LIBRARY_IMPL` to register the `wrapper_custom_abs` implementation for the `custom_abs` operator in `PrivateUse1`. However, because `Autograd` is always enabled in PyTorch, PyTorch defaults to finding and executing the corresponding backward implementation even if only forward computation is required(will fallthrough in backward implementation). Therefore, we also need to register the corresponding implementation for `AutogradPrivateUse1` of the `custom_abs` operator. Fortunately, PyTorch also provides a general `Autograd Fallback` mechanism named `torch::autograd::autogradNotImplementedFallback`, if only forward computation is involved, it is equivalent to a fallthrough operation, selecting the next DispatchKey for computation; if backward computation is involved, an error is thrown.
|
||||
Use `TORCH_LIBRARY_IMPL` to register the `wrapper_custom_abs` implementation for the `custom_abs` operator in `PrivateUse1`. Because `Autograd` is always enabled in PyTorch, PyTorch defaults to finding and executing the corresponding backward implementation even if only forward computation is required(will fallthrough in backward implementation). Fortunately, PyTorch have implemented a general `Autograd Fallback` for PrivateUse1 as well, if only forward computation is involved, it is equivalent to a fallthrough operation, selecting the next DispatchKey for computation; if backward computation is involved, an error is thrown.
|
||||
|
||||
3. **Register Metadata(optional, but required by the graph mode, etc.):**
|
||||
|
||||
|
||||
@ -258,6 +258,28 @@ See the docs for {class}`~torch.cuda.gds.GdsFile` for an example of how to use t
|
||||
|
||||
```
|
||||
|
||||
## Green Contexts (experimental)
|
||||
|
||||
`torch.cuda.green_contexts` provides thin wrappers around the CUDA Green Context APIs
|
||||
to enable more general carveout of SM resources for CUDA kernels.
|
||||
|
||||
These APIs can be used in PyTorch with CUDA versions greater than or equal to 12.8.
|
||||
|
||||
See the docs for {class}`~torch.cuda.green_contexts.GreenContext` for an example of how to use these.
|
||||
|
||||
```{eval-rst}
|
||||
.. currentmodule:: torch.cuda.green_contexts
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. autosummary::
|
||||
:toctree: generated
|
||||
:nosignatures:
|
||||
|
||||
GreenContext
|
||||
```
|
||||
|
||||
|
||||
% This module needs to be documented. Adding here in the meantime
|
||||
|
||||
% for tracking purposes
|
||||
@ -270,6 +292,10 @@ See the docs for {class}`~torch.cuda.gds.GdsFile` for an example of how to use t
|
||||
.. py:module:: torch.cuda.gds
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. py:module:: torch.cuda.green_contexts
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. py:module:: torch.cuda.jiterator
|
||||
```
|
||||
|
||||
@ -44,9 +44,9 @@ following invariants. More specifications about the IR can be found
|
||||
- **Normalized**: There are no Python semantics within the graph. Submodules
|
||||
from the original programs are inlined to form one fully flattened
|
||||
computational graph.
|
||||
- **Graph properties**: The graph is purely functional, meaning it does not
|
||||
contain operations with side effects such as mutations or aliasing. It does
|
||||
not mutate any intermediate values, parameters, or buffers.
|
||||
- **Graph properties**: By default, the graph may contain both functional and
|
||||
non-functional operators (including mutations). To obtain a purely functional
|
||||
graph, use `run_decompositions()` which removes mutations and aliasing.
|
||||
- **Metadata**: The graph contains metadata captured during tracing, such as a
|
||||
stacktrace from user's code.
|
||||
|
||||
@ -56,8 +56,8 @@ Under the hood, `torch.export` leverages the following latest technologies:
|
||||
called the Frame Evaluation API to safely trace PyTorch graphs. This
|
||||
provides a massively improved graph capturing experience, with much fewer
|
||||
rewrites needed in order to fully trace the PyTorch code.
|
||||
- **AOT Autograd** provides a functionalized PyTorch graph and ensures the graph
|
||||
is decomposed/lowered to the ATen operator set.
|
||||
- **AOT Autograd** ensures the graph is decomposed/lowered to the ATen operator
|
||||
set. When using `run_decompositions()`, it can also provide functionalization.
|
||||
- **Torch FX (torch.fx)** is the underlying representation of the graph,
|
||||
allowing flexible Python-based transformations.
|
||||
|
||||
@ -444,23 +444,31 @@ saved_exported_program = torch.export.load('exported_program.pt2')
|
||||
|
||||
(training-export)=
|
||||
|
||||
## Export IR, Decompositions
|
||||
## Export IR: Training vs Inference
|
||||
|
||||
The graph produced by `torch.export` returns a graph containing only
|
||||
[ATen operators](https://pytorch.org/cppdocs/#aten), which are the basic unit of
|
||||
computation in PyTorch. As there are over
|
||||
3000 ATen operators, export provides a way to narrow down the operator set used
|
||||
in the graph based on certain characteristics, creating different IRs.
|
||||
computation in PyTorch. Export provides different IR levels based on your use case:
|
||||
|
||||
By default, export produces the most generic IR which contains all ATen
|
||||
operators, including both functional and non-functional operators. A functional
|
||||
operator is one that does not contain any mutations or aliasing of the inputs.
|
||||
| IR Type | How to Obtain | Properties | Operator Count | Use Case |
|
||||
|---------|---------------|------------|----------------|----------|
|
||||
| Training IR | `torch.export.export()` (default) | May contain mutations | ~3000 | Training with autograd |
|
||||
| Inference IR | `ep.run_decompositions(decomp_table={})` | Purely functional | ~2000 | Inference deployment |
|
||||
| Core ATen IR | `ep.run_decompositions(decomp_table=None)` | Purely functional, highly decomposed | ~180 | Minimal backend support |
|
||||
|
||||
### Training IR (Default)
|
||||
|
||||
By default, export produces a **Training IR** which contains all ATen
|
||||
operators, including both functional and non-functional (mutating) operators.
|
||||
A functional operator is one that does not contain any mutations or aliasing
|
||||
of the inputs, while non-functional operators may modify their inputs in-place.
|
||||
You can find a list of all ATen operators
|
||||
[here](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/native_functions.yaml)
|
||||
and you can inspect if an operator is functional by checking
|
||||
`op._schema.is_mutable`.
|
||||
|
||||
This generic IR can be used to train in eager PyTorch Autograd.
|
||||
This Training IR, which may contain mutations, is designed for training use
|
||||
cases and can be used with eager PyTorch Autograd.
|
||||
|
||||
```{code-cell}
|
||||
import torch
|
||||
@ -480,15 +488,18 @@ ep_for_training = torch.export.export(M(), (torch.randn(1, 1, 3, 3),))
|
||||
print(ep_for_training.graph_module.print_readable(print_output=False))
|
||||
```
|
||||
|
||||
However, if you want to use the IR for inference, or decrease the amount of
|
||||
operators being used, you can lower the graph through the
|
||||
{func}`ExportedProgram.run_decompositions` API. This method decomposes the
|
||||
ATen operators into the ones specified in the decomposition table, and
|
||||
functionalizes the graph.
|
||||
### Inference IR (via run_decompositions)
|
||||
|
||||
By specifying an empty set, we're only performing functionalization, and does
|
||||
not do any additional decompositions. This results in an IR which contains ~2000
|
||||
operators (instead of the 3000 operators above), and is ideal for inference cases.
|
||||
To obtain an **Inference IR** suitable for deployment, use the
|
||||
{func}`ExportedProgram.run_decompositions` API. This method automatically:
|
||||
1. Functionalizes the graph (removes all mutations and converts them to functional equivalents)
|
||||
2. Optionally decomposes ATen operators based on the provided decomposition table
|
||||
|
||||
This produces a purely functional graph ideal for inference scenarios.
|
||||
|
||||
By specifying an empty decomposition table (`decomp_table={}`), you get just
|
||||
the functionalization without additional decompositions. This produces an
|
||||
Inference IR with ~2000 functional operators (compared to 3000+ in Training IR).
|
||||
|
||||
```{code-cell}
|
||||
import torch
|
||||
@ -514,11 +525,14 @@ As we can see, the previously in-place operator,
|
||||
`torch.ops.aten.add_.default` has now been replaced with
|
||||
`torch.ops.aten.add.default`, a functional operator.
|
||||
|
||||
We can also further lower this exported program to an operator set which only
|
||||
contains the
|
||||
### Core ATen IR
|
||||
|
||||
We can further lower the Inference IR to the
|
||||
`Core ATen Operator Set <https://pytorch.org/docs/main/torch.compiler_ir.html#core-aten-ir>`__,
|
||||
which is a collection of only ~180 operators. This IR is optimal for backends
|
||||
who do not want to reimplement all ATen operators.
|
||||
which contains only ~180 operators. This is achieved by passing `decomp_table=None`
|
||||
(which uses the default decomposition table) to `run_decompositions()`. This IR
|
||||
is optimal for backends who want to minimize the number of operators they need
|
||||
to implement.
|
||||
|
||||
```{code-cell}
|
||||
import torch
|
||||
|
||||
@ -208,6 +208,7 @@ select = [
|
||||
"PLC1802", # len({expression}) used as condition without comparison
|
||||
"PLC0205", # string as __slots__
|
||||
"PLC3002", # unnecessary-direct-lambda-call
|
||||
"PLC0414", # Import alias does not rename original package
|
||||
"PLE",
|
||||
"PLR0133", # constant comparison
|
||||
"PLR0206", # property with params
|
||||
|
||||
@ -53,3 +53,40 @@ TEST_FORALL(AT_FORALL_COMPLEX_TYPES, 2)
|
||||
|
||||
#undef DEFINE_CHECK
|
||||
#undef TEST_FORALL
|
||||
|
||||
TEST(TestScalarType, toString) {
|
||||
using torch::headeronly::ScalarType;
|
||||
|
||||
#define DEFINE_CHECK(_, name) EXPECT_EQ(toString(ScalarType::name), #name);
|
||||
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_CHECK);
|
||||
#undef DEFINE_CHECK
|
||||
}
|
||||
|
||||
TEST(TestScalarType, operator_left_shift) {
|
||||
using torch::headeronly::ScalarType;
|
||||
|
||||
#define DEFINE_CHECK(_, name) \
|
||||
{ \
|
||||
std::stringstream ss; \
|
||||
ss << ScalarType::name; \
|
||||
EXPECT_EQ(ss.str(), #name); \
|
||||
}
|
||||
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_CHECK);
|
||||
#undef DEFINE_CHECK
|
||||
}
|
||||
|
||||
TEST(TestScalarType, toUnderlying) {
|
||||
using torch::headeronly::ScalarType;
|
||||
using torch::headeronly::toUnderlying;
|
||||
|
||||
EXPECT_EQ(toUnderlying(ScalarType::QUInt8), ScalarType::Byte);
|
||||
EXPECT_EQ(toUnderlying(ScalarType::QUInt4x2), ScalarType::Byte);
|
||||
EXPECT_EQ(toUnderlying(ScalarType::QUInt2x4), ScalarType::Byte);
|
||||
EXPECT_EQ(toUnderlying(ScalarType::QInt8), ScalarType::Char);
|
||||
EXPECT_EQ(toUnderlying(ScalarType::QInt32), ScalarType::Int);
|
||||
#define DEFINE_CHECK(_, name) \
|
||||
EXPECT_EQ(toUnderlying(ScalarType::name), ScalarType::name);
|
||||
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(DEFINE_CHECK);
|
||||
AT_FORALL_FLOAT8_TYPES(DEFINE_CHECK);
|
||||
#undef DEFINE_CHECK
|
||||
}
|
||||
|
||||
@ -156,12 +156,6 @@ TORCH_LIBRARY_IMPL(openreg, PrivateUse1, m) {
|
||||
}
|
||||
// LITERALINCLUDE END: CUSTOM OPERATOR DEFAULT
|
||||
|
||||
// LITERALINCLUDE START: CUSTOM OPERATOR FALLBACK
|
||||
TORCH_LIBRARY_IMPL(_, AutogradPrivateUse1, m) {
|
||||
m.fallback(torch::autograd::autogradNotImplementedFallback());
|
||||
}
|
||||
// LITERALINCLUDE END: CUSTOM OPERATOR FALLBACK
|
||||
|
||||
// The rest is for testing purposes
|
||||
TORCH_LIBRARY_IMPL(aten, PrivateUse1, m) {
|
||||
/*
|
||||
|
||||
@ -67,7 +67,21 @@ class TestFullyShardMemory(FSDPTest):
|
||||
# allocate the cuBLAS workspaces before measuring the memory usage
|
||||
# since the workspace size can differ between hardwares
|
||||
lin = torch.nn.Linear(768, 768, device=device_type)
|
||||
inp = torch.randn(1, 768, device=device_type)
|
||||
# NOTE: before https://github.com/pytorch/pytorch/pull/163955,
|
||||
# the input shape was (1, 768), so that the forward gemm used
|
||||
# cublaslt, and the backward used cublas.
|
||||
# With the aforementioned PR, and with shape (1, 768),
|
||||
# the cublas path is used both in forward and in backward,
|
||||
# altering peak memory usage not accounting for cublaslt.
|
||||
# Here we change the input shape to (2, 768), and that swaps
|
||||
# the cublas/cublaslt selection in the forward/backward,
|
||||
# but that does not affect the peak memory usage stored in `base_mem_mb`.
|
||||
# Reasons for the flip:
|
||||
# before PR: no Lt in addmm when mat2 has nrows/ncols <= 1,
|
||||
# after PR: no Lt in addmm when either mat1 or mat2 have nrows/ncols <= 1,
|
||||
# since the input preparation can swap matrices based on output
|
||||
# row-/col-majorness.
|
||||
inp = torch.randn(2, 768, device=device_type)
|
||||
lin(inp).sum().backward()
|
||||
torch.get_device_module(device_type).empty_cache()
|
||||
base_mem_mb = self._get_peak_active_memory_mb()
|
||||
|
||||
@ -127,8 +127,9 @@ def echo1(msg: str, exitcode: int = 0) -> str:
|
||||
print(f"exit {exitcode} from {rank}", file=sys.stderr)
|
||||
sys.exit(exitcode)
|
||||
else:
|
||||
print(f"{msg} stdout from {rank}")
|
||||
print(f"{msg} stderr from {rank}", file=sys.stderr)
|
||||
for m in msg.split(","):
|
||||
print(f"{m} stdout from {rank}")
|
||||
print(f"{m} stderr from {rank}", file=sys.stderr)
|
||||
return f"{msg}_{rank}"
|
||||
|
||||
|
||||
@ -247,6 +248,13 @@ class _StartProcessesTest(TestCase):
|
||||
for line in expected:
|
||||
self.assertIn(line, actual)
|
||||
|
||||
def assert_not_in_file(self, lines: list[str], filename: str) -> None:
|
||||
lines = [f"{line.rstrip()}\n" for line in lines]
|
||||
with open(filename) as fp:
|
||||
actual = fp.readlines()
|
||||
for line in lines:
|
||||
self.assertNotIn(line, actual)
|
||||
|
||||
def assert_pids_noexist(self, pids: dict[int, int]):
|
||||
for local_rank, pid in pids.items():
|
||||
with self.assertRaises(
|
||||
@ -360,8 +368,8 @@ if not (TEST_WITH_DEV_DBG_ASAN or IS_WINDOWS or IS_MACOS):
|
||||
|
||||
self.assertIsNone(pc.wait(timeout=0.1, period=0.01))
|
||||
self.assertIsNotNone(pc.wait(period=0.1))
|
||||
self.assertTrue(pc._stderr_tail.stopped())
|
||||
self.assertTrue(pc._stdout_tail.stopped())
|
||||
for tail_log in pc._tail_logs:
|
||||
self.assertTrue(tail_log.stopped())
|
||||
|
||||
def test_pcontext_wait_on_a_child_thread(self):
|
||||
asyncio.run(asyncio.to_thread(self.test_pcontext_wait))
|
||||
@ -379,8 +387,8 @@ if not (TEST_WITH_DEV_DBG_ASAN or IS_WINDOWS or IS_MACOS):
|
||||
pids = pc.pids()
|
||||
pc.close()
|
||||
self.assert_pids_noexist(pids)
|
||||
self.assertTrue(pc._stderr_tail.stopped())
|
||||
self.assertTrue(pc._stdout_tail.stopped())
|
||||
for tail_log in pc._tail_logs:
|
||||
self.assertTrue(tail_log.stopped())
|
||||
|
||||
def test_function_with_tensor(self):
|
||||
for start_method in self._start_methods:
|
||||
@ -482,8 +490,8 @@ if not (TEST_WITH_DEV_DBG_ASAN or IS_WINDOWS or IS_MACOS):
|
||||
int(error_file_data["message"]["extraInfo"]["timestamp"]),
|
||||
int(failure.timestamp),
|
||||
)
|
||||
self.assertTrue(pc._stderr_tail.stopped())
|
||||
self.assertTrue(pc._stdout_tail.stopped())
|
||||
for tail_log in pc._tail_logs:
|
||||
self.assertTrue(tail_log.stopped())
|
||||
|
||||
def test_wait_for_all_child_procs_to_exit(self):
|
||||
"""
|
||||
@ -580,8 +588,8 @@ if not (TEST_WITH_DEV_DBG_ASAN or IS_WINDOWS or IS_MACOS):
|
||||
self.assert_in_file([], results.stdouts[0])
|
||||
self.assertFalse(results.stderrs[1])
|
||||
self.assertFalse(results.stdouts[1])
|
||||
self.assertTrue(pc._stderr_tail.stopped())
|
||||
self.assertTrue(pc._stdout_tail.stopped())
|
||||
for tail_log in pc._tail_logs:
|
||||
self.assertTrue(tail_log.stopped())
|
||||
|
||||
failure = results.failures[1]
|
||||
self.assertEqual(-15, failure.exitcode)
|
||||
@ -731,8 +739,37 @@ if not (TEST_WITH_DEV_DBG_ASAN or IS_WINDOWS or IS_MACOS):
|
||||
self.assert_in_file(["hello stderr from 0"], pc.stderrs[0])
|
||||
self.assert_in_file(["world stderr from 1"], pc.stderrs[1])
|
||||
self.assertFalse(pc.stdouts[1])
|
||||
self.assertTrue(pc._stderr_tail.stopped())
|
||||
self.assertTrue(pc._stdout_tail.stopped())
|
||||
for tail_log in pc._tail_logs:
|
||||
self.assertTrue(tail_log.stopped())
|
||||
|
||||
def test_binary_duplicate_log_filters(self):
|
||||
pc = start_processes(
|
||||
name="trainer",
|
||||
entrypoint=bin("echo1.py"),
|
||||
args={0: ("helloA,helloB",), 1: ("worldA,worldB",)},
|
||||
envs={0: {"RANK": "0"}, 1: {"RANK": "1"}},
|
||||
logs_specs=DefaultLogsSpecs(
|
||||
log_dir=self.log_dir(),
|
||||
redirects={0: Std.ERR, 1: Std.NONE},
|
||||
tee={0: Std.OUT, 1: Std.ERR},
|
||||
),
|
||||
log_line_prefixes={0: "[rank0]:", 1: "[rank1]:"},
|
||||
duplicate_stdout_filters=["helloA"],
|
||||
duplicate_stderr_filters=["worldA", "B"],
|
||||
start_method="spawn",
|
||||
)
|
||||
|
||||
result = pc.wait()
|
||||
|
||||
self.assertFalse(result.is_failed())
|
||||
self.assert_in_file(["[rank0]:helloA stdout from 0"], pc.filtered_stdout)
|
||||
self.assert_not_in_file(
|
||||
["[rank0]:helloB stdout from 0"], pc.filtered_stdout
|
||||
)
|
||||
self.assert_in_file(["[rank1]:worldA stderr from 1"], pc.filtered_stderr)
|
||||
self.assert_in_file(["[rank1]:worldB stderr from 1"], pc.filtered_stderr)
|
||||
for tail_log in pc._tail_logs:
|
||||
self.assertTrue(tail_log.stopped())
|
||||
|
||||
|
||||
# tests incompatible with tsan or asan, the redirect functionality does not work on macos or windows
|
||||
@ -794,8 +831,44 @@ if not (TEST_WITH_DEV_DBG_ASAN or IS_WINDOWS or IS_MACOS or IS_CI):
|
||||
self.assert_in_file(["hello stderr from 0"], pc.stderrs[0])
|
||||
self.assert_in_file(["world stderr from 1"], pc.stderrs[1])
|
||||
self.assertFalse(pc.stdouts[1])
|
||||
self.assertTrue(pc._stderr_tail.stopped())
|
||||
self.assertTrue(pc._stdout_tail.stopped())
|
||||
for tail_log in pc._tail_logs:
|
||||
self.assertTrue(tail_log.stopped())
|
||||
|
||||
def test_function_duplicate_log_filters(self):
|
||||
for start_method in self._start_methods:
|
||||
with self.subTest(start_method=start_method):
|
||||
pc = start_processes(
|
||||
name="trainer",
|
||||
entrypoint=echo1,
|
||||
args={0: ("helloA,helloB",), 1: ("worldA,worldB",)},
|
||||
envs={0: {"RANK": "0"}, 1: {"RANK": "1"}},
|
||||
logs_specs=DefaultLogsSpecs(
|
||||
log_dir=self.log_dir(),
|
||||
redirects={0: Std.ERR, 1: Std.NONE},
|
||||
tee={0: Std.OUT, 1: Std.ERR},
|
||||
),
|
||||
duplicate_stdout_filters=["helloA"],
|
||||
duplicate_stderr_filters=["worldA", "B"],
|
||||
start_method="spawn",
|
||||
)
|
||||
|
||||
result = pc.wait()
|
||||
|
||||
self.assertFalse(result.is_failed())
|
||||
self.assert_in_file(
|
||||
["[trainer0]:helloA stdout from 0"], pc.filtered_stdout
|
||||
)
|
||||
self.assert_not_in_file(
|
||||
["[trainer0]:helloB stdout from 0"], pc.filtered_stdout
|
||||
)
|
||||
self.assert_in_file(
|
||||
["[trainer1]:worldA stderr from 1"], pc.filtered_stderr
|
||||
)
|
||||
self.assert_in_file(
|
||||
["[trainer1]:worldB stderr from 1"], pc.filtered_stderr
|
||||
)
|
||||
for tail_log in pc._tail_logs:
|
||||
self.assertTrue(tail_log.stopped())
|
||||
|
||||
def test_function(self):
|
||||
for start_method, redirs in product(self._start_methods, redirects_all()):
|
||||
@ -880,8 +953,8 @@ if not (TEST_WITH_DEV_DBG_ASAN or IS_WINDOWS or IS_MACOS or IS_CI):
|
||||
self.assertFalse(results.stdouts[0])
|
||||
self.assertFalse(results.stderrs[1])
|
||||
self.assertFalse(results.stdouts[1])
|
||||
self.assertTrue(pc._stderr_tail.stopped())
|
||||
self.assertTrue(pc._stdout_tail.stopped())
|
||||
for tail_log in pc._tail_logs:
|
||||
self.assertTrue(tail_log.stopped())
|
||||
|
||||
def test_no_zombie_process_function(self):
|
||||
signals = [signal.SIGTERM, signal.SIGINT, signal.SIGHUP, signal.SIGQUIT]
|
||||
|
||||
@ -23,5 +23,6 @@ if __name__ == "__main__":
|
||||
print(f"exit {exitcode} from {rank}", file=sys.stderr)
|
||||
sys.exit(exitcode)
|
||||
else:
|
||||
print(f"{args.msg} stdout from {rank}")
|
||||
print(f"{args.msg} stderr from {rank}", file=sys.stderr)
|
||||
for msg in args.msg.split(","):
|
||||
print(f"{msg} stdout from {rank}")
|
||||
print(f"{msg} stderr from {rank}", file=sys.stderr)
|
||||
|
||||
@ -84,6 +84,53 @@ class TailLogTest(unittest.TestCase):
|
||||
)
|
||||
self.assertTrue(tail.stopped())
|
||||
|
||||
def test_tail_write_to_dst_file(self):
|
||||
"""
|
||||
writer() writes 0 - max (on number on each line) to a log file.
|
||||
Run nprocs such writers and tail the log files into a temp file
|
||||
and validate that all lines are accounted for.
|
||||
"""
|
||||
nprocs = 32
|
||||
max = 1000
|
||||
interval_sec = 0.0001
|
||||
|
||||
log_files = {
|
||||
local_rank: os.path.join(self.test_dir, f"{local_rank}_stdout.log")
|
||||
for local_rank in range(nprocs)
|
||||
}
|
||||
|
||||
dst = os.path.join(self.test_dir, "tailed_stdout.log")
|
||||
tail = TailLog(
|
||||
name="writer", log_files=log_files, dst=dst, interval_sec=interval_sec
|
||||
).start()
|
||||
# sleep here is intentional to ensure that the log tail
|
||||
# can gracefully handle and wait for non-existent log files
|
||||
time.sleep(interval_sec * 10)
|
||||
|
||||
futs = []
|
||||
for local_rank, file in log_files.items():
|
||||
f = self.threadpool.submit(
|
||||
write, max=max, sleep=interval_sec * local_rank, file=file
|
||||
)
|
||||
futs.append(f)
|
||||
|
||||
wait(futs, return_when=ALL_COMPLETED)
|
||||
self.assertFalse(tail.stopped())
|
||||
tail.stop()
|
||||
|
||||
actual: dict[int, set[int]] = {}
|
||||
with open(dst) as dst_file:
|
||||
for line in dst_file:
|
||||
header, num = line.split(":")
|
||||
nums = actual.setdefault(header, set())
|
||||
nums.add(int(num))
|
||||
|
||||
self.assertEqual(nprocs, len(actual))
|
||||
self.assertEqual(
|
||||
{f"[writer{i}]": set(range(max)) for i in range(nprocs)}, actual
|
||||
)
|
||||
self.assertTrue(tail.stopped())
|
||||
|
||||
def test_tail_with_custom_prefix(self):
|
||||
"""
|
||||
writer() writes 0 - max (on number on each line) to a log file.
|
||||
@ -131,6 +178,52 @@ class TailLogTest(unittest.TestCase):
|
||||
self.assertIn(f"[worker{i}][{i}]", headers)
|
||||
self.assertTrue(tail.stopped())
|
||||
|
||||
def test_tail_with_custom_filter(self):
|
||||
"""
|
||||
writer() writes 0 - max (on number on each line) to a log file.
|
||||
Run nprocs such writers and tail the log files into an IOString
|
||||
and validate that all lines are accounted for.
|
||||
"""
|
||||
nprocs = 3
|
||||
max = 20
|
||||
interval_sec = 0.0001
|
||||
|
||||
log_files = {
|
||||
local_rank: os.path.join(self.test_dir, f"{local_rank}_stdout.log")
|
||||
for local_rank in range(nprocs)
|
||||
}
|
||||
|
||||
dst = io.StringIO()
|
||||
tail = TailLog(
|
||||
"writer",
|
||||
log_files,
|
||||
dst,
|
||||
interval_sec=interval_sec,
|
||||
log_line_filter=lambda line: "2" in line, # only print lines containing '2'
|
||||
).start()
|
||||
# sleep here is intentional to ensure that the log tail
|
||||
# can gracefully handle and wait for non-existent log files
|
||||
time.sleep(interval_sec * 10)
|
||||
futs = []
|
||||
for local_rank, file in log_files.items():
|
||||
f = self.threadpool.submit(
|
||||
write, max=max, sleep=interval_sec * local_rank, file=file
|
||||
)
|
||||
futs.append(f)
|
||||
wait(futs, return_when=ALL_COMPLETED)
|
||||
self.assertFalse(tail.stopped())
|
||||
tail.stop()
|
||||
dst.seek(0)
|
||||
|
||||
actual: dict[int, set[int]] = {}
|
||||
for line in dst.readlines():
|
||||
header, num = line.split(":")
|
||||
nums = actual.setdefault(header, set())
|
||||
nums.add(int(num))
|
||||
self.assertEqual(nprocs, len(actual))
|
||||
self.assertEqual({f"[writer{i}]": {2, 12} for i in range(nprocs)}, actual)
|
||||
self.assertTrue(tail.stopped())
|
||||
|
||||
def test_tail_no_files(self):
|
||||
"""
|
||||
Ensures that the log tail can gracefully handle no log files
|
||||
|
||||
@ -55,9 +55,10 @@ class SignalHandlingTest(TestCase):
|
||||
mock_threading.main_thread.return_value
|
||||
)
|
||||
mock_pcontext = MagicMock(spec=PContext)
|
||||
# Mock the _stdout_tail and _stderr_tail attributes
|
||||
mock_pcontext._stdout_tail = MagicMock()
|
||||
mock_pcontext._stderr_tail = MagicMock()
|
||||
# Mock the stdout_tail and stderr_tail
|
||||
mock_stdout_tail = MagicMock()
|
||||
mock_stderr_tail = MagicMock()
|
||||
mock_pcontext._tail_logs = [mock_stdout_tail, mock_stderr_tail]
|
||||
|
||||
# Remove environment variable if it exists to test default behavior
|
||||
if "TORCHELASTIC_SIGNALS_TO_HANDLE" in os.environ:
|
||||
@ -84,8 +85,8 @@ class SignalHandlingTest(TestCase):
|
||||
# Verify _start was called
|
||||
mock_pcontext._start.assert_called_once()
|
||||
# Verify _stdout_tail.start() and _stderr_tail.start() were called
|
||||
mock_pcontext._stdout_tail.start.assert_called_once()
|
||||
mock_pcontext._stderr_tail.start.assert_called_once()
|
||||
mock_stdout_tail.start.assert_called_once()
|
||||
mock_stderr_tail.start.assert_called_once()
|
||||
|
||||
@patch("torch.distributed.elastic.multiprocessing.api.threading")
|
||||
@patch("torch.distributed.elastic.multiprocessing.api.signal")
|
||||
@ -99,9 +100,10 @@ class SignalHandlingTest(TestCase):
|
||||
mock_threading.main_thread.return_value
|
||||
)
|
||||
mock_pcontext = MagicMock(spec=PContext)
|
||||
# Mock the _stdout_tail and _stderr_tail attributes
|
||||
mock_pcontext._stdout_tail = MagicMock()
|
||||
mock_pcontext._stderr_tail = MagicMock()
|
||||
# Mock the stdout_tail and stderr_tail
|
||||
mock_stdout_tail = MagicMock()
|
||||
mock_stderr_tail = MagicMock()
|
||||
mock_pcontext._tail_logs = [mock_stdout_tail, mock_stderr_tail]
|
||||
|
||||
# Set custom signals in the environment variable
|
||||
os.environ["TORCHELASTIC_SIGNALS_TO_HANDLE"] = "SIGTERM,SIGUSR1,SIGUSR2"
|
||||
@ -139,9 +141,10 @@ class SignalHandlingTest(TestCase):
|
||||
mock_threading.main_thread.return_value
|
||||
)
|
||||
mock_pcontext = MagicMock(spec=PContext)
|
||||
# Mock the _stdout_tail and _stderr_tail attributes
|
||||
mock_pcontext._stdout_tail = MagicMock()
|
||||
mock_pcontext._stderr_tail = MagicMock()
|
||||
# Mock the stdout_tail and stderr_tail
|
||||
mock_stdout_tail = MagicMock()
|
||||
mock_stderr_tail = MagicMock()
|
||||
mock_pcontext._tail_logs = [mock_stdout_tail, mock_stderr_tail]
|
||||
|
||||
# Set invalid signals in the environment variable
|
||||
os.environ["TORCHELASTIC_SIGNALS_TO_HANDLE"] = "SIGTERM,INVALID_SIGNAL"
|
||||
@ -180,9 +183,10 @@ class SignalHandlingTest(TestCase):
|
||||
mock_threading.main_thread.return_value
|
||||
)
|
||||
mock_pcontext = MagicMock(spec=PContext)
|
||||
# Mock the _stdout_tail and _stderr_tail attributes
|
||||
mock_pcontext._stdout_tail = MagicMock()
|
||||
mock_pcontext._stderr_tail = MagicMock()
|
||||
# Mock the stdout_tail and stderr_tail
|
||||
mock_stdout_tail = MagicMock()
|
||||
mock_stderr_tail = MagicMock()
|
||||
mock_pcontext._tail_logs = [mock_stdout_tail, mock_stderr_tail]
|
||||
|
||||
# Set signals including ones not supported on Windows
|
||||
os.environ["TORCHELASTIC_SIGNALS_TO_HANDLE"] = "SIGTERM,SIGHUP,SIGUSR1"
|
||||
@ -234,9 +238,10 @@ class SignalHandlingTest(TestCase):
|
||||
mock_threading.current_thread.return_value = MagicMock() # Not the main thread
|
||||
mock_threading.main_thread.return_value = MagicMock()
|
||||
mock_pcontext = MagicMock(spec=PContext)
|
||||
# Mock the _stdout_tail and _stderr_tail attributes
|
||||
mock_pcontext._stdout_tail = MagicMock()
|
||||
mock_pcontext._stderr_tail = MagicMock()
|
||||
# Mock the stdout_tail and stderr_tail
|
||||
mock_stdout_tail = MagicMock()
|
||||
mock_stderr_tail = MagicMock()
|
||||
mock_pcontext._tail_logs = [mock_stdout_tail, mock_stderr_tail]
|
||||
|
||||
# Call the start method
|
||||
PContext.start(mock_pcontext)
|
||||
@ -262,9 +267,10 @@ class SignalHandlingTest(TestCase):
|
||||
mock_threading.main_thread.return_value
|
||||
)
|
||||
mock_pcontext = MagicMock(spec=PContext)
|
||||
# Mock the _stdout_tail and _stderr_tail attributes
|
||||
mock_pcontext._stdout_tail = MagicMock()
|
||||
mock_pcontext._stderr_tail = MagicMock()
|
||||
# Mock the stdout_tail and stderr_tail
|
||||
mock_stdout_tail = MagicMock()
|
||||
mock_stderr_tail = MagicMock()
|
||||
mock_pcontext._tail_logs = [mock_stdout_tail, mock_stderr_tail]
|
||||
|
||||
# Set environment variable to include SIGUSR1 and SIGUSR2
|
||||
os.environ["TORCHELASTIC_SIGNALS_TO_HANDLE"] = "SIGUSR1,SIGUSR2"
|
||||
@ -323,8 +329,8 @@ class SignalHandlingTest(TestCase):
|
||||
# Verify _start was called
|
||||
mock_pcontext._start.assert_called_once()
|
||||
# Verify _stdout_tail.start() and _stderr_tail.start() were called
|
||||
mock_pcontext._stdout_tail.start.assert_called_once()
|
||||
mock_pcontext._stderr_tail.start.assert_called_once()
|
||||
mock_stdout_tail.start.assert_called_once()
|
||||
mock_stderr_tail.start.assert_called_once()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -15,7 +15,7 @@ from torch.testing._internal.common_utils import (
|
||||
TestCase,
|
||||
)
|
||||
from torch.testing._internal.distributed.fake_pg import FakeStore
|
||||
from torch.utils._debug_mode import DebugMode
|
||||
from torch.utils._debug_mode import _OpCall, _RedistributeCall, DebugMode
|
||||
from torch.utils._python_dispatch import TorchDispatchMode
|
||||
|
||||
|
||||
@ -60,6 +60,10 @@ class TestDTensorDebugMode(TestCase):
|
||||
aten::sum(t: f32[1, 32])""",
|
||||
)
|
||||
|
||||
self.assertTrue(isinstance(debug_mode.operators[0], _OpCall))
|
||||
self.assertTrue(isinstance(debug_mode.operators[2], _RedistributeCall))
|
||||
self.assertEqual(next(iter(debug_mode.operators[1])), torch.ops.aten.mm.default)
|
||||
|
||||
def test_debug_string_inside_context(self):
|
||||
mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
|
||||
|
||||
@ -330,6 +334,46 @@ class TestDTensorDebugMode(TestCase):
|
||||
f(x)
|
||||
self.assertEqual(len(debug_mode.debug_string()), 0)
|
||||
|
||||
def test_nn_module(self):
|
||||
class Foo(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.l1 = torch.nn.Linear(4, 4)
|
||||
self.l2 = torch.nn.Linear(4, 4)
|
||||
|
||||
def forward(self, x):
|
||||
return self.l2(self.l1(x))
|
||||
|
||||
class Bar(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.abc = Foo()
|
||||
self.xyz = torch.nn.Linear(4, 4)
|
||||
|
||||
def forward(self, x):
|
||||
return self.xyz(self.abc(x))
|
||||
|
||||
mod = Bar()
|
||||
inp = torch.randn(4, 4)
|
||||
with DebugMode(record_nn_module=True) as debug_mode:
|
||||
_ = mod(inp)
|
||||
|
||||
self.assertExpectedInline(
|
||||
debug_mode.debug_string(),
|
||||
"""\
|
||||
[nn.Mod] Bar
|
||||
[nn.Mod] Bar.abc
|
||||
[nn.Mod] Bar.abc.l1
|
||||
aten::t(t: f32[4, 4])
|
||||
aten::addmm(t: f32[4], t: f32[4, 4], t: f32[4, 4])
|
||||
[nn.Mod] Bar.abc.l2
|
||||
aten::t(t: f32[4, 4])
|
||||
aten::addmm(t: f32[4], t: f32[4, 4], t: f32[4, 4])
|
||||
[nn.Mod] Bar.xyz
|
||||
aten::t(t: f32[4, 4])
|
||||
aten::addmm(t: f32[4], t: f32[4, 4], t: f32[4, 4])""",
|
||||
)
|
||||
|
||||
|
||||
instantiate_parametrized_tests(TestDTensorDebugMode)
|
||||
|
||||
|
||||
@ -6,7 +6,10 @@ import unittest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.fx.traceback as fx_traceback
|
||||
from torch._dynamo.functional_export import _dynamo_graph_capture_for_export
|
||||
from torch._dynamo.functional_export import (
|
||||
_dynamo_graph_capture_for_export,
|
||||
dynamo_graph_capture_for_export,
|
||||
)
|
||||
from torch._functorch.aot_autograd import aot_export_joint_with_descriptors
|
||||
from torch._functorch.partitioners import min_cut_rematerialization_partition
|
||||
from torch._guards import tracing, TracingContext
|
||||
@ -96,6 +99,13 @@ def strict_export_and_aot_export_joint_with_descriptors(model, inputs):
|
||||
return aot_export_joint_with_descriptors_alone(ep.module(), inputs)
|
||||
|
||||
|
||||
def graph_capture_and_aot_export_joint_with_descriptors_v2(model, inputs):
|
||||
gm = dynamo_graph_capture_for_export(model)(inputs)
|
||||
fake_mode = gm.meta.get("fake_mode", None)
|
||||
with tracing(TracingContext(fake_mode)):
|
||||
return aot_export_joint_with_descriptors_alone(gm, inputs)
|
||||
|
||||
|
||||
def graph_capture_and_aot_export_joint_with_descriptors(model, inputs):
|
||||
with torch._dynamo.config.patch(install_free_tensors=True):
|
||||
# TODO: switch to use the official graph_capture API once it is ready
|
||||
@ -288,6 +298,7 @@ class DTensorExportTest(TestCase):
|
||||
@parametrize(
|
||||
"export_fn",
|
||||
[
|
||||
graph_capture_and_aot_export_joint_with_descriptors_v2,
|
||||
graph_capture_and_aot_export_joint_with_descriptors,
|
||||
aot_export_joint_with_descriptors_alone,
|
||||
],
|
||||
@ -307,7 +318,21 @@ class DTensorExportTest(TestCase):
|
||||
def test_annotate_aot_export_joint_with_descriptors_alone(self):
|
||||
self._run_test(aot_export_joint_with_descriptors_alone, True)
|
||||
|
||||
def test_dynamic_shapes(self):
|
||||
@parametrize(
|
||||
"export_fn_with_answer",
|
||||
[
|
||||
(
|
||||
graph_capture_and_aot_export_joint_with_descriptors_v2,
|
||||
"[[4, 10], [4], [10, 4], [10], [4, 10], [4], [10, 4], [10], [s64, 10], [s64, 10]]",
|
||||
),
|
||||
(
|
||||
graph_capture_and_aot_export_joint_with_descriptors,
|
||||
"[[4, 10], [4], [10, 4], [10], [s22, 10], [s22, 10]]",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_dynamic_shapes(self, export_fn_with_answer):
|
||||
export_fn, answer = export_fn_with_answer
|
||||
dp_degree = 2
|
||||
tp_degree = self.world_size // dp_degree
|
||||
|
||||
@ -331,7 +356,7 @@ class DTensorExportTest(TestCase):
|
||||
inputs = distribute_tensor(inputs, mesh_2d["tp"], placements=[Replicate()])
|
||||
torch._dynamo.mark_dynamic(inputs, 0, min=5, max=100)
|
||||
|
||||
joint_gm = graph_capture_and_aot_export_joint_with_descriptors(tp_model, inputs)
|
||||
joint_gm = export_fn(tp_model, inputs)
|
||||
|
||||
res = []
|
||||
for node in joint_gm.graph.nodes:
|
||||
@ -341,12 +366,16 @@ class DTensorExportTest(TestCase):
|
||||
if isinstance(fake_val, torch._subclasses.fake_tensor.FakeTensor):
|
||||
res.append(list(fake_val.shape))
|
||||
|
||||
self.assertExpectedInline(
|
||||
str(res),
|
||||
"""[[4, 10], [4], [10, 4], [10], [s22, 10], [s22, 10]]""",
|
||||
)
|
||||
self.assertEqual(str(res), answer)
|
||||
|
||||
def test_einsum_dtensor_export(self):
|
||||
@parametrize(
|
||||
"export_fn",
|
||||
[
|
||||
dynamo_graph_capture_for_export,
|
||||
_dynamo_graph_capture_for_export,
|
||||
],
|
||||
)
|
||||
def test_einsum_dtensor_export(self, export_fn):
|
||||
"""Test exporting a model with einsum that has DTensor inputs/outputs with side effects"""
|
||||
world_size = 4
|
||||
# Create device mesh
|
||||
@ -366,9 +395,7 @@ class DTensorExportTest(TestCase):
|
||||
output = model(x_dtensor, y_dtensor, z_dtensor)
|
||||
with torch._dynamo.config.patch(install_free_tensors=True):
|
||||
# TODO: switch to use the official graph_capture API once it is ready
|
||||
gm = _dynamo_graph_capture_for_export(model)(
|
||||
x_dtensor, y_dtensor, z_dtensor
|
||||
)
|
||||
gm = export_fn(model)(x_dtensor, y_dtensor, z_dtensor)
|
||||
output_gm = gm(x_dtensor, y_dtensor, z_dtensor)
|
||||
self.assertEqual(output, output_gm)
|
||||
|
||||
|
||||
@ -44,9 +44,22 @@ device_type = str(get_devtype())
|
||||
|
||||
def apply_reordering_and_get_graph(graph, out_li) -> None:
|
||||
gm = graph.owning_module
|
||||
from torch._inductor.config import aten_distributed_optimizations as dist_opts
|
||||
from torch._inductor.fx_passes.overlap_scheduling import schedule_overlap_bucketing
|
||||
|
||||
schedule_overlap_bucketing(gm)
|
||||
# Read config values, only pass non-None values to use function defaults
|
||||
kwargs: dict[str, object] = {}
|
||||
config_keys = (
|
||||
"collective_bucketing",
|
||||
"max_compute_pre_fetch",
|
||||
"custom_runtime_estimation",
|
||||
"insert_overlap_deps",
|
||||
)
|
||||
for key in config_keys:
|
||||
if (val := getattr(dist_opts, key)) is not None:
|
||||
kwargs[key] = val
|
||||
|
||||
schedule_overlap_bucketing(gm, **kwargs)
|
||||
gm.graph.lint()
|
||||
out_li.append(str(gm.graph))
|
||||
|
||||
@ -62,14 +75,14 @@ def run_and_get_aten_graph(fn, *inputs):
|
||||
|
||||
def get_patches():
|
||||
return {
|
||||
"test_configs.estimate_aten_runtime": estimate_aten_runtime,
|
||||
"aten_distributed_optimizations.custom_runtime_estimation": estimate_aten_runtime,
|
||||
"reorder_for_locality": False,
|
||||
"triton.native_matmul": False,
|
||||
"reorder_for_compute_comm_overlap_passes": [],
|
||||
"compile_threads": 1,
|
||||
"force_disable_caches": True,
|
||||
# Messes up existing test strings
|
||||
"test_configs.aten_fx_overlap_insert_overlap_deps": False,
|
||||
"aten_distributed_optimizations.insert_overlap_deps": False,
|
||||
# interferes with testing, / custom estimation
|
||||
"test_configs.assume_bucketing_reduces_latency": False,
|
||||
}
|
||||
@ -351,21 +364,56 @@ graph():
|
||||
# these have no overlap opportunities
|
||||
self.assertEqual(counters["inductor"]["overlap_scheduling_bad_exposed"], 0)
|
||||
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
def test_overlap_scheduling_via_config(self):
|
||||
"""Test overlap scheduling enabled via config in post_grad pass."""
|
||||
|
||||
def func(a):
|
||||
ar = _functional_collectives.all_reduce(a, "sum", "0")
|
||||
b = torch.matmul(a, a)
|
||||
return torch.matmul(ar, b)
|
||||
|
||||
patches = {
|
||||
**get_patches(),
|
||||
"aten_distributed_optimizations.enable_overlap_scheduling": True,
|
||||
}
|
||||
|
||||
with _dynamo_dist_per_rank_init(
|
||||
self.rank,
|
||||
self.world_size,
|
||||
self.backend(device_type),
|
||||
fake_pg=not at_least_x_gpu(2),
|
||||
):
|
||||
inputs = torch.ones(4, 4, dtype=torch.float, device=device_type) + self.rank
|
||||
|
||||
with torch._inductor.config.patch(patches):
|
||||
compiled_func = torch.compile(func)
|
||||
out, code = run_and_get_code(compiled_func, inputs)
|
||||
|
||||
# Verify that wait_tensor is sinked below matmul
|
||||
FileCheck().check("all_reduce").check("mm").check("wait_tensor").check(
|
||||
"mm"
|
||||
).run(code[0])
|
||||
|
||||
correct = func(inputs)
|
||||
self.assertTrue(same(out, correct))
|
||||
self.assertEqual(counters["inductor"]["overlap_scheduling_exposed"], 0)
|
||||
|
||||
|
||||
def get_bucket_patches(compute_multiplier=1.0):
|
||||
estimate_aten_runtime_part = functools.partial(
|
||||
estimate_aten_runtime, compute_multiplier=compute_multiplier
|
||||
)
|
||||
return {
|
||||
"test_configs.estimate_aten_runtime": estimate_aten_runtime_part,
|
||||
"test_configs.aten_fx_overlap_preserving_bucketing": True,
|
||||
"aten_distributed_optimizations.custom_runtime_estimation": estimate_aten_runtime_part,
|
||||
"aten_distributed_optimizations.collective_bucketing": True,
|
||||
"reorder_for_locality": False,
|
||||
"triton.native_matmul": False,
|
||||
"reorder_for_compute_comm_overlap_passes": [],
|
||||
"compile_threads": 1,
|
||||
"force_disable_caches": True,
|
||||
# messes up test strings
|
||||
"test_configs.aten_fx_overlap_insert_overlap_deps": False,
|
||||
"aten_distributed_optimizations.insert_overlap_deps": False,
|
||||
# interferes with testing, / custom estimation
|
||||
"test_configs.assume_bucketing_reduces_latency": False,
|
||||
}
|
||||
@ -806,7 +854,7 @@ class TestComputeCommReorderingBucketing(TestComputeCommReorderingMultiProc):
|
||||
fake_pg=not at_least_x_gpu(2),
|
||||
),
|
||||
torch._inductor.config.patch(
|
||||
"test_configs.aten_fx_overlap_insert_overlap_deps", True
|
||||
"aten_distributed_optimizations.insert_overlap_deps", True
|
||||
),
|
||||
torch._inductor.config.patch(post_grad_custom_post_pass=apply),
|
||||
):
|
||||
|
||||
@ -471,6 +471,67 @@ from user code:
|
||||
assert hasattr(backend_result.compiled_fn, "serialize")
|
||||
self.assertIsNotNone(backend_result.compiled_fn.serialize)
|
||||
|
||||
def test_fullgraph_capture_with_pytree_module(self):
|
||||
from torch._dynamo.functional_export import dynamo_graph_capture_for_export
|
||||
|
||||
class Module(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.linear = torch.nn.Linear(3, 3)
|
||||
self.linear1 = torch.nn.Linear(3, 3)
|
||||
self.linear2 = torch.nn.Linear(3, 3)
|
||||
self.linear3 = torch.nn.Linear(3, 3)
|
||||
|
||||
def forward(self, x):
|
||||
return {
|
||||
"y": self.linear2(x[2] + 1),
|
||||
"z": self.linear3(x[1] - 1),
|
||||
"w": self.linear(x[0]["b"] + 2),
|
||||
"v": self.linear1(x[0]["a"] - 2),
|
||||
}
|
||||
|
||||
mod = Module()
|
||||
compiled_mod = dynamo_graph_capture_for_export(mod)(
|
||||
(
|
||||
{"a": torch.randn(3, 3), "b": torch.randn(3, 3)},
|
||||
torch.randn(3, 3),
|
||||
torch.randn(3, 3),
|
||||
)
|
||||
)
|
||||
|
||||
inputs = (
|
||||
{"a": torch.randn(3, 3), "b": torch.randn(3, 3)},
|
||||
torch.randn(3, 3),
|
||||
torch.randn(3, 3),
|
||||
)
|
||||
self.assertEqual(compiled_mod(inputs), mod(inputs))
|
||||
|
||||
def test_fullgraph_capture_with_pytree_func(self):
|
||||
from torch._dynamo.functional_export import dynamo_graph_capture_for_export
|
||||
|
||||
def foo(x):
|
||||
return {
|
||||
"y": x[2] + 1,
|
||||
"z": x[1] - 1,
|
||||
"w": x[0]["b"] + 2,
|
||||
"v": x[0]["a"] - 2,
|
||||
}
|
||||
|
||||
compiled_foo = dynamo_graph_capture_for_export(foo)(
|
||||
(
|
||||
{"a": torch.randn(4, 3), "b": torch.randn(3, 2)},
|
||||
torch.randn(2, 3),
|
||||
torch.randn(3, 4),
|
||||
)
|
||||
)
|
||||
|
||||
inputs = (
|
||||
{"a": torch.randn(4, 3), "b": torch.randn(3, 2)},
|
||||
torch.randn(2, 3),
|
||||
torch.randn(3, 4),
|
||||
)
|
||||
self.assertEqual(compiled_foo(inputs), foo(inputs))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
||||
@ -2302,30 +2302,27 @@ class FunctionTests(torch._dynamo.test_case.TestCase):
|
||||
|
||||
return augment(x)
|
||||
|
||||
# # This is to test the new syntax for pattern matching
|
||||
# # ("match ... case ...") added on python 3.10.
|
||||
# # Uncomment these test cases if you run on 3.10+
|
||||
# @make_test
|
||||
# def test_match_sequence(a):
|
||||
# point = (5, 8)
|
||||
# match point:
|
||||
# case (0, 0):
|
||||
# return a
|
||||
# case (0, y):
|
||||
# return a - y
|
||||
# case (x, 0):
|
||||
# return a + x
|
||||
# case (x, y):
|
||||
# return a + x - y
|
||||
@make_test
|
||||
def test_match_sequence(a):
|
||||
point = (5, 8)
|
||||
match point:
|
||||
case (0, 0):
|
||||
return a
|
||||
case (0, y):
|
||||
return a - y
|
||||
case (x, 0):
|
||||
return a + x
|
||||
case (x, y):
|
||||
return a + x - y
|
||||
|
||||
# @make_test
|
||||
# def test_match_mapping_and_match_keys(x):
|
||||
# param = {"a": 0.5}
|
||||
# match param:
|
||||
# case {"a": param}:
|
||||
# return x * param
|
||||
# case {"b": param}:
|
||||
# return x / param
|
||||
@make_test
|
||||
def test_match_mapping_and_match_keys(x):
|
||||
param = {"a": 0.5}
|
||||
match param:
|
||||
case {"a": param}:
|
||||
return x * param
|
||||
case {"b": param}:
|
||||
return x / param
|
||||
|
||||
def test_math_radians(self):
|
||||
def func(x, a):
|
||||
|
||||
@ -288,6 +288,18 @@ class AnnotateTests(torch._dynamo.test_case.TestCase):
|
||||
('call_function', 'mul_2', {'pp_stage': 0, 'fdsp_bucket': 0})""", # noqa: B950
|
||||
)
|
||||
|
||||
def test_graph_break(self):
|
||||
def fn(x):
|
||||
with torch.fx.traceback.annotate({"pp_stage": 0}):
|
||||
x = torch.sin(x)
|
||||
torch._dynamo.graph_break()
|
||||
x = torch.cos(x)
|
||||
return x
|
||||
|
||||
opt_fn = torch.compile(fn, backend="eager")
|
||||
x = torch.randn(10, requires_grad=True)
|
||||
self.assertEqual(fn(x), opt_fn(x))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
||||
@ -1,17 +1,24 @@
|
||||
# Owner(s): ["module: dynamo"]
|
||||
|
||||
import functools
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
import torch._inductor.test_case
|
||||
import torch.fx.traceback as fx_traceback
|
||||
import torch.utils.checkpoint
|
||||
from torch._dynamo.backends.common import aot_autograd
|
||||
from torch._guards import detect_fake_mode
|
||||
from torch._inductor.test_case import run_tests
|
||||
from torch._inductor.utils import run_fw_bw_and_get_code
|
||||
from torch.fx._graph_pickler import GraphPickler
|
||||
from torch.fx.passes.regional_inductor import regional_inductor
|
||||
from torch.nn.attention.flex_attention import create_block_mask, flex_attention
|
||||
from torch.testing._internal.common_utils import skipIfTorchDynamo
|
||||
from torch.testing._internal.common_utils import (
|
||||
instantiate_parametrized_tests,
|
||||
parametrize,
|
||||
skipIfTorchDynamo,
|
||||
)
|
||||
from torch.testing._internal.triton_utils import requires_cuda_and_triton
|
||||
|
||||
|
||||
@ -36,7 +43,29 @@ from torch.testing._internal.triton_utils import requires_cuda_and_triton
|
||||
# f) disallow nested regional compile
|
||||
|
||||
|
||||
def aot_eager_regional_inductor():
|
||||
def aot_eager_regional_inductor(serialize=False):
|
||||
if serialize:
|
||||
|
||||
def regional_inductor_pickle(gm, *example_args):
|
||||
result = regional_inductor(gm, *example_args)
|
||||
serialized = GraphPickler.dumps(result)
|
||||
|
||||
fake_mode = detect_fake_mode(example_args)
|
||||
assert fake_mode is not None
|
||||
# Serialize and deserialize the result to confirm pickling works
|
||||
# Use a fresh tracing context on the new process
|
||||
context = torch._guards.TracingContext(fake_mode)
|
||||
with torch._guards.tracing(context):
|
||||
result = GraphPickler.loads(serialized, fake_mode)
|
||||
assert isinstance(result, torch.fx.GraphModule)
|
||||
result.recompile()
|
||||
return result
|
||||
|
||||
return aot_autograd(
|
||||
fw_compiler=regional_inductor_pickle,
|
||||
bw_compiler=regional_inductor_pickle,
|
||||
)
|
||||
|
||||
return aot_autograd(
|
||||
fw_compiler=regional_inductor,
|
||||
bw_compiler=regional_inductor,
|
||||
@ -44,8 +73,10 @@ def aot_eager_regional_inductor():
|
||||
|
||||
|
||||
@skipIfTorchDynamo("Not a suitable dynamo wrapped test")
|
||||
@instantiate_parametrized_tests
|
||||
class RegionalInductorTests(torch._inductor.test_case.TestCase):
|
||||
def test_simple(self):
|
||||
@parametrize("serialize", [False, True])
|
||||
def test_simple(self, serialize):
|
||||
def fn(x, y):
|
||||
sin = torch.sin(x)
|
||||
|
||||
@ -56,7 +87,7 @@ class RegionalInductorTests(torch._inductor.test_case.TestCase):
|
||||
return torch.sin(add)
|
||||
|
||||
opt_fn = torch.compile(
|
||||
fn, backend=aot_eager_regional_inductor(), fullgraph=True
|
||||
fn, backend=aot_eager_regional_inductor(serialize=serialize), fullgraph=True
|
||||
)
|
||||
x = torch.randn(10, requires_grad=True)
|
||||
y = torch.randn(10, requires_grad=True)
|
||||
@ -65,7 +96,8 @@ class RegionalInductorTests(torch._inductor.test_case.TestCase):
|
||||
_, codes = run_fw_bw_and_get_code(lambda: opt_fn(x, y))
|
||||
self.assertEqual(len(codes), 2)
|
||||
|
||||
def test_repeated_blocks(self):
|
||||
@parametrize("serialize", [False, True])
|
||||
def test_repeated_blocks(self, serialize):
|
||||
def fn(x, y):
|
||||
sin = torch.sin(x)
|
||||
|
||||
@ -86,7 +118,9 @@ class RegionalInductorTests(torch._inductor.test_case.TestCase):
|
||||
mod = Mod()
|
||||
|
||||
opt_mod = torch.compile(
|
||||
mod, backend=aot_eager_regional_inductor(), fullgraph=True
|
||||
mod,
|
||||
backend=aot_eager_regional_inductor(serialize=serialize),
|
||||
fullgraph=True,
|
||||
)
|
||||
x = torch.randn(10, requires_grad=True)
|
||||
y = torch.randn(10, requires_grad=True)
|
||||
@ -96,7 +130,8 @@ class RegionalInductorTests(torch._inductor.test_case.TestCase):
|
||||
_, codes = run_fw_bw_and_get_code(lambda: opt_mod(x, y))
|
||||
self.assertEqual(len(codes), 4)
|
||||
|
||||
def test_invoke_subgraph(self):
|
||||
@parametrize("serialize", [False, True])
|
||||
def test_invoke_subgraph(self, serialize):
|
||||
# Checks that get_attr nodes custom metadata is propagated
|
||||
@torch.compiler.nested_compile_region
|
||||
def gn(x):
|
||||
@ -109,15 +144,21 @@ class RegionalInductorTests(torch._inductor.test_case.TestCase):
|
||||
return torch.sigmoid(z)
|
||||
|
||||
opt_fn = torch.compile(
|
||||
fn, backend=aot_eager_regional_inductor(), fullgraph=True
|
||||
fn, backend=aot_eager_regional_inductor(serialize=serialize), fullgraph=True
|
||||
)
|
||||
x = torch.randn(10, requires_grad=True)
|
||||
|
||||
_, codes = run_fw_bw_and_get_code(lambda: opt_fn(x))
|
||||
self.assertEqual(len(codes), 2)
|
||||
|
||||
def test_invoke_subgraph_inner(self):
|
||||
@parametrize("serialize", [False, True])
|
||||
def test_invoke_subgraph_inner(self, serialize):
|
||||
# Checks that the inductor regions are searched recursively.
|
||||
|
||||
# TODO: GraphPickler does not recompile nested subgraphs?
|
||||
if serialize:
|
||||
raise unittest.SkipTest("GraphPickler doesn't recompile nested subgraphs")
|
||||
|
||||
@torch.compiler.nested_compile_region
|
||||
def gn(x):
|
||||
with fx_traceback.annotate({"compile_with_inductor": 0}):
|
||||
@ -131,7 +172,7 @@ class RegionalInductorTests(torch._inductor.test_case.TestCase):
|
||||
return torch.sigmoid(x)
|
||||
|
||||
opt_fn = torch.compile(
|
||||
fn, backend=aot_eager_regional_inductor(), fullgraph=True
|
||||
fn, backend=aot_eager_regional_inductor(serialize=serialize), fullgraph=True
|
||||
)
|
||||
x = torch.randn(10, requires_grad=True)
|
||||
|
||||
@ -141,7 +182,14 @@ class RegionalInductorTests(torch._inductor.test_case.TestCase):
|
||||
self.assertEqual(len(codes), 2)
|
||||
|
||||
@requires_cuda_and_triton
|
||||
def test_flex_attention(self):
|
||||
@parametrize("serialize", [False, True])
|
||||
def test_flex_attention(self, serialize):
|
||||
if serialize:
|
||||
# TODO: Fixed in next PR
|
||||
raise unittest.SkipTest(
|
||||
"FlexAttentionBackward isn't marked cacheable even though it is"
|
||||
)
|
||||
|
||||
def _squared(score, b, h, m, n):
|
||||
return score * score
|
||||
|
||||
@ -170,7 +218,7 @@ class RegionalInductorTests(torch._inductor.test_case.TestCase):
|
||||
|
||||
opt_fn = torch.compile(
|
||||
fn,
|
||||
backend=aot_eager_regional_inductor(),
|
||||
backend=aot_eager_regional_inductor(serialize),
|
||||
fullgraph=True,
|
||||
)
|
||||
|
||||
@ -179,7 +227,13 @@ class RegionalInductorTests(torch._inductor.test_case.TestCase):
|
||||
self.assertEqual(len(codes), 2)
|
||||
|
||||
@requires_cuda_and_triton
|
||||
def test_selective_ac_flex(self):
|
||||
@parametrize("serialize", [False, True])
|
||||
def test_selective_ac_flex(self, serialize):
|
||||
if serialize:
|
||||
raise unittest.SkipTest(
|
||||
"FlexAttentionBackward isn't marked cacheable even though it is"
|
||||
)
|
||||
|
||||
class FlexAttentionModule(torch.nn.Module):
|
||||
def __init__(self, hidden_size, num_heads):
|
||||
super().__init__()
|
||||
|
||||
@ -8101,14 +8101,6 @@ class ReproTestsDevice(torch._dynamo.test_case.TestCase):
|
||||
res = gm(x, y)
|
||||
self.assertEqual(res, ref)
|
||||
|
||||
def test_current_accelerator(self):
|
||||
@torch.compile(backend="eager", fullgraph=True)
|
||||
def fn(x):
|
||||
torch.accelerator.current_accelerator()
|
||||
return x + 1
|
||||
|
||||
self.assertEqual(fn(torch.ones(3)), torch.ones(3) + 1)
|
||||
|
||||
|
||||
instantiate_parametrized_tests(ReproTests)
|
||||
|
||||
|
||||
@ -402,6 +402,43 @@ def forward(self, x):
|
||||
|
||||
self.assertEqual(res_export, res_eager)
|
||||
|
||||
def test_dynamo_graph_capture(self):
|
||||
from torch._dynamo.functional_export import dynamo_graph_capture_for_export
|
||||
|
||||
class Foo(torch.nn.Module):
|
||||
def forward(self, dct, lst, bleh):
|
||||
x = dct["a"] * lst[1][0]
|
||||
y = dct["b"] * lst[0]
|
||||
out_dict = {}
|
||||
|
||||
# Mutate and get a new entry in there
|
||||
lst_copy = lst.copy()
|
||||
lst_copy.append(lst[0])
|
||||
out_dict["a"] = x
|
||||
out_dict["b"] = y
|
||||
return (
|
||||
dct["a"],
|
||||
out_dict["b"],
|
||||
bleh,
|
||||
lst_copy[-1],
|
||||
out_dict["a"],
|
||||
[5, 6],
|
||||
)
|
||||
|
||||
foo = Foo()
|
||||
|
||||
def make_inputs():
|
||||
return (
|
||||
{"a": torch.randn(2, 3), "b": torch.randn(2, 3)},
|
||||
[torch.randn(2, 3), (torch.randn(2, 3),)],
|
||||
torch.randn(2, 3),
|
||||
)
|
||||
|
||||
trace_inputs = make_inputs()
|
||||
gm = dynamo_graph_capture_for_export(foo)(*trace_inputs)
|
||||
test_inputs = make_inputs()
|
||||
self.assertEqual(gm(*test_inputs), foo(*test_inputs))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user