mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-28 02:04:53 +08:00
Compare commits
99 Commits
csl/workfl
...
ciflow/tru
| Author | SHA1 | Date | |
|---|---|---|---|
| ace9adfaf0 | |||
| f89a7e9fe8 | |||
| f2c81635c8 | |||
| e214af6ae8 | |||
| 7ce723d21c | |||
| 4295a9a158 | |||
| 90d7be35e9 | |||
| 8d4e48831e | |||
| 90b30ebf7e | |||
| 173bcda436 | |||
| 6530bc70fb | |||
| 4c38887346 | |||
| 81fa4a204c | |||
| 4e6afa8c07 | |||
| 79aa88cc5d | |||
| fa4cb91846 | |||
| c58d0ad85d | |||
| 000f49551b | |||
| 9940e894ea | |||
| 27302a4932 | |||
| 507614ba43 | |||
| 86f9f1d0ab | |||
| 154e4d36e9 | |||
| a2b6afeac5 | |||
| 262830d86c | |||
| e4c01011c2 | |||
| a60d9e1f6d | |||
| f863550192 | |||
| 84b14f3a10 | |||
| 5121499f6b | |||
| 8f80892359 | |||
| cdb60e44eb | |||
| 25909d2629 | |||
| c7eee49525 | |||
| 621ba05107 | |||
| 39a70cead1 | |||
| d97f6550a2 | |||
| 516e58965a | |||
| b55b779ad3 | |||
| 74e53d0761 | |||
| 798a6d2be1 | |||
| b0e9c86971 | |||
| 661a56002f | |||
| c9bc00f016 | |||
| ec51b139e1 | |||
| eb83c3ca23 | |||
| 7924e3aacf | |||
| 78bcfcf870 | |||
| 1e2e7cb18b | |||
| 003601a70d | |||
| 1d58d5fe25 | |||
| de7fdfe41a | |||
| b31bad1b8f | |||
| 2efcf3ca98 | |||
| 761f946043 | |||
| 8aa465f18e | |||
| 0a5d68d92d | |||
| 42bd210fff | |||
| 1d13c314b3 | |||
| 0c9763a5a0 | |||
| 79a4a9c02e | |||
| 9d0b77f4cd | |||
| d486eee234 | |||
| cddd5f74ab | |||
| dfdb68e51f | |||
| 98c818320a | |||
| cc20b7ad72 | |||
| bc11a42b3f | |||
| 4fc06f2e0a | |||
| 82473c3d59 | |||
| b6a4236e5d | |||
| b04173be9b | |||
| 32ac38f85d | |||
| c9b49e506e | |||
| 6038e476e8 | |||
| 2c851c16e5 | |||
| 31584f2d91 | |||
| 0442125362 | |||
| fdcf402d82 | |||
| 13cda9b89e | |||
| fa6d911dda | |||
| 0db6bcc015 | |||
| 60ac039998 | |||
| 380d440d1c | |||
| 9038a30cee | |||
| 690c8c13b9 | |||
| 28ee6b62ed | |||
| 81577bdb3f | |||
| e67e3d95f3 | |||
| 27af8480ea | |||
| 6494cdc40c | |||
| ac7074efa2 | |||
| 263901cec4 | |||
| c12293dcbe | |||
| 5a4997dcae | |||
| 47f638eae7 | |||
| 882b834082 | |||
| b146ea411e | |||
| 8625ffbd45 |
354
.claude/skills/pytorch-docstring.md
Normal file
354
.claude/skills/pytorch-docstring.md
Normal file
@ -0,0 +1,354 @@
|
||||
# PyTorch Docstring Writing Guide
|
||||
|
||||
This skill describes how to write docstrings for functions and methods in the PyTorch project, following the conventions in `torch/_tensor_docs.py` and `torch/nn/functional.py`.
|
||||
|
||||
## General Principles
|
||||
|
||||
- Use **raw strings** (`r"""..."""`) for all docstrings to avoid issues with LaTeX/math backslashes
|
||||
- Follow **Sphinx/reStructuredText** (reST) format for documentation
|
||||
- Be **concise but complete** - include all essential information
|
||||
- Always include **examples** when possible
|
||||
- Use **cross-references** to related functions/classes
|
||||
|
||||
## Docstring Structure
|
||||
|
||||
### 1. Function Signature (First Line)
|
||||
|
||||
Start with the function signature showing all parameters:
|
||||
|
||||
```python
|
||||
r"""function_name(param1, param2, *, kwarg1=default1, kwarg2=default2) -> ReturnType
|
||||
```
|
||||
|
||||
**Notes:**
|
||||
- Include the function name
|
||||
- Show positional and keyword-only arguments (use `*` separator)
|
||||
- Include default values
|
||||
- Show return type annotation
|
||||
- This line should NOT end with a period
|
||||
|
||||
### 2. Brief Description
|
||||
|
||||
Provide a one-line description of what the function does:
|
||||
|
||||
```python
|
||||
r"""conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1) -> Tensor
|
||||
|
||||
Applies a 2D convolution over an input image composed of several input
|
||||
planes.
|
||||
```
|
||||
|
||||
### 3. Mathematical Formulas (if applicable)
|
||||
|
||||
Use Sphinx math directives for mathematical expressions:
|
||||
|
||||
```python
|
||||
.. math::
|
||||
\text{Softmax}(x_{i}) = \frac{\exp(x_i)}{\sum_j \exp(x_j)}
|
||||
```
|
||||
|
||||
Or inline math: `:math:\`x^2\``
|
||||
|
||||
### 4. Cross-References
|
||||
|
||||
Link to related classes and functions using Sphinx roles:
|
||||
|
||||
- `:class:\`~torch.nn.ModuleName\`` - Link to a class
|
||||
- `:func:\`torch.function_name\`` - Link to a function
|
||||
- `:meth:\`~Tensor.method_name\`` - Link to a method
|
||||
- `:attr:\`attribute_name\`` - Reference an attribute
|
||||
- The `~` prefix shows only the last component (e.g., `Conv2d` instead of `torch.nn.Conv2d`)
|
||||
|
||||
**Example:**
|
||||
```python
|
||||
See :class:`~torch.nn.Conv2d` for details and output shape.
|
||||
```
|
||||
|
||||
### 5. Notes and Warnings
|
||||
|
||||
Use admonitions for important information:
|
||||
|
||||
```python
|
||||
.. note::
|
||||
This function doesn't work directly with NLLLoss,
|
||||
which expects the Log to be computed between the Softmax and itself.
|
||||
Use log_softmax instead (it's faster and has better numerical properties).
|
||||
|
||||
.. warning::
|
||||
:func:`new_tensor` always copies :attr:`data`. If you have a Tensor
|
||||
``data`` and want to avoid a copy, use :func:`torch.Tensor.requires_grad_`
|
||||
or :func:`torch.Tensor.detach`.
|
||||
```
|
||||
|
||||
### 6. Args Section
|
||||
|
||||
Document all parameters with type annotations and descriptions:
|
||||
|
||||
```python
|
||||
Args:
|
||||
input (Tensor): input tensor of shape :math:`(\text{minibatch} , \text{in\_channels} , iH , iW)`
|
||||
weight (Tensor): filters of shape :math:`(\text{out\_channels} , kH , kW)`
|
||||
bias (Tensor, optional): optional bias tensor of shape :math:`(\text{out\_channels})`. Default: ``None``
|
||||
stride (int or tuple): the stride of the convolving kernel. Can be a single number or a
|
||||
tuple `(sH, sW)`. Default: 1
|
||||
```
|
||||
|
||||
**Formatting rules:**
|
||||
- Parameter name in **lowercase**
|
||||
- Type in parentheses: `(Type)`, `(Type, optional)` for optional parameters
|
||||
- Description follows the type
|
||||
- For optional parameters, include "Default: ``value``" at the end
|
||||
- Use double backticks for inline code: ``` ``None`` ```
|
||||
- Indent continuation lines by 2 spaces
|
||||
|
||||
### 7. Keyword Args Section (if applicable)
|
||||
|
||||
Sometimes keyword arguments are documented separately:
|
||||
|
||||
```python
|
||||
Keyword args:
|
||||
dtype (:class:`torch.dtype`, optional): the desired type of returned tensor.
|
||||
Default: if None, same :class:`torch.dtype` as this tensor.
|
||||
device (:class:`torch.device`, optional): the desired device of returned tensor.
|
||||
Default: if None, same :class:`torch.device` as this tensor.
|
||||
requires_grad (bool, optional): If autograd should record operations on the
|
||||
returned tensor. Default: ``False``.
|
||||
```
|
||||
|
||||
### 8. Returns Section (if needed)
|
||||
|
||||
Document the return value:
|
||||
|
||||
```python
|
||||
Returns:
|
||||
Tensor: Sampled tensor of same shape as `logits` from the Gumbel-Softmax distribution.
|
||||
If ``hard=True``, the returned samples will be one-hot, otherwise they will
|
||||
be probability distributions that sum to 1 across `dim`.
|
||||
```
|
||||
|
||||
Or simply include it in the function signature line if obvious from context.
|
||||
|
||||
### 9. Examples Section
|
||||
|
||||
Always include examples when possible:
|
||||
|
||||
```python
|
||||
Examples::
|
||||
|
||||
>>> inputs = torch.randn(33, 16, 30)
|
||||
>>> filters = torch.randn(20, 16, 5)
|
||||
>>> F.conv1d(inputs, filters)
|
||||
|
||||
>>> # With square kernels and equal stride
|
||||
>>> filters = torch.randn(8, 4, 3, 3)
|
||||
>>> inputs = torch.randn(1, 4, 5, 5)
|
||||
>>> F.conv2d(inputs, filters, padding=1)
|
||||
```
|
||||
|
||||
**Formatting rules:**
|
||||
- Use `Examples::` with double colon
|
||||
- Use `>>>` prompt for Python code
|
||||
- Include comments with `#` when helpful
|
||||
- Show actual output when it helps understanding (indent without `>>>`)
|
||||
|
||||
### 10. External References
|
||||
|
||||
Link to papers or external documentation:
|
||||
|
||||
```python
|
||||
.. _Link Name:
|
||||
https://arxiv.org/abs/1611.00712
|
||||
```
|
||||
|
||||
Reference them in text: ```See `Link Name`_```
|
||||
|
||||
## Method Types
|
||||
|
||||
### Native Python Functions
|
||||
|
||||
For regular Python functions, use a standard docstring:
|
||||
|
||||
```python
|
||||
def relu(input: Tensor, inplace: bool = False) -> Tensor:
|
||||
r"""relu(input, inplace=False) -> Tensor
|
||||
|
||||
Applies the rectified linear unit function element-wise. See
|
||||
:class:`~torch.nn.ReLU` for more details.
|
||||
"""
|
||||
# implementation
|
||||
```
|
||||
|
||||
### C-Bound Functions (using add_docstr)
|
||||
|
||||
For C-bound functions, use `_add_docstr`:
|
||||
|
||||
```python
|
||||
conv1d = _add_docstr(
|
||||
torch.conv1d,
|
||||
r"""
|
||||
conv1d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1) -> Tensor
|
||||
|
||||
Applies a 1D convolution over an input signal composed of several input
|
||||
planes.
|
||||
|
||||
See :class:`~torch.nn.Conv1d` for details and output shape.
|
||||
|
||||
Args:
|
||||
input: input tensor of shape :math:`(\text{minibatch} , \text{in\_channels} , iW)`
|
||||
weight: filters of shape :math:`(\text{out\_channels} , kW)`
|
||||
...
|
||||
""",
|
||||
)
|
||||
```
|
||||
|
||||
### In-Place Variants
|
||||
|
||||
For in-place operations (ending with `_`), reference the original:
|
||||
|
||||
```python
|
||||
add_docstr_all(
|
||||
"abs_",
|
||||
r"""
|
||||
abs_() -> Tensor
|
||||
|
||||
In-place version of :meth:`~Tensor.abs`
|
||||
""",
|
||||
)
|
||||
```
|
||||
|
||||
### Alias Functions
|
||||
|
||||
For aliases, simply reference the original:
|
||||
|
||||
```python
|
||||
add_docstr_all(
|
||||
"absolute",
|
||||
r"""
|
||||
absolute() -> Tensor
|
||||
|
||||
Alias for :func:`abs`
|
||||
""",
|
||||
)
|
||||
```
|
||||
|
||||
## Common Patterns
|
||||
|
||||
### Shape Documentation
|
||||
|
||||
Use LaTeX math notation for tensor shapes:
|
||||
|
||||
```python
|
||||
:math:`(\text{minibatch} , \text{in\_channels} , iH , iW)`
|
||||
```
|
||||
|
||||
### Reusable Argument Definitions
|
||||
|
||||
For commonly used arguments, define them once and reuse:
|
||||
|
||||
```python
|
||||
common_args = parse_kwargs(
|
||||
"""
|
||||
dtype (:class:`torch.dtype`, optional): the desired type of returned tensor.
|
||||
Default: if None, same as this tensor.
|
||||
"""
|
||||
)
|
||||
|
||||
# Then use with .format():
|
||||
r"""
|
||||
...
|
||||
|
||||
Keyword args:
|
||||
{dtype}
|
||||
{device}
|
||||
""".format(**common_args)
|
||||
```
|
||||
|
||||
### Template Insertion
|
||||
|
||||
Insert reproducibility notes or other common text:
|
||||
|
||||
```python
|
||||
r"""
|
||||
{tf32_note}
|
||||
|
||||
{cudnn_reproducibility_note}
|
||||
""".format(**reproducibility_notes, **tf32_notes)
|
||||
```
|
||||
|
||||
## Complete Example
|
||||
|
||||
Here's a complete example showing all elements:
|
||||
|
||||
```python
|
||||
def gumbel_softmax(
|
||||
logits: Tensor,
|
||||
tau: float = 1,
|
||||
hard: bool = False,
|
||||
eps: float = 1e-10,
|
||||
dim: int = -1,
|
||||
) -> Tensor:
|
||||
r"""
|
||||
Sample from the Gumbel-Softmax distribution and optionally discretize.
|
||||
|
||||
Args:
|
||||
logits (Tensor): `[..., num_features]` unnormalized log probabilities
|
||||
tau (float): non-negative scalar temperature
|
||||
hard (bool): if ``True``, the returned samples will be discretized as one-hot vectors,
|
||||
but will be differentiated as if it is the soft sample in autograd. Default: ``False``
|
||||
dim (int): A dimension along which softmax will be computed. Default: -1
|
||||
|
||||
Returns:
|
||||
Tensor: Sampled tensor of same shape as `logits` from the Gumbel-Softmax distribution.
|
||||
If ``hard=True``, the returned samples will be one-hot, otherwise they will
|
||||
be probability distributions that sum to 1 across `dim`.
|
||||
|
||||
.. note::
|
||||
This function is here for legacy reasons, may be removed from nn.Functional in the future.
|
||||
|
||||
Examples::
|
||||
>>> logits = torch.randn(20, 32)
|
||||
>>> # Sample soft categorical using reparametrization trick:
|
||||
>>> F.gumbel_softmax(logits, tau=1, hard=False)
|
||||
>>> # Sample hard categorical using "Straight-through" trick:
|
||||
>>> F.gumbel_softmax(logits, tau=1, hard=True)
|
||||
|
||||
.. _Link 1:
|
||||
https://arxiv.org/abs/1611.00712
|
||||
"""
|
||||
# implementation
|
||||
```
|
||||
|
||||
## Quick Checklist
|
||||
|
||||
When writing a PyTorch docstring, ensure:
|
||||
|
||||
- [ ] Use raw string (`r"""`)
|
||||
- [ ] Include function signature on first line
|
||||
- [ ] Provide brief description
|
||||
- [ ] Document all parameters in Args section with types
|
||||
- [ ] Include default values for optional parameters
|
||||
- [ ] Use Sphinx cross-references (`:func:`, `:class:`, `:meth:`)
|
||||
- [ ] Add mathematical formulas if applicable
|
||||
- [ ] Include at least one example in Examples section
|
||||
- [ ] Add warnings/notes for important caveats
|
||||
- [ ] Link to related module class with `:class:`
|
||||
- [ ] Use proper math notation for tensor shapes
|
||||
- [ ] Follow consistent formatting and indentation
|
||||
|
||||
## Common Sphinx Roles Reference
|
||||
|
||||
- `:class:\`~torch.nn.Module\`` - Class reference
|
||||
- `:func:\`torch.function\`` - Function reference
|
||||
- `:meth:\`~Tensor.method\`` - Method reference
|
||||
- `:attr:\`attribute\`` - Attribute reference
|
||||
- `:math:\`equation\`` - Inline math
|
||||
- `:ref:\`label\`` - Internal reference
|
||||
- ``` ``code`` ``` - Inline code (use double backticks)
|
||||
|
||||
## Additional Notes
|
||||
|
||||
- **Indentation**: Use 4 spaces for code, 2 spaces for continuation of parameter descriptions
|
||||
- **Line length**: Try to keep lines under 100 characters when possible
|
||||
- **Periods**: End sentences with periods, but not the signature line
|
||||
- **Backticks**: Use double backticks for code: ``` ``True`` ``None`` ``False`` ```
|
||||
- **Types**: Common types are `Tensor`, `int`, `float`, `bool`, `str`, `tuple`, `list`, etc.
|
||||
7
.github/actions/setup-rocm/action.yml
vendored
7
.github/actions/setup-rocm/action.yml
vendored
@ -124,3 +124,10 @@ runs:
|
||||
id: login-ecr
|
||||
continue-on-error: true
|
||||
uses: aws-actions/amazon-ecr-login@062b18b96a7aff071d4dc91bc00c4c1a7945b076 # v2.0.1
|
||||
|
||||
- name: Preserve github env variables for use in docker
|
||||
shell: bash
|
||||
run: |
|
||||
env | grep '^GITHUB' >> "${RUNNER_TEMP}/github_env_${GITHUB_RUN_ID}"
|
||||
env | grep '^CI' >> "${RUNNER_TEMP}/github_env_${GITHUB_RUN_ID}"
|
||||
env | grep '^RUNNER' >> "${RUNNER_TEMP}/github_env_${GITHUB_RUN_ID}"
|
||||
|
||||
2
.github/ci_commit_pins/xla.txt
vendored
2
.github/ci_commit_pins/xla.txt
vendored
@ -1 +1 @@
|
||||
0fa6e3129e61143224663e1ec67980d12b7ec4eb
|
||||
df6798dfb931ce7c7fe5bed2447cd1092a5981af
|
||||
|
||||
5
.github/ci_configs/vllm/Dockerfile
vendored
5
.github/ci_configs/vllm/Dockerfile
vendored
@ -283,6 +283,9 @@ RUN --mount=type=bind,source=${TORCH_WHEELS_PATH},target=/dist \
|
||||
uv pip install --system $(cat torch_build_versions.txt | xargs) --index-url https://download.pytorch.org/whl/nightly/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.'); \
|
||||
fi
|
||||
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
uv pip install --system --pre apache-tvm-ffi==0.1.0b15
|
||||
|
||||
# Install the vllm wheel from previous stage
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
uv pip install --system /wheels/vllm/*.whl --verbose
|
||||
@ -295,6 +298,8 @@ RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
ARG torch_cuda_arch_list='8.0;8.9;9.0a;10.0a;12.0'
|
||||
ENV TORCH_CUDA_ARCH_LIST=${torch_cuda_arch_list}
|
||||
|
||||
# TODO(elainewy): remove this once vllm commit is updated, and install flashinfer from pip
|
||||
# see https://github.com/pytorch/pytorch/pull/165274#issuecomment-3408531784
|
||||
ARG FLASHINFER_GIT_REPO="https://github.com/flashinfer-ai/flashinfer.git"
|
||||
ARG FLASHINFER_GIT_REF="v0.2.14.post1"
|
||||
|
||||
|
||||
9
.github/label_to_label.yml
vendored
9
.github/label_to_label.yml
vendored
@ -15,6 +15,11 @@
|
||||
- "module: reinplacing"
|
||||
then:
|
||||
- "module: pt2-dispatcher"
|
||||
- any:
|
||||
- "vllm-compile"
|
||||
then:
|
||||
- "module: vllm"
|
||||
- "oncall: pt2"
|
||||
- any:
|
||||
- "module: vmap"
|
||||
then:
|
||||
@ -27,10 +32,6 @@
|
||||
- "module: pt2 optimizer"
|
||||
then:
|
||||
- "module: dynamo"
|
||||
- any:
|
||||
- "module: flex attention"
|
||||
then:
|
||||
- "module: higher order operators"
|
||||
- any:
|
||||
- "module: aotinductor"
|
||||
then:
|
||||
|
||||
@ -833,8 +833,7 @@ exclude_patterns = [
|
||||
command = [
|
||||
'python3',
|
||||
'tools/linter/adapters/grep_linter.py',
|
||||
'--pattern=cudaSetDevice(',
|
||||
'--pattern=cudaGetDevice(',
|
||||
'--pattern=(cudaSetDevice|cudaGetDevice)\\(',
|
||||
'--linter-name=RAWCUDADEVICE',
|
||||
'--error-name=raw CUDA API usage',
|
||||
"""--error-description=\
|
||||
|
||||
@ -38,7 +38,7 @@ set_bool(AT_HIPSPARSELT_ENABLED CAFFE2_USE_HIPSPARSELT)
|
||||
|
||||
configure_file(Config.h.in "${CMAKE_CURRENT_SOURCE_DIR}/Config.h")
|
||||
# TODO: Do not generate CUDAConfig.h for ROCm BUILDS
|
||||
# At the moment, `jit_macors.h` include CUDAConfig.h for both CUDA and HIP builds
|
||||
# At the moment, `jit_macros.h` include CUDAConfig.h for both CUDA and HIP builds
|
||||
if(USE_CUDA OR USE_ROCM)
|
||||
configure_file(cuda/CUDAConfig.h.in "${CMAKE_CURRENT_SOURCE_DIR}/cuda/CUDAConfig.h")
|
||||
endif()
|
||||
|
||||
@ -122,7 +122,7 @@ void FunctionalTensorWrapper::freeze_storage() const {
|
||||
// | have their own storages, but backends like functorch |
|
||||
// \/ are allowed to re-alias underneath the pass \/
|
||||
// . - - - - - - - - - - - - - . . - - - - - - - - - - - - - - - .
|
||||
// | underyling_storage | | underyling_storage |
|
||||
// | underlying_storage | | underlying_storage |
|
||||
// . - - - - - - - - - - - - - . . - - - - - - - - - - - - - - - .
|
||||
//
|
||||
// This constructor is only used by view ops.
|
||||
|
||||
@ -1534,7 +1534,7 @@ void TensorIteratorBase::build(TensorIteratorConfig& config) {
|
||||
|
||||
// XLA and lazy tensors don't have storage, so they don't have an underlying data pointer.
|
||||
// Nothing beyond this point is important for meta functions, so it's fine to exit early here.
|
||||
// Extend the condition to MAIA tesnors as MAIA tensors also don't have storage.
|
||||
// Extend the condition to MAIA tensors as MAIA tensors also don't have storage.
|
||||
if (privateuse1_without_storage ||
|
||||
common_device_.type() == DeviceType::XLA ||
|
||||
common_device_.type() == DeviceType::IPU ||
|
||||
|
||||
@ -94,11 +94,11 @@ struct PinnedReserveSegment {
|
||||
struct TORCH_API HostStats {
|
||||
// COUNT: total allocations (active)
|
||||
Stat active_requests;
|
||||
// SUM: bytes allocated/reserved by this memory alocator. (active)
|
||||
// SUM: bytes allocated/reserved by this memory allocator. (active)
|
||||
Stat active_bytes;
|
||||
// COUNT: total allocations (active + free)
|
||||
Stat allocations;
|
||||
// SUM: bytes allocated/reserved by this memory alocator. This accounts
|
||||
// SUM: bytes allocated/reserved by this memory allocator. This accounts
|
||||
// for both free and in-use blocks.
|
||||
Stat allocated_bytes;
|
||||
|
||||
@ -127,7 +127,7 @@ struct alignas(hardware_destructive_interference_size) HostStatsStaged {
|
||||
// COUNT: total allocations (active + free)
|
||||
// LOCK: access to this stat is protected by the allocator's blocks_mutex_
|
||||
Stat allocations;
|
||||
// SUM: bytes allocated/reserved by this memory alocator. This accounts
|
||||
// SUM: bytes allocated/reserved by this memory allocator. This accounts
|
||||
// for both free and in-use blocks.
|
||||
Stat allocated_bytes;
|
||||
// COUNT: number of allocations per bucket (active)
|
||||
@ -455,7 +455,7 @@ struct CachingHostAllocatorImpl {
|
||||
}
|
||||
|
||||
void resetAccumulatedStats() {
|
||||
// Reseting accumulated memory stats requires concurrently holding both the
|
||||
// Resetting accumulated memory stats requires concurrently holding both the
|
||||
// free list mutexes and the blocks mutex. Previously, this was only done in
|
||||
// empty_cache function.
|
||||
for (size_t i = 0; i < free_list_.size(); ++i) {
|
||||
@ -482,7 +482,7 @@ struct CachingHostAllocatorImpl {
|
||||
}
|
||||
|
||||
void resetPeakStats() {
|
||||
// Reseting peak memory stats requires concurrently holding both the
|
||||
// Resetting peak memory stats requires concurrently holding both the
|
||||
// free list mutexes and the blocks mutex. Previously, this was only done in
|
||||
// empty_cache function.
|
||||
for (size_t i = 0; i < free_list_.size(); ++i) {
|
||||
|
||||
@ -109,6 +109,10 @@ TORCH_LIBRARY_IMPL(_, AutogradHPU, m) {
|
||||
m.fallback(AUTOGRAD_FALLBACK);
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL(_, AutogradPrivateUse1, m) {
|
||||
m.fallback(AUTOGRAD_FALLBACK);
|
||||
}
|
||||
|
||||
#undef AUTOGRAD_FALLBACK
|
||||
|
||||
} // namespace
|
||||
|
||||
@ -148,7 +148,7 @@ struct TORCH_API ClassType : public NamedType {
|
||||
|
||||
void checkNotExist(const std::string& name, const std::string& what) const;
|
||||
|
||||
// Attributes are stored in a specific slot at runtime for effiency.
|
||||
// Attributes are stored in a specific slot at runtime for efficiency.
|
||||
// When emitting instructions we specify the slot so that attribute access is
|
||||
// a constant lookup
|
||||
std::optional<size_t> findAttributeSlot(const std::string& name) const {
|
||||
@ -412,7 +412,7 @@ struct TORCH_API ClassType : public NamedType {
|
||||
// Holds method attributes
|
||||
std::weak_ptr<CompilationUnit> compilation_unit_;
|
||||
|
||||
// Holds all atrributes, attribute details are found on ClassAttribute
|
||||
// Holds all attributes, attribute details are found on ClassAttribute
|
||||
std::vector<ClassAttribute> attributes_;
|
||||
// Construct mirroring attributes_, only around due to the fact that `containedTypes()` method returns an ArrayRef.
|
||||
// Never fill this without using the appropriate provideNewClassAttribute method
|
||||
|
||||
@ -442,11 +442,17 @@ RegistrationHandleRAII Dispatcher::registerFallback(DispatchKey dispatchKey, Ker
|
||||
|
||||
auto idx = getDispatchTableIndexForDispatchKey(dispatchKey);
|
||||
TORCH_CHECK(idx >= 0 && static_cast<uint64_t>(idx) < backendFallbackKernels_.size(), "idx=", idx);
|
||||
// NB: Perserve BC for registering fallback for AutogradPrivateUse1 multiple time,
|
||||
// refer to https://github.com/pytorch/pytorch/issues/163979 for more informations.
|
||||
TORCH_CHECK(
|
||||
!backendFallbackKernels_[idx].kernel.isValid(),
|
||||
"Tried to register multiple backend fallbacks for the same dispatch key ", dispatchKey, "; previous registration ",
|
||||
backendFallbackKernels_[idx].debug, ", new registration ", debug
|
||||
);
|
||||
dispatchKey == DispatchKey::AutogradPrivateUse1 ||
|
||||
!backendFallbackKernels_[idx].kernel.isValid(),
|
||||
"Tried to register multiple backend fallbacks for the same dispatch key ",
|
||||
dispatchKey,
|
||||
"; previous registration ",
|
||||
backendFallbackKernels_[idx].debug,
|
||||
", new registration ",
|
||||
debug);
|
||||
// NB: inferred function schema is always nullptr for fallbacks, as fallbacks
|
||||
// cannot be unboxed
|
||||
backendFallbackKernels_[idx] = impl::AnnotatedKernel(std::move(kernel), nullptr, std::move(debug));
|
||||
@ -531,7 +537,7 @@ int64_t Dispatcher::sequenceNumberForRunningRecordFunction(DispatchKey dispatchK
|
||||
|
||||
// Note: this records a sequence number for both Autograd keys, and for
|
||||
// non-Autograd keys where the dispatchKeySet still contains an autograd key.
|
||||
// This means that we might collect the same sequence nubmer two different
|
||||
// This means that we might collect the same sequence number two different
|
||||
// events if they all occurred above Autograd and still had the Autograd
|
||||
// dispatch key in the dispatch key set.
|
||||
// However, this usually doesn't happen: normally the first call will
|
||||
|
||||
@ -585,7 +585,7 @@ class TORCH_API OperatorHandle {
|
||||
|
||||
// We need to store this iterator in order to make
|
||||
// Dispatcher::cleanup() fast -- it runs a lot on program
|
||||
// termination (and presuambly library unloading).
|
||||
// termination (and presumably library unloading).
|
||||
std::list<Dispatcher::OperatorDef>::iterator operatorIterator_;
|
||||
};
|
||||
|
||||
|
||||
@ -365,7 +365,7 @@ std::pair<const AnnotatedKernel&, const char*> OperatorEntry::computeDispatchTab
|
||||
// For autograd keys, we only use kernel from CompositeImplicitAutograd when there's no direct registration
|
||||
// to its corresponding backend key or CompositeExplicitAutograd. See Note [CompositeExplicitAutograd and CompositeImplicitAutograd].
|
||||
// For AutogradOther, we eagerly return ambiguousAutogradOtherKernel() if there's registration to any of
|
||||
// its backends and ask backend extender to request a decicated Autograd key for the backend.
|
||||
// its backends and ask backend extender to request a dedicated Autograd key for the backend.
|
||||
// See Note [Ambiguity in AutogradOther kernel] for more details.
|
||||
// A CompositeExplicitAutograd kernel prevents CompositeImplicitAutograd kernel being used for Autograd keys, but it doesn't
|
||||
// cause confusion for AutogradOther. It's pretty straightforward to use Autograd (if available)
|
||||
|
||||
@ -261,7 +261,7 @@ std::ostream& operator<<(std::ostream& out, const FunctionSchema& schema) {
|
||||
//
|
||||
// There are 2 cases
|
||||
// 1. something like 'aten::items.str(Dict(str, t) self) -> ((str, t)[])'.
|
||||
// without the extra parenthesis, the c++ schem parser can not parse it.
|
||||
// without the extra parenthesis, the c++ scheme parser can not parse it.
|
||||
// 2. something like '-> ((str, str))'. Need extra parenthesis so the return
|
||||
// type is a single tuple rather than two strings.
|
||||
// PR (https://github.com/pytorch/pytorch/pull/23204) has more context about
|
||||
|
||||
@ -1176,7 +1176,7 @@ struct TORCH_API IValue final {
|
||||
using HashIdentityIValueMap =
|
||||
std::unordered_map<IValue, IValue, HashIdentityIValue, CompIdentityIValues>;
|
||||
|
||||
// Chechs if this and rhs has a subvalues in common.
|
||||
// Checks if this and rhs has a subvalues in common.
|
||||
// [t1,t2] and [t2, t3] returns true.
|
||||
bool overlaps(const IValue& rhs) const;
|
||||
|
||||
|
||||
@ -1501,7 +1501,7 @@ struct C10_EXPORT ivalue::Object final : c10::intrusive_ptr_target {
|
||||
// However, the CompilationUnit holds ownership of the type's graphs, so
|
||||
// inserting a constant object into a Graph would create a reference cycle if
|
||||
// that constant object held a shared_ptr to its CU. For these objects we
|
||||
// instatiate them with non-owning references to its CU
|
||||
// instantiate them with non-owning references to its CU
|
||||
Object(WeakOrStrongTypePtr type, size_t numSlots) : type_(std::move(type)) {
|
||||
slots_.resize(numSlots);
|
||||
}
|
||||
|
||||
@ -373,7 +373,7 @@ struct TORCH_API SymbolicShape {
|
||||
// Unranked shape constructor.
|
||||
SymbolicShape() : dims_(std::nullopt) {}
|
||||
|
||||
// Known rank but unknown dimentions.
|
||||
// Known rank but unknown dimensions.
|
||||
SymbolicShape(std::optional<size_t> rank) : dims_(std::nullopt) {
|
||||
if(!rank) {
|
||||
return;
|
||||
@ -884,9 +884,9 @@ struct TORCH_API ListType
|
||||
|
||||
// global singleton
|
||||
// Given an inner type T and an identifier,
|
||||
// this function wil return the global singleton type pointer
|
||||
// this function will return the global singleton type pointer
|
||||
// the type List<T>.
|
||||
// The extra "identifier" argument is needed beccause we have multiple container types
|
||||
// The extra "identifier" argument is needed because we have multiple container types
|
||||
// that all re-use this function (List<T>, array<T, N>, etc.)
|
||||
static TypePtr get(const std::string& identifier, TypePtr inner);
|
||||
|
||||
|
||||
@ -185,11 +185,11 @@ struct TORCH_API Type {
|
||||
: repr_(nullptr) {}
|
||||
|
||||
/* implicit */ SingletonOrSharedTypePtr(SingletonTypePtr<T> p)
|
||||
: repr_(p) {}
|
||||
: repr_(makeSingletonSharedPtr(p.get())) {}
|
||||
|
||||
template <typename U, std::enable_if_t<std::is_convertible_v<U*, T*>, bool> = true>
|
||||
/* implicit */ SingletonOrSharedTypePtr(SingletonTypePtr<U> p)
|
||||
: repr_(SingletonTypePtr<T>(p.get())) {}
|
||||
: repr_(makeSingletonSharedPtr(static_cast<T*>(p.get()))) {}
|
||||
|
||||
|
||||
// We need to support construction from T* for pybind. The problem
|
||||
@ -202,8 +202,8 @@ struct TORCH_API Type {
|
||||
// Case 2: if T is exactly Type, we need to do a dynamic_cast to
|
||||
// check if it's a SharedType and do the right thing.
|
||||
//
|
||||
// Case 3: Otherwise, T is not a SharedType. (debug-check this
|
||||
// assumption!) Use a singleton pointer.
|
||||
// Case 3: Otherwise, T is not a SharedType. Use a singleton
|
||||
// pointer.
|
||||
|
||||
template <typename U = T, std::enable_if_t<std::is_base_of_v<SharedType, U>, bool> = true>
|
||||
/* implicit */ SingletonOrSharedTypePtr(T* p) : SingletonOrSharedTypePtr(static_cast<typename detail::as_shared_type<U>::type>(p)->shared_from_this()) {}
|
||||
@ -211,15 +211,15 @@ struct TORCH_API Type {
|
||||
template <typename U = T, std::enable_if_t<std::is_same_v<Type, U>, bool> = true>
|
||||
/* implicit */ SingletonOrSharedTypePtr(T* p) {
|
||||
if (auto* shared_p = dynamic_cast<typename detail::as_shared_type<U>::type>(p)) {
|
||||
repr_ = Repr(shared_p->shared_from_this());
|
||||
repr_ = shared_p->shared_from_this();
|
||||
} else {
|
||||
repr_ = Repr(p);
|
||||
repr_ = makeSingletonSharedPtr(p);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename U = T, std::enable_if_t<!std::is_same_v<Type, U> && !std::is_base_of_v<SharedType, U>, bool> = true>
|
||||
/* implicit */ SingletonOrSharedTypePtr(T* p)
|
||||
: repr_(p) {
|
||||
: repr_(makeSingletonSharedPtr(p)) {
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(dynamic_cast<typename detail::as_shared_type<U>::type>(p) == nullptr);
|
||||
}
|
||||
|
||||
@ -230,19 +230,19 @@ struct TORCH_API Type {
|
||||
~SingletonOrSharedTypePtr() = default;
|
||||
|
||||
T* get() const {
|
||||
return repr_.isSharedAndNonNull() ? repr_.shared_.repr_.get() : static_cast<T*>(repr_.rawRepr().first);
|
||||
return repr_.get();
|
||||
}
|
||||
|
||||
operator bool() const {
|
||||
return repr_.isNonNull();
|
||||
return repr_ != nullptr;
|
||||
}
|
||||
|
||||
bool operator==(std::nullptr_t) const {
|
||||
return !repr_.isNonNull();
|
||||
return repr_ == nullptr;
|
||||
}
|
||||
|
||||
bool operator!=(std::nullptr_t) const {
|
||||
return repr_.isNonNull();
|
||||
return repr_ != nullptr;
|
||||
}
|
||||
|
||||
template <typename U = T, std::enable_if_t<!std::is_same_v<std::remove_const_t<U>, void>, bool> = true>
|
||||
@ -255,138 +255,14 @@ struct TORCH_API Type {
|
||||
}
|
||||
|
||||
private:
|
||||
// NOTE: SharedPtrWrapper exists to work around a baffling bug in
|
||||
// nvcc; see comment in destroy() below.
|
||||
struct SharedPtrWrapper {
|
||||
SharedPtrWrapper(std::shared_ptr<T> &&x)
|
||||
: repr_(std::move(x)) {}
|
||||
std::shared_ptr<T> repr_;
|
||||
};
|
||||
union Repr {
|
||||
Repr() : Repr(nullptr) {}
|
||||
// Use shared_ptr's aliasing constructor to create a non-owning pointer
|
||||
// to a singleton. The lifetime is tied to the null shared_ptr, so there's
|
||||
// no reference counting overhead for the singleton itself.
|
||||
static std::shared_ptr<T> makeSingletonSharedPtr(T* ptr) {
|
||||
return std::shared_ptr<T>(std::shared_ptr<T>(), ptr);
|
||||
}
|
||||
|
||||
explicit Repr(std::shared_ptr<T> x)
|
||||
: shared_(std::move(x)) {}
|
||||
|
||||
explicit Repr(std::nullptr_t)
|
||||
: singletonRepr_(nullptr) {}
|
||||
|
||||
explicit Repr(SingletonTypePtr<T> p)
|
||||
: singletonRepr_(p.get()) {}
|
||||
|
||||
~Repr() {
|
||||
destroy();
|
||||
}
|
||||
|
||||
// NOTE: the only non-UB way to access our null state is through
|
||||
// rawRepr(), because our copy operation doesn't preserve which
|
||||
// union member is active for null pointers.
|
||||
Repr(const Repr& rhs) {
|
||||
if (rhs.isSharedAndNonNull()) {
|
||||
new (&shared_) SharedPtrWrapper(rhs.shared_);
|
||||
} else {
|
||||
singletonRepr_.singleton_ = static_cast<T*>(rhs.rawRepr().first);
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(rhs.singletonRepr_.unused_ == nullptr);
|
||||
singletonRepr_.unused_ = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
Repr(Repr&& rhs) noexcept {
|
||||
if (rhs.isSharedAndNonNull()) {
|
||||
new (&shared_) SharedPtrWrapper(std::move(rhs.shared_));
|
||||
} else {
|
||||
singletonRepr_.singleton_ = static_cast<T*>(rhs.rawRepr().first);
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(rhs.singletonRepr_.unused_ == nullptr);
|
||||
singletonRepr_.unused_ = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
Repr& operator=(const Repr& rhs) {
|
||||
if (&rhs == this) {
|
||||
return *this;
|
||||
}
|
||||
if (rhs.isSharedAndNonNull()) {
|
||||
if (isSharedAndNonNull()) {
|
||||
shared_ = rhs.shared_;
|
||||
} else {
|
||||
new (&shared_) SharedPtrWrapper(rhs.shared_);
|
||||
}
|
||||
} else {
|
||||
if (isSharedAndNonNull()) {
|
||||
destroy();
|
||||
}
|
||||
singletonRepr_.singleton_ = static_cast<T*>(rhs.rawRepr().first);
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(rhs.rawRepr().nullIfSingleton_ == nullptr);
|
||||
singletonRepr_.unused_ = nullptr;
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
Repr& operator=(Repr&& rhs) noexcept {
|
||||
if (&rhs == this) {
|
||||
return *this;
|
||||
}
|
||||
if (rhs.isSharedAndNonNull()) {
|
||||
if (isSharedAndNonNull()) {
|
||||
shared_ = std::move(rhs.shared_);
|
||||
} else {
|
||||
new (&shared_) SharedPtrWrapper(std::move(rhs.shared_));
|
||||
}
|
||||
} else {
|
||||
if (isSharedAndNonNull()) {
|
||||
destroy();
|
||||
}
|
||||
singletonRepr_.singleton_ = static_cast<T*>(rhs.rawRepr().first);
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(rhs.rawRepr().nullIfSingleton_ == nullptr);
|
||||
singletonRepr_.unused_ = nullptr;
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
SharedPtrWrapper shared_;
|
||||
|
||||
struct SingletonRepr {
|
||||
explicit SingletonRepr(T* s) : singleton_(s) {}
|
||||
T* singleton_;
|
||||
void* unused_ = nullptr;
|
||||
} singletonRepr_;
|
||||
struct RawRepr {
|
||||
void* first;
|
||||
void* nullIfSingleton_;
|
||||
};
|
||||
|
||||
// It is UB to read the singleton part of Repr if it was
|
||||
// constructed as a shared_ptr and vice versa, but memcpying out
|
||||
// the representation is always OK, so here's an accessor to obey
|
||||
// the letter of the law.
|
||||
RawRepr rawRepr() const {
|
||||
RawRepr repr{};
|
||||
memcpy(&repr, reinterpret_cast<const char *>(this), sizeof(RawRepr));
|
||||
return repr;
|
||||
}
|
||||
|
||||
bool isNonNull() const {
|
||||
auto repr = rawRepr();
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(repr.nullIfSingleton_ == nullptr || repr.first != nullptr);
|
||||
return repr.first != nullptr;
|
||||
}
|
||||
|
||||
bool isSharedAndNonNull() const {
|
||||
return rawRepr().nullIfSingleton_ != nullptr;
|
||||
}
|
||||
|
||||
private:
|
||||
void destroy() {
|
||||
if (isSharedAndNonNull()) {
|
||||
// Without SharedPtrWrapper, this line would read
|
||||
// `shared_.~shared_ptr()` and nvcc would complain with
|
||||
// "error: expected primary-expression before '>' token"
|
||||
// referring to the "t" in "shared_ptr". SharedPtrWrapper
|
||||
// exists to work around this compiler bug.
|
||||
shared_.~SharedPtrWrapper();
|
||||
}
|
||||
}
|
||||
} repr_;
|
||||
std::shared_ptr<T> repr_;
|
||||
};
|
||||
|
||||
using TypePtr = SingletonOrSharedTypePtr<Type>;
|
||||
|
||||
@ -21,7 +21,7 @@ namespace c10 {
|
||||
|
||||
namespace detail {
|
||||
// The first argument of the schema might be of type DispatchKeySet, in which case we remove it.
|
||||
// We do this because every argument in a function schema is expected to be convertable
|
||||
// We do this because every argument in a function schema is expected to be convertible
|
||||
// to an ivalue, but DispatchKeySet is not a type we want the jit to be aware of.
|
||||
// See Note [Plumbing Keys Through The Dispatcher]
|
||||
template<class KernelFunctor>
|
||||
|
||||
@ -251,7 +251,7 @@ TEST(OperatorRegistrationTest, whenRegisteringCPUTensorType_thenCanOnlyCallUnbox
|
||||
callOpUnboxedWithPrecomputedDispatchKeySet<void, Tensor>(*op, c10::DispatchKeySet(c10::DispatchKey::CPU), dummyTensor(c10::DispatchKey::CUDA));
|
||||
EXPECT_TRUE(called_kernel_cpu);
|
||||
|
||||
// Ensure that disptach key from tensor is not used here.
|
||||
// Ensure that dispatch key from tensor is not used here.
|
||||
called_kernel_cpu = false;
|
||||
expectThrows<c10::Error>([&] {
|
||||
callOpUnboxedWithPrecomputedDispatchKeySet<void, Tensor>(*op, c10::DispatchKeySet(c10::DispatchKey::CUDA), dummyTensor(c10::DispatchKey::CPU));
|
||||
|
||||
@ -172,7 +172,7 @@ VaryingShape<Stride> TensorType::computeStrideProps(
|
||||
// The logic below follows what TensorIterator uses in its logic:
|
||||
// 1. Fast_set_up is the short-cut to identify a. channels_last and
|
||||
// b. contiguous format, which is what we have in the below logic.
|
||||
// 2. In more generla cases, it does best effort to preserve permutatoin.
|
||||
// 2. In more general cases, it does best effort to preserve permutatoin.
|
||||
if (is_channels_last_strides_2d(sizes, strides) || is_channels_last_strides_3d(sizes, strides)) {
|
||||
// case 1.a. short cut channels last
|
||||
std::iota(stride_indices.rbegin() + 1, stride_indices.rend() - 1, 2);
|
||||
|
||||
@ -104,71 +104,6 @@ class Vectorized<float> {
|
||||
}
|
||||
return b;
|
||||
}
|
||||
// Implementation is picked from
|
||||
// https://github.com/ARM-software/ComputeLibrary/blob/v25.01/src/core/NEON/SVEMath.inl#L105
|
||||
inline svfloat32_t svexp_f32_z(svbool_t pg, svfloat32_t x) const {
|
||||
const auto c1 =
|
||||
svreinterpret_f32_u32(svdup_n_u32(0x3f7ffff6)); // x^1: 0x1.ffffecp-1f
|
||||
const auto c2 =
|
||||
svreinterpret_f32_u32(svdup_n_u32(0x3efffedb)); // x^2: 0x1.fffdb6p-2f
|
||||
const auto c3 =
|
||||
svreinterpret_f32_u32(svdup_n_u32(0x3e2aaf33)); // x^3: 0x1.555e66p-3f
|
||||
const auto c4 =
|
||||
svreinterpret_f32_u32(svdup_n_u32(0x3d2b9f17)); // x^4: 0x1.573e2ep-5f
|
||||
const auto c5 =
|
||||
svreinterpret_f32_u32(svdup_n_u32(0x3c072010)); // x^5: 0x1.0e4020p-7f
|
||||
const auto shift = svreinterpret_f32_u32(
|
||||
svdup_n_u32(0x4b00007f)); // 2^23 + 127 = 0x1.0000fep23f
|
||||
const auto inv_ln2 = svreinterpret_f32_u32(
|
||||
svdup_n_u32(0x3fb8aa3b)); // 1 / ln(2) = 0x1.715476p+0f
|
||||
const auto neg_ln2_hi = svreinterpret_f32_u32(svdup_n_u32(
|
||||
0xbf317200)); // -ln(2) from bits -1 to -19: -0x1.62e400p-1f
|
||||
const auto neg_ln2_lo = svreinterpret_f32_u32(svdup_n_u32(
|
||||
0xb5bfbe8e)); // -ln(2) from bits -20 to -42: -0x1.7f7d1cp-20f
|
||||
const auto inf = svdup_n_f32(std::numeric_limits<float>::infinity());
|
||||
const auto max_input = svdup_n_f32(88.37f); // Approximately ln(2^127.5)
|
||||
const auto zero = svdup_n_f32(0.f);
|
||||
const auto min_input = svdup_n_f32(-86.64f); // Approximately ln(2^-125)
|
||||
// Range reduction:
|
||||
// e^x = 2^n * e^r
|
||||
// where:
|
||||
// n = floor(x / ln(2))
|
||||
// r = x - n * ln(2)
|
||||
//
|
||||
// By adding x / ln(2) with 2^23 + 127 (shift):
|
||||
// * As FP32 fraction part only has 23-bits, the addition of 2^23 + 127
|
||||
// forces decimal part
|
||||
// of x / ln(2) out of the result. The integer part of x / ln(2) (i.e.
|
||||
// n) + 127 will occupy the whole fraction part of z in FP32 format.
|
||||
// Subtracting 2^23 + 127 (shift) from z will result in the integer part
|
||||
// of x / ln(2) (i.e. n) because the decimal part has been pushed out
|
||||
// and lost.
|
||||
// * The addition of 127 makes the FP32 fraction part of z ready to be
|
||||
// used as the exponent
|
||||
// in FP32 format. Left shifting z by 23 bits will result in 2^n.
|
||||
const auto z = svmla_f32_z(pg, shift, x, inv_ln2);
|
||||
const auto n = svsub_f32_z(pg, z, shift);
|
||||
const auto scale = svreinterpret_f32_u32(
|
||||
svlsl_n_u32_z(pg, svreinterpret_u32_f32(z), 23)); // 2^n
|
||||
// The calculation of n * ln(2) is done using 2 steps to achieve accuracy
|
||||
// beyond FP32. This outperforms longer Taylor series (3-4 tabs) both in
|
||||
// term of accuracy and performance.
|
||||
const auto r_hi = svmla_f32_z(pg, x, n, neg_ln2_hi);
|
||||
const auto r = svmla_f32_z(pg, r_hi, n, neg_ln2_lo);
|
||||
// Compute the truncated Taylor series of e^r.
|
||||
// poly = scale * (1 + c1 * r + c2 * r^2 + c3 * r^3 + c4 * r^4 + c5 * r^5)
|
||||
const auto r2 = svmul_f32_z(pg, r, r);
|
||||
const auto p1 = svmul_f32_z(pg, c1, r);
|
||||
const auto p23 = svmla_f32_z(pg, c2, c3, r);
|
||||
const auto p45 = svmla_f32_z(pg, c4, c5, r);
|
||||
const auto p2345 = svmla_f32_z(pg, p23, p45, r2);
|
||||
const auto p12345 = svmla_f32_z(pg, p1, p2345, r2);
|
||||
auto poly = svmla_f32_z(pg, scale, p12345, scale);
|
||||
// Handle underflow and overflow.
|
||||
poly = svsel_f32(svcmplt_f32(pg, x, min_input), zero, poly);
|
||||
poly = svsel_f32(svcmpgt_f32(pg, x, max_input), inf, poly);
|
||||
return poly;
|
||||
}
|
||||
static Vectorized<float> loadu(const void* ptr, int64_t count = size()) {
|
||||
if (count == size())
|
||||
return svld1_f32(ptrue, reinterpret_cast<const float*>(ptr));
|
||||
@ -313,11 +248,41 @@ class Vectorized<float> {
|
||||
return USE_SLEEF(
|
||||
Vectorized<float>(Sleef_expm1fx_u10sve(values)), map(std::expm1));
|
||||
}
|
||||
// Implementation copied from Arm Optimized Routines:
|
||||
// https://github.com/ARM-software/optimized-routines/blob/master/math/aarch64/sve/expf.c
|
||||
Vectorized<float> exp_u20() const {
|
||||
return exp();
|
||||
// special case to handle special inputs that are too large or too small
|
||||
// i.e. where there's at least one element x, s.t. |x| >= 87.3...
|
||||
svbool_t is_special_case = svacgt(svptrue_b32(), values, 0x1.5d5e2ap+6f);
|
||||
if (svptest_any(svptrue_b32(), is_special_case)) {
|
||||
return exp();
|
||||
}
|
||||
const svfloat32_t ln2_hi = svdup_n_f32(0x1.62e4p-1f);
|
||||
const svfloat32_t ln2_lo = svdup_n_f32(0x1.7f7d1cp-20f);
|
||||
const svfloat32_t c1 = svdup_n_f32(0.5f);
|
||||
const svfloat32_t inv_ln2 = svdup_n_f32(0x1.715476p+0f);
|
||||
|
||||
const float shift = 0x1.803f8p17f;
|
||||
|
||||
/* n = round(x/(ln2/N)). */
|
||||
svfloat32_t z = svmad_x(svptrue_b32(), inv_ln2, values, shift);
|
||||
svfloat32_t n = svsub_x(svptrue_b32(), z, shift);
|
||||
|
||||
/* r = x - n*ln2/N. */
|
||||
svfloat32_t r = values;
|
||||
r = svmls_x(svptrue_b32(), r, n, ln2_hi);
|
||||
r = svmls_x(svptrue_b32(), r, n, ln2_lo);
|
||||
|
||||
/* scale = 2^(n/N). */
|
||||
svfloat32_t scale = svexpa(svreinterpret_u32(z));
|
||||
|
||||
/* poly(r) = exp(r) - 1 ~= r + 0.5 r^2. */
|
||||
svfloat32_t r2 = svmul_x(svptrue_b32(), r, r);
|
||||
svfloat32_t poly = svmla_x(svptrue_b32(), r, r2, c1);
|
||||
return svmla_x(svptrue_b32(), scale, scale, poly);
|
||||
}
|
||||
Vectorized<float> fexp_u20() const {
|
||||
return exp();
|
||||
return exp_u20();
|
||||
}
|
||||
Vectorized<float> fmod(const Vectorized<float>& q) const {USE_SLEEF(
|
||||
{ return Vectorized<float>(Sleef_fmodfx_sve(values, q)); },
|
||||
@ -453,9 +418,11 @@ class Vectorized<float> {
|
||||
ptrue, svmax_f32_z(ptrue, values, CONST_MIN_TANH), CONST_MAX_TANH);
|
||||
|
||||
// Step 2: Calculate exp(2 * x), where x is the clamped value.
|
||||
// svmul_f32_z computes 2 * x, and svexp_f32_z computes the exponential of
|
||||
// the result.
|
||||
svfloat32_t exp2x = svexp_f32_z(ptrue, svmul_f32_z(ptrue, CONST_2, x));
|
||||
// svmul_f32_z computes 2 * x, and exp_u20() computes the exponential of
|
||||
// the result (via Vectorized<float>, then auto-converts back to
|
||||
// svfloat32_t).
|
||||
svfloat32_t exp2x =
|
||||
Vectorized<float>(svmul_f32_z(ptrue, CONST_2, x)).exp_u20();
|
||||
|
||||
// Step 3: Calculate the numerator of the tanh function, which is exp(2x)
|
||||
// - 1.
|
||||
|
||||
@ -5,6 +5,114 @@
|
||||
namespace at::vec {
|
||||
inline namespace CPU_CAPABILITY {
|
||||
#if (defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE256))
|
||||
|
||||
// Enable auto-vectorization for GCC-13+ and clang-17+
|
||||
// GCC-12 has a bug: gcc.gnu.org/bugzilla/show_bug.cgi?id=117001
|
||||
#if __GNUC__ > 12 || (defined(__clang__) && (__clang_major__ >= 17))
|
||||
|
||||
template <typename from_type, typename to_type>
|
||||
inline void convertImpl(
|
||||
const from_type* __restrict src,
|
||||
to_type* __restrict dst,
|
||||
int64_t n) {
|
||||
uint64_t len = static_cast<uint64_t>(n);
|
||||
for (uint64_t i = 0; i < len; i++) {
|
||||
dst[i] = static_cast<to_type>(src[i]);
|
||||
}
|
||||
}
|
||||
|
||||
#define CONVERT_TEMPLATE(from_type, to_type) \
|
||||
template <> \
|
||||
inline void convert(const from_type* src, to_type* dst, int64_t n) { \
|
||||
return convertImpl<from_type, to_type>(src, dst, n); \
|
||||
}
|
||||
|
||||
CONVERT_TEMPLATE(uint8_t, uint8_t)
|
||||
CONVERT_TEMPLATE(uint8_t, int8_t)
|
||||
CONVERT_TEMPLATE(uint8_t, int16_t)
|
||||
CONVERT_TEMPLATE(uint8_t, int32_t)
|
||||
CONVERT_TEMPLATE(uint8_t, int64_t)
|
||||
CONVERT_TEMPLATE(uint8_t, float)
|
||||
CONVERT_TEMPLATE(uint8_t, double)
|
||||
CONVERT_TEMPLATE(int8_t, uint8_t)
|
||||
CONVERT_TEMPLATE(int8_t, int8_t)
|
||||
CONVERT_TEMPLATE(int8_t, int16_t)
|
||||
CONVERT_TEMPLATE(int8_t, int32_t)
|
||||
CONVERT_TEMPLATE(int8_t, int64_t)
|
||||
CONVERT_TEMPLATE(int8_t, float)
|
||||
CONVERT_TEMPLATE(int8_t, double)
|
||||
CONVERT_TEMPLATE(int16_t, uint8_t)
|
||||
CONVERT_TEMPLATE(int16_t, int8_t)
|
||||
CONVERT_TEMPLATE(int16_t, int16_t)
|
||||
CONVERT_TEMPLATE(int16_t, int32_t)
|
||||
CONVERT_TEMPLATE(int16_t, int64_t)
|
||||
CONVERT_TEMPLATE(int16_t, float)
|
||||
CONVERT_TEMPLATE(int16_t, double)
|
||||
CONVERT_TEMPLATE(int32_t, uint8_t)
|
||||
CONVERT_TEMPLATE(int32_t, int8_t)
|
||||
CONVERT_TEMPLATE(int32_t, int16_t)
|
||||
CONVERT_TEMPLATE(int32_t, int32_t)
|
||||
CONVERT_TEMPLATE(int32_t, int64_t)
|
||||
CONVERT_TEMPLATE(int32_t, float)
|
||||
CONVERT_TEMPLATE(int32_t, double)
|
||||
CONVERT_TEMPLATE(int64_t, uint8_t)
|
||||
CONVERT_TEMPLATE(int64_t, int8_t)
|
||||
CONVERT_TEMPLATE(int64_t, int16_t)
|
||||
CONVERT_TEMPLATE(int64_t, int32_t)
|
||||
CONVERT_TEMPLATE(int64_t, int64_t)
|
||||
CONVERT_TEMPLATE(int64_t, float)
|
||||
CONVERT_TEMPLATE(int64_t, double)
|
||||
CONVERT_TEMPLATE(float, uint8_t)
|
||||
CONVERT_TEMPLATE(float, int8_t)
|
||||
CONVERT_TEMPLATE(float, int16_t)
|
||||
CONVERT_TEMPLATE(float, int32_t)
|
||||
CONVERT_TEMPLATE(float, int64_t)
|
||||
CONVERT_TEMPLATE(float, float)
|
||||
CONVERT_TEMPLATE(float, double)
|
||||
CONVERT_TEMPLATE(double, uint8_t)
|
||||
CONVERT_TEMPLATE(double, int8_t)
|
||||
CONVERT_TEMPLATE(double, int16_t)
|
||||
CONVERT_TEMPLATE(double, int32_t)
|
||||
CONVERT_TEMPLATE(double, int64_t)
|
||||
CONVERT_TEMPLATE(double, float)
|
||||
CONVERT_TEMPLATE(double, double)
|
||||
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
|
||||
CONVERT_TEMPLATE(float16_t, uint8_t)
|
||||
CONVERT_TEMPLATE(float16_t, int8_t)
|
||||
CONVERT_TEMPLATE(float16_t, int16_t)
|
||||
CONVERT_TEMPLATE(float16_t, int32_t)
|
||||
CONVERT_TEMPLATE(float16_t, int64_t)
|
||||
CONVERT_TEMPLATE(float16_t, float16_t)
|
||||
CONVERT_TEMPLATE(float16_t, float)
|
||||
CONVERT_TEMPLATE(float16_t, double)
|
||||
CONVERT_TEMPLATE(uint8_t, float16_t)
|
||||
CONVERT_TEMPLATE(int8_t, float16_t)
|
||||
CONVERT_TEMPLATE(int16_t, float16_t)
|
||||
CONVERT_TEMPLATE(int32_t, float16_t)
|
||||
CONVERT_TEMPLATE(int64_t, float16_t)
|
||||
CONVERT_TEMPLATE(float, float16_t)
|
||||
CONVERT_TEMPLATE(double, float16_t)
|
||||
#endif
|
||||
#ifdef __ARM_FEATURE_BF16
|
||||
CONVERT_TEMPLATE(bfloat16_t, uint8_t)
|
||||
CONVERT_TEMPLATE(bfloat16_t, int8_t)
|
||||
CONVERT_TEMPLATE(bfloat16_t, int16_t)
|
||||
CONVERT_TEMPLATE(bfloat16_t, int32_t)
|
||||
CONVERT_TEMPLATE(bfloat16_t, int64_t)
|
||||
CONVERT_TEMPLATE(bfloat16_t, bfloat16_t)
|
||||
CONVERT_TEMPLATE(bfloat16_t, float)
|
||||
CONVERT_TEMPLATE(bfloat16_t, double)
|
||||
CONVERT_TEMPLATE(uint8_t, bfloat16_t)
|
||||
CONVERT_TEMPLATE(int8_t, bfloat16_t)
|
||||
CONVERT_TEMPLATE(int16_t, bfloat16_t)
|
||||
CONVERT_TEMPLATE(int32_t, bfloat16_t)
|
||||
CONVERT_TEMPLATE(int64_t, bfloat16_t)
|
||||
CONVERT_TEMPLATE(float, bfloat16_t)
|
||||
CONVERT_TEMPLATE(double, bfloat16_t)
|
||||
#endif
|
||||
|
||||
#endif
|
||||
|
||||
template <typename src_t>
|
||||
struct VecConvert<
|
||||
float,
|
||||
|
||||
@ -307,11 +307,49 @@ class Vectorized<float> {
|
||||
DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(exp)
|
||||
DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(exp2)
|
||||
DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(expm1)
|
||||
// Implementation copied from Arm Optimized Routine
|
||||
// https://github.com/ARM-software/optimized-routines/blob/master/math/aarch64/advsimd/expf.c
|
||||
Vectorized<float> exp_u20() const {
|
||||
return exp();
|
||||
// bail out to sleef if it's a special case:
|
||||
// i.e. there's an input s.t. |input| > 87.3....
|
||||
const float32x4_t special_bound = vdupq_n_f32(0x1.5d5e2ap+6f);
|
||||
uint32x4_t cmp = vcagtq_f32(values, special_bound);
|
||||
if (vpaddd_u64(vreinterpretq_u64_u32(cmp)) != 0) {
|
||||
return exp();
|
||||
}
|
||||
|
||||
const float32x4_t inv_ln2 = vdupq_n_f32(0x1.715476p+0f);
|
||||
const float ln2_hi = 0x1.62e4p-1f;
|
||||
const float ln2_lo = 0x1.7f7d1cp-20f;
|
||||
const float c0 = 0x1.0e4020p-7f;
|
||||
const float c2 = 0x1.555e66p-3f;
|
||||
const float32x4_t ln2_c02 = {ln2_hi, ln2_lo, c0, c2};
|
||||
|
||||
const uint32x4_t exponent_bias = vdupq_n_u32(0x3f800000);
|
||||
const float32x4_t c1 = vdupq_n_f32(0x1.573e2ep-5f);
|
||||
const float32x4_t c3 = vdupq_n_f32(0x1.fffdb6p-2f);
|
||||
const float32x4_t c4 = vdupq_n_f32(0x1.ffffecp-1f);
|
||||
|
||||
/* exp(x) = 2^n (1 + poly(r)), with 1 + poly(r) in [1/sqrt(2),sqrt(2)]
|
||||
x = ln2*n + r, with r in [-ln2/2, ln2/2]. */
|
||||
|
||||
float32x4_t n = vrndaq_f32(vmulq_f32(values, inv_ln2));
|
||||
float32x4_t r = vfmsq_laneq_f32(values, n, ln2_c02, 0);
|
||||
r = vfmsq_laneq_f32(r, n, ln2_c02, 1);
|
||||
uint32x4_t e = vshlq_n_u32(vreinterpretq_u32_s32(vcvtq_s32_f32(n)), 23);
|
||||
float32x4_t scale = vreinterpretq_f32_u32(vaddq_u32(e, exponent_bias));
|
||||
|
||||
float32x4_t r2 = vmulq_f32(r, r);
|
||||
float32x4_t p = vfmaq_laneq_f32(c1, r, ln2_c02, 2);
|
||||
float32x4_t q = vfmaq_laneq_f32(c3, r, ln2_c02, 3);
|
||||
q = vfmaq_f32(q, p, r2);
|
||||
p = vmulq_f32(c4, r);
|
||||
float32x4_t poly = vfmaq_f32(p, q, r2);
|
||||
|
||||
return vfmaq_f32(scale, poly, scale);
|
||||
}
|
||||
Vectorized<float> fexp_u20() const {
|
||||
return exp();
|
||||
return exp_u20();
|
||||
}
|
||||
DEFINE_SLEEF_COMPATIBLE_BINARY_ELEMENTWISE_FUNC_WITH_SLEEF_NAME(
|
||||
fmod,
|
||||
@ -540,42 +578,6 @@ inline Vectorized<float> Vectorized<float>::le(
|
||||
return (*this <= other) & Vectorized<float>(1.0f);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline void convert(const float* src, int32_t* dst, int64_t n) {
|
||||
int64_t i;
|
||||
#ifndef __msvc_cl__
|
||||
#pragma unroll
|
||||
#endif
|
||||
for (i = 0; i <= (n - Vectorized<float>::size());
|
||||
i += Vectorized<float>::size()) {
|
||||
vst1q_s32(dst + i, vcvtq_s32_f32(vld1q_f32(src + i)));
|
||||
}
|
||||
#ifndef __msvc_cl__
|
||||
#pragma unroll
|
||||
#endif
|
||||
for (; i < n; i++) {
|
||||
dst[i] = static_cast<int32_t>(src[i]);
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
inline void convert(const int32_t* src, float* dst, int64_t n) {
|
||||
int64_t i;
|
||||
#ifndef __msvc_cl__
|
||||
#pragma unroll
|
||||
#endif
|
||||
for (i = 0; i <= (n - Vectorized<float>::size());
|
||||
i += Vectorized<float>::size()) {
|
||||
vst1q_f32(dst + i, vcvtq_f32_s32(vld1q_s32(src + i)));
|
||||
}
|
||||
#ifndef __msvc_cl__
|
||||
#pragma unroll
|
||||
#endif
|
||||
for (; i < n; i++) {
|
||||
dst[i] = static_cast<float>(src[i]);
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<float> inline fmadd(
|
||||
const Vectorized<float>& a,
|
||||
@ -632,8 +634,7 @@ inline Vectorized<float> Vectorized<float>::erf() const {
|
||||
// - exp(- x * x)
|
||||
auto pow_2 = (*this) * (*this);
|
||||
auto neg_pow_2 = pow_2 ^ neg_zero_vec;
|
||||
auto tmp4 = neg_pow_2.map(
|
||||
std::exp); // This can be swapped for a faster implementation of exp.
|
||||
auto tmp4 = neg_pow_2.exp();
|
||||
auto tmp5 = tmp4 ^ neg_zero_vec;
|
||||
// erf(x) = sign(x) * (1 - r * t * exp(- x * x))
|
||||
auto tmp6 = t * tmp5;
|
||||
|
||||
@ -234,7 +234,7 @@ class Vectorized<c10::Half> : public Vectorized16<
|
||||
vshlq_u16(vandq_u16(is_zero_vec, vdupq_n_u16(1)), shift);
|
||||
return vaddvq_u16(bits_vec);
|
||||
#else // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
|
||||
// use known working implmentation.
|
||||
// use known working implementation.
|
||||
__at_align__ value_type tmp[size()];
|
||||
store(tmp);
|
||||
int mask = 0;
|
||||
@ -569,46 +569,6 @@ inline Vectorized<c10::Half> Vectorized<c10::Half>::le(
|
||||
return (*this <= other) & Vectorized<c10::Half>(1);
|
||||
}
|
||||
|
||||
// These are global functions, so the defaults in vec_base.h should
|
||||
// work fine if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC is not available.
|
||||
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
|
||||
template <>
|
||||
inline void convert(const float16_t* src, int16_t* dst, int64_t n) {
|
||||
int64_t i;
|
||||
#ifndef __msvc_cl__
|
||||
#pragma unroll
|
||||
#endif
|
||||
for (i = 0; i <= (n - Vectorized<c10::Half>::size());
|
||||
i += Vectorized<c10::Half>::size()) {
|
||||
vst1q_s16(dst + i, vcvtq_s16_f16(vld1q_f16(src + i)));
|
||||
}
|
||||
#ifndef __msvc_cl__
|
||||
#pragma unroll
|
||||
#endif
|
||||
for (; i < n; i++) {
|
||||
dst[i] = static_cast<int16_t>(src[i]);
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
inline void convert(const int16_t* src, float16_t* dst, int64_t n) {
|
||||
int64_t i;
|
||||
#ifndef __msvc_cl__
|
||||
#pragma unroll
|
||||
#endif
|
||||
for (i = 0; i <= (n - Vectorized<c10::Half>::size());
|
||||
i += Vectorized<c10::Half>::size()) {
|
||||
vst1q_f16(dst + i, vcvtq_f16_s16(vld1q_s16(src + i)));
|
||||
}
|
||||
#ifndef __msvc_cl__
|
||||
#pragma unroll
|
||||
#endif
|
||||
for (; i < n; i++) {
|
||||
dst[i] = static_cast<float16_t>(src[i]);
|
||||
}
|
||||
}
|
||||
#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
|
||||
|
||||
template <>
|
||||
Vectorized<c10::Half> inline fmadd(
|
||||
const Vectorized<c10::Half>& a,
|
||||
|
||||
@ -1740,7 +1740,7 @@ Vectorized<int16_t> inline shift_256_16(
|
||||
|
||||
// Control masks for shuffle operation, treating 256 bits as an
|
||||
// array of 16-bit elements, and considering pairs of neighboring
|
||||
// elements. Specifially, a mask named "ctl_M_N" (M,N in [0,1], and
|
||||
// elements. Specifically, a mask named "ctl_M_N" (M,N in [0,1], and
|
||||
// M!=N) is set so that shuffle will move element with index M from
|
||||
// input pair into element with index N in output pair, and element
|
||||
// with index M in output pair will be set to all 0s.
|
||||
@ -1875,7 +1875,7 @@ Vectorized<T> inline shift_256_8(
|
||||
|
||||
// Control masks for shuffle operation, treating 256 bits as an
|
||||
// array of 8-bit elements, and considering quadruples of
|
||||
// neighboring elements. Specifially, a mask named "ctl_M_N" (M,N
|
||||
// neighboring elements. Specifically, a mask named "ctl_M_N" (M,N
|
||||
// in [0,1,2,3], and M!=N) is set so that shuffle will move element
|
||||
// with index M from input quadruple into element with index N in
|
||||
// output quadruple, and other elements in output quadruple will be
|
||||
|
||||
@ -143,7 +143,7 @@ class Vectorized<double> {
|
||||
const Vectorized<double>& a,
|
||||
const Vectorized<double>& b,
|
||||
const Vectorized<double>& mask) {
|
||||
// the mask used here returned by comparision of vec256
|
||||
// the mask used here returned by comparison of vec256
|
||||
|
||||
return {
|
||||
vec_sel(a._vec0, b._vec0, mask._vecb0),
|
||||
|
||||
@ -142,7 +142,7 @@ class Vectorized<float> {
|
||||
const Vectorized<float>& a,
|
||||
const Vectorized<float>& b,
|
||||
const Vectorized<float>& mask) {
|
||||
// the mask used here returned by comparision of vec256
|
||||
// the mask used here returned by comparison of vec256
|
||||
// assuming this we can use the same mask directly with vec_sel
|
||||
return {
|
||||
vec_sel(a._vec0, b._vec0, mask._vecb0),
|
||||
|
||||
@ -202,7 +202,7 @@ class Vectorized<int16_t> {
|
||||
const Vectorized<int16_t>& a,
|
||||
const Vectorized<int16_t>& b,
|
||||
const Vectorized<int16_t>& mask) {
|
||||
// the mask used here returned by comparision of vec256
|
||||
// the mask used here returned by comparison of vec256
|
||||
// assuming this we can use the same mask directly with vec_sel
|
||||
// warning intel style mask will not work properly
|
||||
return {
|
||||
|
||||
@ -155,7 +155,7 @@ class Vectorized<int32_t> {
|
||||
const Vectorized<int32_t>& a,
|
||||
const Vectorized<int32_t>& b,
|
||||
const Vectorized<int32_t>& mask) {
|
||||
// the mask used here returned by comparision of vec256
|
||||
// the mask used here returned by comparison of vec256
|
||||
// assuming this we can use the same mask directly with vec_sel
|
||||
// warning intel style mask will not work properly
|
||||
return {
|
||||
|
||||
@ -119,7 +119,7 @@ class Vectorized<int64_t> {
|
||||
const Vectorized<int64_t>& a,
|
||||
const Vectorized<int64_t>& b,
|
||||
const Vectorized<int64_t>& mask) {
|
||||
// the mask used here returned by comparision of vec256
|
||||
// the mask used here returned by comparison of vec256
|
||||
|
||||
return {
|
||||
vec_sel(a._vec0, b._vec0, mask._vecb0),
|
||||
|
||||
@ -397,7 +397,7 @@ inline Vectorized<bool> operator&&(
|
||||
const __m512i* other_ = reinterpret_cast<const __m512i*>(other.as_bytes());
|
||||
__m512i out = _mm512_and_si512(*self_, *other_);
|
||||
Vectorized<bool> ret;
|
||||
// We do not have a constructer that takes __m512i, so we need to memcpy
|
||||
// We do not have a constructor that takes __m512i, so we need to memcpy
|
||||
std::memcpy(ret, &out, ret.size() * sizeof(bool));
|
||||
return ret;
|
||||
}
|
||||
|
||||
@ -1852,7 +1852,7 @@ Vectorized<T> inline shift_512_8(
|
||||
|
||||
// Control masks for shuffle operation, treating 512 bits as an
|
||||
// array of 8-bit elements, and considering pairs of neighboring
|
||||
// elements. Specifially, a mask named "ctl_M_N" (M,N in [0,1], and
|
||||
// elements. Specifically, a mask named "ctl_M_N" (M,N in [0,1], and
|
||||
// M!=N) is set so that shuffle will move element with index M from
|
||||
// input pair into element with index N in output pair, and element
|
||||
// with index M in output pair will be set to all 0s.
|
||||
|
||||
@ -634,7 +634,7 @@ struct Vectorized {
|
||||
}
|
||||
Vectorized<T> neg() const {
|
||||
// NB: the trailing return type is needed because we need to coerce the
|
||||
// return value back to T in the case of unary operator- incuring a
|
||||
// return value back to T in the case of unary operator- incurring a
|
||||
// promotion
|
||||
return map([](T x) -> T { return -x; });
|
||||
}
|
||||
|
||||
@ -1958,7 +1958,7 @@ void scaled_gemm(
|
||||
ScalarType result_dtype,
|
||||
bool use_fast_accum,
|
||||
const std::optional<Tensor>& alpha) {
|
||||
// Note: see `cublasCommonArgs` for various non-intuitive manupulations
|
||||
// Note: see `cublasCommonArgs` for various non-intuitive manipulations
|
||||
// of input arguments to this function.
|
||||
const auto computeType = CUBLAS_COMPUTE_32F;
|
||||
const auto scaleType = CUDA_R_32F;
|
||||
|
||||
@ -168,11 +168,9 @@ void CUDAGraph::instantiate() {
|
||||
// https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__GRAPH.html#group__CUDART__GRAPH_1g1accfe1da0c605a577c22d9751a09597
|
||||
// cudaGraphInstantiateWithFlags
|
||||
// https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__GRAPH.html#group__CUDART__GRAPH_1ga2c652a24ba93e52b99a47bec0888233
|
||||
#if !defined(USE_ROCM) || ROCM_VERSION >= 60200
|
||||
int version = 0;
|
||||
AT_CUDA_CHECK(cudaDriverGetVersion(&version));
|
||||
if (version < 11040) {
|
||||
#endif
|
||||
// Trailing NULL, NULL, 0 arguments were recommended by Cuda driver people,
|
||||
// who prefer not to report error message through these arguments moving forward
|
||||
// (they prefer return value, or errors on api calls internal to the capture)
|
||||
@ -183,13 +181,11 @@ void CUDAGraph::instantiate() {
|
||||
#endif
|
||||
//Since ROCm 6.2, we want to go down this path as hipGraphExecDestroy in the destructor will not immediately free the memory.
|
||||
//It will wait for the next sync operation. cudaGraphInstantiateFlagAutoFreeOnLaunch will add async frees after graph launch.
|
||||
#if !defined(USE_ROCM) || ROCM_VERSION >= 60200
|
||||
} else {
|
||||
AT_CUDA_CHECK(cudaGraphInstantiateWithFlags(&graph_exec_,
|
||||
graph_,
|
||||
cudaGraphInstantiateFlagAutoFreeOnLaunch));
|
||||
}
|
||||
#endif
|
||||
has_graph_exec_ = true;
|
||||
}
|
||||
|
||||
@ -311,7 +307,7 @@ CUDAGraph::~CUDAGraph() {
|
||||
// There are recent HIP changes where hipGraphExecDestroy doesn't immediately free memory.
|
||||
// They wait for next sync point in order to free the memory, this is to ensure that all
|
||||
// hipGraphLaunch are finished before we release any memory. This feature was enabled in rocm6.2.
|
||||
// We need to ensure all async opreations finish before deleting the object.
|
||||
// We need to ensure all async operations finish before deleting the object.
|
||||
#if (defined(USE_ROCM) && ROCM_VERSION >= 60200)
|
||||
if (capture_dev_ != UNDEFINED_DEVICE) // check if capture_dev_ contains the real device id
|
||||
{
|
||||
|
||||
270
aten/src/ATen/cuda/CUDAScaledBlas.cpp
Normal file
270
aten/src/ATen/cuda/CUDAScaledBlas.cpp
Normal file
@ -0,0 +1,270 @@
|
||||
#include <cstdint>
|
||||
#include <c10/util/typeid.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#include <c10/util/SmallVector.h>
|
||||
#include <c10/core/Scalar.h>
|
||||
#include <c10/core/ScalarType.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
|
||||
#include <ATen/core/Tensor.h>
|
||||
#include <ATen/core/NamedTensor.h>
|
||||
#include <ATen/Dispatch.h>
|
||||
#include <ATen/ExpandUtils.h>
|
||||
#include <ATen/OpMathType.h>
|
||||
#include <ATen/TensorUtils.h>
|
||||
#include <ATen/cuda/CUDABlas.h>
|
||||
#include <ATen/cuda/tunable/Tunable.h>
|
||||
#include <ATen/cuda/tunable/TunableGemm.h>
|
||||
#include <ATen/native/Resize.h>
|
||||
#include <c10/util/MaybeOwned.h>
|
||||
#include <ATen/native/GroupedMMUtils.h>
|
||||
#include <ATen/native/cuda/RowwiseScaledMM.h>
|
||||
#include <ATen/native/cuda/ScaledGroupMM.h>
|
||||
#include <ATen/native/cuda/GroupMM.h>
|
||||
#include <ATen/ceil_div.h>
|
||||
|
||||
#ifdef USE_FBGEMM_GENAI
|
||||
#include <fbgemm_gpu/torch_ops.h>
|
||||
#endif
|
||||
|
||||
#ifndef AT_PER_OPERATOR_HEADERS
|
||||
#include <ATen/Functions.h>
|
||||
#include <ATen/NativeFunctions.h>
|
||||
#else
|
||||
#include <ATen/ops/_addmm_activation_native.h>
|
||||
#include <ATen/ops/_efficientzerotensor.h>
|
||||
#include <ATen/ops/_scaled_mm_native.h>
|
||||
#include <ATen/ops/_unsafe_view_native.h>
|
||||
#include <ATen/ops/abs.h>
|
||||
#include <ATen/ops/addmm_native.h>
|
||||
#include <ATen/ops/addmv_native.h>
|
||||
#include <ATen/ops/baddbmm_native.h>
|
||||
#include <ATen/ops/bmm_native.h>
|
||||
#include <ATen/ops/copy_native.h>
|
||||
#include <ATen/ops/dot_native.h>
|
||||
#include <ATen/ops/empty.h>
|
||||
#include <ATen/ops/empty_strided.h>
|
||||
#include <ATen/ops/gelu.h>
|
||||
#include <ATen/ops/max.h>
|
||||
#include <ATen/ops/mm_native.h>
|
||||
#include <ATen/ops/mul.h>
|
||||
#include <ATen/ops/relu.h>
|
||||
#include <ATen/ops/ones.h>
|
||||
#include <ATen/ops/scalar_tensor_native.h>
|
||||
#include <ATen/ops/vdot_native.h>
|
||||
#endif
|
||||
|
||||
using at::blas::ScalingType;
|
||||
using at::blas::SwizzleType;
|
||||
|
||||
namespace at::cuda::scaled {
|
||||
|
||||
/**
|
||||
* Both inputs must be fp8,
|
||||
* Each needs a single scale, {Tensorwise (float)}
|
||||
*/
|
||||
bool check_tensorwise_recipe(c10::ScalarType type_a,
|
||||
std::vector<ScalingType>& recipe_a,
|
||||
ArrayRef<Tensor>& scales_a,
|
||||
c10::ScalarType type_b,
|
||||
std::vector<ScalingType>& recipe_b,
|
||||
ArrayRef<Tensor>& scales_b) {
|
||||
// both types must be fp8
|
||||
if (!isFloat8Type(type_a) || !isFloat8Type(type_b)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// 1 scale each, {Tensorwise, float}
|
||||
if (scales_a.size() != 1 || recipe_a.size() != 1 || scales_b.size() != 1 || recipe_b.size() != 1) {
|
||||
return false;
|
||||
}
|
||||
// Need {Blockwise_1x32, e8m0} for A & B
|
||||
if (recipe_a[0] != ScalingType::TensorWise) return false;
|
||||
if (scales_a[0].scalar_type() != ScalarType::Float) return false;
|
||||
if (recipe_b[0] != ScalingType::TensorWise) return false;
|
||||
if (scales_b[0].scalar_type() != ScalarType::Float) return false;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
* Both inputs must be fp8,
|
||||
* Each needs scales, {Rowwise (float)}
|
||||
*/
|
||||
bool check_rowwise_recipe(c10::ScalarType type_a,
|
||||
std::vector<ScalingType>& recipe_a,
|
||||
ArrayRef<Tensor>& scales_a,
|
||||
c10::ScalarType type_b,
|
||||
std::vector<ScalingType>& recipe_b,
|
||||
ArrayRef<Tensor>& scales_b) {
|
||||
// both types must be fp8
|
||||
if (!isFloat8Type(type_a) || !isFloat8Type(type_b)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// 1 scale each, {Tensorwise, float}
|
||||
if (scales_a.size() != 1 || recipe_a.size() != 1 || scales_b.size() != 1 || recipe_b.size() != 1) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Need {RowWise, dp32} for A & B
|
||||
if (recipe_a[0] != ScalingType::RowWise) return false;
|
||||
if (scales_a[0].scalar_type() != ScalarType::Float) return false;
|
||||
if (recipe_b[0] != ScalingType::RowWise) return false;
|
||||
if (scales_b[0].scalar_type() != ScalarType::Float) return false;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Two-level scaling, canonical NVFP4
|
||||
* Both inputs must be fp4
|
||||
* A, B need 2 scales, {Blockwise_1x16 (e4m3), Tensorwise (fp32)}
|
||||
*/
|
||||
bool check_nvfp4_recipe(c10::ScalarType type_a,
|
||||
std::vector<ScalingType>& recipe_a,
|
||||
ArrayRef<Tensor>& scales_a,
|
||||
c10::ScalarType type_b,
|
||||
std::vector<ScalingType>& recipe_b,
|
||||
ArrayRef<Tensor>& scales_b) {
|
||||
// both types must be fp4
|
||||
if (type_a != ScalarType::Float4_e2m1fn_x2 || type_b != ScalarType::Float4_e2m1fn_x2) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// 2 scales, 2 recipes for each input
|
||||
if (scales_a.size() != 2 || recipe_a.size() != 2 || scales_b.size() != 2 || recipe_b.size() != 2) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Need {Blockwise_1x16, e4m3 for scale[0], Tensorwise, fp32 for scale[1]}
|
||||
if (recipe_a[0] != ScalingType::BlockWise1x16 || recipe_a[1] != ScalingType::TensorWise) return false;
|
||||
if (scales_a[0].scalar_type() != ScalarType::Float8_e4m3fn || scales_a[1].scalar_type() != ScalarType::Float) return false;
|
||||
if (recipe_b[0] != ScalingType::BlockWise1x16 || recipe_b[1] != ScalingType::TensorWise) return false;
|
||||
if (scales_b[0].scalar_type() != ScalarType::Float8_e4m3fn || scales_b[1].scalar_type() != ScalarType::Float) return false;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
* Single-level scaling, what PyT currently understands
|
||||
* Both inputs must be fp4
|
||||
* A, B need 1 scale, {Blockwise_1x16 (e4m3)}
|
||||
*/
|
||||
bool check_nvfp4_recipe_single_scale
|
||||
(c10::ScalarType type_a,
|
||||
std::vector<ScalingType>& recipe_a,
|
||||
ArrayRef<Tensor>& scales_a,
|
||||
c10::ScalarType type_b,
|
||||
std::vector<ScalingType>& recipe_b,
|
||||
ArrayRef<Tensor>& scales_b) {
|
||||
// both types must be fp4
|
||||
if (type_a != ScalarType::Float4_e2m1fn_x2 || type_b != ScalarType::Float4_e2m1fn_x2) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// 2 scales, 2 recipes for each input
|
||||
if (scales_a.size() != 1 || recipe_a.size() != 1 || scales_b.size() != 1 || recipe_b.size() != 1) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Need {Blockwise_1x16, e4m3 for scale[0], Tensorwise, fp32 for scale[1]}
|
||||
if (recipe_a[0] != ScalingType::BlockWise1x16) return false;
|
||||
if (scales_a[0].scalar_type() != ScalarType::Float8_e4m3fn) return false;
|
||||
if (recipe_b[0] != ScalingType::BlockWise1x16) return false;
|
||||
if (scales_b[0].scalar_type() != ScalarType::Float8_e4m3fn) return false;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
* Both inputs must be fp8
|
||||
* A, B must only have 1 scale each, A: {Blockwise_1x128 (float), B: {Blockwise_128x128 (float)
|
||||
*/
|
||||
bool check_deepseek_recipe(ScalingType expected_recipe_a,
|
||||
ScalingType expected_recipe_b,
|
||||
c10::ScalarType type_a,
|
||||
std::vector<ScalingType>& recipe_a,
|
||||
ArrayRef<Tensor>& scales_a,
|
||||
c10::ScalarType type_b,
|
||||
std::vector<ScalingType>& recipe_b,
|
||||
ArrayRef<Tensor>& scales_b) {
|
||||
// both types must be fp8
|
||||
if (type_a != ScalarType::Float8_e4m3fn || type_b != ScalarType::Float8_e4m3fn) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// 1 scales, 1 recipes for each input
|
||||
if (scales_a.size() != 1 || recipe_a.size() != 1 || scales_b.size() != 1 || recipe_b.size() != 1) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Need {Blockwise_1x128, float} for A, {Blockwise_128x128, float} for B
|
||||
if (recipe_a[0] != expected_recipe_a) return false;
|
||||
if (scales_a[0].scalar_type() != ScalarType::Float) return false;
|
||||
if (recipe_b[0] != expected_recipe_b) return false;
|
||||
if (scales_b[0].scalar_type() != ScalarType::Float) return false;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
* Both inputs must be fp8
|
||||
* A, B must have 1 scale each, {Blockwise_1x32, e8m0}
|
||||
*/
|
||||
bool check_mxfp8_recipe(c10::ScalarType type_a,
|
||||
std::vector<ScalingType>& recipe_a,
|
||||
ArrayRef<Tensor>& scales_a,
|
||||
c10::ScalarType type_b,
|
||||
std::vector<ScalingType>& recipe_b,
|
||||
ArrayRef<Tensor>& scales_b) {
|
||||
// both types must be fp8
|
||||
if (type_a != ScalarType::Float8_e4m3fn || type_b != ScalarType::Float8_e4m3fn) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// 1 scales, 1 recipes for each input
|
||||
if (scales_a.size() != 1 || recipe_a.size() != 1 || scales_b.size() != 1 || recipe_b.size() != 1) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Need {Blockwise_1x32, e8m0} for A & B
|
||||
if (recipe_a[0] != ScalingType::BlockWise1x32) return false;
|
||||
if (scales_a[0].scalar_type() != ScalarType::Float8_e8m0fnu) return false;
|
||||
if (recipe_b[0] != ScalingType::BlockWise1x32) return false;
|
||||
if (scales_b[0].scalar_type() != ScalarType::Float8_e8m0fnu) return false;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
* Both inputs must be fp4
|
||||
* A, B must have 1 scale each, {Blockwise_1x32, e8m0}
|
||||
*/
|
||||
bool check_mxfp4_recipe(c10::ScalarType type_a,
|
||||
std::vector<ScalingType>& recipe_a,
|
||||
ArrayRef<Tensor>& scales_a,
|
||||
c10::ScalarType type_b,
|
||||
std::vector<ScalingType>& recipe_b,
|
||||
ArrayRef<Tensor>& scales_b) {
|
||||
// both types must be fp4
|
||||
if (type_a != ScalarType::Float4_e2m1fn_x2 || type_b != ScalarType::Float4_e2m1fn_x2) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// 1 scales, 1 recipes for each input
|
||||
if (scales_a.size() != 1 || recipe_a.size() != 1 || scales_b.size() != 1 || recipe_b.size() != 1) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Need {Blockwise_1x32, e8m0} for A & B
|
||||
if (recipe_a[0] != ScalingType::BlockWise1x32) return false;
|
||||
if (scales_a[0].scalar_type() != ScalarType::Float8_e8m0fnu) return false;
|
||||
if (recipe_b[0] != ScalingType::BlockWise1x32) return false;
|
||||
if (scales_b[0].scalar_type() != ScalarType::Float8_e8m0fnu) return false;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace at::native::cuda::blas::scaled
|
||||
174
aten/src/ATen/cuda/CUDAScaledBlas.h
Normal file
174
aten/src/ATen/cuda/CUDAScaledBlas.h
Normal file
@ -0,0 +1,174 @@
|
||||
#include <cstdint>
|
||||
#include <c10/util/typeid.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#include <c10/util/SmallVector.h>
|
||||
#include <c10/core/Scalar.h>
|
||||
#include <c10/core/ScalarType.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
|
||||
#include <ATen/core/Tensor.h>
|
||||
#include <ATen/core/NamedTensor.h>
|
||||
#include <ATen/Dispatch.h>
|
||||
#include <ATen/ExpandUtils.h>
|
||||
#include <ATen/OpMathType.h>
|
||||
#include <ATen/TensorUtils.h>
|
||||
#include <ATen/cuda/CUDABlas.h>
|
||||
#include <ATen/cuda/tunable/Tunable.h>
|
||||
#include <ATen/cuda/tunable/TunableGemm.h>
|
||||
#include <ATen/native/Resize.h>
|
||||
#include <c10/util/MaybeOwned.h>
|
||||
#include <ATen/native/GroupedMMUtils.h>
|
||||
#include <ATen/native/cuda/RowwiseScaledMM.h>
|
||||
#include <ATen/native/cuda/ScaledGroupMM.h>
|
||||
#include <ATen/native/cuda/GroupMM.h>
|
||||
#include <ATen/ceil_div.h>
|
||||
|
||||
#ifdef USE_FBGEMM_GENAI
|
||||
#include <fbgemm_gpu/torch_ops.h>
|
||||
#endif
|
||||
|
||||
#ifndef AT_PER_OPERATOR_HEADERS
|
||||
#include <ATen/Functions.h>
|
||||
#include <ATen/NativeFunctions.h>
|
||||
#else
|
||||
#include <ATen/ops/_addmm_activation_native.h>
|
||||
#include <ATen/ops/_efficientzerotensor.h>
|
||||
#include <ATen/ops/_scaled_mm_native.h>
|
||||
#include <ATen/ops/_unsafe_view_native.h>
|
||||
#include <ATen/ops/abs.h>
|
||||
#include <ATen/ops/addmm_native.h>
|
||||
#include <ATen/ops/addmv_native.h>
|
||||
#include <ATen/ops/baddbmm_native.h>
|
||||
#include <ATen/ops/bmm_native.h>
|
||||
#include <ATen/ops/copy_native.h>
|
||||
#include <ATen/ops/dot_native.h>
|
||||
#include <ATen/ops/empty.h>
|
||||
#include <ATen/ops/empty_strided.h>
|
||||
#include <ATen/ops/gelu.h>
|
||||
#include <ATen/ops/max.h>
|
||||
#include <ATen/ops/mm_native.h>
|
||||
#include <ATen/ops/mul.h>
|
||||
#include <ATen/ops/relu.h>
|
||||
#include <ATen/ops/ones.h>
|
||||
#include <ATen/ops/scalar_tensor_native.h>
|
||||
#include <ATen/ops/vdot_native.h>
|
||||
#endif
|
||||
|
||||
using at::blas::ScalingType;
|
||||
using at::blas::SwizzleType;
|
||||
|
||||
namespace at::cuda::scaled {
|
||||
|
||||
static bool _scaled_mm_allowed_device(bool sm90_only=false, bool sm100_only=false) {
|
||||
#ifdef USE_ROCM
|
||||
static const std::vector<std::string> archs = {
|
||||
"gfx942",
|
||||
#if ROCM_VERSION >= 60300
|
||||
"gfx1200", "gfx1201",
|
||||
#endif
|
||||
#if ROCM_VERSION >= 60500
|
||||
"gfx950"
|
||||
#endif
|
||||
};
|
||||
return at::detail::getCUDAHooks().isGPUArch(archs);
|
||||
#else
|
||||
auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||
|
||||
if (sm90_only || sm100_only) {
|
||||
return (sm90_only && dprops->major == 9) || (sm100_only && dprops->major == 10);
|
||||
} else {
|
||||
return dprops->major >= 9 || (dprops->major == 8 && dprops->minor == 9);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
#ifdef USE_ROCM
|
||||
static bool _scaled_mm_is_fnuz() {
|
||||
return at::detail::getCUDAHooks().isGPUArch({"gfx942"});
|
||||
}
|
||||
#endif
|
||||
/**
|
||||
* Track concrete implementations available
|
||||
*/
|
||||
enum class ScaledGemmImplementation {
|
||||
NONE = 0,
|
||||
TENSORWISE_TENSORWISE = 1,
|
||||
ROWWISE_ROWWISE = 2,
|
||||
BLOCK_128x128_1x128 = 3,
|
||||
BLOCK_1x128_128x128 = 4,
|
||||
BLOCK_1x128_1x128 = 5,
|
||||
MXFP8_MXFP8 = 6,
|
||||
NVFP4_NVFP4 = 7,
|
||||
NVFP4_NVFP4_SINGLE_SCALE = 8,
|
||||
MXFP4_MXFP4 = 9,
|
||||
};
|
||||
|
||||
/**
|
||||
* Convert passed int (enum) from python back into a
|
||||
* strictly-typed enum
|
||||
*/
|
||||
template <class EnumType, class ArrayType>
|
||||
std::vector<EnumType> convert_int_to_enum(ArrayType& v) {
|
||||
std::vector<EnumType> converted;
|
||||
converted.reserve(v.size());
|
||||
|
||||
for (auto vi : v) {
|
||||
converted.push_back(static_cast<EnumType>(vi));
|
||||
}
|
||||
return converted;
|
||||
}
|
||||
|
||||
bool check_tensorwise_recipe(c10::ScalarType,
|
||||
std::vector<ScalingType>&,
|
||||
ArrayRef<Tensor>&,
|
||||
c10::ScalarType,
|
||||
std::vector<ScalingType>&,
|
||||
ArrayRef<Tensor>&);
|
||||
|
||||
|
||||
bool check_rowwise_recipe(c10::ScalarType,
|
||||
std::vector<ScalingType>&,
|
||||
ArrayRef<Tensor>&,
|
||||
c10::ScalarType,
|
||||
std::vector<ScalingType>&,
|
||||
ArrayRef<Tensor>&);
|
||||
|
||||
bool check_nvfp4_recipe(c10::ScalarType,
|
||||
std::vector<ScalingType>&,
|
||||
ArrayRef<Tensor>&,
|
||||
c10::ScalarType,
|
||||
std::vector<ScalingType>&,
|
||||
ArrayRef<Tensor>&);
|
||||
|
||||
bool check_nvfp4_recipe_single_scale
|
||||
(c10::ScalarType,
|
||||
std::vector<ScalingType>&,
|
||||
ArrayRef<Tensor>&,
|
||||
c10::ScalarType,
|
||||
std::vector<ScalingType>&,
|
||||
ArrayRef<Tensor>&);
|
||||
|
||||
bool check_deepseek_recipe(ScalingType,
|
||||
ScalingType,
|
||||
c10::ScalarType,
|
||||
std::vector<ScalingType>&,
|
||||
ArrayRef<Tensor>&,
|
||||
c10::ScalarType,
|
||||
std::vector<ScalingType>&,
|
||||
ArrayRef<Tensor>&);
|
||||
|
||||
bool check_mxfp8_recipe(c10::ScalarType,
|
||||
std::vector<ScalingType>&,
|
||||
ArrayRef<Tensor>&,
|
||||
c10::ScalarType,
|
||||
std::vector<ScalingType>&,
|
||||
ArrayRef<Tensor>&);
|
||||
|
||||
bool check_mxfp4_recipe(c10::ScalarType,
|
||||
std::vector<ScalingType>&,
|
||||
ArrayRef<Tensor>&,
|
||||
c10::ScalarType,
|
||||
std::vector<ScalingType>&,
|
||||
ArrayRef<Tensor>&);
|
||||
|
||||
} // namespace at::native::cuda::blas::scaled
|
||||
@ -137,7 +137,7 @@ struct CUDACachingHostAllocatorImpl
|
||||
void free_block_slowpath(Block* block) {
|
||||
auto start = std::chrono::steady_clock::now();
|
||||
// Users may change the allocator config at will. torch unit tests do this.
|
||||
// However, allocations using cudaHostRegister should use corresonding
|
||||
// However, allocations using cudaHostRegister should use corresponding
|
||||
// cudaHostUnregister and similarly for cudaHostAlloc / cudaFreeHost.
|
||||
void* ptr = block->ptr_;
|
||||
bool use_register = false;
|
||||
|
||||
@ -4,7 +4,7 @@
|
||||
#include <ATen/cuda/CUDAConfig.h>
|
||||
|
||||
// NOTE: These templates are intentionally not defined in this header,
|
||||
// which aviods re-compiling them for each translation unit. If you get
|
||||
// which avoids re-compiling them for each translation unit. If you get
|
||||
// a link error, you need to add an explicit instantiation for your
|
||||
// types in cub.cu
|
||||
|
||||
|
||||
@ -38,7 +38,7 @@ GemmTunableOp_float_NT,nt_25088_4096_64,1219,1.262
|
||||
GemmTunableOp_float_NT,nt_4096_4096_64,1216,0.033
|
||||
```
|
||||
|
||||
Note the "Validator" lines. If you change a library verison, or ROCm version, or PyTorch version, TunableOp will detect
|
||||
Note the "Validator" lines. If you change a library version, or ROCm version, or PyTorch version, TunableOp will detect
|
||||
this and reject the tunings file because the prior tunings are likely affected by other software changes.
|
||||
|
||||
The remaining lines are the tuned solutions for each TunableOp encountered during your execution. Each line consists of
|
||||
|
||||
@ -235,7 +235,7 @@ class TunableOp {
|
||||
// numeric check option is controlled by non-static env var, so check it once per tuned operator
|
||||
bool do_numerics_check = ctx->IsNumericsCheckEnabled();
|
||||
|
||||
// calcaulte a reference answer for numerical check
|
||||
// calculate a reference answer for numerical check
|
||||
if (do_numerics_check) {
|
||||
reference_params = params->DeepCopy(false);
|
||||
TORCH_CHECK(ops_[ResultEntry::Default()]->Call(reference_params) == OK);
|
||||
|
||||
@ -12,7 +12,7 @@ namespace at {
|
||||
|
||||
// AcceleratorHooksInterface is a shared interface provided by all
|
||||
// accelerators to allow generic code.
|
||||
// This inferface is hook-based as it corresponds to all the functions
|
||||
// This interface is hook-based as it corresponds to all the functions
|
||||
// that are going to be called in a generic way from the CPU code.
|
||||
|
||||
struct TORCH_API AcceleratorHooksInterface {
|
||||
|
||||
@ -38,7 +38,7 @@ struct TORCH_API PrivateUse1HooksInterface : AcceleratorHooksInterface {
|
||||
|
||||
Generator getNewGenerator(
|
||||
[[maybe_unused]] DeviceIndex device_index = -1) const override {
|
||||
// TODO(FFFrog): Perserved for BC and will be removed in the future.
|
||||
// TODO(FFFrog): Preserved for BC and will be removed in the future.
|
||||
if (at::GetGeneratorPrivate().has_value())
|
||||
return at::GetGeneratorForPrivateuse1(device_index);
|
||||
|
||||
|
||||
@ -283,7 +283,7 @@ inline void boxed_existing_bdim_all_batch_rule(
|
||||
// Use when all tensors arguments accept one (normal) batch dim.
|
||||
// This batching rule expands the batch dim on all Tensors, reshapes it into
|
||||
// dim 0, calls the op, and then reshapes the batch dim out of dim 0.
|
||||
// This is not the most efficient thing; if there are alternatives, plese try
|
||||
// This is not the most efficient thing; if there are alternatives, please try
|
||||
// to use them. Use this only as a last resort.
|
||||
#define EXISTING_BDIM_ALL_BOXED(op) \
|
||||
m.impl(#op, torch::CppFunction::makeFromBoxedFunction<boxed_existing_bdim_all_batch_rule>());
|
||||
|
||||
@ -384,7 +384,7 @@ fourOutputs solve_ex_batch_rule(
|
||||
|
||||
// NOTE [ solve_ex Batch Rule Contiguity ]
|
||||
// A determines whether or not linalg_solve takes an optimized path. We need the check on A_ to match the one run on
|
||||
// A as BatchedTensor since it might have been saved by autograd (specifically by the jvp) and the autograd behvaior
|
||||
// A as BatchedTensor since it might have been saved by autograd (specifically by the jvp) and the autograd behavior
|
||||
// differs based on whether or not the optimized path was taken
|
||||
const auto batched_A_was_contiguous = A_bdim.has_value() ? at::select(A, *A_bdim, 0).is_contiguous() : A.is_contiguous();
|
||||
if (batched_A_was_contiguous && !A.is_complex()) {
|
||||
|
||||
@ -282,7 +282,7 @@ static std::tuple<Tensor, std::optional<int64_t>> _softmax_backward_batch_rule(
|
||||
|
||||
dim = getPhysicalDim(output_, /*has_batch_dim*/true, dim);
|
||||
|
||||
// Not sure why output_ needs to be marked as .contiguous(). Someting must
|
||||
// Not sure why output_ needs to be marked as .contiguous(). Something must
|
||||
// have changed in PyTorch (and output of softmax is probably always contiguous)
|
||||
return std::make_tuple(at::_softmax_backward_data(grad_output_, output_.contiguous(), dim, input_dtype), 0);
|
||||
}
|
||||
|
||||
@ -224,7 +224,7 @@ static Tensor safeStack(TensorList tensors) {
|
||||
// is possible for the backward function to return an undefined grad for some
|
||||
// grad_input for each example. In that case, we return an undefined grad.
|
||||
//
|
||||
// It is theoretically posssible for *some* of the examples to produce an
|
||||
// It is theoretically possible for *some* of the examples to produce an
|
||||
// undefined grad (a kernel could peek at the gradient values and return an
|
||||
// undefined tensor if it determines the gradient is full of zeros). We
|
||||
// could handle this by treating the undefined grad as a zero-filled tensor
|
||||
|
||||
@ -113,7 +113,7 @@ SymIntArrayRef BatchedTensorImpl::sym_sizes_custom() const {
|
||||
return sym_sizes_default();
|
||||
}
|
||||
|
||||
// The following are publically exposed as methods of Tensor
|
||||
// The following are publicly exposed as methods of Tensor
|
||||
|
||||
IntArrayRef BatchedTensorImpl::strides_custom() const {
|
||||
return strides_default();
|
||||
|
||||
@ -37,7 +37,7 @@ namespace at::functorch {
|
||||
// how to perform the transform.
|
||||
//
|
||||
// TODO: we can excise DynamicLayer in favor of Interpreter,
|
||||
// But I am going to leave it for now as a compatiblity shim to avoid
|
||||
// But I am going to leave it for now as a compatibility shim to avoid
|
||||
// needing to refactor a lot of callsites...
|
||||
struct TORCH_API DynamicLayer {
|
||||
explicit DynamicLayer(
|
||||
|
||||
@ -88,7 +88,7 @@ std::ostream& operator<<(std::ostream& os, const TransformType& t);
|
||||
// >>> VmapInterpreterPtr(&interpreter).batchSize()
|
||||
//
|
||||
// Finally, Interpreter::process switches on the type of the interpreter
|
||||
// and calls one of {Transform}Intepreter::processImpl under the hood.
|
||||
// and calls one of {Transform}Interpreter::processImpl under the hood.
|
||||
// Same for Interpreter::sendToNextInterpreter :)
|
||||
|
||||
struct VmapInterpreterMeta {
|
||||
|
||||
@ -733,7 +733,7 @@ TORCH_LIBRARY_IMPL(_, FuncTorchBatched, m) {
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) {
|
||||
// still legacy b/c teturns multiple tensors
|
||||
// still legacy b/c returns multiple tensors
|
||||
m.impl("split.Tensor", split_batching_rule);
|
||||
m.impl("split_with_sizes", split_with_sizes_batching_rule);
|
||||
m.impl("split_with_sizes_copy", split_with_sizes_copy_batching_rule);
|
||||
|
||||
@ -158,7 +158,7 @@ void MPSStream::fill(id<MTLBuffer> buffer, uint8_t value, size_t length, size_t
|
||||
endKernelCoalescing();
|
||||
id<MTLBlitCommandEncoder> blitEncoder = [commandBuffer() blitCommandEncoder];
|
||||
|
||||
// For some reason fillBufferfor stopped working for lengh > 4Gb on MacOS 26
|
||||
// For some reason fillBufferfor stopped working for length > 4Gb on MacOS 26
|
||||
// See https://github.com/pytorch/pytorch/issues/163962
|
||||
// Workaround by batching copy commands into 4Gb chunks
|
||||
constexpr size_t max_copy_size = 0x100000000; // 4GB
|
||||
|
||||
@ -148,7 +148,7 @@ inline void checkInputsSolver(const Tensor& A,
|
||||
|
||||
inline bool is_row_or_column_contiguous(const Tensor& t) {
|
||||
// This could be made more general, similar to how it's checked in matmul, which would allow to
|
||||
// ellide the copy with strides such as (6, 12, 1, 3) or (3, 1, 9), but this is quite tricky.
|
||||
// elide the copy with strides such as (6, 12, 1, 3) or (3, 1, 9), but this is quite tricky.
|
||||
// We choose to be conservative for simplicity
|
||||
return t.is_contiguous() || t.transpose(-2, -1).is_contiguous();
|
||||
}
|
||||
|
||||
@ -21,7 +21,7 @@ enum class fft_norm_mode {
|
||||
// NOTE [ Fourier Transform Conjugate Symmetry ]
|
||||
//
|
||||
// Real-to-complex Fourier transform satisfies the conjugate symmetry. That is,
|
||||
// assuming X is the transformed K-dimensionsal signal, we have
|
||||
// assuming X is the transformed K-dimensional signal, we have
|
||||
//
|
||||
// X[i_1, ..., i_K] = X[j_i, ..., j_K]*,
|
||||
//
|
||||
|
||||
@ -128,7 +128,7 @@ at::Tensor PackedLinearWeight::apply_impl(
|
||||
auto* input_tr_ptr =
|
||||
reinterpret_cast<uint8_t*>(input_tr.data_ptr<c10::quint8>());
|
||||
// TODO: Activation transpose before and after the kernel can be removed if we
|
||||
// keep activation tensor always tranposed.
|
||||
// keep activation tensor always transposed.
|
||||
fbgemm::transpose_simd<uint8_t>(
|
||||
batch_size, K, input_ptr, K, input_tr_ptr, batch_size);
|
||||
|
||||
|
||||
@ -520,7 +520,7 @@ cpu_adaptive_avg_pool3d_channels_last(
|
||||
scalar_t* out = output_data + i * channels;
|
||||
int64_t size = channels;
|
||||
|
||||
// Note: For oridinary usage scenario, each out lane should
|
||||
// Note: For ordinary usage scenario, each out lane should
|
||||
// fit in L1 cache; otherwise consider block dim C.
|
||||
// Pass I: zero the out lane
|
||||
int64_t d1 = 0;
|
||||
|
||||
@ -34,7 +34,7 @@ struct Dist {
|
||||
// finish : This tells what to do with the aggregated value to compute
|
||||
// the norm. Generally this is the result of val ^ (1 / p).
|
||||
// backward : This is the gradient for that norm. Arguments are pretty
|
||||
// self explanitory.
|
||||
// self explanatory.
|
||||
//
|
||||
// There are a few cases where these aren't used. The 0 norm has no backward,
|
||||
// because it's always 0, so that's shortcircuited earlier. There's a special
|
||||
|
||||
@ -30,7 +30,7 @@ vec::Vectorized<scalar_t> is_nan_vec(vec::Vectorized<scalar_t> vec) {
|
||||
return vec.isnan();
|
||||
}
|
||||
|
||||
// TODO: use is_integeral/is_same to check the scalar_t and simplify the implementation
|
||||
// TODO: use is_integral/is_same to check the scalar_t and simplify the implementation
|
||||
// currently it does not work
|
||||
template <>
|
||||
vec::Vectorized<unsigned char> is_nan_vec<unsigned char>(vec::Vectorized<unsigned char> vec) {
|
||||
|
||||
@ -74,7 +74,7 @@ it to sum up the entire array into a single value.
|
||||
|
||||
`ReduceOpsKernel.cpp` uses the `CPU_CAPABILITY_*` macros to "know" under which
|
||||
compiler flags it is currently compiled. This allows the programmer to write
|
||||
generic code, which will be compiled under multipled compilation settings.
|
||||
generic code, which will be compiled under multiplied compilation settings.
|
||||
|
||||
`../ReduceOps.cpp` now includes the header `ReduceOpsKernel.h`, which contains
|
||||
a generic definition of `sumImplAll`. This function allows the user to reduce
|
||||
|
||||
@ -889,7 +889,7 @@ void ImagingResampleHorizontalConvolution8u(
|
||||
_mm_loadu_si128((__m128i *) (lineIn_min + stride * i))),
|
||||
_mm_loadu_si128((__m128i *) (lineIn_min + stride * (i + 4))), 1);
|
||||
|
||||
// Extract lower part of each lane, cast to epi16 and reoder RGBARGBA -> RRGGBBAA
|
||||
// Extract lower part of each lane, cast to epi16 and reorder RGBARGBA -> RRGGBBAA
|
||||
// RGBA: pix1 = [
|
||||
// r0 0 r1 0 g0 0 g1 0 b0 0 b1 0 a0 0 a1 0
|
||||
// r4 0 r5 0 g4 0 g5 0 b4 0 b5 0 a4 0 a5 0
|
||||
|
||||
@ -240,7 +240,7 @@ _PS256_CONST(coscof_p2, 4.166664568298827E-002);
|
||||
_PS256_CONST(cephes_FOPI, 1.27323954473516); // 4 / M_PI
|
||||
|
||||
|
||||
/* evaluation of 8 sines at onces using AVX intrinsics
|
||||
/* evaluation of 8 sines at once using AVX intrinsics
|
||||
|
||||
The code is the exact rewriting of the cephes sinf function.
|
||||
Precision is excellent as long as x < 8192 (I did not bother to
|
||||
|
||||
@ -311,7 +311,7 @@ void GroupNormKernelImplChannelsLastInternal(
|
||||
const bool gamma_null = (gamma_data == nullptr);
|
||||
const bool beta_null = beta_data == nullptr;
|
||||
|
||||
// NB: About algorithm choosen:
|
||||
// NB: About algorithm chosen:
|
||||
//
|
||||
// On channels last, GroupNorm has a input shape of {N, H, W, GD},
|
||||
// Mean and rstd are collected per each n and g, which involves reduction
|
||||
|
||||
@ -930,7 +930,7 @@ void ref_dyn_quant_matmul_4bit_channelwise_kernel(
|
||||
}
|
||||
};
|
||||
|
||||
// Dynamically Quantize the float32 input to 8 bit assymetric
|
||||
// Dynamically Quantize the float32 input to 8 bit asymmetric
|
||||
input_quant_pack_8bit_channelwise(m, k, lhs_f32, (int8_t*)lhs_qa8dx);
|
||||
|
||||
const size_t lhs_stride =
|
||||
@ -1163,7 +1163,7 @@ void dyn_quant_matmul_4bit_kernel(
|
||||
const int64_t weight_packed_size =
|
||||
kleidiai::kai_pack_rhs_int4_size(N, K, block_size);
|
||||
if (weight_packed_size == packed_weights.numel()) {
|
||||
// KleidiAI interface intenally handles the Channelwise and groupwise
|
||||
// KleidiAI interface internally handles the Channelwise and groupwise
|
||||
// distinction
|
||||
kleidiai::kai_quant_pack_lhs_int4_mm(
|
||||
output, inp, packed_weights, M, N, K, block_size);
|
||||
|
||||
@ -13,6 +13,7 @@
|
||||
#include <ATen/OpMathType.h>
|
||||
#include <ATen/TensorUtils.h>
|
||||
#include <ATen/cuda/CUDABlas.h>
|
||||
#include <ATen/cuda/CUDAScaledBlas.h>
|
||||
#include <ATen/cuda/tunable/Tunable.h>
|
||||
#include <ATen/cuda/tunable/TunableGemm.h>
|
||||
#include <ATen/native/Resize.h>
|
||||
@ -360,7 +361,7 @@ static bool isInputCompliesAddmmCudaLt(Tensor& result, const Tensor& self, const
|
||||
// and the leading stride is at least max(1, other dim length), so we might
|
||||
// end up with contiguous cols but not rows (i.e. holes between different rows)
|
||||
// and vice versa.
|
||||
mat2_sizes[0] < 65535 * 32 && mat2_sizes[1] < 65535 * 32 &&
|
||||
&& mat2_sizes[0] < 65535 * 32 && mat2_sizes[1] < 65535 * 32 &&
|
||||
mat1_sizes[0] < 65535 * 32 && mat1_sizes[1] < 65535 * 32 &&
|
||||
&& (
|
||||
// filter by dtype
|
||||
@ -504,7 +505,7 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
|
||||
// Handle whether to use the Lt interface {
|
||||
static bool persistent_disable_addmm_cuda_lt = isGloballyDisabledAddmmCudaLt(self.device());
|
||||
// if lt path fails, we recurse back into this function here and force the lt path to off
|
||||
// we cannot update varible disable_addmm_cuda_lt from above since it is static and would be permanent
|
||||
// we cannot update variable disable_addmm_cuda_lt from above since it is static and would be permanent
|
||||
bool disable_addmm_cuda_lt = persistent_disable_addmm_cuda_lt || disable_addmm_cuda_lt_override;
|
||||
#ifdef USE_ROCM
|
||||
// Conditioned on the device index, which is not persistent
|
||||
@ -1628,104 +1629,6 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
|
||||
return _scaled_gemm(mat1, mat2, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, use_fast_accum, out);
|
||||
}
|
||||
|
||||
namespace {
|
||||
void _check_scales_fp8_rowwise(const Tensor& mat, const Tensor& scale, const int dim, const int arg_idx, const int scale_multiplier=1) {
|
||||
// Checks scales for 2d or 3d target tensors (`mat`).
|
||||
if (mat.dim() == 2) {
|
||||
TORCH_CHECK(
|
||||
scale.dim() == 1,
|
||||
"scale must be a 1D tensor, but got ",
|
||||
scale.dim(),
|
||||
"D, arg ",
|
||||
arg_idx);
|
||||
TORCH_CHECK(
|
||||
scale.is_contiguous(), "scale must be contiguous for arg ", arg_idx);
|
||||
TORCH_CHECK(
|
||||
scale.size(0) == mat.size(dim) * scale_multiplier,
|
||||
"scale must have the same length as mat for arg ",
|
||||
arg_idx);
|
||||
} else {
|
||||
TORCH_CHECK(
|
||||
scale.dim() == 2,
|
||||
"scale must be a 2D tensor, but got ",
|
||||
scale.dim(),
|
||||
"D for arg ",
|
||||
arg_idx);
|
||||
TORCH_CHECK(
|
||||
scale.stride(1) == 1,
|
||||
"scale must be contiguous in the last dimension for arg ",
|
||||
arg_idx);
|
||||
TORCH_CHECK(
|
||||
scale.size(0) == mat.size(0),
|
||||
"scale must have the same batch dimension as mat for arg ",
|
||||
arg_idx);
|
||||
TORCH_CHECK(
|
||||
scale.size(1) == mat.size(1 + dim),
|
||||
"scale must have the same first dimension as mat for arg ",
|
||||
arg_idx);
|
||||
}
|
||||
}
|
||||
|
||||
void _check_scales_mxfp8(const Tensor& mat, const Tensor& scale, const int dim, const int arg_idx) {
|
||||
// Checks scales for 2d or 3d target tensors (`mat`).
|
||||
if (mat.dim() == 2) {
|
||||
// For MXFP8, 2d tensors have variable size groups represented as subtensors,
|
||||
// that are converted to blocked padded format individually,
|
||||
// so we can't check the scale sizes without doing a d2h sync to get the group sizes here.
|
||||
TORCH_CHECK(
|
||||
scale.dim() == mat.dim(),
|
||||
"for mxfp8, scale must have same number of dimensions as parent tensor, but got mat.dim() = ", mat.dim(), " and scale.dim() = ", scale.dim(), " for arg ", arg_idx);
|
||||
|
||||
// LHS mat shape (M, total_K) -> scale shape (rounded_up(M, 128), rounded_up_per_group(K/32, 4))
|
||||
// RHS mat shape (total_K, N) -> scale shape (rounded_up(N, 128), rounded_up_per_group(K/32, 4))
|
||||
// * weight is transposed prior to the call, scale stays non-transposed.
|
||||
bool LHS = arg_idx == 0;
|
||||
int scale_dim_to_check = 0;
|
||||
int mat_dim_to_check = LHS ? 0 : 1;
|
||||
TORCH_CHECK(
|
||||
scale.size(scale_dim_to_check) >= mat.size(mat_dim_to_check),
|
||||
"for mxfp8, arg ", arg_idx, " tensor shape (", mat.size(0), ", ", mat.size(1), ") ",
|
||||
"must have scale.shape[", scale_dim_to_check, "] >= ", mat.size(mat_dim_to_check), " but got scale.shape=(", scale.size(0), ", ", scale.size(1), ")");
|
||||
} else {
|
||||
// For MXFP8, 3d tensors have static group sizes (stack of 2d tensors),
|
||||
// so we can check the exact expected scale sizes here without a d2h sync.
|
||||
auto round_up = [](auto x, auto y) {
|
||||
return ((x + y - 1) / y) * y;
|
||||
};
|
||||
|
||||
// TODO: this is for 3d tensor in 2d-3d case specifically.
|
||||
// We'll need to support 3d-3d and 3d-2d cases once mxfp8 grouped gemm supports them.
|
||||
int64_t G = mat.size(0);
|
||||
int64_t K = mat.size(1);
|
||||
int64_t N = mat.size(2);
|
||||
int64_t blocked_scale_K = round_up(K/32, 4);
|
||||
int64_t blocked_scale_N = round_up(N, 128);
|
||||
|
||||
// fbgemm expects stack of flattened blocked scales for 3d tensor, shape (G, blocked_scale_K * blocked_scale_N).
|
||||
TORCH_CHECK(
|
||||
scale.dim() == mat.dim() - 1,
|
||||
"for mxfp8 2d-3d grouped GEMM, the 3d tensor of shape (G,K,N) must have a 2d scale of shape (G, blocked_scale_K * blocked_scale_N), but scale is ", scale.dim(), "D for arg ", arg_idx
|
||||
);
|
||||
TORCH_CHECK(
|
||||
scale.size(0) == G && scale.size(1) == blocked_scale_K * blocked_scale_N,
|
||||
"for mxfp8, the tensor shape (", G, ", ", K, ", ", N, ") must have scale shape (", G, ",", blocked_scale_K, ",", blocked_scale_N, ") for arg ", arg_idx
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
void check_scale(const Tensor& mat, const Tensor& scale, const int dim, const int arg_idx, const int scale_multiplier=1) {
|
||||
bool using_fp8_rowwise = scale.scalar_type() == kFloat;
|
||||
bool using_mxfp8 = scale.scalar_type() == at::kFloat8_e8m0fnu;
|
||||
if (using_fp8_rowwise) {
|
||||
_check_scales_fp8_rowwise(mat, scale, dim, arg_idx, scale_multiplier);
|
||||
} else if (using_mxfp8) {
|
||||
_check_scales_mxfp8(mat, scale, dim, arg_idx);
|
||||
} else {
|
||||
TORCH_CHECK(false, "scale must be float32 or float8_e8m0fnu, but got ", scale.dtype());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Tensor
|
||||
_scaled_mm_cuda(const Tensor& mat_a, const Tensor& mat_b,
|
||||
const Tensor& scale_a,
|
||||
@ -1740,261 +1643,26 @@ _scaled_mm_cuda(const Tensor& mat_a, const Tensor& mat_b,
|
||||
return _scaled_mm_out_cuda(mat_a, mat_b, scale_a, scale_b, bias, scale_result, out_dtype, use_fast_accum, out);
|
||||
}
|
||||
|
||||
/**
|
||||
* Track concrete implementations available
|
||||
*/
|
||||
enum class ScaledGemmImplementation {
|
||||
NONE = 0,
|
||||
TENSORWISE_TENSORWISE = 1,
|
||||
ROWWISE_ROWWISE = 2,
|
||||
BLOCK_128x128_1x128 = 3,
|
||||
BLOCK_1x128_128x128 = 4,
|
||||
BLOCK_1x128_1x128 = 5,
|
||||
MXFP8_MXFP8 = 6,
|
||||
NVFP4_NVFP4 = 7,
|
||||
NVFP4_NVFP4_SINGLE_SCALE = 8,
|
||||
MXFP4_MXFP4 = 9,
|
||||
};
|
||||
|
||||
/**
|
||||
* Convert passed int (enum) from python back into a
|
||||
* strictly-typed enum
|
||||
*/
|
||||
template <class EnumType, class ArrayType>
|
||||
std::vector<EnumType> convert_int_to_enum(ArrayType& v) {
|
||||
std::vector<EnumType> converted;
|
||||
converted.reserve(v.size());
|
||||
|
||||
for (auto vi : v) {
|
||||
converted.push_back(static_cast<EnumType>(vi));
|
||||
}
|
||||
return converted;
|
||||
}
|
||||
|
||||
/**
|
||||
* Both inputs must be fp8,
|
||||
* Each needs a single scale, {Tensorwise (float)}
|
||||
*/
|
||||
bool check_tensorwise_recipe(c10::ScalarType type_a,
|
||||
std::vector<ScalingType>& recipe_a,
|
||||
ArrayRef<Tensor>& scales_a,
|
||||
c10::ScalarType type_b,
|
||||
std::vector<ScalingType>& recipe_b,
|
||||
ArrayRef<Tensor>& scales_b) {
|
||||
// both types must be fp8
|
||||
if (!isFloat8Type(type_a) || !isFloat8Type(type_b)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// 1 scale each, {Tensorwise, float}
|
||||
if (scales_a.size() != 1 || recipe_a.size() != 1 || scales_b.size() != 1 || recipe_b.size() != 1) {
|
||||
return false;
|
||||
}
|
||||
// Need {Blockwise_1x32, e8m0} for A & B
|
||||
if (recipe_a[0] != ScalingType::TensorWise) return false;
|
||||
if (scales_a[0].scalar_type() != ScalarType::Float) return false;
|
||||
if (recipe_b[0] != ScalingType::TensorWise) return false;
|
||||
if (scales_b[0].scalar_type() != ScalarType::Float) return false;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
* Both inputs must be fp8,
|
||||
* Each needs scales, {Rowwise (float)}
|
||||
*/
|
||||
bool check_rowwise_recipe(c10::ScalarType type_a,
|
||||
std::vector<ScalingType>& recipe_a,
|
||||
ArrayRef<Tensor>& scales_a,
|
||||
c10::ScalarType type_b,
|
||||
std::vector<ScalingType>& recipe_b,
|
||||
ArrayRef<Tensor>& scales_b) {
|
||||
// both types must be fp8
|
||||
if (!isFloat8Type(type_a) || !isFloat8Type(type_b)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// 1 scale each, {Tensorwise, float}
|
||||
if (scales_a.size() != 1 || recipe_a.size() != 1 || scales_b.size() != 1 || recipe_b.size() != 1) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Need {RowWise, dp32} for A & B
|
||||
if (recipe_a[0] != ScalingType::RowWise) return false;
|
||||
if (scales_a[0].scalar_type() != ScalarType::Float) return false;
|
||||
if (recipe_b[0] != ScalingType::RowWise) return false;
|
||||
if (scales_b[0].scalar_type() != ScalarType::Float) return false;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Two-level scaling, canonical NVFP4
|
||||
* Both inputs must be fp4
|
||||
* A, B need 2 scales, {Blockwise_1x16 (e4m3), Tensorwise (fp32)}
|
||||
*/
|
||||
bool check_nvfp4_recipe(c10::ScalarType type_a,
|
||||
std::vector<ScalingType>& recipe_a,
|
||||
ArrayRef<Tensor>& scales_a,
|
||||
c10::ScalarType type_b,
|
||||
std::vector<ScalingType>& recipe_b,
|
||||
ArrayRef<Tensor>& scales_b) {
|
||||
// both types must be fp4
|
||||
if (type_a != ScalarType::Float4_e2m1fn_x2 || type_b != ScalarType::Float4_e2m1fn_x2) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// 2 scales, 2 recipes for each input
|
||||
if (scales_a.size() != 2 || recipe_a.size() != 2 || scales_b.size() != 2 || recipe_b.size() != 2) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Need {Blockwise_1x16, e4m3 for scale[0], Tensorwise, fp32 for scale[1]}
|
||||
if (recipe_a[0] != ScalingType::BlockWise1x16 || recipe_a[1] != ScalingType::TensorWise) return false;
|
||||
if (scales_a[0].scalar_type() != ScalarType::Float8_e4m3fn || scales_a[1].scalar_type() != ScalarType::Float) return false;
|
||||
if (recipe_b[0] != ScalingType::BlockWise1x16 || recipe_b[1] != ScalingType::TensorWise) return false;
|
||||
if (scales_b[0].scalar_type() != ScalarType::Float8_e4m3fn || scales_b[1].scalar_type() != ScalarType::Float) return false;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
* Single-level scaling, what PyT currently understands
|
||||
* Both inputs must be fp4
|
||||
* A, B need 1 scale, {Blockwise_1x16 (e4m3)}
|
||||
*/
|
||||
bool check_nvfp4_recipe_single_scale
|
||||
(c10::ScalarType type_a,
|
||||
std::vector<ScalingType>& recipe_a,
|
||||
ArrayRef<Tensor>& scales_a,
|
||||
c10::ScalarType type_b,
|
||||
std::vector<ScalingType>& recipe_b,
|
||||
ArrayRef<Tensor>& scales_b) {
|
||||
// both types must be fp4
|
||||
if (type_a != ScalarType::Float4_e2m1fn_x2 || type_b != ScalarType::Float4_e2m1fn_x2) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// 2 scales, 2 recipes for each input
|
||||
if (scales_a.size() != 1 || recipe_a.size() != 1 || scales_b.size() != 1 || recipe_b.size() != 1) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Need {Blockwise_1x16, e4m3 for scale[0], Tensorwise, fp32 for scale[1]}
|
||||
if (recipe_a[0] != ScalingType::BlockWise1x16) return false;
|
||||
if (scales_a[0].scalar_type() != ScalarType::Float8_e4m3fn) return false;
|
||||
if (recipe_b[0] != ScalingType::BlockWise1x16) return false;
|
||||
if (scales_b[0].scalar_type() != ScalarType::Float8_e4m3fn) return false;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
* Both inputs must be fp8
|
||||
* A, B must only have 1 scale each, A: {Blockwise_1x128 (float), B: {Blockwise_128x128 (float)
|
||||
*/
|
||||
bool check_deepseek_recipe(ScalingType expected_recipe_a,
|
||||
ScalingType expected_recipe_b,
|
||||
c10::ScalarType type_a,
|
||||
std::vector<ScalingType>& recipe_a,
|
||||
ArrayRef<Tensor>& scales_a,
|
||||
c10::ScalarType type_b,
|
||||
std::vector<ScalingType>& recipe_b,
|
||||
ArrayRef<Tensor>& scales_b) {
|
||||
// both types must be fp8
|
||||
if (type_a != ScalarType::Float8_e4m3fn || type_b != ScalarType::Float8_e4m3fn) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// 1 scales, 1 recipes for each input
|
||||
if (scales_a.size() != 1 || recipe_a.size() != 1 || scales_b.size() != 1 || recipe_b.size() != 1) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Need {Blockwise_1x128, float} for A, {Blockwise_128x128, float} for B
|
||||
if (recipe_a[0] != expected_recipe_a) return false;
|
||||
if (scales_a[0].scalar_type() != ScalarType::Float) return false;
|
||||
if (recipe_b[0] != expected_recipe_b) return false;
|
||||
if (scales_b[0].scalar_type() != ScalarType::Float) return false;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
* Both inputs must be fp8
|
||||
* A, B must have 1 scale each, {Blockwise_1x32, e8m0}
|
||||
*/
|
||||
bool check_mxfp8_recipe(c10::ScalarType type_a,
|
||||
std::vector<ScalingType>& recipe_a,
|
||||
ArrayRef<Tensor>& scales_a,
|
||||
c10::ScalarType type_b,
|
||||
std::vector<ScalingType>& recipe_b,
|
||||
ArrayRef<Tensor>& scales_b) {
|
||||
// both types must be fp8
|
||||
if (type_a != ScalarType::Float8_e4m3fn || type_b != ScalarType::Float8_e4m3fn) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// 1 scales, 1 recipes for each input
|
||||
if (scales_a.size() != 1 || recipe_a.size() != 1 || scales_b.size() != 1 || recipe_b.size() != 1) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Need {Blockwise_1x32, e8m0} for A & B
|
||||
if (recipe_a[0] != ScalingType::BlockWise1x32) return false;
|
||||
if (scales_a[0].scalar_type() != ScalarType::Float8_e8m0fnu) return false;
|
||||
if (recipe_b[0] != ScalingType::BlockWise1x32) return false;
|
||||
if (scales_b[0].scalar_type() != ScalarType::Float8_e8m0fnu) return false;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
* Both inputs must be fp4
|
||||
* A, B must have 1 scale each, {Blockwise_1x32, e8m0}
|
||||
*/
|
||||
bool check_mxfp4_recipe(c10::ScalarType type_a,
|
||||
std::vector<ScalingType>& recipe_a,
|
||||
ArrayRef<Tensor>& scales_a,
|
||||
c10::ScalarType type_b,
|
||||
std::vector<ScalingType>& recipe_b,
|
||||
ArrayRef<Tensor>& scales_b) {
|
||||
// both types must be fp4
|
||||
if (type_a != ScalarType::Float4_e2m1fn_x2 || type_b != ScalarType::Float4_e2m1fn_x2) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// 1 scales, 1 recipes for each input
|
||||
if (scales_a.size() != 1 || recipe_a.size() != 1 || scales_b.size() != 1 || recipe_b.size() != 1) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Need {Blockwise_1x32, e8m0} for A & B
|
||||
if (recipe_a[0] != ScalingType::BlockWise1x32) return false;
|
||||
if (scales_a[0].scalar_type() != ScalarType::Float8_e8m0fnu) return false;
|
||||
if (recipe_b[0] != ScalingType::BlockWise1x32) return false;
|
||||
if (scales_b[0].scalar_type() != ScalarType::Float8_e8m0fnu) return false;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
using acceptance_fn = std::function<bool(c10::ScalarType, std::vector<ScalingType>&, ArrayRef<Tensor>&, c10::ScalarType, std::vector<ScalingType>&, ArrayRef<Tensor>&)>;
|
||||
using namespace std::placeholders;
|
||||
|
||||
namespace scaled_blas = at::cuda::scaled;
|
||||
using scaled_blas::ScaledGemmImplementation;
|
||||
using scaled_blas::convert_int_to_enum;
|
||||
|
||||
std::array<std::tuple<std::string, acceptance_fn, ScaledGemmImplementation>, 9> scale_kernel_dispatch = {{
|
||||
{ "tensorwise_tensorwise", check_tensorwise_recipe, ScaledGemmImplementation::TENSORWISE_TENSORWISE },
|
||||
{ "rowwise_rowwise", check_rowwise_recipe, ScaledGemmImplementation::ROWWISE_ROWWISE},
|
||||
{ "block_1x128_128x128", std::bind(check_deepseek_recipe, ScalingType::BlockWise1x128, ScalingType::BlockWise128x128, _1, _2, _3, _4, _5, _6),
|
||||
{ "tensorwise_tensorwise", scaled_blas::check_tensorwise_recipe, ScaledGemmImplementation::TENSORWISE_TENSORWISE },
|
||||
{ "rowwise_rowwise", scaled_blas::check_rowwise_recipe, ScaledGemmImplementation::ROWWISE_ROWWISE},
|
||||
{ "block_1x128_128x128", std::bind(scaled_blas::check_deepseek_recipe, ScalingType::BlockWise1x128, ScalingType::BlockWise128x128, _1, _2, _3, _4, _5, _6),
|
||||
ScaledGemmImplementation::BLOCK_1x128_128x128},
|
||||
{ "block_128x128_1x128", std::bind(check_deepseek_recipe, ScalingType::BlockWise128x128, ScalingType::BlockWise1x128, _1, _2, _3, _4, _5, _6),
|
||||
{ "block_128x128_1x128", std::bind(scaled_blas::check_deepseek_recipe, ScalingType::BlockWise128x128, ScalingType::BlockWise1x128, _1, _2, _3, _4, _5, _6),
|
||||
ScaledGemmImplementation::BLOCK_128x128_1x128},
|
||||
{ "block_1x128_1x128", std::bind(check_deepseek_recipe, ScalingType::BlockWise1x128, ScalingType::BlockWise1x128, _1, _2, _3, _4, _5, _6),
|
||||
{ "block_1x128_1x128", std::bind(scaled_blas::check_deepseek_recipe, ScalingType::BlockWise1x128, ScalingType::BlockWise1x128, _1, _2, _3, _4, _5, _6),
|
||||
ScaledGemmImplementation::BLOCK_1x128_1x128},
|
||||
{ "nvfp4_nvfp4", check_nvfp4_recipe, ScaledGemmImplementation::NVFP4_NVFP4},
|
||||
{ "nvfp4_nvfp4_single_scale", check_nvfp4_recipe_single_scale, ScaledGemmImplementation::NVFP4_NVFP4_SINGLE_SCALE },
|
||||
{ "mxfp8_mxfp8", check_mxfp8_recipe, ScaledGemmImplementation::MXFP8_MXFP8},
|
||||
{ "mxfp4_mxfp4", check_mxfp4_recipe, ScaledGemmImplementation::MXFP4_MXFP4}}};
|
||||
{ "nvfp4_nvfp4", scaled_blas::check_nvfp4_recipe, ScaledGemmImplementation::NVFP4_NVFP4},
|
||||
{ "nvfp4_nvfp4_single_scale", scaled_blas::check_nvfp4_recipe_single_scale, ScaledGemmImplementation::NVFP4_NVFP4_SINGLE_SCALE },
|
||||
{ "mxfp8_mxfp8", scaled_blas::check_mxfp8_recipe, ScaledGemmImplementation::MXFP8_MXFP8},
|
||||
{ "mxfp4_mxfp4", scaled_blas::check_mxfp4_recipe, ScaledGemmImplementation::MXFP4_MXFP4}}};
|
||||
|
||||
Tensor&
|
||||
_scaled_tensorwise_tensorwise(
|
||||
@ -2596,410 +2264,6 @@ _scaled_mm_cuda_v2(
|
||||
out);
|
||||
}
|
||||
|
||||
// 2d-2d and 2d-3d
|
||||
// scaling=MXFP8
|
||||
// CUDA-only
|
||||
Tensor&
|
||||
_mx8_mx8_bf16_grouped_mm_fbgemm(
|
||||
const Tensor& mat_a,
|
||||
const Tensor& mat_b,
|
||||
const Tensor& scale_a,
|
||||
const SwizzleType& swizzle_a,
|
||||
const Tensor& scale_b,
|
||||
const SwizzleType& swizzle_b,
|
||||
const std::optional<at::Tensor>& offs,
|
||||
Tensor& out) {
|
||||
const bool a_is_2d = mat_a.dim() == 2;
|
||||
const bool b_is_2d = mat_b.dim() == 2;
|
||||
bool b_is_3d = mat_b.dim() == 3;
|
||||
bool is_2d_2d = a_is_2d && b_is_2d;
|
||||
bool is_2d_3d = a_is_2d && b_is_3d;
|
||||
TORCH_CHECK_VALUE(is_2d_2d || is_2d_3d, "MXFP8 grouped GEMM currently only supports 2d-2d and 2d-3d cases");
|
||||
TORCH_CHECK_VALUE(offs.has_value(), "MXFP8 2d-2d and 2d-3d grouped GEMMs requires offsets");
|
||||
TORCH_CHECK_VALUE(out.scalar_type() == at::kBFloat16, "Only bf16 out_dtype is supported for MXFP8 grouped gemm");
|
||||
// MXFP8 expects float8_e8m0fnu scales.
|
||||
TORCH_CHECK_VALUE(scale_a.scalar_type() == at::kFloat8_e8m0fnu && scale_b.scalar_type() == at::kFloat8_e8m0fnu,
|
||||
"For MXFP8 grouped gemm, both scales must be float8_e8m0fnu tensors.");
|
||||
#ifdef USE_ROCM
|
||||
TORCH_CHECK_VALUE(swizzle_a == SwizzleType::NO_SWIZZLE && swizzle_b == SwizzleType::NO_SWIZZLE,
|
||||
"For ROCM MXFP8 grouped gemm, both scale swizzle types must be SWIZZLE_NONE");
|
||||
#else
|
||||
TORCH_CHECK_VALUE(swizzle_a == SwizzleType::SWIZZLE_32_4_4 && swizzle_b == SwizzleType::SWIZZLE_32_4_4,
|
||||
"For CUDA MXFP8 grouped gemm, both scale swizzle types must be SWIZZLE_32_4_4");
|
||||
#endif
|
||||
|
||||
#if defined(USE_FBGEMM_GENAI) and !defined(USE_ROCM)
|
||||
fbgemm_gpu::mx8mx8bf16_grouped_mm(
|
||||
mat_a,
|
||||
mat_b,
|
||||
scale_a,
|
||||
scale_b,
|
||||
offs.value(),
|
||||
out);
|
||||
#else
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(false, "mxfp8_mxfp8 grouped gemm requires compile with USE_FBGEMM_GENAI");
|
||||
#endif
|
||||
return out;
|
||||
}
|
||||
|
||||
// 2d-2d and 2d-3d cases
|
||||
// scaling=rowwise
|
||||
// CUDA-only
|
||||
Tensor&
|
||||
_f8_f8_bf16_rowwise_grouped_mm_cuda(
|
||||
const Tensor& mat_a,
|
||||
const Tensor& mat_b,
|
||||
const Tensor& scale_a,
|
||||
const Tensor& scale_b,
|
||||
const std::optional<Tensor>& offs,
|
||||
const std::optional<Tensor>& bias,
|
||||
const bool use_fast_accum,
|
||||
Tensor& out) {
|
||||
TORCH_CHECK_VALUE(mat_a.dtype() == at::kFloat8_e4m3fn, "Expected mat_a to be Float8_e4m3 matrix got ", mat_a.scalar_type());
|
||||
TORCH_CHECK_VALUE(mat_b.dtype() == at::kFloat8_e4m3fn, "Expected mat_a to be Float8_e4m3 matrix got ", mat_b.scalar_type());
|
||||
|
||||
at::cuda::detail::f8f8bf16_grouped_mm(
|
||||
mat_a,
|
||||
mat_b,
|
||||
scale_a,
|
||||
scale_b,
|
||||
offs,
|
||||
bias,
|
||||
use_fast_accum,
|
||||
out);
|
||||
return out;
|
||||
}
|
||||
|
||||
// 2d-2d and 2d-3d cases
|
||||
// scaling=rowwise
|
||||
// only being called for rocm
|
||||
Tensor&
|
||||
_f8_f8_bf16_rowwise_grouped_mm_rocm(
|
||||
const Tensor& mat_a,
|
||||
const Tensor& mat_b,
|
||||
const Tensor& scale_a,
|
||||
const Tensor& scale_b,
|
||||
const std::optional<Tensor>& offs,
|
||||
Tensor& out) {
|
||||
TORCH_CHECK_VALUE(mat_a.dtype() == at::kFloat8_e4m3fnuz, "Expected mat_a to be Float8_e4m3fnuz matrix got ", mat_a.scalar_type());
|
||||
TORCH_CHECK_VALUE(mat_b.dtype() == at::kFloat8_e4m3fnuz, "Expected mat_a to be Float8_e4m3fnuz matrix got ", mat_b.scalar_type());
|
||||
|
||||
#if defined(USE_FBGEMM_GENAI) && defined(USE_ROCM)
|
||||
fbgemm_gpu::f8f8bf16_rowwise_grouped_mm(
|
||||
mat_a,
|
||||
// FBGEMM expects B matrix shape to be (.., N, K)
|
||||
mat_b.transpose(-2, -1),
|
||||
scale_a,
|
||||
scale_b,
|
||||
offs,
|
||||
out);
|
||||
#else
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(false, "grouped gemm is not supported without USE_FBGEMM_GENAI on ROCM")
|
||||
#endif
|
||||
return out;
|
||||
|
||||
}
|
||||
|
||||
// Dispatch f8 x f8 -> bf16 row-wise scaled to rocm/cuda
|
||||
Tensor&
|
||||
_f8_f8_bf16_rowwise_grouped_mm(
|
||||
const Tensor& mat_a,
|
||||
const Tensor& mat_b,
|
||||
const Tensor& scale_a,
|
||||
const Tensor& scale_b,
|
||||
const std::optional<Tensor>& offs,
|
||||
const std::optional<Tensor>& bias,
|
||||
bool use_fast_accum,
|
||||
Tensor& out) {
|
||||
// FP8 per-tensor and per-row scaling expect fp32 scales.
|
||||
TORCH_CHECK_VALUE(scale_a.scalar_type() == kFloat && scale_b.scalar_type() == kFloat,
|
||||
"For grouped FP8 rowwise, both scales must be float32 tensors");
|
||||
#ifndef USE_ROCM
|
||||
return _f8_f8_bf16_rowwise_grouped_mm_cuda(
|
||||
mat_a,
|
||||
mat_b,
|
||||
scale_a,
|
||||
scale_b,
|
||||
offs,
|
||||
bias,
|
||||
use_fast_accum,
|
||||
out);
|
||||
#else
|
||||
// NOTE: ignore use_fast_accum
|
||||
TORCH_CHECK_VALUE(!bias.has_value(), "ROCM grouped gemm does not support bias")
|
||||
return _f8_f8_bf16_rowwise_grouped_mm_rocm(
|
||||
mat_a,
|
||||
mat_b,
|
||||
scale_a,
|
||||
scale_b,
|
||||
offs,
|
||||
out);
|
||||
#endif
|
||||
}
|
||||
|
||||
Tensor
|
||||
_scaled_grouped_mm_cuda(
|
||||
const Tensor& mat_a,
|
||||
const Tensor& mat_b,
|
||||
const Tensor& scale_a,
|
||||
const Tensor& scale_b,
|
||||
const std::optional<at::Tensor>& offs,
|
||||
const std::optional<at::Tensor>& bias,
|
||||
const std::optional<at::Tensor>& scale_result,
|
||||
std::optional<c10::ScalarType> out_dtype,
|
||||
bool use_fast_accum) {
|
||||
bool allowed_device = _scaled_mm_allowed_device(/*sm90_only*/true, /*sm100_only*/true);
|
||||
TORCH_CHECK_VALUE(allowed_device, "torch._scaled_grouped_mm is only supported on CUDA devices with compute capability = [9.0, 10.0], or ROCm MI300+");
|
||||
|
||||
TORCH_CHECK_VALUE(!check_valid_strides_and_return_transposed(mat_a), "Expected mat1 to not be transposed");
|
||||
TORCH_CHECK_VALUE(check_valid_strides_and_return_transposed(mat_b), "Expected mat2 to be transposed");
|
||||
TORCH_CHECK_VALUE(mat_a.dim() == 2 || mat_a.dim() == 3, "mat_a has to be 2 or 3d");
|
||||
TORCH_CHECK_VALUE(mat_b.dim() == 2 || mat_b.dim() == 3, "mat_b has to be 2 or 3d");
|
||||
const bool a_is_2d = mat_a.dim() == 2;
|
||||
const bool b_is_2d = mat_b.dim() == 2;
|
||||
|
||||
// NOTE(slayton): For sub-1B formats want contraction_dim argument?
|
||||
if (!a_is_2d || !b_is_2d) {
|
||||
TORCH_CHECK_VALUE(mat_a.size(-1) == mat_b.size(-2), "contraction dimension of mat_a and mat_b must match");
|
||||
}
|
||||
TORCH_CHECK_VALUE(
|
||||
mat_a.size(-1) % 16 == 0,
|
||||
"Expected trailing dimension of mat_a to be divisible by 16 ",
|
||||
"but got mat1 shape: (",
|
||||
mat_a.sizes(),
|
||||
").");
|
||||
TORCH_CHECK_VALUE(mat_b.size(-2) % 16 == 0 && mat_b.size(-1) % 16 == 0,
|
||||
"Expected mat_b shape to be divisible by 16 ",
|
||||
"but got mat_b shape: (",
|
||||
mat_b.sizes(),
|
||||
").");
|
||||
|
||||
|
||||
TORCH_CHECK_VALUE(!bias.has_value(), "Bias not supported yet");
|
||||
TORCH_CHECK_VALUE(!scale_result.has_value(), "Scale result not supported yet");
|
||||
TORCH_CHECK_VALUE(offs.has_value() == (a_is_2d || b_is_2d), "Have to provide offsets if there is a 2d matrix");
|
||||
|
||||
// NOTE: mxfp8 x mxfp8 requires (and asserts later) that offsets is present.
|
||||
// for rowwise, no offsets implies 3d-3d and is handled by lower-level
|
||||
// routines
|
||||
if (offs.has_value()) {
|
||||
TORCH_CHECK_VALUE(offs->dim() == 1, "offs has to be 1D");
|
||||
TORCH_CHECK_VALUE(offs->dtype() == at::kInt, "Offsets have to be int32");
|
||||
}
|
||||
// FP8 per-tensor and per-row scaling expect fp32 scales.
|
||||
// MXFP8 expects float8_e8m0fnu scales.
|
||||
TORCH_CHECK_VALUE(
|
||||
(scale_a.scalar_type() == kFloat && scale_b.scalar_type() == kFloat) ||
|
||||
(scale_a.scalar_type() == at::kFloat8_e8m0fnu && scale_b.scalar_type() == at::kFloat8_e8m0fnu),
|
||||
"For FP8 tensorwise and rowwise, both scales must both be float32 tensors. For MXFP8, scales must both be float8_e8m0fnu tensors.");
|
||||
|
||||
const int scale_multiplier = (mat_a.dim() == 2 && mat_b.dim() == 2) ? offs->size(0) : 1;
|
||||
check_scale(mat_a, scale_a, 0 ,0, scale_multiplier);
|
||||
check_scale(mat_b, scale_b, 1, 1, scale_multiplier);
|
||||
|
||||
const auto out_dtype_ = out_dtype.value_or(kBFloat16);
|
||||
TORCH_CHECK_VALUE(out_dtype_ == kBFloat16, "Only bf16 high precision output types are supported for grouped gemm");
|
||||
|
||||
Tensor out = create_grouped_gemm_output_tensor(mat_a, mat_b, offs, out_dtype_);
|
||||
|
||||
#if defined(USE_FBGEMM_GENAI) && defined(USE_CUDA) && !defined(USE_ROCM)
|
||||
// MXFP8 grouped GEMM dispatching
|
||||
bool is_mx8mx8bf16 = (
|
||||
mat_a.scalar_type() == at::kFloat8_e4m3fn && mat_b.scalar_type() == at::kFloat8_e4m3fn &&
|
||||
scale_a.scalar_type() == at::kFloat8_e8m0fnu && scale_b.scalar_type() == at::kFloat8_e8m0fnu
|
||||
);
|
||||
#else
|
||||
bool is_mx8mx8bf16 = false;
|
||||
#endif
|
||||
|
||||
if (is_mx8mx8bf16) {
|
||||
// Note: Passing implied SwizzleType here, correctness of scale previously checked
|
||||
// in `check_scale` call
|
||||
return _mx8_mx8_bf16_grouped_mm_fbgemm(
|
||||
mat_a,
|
||||
mat_b,
|
||||
scale_a,
|
||||
SwizzleType::SWIZZLE_32_4_4,
|
||||
scale_b,
|
||||
SwizzleType::SWIZZLE_32_4_4,
|
||||
offs.value(),
|
||||
out);
|
||||
}
|
||||
|
||||
// If we're not MXFP8, then we're row-wise scaling.
|
||||
return _f8_f8_bf16_rowwise_grouped_mm(
|
||||
mat_a,
|
||||
mat_b,
|
||||
scale_a,
|
||||
scale_b,
|
||||
offs,
|
||||
bias,
|
||||
use_fast_accum,
|
||||
out);
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
std::array<std::tuple<std::string, acceptance_fn, ScaledGemmImplementation>, 2> scale_grouped_kernel_dispatch = {{
|
||||
{ "rowwise_rowwise", check_rowwise_recipe, ScaledGemmImplementation::ROWWISE_ROWWISE},
|
||||
{ "mxfp8_mxfp8", check_mxfp8_recipe, ScaledGemmImplementation::MXFP8_MXFP8}}};
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
Tensor
|
||||
_scaled_grouped_mm_cuda_v2(
|
||||
const Tensor& mat_a, const Tensor& mat_b,
|
||||
ArrayRef<Tensor> scale_a,
|
||||
IntArrayRef scale_recipe_a,
|
||||
IntArrayRef swizzle_a,
|
||||
ArrayRef<Tensor> scale_b,
|
||||
IntArrayRef scale_recipe_b,
|
||||
IntArrayRef swizzle_b,
|
||||
const std::optional<Tensor>& offs,
|
||||
const std::optional<Tensor>& bias,
|
||||
const std::optional<c10::ScalarType> out_dtype,
|
||||
IntArrayRef contraction_dim,
|
||||
bool use_fast_accum) {
|
||||
bool allowed_device = _scaled_mm_allowed_device(/*sm90_only*/true, /*sm100_only*/true);
|
||||
TORCH_CHECK_VALUE(allowed_device, "torch._scaled_grouped_mm is only supported on CUDA devices with compute capability = [9.0, 10.0], or ROCm MI300+");
|
||||
|
||||
TORCH_CHECK_VALUE(!check_valid_strides_and_return_transposed(mat_a), "Expected mat1 to not be transposed");
|
||||
TORCH_CHECK_VALUE(check_valid_strides_and_return_transposed(mat_b), "Expected mat2 to be transposed");
|
||||
TORCH_CHECK_VALUE(mat_a.dim() == 2 || mat_a.dim() == 3, "mat_a has to be 2 or 3d");
|
||||
TORCH_CHECK_VALUE(mat_b.dim() == 2 || mat_b.dim() == 3, "mat_b has to be 2 or 3d");
|
||||
const bool a_is_2d = mat_a.dim() == 2;
|
||||
const bool b_is_2d = mat_b.dim() == 2;
|
||||
|
||||
// NOTE(slayton): For sub-1B formats want contraction_dim argument?
|
||||
if (!a_is_2d || !b_is_2d) {
|
||||
if (contraction_dim.size() > 0) {
|
||||
const int dim_a = contraction_dim[0], dim_b = mat_b.size(contraction_dim[1]);
|
||||
TORCH_CHECK_VALUE(mat_a.size(dim_a) == mat_b.size(dim_b),
|
||||
"Contraction dimensions (", dim_a, ",", dim_b, ") of mat_a and mat_b must match, got: ", mat_a.size(dim_a), " and ",
|
||||
mat_b.size(dim_b));
|
||||
// Note: only (-1, -2) is currently supported
|
||||
TORCH_CHECK_VALUE(dim_a == -1 && dim_b == -2, "Curently contraction dims must be (-1, -2) only");
|
||||
} else {
|
||||
TORCH_CHECK_VALUE(mat_a.size(-1) == mat_b.size(-2), "contraction dimension of mat_a and mat_b must match");
|
||||
}
|
||||
}
|
||||
TORCH_CHECK_VALUE(
|
||||
mat_a.size(-1) % 16 == 0,
|
||||
"Expected trailing dimension of mat_a to be divisible by 16 ",
|
||||
"but got mat1 shape: (",
|
||||
mat_a.sizes(),
|
||||
").");
|
||||
TORCH_CHECK_VALUE(mat_b.size(-2) % 16 == 0 && mat_b.size(-1) % 16 == 0,
|
||||
"Expected mat_b shape to be divisible by 16 ",
|
||||
"but got mat_b shape: (",
|
||||
mat_b.sizes(),
|
||||
").");
|
||||
|
||||
TORCH_CHECK_VALUE(!bias.has_value(), "Bias not supported yet");
|
||||
TORCH_CHECK_VALUE(offs.has_value() == (a_is_2d || b_is_2d), "Have to provide offsets if there is a 2d matrix");
|
||||
|
||||
// NOTE: mxfp8 x mxfp8 requires (and asserts later) that offsets is present.
|
||||
// for rowwise, no offsets implies 3d-3d and is handled by lower-level
|
||||
// routines
|
||||
if (offs.has_value()) {
|
||||
TORCH_CHECK_VALUE(offs->dim() == 1, "offs has to be 1D");
|
||||
TORCH_CHECK_VALUE(offs->dtype() == at::kInt, "Offsets have to be int32");
|
||||
}
|
||||
|
||||
const auto out_dtype_ = out_dtype.value_or(kBFloat16);
|
||||
TORCH_CHECK_VALUE(out_dtype_ == kBFloat16, "Only bf16 high precision output types are supported for grouped gemm");
|
||||
|
||||
Tensor out = create_grouped_gemm_output_tensor(mat_a, mat_b, offs, out_dtype_);
|
||||
|
||||
// Conversion of implicitly-defined enums to explicit
|
||||
auto scale_recipe_a_enum = convert_int_to_enum<ScalingType>(scale_recipe_a);
|
||||
auto swizzle_a_enum = convert_int_to_enum<SwizzleType>(swizzle_a);
|
||||
auto scale_recipe_b_enum = convert_int_to_enum<ScalingType>(scale_recipe_b);
|
||||
auto swizzle_b_enum = convert_int_to_enum<SwizzleType>(swizzle_b);
|
||||
|
||||
// at this point we can start working out what we want to be doing
|
||||
// Try to do as few steps as possible.
|
||||
// NOTE: support is deliberately sparse, can explicitly enumerate all combinations allowed.
|
||||
// Do this via a list of defined (name, acceptance, concrete_impl) tuples.
|
||||
ScaledGemmImplementation gemm_impl = ScaledGemmImplementation::NONE;
|
||||
for (const auto& fn_entry : scale_grouped_kernel_dispatch) {
|
||||
const auto [name, accept_fn, scaled_gemm_impl] = fn_entry;
|
||||
bool ok = accept_fn(mat_a.scalar_type(),
|
||||
scale_recipe_a_enum,
|
||||
scale_a,
|
||||
mat_b.scalar_type(),
|
||||
scale_recipe_b_enum,
|
||||
scale_b);
|
||||
if (ok) {
|
||||
gemm_impl = scaled_gemm_impl;
|
||||
break;
|
||||
}
|
||||
}
|
||||
TORCH_CHECK_VALUE(gemm_impl != ScaledGemmImplementation::NONE,
|
||||
"No gemm implementation was found");
|
||||
|
||||
switch (gemm_impl) {
|
||||
case ScaledGemmImplementation::ROWWISE_ROWWISE: {
|
||||
const int scale_multiplier = (mat_a.dim() == 2 && mat_b.dim() == 2) ? offs->size(0) : 1;
|
||||
_check_scales_fp8_rowwise(mat_a, scale_a[0], 0 /* dim */ , 0 /* arg_idx */, scale_multiplier);
|
||||
_check_scales_fp8_rowwise(mat_b, scale_b[0], 1 /* dim */ , 1 /* arg_idx */, scale_multiplier);
|
||||
return _f8_f8_bf16_rowwise_grouped_mm(
|
||||
mat_a,
|
||||
mat_b,
|
||||
scale_a[0],
|
||||
scale_b[0],
|
||||
offs,
|
||||
bias,
|
||||
use_fast_accum,
|
||||
out);
|
||||
}
|
||||
case ScaledGemmImplementation::MXFP8_MXFP8: {
|
||||
_check_scales_mxfp8(mat_a, scale_a[0], 0 /* dim */, 0 /* arg_idx */);
|
||||
_check_scales_mxfp8(mat_b, scale_b[0], 1 /* dim */, 1 /* arg_idx */);
|
||||
return _mx8_mx8_bf16_grouped_mm_fbgemm(
|
||||
mat_a,
|
||||
mat_b,
|
||||
scale_a[0],
|
||||
swizzle_a_enum[0],
|
||||
scale_b[0],
|
||||
swizzle_b_enum[0],
|
||||
offs.value(),
|
||||
out);
|
||||
}
|
||||
default:
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(false,
|
||||
"_scaled_grouped_mm_cuda_v2 is in an inconsistent state - should never reach here");
|
||||
}
|
||||
}
|
||||
|
||||
Tensor _grouped_mm_cuda(const Tensor& mat_a, const Tensor& mat_b,
|
||||
const std::optional<at::Tensor>& offs,
|
||||
const std::optional<at::Tensor>& bias,
|
||||
std::optional<c10::ScalarType> out_dtype) {
|
||||
_grouped_mm_validate_inputs(mat_a, mat_b, offs, bias, out_dtype);
|
||||
bool a_b_and_out_are_bf16 = (
|
||||
mat_a.dtype() == at::kBFloat16 &&
|
||||
mat_b.dtype() == at::kBFloat16 &&
|
||||
out_dtype.value_or(at::kBFloat16) == at::kBFloat16
|
||||
);
|
||||
#ifndef USE_ROCM
|
||||
bool use_fast_path = _scaled_mm_allowed_device(/*sm90_only*/true, /*sm100_only*/true) && a_b_and_out_are_bf16;
|
||||
#else
|
||||
// _scaled_mm_allowed_device is used here within _grouped_mm_cuda which seems incorrect since scale is not used.
|
||||
// the _grouped_mm_fallback should be safe for any ROCm GPU since it's just calling typical mm/bmm
|
||||
bool use_fast_path = false;
|
||||
#endif
|
||||
const auto out_dtype_ = _resolve_grouped_mm_out_dtype(mat_a, mat_b, out_dtype);
|
||||
Tensor out = create_grouped_gemm_output_tensor(mat_a, mat_b, offs, out_dtype_);
|
||||
if (use_fast_path) {
|
||||
// fast path, no d2h sync needed
|
||||
at::cuda::detail::bf16bf16_grouped_mm(mat_a, mat_b, offs, bias, out);
|
||||
} else {
|
||||
_grouped_mm_fallback(mat_a, mat_b, offs, bias, out_dtype, out);
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
static void baddbmm_bmm_out_dtype_checks(const Tensor& batch1, const Tensor& batch2, const Scalar& beta, const Scalar& alpha, const at::ScalarType out_dtype, bool is_bmm, const std::optional<Tensor>& self_baddbmm = std::nullopt) {
|
||||
// ref ATen/native/LinearAlgebra.cpp common_checks_baddbmm_bmm
|
||||
TORCH_CHECK(batch1.dim() == 3, "batch1 must be a 3D tensor");
|
||||
|
||||
@ -494,7 +494,7 @@ void uniform_kernel(TensorIteratorBase& iter, double from_, double to_, RNG gen)
|
||||
auto value = static_cast<scalar_t>(rand * range + from);
|
||||
// reverse the bounds of curand4 from (0, 1] to [0, 1)
|
||||
// Note that this method is from legacy THCTensorRandom and is likely to give
|
||||
// you more 0-s, since, the probability of gettings 1-s is higher than 0-s and
|
||||
// you more 0-s, since, the probability of getting 1-s is higher than 0-s and
|
||||
// by reversing the bounds, we are flipping the probabilities of 1-s and 0-s.
|
||||
// BEFORE TOUCHING THIS CODE READ: https://github.com/pytorch/pytorch/issues/16706
|
||||
auto reverse_bound_value = value == to ? from : value;
|
||||
|
||||
574
aten/src/ATen/native/cuda/GroupedBlas.cpp
Normal file
574
aten/src/ATen/native/cuda/GroupedBlas.cpp
Normal file
@ -0,0 +1,574 @@
|
||||
#include <cstdint>
|
||||
#include <c10/util/typeid.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#include <c10/util/SmallVector.h>
|
||||
#include <c10/core/Scalar.h>
|
||||
#include <c10/core/ScalarType.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
|
||||
#include <ATen/core/Tensor.h>
|
||||
#include <ATen/core/NamedTensor.h>
|
||||
#include <ATen/Dispatch.h>
|
||||
#include <ATen/ExpandUtils.h>
|
||||
#include <ATen/OpMathType.h>
|
||||
#include <ATen/TensorUtils.h>
|
||||
#include <ATen/cuda/CUDABlas.h>
|
||||
#include <ATen/cuda/CUDAScaledBlas.h>
|
||||
#include <ATen/cuda/tunable/Tunable.h>
|
||||
#include <ATen/cuda/tunable/TunableGemm.h>
|
||||
#include <ATen/native/Resize.h>
|
||||
#include <c10/util/MaybeOwned.h>
|
||||
#include <ATen/native/GroupedMMUtils.h>
|
||||
#include <ATen/native/cuda/RowwiseScaledMM.h>
|
||||
#include <ATen/native/cuda/ScaledGroupMM.h>
|
||||
#include <ATen/native/cuda/GroupMM.h>
|
||||
#include <ATen/ceil_div.h>
|
||||
|
||||
#ifdef USE_FBGEMM_GENAI
|
||||
#include <fbgemm_gpu/torch_ops.h>
|
||||
#endif
|
||||
|
||||
#ifndef AT_PER_OPERATOR_HEADERS
|
||||
#include <ATen/Functions.h>
|
||||
#include <ATen/NativeFunctions.h>
|
||||
#else
|
||||
#include <ATen/ops/_addmm_activation_native.h>
|
||||
#include <ATen/ops/_efficientzerotensor.h>
|
||||
#include <ATen/ops/_scaled_mm_native.h>
|
||||
#include <ATen/ops/_unsafe_view_native.h>
|
||||
#include <ATen/ops/abs.h>
|
||||
#include <ATen/ops/addmm_native.h>
|
||||
#include <ATen/ops/addmv_native.h>
|
||||
#include <ATen/ops/baddbmm_native.h>
|
||||
#include <ATen/ops/bmm_native.h>
|
||||
#include <ATen/ops/copy_native.h>
|
||||
#include <ATen/ops/dot_native.h>
|
||||
#include <ATen/ops/empty.h>
|
||||
#include <ATen/ops/empty_strided.h>
|
||||
#include <ATen/ops/gelu.h>
|
||||
#include <ATen/ops/max.h>
|
||||
#include <ATen/ops/mm_native.h>
|
||||
#include <ATen/ops/mul.h>
|
||||
#include <ATen/ops/relu.h>
|
||||
#include <ATen/ops/ones.h>
|
||||
#include <ATen/ops/scalar_tensor_native.h>
|
||||
#include <ATen/ops/vdot_native.h>
|
||||
#endif
|
||||
|
||||
using at::blas::ScalingType;
|
||||
using at::blas::SwizzleType;
|
||||
|
||||
namespace scaled_blas = at::cuda::scaled;
|
||||
using scaled_blas::ScaledGemmImplementation;
|
||||
using scaled_blas::convert_int_to_enum;
|
||||
using scaled_blas::_scaled_mm_allowed_device;
|
||||
|
||||
namespace at::native {
|
||||
|
||||
namespace {
|
||||
|
||||
// 2d-2d and 2d-3d
|
||||
// scaling=MXFP8
|
||||
// CUDA-only
|
||||
Tensor&
|
||||
_mx8_mx8_bf16_grouped_mm_fbgemm(
|
||||
const Tensor& mat_a,
|
||||
const Tensor& mat_b,
|
||||
const Tensor& scale_a,
|
||||
const SwizzleType& swizzle_a,
|
||||
const Tensor& scale_b,
|
||||
const SwizzleType& swizzle_b,
|
||||
const std::optional<at::Tensor>& offs,
|
||||
Tensor& out) {
|
||||
const bool a_is_2d = mat_a.dim() == 2;
|
||||
const bool b_is_2d = mat_b.dim() == 2;
|
||||
bool b_is_3d = mat_b.dim() == 3;
|
||||
bool is_2d_2d = a_is_2d && b_is_2d;
|
||||
bool is_2d_3d = a_is_2d && b_is_3d;
|
||||
TORCH_CHECK_VALUE(is_2d_2d || is_2d_3d, "MXFP8 grouped GEMM currently only supports 2d-2d and 2d-3d cases");
|
||||
TORCH_CHECK_VALUE(offs.has_value(), "MXFP8 2d-2d and 2d-3d grouped GEMMs requires offsets");
|
||||
TORCH_CHECK_VALUE(out.scalar_type() == at::kBFloat16, "Only bf16 out_dtype is supported for MXFP8 grouped gemm");
|
||||
// MXFP8 expects float8_e8m0fnu scales.
|
||||
TORCH_CHECK_VALUE(scale_a.scalar_type() == at::kFloat8_e8m0fnu && scale_b.scalar_type() == at::kFloat8_e8m0fnu,
|
||||
"For MXFP8 grouped gemm, both scales must be float8_e8m0fnu tensors.");
|
||||
#ifdef USE_ROCM
|
||||
TORCH_CHECK_VALUE(swizzle_a == SwizzleType::NO_SWIZZLE && swizzle_b == SwizzleType::NO_SWIZZLE,
|
||||
"For ROCM MXFP8 grouped gemm, both scale swizzle types must be SWIZZLE_NONE");
|
||||
#else
|
||||
TORCH_CHECK_VALUE(swizzle_a == SwizzleType::SWIZZLE_32_4_4 && swizzle_b == SwizzleType::SWIZZLE_32_4_4,
|
||||
"For CUDA MXFP8 grouped gemm, both scale swizzle types must be SWIZZLE_32_4_4");
|
||||
#endif
|
||||
|
||||
#if defined(USE_FBGEMM_GENAI) and !defined(USE_ROCM)
|
||||
fbgemm_gpu::mx8mx8bf16_grouped_mm(
|
||||
mat_a,
|
||||
mat_b,
|
||||
scale_a,
|
||||
scale_b,
|
||||
offs.value(),
|
||||
out);
|
||||
#else
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(false, "mxfp8_mxfp8 grouped gemm requires compile with USE_FBGEMM_GENAI");
|
||||
#endif
|
||||
return out;
|
||||
}
|
||||
|
||||
// 2d-2d and 2d-3d cases
|
||||
// scaling=rowwise
|
||||
// CUDA-only
|
||||
Tensor&
|
||||
_f8_f8_bf16_rowwise_grouped_mm_cuda(
|
||||
const Tensor& mat_a,
|
||||
const Tensor& mat_b,
|
||||
const Tensor& scale_a,
|
||||
const Tensor& scale_b,
|
||||
const std::optional<Tensor>& offs,
|
||||
const std::optional<Tensor>& bias,
|
||||
const bool use_fast_accum,
|
||||
Tensor& out) {
|
||||
TORCH_CHECK_VALUE(mat_a.dtype() == at::kFloat8_e4m3fn, "Expected mat_a to be Float8_e4m3 matrix got ", mat_a.scalar_type());
|
||||
TORCH_CHECK_VALUE(mat_b.dtype() == at::kFloat8_e4m3fn, "Expected mat_a to be Float8_e4m3 matrix got ", mat_b.scalar_type());
|
||||
|
||||
at::cuda::detail::f8f8bf16_grouped_mm(
|
||||
mat_a,
|
||||
mat_b,
|
||||
scale_a,
|
||||
scale_b,
|
||||
offs,
|
||||
bias,
|
||||
use_fast_accum,
|
||||
out);
|
||||
return out;
|
||||
}
|
||||
|
||||
// 2d-2d and 2d-3d cases
|
||||
// scaling=rowwise
|
||||
// only being called for rocm
|
||||
Tensor&
|
||||
_f8_f8_bf16_rowwise_grouped_mm_rocm(
|
||||
const Tensor& mat_a,
|
||||
const Tensor& mat_b,
|
||||
const Tensor& scale_a,
|
||||
const Tensor& scale_b,
|
||||
const std::optional<Tensor>& offs,
|
||||
Tensor& out) {
|
||||
TORCH_CHECK_VALUE(mat_a.dtype() == at::kFloat8_e4m3fnuz, "Expected mat_a to be Float8_e4m3fnuz matrix got ", mat_a.scalar_type());
|
||||
TORCH_CHECK_VALUE(mat_b.dtype() == at::kFloat8_e4m3fnuz, "Expected mat_a to be Float8_e4m3fnuz matrix got ", mat_b.scalar_type());
|
||||
|
||||
#if defined(USE_FBGEMM_GENAI) && defined(USE_ROCM)
|
||||
fbgemm_gpu::f8f8bf16_rowwise_grouped_mm(
|
||||
mat_a,
|
||||
// FBGEMM expects B matrix shape to be (.., N, K)
|
||||
mat_b.transpose(-2, -1),
|
||||
scale_a,
|
||||
scale_b,
|
||||
offs,
|
||||
out);
|
||||
#else
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(false, "grouped gemm is not supported without USE_FBGEMM_GENAI on ROCM")
|
||||
#endif
|
||||
return out;
|
||||
|
||||
}
|
||||
|
||||
// Dispatch f8 x f8 -> bf16 row-wise scaled to rocm/cuda
|
||||
Tensor&
|
||||
_f8_f8_bf16_rowwise_grouped_mm(
|
||||
const Tensor& mat_a,
|
||||
const Tensor& mat_b,
|
||||
const Tensor& scale_a,
|
||||
const Tensor& scale_b,
|
||||
const std::optional<Tensor>& offs,
|
||||
const std::optional<Tensor>& bias,
|
||||
bool use_fast_accum,
|
||||
Tensor& out) {
|
||||
// FP8 per-tensor and per-row scaling expect fp32 scales.
|
||||
TORCH_CHECK_VALUE(scale_a.scalar_type() == kFloat && scale_b.scalar_type() == kFloat,
|
||||
"For grouped FP8 rowwise, both scales must be float32 tensors");
|
||||
#ifndef USE_ROCM
|
||||
return _f8_f8_bf16_rowwise_grouped_mm_cuda(
|
||||
mat_a,
|
||||
mat_b,
|
||||
scale_a,
|
||||
scale_b,
|
||||
offs,
|
||||
bias,
|
||||
use_fast_accum,
|
||||
out);
|
||||
#else
|
||||
// NOTE: ignore use_fast_accum
|
||||
TORCH_CHECK_VALUE(!bias.has_value(), "ROCM grouped gemm does not support bias")
|
||||
return _f8_f8_bf16_rowwise_grouped_mm_rocm(
|
||||
mat_a,
|
||||
mat_b,
|
||||
scale_a,
|
||||
scale_b,
|
||||
offs,
|
||||
out);
|
||||
#endif
|
||||
}
|
||||
|
||||
void _check_scales_fp8_rowwise(const Tensor& mat, const Tensor& scale, const int dim, const int arg_idx, const int scale_multiplier=1) {
|
||||
// Checks scales for 2d or 3d target tensors (`mat`).
|
||||
if (mat.dim() == 2) {
|
||||
TORCH_CHECK(
|
||||
scale.dim() == 1,
|
||||
"scale must be a 1D tensor, but got ",
|
||||
scale.dim(),
|
||||
"D, arg ",
|
||||
arg_idx);
|
||||
TORCH_CHECK(
|
||||
scale.is_contiguous(), "scale must be contiguous for arg ", arg_idx);
|
||||
TORCH_CHECK(
|
||||
scale.size(0) == mat.size(dim) * scale_multiplier,
|
||||
"scale must have the same length as mat for arg ",
|
||||
arg_idx);
|
||||
} else {
|
||||
TORCH_CHECK(
|
||||
scale.dim() == 2,
|
||||
"scale must be a 2D tensor, but got ",
|
||||
scale.dim(),
|
||||
"D for arg ",
|
||||
arg_idx);
|
||||
TORCH_CHECK(
|
||||
scale.stride(1) == 1,
|
||||
"scale must be contiguous in the last dimension for arg ",
|
||||
arg_idx);
|
||||
TORCH_CHECK(
|
||||
scale.size(0) == mat.size(0),
|
||||
"scale must have the same batch dimension as mat for arg ",
|
||||
arg_idx);
|
||||
TORCH_CHECK(
|
||||
scale.size(1) == mat.size(1 + dim),
|
||||
"scale must have the same first dimension as mat for arg ",
|
||||
arg_idx);
|
||||
}
|
||||
}
|
||||
|
||||
void _check_scales_mxfp8(const Tensor& mat, const Tensor& scale, const int dim, const int arg_idx) {
|
||||
// Checks scales for 2d or 3d target tensors (`mat`).
|
||||
if (mat.dim() == 2) {
|
||||
// For MXFP8, 2d tensors have variable size groups represented as subtensors,
|
||||
// that are converted to blocked padded format individually,
|
||||
// so we can't check the scale sizes without doing a d2h sync to get the group sizes here.
|
||||
TORCH_CHECK(
|
||||
scale.dim() == mat.dim(),
|
||||
"for mxfp8, scale must have same number of dimensions as parent tensor, but got mat.dim() = ", mat.dim(), " and scale.dim() = ", scale.dim(), " for arg ", arg_idx);
|
||||
|
||||
// LHS mat shape (M, total_K) -> scale shape (rounded_up(M, 128), rounded_up_per_group(K/32, 4))
|
||||
// RHS mat shape (total_K, N) -> scale shape (rounded_up(N, 128), rounded_up_per_group(K/32, 4))
|
||||
// * weight is transposed prior to the call, scale stays non-transposed.
|
||||
bool LHS = arg_idx == 0;
|
||||
int scale_dim_to_check = 0;
|
||||
int mat_dim_to_check = LHS ? 0 : 1;
|
||||
TORCH_CHECK(
|
||||
scale.size(scale_dim_to_check) >= mat.size(mat_dim_to_check),
|
||||
"for mxfp8, arg ", arg_idx, " tensor shape (", mat.size(0), ", ", mat.size(1), ") ",
|
||||
"must have scale.shape[", scale_dim_to_check, "] >= ", mat.size(mat_dim_to_check), " but got scale.shape=(", scale.size(0), ", ", scale.size(1), ")");
|
||||
} else {
|
||||
// For MXFP8, 3d tensors have static group sizes (stack of 2d tensors),
|
||||
// so we can check the exact expected scale sizes here without a d2h sync.
|
||||
auto round_up = [](auto x, auto y) {
|
||||
return ((x + y - 1) / y) * y;
|
||||
};
|
||||
|
||||
// TODO: this is for 3d tensor in 2d-3d case specifically.
|
||||
// We'll need to support 3d-3d and 3d-2d cases once mxfp8 grouped gemm supports them.
|
||||
int64_t G = mat.size(0);
|
||||
int64_t K = mat.size(1);
|
||||
int64_t N = mat.size(2);
|
||||
int64_t blocked_scale_K = round_up(K/32, 4);
|
||||
int64_t blocked_scale_N = round_up(N, 128);
|
||||
|
||||
// fbgemm expects stack of flattened blocked scales for 3d tensor, shape (G, blocked_scale_K * blocked_scale_N).
|
||||
TORCH_CHECK(
|
||||
scale.dim() == mat.dim() - 1,
|
||||
"for mxfp8 2d-3d grouped GEMM, the 3d tensor of shape (G,K,N) must have a 2d scale of shape (G, blocked_scale_K * blocked_scale_N), but scale is ", scale.dim(), "D for arg ", arg_idx
|
||||
);
|
||||
TORCH_CHECK(
|
||||
scale.size(0) == G && scale.size(1) == blocked_scale_K * blocked_scale_N,
|
||||
"for mxfp8, the tensor shape (", G, ", ", K, ", ", N, ") must have scale shape (", G, ",", blocked_scale_K, ",", blocked_scale_N, ") for arg ", arg_idx
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
void check_scale(const Tensor& mat, const Tensor& scale, const int dim, const int arg_idx, const int scale_multiplier=1) {
|
||||
bool using_fp8_rowwise = scale.scalar_type() == kFloat;
|
||||
bool using_mxfp8 = scale.scalar_type() == at::kFloat8_e8m0fnu;
|
||||
if (using_fp8_rowwise) {
|
||||
_check_scales_fp8_rowwise(mat, scale, dim, arg_idx, scale_multiplier);
|
||||
} else if (using_mxfp8) {
|
||||
_check_scales_mxfp8(mat, scale, dim, arg_idx);
|
||||
} else {
|
||||
TORCH_CHECK(false, "scale must be float32 or float8_e8m0fnu, but got ", scale.dtype());
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
Tensor
|
||||
_scaled_grouped_mm_cuda(
|
||||
const Tensor& mat_a,
|
||||
const Tensor& mat_b,
|
||||
const Tensor& scale_a,
|
||||
const Tensor& scale_b,
|
||||
const std::optional<at::Tensor>& offs,
|
||||
const std::optional<at::Tensor>& bias,
|
||||
const std::optional<at::Tensor>& scale_result,
|
||||
std::optional<c10::ScalarType> out_dtype,
|
||||
bool use_fast_accum) {
|
||||
bool allowed_device = _scaled_mm_allowed_device(/*sm90_only*/true, /*sm100_only*/true);
|
||||
TORCH_CHECK_VALUE(allowed_device, "torch._scaled_grouped_mm is only supported on CUDA devices with compute capability = [9.0, 10.0], or ROCm MI300+");
|
||||
|
||||
TORCH_CHECK_VALUE(!check_valid_strides_and_return_transposed(mat_a), "Expected mat1 to not be transposed");
|
||||
TORCH_CHECK_VALUE(check_valid_strides_and_return_transposed(mat_b), "Expected mat2 to be transposed");
|
||||
TORCH_CHECK_VALUE(mat_a.dim() == 2 || mat_a.dim() == 3, "mat_a has to be 2 or 3d");
|
||||
TORCH_CHECK_VALUE(mat_b.dim() == 2 || mat_b.dim() == 3, "mat_b has to be 2 or 3d");
|
||||
const bool a_is_2d = mat_a.dim() == 2;
|
||||
const bool b_is_2d = mat_b.dim() == 2;
|
||||
|
||||
// NOTE(slayton): For sub-1B formats want contraction_dim argument?
|
||||
if (!a_is_2d || !b_is_2d) {
|
||||
TORCH_CHECK_VALUE(mat_a.size(-1) == mat_b.size(-2), "contraction dimension of mat_a and mat_b must match");
|
||||
}
|
||||
TORCH_CHECK_VALUE(
|
||||
mat_a.size(-1) % 16 == 0,
|
||||
"Expected trailing dimension of mat_a to be divisible by 16 ",
|
||||
"but got mat1 shape: (",
|
||||
mat_a.sizes(),
|
||||
").");
|
||||
TORCH_CHECK_VALUE(mat_b.size(-2) % 16 == 0 && mat_b.size(-1) % 16 == 0,
|
||||
"Expected mat_b shape to be divisible by 16 ",
|
||||
"but got mat_b shape: (",
|
||||
mat_b.sizes(),
|
||||
").");
|
||||
|
||||
|
||||
TORCH_CHECK_VALUE(!bias.has_value(), "Bias not supported yet");
|
||||
TORCH_CHECK_VALUE(!scale_result.has_value(), "Scale result not supported yet");
|
||||
TORCH_CHECK_VALUE(offs.has_value() == (a_is_2d || b_is_2d), "Have to provide offsets if there is a 2d matrix");
|
||||
|
||||
// NOTE: mxfp8 x mxfp8 requires (and asserts later) that offsets is present.
|
||||
// for rowwise, no offsets implies 3d-3d and is handled by lower-level
|
||||
// routines
|
||||
if (offs.has_value()) {
|
||||
TORCH_CHECK_VALUE(offs->dim() == 1, "offs has to be 1D");
|
||||
TORCH_CHECK_VALUE(offs->dtype() == at::kInt, "Offsets have to be int32");
|
||||
}
|
||||
// FP8 per-tensor and per-row scaling expect fp32 scales.
|
||||
// MXFP8 expects float8_e8m0fnu scales.
|
||||
TORCH_CHECK_VALUE(
|
||||
(scale_a.scalar_type() == kFloat && scale_b.scalar_type() == kFloat) ||
|
||||
(scale_a.scalar_type() == at::kFloat8_e8m0fnu && scale_b.scalar_type() == at::kFloat8_e8m0fnu),
|
||||
"For FP8 tensorwise and rowwise, both scales must both be float32 tensors. For MXFP8, scales must both be float8_e8m0fnu tensors.");
|
||||
|
||||
const int scale_multiplier = (mat_a.dim() == 2 && mat_b.dim() == 2) ? offs->size(0) : 1;
|
||||
check_scale(mat_a, scale_a, 0 ,0, scale_multiplier);
|
||||
check_scale(mat_b, scale_b, 1, 1, scale_multiplier);
|
||||
|
||||
const auto out_dtype_ = out_dtype.value_or(kBFloat16);
|
||||
TORCH_CHECK_VALUE(out_dtype_ == kBFloat16, "Only bf16 high precision output types are supported for grouped gemm");
|
||||
|
||||
Tensor out = create_grouped_gemm_output_tensor(mat_a, mat_b, offs, out_dtype_);
|
||||
|
||||
#if defined(USE_FBGEMM_GENAI) && defined(USE_CUDA) && !defined(USE_ROCM)
|
||||
// MXFP8 grouped GEMM dispatching
|
||||
bool is_mx8mx8bf16 = (
|
||||
mat_a.scalar_type() == at::kFloat8_e4m3fn && mat_b.scalar_type() == at::kFloat8_e4m3fn &&
|
||||
scale_a.scalar_type() == at::kFloat8_e8m0fnu && scale_b.scalar_type() == at::kFloat8_e8m0fnu
|
||||
);
|
||||
#else
|
||||
bool is_mx8mx8bf16 = false;
|
||||
#endif
|
||||
|
||||
if (is_mx8mx8bf16) {
|
||||
// Note: Passing implied SwizzleType here, correctness of scale previously checked
|
||||
// in `check_scale` call
|
||||
return _mx8_mx8_bf16_grouped_mm_fbgemm(
|
||||
mat_a,
|
||||
mat_b,
|
||||
scale_a,
|
||||
SwizzleType::SWIZZLE_32_4_4,
|
||||
scale_b,
|
||||
SwizzleType::SWIZZLE_32_4_4,
|
||||
offs.value(),
|
||||
out);
|
||||
}
|
||||
|
||||
// If we're not MXFP8, then we're row-wise scaling.
|
||||
return _f8_f8_bf16_rowwise_grouped_mm(
|
||||
mat_a,
|
||||
mat_b,
|
||||
scale_a,
|
||||
scale_b,
|
||||
offs,
|
||||
bias,
|
||||
use_fast_accum,
|
||||
out);
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
using acceptance_fn = std::function<bool(c10::ScalarType, std::vector<ScalingType>&, ArrayRef<Tensor>&, c10::ScalarType, std::vector<ScalingType>&, ArrayRef<Tensor>&)>;
|
||||
|
||||
std::array<std::tuple<std::string, acceptance_fn, ScaledGemmImplementation>, 2> scale_grouped_kernel_dispatch = {{
|
||||
{ "rowwise_rowwise", scaled_blas::check_rowwise_recipe, ScaledGemmImplementation::ROWWISE_ROWWISE},
|
||||
{ "mxfp8_mxfp8", scaled_blas::check_mxfp8_recipe, ScaledGemmImplementation::MXFP8_MXFP8}}};
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
Tensor
|
||||
_scaled_grouped_mm_cuda_v2(
|
||||
const Tensor& mat_a, const Tensor& mat_b,
|
||||
ArrayRef<Tensor> scale_a,
|
||||
IntArrayRef scale_recipe_a,
|
||||
IntArrayRef swizzle_a,
|
||||
ArrayRef<Tensor> scale_b,
|
||||
IntArrayRef scale_recipe_b,
|
||||
IntArrayRef swizzle_b,
|
||||
const std::optional<Tensor>& offs,
|
||||
const std::optional<Tensor>& bias,
|
||||
const std::optional<c10::ScalarType> out_dtype,
|
||||
IntArrayRef contraction_dim,
|
||||
bool use_fast_accum) {
|
||||
bool allowed_device = _scaled_mm_allowed_device(/*sm90_only*/true, /*sm100_only*/true);
|
||||
TORCH_CHECK_VALUE(allowed_device, "torch._scaled_grouped_mm is only supported on CUDA devices with compute capability = [9.0, 10.0], or ROCm MI300+");
|
||||
|
||||
TORCH_CHECK_VALUE(!check_valid_strides_and_return_transposed(mat_a), "Expected mat1 to not be transposed");
|
||||
TORCH_CHECK_VALUE(check_valid_strides_and_return_transposed(mat_b), "Expected mat2 to be transposed");
|
||||
TORCH_CHECK_VALUE(mat_a.dim() == 2 || mat_a.dim() == 3, "mat_a has to be 2 or 3d");
|
||||
TORCH_CHECK_VALUE(mat_b.dim() == 2 || mat_b.dim() == 3, "mat_b has to be 2 or 3d");
|
||||
const bool a_is_2d = mat_a.dim() == 2;
|
||||
const bool b_is_2d = mat_b.dim() == 2;
|
||||
|
||||
// NOTE(slayton): For sub-1B formats want contraction_dim argument?
|
||||
if (!a_is_2d || !b_is_2d) {
|
||||
if (contraction_dim.size() > 0) {
|
||||
const int dim_a = contraction_dim[0], dim_b = mat_b.size(contraction_dim[1]);
|
||||
TORCH_CHECK_VALUE(mat_a.size(dim_a) == mat_b.size(dim_b),
|
||||
"Contraction dimensions (", dim_a, ",", dim_b, ") of mat_a and mat_b must match, got: ", mat_a.size(dim_a), " and ",
|
||||
mat_b.size(dim_b));
|
||||
// Note: only (-1, -2) is currently supported
|
||||
TORCH_CHECK_VALUE(dim_a == -1 && dim_b == -2, "Curently contraction dims must be (-1, -2) only");
|
||||
} else {
|
||||
TORCH_CHECK_VALUE(mat_a.size(-1) == mat_b.size(-2), "contraction dimension of mat_a and mat_b must match");
|
||||
}
|
||||
}
|
||||
TORCH_CHECK_VALUE(
|
||||
mat_a.size(-1) % 16 == 0,
|
||||
"Expected trailing dimension of mat_a to be divisible by 16 ",
|
||||
"but got mat1 shape: (",
|
||||
mat_a.sizes(),
|
||||
").");
|
||||
TORCH_CHECK_VALUE(mat_b.size(-2) % 16 == 0 && mat_b.size(-1) % 16 == 0,
|
||||
"Expected mat_b shape to be divisible by 16 ",
|
||||
"but got mat_b shape: (",
|
||||
mat_b.sizes(),
|
||||
").");
|
||||
|
||||
TORCH_CHECK_VALUE(!bias.has_value(), "Bias not supported yet");
|
||||
TORCH_CHECK_VALUE(offs.has_value() == (a_is_2d || b_is_2d), "Have to provide offsets if there is a 2d matrix");
|
||||
|
||||
// NOTE: mxfp8 x mxfp8 requires (and asserts later) that offsets is present.
|
||||
// for rowwise, no offsets implies 3d-3d and is handled by lower-level
|
||||
// routines
|
||||
if (offs.has_value()) {
|
||||
TORCH_CHECK_VALUE(offs->dim() == 1, "offs has to be 1D");
|
||||
TORCH_CHECK_VALUE(offs->dtype() == at::kInt, "Offsets have to be int32");
|
||||
}
|
||||
|
||||
const auto out_dtype_ = out_dtype.value_or(kBFloat16);
|
||||
TORCH_CHECK_VALUE(out_dtype_ == kBFloat16, "Only bf16 high precision output types are supported for grouped gemm");
|
||||
|
||||
Tensor out = create_grouped_gemm_output_tensor(mat_a, mat_b, offs, out_dtype_);
|
||||
|
||||
// Conversion of implicitly-defined enums to explicit
|
||||
auto scale_recipe_a_enum = convert_int_to_enum<ScalingType>(scale_recipe_a);
|
||||
auto swizzle_a_enum = convert_int_to_enum<SwizzleType>(swizzle_a);
|
||||
auto scale_recipe_b_enum = convert_int_to_enum<ScalingType>(scale_recipe_b);
|
||||
auto swizzle_b_enum = convert_int_to_enum<SwizzleType>(swizzle_b);
|
||||
|
||||
// at this point we can start working out what we want to be doing
|
||||
// Try to do as few steps as possible.
|
||||
// NOTE: support is deliberately sparse, can explicitly enumerate all combinations allowed.
|
||||
// Do this via a list of defined (name, acceptance, concrete_impl) tuples.
|
||||
ScaledGemmImplementation gemm_impl = ScaledGemmImplementation::NONE;
|
||||
for (const auto& fn_entry : scale_grouped_kernel_dispatch) {
|
||||
const auto [name, accept_fn, scaled_gemm_impl] = fn_entry;
|
||||
bool ok = accept_fn(mat_a.scalar_type(),
|
||||
scale_recipe_a_enum,
|
||||
scale_a,
|
||||
mat_b.scalar_type(),
|
||||
scale_recipe_b_enum,
|
||||
scale_b);
|
||||
if (ok) {
|
||||
gemm_impl = scaled_gemm_impl;
|
||||
break;
|
||||
}
|
||||
}
|
||||
TORCH_CHECK_VALUE(gemm_impl != ScaledGemmImplementation::NONE,
|
||||
"No gemm implementation was found");
|
||||
|
||||
switch (gemm_impl) {
|
||||
case ScaledGemmImplementation::ROWWISE_ROWWISE: {
|
||||
const int scale_multiplier = (mat_a.dim() == 2 && mat_b.dim() == 2) ? offs->size(0) : 1;
|
||||
_check_scales_fp8_rowwise(mat_a, scale_a[0], 0 /* dim */ , 0 /* arg_idx */, scale_multiplier);
|
||||
_check_scales_fp8_rowwise(mat_b, scale_b[0], 1 /* dim */ , 1 /* arg_idx */, scale_multiplier);
|
||||
return _f8_f8_bf16_rowwise_grouped_mm(
|
||||
mat_a,
|
||||
mat_b,
|
||||
scale_a[0],
|
||||
scale_b[0],
|
||||
offs,
|
||||
bias,
|
||||
use_fast_accum,
|
||||
out);
|
||||
}
|
||||
case ScaledGemmImplementation::MXFP8_MXFP8: {
|
||||
_check_scales_mxfp8(mat_a, scale_a[0], 0 /* dim */, 0 /* arg_idx */);
|
||||
_check_scales_mxfp8(mat_b, scale_b[0], 1 /* dim */, 1 /* arg_idx */);
|
||||
return _mx8_mx8_bf16_grouped_mm_fbgemm(
|
||||
mat_a,
|
||||
mat_b,
|
||||
scale_a[0],
|
||||
swizzle_a_enum[0],
|
||||
scale_b[0],
|
||||
swizzle_b_enum[0],
|
||||
offs.value(),
|
||||
out);
|
||||
}
|
||||
default:
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(false,
|
||||
"_scaled_grouped_mm_cuda_v2 is in an inconsistent state - should never reach here");
|
||||
}
|
||||
}
|
||||
|
||||
Tensor _grouped_mm_cuda(const Tensor& mat_a, const Tensor& mat_b,
|
||||
const std::optional<at::Tensor>& offs,
|
||||
const std::optional<at::Tensor>& bias,
|
||||
std::optional<c10::ScalarType> out_dtype) {
|
||||
_grouped_mm_validate_inputs(mat_a, mat_b, offs, bias, out_dtype);
|
||||
bool a_b_and_out_are_bf16 = (
|
||||
mat_a.dtype() == at::kBFloat16 &&
|
||||
mat_b.dtype() == at::kBFloat16 &&
|
||||
out_dtype.value_or(at::kBFloat16) == at::kBFloat16
|
||||
);
|
||||
#ifndef USE_ROCM
|
||||
bool use_fast_path = _scaled_mm_allowed_device(/*sm90_only*/true, /*sm100_only*/true) && a_b_and_out_are_bf16;
|
||||
#else
|
||||
// _scaled_mm_allowed_device is used here within _grouped_mm_cuda which seems incorrect since scale is not used.
|
||||
// the _grouped_mm_fallback should be safe for any ROCm GPU since it's just calling typical mm/bmm
|
||||
bool use_fast_path = false;
|
||||
#endif
|
||||
const auto out_dtype_ = _resolve_grouped_mm_out_dtype(mat_a, mat_b, out_dtype);
|
||||
Tensor out = create_grouped_gemm_output_tensor(mat_a, mat_b, offs, out_dtype_);
|
||||
if (use_fast_path) {
|
||||
// fast path, no d2h sync needed
|
||||
at::cuda::detail::bf16bf16_grouped_mm(mat_a, mat_b, offs, bias, out);
|
||||
} else {
|
||||
_grouped_mm_fallback(mat_a, mat_b, offs, bias, out_dtype, out);
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
} // namespace at::native
|
||||
@ -6,7 +6,7 @@
|
||||
#endif
|
||||
|
||||
// ROCm 6.3 is planned to have these functions, but until then here they are.
|
||||
#if defined(USE_ROCM) && ROCM_VERSION >= 60201
|
||||
#if defined(USE_ROCM)
|
||||
#include <device_functions.h>
|
||||
#include <hip/hip_fp16.h>
|
||||
#include <hip/hip_bf16.h>
|
||||
@ -115,9 +115,7 @@ __device__ __forceinline__ void fastSpecializedAtomicAdd(
|
||||
index_t index,
|
||||
const index_t numel,
|
||||
scalar_t value) {
|
||||
#if ( \
|
||||
(defined(USE_ROCM) && ROCM_VERSION < 60201) || \
|
||||
(defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 700)))
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 700))
|
||||
gpuAtomicAddNoReturn(
|
||||
reinterpret_cast<at::Half*>(tensor) + index,
|
||||
static_cast<at::Half>(value));
|
||||
@ -160,9 +158,7 @@ __device__ __forceinline__ void fastSpecializedAtomicAdd(
|
||||
index_t index,
|
||||
const index_t numel,
|
||||
scalar_t value) {
|
||||
#if ( \
|
||||
(defined(USE_ROCM) && ROCM_VERSION < 60201) || \
|
||||
(defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)))
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))
|
||||
gpuAtomicAddNoReturn(
|
||||
reinterpret_cast<at::BFloat16*>(tensor) + index,
|
||||
static_cast<at::BFloat16>(value));
|
||||
|
||||
@ -154,7 +154,7 @@ REGISTER_CUDA_DISPATCH(lstsq_stub, &lazy_lstsq_kernel)
|
||||
|
||||
// Old style dispatches
|
||||
// torch_cuda_linalg dynamic library should have a global constructor
|
||||
// that calls regiserLinaglDispatch so in order ot lazy bind
|
||||
// that calls registerLinalgDispatch so in order ot lazy bind
|
||||
// old style dispatch all one have to do is to load library and call disp.func_name
|
||||
// Protect from infinite recursion by initializing dispatch to self and checking
|
||||
// that values are different after linalg library were loaded
|
||||
|
||||
@ -121,7 +121,7 @@ __device__ scalar_t reduce(Op op, PTA tensor, int plane) {
|
||||
for (int x = threadIdx.x; x < tensor.size(2); x += blockDim.x*UNRL) {
|
||||
#pragma unroll
|
||||
for (int u = 0; u < UNRL; u++)
|
||||
tmp[u] = op(batch, plane, min((int)tensor.size(2)-1, (int)(x+u*blockDim.x)));
|
||||
tmp[u] = op(batch, plane, std::min((int)tensor.size(2)-1, (int)(x+u*blockDim.x)));
|
||||
#pragma unroll
|
||||
for (int u = 0; u < UNRL; u++)
|
||||
if (x+u*blockDim.x < tensor.size(2))
|
||||
@ -306,6 +306,22 @@ __global__ void batch_norm_collect_statistics_kernel(
|
||||
stat_accscalar_t var_n = 0;
|
||||
int n = 0;
|
||||
for (int batch = threadIdx.y; batch < input.size(0); batch += blockDim.y) {
|
||||
#if defined(USE_ROCM)
|
||||
constexpr int UNRL = 4;
|
||||
stat_accscalar_t v_[UNRL];
|
||||
for (int x = threadIdx.x; x < input.size(2); x += blockDim.x*UNRL) {
|
||||
for (int u = 0; u < UNRL; u++)
|
||||
v_[u] = input[batch][plane][min(x+u*blockDim.x, input.size(2)-1)];
|
||||
for (int u = 0; u < UNRL; u++) {
|
||||
if (x+u*blockDim.x < input.size(2)) {
|
||||
stat_accscalar_t d1 = v_[u] - avg;
|
||||
n++;
|
||||
avg += d1 / n;
|
||||
var_n += d1 * (v_[u] - avg);
|
||||
}
|
||||
}
|
||||
}
|
||||
#else
|
||||
for (int x = threadIdx.x; x < input.size(2); x += blockDim.x) {
|
||||
stat_accscalar_t v = input[batch][plane][x];
|
||||
stat_accscalar_t d1 = v - avg;
|
||||
@ -313,6 +329,7 @@ __global__ void batch_norm_collect_statistics_kernel(
|
||||
avg += d1 / n;
|
||||
var_n += d1 * (v - avg);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
// first warpSum to get one value per thread to
|
||||
|
||||
@ -43,6 +43,12 @@ std::tuple<Tensor&, Tensor&> kthvalue_out_impl_cuda(
|
||||
TORCH_CHECK(k >= 1 && k <= slicesize,
|
||||
"kthvalue(): selected number k out of range for dimension ", dim);
|
||||
|
||||
TORCH_CHECK(
|
||||
slicesize <= std::numeric_limits<int32_t>::max(),
|
||||
"kthvalue(): dimension ", dim, " is too large (", slicesize,
|
||||
"). The current CUDA implementation supports dimension sizes up to ",
|
||||
std::numeric_limits<int32_t>::max());
|
||||
|
||||
at::assert_no_overlap(self, values);
|
||||
|
||||
_reduction_with_indices_allocate_or_resize_output(
|
||||
@ -163,10 +169,6 @@ std::tuple<Tensor&, Tensor&> kthvalue_out_cuda(
|
||||
bool keepdim,
|
||||
Tensor& values,
|
||||
Tensor& indices) {
|
||||
// See note [Writing Nondeterministic Operations]
|
||||
// If there are duplicate elements of the kth value, the procedure for choosing which
|
||||
// of the duplicates to use for the indices output is nondeterministic.
|
||||
at::globalContext().alertNotDeterministic("kthvalue CUDA");
|
||||
auto result = [&]() {
|
||||
NoNamesGuard guard;
|
||||
// `kthvalue_out_impl_cuda` expects contiguous in input `self`.
|
||||
|
||||
@ -65,25 +65,34 @@ __global__ void gatherKthValue(
|
||||
&kValue);
|
||||
|
||||
// Find the index of the k-th highest element
|
||||
index_t kValueIndex = 0;
|
||||
bool foundKValue = false;
|
||||
__shared__ int32_t minIndexFound;
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
minIndexFound = static_cast<int32_t>(inputSliceSize);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
for (index_t i = threadIdx.x; i < inputSliceSize; i += blockDim.x) {
|
||||
bool inRange = (i < inputSliceSize);
|
||||
scalar_t v = inRange ? doLdg(&inputSliceStart[i * inputWithinSliceStride])
|
||||
: static_cast<scalar_t>(0);
|
||||
bool isKValue = inRange &&
|
||||
((v == kValue) || (at::_isnan(v) && at::_isnan(kValue)));
|
||||
if (isKValue) {
|
||||
kValueIndex = i;
|
||||
foundKValue = true;
|
||||
break;
|
||||
}
|
||||
// Early exit based on best-so-far
|
||||
if (i >= minIndexFound) {
|
||||
break;
|
||||
}
|
||||
|
||||
scalar_t v = doLdg(&inputSliceStart[i * inputWithinSliceStride]);
|
||||
bool isKValue =
|
||||
((v == kValue) || (at::_isnan(v) && at::_isnan(kValue)));
|
||||
|
||||
if (isKValue) {
|
||||
atomicMin(&minIndexFound, static_cast<int32_t>(i));
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (foundKValue) {
|
||||
kthValueSliceStart[0] = kValue;
|
||||
indicesSliceStart[0] = kValueIndex;
|
||||
__syncthreads();
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
indicesSliceStart[0] = static_cast<index_t>(minIndexFound);
|
||||
kthValueSliceStart[0] = kValue;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -127,6 +127,29 @@ __global__ void upsample_bilinear2d_nhwc_out_frame(
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef USE_ROCM
|
||||
// Helper function to compute output pixel range that can contribute to input pixel
|
||||
template <typename accscalar_t>
|
||||
__device__ __forceinline__ void compute_output_range(
|
||||
int input_pos,
|
||||
accscalar_t scale,
|
||||
int output_size,
|
||||
bool align_corners,
|
||||
int& min_output,
|
||||
int& max_output) {
|
||||
accscalar_t lo, hi;
|
||||
if (align_corners) {
|
||||
lo = static_cast<accscalar_t>(input_pos - 1) / scale;
|
||||
hi = static_cast<accscalar_t>(input_pos + 1) / scale;
|
||||
} else {
|
||||
lo = (input_pos - static_cast<accscalar_t>(0.5)) / scale - static_cast<accscalar_t>(0.5);
|
||||
hi = (input_pos + static_cast<accscalar_t>(1.5)) / scale - static_cast<accscalar_t>(0.5);
|
||||
}
|
||||
min_output = max(0, static_cast<int>(std::ceil(lo)));
|
||||
max_output = min(output_size - 1, static_cast<int>(std::floor(hi)));
|
||||
}
|
||||
#endif
|
||||
|
||||
// Backward (adjoint) operation 1 <- 2 (accumulates)
|
||||
template <typename scalar_t, typename accscalar_t>
|
||||
C10_LAUNCH_BOUNDS_1(1024)
|
||||
@ -141,8 +164,74 @@ __global__ void upsample_bilinear2d_backward_out_frame(
|
||||
const bool align_corners,
|
||||
scalar_t* __restrict__ idata,
|
||||
const scalar_t* __restrict__ odata) {
|
||||
const size_t o_numel = nc * width2 * height2;
|
||||
// In C++, integer multiplication, like in standard arithmetic, is generally commutative.
|
||||
const size_t i_numel = nc * width1 * height1;
|
||||
#ifdef USE_ROCM
|
||||
for (size_t index = blockDim.x * blockIdx.x + threadIdx.x; index < i_numel;
|
||||
index += blockDim.x * gridDim.x) {
|
||||
// Decode input pixel coordinates
|
||||
size_t index_temp = index;
|
||||
const int w1 = index_temp % width1;
|
||||
index_temp /= width1;
|
||||
const int h1 = index_temp % height1;
|
||||
const size_t nc_idx = index_temp / height1;
|
||||
|
||||
accscalar_t grad_sum = 0;
|
||||
|
||||
// Find range of output pixels that could interpolate from this input pixel
|
||||
int h2_min, h2_max, w2_min, w2_max;
|
||||
compute_output_range<accscalar_t>(h1, rheight, height2, align_corners, h2_min, h2_max);
|
||||
compute_output_range<accscalar_t>(w1, rwidth, width2, align_corners, w2_min, w2_max);
|
||||
|
||||
// Iterate over potential output pixels
|
||||
for (int h2 = h2_min; h2 <= h2_max; h2++) {
|
||||
for (int w2 = w2_min; w2 <= w2_max; w2++) {
|
||||
// Compute source coordinates for this output pixel
|
||||
const accscalar_t h1r = area_pixel_compute_source_index<accscalar_t>(
|
||||
rheight, h2, align_corners, /*cubic=*/false);
|
||||
const int h1_base = (int)h1r;
|
||||
const int h1p = (h1_base < height1 - 1) ? 1 : 0;
|
||||
const accscalar_t h1lambda = h1r - h1_base;
|
||||
const accscalar_t h0lambda = static_cast<accscalar_t>(1) - h1lambda;
|
||||
|
||||
const accscalar_t w1r = area_pixel_compute_source_index<accscalar_t>(
|
||||
rwidth, w2, align_corners, /*cubic=*/false);
|
||||
const int w1_base = (int)w1r;
|
||||
const int w1p = (w1_base < width1 - 1) ? 1 : 0;
|
||||
const accscalar_t w1lambda = w1r - w1_base;
|
||||
const accscalar_t w0lambda = static_cast<accscalar_t>(1) - w1lambda;
|
||||
|
||||
// Check if our input pixel participates in this interpolation and accumulate all weights
|
||||
// At boundaries, h1p=0 or w1p=0 causes some sampling positions to collapse
|
||||
// to the same pixel, so we need to accumulate weights from all matching positions
|
||||
accscalar_t weight = 0;
|
||||
|
||||
// Check all four interpolation positions and accumulate weights
|
||||
if (h1 == h1_base && w1 == w1_base) {
|
||||
weight += h0lambda * w0lambda; // top-left
|
||||
}
|
||||
if (h1 == h1_base && w1 == w1_base + w1p) {
|
||||
weight += h0lambda * w1lambda; // top-right (may be same as top-left if w1p=0)
|
||||
}
|
||||
if (h1 == h1_base + h1p && w1 == w1_base) {
|
||||
weight += h1lambda * w0lambda; // bottom-left (may be same as top-left if h1p=0)
|
||||
}
|
||||
if (h1 == h1_base + h1p && w1 == w1_base + w1p) {
|
||||
weight += h1lambda * w1lambda; // bottom-right (may collapse to other positions)
|
||||
}
|
||||
|
||||
if (weight > 0) {
|
||||
const size_t output_idx = nc_idx * height2 * width2 + h2 * width2 + w2;
|
||||
grad_sum += weight * static_cast<accscalar_t>(odata[output_idx]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Write accumulated gradient (no atomics needed)
|
||||
idata[index] = static_cast<scalar_t>(grad_sum);
|
||||
}
|
||||
#else
|
||||
const size_t o_numel = nc * width2 * height2;
|
||||
for (size_t index = blockDim.x * blockIdx.x + threadIdx.x; index < o_numel;
|
||||
index += blockDim.x * gridDim.x) {
|
||||
size_t index_temp = index;
|
||||
@ -191,6 +280,7 @@ __global__ void upsample_bilinear2d_backward_out_frame(
|
||||
static_cast<scalar_t>(h1lambda * w1lambda * d2val),
|
||||
true);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename scalar_t, typename accscalar_t>
|
||||
@ -387,7 +477,6 @@ static void upsample_bilinear2d_backward_out_cuda_template(
|
||||
// threads are not covering the whole input tensor.
|
||||
grad_input.zero_();
|
||||
|
||||
const size_t num_kernels = nbatch * channels * output_height * output_width;
|
||||
const int num_threads = std::min(
|
||||
at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock, 1024);
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
@ -397,6 +486,12 @@ static void upsample_bilinear2d_backward_out_cuda_template(
|
||||
return;
|
||||
}
|
||||
|
||||
#ifdef USE_ROCM
|
||||
constexpr bool use_input = true;
|
||||
#else
|
||||
constexpr bool use_input = false;
|
||||
#endif
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(
|
||||
at::ScalarType::Half, at::ScalarType::BFloat16,
|
||||
grad_output_.scalar_type(), "upsample_bilinear2d_backward_out_frame", [&] {
|
||||
@ -414,6 +509,8 @@ static void upsample_bilinear2d_backward_out_cuda_template(
|
||||
const accscalar_t rwidth = area_pixel_compute_scale<accscalar_t>(
|
||||
input_width, output_width, align_corners, scales_w);
|
||||
|
||||
const size_t num_kernels = nbatch * channels * output_height * output_width;
|
||||
|
||||
upsample_bilinear2d_backward_nhwc_out_frame<scalar_t, accscalar_t>
|
||||
<<<ceil_div(num_kernels, static_cast<size_t>(num_threads)), num_threads, 0, stream>>>(
|
||||
input_height,
|
||||
@ -444,6 +541,8 @@ static void upsample_bilinear2d_backward_out_cuda_template(
|
||||
const accscalar_t rwidth = area_pixel_compute_scale<accscalar_t>(
|
||||
input_width, output_width, align_corners, scales_w);
|
||||
|
||||
const size_t num_kernels = nbatch * channels * (use_input ? input_height * input_width : output_height * output_width);
|
||||
|
||||
upsample_bilinear2d_backward_out_frame<scalar_t, accscalar_t>
|
||||
<<<ceil_div(num_kernels, static_cast<size_t>(num_threads)),
|
||||
num_threads,
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
#if (defined(USE_ROCM) && ROCM_VERSION >= 50700) || ((defined(CUDA_VERSION) && CUDA_VERSION >= 12000) && (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 800)))
|
||||
#if defined(USE_ROCM) || ((defined(CUDA_VERSION) && CUDA_VERSION >= 12000) && (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 800)))
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_runtime.h>
|
||||
@ -133,7 +133,7 @@ inline __host__ __device__ uint32_t getAlignmentRoundUp(const void* p) {
|
||||
#define CDNA2_OR_LATER 0
|
||||
#endif
|
||||
|
||||
#if (defined(USE_ROCM) && ROCM_VERSION >= 50700) || ((defined(CUDA_VERSION) && CUDA_VERSION >= 12000) && (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 800)))
|
||||
#if defined(USE_ROCM) || ((defined(CUDA_VERSION) && CUDA_VERSION >= 12000) && (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 800)))
|
||||
|
||||
#if defined(USE_ROCM)
|
||||
// TODO: Support RDNA
|
||||
@ -1161,7 +1161,7 @@ at::Tensor _weight_int4pack_mm_cuda(
|
||||
auto C_final = at::empty(
|
||||
{m, n}, at::TensorOptions().dtype(at::kBFloat16).device(A.device()));
|
||||
|
||||
#if (defined(USE_ROCM) && ROCM_VERSION >= 50700) || ((defined(CUDA_VERSION) && CUDA_VERSION >= 12000) && (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 800)))
|
||||
#if defined(USE_ROCM) || ((defined(CUDA_VERSION) && CUDA_VERSION >= 12000) && (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 800)))
|
||||
auto stream = at::cuda::getCurrentCUDAStream();
|
||||
#define RUN_GEMM(WARPS, K_TILES_PER_WARP, Q_GROUP_SIZE, REDUCE_TYPE) \
|
||||
do { \
|
||||
@ -1327,7 +1327,7 @@ at::Tensor _convert_weight_to_int4pack_cuda(
|
||||
{nTilesTensor, kSuperTiles, 32, innerKTiles / 2},
|
||||
at::TensorOptions().dtype(at::kInt).device(in.device()));
|
||||
|
||||
#if (defined(USE_ROCM) && ROCM_VERSION >= 50700) || ((defined(CUDA_VERSION) && CUDA_VERSION >= 12000) && (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 800)))
|
||||
#if defined(USE_ROCM) || ((defined(CUDA_VERSION) && CUDA_VERSION >= 12000) && (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 800)))
|
||||
auto stream = at::cuda::getCurrentCUDAStream();
|
||||
dim3 grid(kSuperTiles, nTiles);
|
||||
|
||||
|
||||
@ -1532,7 +1532,7 @@ NvrtcFunction jit_pwise_function(
|
||||
|
||||
std::string file_path;
|
||||
if (cache_dir.has_value()) {
|
||||
// Attemps to read from the cache.
|
||||
// Attempts to read from the cache.
|
||||
// Cubin name is <kernel name>_arch<major>.<minor>_nvrtc<major>.<minor>_<ptx or sass>_<program length>_<string hash>
|
||||
// Note that the SHA1 hash used in the file name is NOT the SHA1 hash of the file's contents,
|
||||
// because we hash on the CUDA code, but we save the compiled ptx or sass
|
||||
|
||||
@ -1346,7 +1346,7 @@ void cholesky_helper_magma(const Tensor& input, bool upper, const Tensor& info)
|
||||
});
|
||||
|
||||
if (input.dim() > 2) {
|
||||
// if upper=true we need to tranpose and conjugate the result tensor
|
||||
// if upper=true we need to transpose and conjugate the result tensor
|
||||
// because the cholesky decomposition is stored in the lower triangular part
|
||||
if (upper) {
|
||||
input.copy_(result.mH());
|
||||
@ -1857,7 +1857,7 @@ void geqrf_kernel(const Tensor& input, const Tensor& tau) {
|
||||
|
||||
auto preferred_backend = at::globalContext().linalgPreferredBackend();
|
||||
switch (preferred_backend) {
|
||||
// TODO Investigate whether the following magma bug is still occuring.
|
||||
// TODO Investigate whether the following magma bug is still occurring.
|
||||
// It may be the case that geqrf followed by orgqr is wrong for the magma backend
|
||||
// geqrf_magma currently uses geqrf2_gpu
|
||||
//
|
||||
|
||||
@ -82,7 +82,7 @@ void lu_factor_looped_cusolver(const Tensor& self, const Tensor& pivots, const T
|
||||
#if defined(BUILD_LAZY_CUDA_LINALG)
|
||||
namespace cuda { namespace detail {
|
||||
// This is only used for an old-style dispatches
|
||||
// Please do not add any new entires to it
|
||||
// Please do not add any new entries to it
|
||||
struct LinalgDispatch {
|
||||
Tensor (*cholesky_solve_helper)(const Tensor& self, const Tensor& A, bool upper);
|
||||
};
|
||||
|
||||
@ -147,7 +147,7 @@ static void check_shape_forward(const Tensor& input,
|
||||
// blocked format will propagate between layers. Input, output will be in blocked format.
|
||||
//
|
||||
// For inference case, weight can be prepacked into blocked format by
|
||||
// (so as to save weight reoder overhead):
|
||||
// (so as to save weight reorder overhead):
|
||||
// model = torch.utils.mkldnn.to_mkldnn(model)
|
||||
//
|
||||
// For training case, grad_output can be CPU tensor or MKLDNN tensor,
|
||||
@ -723,7 +723,7 @@ Tensor _mkldnn_convolution_transpose(
|
||||
ideep::tensor w = itensor_from_tensor(weight, /*from_const_data_ptr*/true);
|
||||
if (!weight.is_mkldnn()) {
|
||||
// mkldnn transposed convolution has weight in logical order of OIHW or OIDHW,
|
||||
// while PyTorch has IOHW or IODHW, `._tranpose()` switches strides (no memory copy).
|
||||
// while PyTorch has IOHW or IODHW, `._transpose()` switches strides (no memory copy).
|
||||
w.transpose_(0, 1);
|
||||
}
|
||||
|
||||
|
||||
@ -540,7 +540,7 @@ static void _mkldnn_matmul_i8i8i32_with_primitive(
|
||||
args.insert({DNNL_ARG_WEIGHTS, expected_weight});
|
||||
args.insert({DNNL_ARG_DST, dst});
|
||||
args.insert({DNNL_ARG_SCRATCHPAD, scratchpad});
|
||||
// Create primitve and execute
|
||||
// Create primitive and execute
|
||||
auto primitive = dnnl::matmul(prim_desc);
|
||||
primitive.execute(ideep::stream::default_stream(), args);
|
||||
}
|
||||
|
||||
@ -439,7 +439,7 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor> mkldnn_rnn_la
|
||||
// I. Memory Formats
|
||||
// a. mkldnn will use plain formats for input, hx/cx, output, hy/cy
|
||||
// and possibly use blocked formats for weights depending shape info.
|
||||
// b. All mkldnn memorys are created (in plain format) as views on ATen tensor,
|
||||
// b. All mkldnn memories are created (in plain format) as views on ATen tensor,
|
||||
// the weight reorder(if any) is handed automatically inside ideep (mkldnn bridge)
|
||||
//
|
||||
// II. MKLDNN Primitive Mapping
|
||||
|
||||
@ -39,7 +39,7 @@ void check_mkldnn_binary_fusion_inputs(
|
||||
inline std::vector<int64_t> padding_r(
|
||||
IntArrayRef padding, IntArrayRef output_padding)
|
||||
{
|
||||
// ConvTranpose padding adjustment
|
||||
// ConvTranspose padding adjustment
|
||||
//
|
||||
// PyTorch uses padding/output_padding:
|
||||
// osize = (isize - 1) * stride - 2 * padding + dilation * (kernel_size - 1) + output_padding + 1
|
||||
|
||||
@ -75,7 +75,7 @@ bool can_use_overrideable_attention(sdp::sdp_params const& params, bool debug) {
|
||||
}
|
||||
|
||||
bool can_use_flash_attention(sdp::sdp_params const& params, bool debug) {
|
||||
// Currently, XPU fallbacks flash attention to overrideable
|
||||
// Currently, XPU fallbacks flash attention to overridable
|
||||
return can_use_overrideable_attention(params, debug);
|
||||
}
|
||||
|
||||
@ -115,7 +115,7 @@ sdp::SDPBackend select_sdp_backend_xpu(sdp::sdp_params const& kernel_params) {
|
||||
// 1. Flash Attention
|
||||
// 2. Math fallback
|
||||
auto& ctx = at::globalContext();
|
||||
// use overrideable linked to onednn as overrideable implementation
|
||||
// use overridable linked to onednn as overridable implementation
|
||||
if (!ctx.userEnabledMathSDP() && !ctx.userEnabledOverrideableSDP() &&
|
||||
!ctx.userEnabledFlashSDP()) {
|
||||
return sdp::SDPBackend::error;
|
||||
@ -165,7 +165,7 @@ sdp::SDPBackend select_sdp_backend_xpu(sdp::sdp_params const& kernel_params) {
|
||||
}
|
||||
}
|
||||
// If we have gotten to this point then two things have happened:
|
||||
// 1. can_use_overrideable_attention did not satisfy the constraints to be ran
|
||||
// 1. can_use_overridable_attention did not satisfy the constraints to be ran
|
||||
// 2. The user has explicitly disabled the math kernel
|
||||
// We then re-run the kernel checks with debug enabled to print out the
|
||||
// reason why the kernel was not selected
|
||||
|
||||
@ -2,6 +2,7 @@
|
||||
#include <ATen/WrapDimUtilsMulti.h>
|
||||
#include <ATen/native/Resize.h>
|
||||
#include <ATen/native/mkldnn/xpu/detail/oneDNN.h>
|
||||
#include <ATen/native/xpu/Blas.h>
|
||||
#include <torch/library.h>
|
||||
#ifndef AT_PER_OPERATOR_HEADERS
|
||||
|
||||
@ -50,9 +51,13 @@ Tensor& addmm_out(
|
||||
mat1.dtype(),
|
||||
" != ",
|
||||
mat2.dtype())
|
||||
|
||||
// complex case
|
||||
TORCH_CHECK(
|
||||
!mat1.is_complex(), "Complex datatype matmul is not supported in oneDNN");
|
||||
if (self.is_complex()) {
|
||||
at::native::addmm_complex_out_xpu(self, mat1, mat2, beta, alpha, result);
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
std::vector<int64_t> result_shape = {mat1.size(0), mat2.size(1)};
|
||||
result.resize_(result_shape);
|
||||
@ -167,8 +172,11 @@ Tensor& mm_out(const Tensor& self, const Tensor& mat2, Tensor& result) {
|
||||
return result;
|
||||
}
|
||||
|
||||
TORCH_CHECK(
|
||||
!self.is_complex(), "Complex datatype matmul is not supported in oneDNN");
|
||||
if (self.is_complex()) {
|
||||
at::native::mm_complex_out_xpu(self, mat2, result);
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
onednn::matmul(result, self, mat2, Tensor(), true, onednn::Attr());
|
||||
return result;
|
||||
@ -208,9 +216,12 @@ Tensor& baddbmm_out(
|
||||
input.sizes());
|
||||
|
||||
// complex case
|
||||
TORCH_CHECK(
|
||||
!batch1.is_complex(),
|
||||
"Complex datatype matmul is not supported in oneDNN");
|
||||
if (input.is_complex()) {
|
||||
at::native::baddbmm_complex_out_xpu(
|
||||
input, batch1, batch2, beta, alpha, result);
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
// general case
|
||||
onednn::Attr attr;
|
||||
@ -257,8 +268,13 @@ Tensor& bmm_out(const Tensor& self, const Tensor& batch2, Tensor& result) {
|
||||
return result;
|
||||
}
|
||||
|
||||
TORCH_CHECK(
|
||||
!self.is_complex(), "Complex datatype matmul is not supported in oneDNN");
|
||||
// complex case
|
||||
if (self.is_complex()) {
|
||||
at::native::bmm_complex_out_xpu(self, batch2, result);
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
onednn::matmul(result, self, batch2, at::Tensor(), true, onednn::Attr());
|
||||
return result;
|
||||
}
|
||||
|
||||
@ -215,7 +215,7 @@ partition create_sdpa_graph_partition(
|
||||
// For optional additive mask
|
||||
std::optional<op> mask_add;
|
||||
|
||||
// For optional implicite causal mask
|
||||
// For optional implicit causal mask
|
||||
std::optional<op> mask_gen_idx_row;
|
||||
std::optional<logical_tensor> mask_row_idx;
|
||||
std::optional<op> mask_gen_idx_col;
|
||||
@ -556,7 +556,7 @@ partition create_sdpa_backward_graph_partition(
|
||||
// For optional additive mask
|
||||
std::optional<op> mask_add;
|
||||
|
||||
// For optional implicite causal mask
|
||||
// For optional implicit causal mask
|
||||
std::optional<op> mask_gen_idx_row;
|
||||
std::optional<logical_tensor> mask_row_idx;
|
||||
std::optional<op> mask_gen_idx_col;
|
||||
|
||||
@ -345,7 +345,7 @@ class Attr {
|
||||
dnnl::memory binary_m;
|
||||
auto binary = ops_params_[i].binary_;
|
||||
auto md = ops_params_[i].meta_;
|
||||
// qeury expected_md to achieve peak performance
|
||||
// query expected_md to achieve peak performance
|
||||
auto expected_md = pd.query_md(
|
||||
dnnl::query::exec_arg_md,
|
||||
DNNL_ARG_ATTR_MULTIPLE_POST_OP(i) | DNNL_ARG_SRC_1);
|
||||
|
||||
@ -301,7 +301,7 @@ bool is_onednn_matmul_strides(const at::Tensor& tensor) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// the overlaped cases are not supported
|
||||
// the overlapped cases are not supported
|
||||
dnnl::memory::dims strides = get_onednn_strides(tensor);
|
||||
int64_t storage_size = 1;
|
||||
for (size_t dim = 0; dim < tensor_dim; ++dim)
|
||||
|
||||
@ -29,7 +29,7 @@
|
||||
secondaryTensor:(MPSGraphTensor*)secondaryTensor
|
||||
name:(NSString*)name {
|
||||
// As of MacOS-15.1 m..imumWithNanPropagation is only defined for floating types and calling it with integral
|
||||
// agruments results in
|
||||
// arguments results in
|
||||
// /AppleInternal/Library/BuildRoots/c7c74b64-74b4-11ef-aeda-9635a580fe0d/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShaders/MPSCore/Utility/MPSKernelDAG.mm:805:
|
||||
// failed assertion `Error getting visible function: (null) Function isNaN_u8_i8 was not found in the library'
|
||||
if (([primaryTensor dataType] & MPSDataTypeFloatBit) == 0) {
|
||||
@ -42,7 +42,7 @@
|
||||
secondaryTensor:(MPSGraphTensor*)secondaryTensor
|
||||
name:(NSString*)name {
|
||||
// As of MacOS-15.1 m..imumWithNanPropagation is only defined for floating types and calling it with integral
|
||||
// agruments results in
|
||||
// arguments results in
|
||||
// /AppleInternal/Library/BuildRoots/c7c74b64-74b4-11ef-aeda-9635a580fe0d/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShaders/MPSCore/Utility/MPSKernelDAG.mm:805:
|
||||
// failed assertion `Error getting visible function: (null) Function isNaN_u8_i8 was not found in the library'
|
||||
if (([primaryTensor dataType] & MPSDataTypeFloatBit) == 0) {
|
||||
@ -539,7 +539,7 @@ Placeholder::Placeholder(MPSGraphTensor* mpsGraphTensor,
|
||||
|
||||
static const bool is_macOS_15_0_or_newer = is_macos_13_or_newer(MacOSVersion::MACOS_VER_15_0_PLUS);
|
||||
// Use gather kernel to solve strides for macOS < 15.0
|
||||
// Starting with macOS 15.0, MPS supports native strides direclty in the kernels
|
||||
// Starting with macOS 15.0, MPS supports native strides directly in the kernels
|
||||
if (!is_macOS_15_0_or_newer || !useMPSStridedAPI) {
|
||||
if ((!src.is_contiguous() || src.storage_offset()) && gatherTensorData) {
|
||||
Tensor emptyShell = Tensor();
|
||||
|
||||
@ -222,6 +222,13 @@ struct nextafter_functor {
|
||||
}
|
||||
};
|
||||
|
||||
struct hypot_functor {
|
||||
template <typename T>
|
||||
inline T operator()(const T a, const T b) {
|
||||
return static_cast<T>(precise::sqrt(float(a) * a + float(b) * b));
|
||||
}
|
||||
};
|
||||
|
||||
// Complex binary functors
|
||||
struct polar_functor {
|
||||
template <typename U>
|
||||
@ -362,6 +369,7 @@ struct igammac_functor {
|
||||
REGISTER_OPMATH_BINARY_OP(NAME, half, half); \
|
||||
REGISTER_OPMATH_BINARY_OP(NAME, bfloat, bfloat)
|
||||
|
||||
REGISTER_FLOAT_BINARY_OP(hypot);
|
||||
REGISTER_FLOAT_BINARY_OP(copysign);
|
||||
REGISTER_INT2FLOAT_BINARY_OP(copysign);
|
||||
REGISTER_FLOAT_BINARY_OP(fmax);
|
||||
|
||||
16
aten/src/ATen/native/mps/kernels/LinearAlgebra.h
Normal file
16
aten/src/ATen/native/mps/kernels/LinearAlgebra.h
Normal file
@ -0,0 +1,16 @@
|
||||
#pragma onces
|
||||
#include <c10/metal/common.h>
|
||||
|
||||
template <unsigned N = c10::metal::max_ndim>
|
||||
struct OrgqrParams {
|
||||
int32_t num_batch_dims;
|
||||
|
||||
uint32_t m;
|
||||
uint32_t n;
|
||||
uint32_t k;
|
||||
|
||||
::c10::metal::array<uint32_t, N> A_strides;
|
||||
::c10::metal::array<uint32_t, N> tau_strides;
|
||||
::c10::metal::array<uint32_t, N> H_strides;
|
||||
::c10::metal::array<uint32_t, N> H_sizes;
|
||||
};
|
||||
@ -1,3 +1,4 @@
|
||||
#include <ATen/native/mps/kernels/LinearAlgebra.h>
|
||||
#include <c10/metal/utils.h>
|
||||
#include <metal_array>
|
||||
#include <metal_simdgroup>
|
||||
@ -640,6 +641,164 @@ kernel void applyPivots(
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static T bool_to_float(bool b) {
|
||||
return static_cast<T>(b);
|
||||
}
|
||||
|
||||
template <>
|
||||
half2 bool_to_float(bool b) {
|
||||
return half2(b ? 1 : 0, 0);
|
||||
}
|
||||
|
||||
template <>
|
||||
float2 bool_to_float(bool b) {
|
||||
return float2(b ? 1 : 0, 0);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static T calc_H_irc(
|
||||
device T* A,
|
||||
uint32_t A_stride_r,
|
||||
uint32_t A_stride_c,
|
||||
constant T* tau,
|
||||
uint32_t tau_stride,
|
||||
uint32_t r,
|
||||
uint32_t c,
|
||||
uint32_t i) {
|
||||
T I_val = bool_to_float<T>(r == c);
|
||||
T tau_val = tau[i * tau_stride];
|
||||
|
||||
T A_ci = c10::metal::conj(A[c * A_stride_r + i * A_stride_c]);
|
||||
T A_ri = A[r * A_stride_r + i * A_stride_c];
|
||||
|
||||
T c_eq_i = bool_to_float<T>(c == i);
|
||||
T r_eq_i = bool_to_float<T>(r == i);
|
||||
|
||||
T A_ci_ = (c > i) ? A_ci : c_eq_i;
|
||||
T A_ri_ = (r > i) ? A_ri : r_eq_i;
|
||||
|
||||
return I_val - c10::metal::mul(tau_val, c10::metal::mul(A_ci_, A_ri_));
|
||||
}
|
||||
|
||||
// Calculate (A @ B)[r, c], the element in the r-th row and c-th column of the
|
||||
// result of matrix multiplying A and B together. A and B must be size m-by-m
|
||||
// and have the same strides. The formula for this operation, written in Python
|
||||
// syntax, is:
|
||||
// (A @ B)[r, c] = A[r, :].dot(B[:, c])
|
||||
template <typename T>
|
||||
static T calc_matmul_rc(
|
||||
device T* A,
|
||||
device T* B,
|
||||
uint32_t stride_r,
|
||||
uint32_t stride_c,
|
||||
uint32_t m,
|
||||
uint32_t r,
|
||||
uint32_t c) {
|
||||
T AB_rc = 0;
|
||||
auto A_row_offset = r * stride_r;
|
||||
auto B_col_offset = c * stride_c;
|
||||
|
||||
uint32_t A_col_offset = 0;
|
||||
uint32_t B_row_offset = 0;
|
||||
|
||||
for (uint32_t j = 0; j < m;
|
||||
j++, A_col_offset += stride_c, B_row_offset += stride_r) {
|
||||
AB_rc += c10::metal::mul(
|
||||
A[A_row_offset + A_col_offset], B[B_row_offset + B_col_offset]);
|
||||
}
|
||||
return AB_rc;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
kernel void orgqr(
|
||||
device T* A [[buffer(0)]],
|
||||
constant T* tau [[buffer(1)]],
|
||||
device T* H [[buffer(2)]],
|
||||
device T* H_prod [[buffer(3)]],
|
||||
constant OrgqrParams<>& params [[buffer(4)]],
|
||||
uint tid [[thread_position_in_grid]]) {
|
||||
constant auto& A_strides = params.A_strides;
|
||||
constant auto& tau_strides = params.tau_strides;
|
||||
constant auto& H_strides = params.H_strides;
|
||||
constant auto& H_sizes = params.H_sizes;
|
||||
|
||||
auto num_batch_dims = params.num_batch_dims;
|
||||
auto m = params.m;
|
||||
auto n = params.n;
|
||||
auto k = params.k;
|
||||
|
||||
auto m2 = m * m;
|
||||
auto batch_idx = tid / m2;
|
||||
|
||||
// Find the matrices for this thread's batch index
|
||||
uint32_t A_offset = 0;
|
||||
uint32_t tau_offset = 0;
|
||||
uint32_t H_offset = 0;
|
||||
|
||||
for (auto dim = num_batch_dims - 1; dim >= 0; dim--) {
|
||||
auto dim_size = H_sizes[dim];
|
||||
auto dim_idx = batch_idx % dim_size;
|
||||
|
||||
A_offset += dim_idx * A_strides[dim];
|
||||
tau_offset += dim_idx * tau_strides[dim];
|
||||
H_offset += dim_idx * H_strides[dim];
|
||||
|
||||
batch_idx /= dim_size;
|
||||
}
|
||||
|
||||
A += A_offset;
|
||||
tau += tau_offset;
|
||||
H += H_offset;
|
||||
H_prod += H_offset;
|
||||
|
||||
auto matrix_idx = tid % m2;
|
||||
auto r = matrix_idx / m;
|
||||
auto c = matrix_idx % m;
|
||||
auto A_stride_r = A_strides[num_batch_dims];
|
||||
auto A_stride_c = A_strides[num_batch_dims + 1];
|
||||
auto tau_stride = tau_strides[num_batch_dims];
|
||||
auto H_stride_r = H_strides[num_batch_dims];
|
||||
auto H_stride_c = H_strides[num_batch_dims + 1];
|
||||
|
||||
// Find the element of H and H_prod that this thread will calculate
|
||||
device T* H_elem_ptr = H + (r * H_stride_r + c * H_stride_c);
|
||||
device T* H_prod_elem_ptr = H_prod + (r * H_stride_r + c * H_stride_c);
|
||||
|
||||
for (uint32_t i = 0; i < k; i++) {
|
||||
// Calculate and write H_i
|
||||
|
||||
T H_irc = calc_H_irc(A, A_stride_r, A_stride_c, tau, tau_stride, r, c, i);
|
||||
|
||||
// Calculate element [r, c] of prod(H_0, ..., H_i)
|
||||
if (i == 0) {
|
||||
*H_prod_elem_ptr = H_irc;
|
||||
} else {
|
||||
*H_elem_ptr = H_irc;
|
||||
|
||||
// Need this sync because the below matmul requires all threads to finish
|
||||
// writing their entries to `H_prod` and `H`.
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
T H_prod_0_to_i_rc =
|
||||
calc_matmul_rc(H_prod, H, H_stride_r, H_stride_c, m, r, c);
|
||||
|
||||
// Need this sync because the above matmul uses the current values in
|
||||
// `H_prod`, and we don't want to overwrite those until all threads are
|
||||
// finished using them.
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
*H_prod_elem_ptr = H_prod_0_to_i_rc;
|
||||
}
|
||||
}
|
||||
|
||||
device T* A_elem_ptr = A + (r * A_stride_r + c * A_stride_c);
|
||||
|
||||
if (c < n) {
|
||||
*A_elem_ptr = *H_prod_elem_ptr;
|
||||
}
|
||||
}
|
||||
|
||||
#define INSTANTIATE_MM_OPS(DTYPE) \
|
||||
template [[host_name("matmul_" #DTYPE)]] kernel void matmul<DTYPE>( \
|
||||
constant DTYPE * mat1Data [[buffer(0)]], \
|
||||
@ -679,3 +838,19 @@ INSTANTIATE_MM_OPS(int);
|
||||
INSTANTIATE_MM_OPS(short);
|
||||
INSTANTIATE_MM_OPS(char);
|
||||
INSTANTIATE_MM_OPS(uchar);
|
||||
|
||||
#define REGISTER_ORGQR(T) \
|
||||
template [[host_name("orgqr_" #T)]] \
|
||||
kernel void orgqr<T>( \
|
||||
device T * A [[buffer(0)]], \
|
||||
constant T * tau [[buffer(1)]], \
|
||||
device T * H [[buffer(2)]], \
|
||||
device T * H_prod [[buffer(3)]], \
|
||||
constant OrgqrParams<> & params [[buffer(4)]], \
|
||||
uint tid [[thread_position_in_grid]]);
|
||||
|
||||
REGISTER_ORGQR(float);
|
||||
REGISTER_ORGQR(half);
|
||||
REGISTER_ORGQR(bfloat);
|
||||
REGISTER_ORGQR(float2);
|
||||
REGISTER_ORGQR(half2);
|
||||
|
||||
@ -5,6 +5,21 @@
|
||||
using namespace metal;
|
||||
using namespace c10::metal;
|
||||
|
||||
struct angle_functor {
|
||||
template <typename T, enable_if_t<is_complex_v<T>, bool> = true>
|
||||
inline T operator()(const T x) {
|
||||
return T(atan2(x.y, x.x), 0);
|
||||
}
|
||||
template <typename T, enable_if_t<is_scalar_floating_point_v<T>, bool> = true>
|
||||
inline T operator()(const T x) {
|
||||
return T(isnan(x) ? x : x < 0 ? M_PI_F : 0.0);
|
||||
}
|
||||
template <typename T, enable_if_t<is_scalar_integral_v<T>, bool> = true>
|
||||
inline float operator()(const T x) {
|
||||
return x < 0 ? M_PI_F : 0.0;
|
||||
}
|
||||
};
|
||||
|
||||
// Implement exp wrapper for both real and complex types
|
||||
template <typename T, enable_if_t<is_scalar_floating_point_v<T>, bool> = true>
|
||||
inline T exp_(const T x) {
|
||||
@ -545,6 +560,7 @@ REGISTER_UNARY_OP(abs, float, float);
|
||||
REGISTER_UNARY_OP(abs, half, half);
|
||||
|
||||
#define INSTANTIATE_UNARY_KERNELS2(DTYPE0, DTYPE1) \
|
||||
REGISTER_UNARY_OP(angle, DTYPE1, DTYPE0); \
|
||||
REGISTER_UNARY_OP(erf, DTYPE1, DTYPE0); \
|
||||
REGISTER_UNARY_OP(erfc, DTYPE1, DTYPE0); \
|
||||
REGISTER_UNARY_OP(erfinv, DTYPE1, DTYPE0); \
|
||||
@ -583,6 +599,7 @@ INSTANTIATE_UNARY_KERNELS2(float, int);
|
||||
INSTANTIATE_UNARY_KERNELS2(float, long);
|
||||
|
||||
#define INSTANTIATE_UNARY_KERNELS_VEC2(DTYPE) \
|
||||
REGISTER_UNARY_OP(angle, DTYPE##2, DTYPE##2); \
|
||||
REGISTER_UNARY_OP(neg, DTYPE##2, DTYPE##2); \
|
||||
REGISTER_UNARY_OP(exp, DTYPE##2, DTYPE##2); \
|
||||
REGISTER_UNARY_OP(expm1, DTYPE##2, DTYPE##2); \
|
||||
|
||||
@ -202,6 +202,10 @@ static void igammac_mps_kernel(TensorIteratorBase& iter) {
|
||||
lib.exec_binary_kernel(iter, "igammac");
|
||||
}
|
||||
|
||||
static void hypot_mps_kernel(TensorIteratorBase& iter) {
|
||||
lib.exec_binary_kernel(iter, "hypot");
|
||||
}
|
||||
|
||||
REGISTER_DISPATCH(fmax_stub, &fmax_mps_kernel)
|
||||
REGISTER_DISPATCH(fmin_stub, &fmin_mps_kernel)
|
||||
REGISTER_DISPATCH(copysign_stub, ©sign_mps_kernel)
|
||||
@ -229,4 +233,5 @@ REGISTER_DISPATCH(fmod_stub, &fmod_mps_kernel)
|
||||
REGISTER_DISPATCH(remainder_stub, &remainder_mps_kernel)
|
||||
REGISTER_DISPATCH(igamma_stub, &igamma_mps_kernel)
|
||||
REGISTER_DISPATCH(igammac_stub, &igammac_mps_kernel)
|
||||
REGISTER_DISPATCH(hypot_stub, &hypot_mps_kernel)
|
||||
} // namespace at::native
|
||||
|
||||
@ -16,7 +16,6 @@
|
||||
#include <ATen/ops/eq_native.h>
|
||||
#include <ATen/ops/ge_native.h>
|
||||
#include <ATen/ops/gt_native.h>
|
||||
#include <ATen/ops/hypot_native.h>
|
||||
#include <ATen/ops/le_native.h>
|
||||
#include <ATen/ops/logaddexp2_native.h>
|
||||
#include <ATen/ops/logaddexp_native.h>
|
||||
@ -278,22 +277,6 @@ TORCH_IMPL_FUNC(pow_Scalar_out_mps)(const Scalar& base, const Tensor& exp, const
|
||||
}
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(hypot_out_mps)(const Tensor& self, const Tensor& other, const Tensor& output) {
|
||||
mps::BinaryOpBlock hypot_op_block = ^BinaryOpFn(cachedGraph, primaryCastTensor, secondaryCastTensor) {
|
||||
MPSGraph* mpsGraph = cachedGraph->graph();
|
||||
MPSGraphTensor* twoTensor = [mpsGraph constantWithScalar:2.0 shape:@[ @1 ] dataType:primaryCastTensor.dataType];
|
||||
MPSGraphTensor* sumTensor = [mpsGraph additionWithPrimaryTensor:[mpsGraph powerWithPrimaryTensor:primaryCastTensor
|
||||
secondaryTensor:twoTensor
|
||||
name:nil]
|
||||
secondaryTensor:[mpsGraph powerWithPrimaryTensor:secondaryCastTensor
|
||||
secondaryTensor:twoTensor
|
||||
name:nil]
|
||||
name:nil];
|
||||
return [mpsGraph squareRootWithTensor:sumTensor name:nil];
|
||||
};
|
||||
mps::binaryOpTensor(self, other, output, "hypot_out_mps", hypot_op_block);
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(logaddexp_out_mps)(const Tensor& self, const Tensor& other, const Tensor& output) {
|
||||
mps::BinaryOpBlock logaddexp_op_block = ^BinaryOpFn(cachedGraph, primaryCastTensor, secondaryCastTensor) {
|
||||
MPSGraph* mpsGraph = cachedGraph->graph();
|
||||
|
||||
@ -8,6 +8,9 @@
|
||||
#include <ATen/native/Resize.h>
|
||||
#include <ATen/native/mps/MPSGraphSequoiaOps.h>
|
||||
#include <ATen/native/mps/OperationUtils.h>
|
||||
#include <ATen/native/mps/kernels/LinearAlgebra.h>
|
||||
|
||||
#include <fmt/format.h>
|
||||
|
||||
#ifndef AT_PER_OPERATOR_HEADERS
|
||||
#include <ATen/Functions.h>
|
||||
@ -28,6 +31,7 @@
|
||||
#include <ATen/ops/linalg_solve_triangular_native.h>
|
||||
#include <ATen/ops/lu_unpack_native.h>
|
||||
#include <ATen/ops/mm_native.h>
|
||||
#include <ATen/ops/orgqr_native.h>
|
||||
#include <ATen/ops/slice.h>
|
||||
#include <ATen/ops/stack.h>
|
||||
#include <ATen/ops/triangular_solve_native.h>
|
||||
@ -1235,6 +1239,69 @@ static void cholesky_stub_impl(const Tensor& out, const Tensor& info, bool upper
|
||||
}
|
||||
}
|
||||
|
||||
static Tensor& orgqr_stub_impl(Tensor& self, const Tensor& tau) {
|
||||
if (self.numel() == 0) {
|
||||
return self;
|
||||
}
|
||||
|
||||
auto m = self.size(-2);
|
||||
auto n = self.size(-1);
|
||||
auto k = tau.size(-1);
|
||||
|
||||
if (tau.numel() == 0) {
|
||||
auto I = eye(m, self.scalar_type(), std::nullopt, self.device());
|
||||
return self.copy_(I.slice(-1, 0, n));
|
||||
}
|
||||
|
||||
auto num_batch_dims = self.dim() - 2;
|
||||
auto batch_sizes = self.sizes().slice(0, num_batch_dims);
|
||||
|
||||
std::vector<int64_t> H_sizes(num_batch_dims + 2);
|
||||
for (auto dim : c10::irange(num_batch_dims)) {
|
||||
H_sizes[dim] = self.size(dim);
|
||||
}
|
||||
H_sizes[num_batch_dims] = m;
|
||||
H_sizes[num_batch_dims + 1] = m;
|
||||
|
||||
auto H = at::empty(H_sizes, self.options().memory_format(MemoryFormat::Contiguous));
|
||||
auto H_prod = at::empty_like(H);
|
||||
|
||||
OrgqrParams params;
|
||||
|
||||
params.num_batch_dims = num_batch_dims;
|
||||
params.m = m;
|
||||
params.n = n;
|
||||
params.k = k;
|
||||
|
||||
for (const auto dim : c10::irange(self.dim())) {
|
||||
params.A_strides[dim] = self.stride(dim);
|
||||
|
||||
if (dim < tau.dim()) {
|
||||
params.tau_strides[dim] = tau.stride(dim);
|
||||
}
|
||||
|
||||
params.H_strides[dim] = H.stride(dim);
|
||||
params.H_sizes[dim] = H.size(dim);
|
||||
}
|
||||
|
||||
auto num_threads = H.numel();
|
||||
MPSStream* stream = getCurrentMPSStream();
|
||||
|
||||
dispatch_sync_with_rethrow(stream->queue(), ^() {
|
||||
@autoreleasepool {
|
||||
id<MTLComputeCommandEncoder> compute_encoder = stream->commandEncoder();
|
||||
auto pipeline_state = lib.getPipelineStateForFunc(fmt::format("orgqr_{}", scalarToMetalTypeString(self)));
|
||||
getMPSProfiler().beginProfileKernel(pipeline_state, "orgqr", {self, tau});
|
||||
[compute_encoder setComputePipelineState:pipeline_state];
|
||||
mtl_setArgs(compute_encoder, self, tau, H, H_prod, params);
|
||||
mtl_dispatch1DJob(compute_encoder, pipeline_state, num_threads);
|
||||
getMPSProfiler().endProfileKernel(pipeline_state);
|
||||
}
|
||||
});
|
||||
|
||||
return self;
|
||||
}
|
||||
|
||||
} // namespace mps
|
||||
|
||||
Tensor addr_mps(const Tensor& self, const Tensor& vec1, const Tensor& vec2, const Scalar& beta, const Scalar& alpha) {
|
||||
@ -1471,4 +1538,6 @@ TORCH_IMPL_FUNC(linalg_inv_ex_out_mps)(const Tensor& A, bool check_errors, const
|
||||
}
|
||||
|
||||
REGISTER_DISPATCH(cholesky_stub, mps::cholesky_stub_impl)
|
||||
REGISTER_DISPATCH(orgqr_stub, mps::orgqr_stub_impl);
|
||||
|
||||
} // namespace at::native
|
||||
|
||||
@ -158,7 +158,7 @@ static void reduction_out_mps(const Tensor& input_t,
|
||||
IntArrayRef dim = opt_dim.value();
|
||||
for (const auto dim_val : dim) {
|
||||
auto wrap_dim = maybe_wrap_dim(dim_val, input_shape.size());
|
||||
// canSqueeze logic is broken when dim is negative, it introduces off-by-one-erros or crashes
|
||||
// canSqueeze logic is broken when dim is negative, it introduces off-by-one-errors or crashes
|
||||
// See https://github.com/pytorch/pytorch/issues/136132#issuecomment-2354482608
|
||||
if (wrap_dim >= 4 || dim_val < 0) {
|
||||
canSqueezeLastDim = false;
|
||||
@ -1282,7 +1282,7 @@ static void all_any_common_impl_mps(const Tensor& input_t,
|
||||
auto inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_t);
|
||||
|
||||
auto castInputTensor = castToIHFTypes(mpsGraph, inputTensor, input_t);
|
||||
// reductionOrWithTensor:axis: will throw an internal assert if number of dimentions is more than 4
|
||||
// reductionOrWithTensor:axis: will throw an internal assert if number of dimensions is more than 4
|
||||
// See https://github.com/pytorch/pytorch/issues/95538
|
||||
MPSGraphTensor* outputTensor = nil;
|
||||
if (input_t.ndimension() > 4) {
|
||||
@ -1352,7 +1352,7 @@ TORCH_IMPL_FUNC(any_all_out_mps)(const Tensor& input_t, const Tensor& output_t)
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
auto inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_t);
|
||||
auto castInputTensor = castToIHFTypes(mpsGraph, inputTensor, input_t);
|
||||
// reductionOrWithTensor:axes: will throw an internal assert if number of dimentions is more than 4
|
||||
// reductionOrWithTensor:axes: will throw an internal assert if number of dimensions is more than 4
|
||||
// See https://github.com/pytorch/pytorch/issues/95538
|
||||
if (input_t.dim() > 4) {
|
||||
castInputTensor = [mpsGraph reshapeTensor:castInputTensor withShape:@[ @-1 ] name:nil];
|
||||
@ -1400,7 +1400,7 @@ TORCH_IMPL_FUNC(all_all_out_mps)(const Tensor& input_t, const Tensor& output_t)
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
auto inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_t);
|
||||
auto castInputTensor = castToIHFTypes(mpsGraph, inputTensor, input_t);
|
||||
// reductionAndWithTensor:axes: will throw an internal assert if number of dimentions is more than 4
|
||||
// reductionAndWithTensor:axes: will throw an internal assert if number of dimensions is more than 4
|
||||
// See https://github.com/pytorch/pytorch/issues/95538
|
||||
if (input_t.ndimension() > 4) {
|
||||
castInputTensor = [mpsGraph reshapeTensor:castInputTensor withShape:@[ @-1 ] name:nil];
|
||||
|
||||
@ -34,6 +34,7 @@ REGISTER_UNARY_TI_DISPATCH(sinc);
|
||||
REGISTER_UNARY_TI_DISPATCH(sinh);
|
||||
REGISTER_UNARY_TI_DISPATCH(cosh);
|
||||
REGISTER_UNARY_TI_DISPATCH(tanh);
|
||||
REGISTER_UNARY_TI_DISPATCH(angle);
|
||||
REGISTER_UNARY_TI_DISPATCH(abs);
|
||||
REGISTER_UNARY_TI_DISPATCH(sin);
|
||||
REGISTER_UNARY_TI_DISPATCH(cos);
|
||||
|
||||
@ -12,7 +12,6 @@
|
||||
#include <ATen/ops/_copy_from_and_resize.h>
|
||||
#include <ATen/ops/acos_native.h>
|
||||
#include <ATen/ops/acosh_native.h>
|
||||
#include <ATen/ops/angle_native.h>
|
||||
#include <ATen/ops/asin_native.h>
|
||||
#include <ATen/ops/asinh_native.h>
|
||||
#include <ATen/ops/atan_native.h>
|
||||
@ -204,23 +203,6 @@ Tensor& logical_not_out_mps(const Tensor& self, Tensor& output) {
|
||||
return output;
|
||||
}
|
||||
|
||||
Tensor& angle_out_mps(const Tensor& self, Tensor& output) {
|
||||
mps::unary_op(self, output, "angle_out_mps", ^MPSGraphTensor*(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) {
|
||||
auto realPart = [mpsGraph realPartOfTensor:inputTensor name:nil];
|
||||
auto imagPart = [mpsGraph imaginaryPartOfTensor:inputTensor name:nil];
|
||||
return [mpsGraph atan2WithPrimaryTensor:imagPart secondaryTensor:realPart name:nil];
|
||||
});
|
||||
return output;
|
||||
}
|
||||
|
||||
Tensor angle_mps(const Tensor& self) {
|
||||
const auto float_type = c10::isIntegralType(self.scalar_type(), /*includeBool=*/true)
|
||||
? c10::typeMetaToScalarType(c10::get_default_dtype())
|
||||
: c10::toRealValueType(self.scalar_type());
|
||||
Tensor result = at::empty({0}, self.options().dtype(float_type));
|
||||
return angle_out_mps(self, result);
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(frac_out_mps)(const Tensor& self, const Tensor& output) {
|
||||
TORCH_CHECK(isFloatingType(self.scalar_type()), "frac_out_mps is only implemented for floating types");
|
||||
mps::unary_op(self, output, "frac_out_mps", ^MPSGraphTensor*(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) {
|
||||
|
||||
@ -19,7 +19,7 @@ namespace at::native::mps {
|
||||
|
||||
// For both scatter and gather kernels, there are 4 specized ones (for 1D to 4D tensor)
|
||||
// and one generic, for 5+D ones. Assumption (to be tested) about specialized kernels
|
||||
// is that reduction of n-dimentional vector, where n is 2, should be slower
|
||||
// is that reduction of n-dimensional vector, where n is 2, should be slower
|
||||
// than reduction of 2D one, as n is not known at compiler time, therefore compiler
|
||||
// could not do loop unrolls, that is
|
||||
// float sum(float* v, int n) {
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user