Compare commits

...

99 Commits

Author SHA1 Message Date
ace9adfaf0 [Inductor-FX] Don't flatten constant args (#166144)
Summary:
Fallback kernels are created with flattened constant args and an `unflatten` utility to unflatten them when needed. Apply it in FXConverter to preserve the original structure


Test Plan:
added new CI tests





cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

Differential Revision: D85347589

Pulled By: chenmillie
2025-10-27 09:13:17 -07:00
f89a7e9fe8 [1/N][Fix] Fix typo in aten folder (#166126)
Fix typo in aten folder.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166126
Approved by: https://github.com/cyyever, https://github.com/slayton58
2025-10-27 15:34:39 +00:00
f2c81635c8 [DeviceMesh][2D] Use concatenate for 2D (FSDP+TP) instead of getting from root mesh (#165492)
With concatenate API, we can directly combine two meshes together rather than getting the spmd mesh from root.

Differential Revision: [D85409698](https://our.internmc.facebook.com/intern/diff/D85409698)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165492
Approved by: https://github.com/fegin
ghstack dependencies: #163358
2025-10-27 15:33:21 +00:00
e214af6ae8 [Pytorch] Improve float32 erf() on aarch64 (#166262)
Summary:
The float32 data type has a vectorized routine that computes erf(). Such function currently calls std::exp() individually for each float on the vector being processed.
We now use sleef's vectorized routine to compute exp, improving performance of erf.

AVX2/AVX512 also have a custom erf implementation, which uses sleef to compute exp.

We've observed a throughput increase of 25%, when tested on tensors containing 1M elements

Before:
f32 erf: 3175.977us

After:
f32 erf: 2539.446us

Test Plan:
Correctness:

buck2 test mode/opt //caffe2/test:test_ops
buck2 test mode/opt //caffe2/test:torch

Performance:

buck2 run mode/opt //caffe2/benchmarks/operator_benchmark/fb:operator_benchmark_test

Differential Revision: D85522651

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166262
Approved by: https://github.com/fadara01, https://github.com/jgong5, https://github.com/aditew01
2025-10-27 14:55:38 +00:00
7ce723d21c [AOTI] Remove c10 as linked library (#165489)
Summary: AOTI compilation doesn't depend on c10 now. It should only depend on C shim symbols which live in libtorch_cpu or libtorch_cuda.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165489
Approved by: https://github.com/yushangdi
2025-10-27 13:53:44 +00:00
4295a9a158 [xla hash update] update the pinned xla hash (#165895)
This PR is auto-generated nightly by [this action](https://github.com/pytorch/pytorch/blob/main/.github/workflows/nightly.yml).
Update the pinned xla hash.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165895
Approved by: https://github.com/pytorchbot
2025-10-27 11:47:29 +00:00
90d7be35e9 Update slow tests (#165894)
This PR is auto-generated weekly by [this action](https://github.com/pytorch/pytorch/blob/main/.github/workflows/weekly.yml).
Update the list of slow tests.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165894
Approved by: https://github.com/pytorchbot
2025-10-27 11:42:14 +00:00
8d4e48831e Remove JITFunction constexpr and some arg_names (#166280)
https://github.com/triton-lang/triton/pull/8536 breaks torch.compile integration. This PR attempts to fix it.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166280
Approved by: https://github.com/jansel
2025-10-27 09:29:03 +00:00
90b30ebf7e Update torch-xpu-ops commit pin (#166129)
Update the torch-xpu-ops commit to [intel/torch-xpu-ops@8d373b](8d373ba272), includes:

- Add CONFIGURE_DEPENDS in install_xpu_headers macro to track these headers
- Add check to ensure P2P Tensors are dense
- Switch philox_engine_inputs usage to philox_xpu_state per XPU graph request
- Add vectorization path for maxpool backward channel last
- Fix SYCL_PRINT macro usable on Windows
- Eliminate unnecessary warning if no AOT enabled

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166129
Approved by: https://github.com/EikanWang
2025-10-27 08:17:03 +00:00
173bcda436 Quick fix of torch.save memory leak (#165204)
Fix the memory leak shown in https://github.com/pytorch/pytorch/issues/149846#issuecomment-3392634572
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165204
Approved by: https://github.com/ezyang
2025-10-27 07:50:58 +00:00
6530bc70fb [DeviceMesh] Implement a device mesh concatenate api for submesh and SPMD use case (#163358)
Today FSDP needs to slicing out spmd mesh from root mesh here: https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/_fully_shard/_fsdp_param.py#L301. But essentially, users want is a concatenate of some submesh into a big mesh and used as a spmd mesh. This PR is tentatively trying to implement this API for users.

One thing to note is that, all sub-mesh needs to slicing/flatten or unflatten from same root mesh otherwise the indices make no sense when it comes to mesh indexing and device allocation.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163358
Approved by: https://github.com/fegin
2025-10-27 07:39:21 +00:00
4c38887346 [rfc] add debug mode to print meta in fx graphs (#165874)
quite useful in debugging things like unbacked bindings (and presumably other mechanisms that dependent on meta including activation checkpointing and stack trace printing)

<img width="3996" height="748" alt="CleanShot 2025-10-21 at 09 41 54@2x" src="https://github.com/user-attachments/assets/8b885a36-54a5-48b4-a23c-80b39ac7eb12" />

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165874
Approved by: https://github.com/ezyang
ghstack dependencies: #165893
2025-10-27 07:20:28 +00:00
81fa4a204c Enable Intel GPU on 4 unit test cases (#165405)
For https://github.com/pytorch/pytorch/issues/114850, we will port some aten unit tests to Intel GPU. We could enable Intel GPU with following methods and try the best to keep the original code styles:

1. Replaced onlyCUDA with onlyOn(['cuda', 'xpu']) for supported tests
2. Added allow_xpu=True for supported test class in test parameterization.
3. Use torch.accelerator to extend cude specific test to XPU if needed.
4. Enabled 'xpu' for some test pathes

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165405
Approved by: https://github.com/guangyey, https://github.com/ezyang
2025-10-27 06:06:07 +00:00
4e6afa8c07 [BE][Opinfo] Mark [c]double as unsupported for MPS (#166213)
Test plan: Run `python ../test/test_ops.py -v -k test_dtypes___radd___mps` when TestCommon parametrization is enabled for MPS
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166213
Approved by: https://github.com/kulinseth, https://github.com/Skylion007
2025-10-27 05:38:36 +00:00
79aa88cc5d Remove old ROCm version checks and branches (#166111)
This PR removes outdated ROCm version checks and their branches. While there is no explicit mention of minimum supported version. ROCm 6.4 is listed in the installation page and the CI yaml files.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166111
Approved by: https://github.com/ezyang
2025-10-27 05:32:54 +00:00
fa4cb91846 add support for ir scalar literal parsing for inf/-inf/True/False (#163924)
Currently the ir parser doesn't support parse ir like
```
graph():
  %12 : float = prim::Constant[value=-inf]()
  %13 : float = prim::Constant[value=inf]()
  %14 : bool = prim::Constant[value=True]()
  %15 : bool = prim::Constant[value=False]()
  return (%12)
```

So the python script below will throw error.

```
#!/bin/env python
import torch

def test():
    return [True, False]
f = torch.jit.script(test)
torch._C._jit_pass_constant_propagation(f.graph)
ts_str = f.graph.__repr__()
print(ts_str)
ts = torch.parse_ir(ts_str)
func = torch._C._create_function_from_graph("forward", ts)
ret = func()
assert ret == [True, False]

def test():
    return [float("inf"), float("-inf")]
f = torch.jit.script(test)
torch._C._jit_pass_constant_propagation(f.graph)
ts_str = f.graph.__repr__()
print(ts_str)
ts = torch.parse_ir(ts_str)
func = torch._C._create_function_from_graph("forward", ts)
ret = func()
assert ret == [float("inf"), float("-inf")]
```

I add "inf" and bool cases for IRParser::parseScalarLiteral in irparser.cpp.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163924
Approved by: https://github.com/ezyang
2025-10-27 05:10:21 +00:00
c58d0ad85d Propose Out-of-tree Backend Integration (PrivateUse1) as a module and FFFrog as the maintainer (#165958)
I'd like to propose a new module `Out-of-tree Backend Integration` via `PrivateUse1` device key. The out-of-tree backend integration via `PrivateUse1` device key has been a recommended mechanism of plug-in third-party accelerator devices into PyTorch. There are already quite a few documents/tutorials on the usage with the primary one as https://docs.pytorch.org/docs/main/accelerator/index.html.

We also saw more and more HW vendors to leverage the `PrivateUse1` mechanism to support their accelerators. For example:
1. Ascend NPU
2. Microsoft MAIA
3. MooreThreads MUSA
4. Cambricon MLU

The scope of `PrivateUse1` based out-of-tree backend integration is composed of two parts:
1. `PrivateUse1` device as an out-of-tree backend that involves:
    (a) make `PrivateUse1` a function-complete device as other in-tree devices: i.e., device runtime, autograd, autocast, profiling, distributed, quantization etc.
    (b) a pluggable design to allow out-of-tree integration to extend the functionality of `PrivateUse1` such as a backend registration mechanism that allows user-friendly device naming, runtime extension points with either C++ and Python for third-party to plug-in their runtime implementation, customizable tensor implementation for third-party to add extra info/functionality to the tensor and their serialization.
2. OpenReg: A test suite and documentation effort to guarantee the functional correctness of `PrivateUse1` mechanism and to guide HW vendors with the right implementation.

I'm also proposing @FFFrog as the module maintainer for this new module due to his continuous contribution to the design and implementation both parts of the module. Below are the RFCs/Feature Proposals @FFFrog was working on:
1. [An improvement of PrivateUse1 mechanism, facilitating third-party backend integration](https://docs.google.com/document/d/1_2EO5A2Ww3xDwqbhIvs9Nk65-jV0oNYg3XAmNUsHdAY/edit?tab=t.0#heading=h.5vt8c1vo4dc7)
2. [The interoperability Standard of Third-party Backend Integration Mechanism](9bd181e742/RFC-0037-Interoperability-Standard-of-3rd-Backend-Integration-Mechanism.md)
3. [PyTorch Backend Accelerator Integration Verification and Guidance](f6048cbd4f/RFC-0045-PyTorch-Accelerator-Integration-Enhancements.md)

@FFFrog contributed 240+ PRs and a majority of them is related to `PrivateUse1`. (https://github.com/pytorch/pytorch/pulls?q=is%3Apr+author%3Afffrog+). He also reviewed 50+ PRs related to this area. He is also the primary author of OpenReg.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165958
Approved by: https://github.com/albanD, https://github.com/malfet, https://github.com/ezyang
2025-10-27 05:00:15 +00:00
000f49551b [DeviceMesh] Use _flatten_rank_map to replace _flatten_mesh_list so that we don't need to compare root mesh (#166003) (#166264)
Summary:

Since we are already share a flattened tensor `_rank_map` across all meshes from a same root mesh, we can just use a flattened list of it to replace the comparison of root_mesh and flattened_mesh_list (because with same _rank_map and layout, the mesh tensor is guaranteed to be the same). This way we can also give back the CPU overhead added in https://github.com/pytorch/pytorch/pull/164510 and further simply the code.

We do have a more ambitious universe-based change here: https://github.com/pytorch/pytorch/pull/165680 but it needs more discussions and would lead to BC breaking. We might eventually merge that PR but probably not now and this is a change which is not BC breaking and will help concatenate and 2D integration with concatenate.

cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta msaroufim dcci

imported-using-ghimport

Test Plan: Imported from OSS

Differential Revision: D85526705

Pulled By: fduwjj

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166264
Approved by: https://github.com/XilunWu
2025-10-27 03:15:15 +00:00
9940e894ea Fix pyrefly ignore syntax in _inductor (#166247)
Ensures pyrefly ignores only ignore the intended error code.

pyrefly check
lintrunner

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166247
Approved by: https://github.com/oulgen
2025-10-27 02:48:42 +00:00
27302a4932 Fix error suppression syntax in onnx, jit, _dynamo (#166249)
Ensures pyrefly will only silence one specific error code

pyrefly check
lintrunner

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166249
Approved by: https://github.com/oulgen
2025-10-27 02:01:54 +00:00
507614ba43 Add GraphModule.recompile_submodules, use for regional inductor (#166002)
This makes it so that `GraphModule.recompile()` will also recompile any submodules that are also graph modules, which allows us to pass all existing regional inductor tests without skipping.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166002
Approved by: https://github.com/oulgen
ghstack dependencies: #165996
2025-10-27 01:40:51 +00:00
86f9f1d0ab Enable local tensor model for DTensor redistribute tests (#166081)
Redistribute test exercise extensively various sharding schemes and
redistribution between them. These tests uncovered more edge cases
that were not supported by the local tensor primarily different flavors
of uneven sharding. In order to handle these cases this change implements
missing functional collectives and adds support for uneven sharding
case where sharding group (ranks) is larger than the size of the dimension
being sharded. In the latter case the "missing" shards are represented
by zero sized tensors so that the rest of the local tensor machinery
can stay oblivious to this special case.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166081
Approved by: https://github.com/ezyang
2025-10-26 22:21:43 +00:00
154e4d36e9 Fix pyrelfy ignore syntax in distributions and ao (#166248)
Ensures existing pyrefly ignores only ignore the intended error code

pyrefly check
lintrunner

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166248
Approved by: https://github.com/oulgen
2025-10-26 22:13:48 +00:00
a2b6afeac5 [dynamo][guards] CLASS_MATCH guard for readability (#166217)
We were using FUNCTION_MATCH guard for classes. This was very confusing
(although correct).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166217
Approved by: https://github.com/jansel
2025-10-26 18:35:27 +00:00
262830d86c [dynamo] Repro for 166238 (#166252)
xfail repro for https://github.com/pytorch/pytorch/issues/166238

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166252
Approved by: https://github.com/XuehaiPan, https://github.com/jansel
2025-10-26 18:34:22 +00:00
e4c01011c2 Mark FlexAttentionBackward as cacheable (#165996)
This probably should have been marked cacheable a long time ago, no reason that it isn't.

Test Plan:
New regional inductor tests for test_flex_attention now are serializable.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165996
Approved by: https://github.com/oulgen, https://github.com/zou3519, https://github.com/drisspg
2025-10-26 14:39:17 +00:00
a60d9e1f6d Fix flake8 B028 warnings (#166224)
This PR fixes flake8 B028 warning by specifying stacklevel=2 in `warnings.warn`. The advantage is that users can know more contextual information about PyTorch warnings.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166224
Approved by: https://github.com/ezyang
2025-10-26 06:18:55 +00:00
f863550192 [dtensor] fix incorrect norm calculation for Partial DTensors (#159856)
The sharding strategies for `aten.linalg_vector_norm` and the optimized `aten._foreach_norm.Scalar` incorrectly assumes the norm operation is always "reduction linear" with respect to its inputs. This bug causes the norm to be computed on local, incomplete data for DTensors with a `Partial(sum)` placement, leading to an inflated result (a sum of norms, rather than the correct norm of the sum).

The error can be reproduced with the following script:
```python
import os
import torch
import torch.distributed as dist
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.tensor import DTensor, Partial, Replicate, Shard

def setup_distributed():
    """Initializes the distributed environment."""
    rank = int(os.environ["RANK"])
    local_rank = int(os.environ["LOCAL_RANK"])
    world_size = int(os.environ["WORLD_SIZE"])

    dist.init_process_group("nccl")
    torch.cuda.set_device(local_rank)

    print(f"Initialized process {rank}/{world_size} on GPU {local_rank}")
    return rank, world_size

rank, world_size = setup_distributed()
assert world_size == 2, "Please run with exactly 2 GPUs for this minimal repro."

mesh = init_device_mesh("cuda", (world_size,))

if rank == 0:
    local_partial = torch.tensor([1.0, 3.0], dtype=torch.float32)
else:
    local_partial = torch.tensor([2.0, 1.0], dtype=torch.float32)

partial_dtensor = DTensor.from_local(local_partial, mesh, [Partial("sum")])
partial_result = torch.linalg.vector_norm(partial_dtensor)
print(
    f"[Rank {rank}] partial_result: {partial_result}, full_tensor: {partial_result.full_tensor()}"
)

shard_dtensor = partial_dtensor.redistribute(mesh, [Shard(0)])
shard_result = torch.linalg.vector_norm(shard_dtensor)
print(
    f"[Rank {rank}] shard_result: {shard_result}, full_tensor {shard_result.full_tensor()}"
)

replicate_dtensor = partial_dtensor.redistribute(mesh, [Replicate()])
replicate_result = torch.linalg.vector_norm(replicate_dtensor)
print(
    f"[Rank {rank}] replicate_result: {replicate_result}, full_tensor {replicate_result.full_tensor()}"
)

full_tensor = partial_dtensor.full_tensor()
full_result = torch.linalg.vector_norm(full_tensor)
print(f"[Rank {rank}] correct_result: {full_result}")
```

Run results show that the norm is `sqrt(1**2 + 3**2) + sqrt(2**2 + 1**2) = sqrt(10) + sqrt(5) = 5.398` instead of `sqrt(3**2 + 4**2) = 5`.
```
$ torchrun --local-ranks-filter 0 --nproc-per-node 2 script.py
Initialized process 0/2 on GPU 0
[Rank 0] partial_result: DTensor(local_tensor=3.1622776985168457, device_mesh=DeviceMesh('cuda', [0, 1]), placements=(Partial(sum),)), full_tensor: 5.398345947265625
[Rank 0] shard_result: DTensor(local_tensor=3.0, device_mesh=DeviceMesh('cuda', [0, 1]), placements=(_NormPartial(reduce_op='sum', norm_type=2),)), full_tensor 5.0
[Rank 0] replicate_result: DTensor(local_tensor=5.0, device_mesh=DeviceMesh('cuda', [0, 1]), placements=(Replicate(),)), full_tensor 5.0
[Rank 0] correct_result: 5.0
$ torchrun --local-ranks-filter 1 --nproc-per-node 2 script.py
Initialized process 1/2 on GPU 1
[Rank 1] partial_result: DTensor(local_tensor=2.2360680103302, device_mesh=DeviceMesh('cuda', [0, 1]), placements=(Partial(sum),)), full_tensor: 5.398345947265625
[Rank 1] shard_result: DTensor(local_tensor=4.0, device_mesh=DeviceMesh('cuda', [0, 1]), placements=(_NormPartial(reduce_op='sum', norm_type=2),)), full_tensor 5.0
[Rank 1] replicate_result: DTensor(local_tensor=5.0, device_mesh=DeviceMesh('cuda', [0, 1]), placements=(Replicate(),)), full_tensor 5.0
[Rank 1] correct_result: 5.0
```

This fix simply forces `reduction_linear=False` for partial placements. The output becomes:
```
$ python -m torch.distributed.run --local-ranks-filter 0 --nproc-per-node 2 script.py
Initialized process 0/2 on GPU 0
[Rank 0] partial_result: DTensor(local_tensor=5.0, device_mesh=DeviceMesh((2,), device: 'cuda', stride: (1,)), placements=(Replicate(),)), full_tensor: 5.0
[Rank 0] shard_result: DTensor(local_tensor=3.0, device_mesh=DeviceMesh((2,), device: 'cuda', stride: (1,)), placements=(_NormPartial(reduce_op='sum', norm_type=2),)), full_tensor 5.0
[Rank 0] replicate_result: DTensor(local_tensor=5.0, device_mesh=DeviceMesh((2,), device: 'cuda', stride: (1,)), placements=(Replicate(),)), full_tensor 5.0
[Rank 0] correct_result: 5.0
$ python -m torch.distributed.run --local-ranks-filter 1 --nproc-per-node 2 script.py
Initialized process 1/2 on GPU 1
[Rank 1] partial_result: DTensor(local_tensor=5.0, device_mesh=DeviceMesh((2,), device: 'cuda', stride: (1,)), placements=(Replicate(),)), full_tensor: 5.0
[Rank 1] shard_result: DTensor(local_tensor=4.0, device_mesh=DeviceMesh((2,), device: 'cuda', stride: (1,)), placements=(_NormPartial(reduce_op='sum', norm_type=2),)), full_tensor 5.0
[Rank 1] replicate_result: DTensor(local_tensor=5.0, device_mesh=DeviceMesh((2,), device: 'cuda', stride: (1,)), placements=(Replicate(),)), full_tensor 5.0
[Rank 1] correct_result: 5.0
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159856
Approved by: https://github.com/ezyang
2025-10-26 05:58:44 +00:00
84b14f3a10 Fix error suppression syntax in utils and nn (#166242)
Fixes syntax for pyrefly : ignores so they only ignore a specific category. No functional changes

pyrefly check
lintrunner

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166242
Approved by: https://github.com/oulgen, https://github.com/cyyever
2025-10-26 05:21:07 +00:00
5121499f6b Fix pyrefly ignore syntax in /tools/... (#166240)
Second PR for this - only adjusts the syntax used for the ignores so the suppressions hide only one category of pyrefly errors.

test:
pyrefly check
lintrunner

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166240
Approved by: https://github.com/oulgen
2025-10-26 04:20:16 +00:00
8f80892359 Use correct pyrefly syntax in suppressions distributed/... (#166241)
Updates the pyrefy-ignores in the torch/distributed directory to use the correct syntax. No functional changes.

pyrefly check
lintrunner

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166241
Approved by: https://github.com/oulgen
2025-10-26 04:16:41 +00:00
cdb60e44eb [Inductor] Naive foreach autotune support (#162053)
Initial autotuning support for foreach kernels, 4x improvement for some kernels in internal workload. More improvements can surely be made here in the future. Removing num_warps for definition to enable autotune support in generated wrapper code.

Before:
triton_for_fused_18.kd 🔍 | 4.986 ms | 4.986 ms | 2.493 ms | 2 |
triton_for_fused_6.kd 🔍 | 0.098 ms | 0.098 ms | 0.049 ms | 2 |
triton_for_fused_7.kd 🔍 | 0.036 ms | 0.036 ms | 0.018 ms | 2 |

After:
triton_for_fused_18.kd 🔍 | 1.273 ms | 1.273 ms | 0.636 ms | 2 |
triton_for_fused_6.kd 🔍 | 0.044 ms | 0.044 ms | 0.022 ms | 2 |
triton_for_fused_7.kd 🔍 | 0.024 ms | 0.024 ms | 0.012 ms | 2 |

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162053
Approved by: https://github.com/mlazos, https://github.com/naromero77amd

Co-authored-by: Nichols A. Romero <nick.romero@amd.com>
2025-10-26 02:36:15 +00:00
25909d2629 Simplify SingletonOrSharedTypePtr (#166183)
@neildhar pointed out at PTC yesterday that the assumption SingletonOrSharedTypePtr makes about shared_ptr's pointers being either both null or both non-null is incorrect because of the aliasing constructor, and furthermore that SingletonOrSharedTypePtr needn't be as fancy as it is because said constructor exists. (See also https://github.com/pytorch/pytorch/issues/166152 .)

Differential Revision: [D85458769](https://our.internmc.facebook.com/intern/diff/D85458769/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166183
Approved by: https://github.com/Skylion007, https://github.com/cyyever
2025-10-26 01:25:24 +00:00
c7eee49525 Fix pyrefly ignores 1/n (#166239)
First diff adjusting the syntax for pyrefly: ignore suppressions so they only hide one class of type error.

Test:
lintrunner
pyrefly check

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166239
Approved by: https://github.com/oulgen
2025-10-26 00:44:10 +00:00
eqy
621ba05107 [cuDNN][SDPA] Handle c10:Error when checking device capability for prefer-cuDNN SDPA check (#166201)
Fake device test can execute this function when the number of visible CUDA devices is 0, fix to unblock #165922

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166201
Approved by: https://github.com/Skylion007
2025-10-25 23:00:06 +00:00
39a70cead1 [feat]: add optimized exp_u20 implementation from Arm Optimized Routi… (#161049)
This patch adds an optimized exp_u20() implementation, based on Arm Optimized Routines (AOR). The legacy svexp_f32_z function is removed, and internal uses (such as in tanh) now leverage the new exp_u20() logic. Unit tests have been updated to cover all scenarios.
The implementation ensures correct handling of edge cases by falling back to exp() for extreme inputs (|x| ≥ 0x1.5d5e2ap+6f or |x| ≤ -0x1.5d5e2ap+6f).

Performance:
<html>
<body>
<!--StartFragment--><h3 data-start="70" data-end="182"><strong data-start="77" data-end="182">Performance Improvements for <code data-start="108" data-end="144">aten::scaled_dot_product_attention</code> (Neoverse-V2, <code data-start="159" data-end="179">OMP_NUM_THREADS=16</code>)</strong></h3>
<div class="_tableContainer_1rjym_1"><div tabindex="-1" class="group _tableWrapper_1rjym_13 flex w-fit flex-col-reverse">
<html xmlns:v="urn:schemas-microsoft-com:vml"
xmlns:o="urn:schemas-microsoft-com:office:office"
xmlns:x="urn:schemas-microsoft-com:office:excel"
xmlns="http://www.w3.org/TR/REC-html40">

<head>

<meta name=ProgId content=Excel.Sheet>
<meta name=Generator content="Microsoft Excel 15">
<link id=Main-File rel=Main-File
href="file:///C:/Users/anaabu01/AppData/Local/Temp/msohtmlclip1/01/clip.htm">
<link rel=File-List
href="file:///C:/Users/anaabu01/AppData/Local/Temp/msohtmlclip1/01/clip_filelist.xml">

<!--table
	{mso-displayed-decimal-separator:"\.";
	mso-displayed-thousand-separator:"\,";}
@page
	{margin:.75in .7in .75in .7in;
	mso-header-margin:.3in;
	mso-footer-margin:.3in;}
tr
	{mso-height-source:auto;}
col
	{mso-width-source:auto;}
br
	{mso-data-placement:same-cell;}
td
	{padding-top:1px;
	padding-right:1px;
	padding-left:1px;
	mso-ignore:padding;
	color:black;
	font-size:11.0pt;
	font-weight:400;
	font-style:normal;
	text-decoration:none;
	font-family:"Aptos Narrow", sans-serif;
	mso-font-charset:0;
	mso-number-format:General;
	text-align:general;
	vertical-align:bottom;
	border:none;
	mso-background-source:auto;
	mso-pattern:auto;
	mso-protection:locked visible;
	white-space:nowrap;
	mso-rotate:0;}
.xl63
	{font-weight:700;
	text-align:center;
	vertical-align:middle;
	border:.5pt solid windowtext;
	white-space:normal;}
.xl64
	{text-align:center;
	vertical-align:middle;
	border:.5pt solid windowtext;
	white-space:normal;}
-->

</head>

<body link="#467886" vlink="#96607D">

Configuration | Current | With Changes   (F32) | Speedup
-- | -- | -- | --
Batch 1 · 16 Heads   · Seq 512 · Q 128 | 654.102 µs | 551.031 µs | 1.19× faster (≈ 19%)
Batch 8 · 64 Heads   · Seq 2048 · Q 128 | 30.308 ms | 17.142 ms | 1.77× faster (≈ 43%)

</body>

</html>

</div></div><!--EndFragment-->
</body>
</html>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/161049
Approved by: https://github.com/fadara01, https://github.com/jgong5

Co-authored-by: Fadi Arafeh <Fadi.Arafeh@arm.com>
2025-10-25 20:44:11 +00:00
d97f6550a2 [Intel GPU] Xpu matmul implementation for complex dtype (#160867)
Enabling complex datatype support for 4 ops: `mm`, `bmm`, `addmm`, `baddbmm` for XPU. From now implementation will call functions created in: https://github.com/intel/torch-xpu-ops/pull/1992.

Additionally added complex datatype tests for matmul operators. More detailed tests are going to be enabled in: https://github.com/intel/torch-xpu-ops/pull/1993

CI runs have found that `test_comprehensive_linalg_eig_xpu` tests were calling internally matmul with complex datatype. With this PR test starts to pass so linalg.eig was removed from `inductor_expected_failures_single_sample["xpu"]` as otherwise it was failing with: `Unexpected success` message.

Part of: https://github.com/intel/torch-xpu-ops/issues/1853

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160867
Approved by: https://github.com/guangyey, https://github.com/ZhiweiYan-96, https://github.com/gujinghui, https://github.com/EikanWang, https://github.com/Silv3S, https://github.com/CuiYifeng, https://github.com/jansel
2025-10-25 17:13:13 +00:00
516e58965a Revert "Export flex attention with kwargs and DTensor (#166045)"
This reverts commit de7fdfe41ad12aec719e3662be58ce9e9bf255a8.

Reverted https://github.com/pytorch/pytorch/pull/166045 on behalf of https://github.com/malfet due to Broke distributed tests, see b55b779ad3/1 ([comment](https://github.com/pytorch/pytorch/pull/166045#issuecomment-3446850955))
2025-10-25 15:47:32 +00:00
b55b779ad3 Add file size limits to linters and refactor grep_linter (#166202)
- Add 1GB file size limits to grep_linter, newlines_linter, codespell_linter
- Refactor grep_linter
  - process files once instead of per-line
  - Extract allowlist check to separate function
  - Add 512KB limit for computing replacements, 100 match limit per file
  - Detect duplicate arguments
- Fix .lintrunner.toml: RAWCUDADEVICE used --pattern twice
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166202
Approved by: https://github.com/Skylion007
2025-10-25 14:57:19 +00:00
74e53d0761 [TorchScript] clearer debug for ConcreteModuleType::findSubmoduleConcreteType (#166192)
Summary:
right now the log is just
```
RuntimeError: it != data_.modules_.end() INTERNAL ASSERT FAILED at "fbcode/caffe2/torch/csrc/jit/frontend/concrete_module_type.cpp":207, please report a bug to PyTorch.
```
we have no clue where the error happens
https://fb.workplace.com/groups/gpuinference/posts/789257990578348/?comment_id=789284783909002&reply_comment_id=789415260562621

Test Plan: UT

Reviewed By: jcwchen

Differential Revision: D80020093

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166192
Approved by: https://github.com/gmagogsfm
2025-10-25 14:07:54 +00:00
798a6d2be1 [Inductor][Autotune] Gracefully restart the autotune process after ULF failure (#166073)
This PR partially fixes https://github.com/pytorch/torchtitan/issues/1791, as it will work with `TORCHINDUCTOR_AUTOTUNE_IN_SUBPROC=1` setting only.

The core of the problem: In `max-autotune` mode Inductor runs multiple benchmarks to determine the best config. If one of these benchmarks fails with `cudaErrorLaunchFailure`, all other CUDA calls within the same process will fail including the rest of the benchmarks.

The solution: Restart the child process gracefully and continue benchmarking.

Unfortunately, if autotuning is done in the main process, the whole program falls into unrecoverable state. In this case, the only way of successful execution would be just preventing the ULF.

Here is some info from [CUDA documentation](https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html):
>cudaErrorLaunchFailure = 719
An exception occurred on the device while executing a kernel. ... . This leaves the process in an inconsistent state and any further CUDA work will return the same error. To continue using CUDA, the process must be terminated and relaunched.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166073
Approved by: https://github.com/syed-ahmed, https://github.com/drisspg
2025-10-25 10:40:59 +00:00
b0e9c86971 [MPS] Move hypot to Metal (#166216)
Which also prevents crashes, when invoked for integer types, for example, before this change following crashes
```
python -c "import torch; print(torch.hypot(torch.randint(0, 10, (3,), device='mps'), torch.randint(0, 10, (3,), device='mps')))"
*** Terminating app due to uncaught exception 'NSInvalidArgumentException', reason: '*** -[__NSDictionaryM setObject:forKey:]: object cannot be nil (key: squareRoot_i64)'
*** First throw call stack:
(
	0   CoreFoundation                      0x0000000194d33ae0 __exceptionPreprocess + 176
	1   libobjc.A.dylib                     0x00000001947f6b90 objc_exception_throw + 88
	2   CoreFoundation                      0x0000000194c7d884 -[__NSDictionaryM setObject:forKey:] + 1288
	3   MPSCore                             0x00000001a1187d0c _ZN12MPSKernelDAG15duodenaryCoreOpEP10BaseTensorS1_S1_S1_S1_S1_S1_S1_S1_S1_S1_S1_RKNSt3__16vectorIlNS2_9allocatorIlEEEE11MPSDataTypePKc + 37044
	4   MPSCore                             0x00000001a113fab0 _ZN12MPSKernelDAGD0Ev + 4256
	5   MPSCore                             0x00000001a1139f6c _ZN12MPSKernelDAG13getDAGAndHashEPU21objcproto10MTLLibrary11objc_objectP14MPSDAGKernelOpP19NSMutableDictionaryIP8NSStringPU22objcproto11MTLFunction11objc_objectEP14NSMutableArrayIS6_ERDv4_yPb + 8	6   MPSCore                             0x00000001a113c7a4 _ZN12MPSKernelDAG13getDAGAndHashEPU21objcproto10MTLLibrary11objc_objectP14MPSDAGKernelOpP19NSMutableDictionaryIP8NSStringPU22objcproto11MTLFunction11objc_objectEP14NSMutableArrayIS6_ERDv4_yPb + 1	7   MPSCore                             0x00000001a11c03c8 _ZN10MPSLibrary19CreateUberShaderKeyEP8NSStringRK23MPSFunctionConstantListyPFPU22objcproto11MTLFunction11objc_objectPU21objcproto10MTLLibrary11objc_objectPK13MPSKernelInfoS4_RK33MPSFunctionConstr	8   MPSNDArray                          0x00000001a27b546c MPSSetResourcesOnCommandEncoder + 154176
	9   MPSNDArray                          0x00000001a27967d8 MPSSetResourcesOnCommandEncoder + 28076
	10  MPSNDArray                          0x00000001a2798ec8 MPSSetResourcesOnCommandEncoder + 38044
	11  MetalPerformanceShadersGraph        0x00000001f97689ac _ZN3GPU17IdentityOpHandler15encodeNDArrayOpEPNS_16EncodeDescriptorEP7NSArray + 436
	12  MetalPerformanceShadersGraph        0x00000001f977f93c _ZN3GPU17StitchedOpHandler8encodeOpEPNS_16EncodeDescriptorE + 924
	13  MetalPerformanceShadersGraph        0x00000001f9544898 _ZN16GPURegionRuntime5runOpIN3GPU23AbsoluteSquareOpHandlerEEEvPN4mlir9OperationEPNS1_16EncodeDescriptorE + 120
	14  MetalPerformanceShadersGraph        0x00000001f9543894 _ZN16GPURegionRuntime8encodeOpEPN4mlir9OperationEPN3GPU16EncodeDescriptorE + 4700
	15  MetalPerformanceShadersGraph        0x00000001f954251c _ZN16GPURegionRuntime29encodeOpWithCommitAndContinueEPN4mlir9OperationEPN3GPU16EncodeDescriptorE + 92
	16  MetalPerformanceShadersGraph        0x00000001f954189c _ZN16GPURegionRuntime11evaluateOpsEPN3GPU16EncodeDescriptorEP7NSArrayIP18MPSGraphTensorDataES7_ + 3572
	17  MetalPerformanceShadersGraph        0x00000001f953f7b4 _ZN10MPSRuntime11evaluateOpsEN4mlir4func6FuncOpEP21RuntimeSpecializationP7NSArrayIP18MPSGraphTensorDataES9_P37MPSGraphExecutableExecutionDescriptorP16MPSCommandBufferbbbPb + 824
	18  MetalPerformanceShadersGraph        0x00000001f988dd38 -[MPSGraphExecutable runInternalWithDevice:commandBuffer:feeds:results:executableExecutionDescriptor:mpsGraphOwnedCommandBuffer:] + 3848
	19  MetalPerformanceShadersGraph        0x00000001f988ca04 -[MPSGraphExecutable runInternalWithDevice:commandBuffer:feedsDictionary:resultsDictionary:executableExecutionDescriptor:mpsGraphOwnedCommandBuffer:] + 608
	20  MetalPerformanceShadersGraph        0x00000001f9728aa0 -[MPSGraph runInternalWithMPSCommandBuffer:feeds:targetTensors:targetOperations:resultsDictionary:executionDescriptor:mpsGraphOwnedCommandBuffer:] + 320
	21  MetalPerformanceShadersGraph        0x00000001f9727b58 -[MPSGraph encodeToCommandBuffer:feeds:targetOperations:resultsDictionary:executionDescriptor:] + 188
	22  libtorch_cpu.dylib                  0x00000001556c9478 ___ZN2at3mps9MPSStream15executeMPSGraphEP8MPSGraphP12NSDictionaryS5_NS0_8SyncTypeE_block_invoke + 128
	23  libdispatch.dylib                   0x0000000194a3985c _dispatch_client_callout + 16
	24  libdispatch.dylib                   0x0000000194a2f7a8 _dispatch_lane_barrier_sync_invoke_and_complete + 56
	25  libtorch_cpu.dylib                  0x00000001556c93e0 _ZN2at3mps9MPSStream15executeMPSGraphEP8MPSGraphP12NSDictionaryS5_NS0_8SyncTypeE + 160
	26  libtorch_cpu.dylib                  0x00000001556fd0f4 _ZN2at6native3mpsL14binaryOpTensorERKNS_6TensorES4_S4_NSt3__112basic_stringIcNS5_11char_traitsIcEENS5_9allocatorIcEEEEU13block_pointerFP14MPSGraphTensorPNS1_19BinaryOpCachedGraphESD_SD_E + 3040
	27  libtorch_cpu.dylib                  0x00000001556ff680 _ZN2at6native24structured_hypot_out_mps4implERKNS_6TensorES4_S4_ + 84
	28  libtorch_cpu.dylib                  0x00000001522682e4 _ZN2at12_GLOBAL__N_117wrapper_MPS_hypotERKNS_6TensorES3_ + 216
	29  libtorch_cpu.dylib                  0x0000000153a1378c _ZN3c104impl28wrap_kernel_functor_unboxed_INS0_6detail24WrapFunctionIntoFunctor_INS_26CompileTimeFunctionPointerIFN2at6TensorENS_14DispatchKeySetERKS6_S9_EXadL_ZN5torch8autograd12VariableType12_G	30  libtorch_cpu.dylib                  0x0000000151241714 _ZN2at4_ops5hypot4callERKNS_6TensorES4_ + 304
	31  libtorch_python.dylib               0x0000000105d9a848 _ZN5torch8autogradL17THPVariable_hypotEP7_objectS2_S2_ + 752
	32  Python                              0x00000001036afa7c cfunction_call + 72
	33  Python                              0x000000010365db08 _PyObject_MakeTpCall + 124
	34  Python                              0x0000000103750f40 _PyEval_EvalFrameDefault + 23304
	35  Python                              0x000000010374b1c8 PyEval_EvalCode + 184
	36  Python                              0x00000001037ab8bc run_eval_code_obj + 88
	37  Python                              0x00000001037a9994 run_mod + 132
	38  Python                              0x00000001037a8fdc PyRun_StringFlags + 124
	39  Python                              0x00000001037a8f08 PyRun_SimpleStringFlags + 64
	40  Python                              0x00000001037cd464 Py_RunMain + 716
	41  Python                              0x00000001037cd950 pymain_main + 304
	42  Python                              0x00000001037cd9f0 Py_BytesMain + 40
	43  dyld                                0x0000000194836b98 start + 6076
)
libc++abi: terminating due to uncaught exception of type NSException
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166216
Approved by: https://github.com/Skylion007
ghstack dependencies: #166210
2025-10-25 08:51:38 +00:00
661a56002f [AI Codemod][DevmateFBSourceTestFailureBot] Fix for T241916639 ("Your diff, D84932408, broke one test") (#166168)
Reviewed By: XilunWu

Differential Revision: D84983164

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166168
Approved by: https://github.com/H-Huang, https://github.com/fduwjj
2025-10-25 06:46:23 +00:00
c9bc00f016 Split grouped_mm methods into their own file (#166140)
Summary:

`Blas.cpp` was getting a little full and hard to work with, split out
the `*_grouped_mm` methods into their own file

Test Plan:

```
pytest -svv -k group test/test_matmul_cuda.py
pytest -svv -k group test/test_scaled_matmul_cuda.py
```

Reviewers:

Subscribers:

Tasks:

Tags:
Signed-off-by: Simon Layton <simonlayton@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166140
Approved by: https://github.com/drisspg
ghstack dependencies: #166139
2025-10-25 05:40:31 +00:00
ec51b139e1 Factor out shared scaled mm routines (#166139)
Summary:

In preparation for splitting out scaled grouped mm functions, factor out
scaled-specific routines into their own file(s)

Test Plan:

```
pytest -svv test/test_scaled_matmul_cuda.py
```

Reviewers:

Subscribers:

Tasks:

Tags:
Signed-off-by: Simon Layton <simonlayton@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166139
Approved by: https://github.com/drisspg
2025-10-25 05:40:31 +00:00
eb83c3ca23 Clean up unused Pyrefly suppressions (#166178)
Cleaning up ignores that are no longer needed in the repo and adding select suppressions so the main branch is clean.

test plan:
`lintrunner -a`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166178
Approved by: https://github.com/oulgen
2025-10-25 05:32:21 +00:00
7924e3aacf Remove likely unnecessary _EXPAND trick for non-windows in HIDDEN_NAMESPACE_BEGIN (#166203)
I've learned that the EXPAND trick is needed mostly for an MSVC quirk to properly expand arguments. I tested on Linux locally and suspect that we don't need the _EXPAND for non-windows.  This PR is BE to minimalize what we need and remove what we don't, but I'm also okay not landing this if @malfet tells me that this quirk goes beyond MSVC.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166203
Approved by: https://github.com/malfet
ghstack dependencies: #166076, #166077, #166078, #166079
2025-10-25 04:44:07 +00:00
78bcfcf870 [fx] Optimize torch.fx.Node.replace_all_uses_with (#165889)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165889
Approved by: https://github.com/aorenste
2025-10-25 03:44:41 +00:00
1e2e7cb18b Add doc for Symmetric Memory (#166148)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166148
Approved by: https://github.com/fduwjj
2025-10-25 03:41:15 +00:00
003601a70d Set prefer_deferred_runtime_asserts_over_guards to True (#165820)
Set prefer_deferred_runtime_asserts_over_guards to True and allow a flag to control the behavior, just in case.

This option has enable the gemma3 model export with transformers==4.57. I am not sure how best to test it though.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165820
Approved by: https://github.com/titaiwangms
2025-10-25 03:38:19 +00:00
1d58d5fe25 [hops] fix unbacked runtime asserts for cond higher order op (#165893)
At a high level after this fix we get the following nice tlparse https://manifold.edge.x2p.facebook.net/v0/read/tree/logs/bobren/54a57665-7dcc-41e0-8ca7-df01393cd4aa/custom/index.html?bucketName=tlparse_reports&apiKey=tlparse_reports-key&withPayload=1&timeoutMsec=10000

As seen in this doc, previously we were simply dropping assert post
dynamo: https://docs.google.com/document/d/1nRQwvw_gWL0_9T3VKb5Ly3_tNI1fgqG9WtryeD6qaZI/edit?tab=t.0

The fixes are a couple things:

1) Actually run the runtime assertion fx graph pass on subgraphs
2) Reset fake mode unbacked memo across speculate subgraph invocations
   since the memos actually break the runtime assertion insertions since
   calls like nonzero end up not allocating new unbacked symints and
   hence not populating pending_unbacked which then results in incorrect
   unbacked_bindings on fx_nodes in subgraphs.

This is a first step in hardening runtime asserts across all phases of
the compiler (eager, aot_eager, inductor, etc.). I will continue kicking
tires and fixing bugs until we get runtime assert generations in a good
place. One obvious next step is the added test case in this PR fails
when compiled with inductor with the following error (NB: it fails before this PR as well):

```
  File "/data/users/bobren/a/pytorch/torch/_inductor/ir.py", line 659, in get_dtype
    return self.dtype
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
LoweringException: AttributeError: 'ShapeAsConstantBuffer' object has no attribute 'dtype'
  target: cond
  args[0]: Eq(Mod(s77, 4), 0)
  args[1]: Subgraph(name='true_graph_0', graph_module=<lambda>(), graph=<torch._inductor.graph.SubgraphLowering object at 0x7fbcbb11e110>)
  args[2]: Subgraph(name='false_graph_0', graph_module=<lambda>(), graph=<torch._inductor.graph.SubgraphLowering object at 0x7fbcbb21cf70>)
  args[3]: (s77, TensorBox(StorageBox(
    ComputedBuffer(name='buf0', layout=FlexibleLayout('cuda:0', torch.float32, size=[s77, s77], stride=[s77, 1]), data=Pointwise(device=device(type='cuda', index=0), dtype=torch.float32, inner_fn=<function make_pointwise.<locals>.inner.<locals>.inner_fn at 0x7fbcbb2f37f0>, ranges=[s77, s77]))
  )))
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165893
Approved by: https://github.com/zou3519
2025-10-25 03:25:36 +00:00
de7fdfe41a Export flex attention with kwargs and DTensor (#166045)
Fixes #165948

Adding registration of the MaskBlock makes flex attention with kwargs exportable.

Also modified unittests to accept kwargs

```
python test/distributed/tensor/test_dtensor_export.py -k test_flex_attention_dtensor_export

python test/inductor/test_flex_attention.py -k test_pytree_
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166045
Approved by: https://github.com/drisspg
2025-10-25 03:17:22 +00:00
b31bad1b8f [Pytorch] Enable autovec on aarch64 for type conversion (#166049)
Summary:
Implementing autovec template for type conversions on aarch64-NEON

Generated code can be seen here: https://godbolt.org/z/1K6T1d9TE

We've seen significant performance improvements for converting to and from bytes, compiling using clang with -march=armv9-a+sve2:

Before
float->uint8->float ===> 683.212us
float->int8->float ===> 687.846us
int32->uint8->int32 ===> 497.121us
int32->int8->int32 ===> 481.889us

After:
float->uint8->float ===> 198.204us  ----> 245% higher throughput
float->int8->float ===> 200.241us ----> 244% higher throughput
int32->uint8->int32 ===> 197.970us ----> 151% higher throughput
int32->int8->int32 ===> 198.206us ----> 143% higher throughput

Test Plan:

buck2 test mode/opt //caffe2/test:test_ops
buck2 test mode/opt //caffe2/test:torch

Differential Revision: D85213420

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166049
Approved by: https://github.com/ezyang, https://github.com/mcfi, https://github.com/aditew01
2025-10-25 02:55:50 +00:00
2efcf3ca98 Reverts #163712 and forces allgather/scatter inputs/outputs to be contiguous (#166181)
Per title

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166181
Approved by: https://github.com/kwen2501
2025-10-25 02:43:10 +00:00
761f946043 [ROCm] new implementation of upsample_bilinear2d_backward (#164572)
Changed the implementation from an output-based approach to an input-based one to remove `atomicAdd` operations, and it appears to deliver at least a 20× speedup.

The changes are from Yu-Yun <YuYun.Chang@amd.com>.

# Summary: Refactor of the implementation of the `upsample_bilinear2d_backward` opertion on MI300X/MI325X
- The original "scatter-add" approach
  - Each thread, representing an output pixel, scattered gradient contributions to four input pixels, using costly atomic operations on MI300X/MI325X GPUs.
- The new "gather-sum" approach
  - Each thread is responsible for a single input pixel and gathers all relevant gradient contributions from a small, calculated region of the output tensor (done by the `compute_output_range` device function).
# Breakdown of the code changes
- Inversion of the parallelization strategy of the kernel function `upsample_bilinear2d_backward_out_frame`
  - Originally, the main kernel loop was parallelized over the number of elements in the output gradient tensor (`const size_t o_numel = nc * width2 * height2;`).
    - Each thread processed one output pixel.
  - The new loop is parallelized over the number of elements in the input gradient tensor (`const size_t i_numel = nc * height1 * width1;`).
    - Each thread is responsible for calculating the final gradient for a single input pixel.
  - The kernel launch changes accordingly in the function `upsample_bilinear2d_backward_out_cuda_template`.
- Added a device function for calculating the range of output pixels that could have possibly used that the input pixel (`input_pos`) during the forward pass interpolation
  - This is essentially the mathematical inverse of the forward pass.
  - This function tries to prune a thread's search space so that it only needs to inspect a small, local window of the output tensor.
- Gradient calculation approach switching from "scatter-add" to "gather-sum"
  - Scatter-add
    - For each output pixel, the thread calculated 4 gradient contributions and use `fastAtomicAdd` 4 times to add these values to 4 different (and potentially highly contended) memory locations in the input gradient tensor.
  - Gather-sum
    - A thread responsible for one input pixel calls `compute_output_range` to determine the small rectangular region of output pixels that influence the input's final gradient value.
    - The thread iterates through this region, and for each output pixel in the regionre, it re-calculates the interpolation weights to determine the exact contribution to its specific input pixel.
    - All these contributions are accumulated into a private, per-thread register variable (`accscalar_t grad_sum = 0;`).
      - W/o any gloabl memory access, this accumulation is extremely fast.
    - When the loops are done, the thread performs a single, direct write (non-atomic) of the final summed gradient to its designated location in global memory (`idata[index] = static_cast<scalar_t>(grad_sum);`).
# Why performance gets boosted
- Analysis of the root cause of performance drop
  - Ref. (internal only) - https://amd.atlassian.net/wiki/spaces/~glencao2/pages/1140493327/PyTorch__upsample_bilinear2d_backward
- First and foremost, elimination of the contention of atomic operations
  - Many parallel threads called `atomicAdd` frequently attempting to update the exact same memory location in the input gradient tensor at the same time.
    - The GPU's memory controler has to serialize these operations, effectively nullifying the benefit of parallel capability at those contention points.
  - MI300X/MI325X chiplet-based CDNA 3 architeture amplified the issue.
    - When contending threads reside on different XCDs, resolving the atomic operation requires high-latency coherence traffic across the Infinity Fabric interconnect.
  - The implementation change eliminates hardware-level serialization and cross-chiplet coherence traffic caused by many `atomicAdd`.
- Improved memory access pattern and locality
  - Write coalescing
    - The regular sum writes `idata[index] = static_cast<scalar_t>(grad_sum);` can be perfectly coalesced by GPUs.
  - Read locality
    - Even though there are many (potentially repeated) reads from the output tensor (`static_cast<accscalar_t>(odata[output_idx])`), these are highly cache-friendly, meaning the data for one thread is likely to be in the L1 or L2 cache already due to an access from a neighboring thread.
- Trade-off: computation for memory synchronization
  - The recalculation of interpolation weights fits well on high-computational-throughput modern GPUs like MI300X/MI325X.
  - Removal of atomic operations avoids expensive memory synchronization.

---

Optimizations of `grid_sampler_2d_backward` will be addressed in a separate PR.
Doc for reference: (internal only) https://amd.atlassian.net/wiki/spaces/~glencao2/pages/1162750701/PyTorch__grid_sampler_2d_backward

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164572
Approved by: https://github.com/jeffdaily

Co-authored-by: Eli Uriegas <1700823+seemethere@users.noreply.github.com>
2025-10-25 02:39:24 +00:00
8aa465f18e [MPS] Migrate angle to Metal ops (#166210)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166210
Approved by: https://github.com/Skylion007
2025-10-25 01:52:33 +00:00
0a5d68d92d [dynamo] Remove unnecessary NAME_MATCH guard (#166112)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166112
Approved by: https://github.com/Lucaskabela
ghstack dependencies: #166155
2025-10-25 01:27:42 +00:00
42bd210fff [dynamo] Avoid ID_MATCH on methods - use CLOSURE_MATCH on functions (#166155)
id on methods can change from invocation to invocation. Here we guard on
__code__ objects which does not change

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166155
Approved by: https://github.com/jansel
2025-10-25 01:27:42 +00:00
1d13c314b3 [OpenReg] Remove the Unnecessary Fallback Implementation for AutogradPrivate1 (#165316)
As the title stated.

The fallback for AutogradPrivateUse1 is builtin in PyTorch, so it is no need to register general implementation for out of tree backend.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165316
Approved by: https://github.com/ezyang, https://github.com/albanD
ghstack dependencies: #165315
2025-10-25 01:27:27 +00:00
0c9763a5a0 [Autograd] Add Default Autograd Fallback for PrivateUse1 in PyTorch (#165315)
Please refer to this [link](https://github.com/pytorch/pytorch/issues/163979) for more background.

- Allow register fallback for AutogradPrivateUse1 multiple.
- Add Autograd fallback implemetation for AutogradPrivateUse1

PyTorch can privide a common implementation for AutogradPrivateUse1, and the user can override it based on the need of specififc accelerator.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165315
Approved by: https://github.com/albanD
2025-10-25 01:27:27 +00:00
79a4a9c02e Fix race condition and make CUDA kthvalue deterministic (#165762)
The gatherKthValue kernel had a race condition where multiple threads could write to the same output location without synchronization when duplicate k-th values exist, resulting in non-deterministic output.

Changes:
- aten/src/ATen/native/cuda/Sorting.cu: Use atomicMin with shared memory to deterministically find minimum index. Add early termination and remove redundant inRange checks. (We have to cast the index to `int32_t`, but this is already assumed to fit earlier in the kernel.)
- aten/src/ATen/native/cuda/Sorting.cpp: Remove non-deterministic alert since kthvalue is now deterministic on CUDA.
- torch/__init__.py: Remove kthvalue from non-deterministic operations list and remove kthvalue example from use_deterministic_algorithms() docstring.
- test/test_torch.py: Remove test_nondeterministic_alert_kthvalue since kthvalue no longer raises alerts on CUDA.

Benefits:
- Deterministic: always returns minimum index when duplicates exist
- Potential performance improvement on large arrays with repetitions

Test Results:
- All existing PyTorch tests pass (test_kthvalue)
- Custom determinism tests confirm consistent results
- Custom CUDA vs CPU correctness validated across 50+ scenarios
- Custom performance benchmarks show improvements with no visible regressions

Addresses #165227

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165762
Approved by: https://github.com/ngimel, https://github.com/eqy
2025-10-25 00:45:57 +00:00
9d0b77f4cd [10/N] Apply ruff UP035 rule (#165709)
This is a follow-up of #165515. ruff `UP035` rules are applied to  dynamo code to use Py 3.10+ typing.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165709
Approved by: https://github.com/ezyang
2025-10-25 00:20:13 +00:00
d486eee234 Hide APIs in torch::headeronly (#166079)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166079
Approved by: https://github.com/malfet, https://github.com/cyyever
ghstack dependencies: #166076, #166077, #166078
2025-10-25 00:18:26 +00:00
cddd5f74ab Hide stable Library structs instead of using anon namespace (#166078)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166078
Approved by: https://github.com/malfet
ghstack dependencies: #166076, #166077
2025-10-25 00:18:26 +00:00
dfdb68e51f Hide all APIs in torch::stable (#166077)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166077
Approved by: https://github.com/malfet
ghstack dependencies: #166076
2025-10-25 00:18:26 +00:00
98c818320a Add HIDDEN_NAMESPACE_BEGIN and END macros for hiding header APIs (#166076)
Spurred by the conversation started in https://github.com/pytorch/pytorch/issues/163343.

Context:
* Header implementations may be inlined _but_ are not necessarily inlined, even when using the `inline` keyword.
* When someone wants to use multiple extensions in the same runtime, e,g., with FA3 and AO, then 2 `.so`s are loaded that may have been built with different libtorch versions. Thus, if an API is not inlined and are differently implemented, one implementation will be arbitrarily picked up and used across the runtime, depending on link order. This is bad!
* Consequently, we need to be very good at guaranteeing that we don't modify header implementations within a namespace. This is easy to mess up by accident, which would be a dire mistake.

Solution:
In essence, we want APIs in torch::headeronly and torch::stable to be visible in each individual extension only, and nowhere else. We want to hide these symbols! Thankfully, pybind already solved this problem (thanks @malfet for bringing that to my attention). This PR is heavily inspired by the code in pybind here: e6984c805e/include/pybind11/detail/pybind11_namespace_macros.h (L73-L82).

In this PR, we introduce the macros for defining hidden namespaces in PyTorch.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166076
Approved by: https://github.com/malfet
2025-10-25 00:18:26 +00:00
cc20b7ad72 [FlexFlash] update names (#166193)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166193
Approved by: https://github.com/BoyuanFeng
2025-10-25 00:07:11 +00:00
bc11a42b3f [inductor][ez] fix score fusion memory typo (#166029)
Fix https://github.com/pytorch/pytorch/issues/165724 .
The typo does not affect the compilation result. It just affect compilation time a little bit.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166029
Approved by: https://github.com/eellison
2025-10-24 23:48:05 +00:00
4fc06f2e0a Use std::min for #165927 (#166199)
Summary: Like D85463674 (pr https://github.com/pytorch/pytorch/pull/166195) but for D85357351 (https://github.com/pytorch/pytorch/pull/165927)

Differential Revision: D85464917

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166199
Approved by: https://github.com/Camyll, https://github.com/malfet, https://github.com/Skylion007
2025-10-24 23:19:00 +00:00
82473c3d59 [torch.export] Add original module type to UnflattenedModule class (#166145)
Summary: Currently all sub modules of UnflattenedModule have orginal type name. This diff will orginal type for UnflattenedModule.

Test Plan:
```
buck test mode/opt caffe2/test:test_export
```
https://www.internalfb.com/intern/testinfra/testrun/17732923654320197

Differential Revision: D85373454

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166145
Approved by: https://github.com/angelayi
2025-10-24 22:47:29 +00:00
b6a4236e5d [label_to_label] minor updates (#166172)
vllm-compile implies "module: vllm" and "oncall: pt2".
The volume of issues in Flex -> HigherOrderOperators is too noisy,
plus we have a different set of folks looking at each, so I'm going to
make that not automatic anymore. We can still manually label flex issues
as higher order operator issues.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166172
Approved by: https://github.com/angelayi
2025-10-24 22:47:23 +00:00
b04173be9b [ONNX] Add a test to backed_size_oblivious patch in onnx (#166196)
Follow-up https://github.com/pytorch/pytorch/pull/166151

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166196
Approved by: https://github.com/justinchuby
2025-10-24 22:47:10 +00:00
32ac38f85d [lint] workflow consistency linter to look at all files instead of just changed files (#165171)
As in title

If you change only one workflow file, lintrunner (default arg, also the one in CI since it only inputs changed files) won't look at other files in the repo, but the sync-tag might come from those other files

This makes it so that it looks at all workflow files so it will catch those failures

Also change output line so it prints which file + which job it is different from

Pros:
catches errors

Cons:
unusual behavior (getting around what lintrunner says the linter should run on)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165171
Approved by: https://github.com/malfet, https://github.com/izaitsevfb, https://github.com/atalman
2025-10-24 21:43:18 +00:00
c9b49e506e [MPS] Add linalg.householder_product for MPS (#166090)
Fixes #166089
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166090
Approved by: https://github.com/malfet
2025-10-24 21:13:56 +00:00
6038e476e8 [Dynamo][Logging]Fix regression on stack adding to latest bytecode by… (#165946)
… adding verbose check (#165926)

[ghstack-poisoned]

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165946
Approved by: https://github.com/williamwen42
2025-10-24 20:36:50 +00:00
2c851c16e5 [FX][ez] fix the split_module tutorial code (#166154)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166154
Approved by: https://github.com/BoyuanFeng
2025-10-24 20:16:04 +00:00
31584f2d91 Add a Claude skill for writing docstrings. (#166175)
Generated with prompt:

> torch/_tensor_docs.py and torch/nn/functional.py contain the "gold standard" for docstrings in the PyTorch project. Write a skill describing how to write a docstring for a function/method in the PyTorch project. Note that add_docstring is specifically for C binded functions; a native Python function can just be a direct docstring. Sphinx is used to generate docs.

Signed-off-by: Edward Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166175
Approved by: https://github.com/Skylion007
2025-10-24 20:05:44 +00:00
0442125362 [Inductor] Restore original dtype for rank-0 CPU tensors (#166118)
# Problem
Inductor implicitly upcasts certain rank-0 kernel arguments from float16 to float32. Currently, this happens only on the `"cpu"` device, which appears to be related to float16 support in CPU Triton. However, it can also affect the behavior of GPU kernels, when a model contains tensors from multiple devices. Upcasting may be undesirable on some platforms, so users can typically disable it with the `config.triton.codegen_upcast_to_fp32` flag. However, this flag was not respected by the rank-0 kernel argument codepath.

Through an improbable series of events, float32 upcasting caused an internal model to fail compilation on MTIA. (Internal reviewers see T242444110.)

# Fix
If `config.triton.codegen_upcast_to_fp32` evaluates to `False`, cast the kernel argument to the original dtype.

# Test plan
Added a new CI test checking for the downcast iff the config flag is false. The test mixes GPU and CPU tensors to generate a GPU kernel with the implicit float32 upcast and explicit float16 downcast.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166118
Approved by: https://github.com/jfix71, https://github.com/jansel, https://github.com/kundaMwiza
2025-10-24 19:59:25 +00:00
fdcf402d82 vllm test build (#166146)
FIx the vllm test build it's broken due to the flashinfer dependency
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166146
Approved by: https://github.com/huydhn
2025-10-24 19:18:10 +00:00
13cda9b89e Allow BlockDescriptorOptions classes to be overridden In TritonKernel (#165899)
By allowing the options classes (`BlockPtrOptions`/`TensorDescriptorOptions`) to be overridden in `TritonKernel`, subclasses with custom behaviour can be used in place of them, which provides greater flexibility.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165899
Approved by: https://github.com/jansel
2025-10-24 18:59:59 +00:00
fa6d911dda [MPS] Sparse mul enable tests and fix on MPS (#166164)
Apparently mul tests in test_sparse were disabled. The dense representation i.e. when nnz is not a scalar was broken on MPS. This PR fixes it and enables the tests in test_sparse.py

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166164
Approved by: https://github.com/malfet
2025-10-24 18:30:30 +00:00
0db6bcc015 Fix accuracy for layernorm/rmsnorm benchmarking (#166005)
Example command:
    python benchmarks/dynamo/genai_layers/benchmark.py --exit-on-accuracy-failure --tolerance=1e-2 rmsnorm_backward

Fix the accuracy problem for layernorm/rmsnorm fwd/bwd.
Also fix some quack calls (maybe due to quack API change)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166005
Approved by: https://github.com/BoyuanFeng
2025-10-24 18:14:51 +00:00
eqy
60ac039998 [CUDA][Grouped Gemm] remove xFail on Group GEMM tests after fallback was added (#165378)
https://github.com/pytorch/pytorch/pull/162059 means we get unexpected successes now on e.g., SM 12.0

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165378
Approved by: https://github.com/Skylion007
2025-10-24 17:42:40 +00:00
380d440d1c Revert "inductor: avoid unrolling argmin/argmax reductions to preserve index … (#164040)"
This reverts commit 9038a30cee56e0d577a666fffa32e990732572d4.

Reverted https://github.com/pytorch/pytorch/pull/164040 on behalf of https://github.com/karthickai due to Kindly add the test case mentioned in the issue ([comment](https://github.com/pytorch/pytorch/pull/164040#issuecomment-3444137989))
2025-10-24 17:14:45 +00:00
9038a30cee inductor: avoid unrolling argmin/argmax reductions to preserve index … (#164040)
…semantics on views; add regression test for transposed mutation (fixes #163929)

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164040
Approved by: https://github.com/ngimel, https://github.com/jansel
2025-10-24 16:37:43 +00:00
690c8c13b9 Revert "Export should use aot_export_joint_with_descriptors (#165931)"
This reverts commit 882b834082719afd8ee41769c2cb224bc9031632.

Reverted https://github.com/pytorch/pytorch/pull/165931 on behalf of https://github.com/clee2000 due to breaking internal tests D85084301 for test_auto_functionalize?  I checked that they did run on OSS CI so I'm not entirely sure whats going on, I assume its the IS_FBCODE stuff ([comment](https://github.com/pytorch/pytorch/pull/165931#issuecomment-3443887361))
2025-10-24 16:02:20 +00:00
28ee6b62ed Revert "[DeviceMesh] Implement a device mesh concatenate api for submesh and SPMD use case (#163358)"
This reverts commit 5a4997dcae47acf69c929ac5b081321143bfbf11.

Reverted https://github.com/pytorch/pytorch/pull/163358 on behalf of https://github.com/clee2000 due to probably need to revert this one  too, its stacked with https://github.com/pytorch/pytorch/pull/166003#issuecomment-3443668389 ([comment](https://github.com/pytorch/pytorch/pull/163358#issuecomment-3443874910))
2025-10-24 15:58:54 +00:00
81577bdb3f Revert "[DeviceMesh] Use _flatten_rank_map to replace _flatten_mesh_list so that we don't need to compare root mesh (#166003)"
This reverts commit 8625ffbd45884464f736cfc61300c14f47633641.

Reverted https://github.com/pytorch/pytorch/pull/166003 on behalf of https://github.com/clee2000 due to failing internal tests D85405179 I believe there are uses of _flatten_mesh_list internally that need to be updated ([comment](https://github.com/pytorch/pytorch/pull/166003#issuecomment-3443668389))
2025-10-24 15:14:23 +00:00
e67e3d95f3 Simplify the CUPTI CMake check for kineto (#161370)
Simplify the CUPTI check because kineto has used `CUDA::cupti`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/161370
Approved by: https://github.com/Skylion007
2025-10-24 08:13:17 +00:00
27af8480ea Refactor api and configs of overlapping (#166130)
- pass important configs values directly into the class
- migrate those configs from `test_configs` to another class
- add an (off by default) config to enable inside inductor, instead of requiring a custom post pass

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166130
Approved by: https://github.com/bdhirsh
2025-10-24 07:03:54 +00:00
6494cdc40c [DebugMode] add nn.Module tracking (#165498)
Uses ModTracker to record nn.Module entries, much like CommDebugMode.

Can be switched on with `DebugMode(record_nn_module=True)`:
```
    [nn.Mod] Bar
      [nn.Mod] Bar.abc
        [nn.Mod] Bar.abc.l1
          aten::t(t: f32[4, 4])
          aten::addmm(t: f32[4], t: f32[4, 4], t: f32[4, 4])
        [nn.Mod] Bar.abc.l2
          aten::t(t: f32[4, 4])
          aten::addmm(t: f32[4], t: f32[4, 4], t: f32[4, 4])
      [nn.Mod] Bar.xyz
        aten::t(t: f32[4, 4])
        aten::addmm(t: f32[4], t: f32[4, 4], t: f32[4, 4])"""
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165498
Approved by: https://github.com/SherlockNoMad
2025-10-24 05:08:33 +00:00
ac7074efa2 [CUDA][cuBLAS] Fix a compilation issue in #163955 when CUDA_VERSION < 12010 (#166137)
Summary:
This PR fixes a compilation issue when `CUDA_VERSION < 12010`. Even if we might drop old CUDA support, let's correct the code itself.

## Issue
When `CUDA_VERSION` is `12010`, the following does not compile.
```
      mat1_sizes[0] > 1 && mat1_sizes[1] > 1 &&
      mat2_sizes[0] > 1 && mat2_sizes[1] > 1
      #if !(defined(CUDA_VERSION) && CUDA_VERSION >= 12010 || defined(USE_ROCM))
      // Here not "&&"
      mat2_sizes[0] < 65535 * 32 && mat2_sizes[1] < 65535 * 32 &&
```
This patch adds "&&"

Test Plan: CI

Differential Revision: D85356831

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166137
Approved by: https://github.com/ngimel, https://github.com/cyyever
2025-10-24 04:06:03 +00:00
263901cec4 [pytorch/kineto] Update Kineto Submodule (#166150)
Summary: Update to include some race condition fixes.

Test Plan: n/a

Differential Revision: D85390799

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166150
Approved by: https://github.com/sraikund16, https://github.com/cyyever
2025-10-24 04:03:13 +00:00
c12293dcbe [ONNX] Cover all FX passes into backed size oblivious (#166151)
Found a bug that after `run_decomposition()`, the shape could be fixed to 1. It's caused by the fact that all FX graph (related to shape inference) surgery should happen inside backed size oblivious patch.

```python
import torch
from transformers.models.phi3.modeling_phi3 import Phi3RMSNorm

# Previous to this PR, this will generate a fixed batch size
op = torch.onnx.export(
    Phi3RMSNorm(256).eval(),
    args=(),
    kwargs={"hidden_states": torch.rand((1, 32, 256))},
    dynamic_shapes={"hidden_states": {0: torch.export.Dim.DYNAMIC, 1: torch.export.Dim.DYNAMIC}},
)

# It is dynamic when it's only in torch.export
with torch.fx.experimental._config.patch(backed_size_oblivious=True):
    ep = torch.onnx.export(
    Phi3RMSNorm(256).eval(),
    args=(),
    kwargs={"hidden_states": torch.rand((1, 32, 256))},
    dynamic_shapes={"hidden_states": {0: torch.export.Dim.DYNAMIC, 1: torch.export.Dim.DYNAMIC}},
)
# But when run_decomposition is called outside of the patch, it is static.
# ep = ep.run_decompositions()
print(ep)

```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166151
Approved by: https://github.com/justinchuby
2025-10-24 03:25:16 +00:00
5a4997dcae [DeviceMesh] Implement a device mesh concatenate api for submesh and SPMD use case (#163358)
Today FSDP needs to slicing out spmd mesh from root mesh here: https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/_fully_shard/_fsdp_param.py#L301. But essentially, users want is a concatenate of some submesh into a big mesh and used as a spmd mesh. This PR is tentatively trying to implement this API for users.

One thing to note is that, all sub-mesh needs to slicing/flatten or unflatten from same root mesh otherwise the indices make no sense when it comes to mesh indexing and device allocation.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163358
Approved by: https://github.com/fegin
ghstack dependencies: #166003
2025-10-23 23:31:17 +00:00
47f638eae7 [ROCm] deserialize loads in planer sum portion of stats() of norm (#166021)
Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166021
Approved by: https://github.com/jeffdaily
2025-10-23 22:47:42 +00:00
882b834082 Export should use aot_export_joint_with_descriptors (#165931)
This diff moves export run_decompositions to use aot_export_joint_with_descriptors instead of aot_export_module. Doing so, i ran into 2 main bugs:
1) aot_export_joint_with_descriptors don't correctly pass in record_nn_module_stack flag that is needed to populate nn_module_stack by switching the internal tracer.
2) When creating symint with negative inputs, we need to pass in positive=False. This didn't matter before because aot_autograd directly returns integer inputs instead of creating symint.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165931
Approved by: https://github.com/zhxchen17
2025-10-23 22:42:11 +00:00
b146ea411e Save GitHub env variables on ROCm (#165821)
As `.github/actions/setup-rocm/action.yml` is now used on `linux_job_v2` to setup ROCm, we need to have this step here to save the list of GitHub env variables.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165821
Approved by: https://github.com/atalman
2025-10-23 22:13:37 +00:00
8625ffbd45 [DeviceMesh] Use _flatten_rank_map to replace _flatten_mesh_list so that we don't need to compare root mesh (#166003)
Since we are already share a flattened tensor `_rank_map` across all meshes from a same root mesh, we can just use a flattened list of it to replace the comparison of root_mesh and flattened_mesh_list (because with same _rank_map and layout, the mesh tensor is guaranteed to be the same). This way we can also give back the CPU overhead added in https://github.com/pytorch/pytorch/pull/164510 and further simply the code.

We do have a more ambitious universe-based change here: https://github.com/pytorch/pytorch/pull/165680 but it needs more discussions and would lead to BC breaking. We might eventually merge that PR but probably not now and this is a change which is not BC breaking and will help concatenate and 2D integration with concatenate.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166003
Approved by: https://github.com/Skylion007, https://github.com/fegin
2025-10-23 20:49:59 +00:00
866 changed files with 8013 additions and 4894 deletions

View File

@ -0,0 +1,354 @@
# PyTorch Docstring Writing Guide
This skill describes how to write docstrings for functions and methods in the PyTorch project, following the conventions in `torch/_tensor_docs.py` and `torch/nn/functional.py`.
## General Principles
- Use **raw strings** (`r"""..."""`) for all docstrings to avoid issues with LaTeX/math backslashes
- Follow **Sphinx/reStructuredText** (reST) format for documentation
- Be **concise but complete** - include all essential information
- Always include **examples** when possible
- Use **cross-references** to related functions/classes
## Docstring Structure
### 1. Function Signature (First Line)
Start with the function signature showing all parameters:
```python
r"""function_name(param1, param2, *, kwarg1=default1, kwarg2=default2) -> ReturnType
```
**Notes:**
- Include the function name
- Show positional and keyword-only arguments (use `*` separator)
- Include default values
- Show return type annotation
- This line should NOT end with a period
### 2. Brief Description
Provide a one-line description of what the function does:
```python
r"""conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1) -> Tensor
Applies a 2D convolution over an input image composed of several input
planes.
```
### 3. Mathematical Formulas (if applicable)
Use Sphinx math directives for mathematical expressions:
```python
.. math::
\text{Softmax}(x_{i}) = \frac{\exp(x_i)}{\sum_j \exp(x_j)}
```
Or inline math: `:math:\`x^2\``
### 4. Cross-References
Link to related classes and functions using Sphinx roles:
- `:class:\`~torch.nn.ModuleName\`` - Link to a class
- `:func:\`torch.function_name\`` - Link to a function
- `:meth:\`~Tensor.method_name\`` - Link to a method
- `:attr:\`attribute_name\`` - Reference an attribute
- The `~` prefix shows only the last component (e.g., `Conv2d` instead of `torch.nn.Conv2d`)
**Example:**
```python
See :class:`~torch.nn.Conv2d` for details and output shape.
```
### 5. Notes and Warnings
Use admonitions for important information:
```python
.. note::
This function doesn't work directly with NLLLoss,
which expects the Log to be computed between the Softmax and itself.
Use log_softmax instead (it's faster and has better numerical properties).
.. warning::
:func:`new_tensor` always copies :attr:`data`. If you have a Tensor
``data`` and want to avoid a copy, use :func:`torch.Tensor.requires_grad_`
or :func:`torch.Tensor.detach`.
```
### 6. Args Section
Document all parameters with type annotations and descriptions:
```python
Args:
input (Tensor): input tensor of shape :math:`(\text{minibatch} , \text{in\_channels} , iH , iW)`
weight (Tensor): filters of shape :math:`(\text{out\_channels} , kH , kW)`
bias (Tensor, optional): optional bias tensor of shape :math:`(\text{out\_channels})`. Default: ``None``
stride (int or tuple): the stride of the convolving kernel. Can be a single number or a
tuple `(sH, sW)`. Default: 1
```
**Formatting rules:**
- Parameter name in **lowercase**
- Type in parentheses: `(Type)`, `(Type, optional)` for optional parameters
- Description follows the type
- For optional parameters, include "Default: ``value``" at the end
- Use double backticks for inline code: ``` ``None`` ```
- Indent continuation lines by 2 spaces
### 7. Keyword Args Section (if applicable)
Sometimes keyword arguments are documented separately:
```python
Keyword args:
dtype (:class:`torch.dtype`, optional): the desired type of returned tensor.
Default: if None, same :class:`torch.dtype` as this tensor.
device (:class:`torch.device`, optional): the desired device of returned tensor.
Default: if None, same :class:`torch.device` as this tensor.
requires_grad (bool, optional): If autograd should record operations on the
returned tensor. Default: ``False``.
```
### 8. Returns Section (if needed)
Document the return value:
```python
Returns:
Tensor: Sampled tensor of same shape as `logits` from the Gumbel-Softmax distribution.
If ``hard=True``, the returned samples will be one-hot, otherwise they will
be probability distributions that sum to 1 across `dim`.
```
Or simply include it in the function signature line if obvious from context.
### 9. Examples Section
Always include examples when possible:
```python
Examples::
>>> inputs = torch.randn(33, 16, 30)
>>> filters = torch.randn(20, 16, 5)
>>> F.conv1d(inputs, filters)
>>> # With square kernels and equal stride
>>> filters = torch.randn(8, 4, 3, 3)
>>> inputs = torch.randn(1, 4, 5, 5)
>>> F.conv2d(inputs, filters, padding=1)
```
**Formatting rules:**
- Use `Examples::` with double colon
- Use `>>>` prompt for Python code
- Include comments with `#` when helpful
- Show actual output when it helps understanding (indent without `>>>`)
### 10. External References
Link to papers or external documentation:
```python
.. _Link Name:
https://arxiv.org/abs/1611.00712
```
Reference them in text: ```See `Link Name`_```
## Method Types
### Native Python Functions
For regular Python functions, use a standard docstring:
```python
def relu(input: Tensor, inplace: bool = False) -> Tensor:
r"""relu(input, inplace=False) -> Tensor
Applies the rectified linear unit function element-wise. See
:class:`~torch.nn.ReLU` for more details.
"""
# implementation
```
### C-Bound Functions (using add_docstr)
For C-bound functions, use `_add_docstr`:
```python
conv1d = _add_docstr(
torch.conv1d,
r"""
conv1d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1) -> Tensor
Applies a 1D convolution over an input signal composed of several input
planes.
See :class:`~torch.nn.Conv1d` for details and output shape.
Args:
input: input tensor of shape :math:`(\text{minibatch} , \text{in\_channels} , iW)`
weight: filters of shape :math:`(\text{out\_channels} , kW)`
...
""",
)
```
### In-Place Variants
For in-place operations (ending with `_`), reference the original:
```python
add_docstr_all(
"abs_",
r"""
abs_() -> Tensor
In-place version of :meth:`~Tensor.abs`
""",
)
```
### Alias Functions
For aliases, simply reference the original:
```python
add_docstr_all(
"absolute",
r"""
absolute() -> Tensor
Alias for :func:`abs`
""",
)
```
## Common Patterns
### Shape Documentation
Use LaTeX math notation for tensor shapes:
```python
:math:`(\text{minibatch} , \text{in\_channels} , iH , iW)`
```
### Reusable Argument Definitions
For commonly used arguments, define them once and reuse:
```python
common_args = parse_kwargs(
"""
dtype (:class:`torch.dtype`, optional): the desired type of returned tensor.
Default: if None, same as this tensor.
"""
)
# Then use with .format():
r"""
...
Keyword args:
{dtype}
{device}
""".format(**common_args)
```
### Template Insertion
Insert reproducibility notes or other common text:
```python
r"""
{tf32_note}
{cudnn_reproducibility_note}
""".format(**reproducibility_notes, **tf32_notes)
```
## Complete Example
Here's a complete example showing all elements:
```python
def gumbel_softmax(
logits: Tensor,
tau: float = 1,
hard: bool = False,
eps: float = 1e-10,
dim: int = -1,
) -> Tensor:
r"""
Sample from the Gumbel-Softmax distribution and optionally discretize.
Args:
logits (Tensor): `[..., num_features]` unnormalized log probabilities
tau (float): non-negative scalar temperature
hard (bool): if ``True``, the returned samples will be discretized as one-hot vectors,
but will be differentiated as if it is the soft sample in autograd. Default: ``False``
dim (int): A dimension along which softmax will be computed. Default: -1
Returns:
Tensor: Sampled tensor of same shape as `logits` from the Gumbel-Softmax distribution.
If ``hard=True``, the returned samples will be one-hot, otherwise they will
be probability distributions that sum to 1 across `dim`.
.. note::
This function is here for legacy reasons, may be removed from nn.Functional in the future.
Examples::
>>> logits = torch.randn(20, 32)
>>> # Sample soft categorical using reparametrization trick:
>>> F.gumbel_softmax(logits, tau=1, hard=False)
>>> # Sample hard categorical using "Straight-through" trick:
>>> F.gumbel_softmax(logits, tau=1, hard=True)
.. _Link 1:
https://arxiv.org/abs/1611.00712
"""
# implementation
```
## Quick Checklist
When writing a PyTorch docstring, ensure:
- [ ] Use raw string (`r"""`)
- [ ] Include function signature on first line
- [ ] Provide brief description
- [ ] Document all parameters in Args section with types
- [ ] Include default values for optional parameters
- [ ] Use Sphinx cross-references (`:func:`, `:class:`, `:meth:`)
- [ ] Add mathematical formulas if applicable
- [ ] Include at least one example in Examples section
- [ ] Add warnings/notes for important caveats
- [ ] Link to related module class with `:class:`
- [ ] Use proper math notation for tensor shapes
- [ ] Follow consistent formatting and indentation
## Common Sphinx Roles Reference
- `:class:\`~torch.nn.Module\`` - Class reference
- `:func:\`torch.function\`` - Function reference
- `:meth:\`~Tensor.method\`` - Method reference
- `:attr:\`attribute\`` - Attribute reference
- `:math:\`equation\`` - Inline math
- `:ref:\`label\`` - Internal reference
- ``` ``code`` ``` - Inline code (use double backticks)
## Additional Notes
- **Indentation**: Use 4 spaces for code, 2 spaces for continuation of parameter descriptions
- **Line length**: Try to keep lines under 100 characters when possible
- **Periods**: End sentences with periods, but not the signature line
- **Backticks**: Use double backticks for code: ``` ``True`` ``None`` ``False`` ```
- **Types**: Common types are `Tensor`, `int`, `float`, `bool`, `str`, `tuple`, `list`, etc.

View File

@ -124,3 +124,10 @@ runs:
id: login-ecr
continue-on-error: true
uses: aws-actions/amazon-ecr-login@062b18b96a7aff071d4dc91bc00c4c1a7945b076 # v2.0.1
- name: Preserve github env variables for use in docker
shell: bash
run: |
env | grep '^GITHUB' >> "${RUNNER_TEMP}/github_env_${GITHUB_RUN_ID}"
env | grep '^CI' >> "${RUNNER_TEMP}/github_env_${GITHUB_RUN_ID}"
env | grep '^RUNNER' >> "${RUNNER_TEMP}/github_env_${GITHUB_RUN_ID}"

View File

@ -1 +1 @@
0fa6e3129e61143224663e1ec67980d12b7ec4eb
df6798dfb931ce7c7fe5bed2447cd1092a5981af

View File

@ -283,6 +283,9 @@ RUN --mount=type=bind,source=${TORCH_WHEELS_PATH},target=/dist \
uv pip install --system $(cat torch_build_versions.txt | xargs) --index-url https://download.pytorch.org/whl/nightly/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.'); \
fi
RUN --mount=type=cache,target=/root/.cache/uv \
uv pip install --system --pre apache-tvm-ffi==0.1.0b15
# Install the vllm wheel from previous stage
RUN --mount=type=cache,target=/root/.cache/uv \
uv pip install --system /wheels/vllm/*.whl --verbose
@ -295,6 +298,8 @@ RUN --mount=type=cache,target=/root/.cache/uv \
ARG torch_cuda_arch_list='8.0;8.9;9.0a;10.0a;12.0'
ENV TORCH_CUDA_ARCH_LIST=${torch_cuda_arch_list}
# TODO(elainewy): remove this once vllm commit is updated, and install flashinfer from pip
# see https://github.com/pytorch/pytorch/pull/165274#issuecomment-3408531784
ARG FLASHINFER_GIT_REPO="https://github.com/flashinfer-ai/flashinfer.git"
ARG FLASHINFER_GIT_REF="v0.2.14.post1"

View File

@ -15,6 +15,11 @@
- "module: reinplacing"
then:
- "module: pt2-dispatcher"
- any:
- "vllm-compile"
then:
- "module: vllm"
- "oncall: pt2"
- any:
- "module: vmap"
then:
@ -27,10 +32,6 @@
- "module: pt2 optimizer"
then:
- "module: dynamo"
- any:
- "module: flex attention"
then:
- "module: higher order operators"
- any:
- "module: aotinductor"
then:

View File

@ -833,8 +833,7 @@ exclude_patterns = [
command = [
'python3',
'tools/linter/adapters/grep_linter.py',
'--pattern=cudaSetDevice(',
'--pattern=cudaGetDevice(',
'--pattern=(cudaSetDevice|cudaGetDevice)\\(',
'--linter-name=RAWCUDADEVICE',
'--error-name=raw CUDA API usage',
"""--error-description=\

View File

@ -38,7 +38,7 @@ set_bool(AT_HIPSPARSELT_ENABLED CAFFE2_USE_HIPSPARSELT)
configure_file(Config.h.in "${CMAKE_CURRENT_SOURCE_DIR}/Config.h")
# TODO: Do not generate CUDAConfig.h for ROCm BUILDS
# At the moment, `jit_macors.h` include CUDAConfig.h for both CUDA and HIP builds
# At the moment, `jit_macros.h` include CUDAConfig.h for both CUDA and HIP builds
if(USE_CUDA OR USE_ROCM)
configure_file(cuda/CUDAConfig.h.in "${CMAKE_CURRENT_SOURCE_DIR}/cuda/CUDAConfig.h")
endif()

View File

@ -122,7 +122,7 @@ void FunctionalTensorWrapper::freeze_storage() const {
// | have their own storages, but backends like functorch |
// \/ are allowed to re-alias underneath the pass \/
// . - - - - - - - - - - - - - . . - - - - - - - - - - - - - - - .
// | underyling_storage | | underyling_storage |
// | underlying_storage | | underlying_storage |
// . - - - - - - - - - - - - - . . - - - - - - - - - - - - - - - .
//
// This constructor is only used by view ops.

View File

@ -1534,7 +1534,7 @@ void TensorIteratorBase::build(TensorIteratorConfig& config) {
// XLA and lazy tensors don't have storage, so they don't have an underlying data pointer.
// Nothing beyond this point is important for meta functions, so it's fine to exit early here.
// Extend the condition to MAIA tesnors as MAIA tensors also don't have storage.
// Extend the condition to MAIA tensors as MAIA tensors also don't have storage.
if (privateuse1_without_storage ||
common_device_.type() == DeviceType::XLA ||
common_device_.type() == DeviceType::IPU ||

View File

@ -94,11 +94,11 @@ struct PinnedReserveSegment {
struct TORCH_API HostStats {
// COUNT: total allocations (active)
Stat active_requests;
// SUM: bytes allocated/reserved by this memory alocator. (active)
// SUM: bytes allocated/reserved by this memory allocator. (active)
Stat active_bytes;
// COUNT: total allocations (active + free)
Stat allocations;
// SUM: bytes allocated/reserved by this memory alocator. This accounts
// SUM: bytes allocated/reserved by this memory allocator. This accounts
// for both free and in-use blocks.
Stat allocated_bytes;
@ -127,7 +127,7 @@ struct alignas(hardware_destructive_interference_size) HostStatsStaged {
// COUNT: total allocations (active + free)
// LOCK: access to this stat is protected by the allocator's blocks_mutex_
Stat allocations;
// SUM: bytes allocated/reserved by this memory alocator. This accounts
// SUM: bytes allocated/reserved by this memory allocator. This accounts
// for both free and in-use blocks.
Stat allocated_bytes;
// COUNT: number of allocations per bucket (active)
@ -455,7 +455,7 @@ struct CachingHostAllocatorImpl {
}
void resetAccumulatedStats() {
// Reseting accumulated memory stats requires concurrently holding both the
// Resetting accumulated memory stats requires concurrently holding both the
// free list mutexes and the blocks mutex. Previously, this was only done in
// empty_cache function.
for (size_t i = 0; i < free_list_.size(); ++i) {
@ -482,7 +482,7 @@ struct CachingHostAllocatorImpl {
}
void resetPeakStats() {
// Reseting peak memory stats requires concurrently holding both the
// Resetting peak memory stats requires concurrently holding both the
// free list mutexes and the blocks mutex. Previously, this was only done in
// empty_cache function.
for (size_t i = 0; i < free_list_.size(); ++i) {

View File

@ -109,6 +109,10 @@ TORCH_LIBRARY_IMPL(_, AutogradHPU, m) {
m.fallback(AUTOGRAD_FALLBACK);
}
TORCH_LIBRARY_IMPL(_, AutogradPrivateUse1, m) {
m.fallback(AUTOGRAD_FALLBACK);
}
#undef AUTOGRAD_FALLBACK
} // namespace

View File

@ -148,7 +148,7 @@ struct TORCH_API ClassType : public NamedType {
void checkNotExist(const std::string& name, const std::string& what) const;
// Attributes are stored in a specific slot at runtime for effiency.
// Attributes are stored in a specific slot at runtime for efficiency.
// When emitting instructions we specify the slot so that attribute access is
// a constant lookup
std::optional<size_t> findAttributeSlot(const std::string& name) const {
@ -412,7 +412,7 @@ struct TORCH_API ClassType : public NamedType {
// Holds method attributes
std::weak_ptr<CompilationUnit> compilation_unit_;
// Holds all atrributes, attribute details are found on ClassAttribute
// Holds all attributes, attribute details are found on ClassAttribute
std::vector<ClassAttribute> attributes_;
// Construct mirroring attributes_, only around due to the fact that `containedTypes()` method returns an ArrayRef.
// Never fill this without using the appropriate provideNewClassAttribute method

View File

@ -442,11 +442,17 @@ RegistrationHandleRAII Dispatcher::registerFallback(DispatchKey dispatchKey, Ker
auto idx = getDispatchTableIndexForDispatchKey(dispatchKey);
TORCH_CHECK(idx >= 0 && static_cast<uint64_t>(idx) < backendFallbackKernels_.size(), "idx=", idx);
// NB: Perserve BC for registering fallback for AutogradPrivateUse1 multiple time,
// refer to https://github.com/pytorch/pytorch/issues/163979 for more informations.
TORCH_CHECK(
!backendFallbackKernels_[idx].kernel.isValid(),
"Tried to register multiple backend fallbacks for the same dispatch key ", dispatchKey, "; previous registration ",
backendFallbackKernels_[idx].debug, ", new registration ", debug
);
dispatchKey == DispatchKey::AutogradPrivateUse1 ||
!backendFallbackKernels_[idx].kernel.isValid(),
"Tried to register multiple backend fallbacks for the same dispatch key ",
dispatchKey,
"; previous registration ",
backendFallbackKernels_[idx].debug,
", new registration ",
debug);
// NB: inferred function schema is always nullptr for fallbacks, as fallbacks
// cannot be unboxed
backendFallbackKernels_[idx] = impl::AnnotatedKernel(std::move(kernel), nullptr, std::move(debug));
@ -531,7 +537,7 @@ int64_t Dispatcher::sequenceNumberForRunningRecordFunction(DispatchKey dispatchK
// Note: this records a sequence number for both Autograd keys, and for
// non-Autograd keys where the dispatchKeySet still contains an autograd key.
// This means that we might collect the same sequence nubmer two different
// This means that we might collect the same sequence number two different
// events if they all occurred above Autograd and still had the Autograd
// dispatch key in the dispatch key set.
// However, this usually doesn't happen: normally the first call will

View File

@ -585,7 +585,7 @@ class TORCH_API OperatorHandle {
// We need to store this iterator in order to make
// Dispatcher::cleanup() fast -- it runs a lot on program
// termination (and presuambly library unloading).
// termination (and presumably library unloading).
std::list<Dispatcher::OperatorDef>::iterator operatorIterator_;
};

View File

@ -365,7 +365,7 @@ std::pair<const AnnotatedKernel&, const char*> OperatorEntry::computeDispatchTab
// For autograd keys, we only use kernel from CompositeImplicitAutograd when there's no direct registration
// to its corresponding backend key or CompositeExplicitAutograd. See Note [CompositeExplicitAutograd and CompositeImplicitAutograd].
// For AutogradOther, we eagerly return ambiguousAutogradOtherKernel() if there's registration to any of
// its backends and ask backend extender to request a decicated Autograd key for the backend.
// its backends and ask backend extender to request a dedicated Autograd key for the backend.
// See Note [Ambiguity in AutogradOther kernel] for more details.
// A CompositeExplicitAutograd kernel prevents CompositeImplicitAutograd kernel being used for Autograd keys, but it doesn't
// cause confusion for AutogradOther. It's pretty straightforward to use Autograd (if available)

View File

@ -261,7 +261,7 @@ std::ostream& operator<<(std::ostream& out, const FunctionSchema& schema) {
//
// There are 2 cases
// 1. something like 'aten::items.str(Dict(str, t) self) -> ((str, t)[])'.
// without the extra parenthesis, the c++ schem parser can not parse it.
// without the extra parenthesis, the c++ scheme parser can not parse it.
// 2. something like '-> ((str, str))'. Need extra parenthesis so the return
// type is a single tuple rather than two strings.
// PR (https://github.com/pytorch/pytorch/pull/23204) has more context about

View File

@ -1176,7 +1176,7 @@ struct TORCH_API IValue final {
using HashIdentityIValueMap =
std::unordered_map<IValue, IValue, HashIdentityIValue, CompIdentityIValues>;
// Chechs if this and rhs has a subvalues in common.
// Checks if this and rhs has a subvalues in common.
// [t1,t2] and [t2, t3] returns true.
bool overlaps(const IValue& rhs) const;

View File

@ -1501,7 +1501,7 @@ struct C10_EXPORT ivalue::Object final : c10::intrusive_ptr_target {
// However, the CompilationUnit holds ownership of the type's graphs, so
// inserting a constant object into a Graph would create a reference cycle if
// that constant object held a shared_ptr to its CU. For these objects we
// instatiate them with non-owning references to its CU
// instantiate them with non-owning references to its CU
Object(WeakOrStrongTypePtr type, size_t numSlots) : type_(std::move(type)) {
slots_.resize(numSlots);
}

View File

@ -373,7 +373,7 @@ struct TORCH_API SymbolicShape {
// Unranked shape constructor.
SymbolicShape() : dims_(std::nullopt) {}
// Known rank but unknown dimentions.
// Known rank but unknown dimensions.
SymbolicShape(std::optional<size_t> rank) : dims_(std::nullopt) {
if(!rank) {
return;
@ -884,9 +884,9 @@ struct TORCH_API ListType
// global singleton
// Given an inner type T and an identifier,
// this function wil return the global singleton type pointer
// this function will return the global singleton type pointer
// the type List<T>.
// The extra "identifier" argument is needed beccause we have multiple container types
// The extra "identifier" argument is needed because we have multiple container types
// that all re-use this function (List<T>, array<T, N>, etc.)
static TypePtr get(const std::string& identifier, TypePtr inner);

View File

@ -185,11 +185,11 @@ struct TORCH_API Type {
: repr_(nullptr) {}
/* implicit */ SingletonOrSharedTypePtr(SingletonTypePtr<T> p)
: repr_(p) {}
: repr_(makeSingletonSharedPtr(p.get())) {}
template <typename U, std::enable_if_t<std::is_convertible_v<U*, T*>, bool> = true>
/* implicit */ SingletonOrSharedTypePtr(SingletonTypePtr<U> p)
: repr_(SingletonTypePtr<T>(p.get())) {}
: repr_(makeSingletonSharedPtr(static_cast<T*>(p.get()))) {}
// We need to support construction from T* for pybind. The problem
@ -202,8 +202,8 @@ struct TORCH_API Type {
// Case 2: if T is exactly Type, we need to do a dynamic_cast to
// check if it's a SharedType and do the right thing.
//
// Case 3: Otherwise, T is not a SharedType. (debug-check this
// assumption!) Use a singleton pointer.
// Case 3: Otherwise, T is not a SharedType. Use a singleton
// pointer.
template <typename U = T, std::enable_if_t<std::is_base_of_v<SharedType, U>, bool> = true>
/* implicit */ SingletonOrSharedTypePtr(T* p) : SingletonOrSharedTypePtr(static_cast<typename detail::as_shared_type<U>::type>(p)->shared_from_this()) {}
@ -211,15 +211,15 @@ struct TORCH_API Type {
template <typename U = T, std::enable_if_t<std::is_same_v<Type, U>, bool> = true>
/* implicit */ SingletonOrSharedTypePtr(T* p) {
if (auto* shared_p = dynamic_cast<typename detail::as_shared_type<U>::type>(p)) {
repr_ = Repr(shared_p->shared_from_this());
repr_ = shared_p->shared_from_this();
} else {
repr_ = Repr(p);
repr_ = makeSingletonSharedPtr(p);
}
}
template <typename U = T, std::enable_if_t<!std::is_same_v<Type, U> && !std::is_base_of_v<SharedType, U>, bool> = true>
/* implicit */ SingletonOrSharedTypePtr(T* p)
: repr_(p) {
: repr_(makeSingletonSharedPtr(p)) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(dynamic_cast<typename detail::as_shared_type<U>::type>(p) == nullptr);
}
@ -230,19 +230,19 @@ struct TORCH_API Type {
~SingletonOrSharedTypePtr() = default;
T* get() const {
return repr_.isSharedAndNonNull() ? repr_.shared_.repr_.get() : static_cast<T*>(repr_.rawRepr().first);
return repr_.get();
}
operator bool() const {
return repr_.isNonNull();
return repr_ != nullptr;
}
bool operator==(std::nullptr_t) const {
return !repr_.isNonNull();
return repr_ == nullptr;
}
bool operator!=(std::nullptr_t) const {
return repr_.isNonNull();
return repr_ != nullptr;
}
template <typename U = T, std::enable_if_t<!std::is_same_v<std::remove_const_t<U>, void>, bool> = true>
@ -255,138 +255,14 @@ struct TORCH_API Type {
}
private:
// NOTE: SharedPtrWrapper exists to work around a baffling bug in
// nvcc; see comment in destroy() below.
struct SharedPtrWrapper {
SharedPtrWrapper(std::shared_ptr<T> &&x)
: repr_(std::move(x)) {}
std::shared_ptr<T> repr_;
};
union Repr {
Repr() : Repr(nullptr) {}
// Use shared_ptr's aliasing constructor to create a non-owning pointer
// to a singleton. The lifetime is tied to the null shared_ptr, so there's
// no reference counting overhead for the singleton itself.
static std::shared_ptr<T> makeSingletonSharedPtr(T* ptr) {
return std::shared_ptr<T>(std::shared_ptr<T>(), ptr);
}
explicit Repr(std::shared_ptr<T> x)
: shared_(std::move(x)) {}
explicit Repr(std::nullptr_t)
: singletonRepr_(nullptr) {}
explicit Repr(SingletonTypePtr<T> p)
: singletonRepr_(p.get()) {}
~Repr() {
destroy();
}
// NOTE: the only non-UB way to access our null state is through
// rawRepr(), because our copy operation doesn't preserve which
// union member is active for null pointers.
Repr(const Repr& rhs) {
if (rhs.isSharedAndNonNull()) {
new (&shared_) SharedPtrWrapper(rhs.shared_);
} else {
singletonRepr_.singleton_ = static_cast<T*>(rhs.rawRepr().first);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(rhs.singletonRepr_.unused_ == nullptr);
singletonRepr_.unused_ = nullptr;
}
}
Repr(Repr&& rhs) noexcept {
if (rhs.isSharedAndNonNull()) {
new (&shared_) SharedPtrWrapper(std::move(rhs.shared_));
} else {
singletonRepr_.singleton_ = static_cast<T*>(rhs.rawRepr().first);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(rhs.singletonRepr_.unused_ == nullptr);
singletonRepr_.unused_ = nullptr;
}
}
Repr& operator=(const Repr& rhs) {
if (&rhs == this) {
return *this;
}
if (rhs.isSharedAndNonNull()) {
if (isSharedAndNonNull()) {
shared_ = rhs.shared_;
} else {
new (&shared_) SharedPtrWrapper(rhs.shared_);
}
} else {
if (isSharedAndNonNull()) {
destroy();
}
singletonRepr_.singleton_ = static_cast<T*>(rhs.rawRepr().first);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(rhs.rawRepr().nullIfSingleton_ == nullptr);
singletonRepr_.unused_ = nullptr;
}
return *this;
}
Repr& operator=(Repr&& rhs) noexcept {
if (&rhs == this) {
return *this;
}
if (rhs.isSharedAndNonNull()) {
if (isSharedAndNonNull()) {
shared_ = std::move(rhs.shared_);
} else {
new (&shared_) SharedPtrWrapper(std::move(rhs.shared_));
}
} else {
if (isSharedAndNonNull()) {
destroy();
}
singletonRepr_.singleton_ = static_cast<T*>(rhs.rawRepr().first);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(rhs.rawRepr().nullIfSingleton_ == nullptr);
singletonRepr_.unused_ = nullptr;
}
return *this;
}
SharedPtrWrapper shared_;
struct SingletonRepr {
explicit SingletonRepr(T* s) : singleton_(s) {}
T* singleton_;
void* unused_ = nullptr;
} singletonRepr_;
struct RawRepr {
void* first;
void* nullIfSingleton_;
};
// It is UB to read the singleton part of Repr if it was
// constructed as a shared_ptr and vice versa, but memcpying out
// the representation is always OK, so here's an accessor to obey
// the letter of the law.
RawRepr rawRepr() const {
RawRepr repr{};
memcpy(&repr, reinterpret_cast<const char *>(this), sizeof(RawRepr));
return repr;
}
bool isNonNull() const {
auto repr = rawRepr();
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(repr.nullIfSingleton_ == nullptr || repr.first != nullptr);
return repr.first != nullptr;
}
bool isSharedAndNonNull() const {
return rawRepr().nullIfSingleton_ != nullptr;
}
private:
void destroy() {
if (isSharedAndNonNull()) {
// Without SharedPtrWrapper, this line would read
// `shared_.~shared_ptr()` and nvcc would complain with
// "error: expected primary-expression before '>' token"
// referring to the "t" in "shared_ptr". SharedPtrWrapper
// exists to work around this compiler bug.
shared_.~SharedPtrWrapper();
}
}
} repr_;
std::shared_ptr<T> repr_;
};
using TypePtr = SingletonOrSharedTypePtr<Type>;

View File

@ -21,7 +21,7 @@ namespace c10 {
namespace detail {
// The first argument of the schema might be of type DispatchKeySet, in which case we remove it.
// We do this because every argument in a function schema is expected to be convertable
// We do this because every argument in a function schema is expected to be convertible
// to an ivalue, but DispatchKeySet is not a type we want the jit to be aware of.
// See Note [Plumbing Keys Through The Dispatcher]
template<class KernelFunctor>

View File

@ -251,7 +251,7 @@ TEST(OperatorRegistrationTest, whenRegisteringCPUTensorType_thenCanOnlyCallUnbox
callOpUnboxedWithPrecomputedDispatchKeySet<void, Tensor>(*op, c10::DispatchKeySet(c10::DispatchKey::CPU), dummyTensor(c10::DispatchKey::CUDA));
EXPECT_TRUE(called_kernel_cpu);
// Ensure that disptach key from tensor is not used here.
// Ensure that dispatch key from tensor is not used here.
called_kernel_cpu = false;
expectThrows<c10::Error>([&] {
callOpUnboxedWithPrecomputedDispatchKeySet<void, Tensor>(*op, c10::DispatchKeySet(c10::DispatchKey::CUDA), dummyTensor(c10::DispatchKey::CPU));

View File

@ -172,7 +172,7 @@ VaryingShape<Stride> TensorType::computeStrideProps(
// The logic below follows what TensorIterator uses in its logic:
// 1. Fast_set_up is the short-cut to identify a. channels_last and
// b. contiguous format, which is what we have in the below logic.
// 2. In more generla cases, it does best effort to preserve permutatoin.
// 2. In more general cases, it does best effort to preserve permutatoin.
if (is_channels_last_strides_2d(sizes, strides) || is_channels_last_strides_3d(sizes, strides)) {
// case 1.a. short cut channels last
std::iota(stride_indices.rbegin() + 1, stride_indices.rend() - 1, 2);

View File

@ -104,71 +104,6 @@ class Vectorized<float> {
}
return b;
}
// Implementation is picked from
// https://github.com/ARM-software/ComputeLibrary/blob/v25.01/src/core/NEON/SVEMath.inl#L105
inline svfloat32_t svexp_f32_z(svbool_t pg, svfloat32_t x) const {
const auto c1 =
svreinterpret_f32_u32(svdup_n_u32(0x3f7ffff6)); // x^1: 0x1.ffffecp-1f
const auto c2 =
svreinterpret_f32_u32(svdup_n_u32(0x3efffedb)); // x^2: 0x1.fffdb6p-2f
const auto c3 =
svreinterpret_f32_u32(svdup_n_u32(0x3e2aaf33)); // x^3: 0x1.555e66p-3f
const auto c4 =
svreinterpret_f32_u32(svdup_n_u32(0x3d2b9f17)); // x^4: 0x1.573e2ep-5f
const auto c5 =
svreinterpret_f32_u32(svdup_n_u32(0x3c072010)); // x^5: 0x1.0e4020p-7f
const auto shift = svreinterpret_f32_u32(
svdup_n_u32(0x4b00007f)); // 2^23 + 127 = 0x1.0000fep23f
const auto inv_ln2 = svreinterpret_f32_u32(
svdup_n_u32(0x3fb8aa3b)); // 1 / ln(2) = 0x1.715476p+0f
const auto neg_ln2_hi = svreinterpret_f32_u32(svdup_n_u32(
0xbf317200)); // -ln(2) from bits -1 to -19: -0x1.62e400p-1f
const auto neg_ln2_lo = svreinterpret_f32_u32(svdup_n_u32(
0xb5bfbe8e)); // -ln(2) from bits -20 to -42: -0x1.7f7d1cp-20f
const auto inf = svdup_n_f32(std::numeric_limits<float>::infinity());
const auto max_input = svdup_n_f32(88.37f); // Approximately ln(2^127.5)
const auto zero = svdup_n_f32(0.f);
const auto min_input = svdup_n_f32(-86.64f); // Approximately ln(2^-125)
// Range reduction:
// e^x = 2^n * e^r
// where:
// n = floor(x / ln(2))
// r = x - n * ln(2)
//
// By adding x / ln(2) with 2^23 + 127 (shift):
// * As FP32 fraction part only has 23-bits, the addition of 2^23 + 127
// forces decimal part
// of x / ln(2) out of the result. The integer part of x / ln(2) (i.e.
// n) + 127 will occupy the whole fraction part of z in FP32 format.
// Subtracting 2^23 + 127 (shift) from z will result in the integer part
// of x / ln(2) (i.e. n) because the decimal part has been pushed out
// and lost.
// * The addition of 127 makes the FP32 fraction part of z ready to be
// used as the exponent
// in FP32 format. Left shifting z by 23 bits will result in 2^n.
const auto z = svmla_f32_z(pg, shift, x, inv_ln2);
const auto n = svsub_f32_z(pg, z, shift);
const auto scale = svreinterpret_f32_u32(
svlsl_n_u32_z(pg, svreinterpret_u32_f32(z), 23)); // 2^n
// The calculation of n * ln(2) is done using 2 steps to achieve accuracy
// beyond FP32. This outperforms longer Taylor series (3-4 tabs) both in
// term of accuracy and performance.
const auto r_hi = svmla_f32_z(pg, x, n, neg_ln2_hi);
const auto r = svmla_f32_z(pg, r_hi, n, neg_ln2_lo);
// Compute the truncated Taylor series of e^r.
// poly = scale * (1 + c1 * r + c2 * r^2 + c3 * r^3 + c4 * r^4 + c5 * r^5)
const auto r2 = svmul_f32_z(pg, r, r);
const auto p1 = svmul_f32_z(pg, c1, r);
const auto p23 = svmla_f32_z(pg, c2, c3, r);
const auto p45 = svmla_f32_z(pg, c4, c5, r);
const auto p2345 = svmla_f32_z(pg, p23, p45, r2);
const auto p12345 = svmla_f32_z(pg, p1, p2345, r2);
auto poly = svmla_f32_z(pg, scale, p12345, scale);
// Handle underflow and overflow.
poly = svsel_f32(svcmplt_f32(pg, x, min_input), zero, poly);
poly = svsel_f32(svcmpgt_f32(pg, x, max_input), inf, poly);
return poly;
}
static Vectorized<float> loadu(const void* ptr, int64_t count = size()) {
if (count == size())
return svld1_f32(ptrue, reinterpret_cast<const float*>(ptr));
@ -313,11 +248,41 @@ class Vectorized<float> {
return USE_SLEEF(
Vectorized<float>(Sleef_expm1fx_u10sve(values)), map(std::expm1));
}
// Implementation copied from Arm Optimized Routines:
// https://github.com/ARM-software/optimized-routines/blob/master/math/aarch64/sve/expf.c
Vectorized<float> exp_u20() const {
return exp();
// special case to handle special inputs that are too large or too small
// i.e. where there's at least one element x, s.t. |x| >= 87.3...
svbool_t is_special_case = svacgt(svptrue_b32(), values, 0x1.5d5e2ap+6f);
if (svptest_any(svptrue_b32(), is_special_case)) {
return exp();
}
const svfloat32_t ln2_hi = svdup_n_f32(0x1.62e4p-1f);
const svfloat32_t ln2_lo = svdup_n_f32(0x1.7f7d1cp-20f);
const svfloat32_t c1 = svdup_n_f32(0.5f);
const svfloat32_t inv_ln2 = svdup_n_f32(0x1.715476p+0f);
const float shift = 0x1.803f8p17f;
/* n = round(x/(ln2/N)). */
svfloat32_t z = svmad_x(svptrue_b32(), inv_ln2, values, shift);
svfloat32_t n = svsub_x(svptrue_b32(), z, shift);
/* r = x - n*ln2/N. */
svfloat32_t r = values;
r = svmls_x(svptrue_b32(), r, n, ln2_hi);
r = svmls_x(svptrue_b32(), r, n, ln2_lo);
/* scale = 2^(n/N). */
svfloat32_t scale = svexpa(svreinterpret_u32(z));
/* poly(r) = exp(r) - 1 ~= r + 0.5 r^2. */
svfloat32_t r2 = svmul_x(svptrue_b32(), r, r);
svfloat32_t poly = svmla_x(svptrue_b32(), r, r2, c1);
return svmla_x(svptrue_b32(), scale, scale, poly);
}
Vectorized<float> fexp_u20() const {
return exp();
return exp_u20();
}
Vectorized<float> fmod(const Vectorized<float>& q) const {USE_SLEEF(
{ return Vectorized<float>(Sleef_fmodfx_sve(values, q)); },
@ -453,9 +418,11 @@ class Vectorized<float> {
ptrue, svmax_f32_z(ptrue, values, CONST_MIN_TANH), CONST_MAX_TANH);
// Step 2: Calculate exp(2 * x), where x is the clamped value.
// svmul_f32_z computes 2 * x, and svexp_f32_z computes the exponential of
// the result.
svfloat32_t exp2x = svexp_f32_z(ptrue, svmul_f32_z(ptrue, CONST_2, x));
// svmul_f32_z computes 2 * x, and exp_u20() computes the exponential of
// the result (via Vectorized<float>, then auto-converts back to
// svfloat32_t).
svfloat32_t exp2x =
Vectorized<float>(svmul_f32_z(ptrue, CONST_2, x)).exp_u20();
// Step 3: Calculate the numerator of the tanh function, which is exp(2x)
// - 1.

View File

@ -5,6 +5,114 @@
namespace at::vec {
inline namespace CPU_CAPABILITY {
#if (defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE256))
// Enable auto-vectorization for GCC-13+ and clang-17+
// GCC-12 has a bug: gcc.gnu.org/bugzilla/show_bug.cgi?id=117001
#if __GNUC__ > 12 || (defined(__clang__) && (__clang_major__ >= 17))
template <typename from_type, typename to_type>
inline void convertImpl(
const from_type* __restrict src,
to_type* __restrict dst,
int64_t n) {
uint64_t len = static_cast<uint64_t>(n);
for (uint64_t i = 0; i < len; i++) {
dst[i] = static_cast<to_type>(src[i]);
}
}
#define CONVERT_TEMPLATE(from_type, to_type) \
template <> \
inline void convert(const from_type* src, to_type* dst, int64_t n) { \
return convertImpl<from_type, to_type>(src, dst, n); \
}
CONVERT_TEMPLATE(uint8_t, uint8_t)
CONVERT_TEMPLATE(uint8_t, int8_t)
CONVERT_TEMPLATE(uint8_t, int16_t)
CONVERT_TEMPLATE(uint8_t, int32_t)
CONVERT_TEMPLATE(uint8_t, int64_t)
CONVERT_TEMPLATE(uint8_t, float)
CONVERT_TEMPLATE(uint8_t, double)
CONVERT_TEMPLATE(int8_t, uint8_t)
CONVERT_TEMPLATE(int8_t, int8_t)
CONVERT_TEMPLATE(int8_t, int16_t)
CONVERT_TEMPLATE(int8_t, int32_t)
CONVERT_TEMPLATE(int8_t, int64_t)
CONVERT_TEMPLATE(int8_t, float)
CONVERT_TEMPLATE(int8_t, double)
CONVERT_TEMPLATE(int16_t, uint8_t)
CONVERT_TEMPLATE(int16_t, int8_t)
CONVERT_TEMPLATE(int16_t, int16_t)
CONVERT_TEMPLATE(int16_t, int32_t)
CONVERT_TEMPLATE(int16_t, int64_t)
CONVERT_TEMPLATE(int16_t, float)
CONVERT_TEMPLATE(int16_t, double)
CONVERT_TEMPLATE(int32_t, uint8_t)
CONVERT_TEMPLATE(int32_t, int8_t)
CONVERT_TEMPLATE(int32_t, int16_t)
CONVERT_TEMPLATE(int32_t, int32_t)
CONVERT_TEMPLATE(int32_t, int64_t)
CONVERT_TEMPLATE(int32_t, float)
CONVERT_TEMPLATE(int32_t, double)
CONVERT_TEMPLATE(int64_t, uint8_t)
CONVERT_TEMPLATE(int64_t, int8_t)
CONVERT_TEMPLATE(int64_t, int16_t)
CONVERT_TEMPLATE(int64_t, int32_t)
CONVERT_TEMPLATE(int64_t, int64_t)
CONVERT_TEMPLATE(int64_t, float)
CONVERT_TEMPLATE(int64_t, double)
CONVERT_TEMPLATE(float, uint8_t)
CONVERT_TEMPLATE(float, int8_t)
CONVERT_TEMPLATE(float, int16_t)
CONVERT_TEMPLATE(float, int32_t)
CONVERT_TEMPLATE(float, int64_t)
CONVERT_TEMPLATE(float, float)
CONVERT_TEMPLATE(float, double)
CONVERT_TEMPLATE(double, uint8_t)
CONVERT_TEMPLATE(double, int8_t)
CONVERT_TEMPLATE(double, int16_t)
CONVERT_TEMPLATE(double, int32_t)
CONVERT_TEMPLATE(double, int64_t)
CONVERT_TEMPLATE(double, float)
CONVERT_TEMPLATE(double, double)
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
CONVERT_TEMPLATE(float16_t, uint8_t)
CONVERT_TEMPLATE(float16_t, int8_t)
CONVERT_TEMPLATE(float16_t, int16_t)
CONVERT_TEMPLATE(float16_t, int32_t)
CONVERT_TEMPLATE(float16_t, int64_t)
CONVERT_TEMPLATE(float16_t, float16_t)
CONVERT_TEMPLATE(float16_t, float)
CONVERT_TEMPLATE(float16_t, double)
CONVERT_TEMPLATE(uint8_t, float16_t)
CONVERT_TEMPLATE(int8_t, float16_t)
CONVERT_TEMPLATE(int16_t, float16_t)
CONVERT_TEMPLATE(int32_t, float16_t)
CONVERT_TEMPLATE(int64_t, float16_t)
CONVERT_TEMPLATE(float, float16_t)
CONVERT_TEMPLATE(double, float16_t)
#endif
#ifdef __ARM_FEATURE_BF16
CONVERT_TEMPLATE(bfloat16_t, uint8_t)
CONVERT_TEMPLATE(bfloat16_t, int8_t)
CONVERT_TEMPLATE(bfloat16_t, int16_t)
CONVERT_TEMPLATE(bfloat16_t, int32_t)
CONVERT_TEMPLATE(bfloat16_t, int64_t)
CONVERT_TEMPLATE(bfloat16_t, bfloat16_t)
CONVERT_TEMPLATE(bfloat16_t, float)
CONVERT_TEMPLATE(bfloat16_t, double)
CONVERT_TEMPLATE(uint8_t, bfloat16_t)
CONVERT_TEMPLATE(int8_t, bfloat16_t)
CONVERT_TEMPLATE(int16_t, bfloat16_t)
CONVERT_TEMPLATE(int32_t, bfloat16_t)
CONVERT_TEMPLATE(int64_t, bfloat16_t)
CONVERT_TEMPLATE(float, bfloat16_t)
CONVERT_TEMPLATE(double, bfloat16_t)
#endif
#endif
template <typename src_t>
struct VecConvert<
float,

View File

@ -307,11 +307,49 @@ class Vectorized<float> {
DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(exp)
DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(exp2)
DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(expm1)
// Implementation copied from Arm Optimized Routine
// https://github.com/ARM-software/optimized-routines/blob/master/math/aarch64/advsimd/expf.c
Vectorized<float> exp_u20() const {
return exp();
// bail out to sleef if it's a special case:
// i.e. there's an input s.t. |input| > 87.3....
const float32x4_t special_bound = vdupq_n_f32(0x1.5d5e2ap+6f);
uint32x4_t cmp = vcagtq_f32(values, special_bound);
if (vpaddd_u64(vreinterpretq_u64_u32(cmp)) != 0) {
return exp();
}
const float32x4_t inv_ln2 = vdupq_n_f32(0x1.715476p+0f);
const float ln2_hi = 0x1.62e4p-1f;
const float ln2_lo = 0x1.7f7d1cp-20f;
const float c0 = 0x1.0e4020p-7f;
const float c2 = 0x1.555e66p-3f;
const float32x4_t ln2_c02 = {ln2_hi, ln2_lo, c0, c2};
const uint32x4_t exponent_bias = vdupq_n_u32(0x3f800000);
const float32x4_t c1 = vdupq_n_f32(0x1.573e2ep-5f);
const float32x4_t c3 = vdupq_n_f32(0x1.fffdb6p-2f);
const float32x4_t c4 = vdupq_n_f32(0x1.ffffecp-1f);
/* exp(x) = 2^n (1 + poly(r)), with 1 + poly(r) in [1/sqrt(2),sqrt(2)]
x = ln2*n + r, with r in [-ln2/2, ln2/2]. */
float32x4_t n = vrndaq_f32(vmulq_f32(values, inv_ln2));
float32x4_t r = vfmsq_laneq_f32(values, n, ln2_c02, 0);
r = vfmsq_laneq_f32(r, n, ln2_c02, 1);
uint32x4_t e = vshlq_n_u32(vreinterpretq_u32_s32(vcvtq_s32_f32(n)), 23);
float32x4_t scale = vreinterpretq_f32_u32(vaddq_u32(e, exponent_bias));
float32x4_t r2 = vmulq_f32(r, r);
float32x4_t p = vfmaq_laneq_f32(c1, r, ln2_c02, 2);
float32x4_t q = vfmaq_laneq_f32(c3, r, ln2_c02, 3);
q = vfmaq_f32(q, p, r2);
p = vmulq_f32(c4, r);
float32x4_t poly = vfmaq_f32(p, q, r2);
return vfmaq_f32(scale, poly, scale);
}
Vectorized<float> fexp_u20() const {
return exp();
return exp_u20();
}
DEFINE_SLEEF_COMPATIBLE_BINARY_ELEMENTWISE_FUNC_WITH_SLEEF_NAME(
fmod,
@ -540,42 +578,6 @@ inline Vectorized<float> Vectorized<float>::le(
return (*this <= other) & Vectorized<float>(1.0f);
}
template <>
inline void convert(const float* src, int32_t* dst, int64_t n) {
int64_t i;
#ifndef __msvc_cl__
#pragma unroll
#endif
for (i = 0; i <= (n - Vectorized<float>::size());
i += Vectorized<float>::size()) {
vst1q_s32(dst + i, vcvtq_s32_f32(vld1q_f32(src + i)));
}
#ifndef __msvc_cl__
#pragma unroll
#endif
for (; i < n; i++) {
dst[i] = static_cast<int32_t>(src[i]);
}
}
template <>
inline void convert(const int32_t* src, float* dst, int64_t n) {
int64_t i;
#ifndef __msvc_cl__
#pragma unroll
#endif
for (i = 0; i <= (n - Vectorized<float>::size());
i += Vectorized<float>::size()) {
vst1q_f32(dst + i, vcvtq_f32_s32(vld1q_s32(src + i)));
}
#ifndef __msvc_cl__
#pragma unroll
#endif
for (; i < n; i++) {
dst[i] = static_cast<float>(src[i]);
}
}
template <>
Vectorized<float> inline fmadd(
const Vectorized<float>& a,
@ -632,8 +634,7 @@ inline Vectorized<float> Vectorized<float>::erf() const {
// - exp(- x * x)
auto pow_2 = (*this) * (*this);
auto neg_pow_2 = pow_2 ^ neg_zero_vec;
auto tmp4 = neg_pow_2.map(
std::exp); // This can be swapped for a faster implementation of exp.
auto tmp4 = neg_pow_2.exp();
auto tmp5 = tmp4 ^ neg_zero_vec;
// erf(x) = sign(x) * (1 - r * t * exp(- x * x))
auto tmp6 = t * tmp5;

View File

@ -234,7 +234,7 @@ class Vectorized<c10::Half> : public Vectorized16<
vshlq_u16(vandq_u16(is_zero_vec, vdupq_n_u16(1)), shift);
return vaddvq_u16(bits_vec);
#else // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
// use known working implmentation.
// use known working implementation.
__at_align__ value_type tmp[size()];
store(tmp);
int mask = 0;
@ -569,46 +569,6 @@ inline Vectorized<c10::Half> Vectorized<c10::Half>::le(
return (*this <= other) & Vectorized<c10::Half>(1);
}
// These are global functions, so the defaults in vec_base.h should
// work fine if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC is not available.
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
template <>
inline void convert(const float16_t* src, int16_t* dst, int64_t n) {
int64_t i;
#ifndef __msvc_cl__
#pragma unroll
#endif
for (i = 0; i <= (n - Vectorized<c10::Half>::size());
i += Vectorized<c10::Half>::size()) {
vst1q_s16(dst + i, vcvtq_s16_f16(vld1q_f16(src + i)));
}
#ifndef __msvc_cl__
#pragma unroll
#endif
for (; i < n; i++) {
dst[i] = static_cast<int16_t>(src[i]);
}
}
template <>
inline void convert(const int16_t* src, float16_t* dst, int64_t n) {
int64_t i;
#ifndef __msvc_cl__
#pragma unroll
#endif
for (i = 0; i <= (n - Vectorized<c10::Half>::size());
i += Vectorized<c10::Half>::size()) {
vst1q_f16(dst + i, vcvtq_f16_s16(vld1q_s16(src + i)));
}
#ifndef __msvc_cl__
#pragma unroll
#endif
for (; i < n; i++) {
dst[i] = static_cast<float16_t>(src[i]);
}
}
#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
template <>
Vectorized<c10::Half> inline fmadd(
const Vectorized<c10::Half>& a,

View File

@ -1740,7 +1740,7 @@ Vectorized<int16_t> inline shift_256_16(
// Control masks for shuffle operation, treating 256 bits as an
// array of 16-bit elements, and considering pairs of neighboring
// elements. Specifially, a mask named "ctl_M_N" (M,N in [0,1], and
// elements. Specifically, a mask named "ctl_M_N" (M,N in [0,1], and
// M!=N) is set so that shuffle will move element with index M from
// input pair into element with index N in output pair, and element
// with index M in output pair will be set to all 0s.
@ -1875,7 +1875,7 @@ Vectorized<T> inline shift_256_8(
// Control masks for shuffle operation, treating 256 bits as an
// array of 8-bit elements, and considering quadruples of
// neighboring elements. Specifially, a mask named "ctl_M_N" (M,N
// neighboring elements. Specifically, a mask named "ctl_M_N" (M,N
// in [0,1,2,3], and M!=N) is set so that shuffle will move element
// with index M from input quadruple into element with index N in
// output quadruple, and other elements in output quadruple will be

View File

@ -143,7 +143,7 @@ class Vectorized<double> {
const Vectorized<double>& a,
const Vectorized<double>& b,
const Vectorized<double>& mask) {
// the mask used here returned by comparision of vec256
// the mask used here returned by comparison of vec256
return {
vec_sel(a._vec0, b._vec0, mask._vecb0),

View File

@ -142,7 +142,7 @@ class Vectorized<float> {
const Vectorized<float>& a,
const Vectorized<float>& b,
const Vectorized<float>& mask) {
// the mask used here returned by comparision of vec256
// the mask used here returned by comparison of vec256
// assuming this we can use the same mask directly with vec_sel
return {
vec_sel(a._vec0, b._vec0, mask._vecb0),

View File

@ -202,7 +202,7 @@ class Vectorized<int16_t> {
const Vectorized<int16_t>& a,
const Vectorized<int16_t>& b,
const Vectorized<int16_t>& mask) {
// the mask used here returned by comparision of vec256
// the mask used here returned by comparison of vec256
// assuming this we can use the same mask directly with vec_sel
// warning intel style mask will not work properly
return {

View File

@ -155,7 +155,7 @@ class Vectorized<int32_t> {
const Vectorized<int32_t>& a,
const Vectorized<int32_t>& b,
const Vectorized<int32_t>& mask) {
// the mask used here returned by comparision of vec256
// the mask used here returned by comparison of vec256
// assuming this we can use the same mask directly with vec_sel
// warning intel style mask will not work properly
return {

View File

@ -119,7 +119,7 @@ class Vectorized<int64_t> {
const Vectorized<int64_t>& a,
const Vectorized<int64_t>& b,
const Vectorized<int64_t>& mask) {
// the mask used here returned by comparision of vec256
// the mask used here returned by comparison of vec256
return {
vec_sel(a._vec0, b._vec0, mask._vecb0),

View File

@ -397,7 +397,7 @@ inline Vectorized<bool> operator&&(
const __m512i* other_ = reinterpret_cast<const __m512i*>(other.as_bytes());
__m512i out = _mm512_and_si512(*self_, *other_);
Vectorized<bool> ret;
// We do not have a constructer that takes __m512i, so we need to memcpy
// We do not have a constructor that takes __m512i, so we need to memcpy
std::memcpy(ret, &out, ret.size() * sizeof(bool));
return ret;
}

View File

@ -1852,7 +1852,7 @@ Vectorized<T> inline shift_512_8(
// Control masks for shuffle operation, treating 512 bits as an
// array of 8-bit elements, and considering pairs of neighboring
// elements. Specifially, a mask named "ctl_M_N" (M,N in [0,1], and
// elements. Specifically, a mask named "ctl_M_N" (M,N in [0,1], and
// M!=N) is set so that shuffle will move element with index M from
// input pair into element with index N in output pair, and element
// with index M in output pair will be set to all 0s.

View File

@ -634,7 +634,7 @@ struct Vectorized {
}
Vectorized<T> neg() const {
// NB: the trailing return type is needed because we need to coerce the
// return value back to T in the case of unary operator- incuring a
// return value back to T in the case of unary operator- incurring a
// promotion
return map([](T x) -> T { return -x; });
}

View File

@ -1958,7 +1958,7 @@ void scaled_gemm(
ScalarType result_dtype,
bool use_fast_accum,
const std::optional<Tensor>& alpha) {
// Note: see `cublasCommonArgs` for various non-intuitive manupulations
// Note: see `cublasCommonArgs` for various non-intuitive manipulations
// of input arguments to this function.
const auto computeType = CUBLAS_COMPUTE_32F;
const auto scaleType = CUDA_R_32F;

View File

@ -168,11 +168,9 @@ void CUDAGraph::instantiate() {
// https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__GRAPH.html#group__CUDART__GRAPH_1g1accfe1da0c605a577c22d9751a09597
// cudaGraphInstantiateWithFlags
// https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__GRAPH.html#group__CUDART__GRAPH_1ga2c652a24ba93e52b99a47bec0888233
#if !defined(USE_ROCM) || ROCM_VERSION >= 60200
int version = 0;
AT_CUDA_CHECK(cudaDriverGetVersion(&version));
if (version < 11040) {
#endif
// Trailing NULL, NULL, 0 arguments were recommended by Cuda driver people,
// who prefer not to report error message through these arguments moving forward
// (they prefer return value, or errors on api calls internal to the capture)
@ -183,13 +181,11 @@ void CUDAGraph::instantiate() {
#endif
//Since ROCm 6.2, we want to go down this path as hipGraphExecDestroy in the destructor will not immediately free the memory.
//It will wait for the next sync operation. cudaGraphInstantiateFlagAutoFreeOnLaunch will add async frees after graph launch.
#if !defined(USE_ROCM) || ROCM_VERSION >= 60200
} else {
AT_CUDA_CHECK(cudaGraphInstantiateWithFlags(&graph_exec_,
graph_,
cudaGraphInstantiateFlagAutoFreeOnLaunch));
}
#endif
has_graph_exec_ = true;
}
@ -311,7 +307,7 @@ CUDAGraph::~CUDAGraph() {
// There are recent HIP changes where hipGraphExecDestroy doesn't immediately free memory.
// They wait for next sync point in order to free the memory, this is to ensure that all
// hipGraphLaunch are finished before we release any memory. This feature was enabled in rocm6.2.
// We need to ensure all async opreations finish before deleting the object.
// We need to ensure all async operations finish before deleting the object.
#if (defined(USE_ROCM) && ROCM_VERSION >= 60200)
if (capture_dev_ != UNDEFINED_DEVICE) // check if capture_dev_ contains the real device id
{

View File

@ -0,0 +1,270 @@
#include <cstdint>
#include <c10/util/typeid.h>
#include <c10/util/Exception.h>
#include <c10/util/SmallVector.h>
#include <c10/core/Scalar.h>
#include <c10/core/ScalarType.h>
#include <c10/util/Exception.h>
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/core/Tensor.h>
#include <ATen/core/NamedTensor.h>
#include <ATen/Dispatch.h>
#include <ATen/ExpandUtils.h>
#include <ATen/OpMathType.h>
#include <ATen/TensorUtils.h>
#include <ATen/cuda/CUDABlas.h>
#include <ATen/cuda/tunable/Tunable.h>
#include <ATen/cuda/tunable/TunableGemm.h>
#include <ATen/native/Resize.h>
#include <c10/util/MaybeOwned.h>
#include <ATen/native/GroupedMMUtils.h>
#include <ATen/native/cuda/RowwiseScaledMM.h>
#include <ATen/native/cuda/ScaledGroupMM.h>
#include <ATen/native/cuda/GroupMM.h>
#include <ATen/ceil_div.h>
#ifdef USE_FBGEMM_GENAI
#include <fbgemm_gpu/torch_ops.h>
#endif
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/_addmm_activation_native.h>
#include <ATen/ops/_efficientzerotensor.h>
#include <ATen/ops/_scaled_mm_native.h>
#include <ATen/ops/_unsafe_view_native.h>
#include <ATen/ops/abs.h>
#include <ATen/ops/addmm_native.h>
#include <ATen/ops/addmv_native.h>
#include <ATen/ops/baddbmm_native.h>
#include <ATen/ops/bmm_native.h>
#include <ATen/ops/copy_native.h>
#include <ATen/ops/dot_native.h>
#include <ATen/ops/empty.h>
#include <ATen/ops/empty_strided.h>
#include <ATen/ops/gelu.h>
#include <ATen/ops/max.h>
#include <ATen/ops/mm_native.h>
#include <ATen/ops/mul.h>
#include <ATen/ops/relu.h>
#include <ATen/ops/ones.h>
#include <ATen/ops/scalar_tensor_native.h>
#include <ATen/ops/vdot_native.h>
#endif
using at::blas::ScalingType;
using at::blas::SwizzleType;
namespace at::cuda::scaled {
/**
* Both inputs must be fp8,
* Each needs a single scale, {Tensorwise (float)}
*/
bool check_tensorwise_recipe(c10::ScalarType type_a,
std::vector<ScalingType>& recipe_a,
ArrayRef<Tensor>& scales_a,
c10::ScalarType type_b,
std::vector<ScalingType>& recipe_b,
ArrayRef<Tensor>& scales_b) {
// both types must be fp8
if (!isFloat8Type(type_a) || !isFloat8Type(type_b)) {
return false;
}
// 1 scale each, {Tensorwise, float}
if (scales_a.size() != 1 || recipe_a.size() != 1 || scales_b.size() != 1 || recipe_b.size() != 1) {
return false;
}
// Need {Blockwise_1x32, e8m0} for A & B
if (recipe_a[0] != ScalingType::TensorWise) return false;
if (scales_a[0].scalar_type() != ScalarType::Float) return false;
if (recipe_b[0] != ScalingType::TensorWise) return false;
if (scales_b[0].scalar_type() != ScalarType::Float) return false;
return true;
}
/**
* Both inputs must be fp8,
* Each needs scales, {Rowwise (float)}
*/
bool check_rowwise_recipe(c10::ScalarType type_a,
std::vector<ScalingType>& recipe_a,
ArrayRef<Tensor>& scales_a,
c10::ScalarType type_b,
std::vector<ScalingType>& recipe_b,
ArrayRef<Tensor>& scales_b) {
// both types must be fp8
if (!isFloat8Type(type_a) || !isFloat8Type(type_b)) {
return false;
}
// 1 scale each, {Tensorwise, float}
if (scales_a.size() != 1 || recipe_a.size() != 1 || scales_b.size() != 1 || recipe_b.size() != 1) {
return false;
}
// Need {RowWise, dp32} for A & B
if (recipe_a[0] != ScalingType::RowWise) return false;
if (scales_a[0].scalar_type() != ScalarType::Float) return false;
if (recipe_b[0] != ScalingType::RowWise) return false;
if (scales_b[0].scalar_type() != ScalarType::Float) return false;
return true;
}
/**
* Two-level scaling, canonical NVFP4
* Both inputs must be fp4
* A, B need 2 scales, {Blockwise_1x16 (e4m3), Tensorwise (fp32)}
*/
bool check_nvfp4_recipe(c10::ScalarType type_a,
std::vector<ScalingType>& recipe_a,
ArrayRef<Tensor>& scales_a,
c10::ScalarType type_b,
std::vector<ScalingType>& recipe_b,
ArrayRef<Tensor>& scales_b) {
// both types must be fp4
if (type_a != ScalarType::Float4_e2m1fn_x2 || type_b != ScalarType::Float4_e2m1fn_x2) {
return false;
}
// 2 scales, 2 recipes for each input
if (scales_a.size() != 2 || recipe_a.size() != 2 || scales_b.size() != 2 || recipe_b.size() != 2) {
return false;
}
// Need {Blockwise_1x16, e4m3 for scale[0], Tensorwise, fp32 for scale[1]}
if (recipe_a[0] != ScalingType::BlockWise1x16 || recipe_a[1] != ScalingType::TensorWise) return false;
if (scales_a[0].scalar_type() != ScalarType::Float8_e4m3fn || scales_a[1].scalar_type() != ScalarType::Float) return false;
if (recipe_b[0] != ScalingType::BlockWise1x16 || recipe_b[1] != ScalingType::TensorWise) return false;
if (scales_b[0].scalar_type() != ScalarType::Float8_e4m3fn || scales_b[1].scalar_type() != ScalarType::Float) return false;
return true;
}
/**
* Single-level scaling, what PyT currently understands
* Both inputs must be fp4
* A, B need 1 scale, {Blockwise_1x16 (e4m3)}
*/
bool check_nvfp4_recipe_single_scale
(c10::ScalarType type_a,
std::vector<ScalingType>& recipe_a,
ArrayRef<Tensor>& scales_a,
c10::ScalarType type_b,
std::vector<ScalingType>& recipe_b,
ArrayRef<Tensor>& scales_b) {
// both types must be fp4
if (type_a != ScalarType::Float4_e2m1fn_x2 || type_b != ScalarType::Float4_e2m1fn_x2) {
return false;
}
// 2 scales, 2 recipes for each input
if (scales_a.size() != 1 || recipe_a.size() != 1 || scales_b.size() != 1 || recipe_b.size() != 1) {
return false;
}
// Need {Blockwise_1x16, e4m3 for scale[0], Tensorwise, fp32 for scale[1]}
if (recipe_a[0] != ScalingType::BlockWise1x16) return false;
if (scales_a[0].scalar_type() != ScalarType::Float8_e4m3fn) return false;
if (recipe_b[0] != ScalingType::BlockWise1x16) return false;
if (scales_b[0].scalar_type() != ScalarType::Float8_e4m3fn) return false;
return true;
}
/**
* Both inputs must be fp8
* A, B must only have 1 scale each, A: {Blockwise_1x128 (float), B: {Blockwise_128x128 (float)
*/
bool check_deepseek_recipe(ScalingType expected_recipe_a,
ScalingType expected_recipe_b,
c10::ScalarType type_a,
std::vector<ScalingType>& recipe_a,
ArrayRef<Tensor>& scales_a,
c10::ScalarType type_b,
std::vector<ScalingType>& recipe_b,
ArrayRef<Tensor>& scales_b) {
// both types must be fp8
if (type_a != ScalarType::Float8_e4m3fn || type_b != ScalarType::Float8_e4m3fn) {
return false;
}
// 1 scales, 1 recipes for each input
if (scales_a.size() != 1 || recipe_a.size() != 1 || scales_b.size() != 1 || recipe_b.size() != 1) {
return false;
}
// Need {Blockwise_1x128, float} for A, {Blockwise_128x128, float} for B
if (recipe_a[0] != expected_recipe_a) return false;
if (scales_a[0].scalar_type() != ScalarType::Float) return false;
if (recipe_b[0] != expected_recipe_b) return false;
if (scales_b[0].scalar_type() != ScalarType::Float) return false;
return true;
}
/**
* Both inputs must be fp8
* A, B must have 1 scale each, {Blockwise_1x32, e8m0}
*/
bool check_mxfp8_recipe(c10::ScalarType type_a,
std::vector<ScalingType>& recipe_a,
ArrayRef<Tensor>& scales_a,
c10::ScalarType type_b,
std::vector<ScalingType>& recipe_b,
ArrayRef<Tensor>& scales_b) {
// both types must be fp8
if (type_a != ScalarType::Float8_e4m3fn || type_b != ScalarType::Float8_e4m3fn) {
return false;
}
// 1 scales, 1 recipes for each input
if (scales_a.size() != 1 || recipe_a.size() != 1 || scales_b.size() != 1 || recipe_b.size() != 1) {
return false;
}
// Need {Blockwise_1x32, e8m0} for A & B
if (recipe_a[0] != ScalingType::BlockWise1x32) return false;
if (scales_a[0].scalar_type() != ScalarType::Float8_e8m0fnu) return false;
if (recipe_b[0] != ScalingType::BlockWise1x32) return false;
if (scales_b[0].scalar_type() != ScalarType::Float8_e8m0fnu) return false;
return true;
}
/**
* Both inputs must be fp4
* A, B must have 1 scale each, {Blockwise_1x32, e8m0}
*/
bool check_mxfp4_recipe(c10::ScalarType type_a,
std::vector<ScalingType>& recipe_a,
ArrayRef<Tensor>& scales_a,
c10::ScalarType type_b,
std::vector<ScalingType>& recipe_b,
ArrayRef<Tensor>& scales_b) {
// both types must be fp4
if (type_a != ScalarType::Float4_e2m1fn_x2 || type_b != ScalarType::Float4_e2m1fn_x2) {
return false;
}
// 1 scales, 1 recipes for each input
if (scales_a.size() != 1 || recipe_a.size() != 1 || scales_b.size() != 1 || recipe_b.size() != 1) {
return false;
}
// Need {Blockwise_1x32, e8m0} for A & B
if (recipe_a[0] != ScalingType::BlockWise1x32) return false;
if (scales_a[0].scalar_type() != ScalarType::Float8_e8m0fnu) return false;
if (recipe_b[0] != ScalingType::BlockWise1x32) return false;
if (scales_b[0].scalar_type() != ScalarType::Float8_e8m0fnu) return false;
return true;
}
} // namespace at::native::cuda::blas::scaled

View File

@ -0,0 +1,174 @@
#include <cstdint>
#include <c10/util/typeid.h>
#include <c10/util/Exception.h>
#include <c10/util/SmallVector.h>
#include <c10/core/Scalar.h>
#include <c10/core/ScalarType.h>
#include <c10/util/Exception.h>
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/core/Tensor.h>
#include <ATen/core/NamedTensor.h>
#include <ATen/Dispatch.h>
#include <ATen/ExpandUtils.h>
#include <ATen/OpMathType.h>
#include <ATen/TensorUtils.h>
#include <ATen/cuda/CUDABlas.h>
#include <ATen/cuda/tunable/Tunable.h>
#include <ATen/cuda/tunable/TunableGemm.h>
#include <ATen/native/Resize.h>
#include <c10/util/MaybeOwned.h>
#include <ATen/native/GroupedMMUtils.h>
#include <ATen/native/cuda/RowwiseScaledMM.h>
#include <ATen/native/cuda/ScaledGroupMM.h>
#include <ATen/native/cuda/GroupMM.h>
#include <ATen/ceil_div.h>
#ifdef USE_FBGEMM_GENAI
#include <fbgemm_gpu/torch_ops.h>
#endif
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/_addmm_activation_native.h>
#include <ATen/ops/_efficientzerotensor.h>
#include <ATen/ops/_scaled_mm_native.h>
#include <ATen/ops/_unsafe_view_native.h>
#include <ATen/ops/abs.h>
#include <ATen/ops/addmm_native.h>
#include <ATen/ops/addmv_native.h>
#include <ATen/ops/baddbmm_native.h>
#include <ATen/ops/bmm_native.h>
#include <ATen/ops/copy_native.h>
#include <ATen/ops/dot_native.h>
#include <ATen/ops/empty.h>
#include <ATen/ops/empty_strided.h>
#include <ATen/ops/gelu.h>
#include <ATen/ops/max.h>
#include <ATen/ops/mm_native.h>
#include <ATen/ops/mul.h>
#include <ATen/ops/relu.h>
#include <ATen/ops/ones.h>
#include <ATen/ops/scalar_tensor_native.h>
#include <ATen/ops/vdot_native.h>
#endif
using at::blas::ScalingType;
using at::blas::SwizzleType;
namespace at::cuda::scaled {
static bool _scaled_mm_allowed_device(bool sm90_only=false, bool sm100_only=false) {
#ifdef USE_ROCM
static const std::vector<std::string> archs = {
"gfx942",
#if ROCM_VERSION >= 60300
"gfx1200", "gfx1201",
#endif
#if ROCM_VERSION >= 60500
"gfx950"
#endif
};
return at::detail::getCUDAHooks().isGPUArch(archs);
#else
auto dprops = at::cuda::getCurrentDeviceProperties();
if (sm90_only || sm100_only) {
return (sm90_only && dprops->major == 9) || (sm100_only && dprops->major == 10);
} else {
return dprops->major >= 9 || (dprops->major == 8 && dprops->minor == 9);
}
#endif
}
#ifdef USE_ROCM
static bool _scaled_mm_is_fnuz() {
return at::detail::getCUDAHooks().isGPUArch({"gfx942"});
}
#endif
/**
* Track concrete implementations available
*/
enum class ScaledGemmImplementation {
NONE = 0,
TENSORWISE_TENSORWISE = 1,
ROWWISE_ROWWISE = 2,
BLOCK_128x128_1x128 = 3,
BLOCK_1x128_128x128 = 4,
BLOCK_1x128_1x128 = 5,
MXFP8_MXFP8 = 6,
NVFP4_NVFP4 = 7,
NVFP4_NVFP4_SINGLE_SCALE = 8,
MXFP4_MXFP4 = 9,
};
/**
* Convert passed int (enum) from python back into a
* strictly-typed enum
*/
template <class EnumType, class ArrayType>
std::vector<EnumType> convert_int_to_enum(ArrayType& v) {
std::vector<EnumType> converted;
converted.reserve(v.size());
for (auto vi : v) {
converted.push_back(static_cast<EnumType>(vi));
}
return converted;
}
bool check_tensorwise_recipe(c10::ScalarType,
std::vector<ScalingType>&,
ArrayRef<Tensor>&,
c10::ScalarType,
std::vector<ScalingType>&,
ArrayRef<Tensor>&);
bool check_rowwise_recipe(c10::ScalarType,
std::vector<ScalingType>&,
ArrayRef<Tensor>&,
c10::ScalarType,
std::vector<ScalingType>&,
ArrayRef<Tensor>&);
bool check_nvfp4_recipe(c10::ScalarType,
std::vector<ScalingType>&,
ArrayRef<Tensor>&,
c10::ScalarType,
std::vector<ScalingType>&,
ArrayRef<Tensor>&);
bool check_nvfp4_recipe_single_scale
(c10::ScalarType,
std::vector<ScalingType>&,
ArrayRef<Tensor>&,
c10::ScalarType,
std::vector<ScalingType>&,
ArrayRef<Tensor>&);
bool check_deepseek_recipe(ScalingType,
ScalingType,
c10::ScalarType,
std::vector<ScalingType>&,
ArrayRef<Tensor>&,
c10::ScalarType,
std::vector<ScalingType>&,
ArrayRef<Tensor>&);
bool check_mxfp8_recipe(c10::ScalarType,
std::vector<ScalingType>&,
ArrayRef<Tensor>&,
c10::ScalarType,
std::vector<ScalingType>&,
ArrayRef<Tensor>&);
bool check_mxfp4_recipe(c10::ScalarType,
std::vector<ScalingType>&,
ArrayRef<Tensor>&,
c10::ScalarType,
std::vector<ScalingType>&,
ArrayRef<Tensor>&);
} // namespace at::native::cuda::blas::scaled

View File

@ -137,7 +137,7 @@ struct CUDACachingHostAllocatorImpl
void free_block_slowpath(Block* block) {
auto start = std::chrono::steady_clock::now();
// Users may change the allocator config at will. torch unit tests do this.
// However, allocations using cudaHostRegister should use corresonding
// However, allocations using cudaHostRegister should use corresponding
// cudaHostUnregister and similarly for cudaHostAlloc / cudaFreeHost.
void* ptr = block->ptr_;
bool use_register = false;

View File

@ -4,7 +4,7 @@
#include <ATen/cuda/CUDAConfig.h>
// NOTE: These templates are intentionally not defined in this header,
// which aviods re-compiling them for each translation unit. If you get
// which avoids re-compiling them for each translation unit. If you get
// a link error, you need to add an explicit instantiation for your
// types in cub.cu

View File

@ -38,7 +38,7 @@ GemmTunableOp_float_NT,nt_25088_4096_64,1219,1.262
GemmTunableOp_float_NT,nt_4096_4096_64,1216,0.033
```
Note the "Validator" lines. If you change a library verison, or ROCm version, or PyTorch version, TunableOp will detect
Note the "Validator" lines. If you change a library version, or ROCm version, or PyTorch version, TunableOp will detect
this and reject the tunings file because the prior tunings are likely affected by other software changes.
The remaining lines are the tuned solutions for each TunableOp encountered during your execution. Each line consists of

View File

@ -235,7 +235,7 @@ class TunableOp {
// numeric check option is controlled by non-static env var, so check it once per tuned operator
bool do_numerics_check = ctx->IsNumericsCheckEnabled();
// calcaulte a reference answer for numerical check
// calculate a reference answer for numerical check
if (do_numerics_check) {
reference_params = params->DeepCopy(false);
TORCH_CHECK(ops_[ResultEntry::Default()]->Call(reference_params) == OK);

View File

@ -12,7 +12,7 @@ namespace at {
// AcceleratorHooksInterface is a shared interface provided by all
// accelerators to allow generic code.
// This inferface is hook-based as it corresponds to all the functions
// This interface is hook-based as it corresponds to all the functions
// that are going to be called in a generic way from the CPU code.
struct TORCH_API AcceleratorHooksInterface {

View File

@ -38,7 +38,7 @@ struct TORCH_API PrivateUse1HooksInterface : AcceleratorHooksInterface {
Generator getNewGenerator(
[[maybe_unused]] DeviceIndex device_index = -1) const override {
// TODO(FFFrog): Perserved for BC and will be removed in the future.
// TODO(FFFrog): Preserved for BC and will be removed in the future.
if (at::GetGeneratorPrivate().has_value())
return at::GetGeneratorForPrivateuse1(device_index);

View File

@ -283,7 +283,7 @@ inline void boxed_existing_bdim_all_batch_rule(
// Use when all tensors arguments accept one (normal) batch dim.
// This batching rule expands the batch dim on all Tensors, reshapes it into
// dim 0, calls the op, and then reshapes the batch dim out of dim 0.
// This is not the most efficient thing; if there are alternatives, plese try
// This is not the most efficient thing; if there are alternatives, please try
// to use them. Use this only as a last resort.
#define EXISTING_BDIM_ALL_BOXED(op) \
m.impl(#op, torch::CppFunction::makeFromBoxedFunction<boxed_existing_bdim_all_batch_rule>());

View File

@ -384,7 +384,7 @@ fourOutputs solve_ex_batch_rule(
// NOTE [ solve_ex Batch Rule Contiguity ]
// A determines whether or not linalg_solve takes an optimized path. We need the check on A_ to match the one run on
// A as BatchedTensor since it might have been saved by autograd (specifically by the jvp) and the autograd behvaior
// A as BatchedTensor since it might have been saved by autograd (specifically by the jvp) and the autograd behavior
// differs based on whether or not the optimized path was taken
const auto batched_A_was_contiguous = A_bdim.has_value() ? at::select(A, *A_bdim, 0).is_contiguous() : A.is_contiguous();
if (batched_A_was_contiguous && !A.is_complex()) {

View File

@ -282,7 +282,7 @@ static std::tuple<Tensor, std::optional<int64_t>> _softmax_backward_batch_rule(
dim = getPhysicalDim(output_, /*has_batch_dim*/true, dim);
// Not sure why output_ needs to be marked as .contiguous(). Someting must
// Not sure why output_ needs to be marked as .contiguous(). Something must
// have changed in PyTorch (and output of softmax is probably always contiguous)
return std::make_tuple(at::_softmax_backward_data(grad_output_, output_.contiguous(), dim, input_dtype), 0);
}

View File

@ -224,7 +224,7 @@ static Tensor safeStack(TensorList tensors) {
// is possible for the backward function to return an undefined grad for some
// grad_input for each example. In that case, we return an undefined grad.
//
// It is theoretically posssible for *some* of the examples to produce an
// It is theoretically possible for *some* of the examples to produce an
// undefined grad (a kernel could peek at the gradient values and return an
// undefined tensor if it determines the gradient is full of zeros). We
// could handle this by treating the undefined grad as a zero-filled tensor

View File

@ -113,7 +113,7 @@ SymIntArrayRef BatchedTensorImpl::sym_sizes_custom() const {
return sym_sizes_default();
}
// The following are publically exposed as methods of Tensor
// The following are publicly exposed as methods of Tensor
IntArrayRef BatchedTensorImpl::strides_custom() const {
return strides_default();

View File

@ -37,7 +37,7 @@ namespace at::functorch {
// how to perform the transform.
//
// TODO: we can excise DynamicLayer in favor of Interpreter,
// But I am going to leave it for now as a compatiblity shim to avoid
// But I am going to leave it for now as a compatibility shim to avoid
// needing to refactor a lot of callsites...
struct TORCH_API DynamicLayer {
explicit DynamicLayer(

View File

@ -88,7 +88,7 @@ std::ostream& operator<<(std::ostream& os, const TransformType& t);
// >>> VmapInterpreterPtr(&interpreter).batchSize()
//
// Finally, Interpreter::process switches on the type of the interpreter
// and calls one of {Transform}Intepreter::processImpl under the hood.
// and calls one of {Transform}Interpreter::processImpl under the hood.
// Same for Interpreter::sendToNextInterpreter :)
struct VmapInterpreterMeta {

View File

@ -733,7 +733,7 @@ TORCH_LIBRARY_IMPL(_, FuncTorchBatched, m) {
}
TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) {
// still legacy b/c teturns multiple tensors
// still legacy b/c returns multiple tensors
m.impl("split.Tensor", split_batching_rule);
m.impl("split_with_sizes", split_with_sizes_batching_rule);
m.impl("split_with_sizes_copy", split_with_sizes_copy_batching_rule);

View File

@ -158,7 +158,7 @@ void MPSStream::fill(id<MTLBuffer> buffer, uint8_t value, size_t length, size_t
endKernelCoalescing();
id<MTLBlitCommandEncoder> blitEncoder = [commandBuffer() blitCommandEncoder];
// For some reason fillBufferfor stopped working for lengh > 4Gb on MacOS 26
// For some reason fillBufferfor stopped working for length > 4Gb on MacOS 26
// See https://github.com/pytorch/pytorch/issues/163962
// Workaround by batching copy commands into 4Gb chunks
constexpr size_t max_copy_size = 0x100000000; // 4GB

View File

@ -148,7 +148,7 @@ inline void checkInputsSolver(const Tensor& A,
inline bool is_row_or_column_contiguous(const Tensor& t) {
// This could be made more general, similar to how it's checked in matmul, which would allow to
// ellide the copy with strides such as (6, 12, 1, 3) or (3, 1, 9), but this is quite tricky.
// elide the copy with strides such as (6, 12, 1, 3) or (3, 1, 9), but this is quite tricky.
// We choose to be conservative for simplicity
return t.is_contiguous() || t.transpose(-2, -1).is_contiguous();
}

View File

@ -21,7 +21,7 @@ enum class fft_norm_mode {
// NOTE [ Fourier Transform Conjugate Symmetry ]
//
// Real-to-complex Fourier transform satisfies the conjugate symmetry. That is,
// assuming X is the transformed K-dimensionsal signal, we have
// assuming X is the transformed K-dimensional signal, we have
//
// X[i_1, ..., i_K] = X[j_i, ..., j_K]*,
//

View File

@ -128,7 +128,7 @@ at::Tensor PackedLinearWeight::apply_impl(
auto* input_tr_ptr =
reinterpret_cast<uint8_t*>(input_tr.data_ptr<c10::quint8>());
// TODO: Activation transpose before and after the kernel can be removed if we
// keep activation tensor always tranposed.
// keep activation tensor always transposed.
fbgemm::transpose_simd<uint8_t>(
batch_size, K, input_ptr, K, input_tr_ptr, batch_size);

View File

@ -520,7 +520,7 @@ cpu_adaptive_avg_pool3d_channels_last(
scalar_t* out = output_data + i * channels;
int64_t size = channels;
// Note: For oridinary usage scenario, each out lane should
// Note: For ordinary usage scenario, each out lane should
// fit in L1 cache; otherwise consider block dim C.
// Pass I: zero the out lane
int64_t d1 = 0;

View File

@ -34,7 +34,7 @@ struct Dist {
// finish : This tells what to do with the aggregated value to compute
// the norm. Generally this is the result of val ^ (1 / p).
// backward : This is the gradient for that norm. Arguments are pretty
// self explanitory.
// self explanatory.
//
// There are a few cases where these aren't used. The 0 norm has no backward,
// because it's always 0, so that's shortcircuited earlier. There's a special

View File

@ -30,7 +30,7 @@ vec::Vectorized<scalar_t> is_nan_vec(vec::Vectorized<scalar_t> vec) {
return vec.isnan();
}
// TODO: use is_integeral/is_same to check the scalar_t and simplify the implementation
// TODO: use is_integral/is_same to check the scalar_t and simplify the implementation
// currently it does not work
template <>
vec::Vectorized<unsigned char> is_nan_vec<unsigned char>(vec::Vectorized<unsigned char> vec) {

View File

@ -74,7 +74,7 @@ it to sum up the entire array into a single value.
`ReduceOpsKernel.cpp` uses the `CPU_CAPABILITY_*` macros to "know" under which
compiler flags it is currently compiled. This allows the programmer to write
generic code, which will be compiled under multipled compilation settings.
generic code, which will be compiled under multiplied compilation settings.
`../ReduceOps.cpp` now includes the header `ReduceOpsKernel.h`, which contains
a generic definition of `sumImplAll`. This function allows the user to reduce

View File

@ -889,7 +889,7 @@ void ImagingResampleHorizontalConvolution8u(
_mm_loadu_si128((__m128i *) (lineIn_min + stride * i))),
_mm_loadu_si128((__m128i *) (lineIn_min + stride * (i + 4))), 1);
// Extract lower part of each lane, cast to epi16 and reoder RGBARGBA -> RRGGBBAA
// Extract lower part of each lane, cast to epi16 and reorder RGBARGBA -> RRGGBBAA
// RGBA: pix1 = [
// r0 0 r1 0 g0 0 g1 0 b0 0 b1 0 a0 0 a1 0
// r4 0 r5 0 g4 0 g5 0 b4 0 b5 0 a4 0 a5 0

View File

@ -240,7 +240,7 @@ _PS256_CONST(coscof_p2, 4.166664568298827E-002);
_PS256_CONST(cephes_FOPI, 1.27323954473516); // 4 / M_PI
/* evaluation of 8 sines at onces using AVX intrinsics
/* evaluation of 8 sines at once using AVX intrinsics
The code is the exact rewriting of the cephes sinf function.
Precision is excellent as long as x < 8192 (I did not bother to

View File

@ -311,7 +311,7 @@ void GroupNormKernelImplChannelsLastInternal(
const bool gamma_null = (gamma_data == nullptr);
const bool beta_null = beta_data == nullptr;
// NB: About algorithm choosen:
// NB: About algorithm chosen:
//
// On channels last, GroupNorm has a input shape of {N, H, W, GD},
// Mean and rstd are collected per each n and g, which involves reduction

View File

@ -930,7 +930,7 @@ void ref_dyn_quant_matmul_4bit_channelwise_kernel(
}
};
// Dynamically Quantize the float32 input to 8 bit assymetric
// Dynamically Quantize the float32 input to 8 bit asymmetric
input_quant_pack_8bit_channelwise(m, k, lhs_f32, (int8_t*)lhs_qa8dx);
const size_t lhs_stride =
@ -1163,7 +1163,7 @@ void dyn_quant_matmul_4bit_kernel(
const int64_t weight_packed_size =
kleidiai::kai_pack_rhs_int4_size(N, K, block_size);
if (weight_packed_size == packed_weights.numel()) {
// KleidiAI interface intenally handles the Channelwise and groupwise
// KleidiAI interface internally handles the Channelwise and groupwise
// distinction
kleidiai::kai_quant_pack_lhs_int4_mm(
output, inp, packed_weights, M, N, K, block_size);

View File

@ -13,6 +13,7 @@
#include <ATen/OpMathType.h>
#include <ATen/TensorUtils.h>
#include <ATen/cuda/CUDABlas.h>
#include <ATen/cuda/CUDAScaledBlas.h>
#include <ATen/cuda/tunable/Tunable.h>
#include <ATen/cuda/tunable/TunableGemm.h>
#include <ATen/native/Resize.h>
@ -360,7 +361,7 @@ static bool isInputCompliesAddmmCudaLt(Tensor& result, const Tensor& self, const
// and the leading stride is at least max(1, other dim length), so we might
// end up with contiguous cols but not rows (i.e. holes between different rows)
// and vice versa.
mat2_sizes[0] < 65535 * 32 && mat2_sizes[1] < 65535 * 32 &&
&& mat2_sizes[0] < 65535 * 32 && mat2_sizes[1] < 65535 * 32 &&
mat1_sizes[0] < 65535 * 32 && mat1_sizes[1] < 65535 * 32 &&
&& (
// filter by dtype
@ -504,7 +505,7 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
// Handle whether to use the Lt interface {
static bool persistent_disable_addmm_cuda_lt = isGloballyDisabledAddmmCudaLt(self.device());
// if lt path fails, we recurse back into this function here and force the lt path to off
// we cannot update varible disable_addmm_cuda_lt from above since it is static and would be permanent
// we cannot update variable disable_addmm_cuda_lt from above since it is static and would be permanent
bool disable_addmm_cuda_lt = persistent_disable_addmm_cuda_lt || disable_addmm_cuda_lt_override;
#ifdef USE_ROCM
// Conditioned on the device index, which is not persistent
@ -1628,104 +1629,6 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
return _scaled_gemm(mat1, mat2, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, use_fast_accum, out);
}
namespace {
void _check_scales_fp8_rowwise(const Tensor& mat, const Tensor& scale, const int dim, const int arg_idx, const int scale_multiplier=1) {
// Checks scales for 2d or 3d target tensors (`mat`).
if (mat.dim() == 2) {
TORCH_CHECK(
scale.dim() == 1,
"scale must be a 1D tensor, but got ",
scale.dim(),
"D, arg ",
arg_idx);
TORCH_CHECK(
scale.is_contiguous(), "scale must be contiguous for arg ", arg_idx);
TORCH_CHECK(
scale.size(0) == mat.size(dim) * scale_multiplier,
"scale must have the same length as mat for arg ",
arg_idx);
} else {
TORCH_CHECK(
scale.dim() == 2,
"scale must be a 2D tensor, but got ",
scale.dim(),
"D for arg ",
arg_idx);
TORCH_CHECK(
scale.stride(1) == 1,
"scale must be contiguous in the last dimension for arg ",
arg_idx);
TORCH_CHECK(
scale.size(0) == mat.size(0),
"scale must have the same batch dimension as mat for arg ",
arg_idx);
TORCH_CHECK(
scale.size(1) == mat.size(1 + dim),
"scale must have the same first dimension as mat for arg ",
arg_idx);
}
}
void _check_scales_mxfp8(const Tensor& mat, const Tensor& scale, const int dim, const int arg_idx) {
// Checks scales for 2d or 3d target tensors (`mat`).
if (mat.dim() == 2) {
// For MXFP8, 2d tensors have variable size groups represented as subtensors,
// that are converted to blocked padded format individually,
// so we can't check the scale sizes without doing a d2h sync to get the group sizes here.
TORCH_CHECK(
scale.dim() == mat.dim(),
"for mxfp8, scale must have same number of dimensions as parent tensor, but got mat.dim() = ", mat.dim(), " and scale.dim() = ", scale.dim(), " for arg ", arg_idx);
// LHS mat shape (M, total_K) -> scale shape (rounded_up(M, 128), rounded_up_per_group(K/32, 4))
// RHS mat shape (total_K, N) -> scale shape (rounded_up(N, 128), rounded_up_per_group(K/32, 4))
// * weight is transposed prior to the call, scale stays non-transposed.
bool LHS = arg_idx == 0;
int scale_dim_to_check = 0;
int mat_dim_to_check = LHS ? 0 : 1;
TORCH_CHECK(
scale.size(scale_dim_to_check) >= mat.size(mat_dim_to_check),
"for mxfp8, arg ", arg_idx, " tensor shape (", mat.size(0), ", ", mat.size(1), ") ",
"must have scale.shape[", scale_dim_to_check, "] >= ", mat.size(mat_dim_to_check), " but got scale.shape=(", scale.size(0), ", ", scale.size(1), ")");
} else {
// For MXFP8, 3d tensors have static group sizes (stack of 2d tensors),
// so we can check the exact expected scale sizes here without a d2h sync.
auto round_up = [](auto x, auto y) {
return ((x + y - 1) / y) * y;
};
// TODO: this is for 3d tensor in 2d-3d case specifically.
// We'll need to support 3d-3d and 3d-2d cases once mxfp8 grouped gemm supports them.
int64_t G = mat.size(0);
int64_t K = mat.size(1);
int64_t N = mat.size(2);
int64_t blocked_scale_K = round_up(K/32, 4);
int64_t blocked_scale_N = round_up(N, 128);
// fbgemm expects stack of flattened blocked scales for 3d tensor, shape (G, blocked_scale_K * blocked_scale_N).
TORCH_CHECK(
scale.dim() == mat.dim() - 1,
"for mxfp8 2d-3d grouped GEMM, the 3d tensor of shape (G,K,N) must have a 2d scale of shape (G, blocked_scale_K * blocked_scale_N), but scale is ", scale.dim(), "D for arg ", arg_idx
);
TORCH_CHECK(
scale.size(0) == G && scale.size(1) == blocked_scale_K * blocked_scale_N,
"for mxfp8, the tensor shape (", G, ", ", K, ", ", N, ") must have scale shape (", G, ",", blocked_scale_K, ",", blocked_scale_N, ") for arg ", arg_idx
);
}
}
void check_scale(const Tensor& mat, const Tensor& scale, const int dim, const int arg_idx, const int scale_multiplier=1) {
bool using_fp8_rowwise = scale.scalar_type() == kFloat;
bool using_mxfp8 = scale.scalar_type() == at::kFloat8_e8m0fnu;
if (using_fp8_rowwise) {
_check_scales_fp8_rowwise(mat, scale, dim, arg_idx, scale_multiplier);
} else if (using_mxfp8) {
_check_scales_mxfp8(mat, scale, dim, arg_idx);
} else {
TORCH_CHECK(false, "scale must be float32 or float8_e8m0fnu, but got ", scale.dtype());
}
}
}
Tensor
_scaled_mm_cuda(const Tensor& mat_a, const Tensor& mat_b,
const Tensor& scale_a,
@ -1740,261 +1643,26 @@ _scaled_mm_cuda(const Tensor& mat_a, const Tensor& mat_b,
return _scaled_mm_out_cuda(mat_a, mat_b, scale_a, scale_b, bias, scale_result, out_dtype, use_fast_accum, out);
}
/**
* Track concrete implementations available
*/
enum class ScaledGemmImplementation {
NONE = 0,
TENSORWISE_TENSORWISE = 1,
ROWWISE_ROWWISE = 2,
BLOCK_128x128_1x128 = 3,
BLOCK_1x128_128x128 = 4,
BLOCK_1x128_1x128 = 5,
MXFP8_MXFP8 = 6,
NVFP4_NVFP4 = 7,
NVFP4_NVFP4_SINGLE_SCALE = 8,
MXFP4_MXFP4 = 9,
};
/**
* Convert passed int (enum) from python back into a
* strictly-typed enum
*/
template <class EnumType, class ArrayType>
std::vector<EnumType> convert_int_to_enum(ArrayType& v) {
std::vector<EnumType> converted;
converted.reserve(v.size());
for (auto vi : v) {
converted.push_back(static_cast<EnumType>(vi));
}
return converted;
}
/**
* Both inputs must be fp8,
* Each needs a single scale, {Tensorwise (float)}
*/
bool check_tensorwise_recipe(c10::ScalarType type_a,
std::vector<ScalingType>& recipe_a,
ArrayRef<Tensor>& scales_a,
c10::ScalarType type_b,
std::vector<ScalingType>& recipe_b,
ArrayRef<Tensor>& scales_b) {
// both types must be fp8
if (!isFloat8Type(type_a) || !isFloat8Type(type_b)) {
return false;
}
// 1 scale each, {Tensorwise, float}
if (scales_a.size() != 1 || recipe_a.size() != 1 || scales_b.size() != 1 || recipe_b.size() != 1) {
return false;
}
// Need {Blockwise_1x32, e8m0} for A & B
if (recipe_a[0] != ScalingType::TensorWise) return false;
if (scales_a[0].scalar_type() != ScalarType::Float) return false;
if (recipe_b[0] != ScalingType::TensorWise) return false;
if (scales_b[0].scalar_type() != ScalarType::Float) return false;
return true;
}
/**
* Both inputs must be fp8,
* Each needs scales, {Rowwise (float)}
*/
bool check_rowwise_recipe(c10::ScalarType type_a,
std::vector<ScalingType>& recipe_a,
ArrayRef<Tensor>& scales_a,
c10::ScalarType type_b,
std::vector<ScalingType>& recipe_b,
ArrayRef<Tensor>& scales_b) {
// both types must be fp8
if (!isFloat8Type(type_a) || !isFloat8Type(type_b)) {
return false;
}
// 1 scale each, {Tensorwise, float}
if (scales_a.size() != 1 || recipe_a.size() != 1 || scales_b.size() != 1 || recipe_b.size() != 1) {
return false;
}
// Need {RowWise, dp32} for A & B
if (recipe_a[0] != ScalingType::RowWise) return false;
if (scales_a[0].scalar_type() != ScalarType::Float) return false;
if (recipe_b[0] != ScalingType::RowWise) return false;
if (scales_b[0].scalar_type() != ScalarType::Float) return false;
return true;
}
/**
* Two-level scaling, canonical NVFP4
* Both inputs must be fp4
* A, B need 2 scales, {Blockwise_1x16 (e4m3), Tensorwise (fp32)}
*/
bool check_nvfp4_recipe(c10::ScalarType type_a,
std::vector<ScalingType>& recipe_a,
ArrayRef<Tensor>& scales_a,
c10::ScalarType type_b,
std::vector<ScalingType>& recipe_b,
ArrayRef<Tensor>& scales_b) {
// both types must be fp4
if (type_a != ScalarType::Float4_e2m1fn_x2 || type_b != ScalarType::Float4_e2m1fn_x2) {
return false;
}
// 2 scales, 2 recipes for each input
if (scales_a.size() != 2 || recipe_a.size() != 2 || scales_b.size() != 2 || recipe_b.size() != 2) {
return false;
}
// Need {Blockwise_1x16, e4m3 for scale[0], Tensorwise, fp32 for scale[1]}
if (recipe_a[0] != ScalingType::BlockWise1x16 || recipe_a[1] != ScalingType::TensorWise) return false;
if (scales_a[0].scalar_type() != ScalarType::Float8_e4m3fn || scales_a[1].scalar_type() != ScalarType::Float) return false;
if (recipe_b[0] != ScalingType::BlockWise1x16 || recipe_b[1] != ScalingType::TensorWise) return false;
if (scales_b[0].scalar_type() != ScalarType::Float8_e4m3fn || scales_b[1].scalar_type() != ScalarType::Float) return false;
return true;
}
/**
* Single-level scaling, what PyT currently understands
* Both inputs must be fp4
* A, B need 1 scale, {Blockwise_1x16 (e4m3)}
*/
bool check_nvfp4_recipe_single_scale
(c10::ScalarType type_a,
std::vector<ScalingType>& recipe_a,
ArrayRef<Tensor>& scales_a,
c10::ScalarType type_b,
std::vector<ScalingType>& recipe_b,
ArrayRef<Tensor>& scales_b) {
// both types must be fp4
if (type_a != ScalarType::Float4_e2m1fn_x2 || type_b != ScalarType::Float4_e2m1fn_x2) {
return false;
}
// 2 scales, 2 recipes for each input
if (scales_a.size() != 1 || recipe_a.size() != 1 || scales_b.size() != 1 || recipe_b.size() != 1) {
return false;
}
// Need {Blockwise_1x16, e4m3 for scale[0], Tensorwise, fp32 for scale[1]}
if (recipe_a[0] != ScalingType::BlockWise1x16) return false;
if (scales_a[0].scalar_type() != ScalarType::Float8_e4m3fn) return false;
if (recipe_b[0] != ScalingType::BlockWise1x16) return false;
if (scales_b[0].scalar_type() != ScalarType::Float8_e4m3fn) return false;
return true;
}
/**
* Both inputs must be fp8
* A, B must only have 1 scale each, A: {Blockwise_1x128 (float), B: {Blockwise_128x128 (float)
*/
bool check_deepseek_recipe(ScalingType expected_recipe_a,
ScalingType expected_recipe_b,
c10::ScalarType type_a,
std::vector<ScalingType>& recipe_a,
ArrayRef<Tensor>& scales_a,
c10::ScalarType type_b,
std::vector<ScalingType>& recipe_b,
ArrayRef<Tensor>& scales_b) {
// both types must be fp8
if (type_a != ScalarType::Float8_e4m3fn || type_b != ScalarType::Float8_e4m3fn) {
return false;
}
// 1 scales, 1 recipes for each input
if (scales_a.size() != 1 || recipe_a.size() != 1 || scales_b.size() != 1 || recipe_b.size() != 1) {
return false;
}
// Need {Blockwise_1x128, float} for A, {Blockwise_128x128, float} for B
if (recipe_a[0] != expected_recipe_a) return false;
if (scales_a[0].scalar_type() != ScalarType::Float) return false;
if (recipe_b[0] != expected_recipe_b) return false;
if (scales_b[0].scalar_type() != ScalarType::Float) return false;
return true;
}
/**
* Both inputs must be fp8
* A, B must have 1 scale each, {Blockwise_1x32, e8m0}
*/
bool check_mxfp8_recipe(c10::ScalarType type_a,
std::vector<ScalingType>& recipe_a,
ArrayRef<Tensor>& scales_a,
c10::ScalarType type_b,
std::vector<ScalingType>& recipe_b,
ArrayRef<Tensor>& scales_b) {
// both types must be fp8
if (type_a != ScalarType::Float8_e4m3fn || type_b != ScalarType::Float8_e4m3fn) {
return false;
}
// 1 scales, 1 recipes for each input
if (scales_a.size() != 1 || recipe_a.size() != 1 || scales_b.size() != 1 || recipe_b.size() != 1) {
return false;
}
// Need {Blockwise_1x32, e8m0} for A & B
if (recipe_a[0] != ScalingType::BlockWise1x32) return false;
if (scales_a[0].scalar_type() != ScalarType::Float8_e8m0fnu) return false;
if (recipe_b[0] != ScalingType::BlockWise1x32) return false;
if (scales_b[0].scalar_type() != ScalarType::Float8_e8m0fnu) return false;
return true;
}
/**
* Both inputs must be fp4
* A, B must have 1 scale each, {Blockwise_1x32, e8m0}
*/
bool check_mxfp4_recipe(c10::ScalarType type_a,
std::vector<ScalingType>& recipe_a,
ArrayRef<Tensor>& scales_a,
c10::ScalarType type_b,
std::vector<ScalingType>& recipe_b,
ArrayRef<Tensor>& scales_b) {
// both types must be fp4
if (type_a != ScalarType::Float4_e2m1fn_x2 || type_b != ScalarType::Float4_e2m1fn_x2) {
return false;
}
// 1 scales, 1 recipes for each input
if (scales_a.size() != 1 || recipe_a.size() != 1 || scales_b.size() != 1 || recipe_b.size() != 1) {
return false;
}
// Need {Blockwise_1x32, e8m0} for A & B
if (recipe_a[0] != ScalingType::BlockWise1x32) return false;
if (scales_a[0].scalar_type() != ScalarType::Float8_e8m0fnu) return false;
if (recipe_b[0] != ScalingType::BlockWise1x32) return false;
if (scales_b[0].scalar_type() != ScalarType::Float8_e8m0fnu) return false;
return true;
}
using acceptance_fn = std::function<bool(c10::ScalarType, std::vector<ScalingType>&, ArrayRef<Tensor>&, c10::ScalarType, std::vector<ScalingType>&, ArrayRef<Tensor>&)>;
using namespace std::placeholders;
namespace scaled_blas = at::cuda::scaled;
using scaled_blas::ScaledGemmImplementation;
using scaled_blas::convert_int_to_enum;
std::array<std::tuple<std::string, acceptance_fn, ScaledGemmImplementation>, 9> scale_kernel_dispatch = {{
{ "tensorwise_tensorwise", check_tensorwise_recipe, ScaledGemmImplementation::TENSORWISE_TENSORWISE },
{ "rowwise_rowwise", check_rowwise_recipe, ScaledGemmImplementation::ROWWISE_ROWWISE},
{ "block_1x128_128x128", std::bind(check_deepseek_recipe, ScalingType::BlockWise1x128, ScalingType::BlockWise128x128, _1, _2, _3, _4, _5, _6),
{ "tensorwise_tensorwise", scaled_blas::check_tensorwise_recipe, ScaledGemmImplementation::TENSORWISE_TENSORWISE },
{ "rowwise_rowwise", scaled_blas::check_rowwise_recipe, ScaledGemmImplementation::ROWWISE_ROWWISE},
{ "block_1x128_128x128", std::bind(scaled_blas::check_deepseek_recipe, ScalingType::BlockWise1x128, ScalingType::BlockWise128x128, _1, _2, _3, _4, _5, _6),
ScaledGemmImplementation::BLOCK_1x128_128x128},
{ "block_128x128_1x128", std::bind(check_deepseek_recipe, ScalingType::BlockWise128x128, ScalingType::BlockWise1x128, _1, _2, _3, _4, _5, _6),
{ "block_128x128_1x128", std::bind(scaled_blas::check_deepseek_recipe, ScalingType::BlockWise128x128, ScalingType::BlockWise1x128, _1, _2, _3, _4, _5, _6),
ScaledGemmImplementation::BLOCK_128x128_1x128},
{ "block_1x128_1x128", std::bind(check_deepseek_recipe, ScalingType::BlockWise1x128, ScalingType::BlockWise1x128, _1, _2, _3, _4, _5, _6),
{ "block_1x128_1x128", std::bind(scaled_blas::check_deepseek_recipe, ScalingType::BlockWise1x128, ScalingType::BlockWise1x128, _1, _2, _3, _4, _5, _6),
ScaledGemmImplementation::BLOCK_1x128_1x128},
{ "nvfp4_nvfp4", check_nvfp4_recipe, ScaledGemmImplementation::NVFP4_NVFP4},
{ "nvfp4_nvfp4_single_scale", check_nvfp4_recipe_single_scale, ScaledGemmImplementation::NVFP4_NVFP4_SINGLE_SCALE },
{ "mxfp8_mxfp8", check_mxfp8_recipe, ScaledGemmImplementation::MXFP8_MXFP8},
{ "mxfp4_mxfp4", check_mxfp4_recipe, ScaledGemmImplementation::MXFP4_MXFP4}}};
{ "nvfp4_nvfp4", scaled_blas::check_nvfp4_recipe, ScaledGemmImplementation::NVFP4_NVFP4},
{ "nvfp4_nvfp4_single_scale", scaled_blas::check_nvfp4_recipe_single_scale, ScaledGemmImplementation::NVFP4_NVFP4_SINGLE_SCALE },
{ "mxfp8_mxfp8", scaled_blas::check_mxfp8_recipe, ScaledGemmImplementation::MXFP8_MXFP8},
{ "mxfp4_mxfp4", scaled_blas::check_mxfp4_recipe, ScaledGemmImplementation::MXFP4_MXFP4}}};
Tensor&
_scaled_tensorwise_tensorwise(
@ -2596,410 +2264,6 @@ _scaled_mm_cuda_v2(
out);
}
// 2d-2d and 2d-3d
// scaling=MXFP8
// CUDA-only
Tensor&
_mx8_mx8_bf16_grouped_mm_fbgemm(
const Tensor& mat_a,
const Tensor& mat_b,
const Tensor& scale_a,
const SwizzleType& swizzle_a,
const Tensor& scale_b,
const SwizzleType& swizzle_b,
const std::optional<at::Tensor>& offs,
Tensor& out) {
const bool a_is_2d = mat_a.dim() == 2;
const bool b_is_2d = mat_b.dim() == 2;
bool b_is_3d = mat_b.dim() == 3;
bool is_2d_2d = a_is_2d && b_is_2d;
bool is_2d_3d = a_is_2d && b_is_3d;
TORCH_CHECK_VALUE(is_2d_2d || is_2d_3d, "MXFP8 grouped GEMM currently only supports 2d-2d and 2d-3d cases");
TORCH_CHECK_VALUE(offs.has_value(), "MXFP8 2d-2d and 2d-3d grouped GEMMs requires offsets");
TORCH_CHECK_VALUE(out.scalar_type() == at::kBFloat16, "Only bf16 out_dtype is supported for MXFP8 grouped gemm");
// MXFP8 expects float8_e8m0fnu scales.
TORCH_CHECK_VALUE(scale_a.scalar_type() == at::kFloat8_e8m0fnu && scale_b.scalar_type() == at::kFloat8_e8m0fnu,
"For MXFP8 grouped gemm, both scales must be float8_e8m0fnu tensors.");
#ifdef USE_ROCM
TORCH_CHECK_VALUE(swizzle_a == SwizzleType::NO_SWIZZLE && swizzle_b == SwizzleType::NO_SWIZZLE,
"For ROCM MXFP8 grouped gemm, both scale swizzle types must be SWIZZLE_NONE");
#else
TORCH_CHECK_VALUE(swizzle_a == SwizzleType::SWIZZLE_32_4_4 && swizzle_b == SwizzleType::SWIZZLE_32_4_4,
"For CUDA MXFP8 grouped gemm, both scale swizzle types must be SWIZZLE_32_4_4");
#endif
#if defined(USE_FBGEMM_GENAI) and !defined(USE_ROCM)
fbgemm_gpu::mx8mx8bf16_grouped_mm(
mat_a,
mat_b,
scale_a,
scale_b,
offs.value(),
out);
#else
TORCH_CHECK_NOT_IMPLEMENTED(false, "mxfp8_mxfp8 grouped gemm requires compile with USE_FBGEMM_GENAI");
#endif
return out;
}
// 2d-2d and 2d-3d cases
// scaling=rowwise
// CUDA-only
Tensor&
_f8_f8_bf16_rowwise_grouped_mm_cuda(
const Tensor& mat_a,
const Tensor& mat_b,
const Tensor& scale_a,
const Tensor& scale_b,
const std::optional<Tensor>& offs,
const std::optional<Tensor>& bias,
const bool use_fast_accum,
Tensor& out) {
TORCH_CHECK_VALUE(mat_a.dtype() == at::kFloat8_e4m3fn, "Expected mat_a to be Float8_e4m3 matrix got ", mat_a.scalar_type());
TORCH_CHECK_VALUE(mat_b.dtype() == at::kFloat8_e4m3fn, "Expected mat_a to be Float8_e4m3 matrix got ", mat_b.scalar_type());
at::cuda::detail::f8f8bf16_grouped_mm(
mat_a,
mat_b,
scale_a,
scale_b,
offs,
bias,
use_fast_accum,
out);
return out;
}
// 2d-2d and 2d-3d cases
// scaling=rowwise
// only being called for rocm
Tensor&
_f8_f8_bf16_rowwise_grouped_mm_rocm(
const Tensor& mat_a,
const Tensor& mat_b,
const Tensor& scale_a,
const Tensor& scale_b,
const std::optional<Tensor>& offs,
Tensor& out) {
TORCH_CHECK_VALUE(mat_a.dtype() == at::kFloat8_e4m3fnuz, "Expected mat_a to be Float8_e4m3fnuz matrix got ", mat_a.scalar_type());
TORCH_CHECK_VALUE(mat_b.dtype() == at::kFloat8_e4m3fnuz, "Expected mat_a to be Float8_e4m3fnuz matrix got ", mat_b.scalar_type());
#if defined(USE_FBGEMM_GENAI) && defined(USE_ROCM)
fbgemm_gpu::f8f8bf16_rowwise_grouped_mm(
mat_a,
// FBGEMM expects B matrix shape to be (.., N, K)
mat_b.transpose(-2, -1),
scale_a,
scale_b,
offs,
out);
#else
TORCH_CHECK_NOT_IMPLEMENTED(false, "grouped gemm is not supported without USE_FBGEMM_GENAI on ROCM")
#endif
return out;
}
// Dispatch f8 x f8 -> bf16 row-wise scaled to rocm/cuda
Tensor&
_f8_f8_bf16_rowwise_grouped_mm(
const Tensor& mat_a,
const Tensor& mat_b,
const Tensor& scale_a,
const Tensor& scale_b,
const std::optional<Tensor>& offs,
const std::optional<Tensor>& bias,
bool use_fast_accum,
Tensor& out) {
// FP8 per-tensor and per-row scaling expect fp32 scales.
TORCH_CHECK_VALUE(scale_a.scalar_type() == kFloat && scale_b.scalar_type() == kFloat,
"For grouped FP8 rowwise, both scales must be float32 tensors");
#ifndef USE_ROCM
return _f8_f8_bf16_rowwise_grouped_mm_cuda(
mat_a,
mat_b,
scale_a,
scale_b,
offs,
bias,
use_fast_accum,
out);
#else
// NOTE: ignore use_fast_accum
TORCH_CHECK_VALUE(!bias.has_value(), "ROCM grouped gemm does not support bias")
return _f8_f8_bf16_rowwise_grouped_mm_rocm(
mat_a,
mat_b,
scale_a,
scale_b,
offs,
out);
#endif
}
Tensor
_scaled_grouped_mm_cuda(
const Tensor& mat_a,
const Tensor& mat_b,
const Tensor& scale_a,
const Tensor& scale_b,
const std::optional<at::Tensor>& offs,
const std::optional<at::Tensor>& bias,
const std::optional<at::Tensor>& scale_result,
std::optional<c10::ScalarType> out_dtype,
bool use_fast_accum) {
bool allowed_device = _scaled_mm_allowed_device(/*sm90_only*/true, /*sm100_only*/true);
TORCH_CHECK_VALUE(allowed_device, "torch._scaled_grouped_mm is only supported on CUDA devices with compute capability = [9.0, 10.0], or ROCm MI300+");
TORCH_CHECK_VALUE(!check_valid_strides_and_return_transposed(mat_a), "Expected mat1 to not be transposed");
TORCH_CHECK_VALUE(check_valid_strides_and_return_transposed(mat_b), "Expected mat2 to be transposed");
TORCH_CHECK_VALUE(mat_a.dim() == 2 || mat_a.dim() == 3, "mat_a has to be 2 or 3d");
TORCH_CHECK_VALUE(mat_b.dim() == 2 || mat_b.dim() == 3, "mat_b has to be 2 or 3d");
const bool a_is_2d = mat_a.dim() == 2;
const bool b_is_2d = mat_b.dim() == 2;
// NOTE(slayton): For sub-1B formats want contraction_dim argument?
if (!a_is_2d || !b_is_2d) {
TORCH_CHECK_VALUE(mat_a.size(-1) == mat_b.size(-2), "contraction dimension of mat_a and mat_b must match");
}
TORCH_CHECK_VALUE(
mat_a.size(-1) % 16 == 0,
"Expected trailing dimension of mat_a to be divisible by 16 ",
"but got mat1 shape: (",
mat_a.sizes(),
").");
TORCH_CHECK_VALUE(mat_b.size(-2) % 16 == 0 && mat_b.size(-1) % 16 == 0,
"Expected mat_b shape to be divisible by 16 ",
"but got mat_b shape: (",
mat_b.sizes(),
").");
TORCH_CHECK_VALUE(!bias.has_value(), "Bias not supported yet");
TORCH_CHECK_VALUE(!scale_result.has_value(), "Scale result not supported yet");
TORCH_CHECK_VALUE(offs.has_value() == (a_is_2d || b_is_2d), "Have to provide offsets if there is a 2d matrix");
// NOTE: mxfp8 x mxfp8 requires (and asserts later) that offsets is present.
// for rowwise, no offsets implies 3d-3d and is handled by lower-level
// routines
if (offs.has_value()) {
TORCH_CHECK_VALUE(offs->dim() == 1, "offs has to be 1D");
TORCH_CHECK_VALUE(offs->dtype() == at::kInt, "Offsets have to be int32");
}
// FP8 per-tensor and per-row scaling expect fp32 scales.
// MXFP8 expects float8_e8m0fnu scales.
TORCH_CHECK_VALUE(
(scale_a.scalar_type() == kFloat && scale_b.scalar_type() == kFloat) ||
(scale_a.scalar_type() == at::kFloat8_e8m0fnu && scale_b.scalar_type() == at::kFloat8_e8m0fnu),
"For FP8 tensorwise and rowwise, both scales must both be float32 tensors. For MXFP8, scales must both be float8_e8m0fnu tensors.");
const int scale_multiplier = (mat_a.dim() == 2 && mat_b.dim() == 2) ? offs->size(0) : 1;
check_scale(mat_a, scale_a, 0 ,0, scale_multiplier);
check_scale(mat_b, scale_b, 1, 1, scale_multiplier);
const auto out_dtype_ = out_dtype.value_or(kBFloat16);
TORCH_CHECK_VALUE(out_dtype_ == kBFloat16, "Only bf16 high precision output types are supported for grouped gemm");
Tensor out = create_grouped_gemm_output_tensor(mat_a, mat_b, offs, out_dtype_);
#if defined(USE_FBGEMM_GENAI) && defined(USE_CUDA) && !defined(USE_ROCM)
// MXFP8 grouped GEMM dispatching
bool is_mx8mx8bf16 = (
mat_a.scalar_type() == at::kFloat8_e4m3fn && mat_b.scalar_type() == at::kFloat8_e4m3fn &&
scale_a.scalar_type() == at::kFloat8_e8m0fnu && scale_b.scalar_type() == at::kFloat8_e8m0fnu
);
#else
bool is_mx8mx8bf16 = false;
#endif
if (is_mx8mx8bf16) {
// Note: Passing implied SwizzleType here, correctness of scale previously checked
// in `check_scale` call
return _mx8_mx8_bf16_grouped_mm_fbgemm(
mat_a,
mat_b,
scale_a,
SwizzleType::SWIZZLE_32_4_4,
scale_b,
SwizzleType::SWIZZLE_32_4_4,
offs.value(),
out);
}
// If we're not MXFP8, then we're row-wise scaling.
return _f8_f8_bf16_rowwise_grouped_mm(
mat_a,
mat_b,
scale_a,
scale_b,
offs,
bias,
use_fast_accum,
out);
}
namespace {
std::array<std::tuple<std::string, acceptance_fn, ScaledGemmImplementation>, 2> scale_grouped_kernel_dispatch = {{
{ "rowwise_rowwise", check_rowwise_recipe, ScaledGemmImplementation::ROWWISE_ROWWISE},
{ "mxfp8_mxfp8", check_mxfp8_recipe, ScaledGemmImplementation::MXFP8_MXFP8}}};
} // anonymous namespace
Tensor
_scaled_grouped_mm_cuda_v2(
const Tensor& mat_a, const Tensor& mat_b,
ArrayRef<Tensor> scale_a,
IntArrayRef scale_recipe_a,
IntArrayRef swizzle_a,
ArrayRef<Tensor> scale_b,
IntArrayRef scale_recipe_b,
IntArrayRef swizzle_b,
const std::optional<Tensor>& offs,
const std::optional<Tensor>& bias,
const std::optional<c10::ScalarType> out_dtype,
IntArrayRef contraction_dim,
bool use_fast_accum) {
bool allowed_device = _scaled_mm_allowed_device(/*sm90_only*/true, /*sm100_only*/true);
TORCH_CHECK_VALUE(allowed_device, "torch._scaled_grouped_mm is only supported on CUDA devices with compute capability = [9.0, 10.0], or ROCm MI300+");
TORCH_CHECK_VALUE(!check_valid_strides_and_return_transposed(mat_a), "Expected mat1 to not be transposed");
TORCH_CHECK_VALUE(check_valid_strides_and_return_transposed(mat_b), "Expected mat2 to be transposed");
TORCH_CHECK_VALUE(mat_a.dim() == 2 || mat_a.dim() == 3, "mat_a has to be 2 or 3d");
TORCH_CHECK_VALUE(mat_b.dim() == 2 || mat_b.dim() == 3, "mat_b has to be 2 or 3d");
const bool a_is_2d = mat_a.dim() == 2;
const bool b_is_2d = mat_b.dim() == 2;
// NOTE(slayton): For sub-1B formats want contraction_dim argument?
if (!a_is_2d || !b_is_2d) {
if (contraction_dim.size() > 0) {
const int dim_a = contraction_dim[0], dim_b = mat_b.size(contraction_dim[1]);
TORCH_CHECK_VALUE(mat_a.size(dim_a) == mat_b.size(dim_b),
"Contraction dimensions (", dim_a, ",", dim_b, ") of mat_a and mat_b must match, got: ", mat_a.size(dim_a), " and ",
mat_b.size(dim_b));
// Note: only (-1, -2) is currently supported
TORCH_CHECK_VALUE(dim_a == -1 && dim_b == -2, "Curently contraction dims must be (-1, -2) only");
} else {
TORCH_CHECK_VALUE(mat_a.size(-1) == mat_b.size(-2), "contraction dimension of mat_a and mat_b must match");
}
}
TORCH_CHECK_VALUE(
mat_a.size(-1) % 16 == 0,
"Expected trailing dimension of mat_a to be divisible by 16 ",
"but got mat1 shape: (",
mat_a.sizes(),
").");
TORCH_CHECK_VALUE(mat_b.size(-2) % 16 == 0 && mat_b.size(-1) % 16 == 0,
"Expected mat_b shape to be divisible by 16 ",
"but got mat_b shape: (",
mat_b.sizes(),
").");
TORCH_CHECK_VALUE(!bias.has_value(), "Bias not supported yet");
TORCH_CHECK_VALUE(offs.has_value() == (a_is_2d || b_is_2d), "Have to provide offsets if there is a 2d matrix");
// NOTE: mxfp8 x mxfp8 requires (and asserts later) that offsets is present.
// for rowwise, no offsets implies 3d-3d and is handled by lower-level
// routines
if (offs.has_value()) {
TORCH_CHECK_VALUE(offs->dim() == 1, "offs has to be 1D");
TORCH_CHECK_VALUE(offs->dtype() == at::kInt, "Offsets have to be int32");
}
const auto out_dtype_ = out_dtype.value_or(kBFloat16);
TORCH_CHECK_VALUE(out_dtype_ == kBFloat16, "Only bf16 high precision output types are supported for grouped gemm");
Tensor out = create_grouped_gemm_output_tensor(mat_a, mat_b, offs, out_dtype_);
// Conversion of implicitly-defined enums to explicit
auto scale_recipe_a_enum = convert_int_to_enum<ScalingType>(scale_recipe_a);
auto swizzle_a_enum = convert_int_to_enum<SwizzleType>(swizzle_a);
auto scale_recipe_b_enum = convert_int_to_enum<ScalingType>(scale_recipe_b);
auto swizzle_b_enum = convert_int_to_enum<SwizzleType>(swizzle_b);
// at this point we can start working out what we want to be doing
// Try to do as few steps as possible.
// NOTE: support is deliberately sparse, can explicitly enumerate all combinations allowed.
// Do this via a list of defined (name, acceptance, concrete_impl) tuples.
ScaledGemmImplementation gemm_impl = ScaledGemmImplementation::NONE;
for (const auto& fn_entry : scale_grouped_kernel_dispatch) {
const auto [name, accept_fn, scaled_gemm_impl] = fn_entry;
bool ok = accept_fn(mat_a.scalar_type(),
scale_recipe_a_enum,
scale_a,
mat_b.scalar_type(),
scale_recipe_b_enum,
scale_b);
if (ok) {
gemm_impl = scaled_gemm_impl;
break;
}
}
TORCH_CHECK_VALUE(gemm_impl != ScaledGemmImplementation::NONE,
"No gemm implementation was found");
switch (gemm_impl) {
case ScaledGemmImplementation::ROWWISE_ROWWISE: {
const int scale_multiplier = (mat_a.dim() == 2 && mat_b.dim() == 2) ? offs->size(0) : 1;
_check_scales_fp8_rowwise(mat_a, scale_a[0], 0 /* dim */ , 0 /* arg_idx */, scale_multiplier);
_check_scales_fp8_rowwise(mat_b, scale_b[0], 1 /* dim */ , 1 /* arg_idx */, scale_multiplier);
return _f8_f8_bf16_rowwise_grouped_mm(
mat_a,
mat_b,
scale_a[0],
scale_b[0],
offs,
bias,
use_fast_accum,
out);
}
case ScaledGemmImplementation::MXFP8_MXFP8: {
_check_scales_mxfp8(mat_a, scale_a[0], 0 /* dim */, 0 /* arg_idx */);
_check_scales_mxfp8(mat_b, scale_b[0], 1 /* dim */, 1 /* arg_idx */);
return _mx8_mx8_bf16_grouped_mm_fbgemm(
mat_a,
mat_b,
scale_a[0],
swizzle_a_enum[0],
scale_b[0],
swizzle_b_enum[0],
offs.value(),
out);
}
default:
TORCH_CHECK_NOT_IMPLEMENTED(false,
"_scaled_grouped_mm_cuda_v2 is in an inconsistent state - should never reach here");
}
}
Tensor _grouped_mm_cuda(const Tensor& mat_a, const Tensor& mat_b,
const std::optional<at::Tensor>& offs,
const std::optional<at::Tensor>& bias,
std::optional<c10::ScalarType> out_dtype) {
_grouped_mm_validate_inputs(mat_a, mat_b, offs, bias, out_dtype);
bool a_b_and_out_are_bf16 = (
mat_a.dtype() == at::kBFloat16 &&
mat_b.dtype() == at::kBFloat16 &&
out_dtype.value_or(at::kBFloat16) == at::kBFloat16
);
#ifndef USE_ROCM
bool use_fast_path = _scaled_mm_allowed_device(/*sm90_only*/true, /*sm100_only*/true) && a_b_and_out_are_bf16;
#else
// _scaled_mm_allowed_device is used here within _grouped_mm_cuda which seems incorrect since scale is not used.
// the _grouped_mm_fallback should be safe for any ROCm GPU since it's just calling typical mm/bmm
bool use_fast_path = false;
#endif
const auto out_dtype_ = _resolve_grouped_mm_out_dtype(mat_a, mat_b, out_dtype);
Tensor out = create_grouped_gemm_output_tensor(mat_a, mat_b, offs, out_dtype_);
if (use_fast_path) {
// fast path, no d2h sync needed
at::cuda::detail::bf16bf16_grouped_mm(mat_a, mat_b, offs, bias, out);
} else {
_grouped_mm_fallback(mat_a, mat_b, offs, bias, out_dtype, out);
}
return out;
}
static void baddbmm_bmm_out_dtype_checks(const Tensor& batch1, const Tensor& batch2, const Scalar& beta, const Scalar& alpha, const at::ScalarType out_dtype, bool is_bmm, const std::optional<Tensor>& self_baddbmm = std::nullopt) {
// ref ATen/native/LinearAlgebra.cpp common_checks_baddbmm_bmm
TORCH_CHECK(batch1.dim() == 3, "batch1 must be a 3D tensor");

View File

@ -494,7 +494,7 @@ void uniform_kernel(TensorIteratorBase& iter, double from_, double to_, RNG gen)
auto value = static_cast<scalar_t>(rand * range + from);
// reverse the bounds of curand4 from (0, 1] to [0, 1)
// Note that this method is from legacy THCTensorRandom and is likely to give
// you more 0-s, since, the probability of gettings 1-s is higher than 0-s and
// you more 0-s, since, the probability of getting 1-s is higher than 0-s and
// by reversing the bounds, we are flipping the probabilities of 1-s and 0-s.
// BEFORE TOUCHING THIS CODE READ: https://github.com/pytorch/pytorch/issues/16706
auto reverse_bound_value = value == to ? from : value;

View File

@ -0,0 +1,574 @@
#include <cstdint>
#include <c10/util/typeid.h>
#include <c10/util/Exception.h>
#include <c10/util/SmallVector.h>
#include <c10/core/Scalar.h>
#include <c10/core/ScalarType.h>
#include <c10/util/Exception.h>
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/core/Tensor.h>
#include <ATen/core/NamedTensor.h>
#include <ATen/Dispatch.h>
#include <ATen/ExpandUtils.h>
#include <ATen/OpMathType.h>
#include <ATen/TensorUtils.h>
#include <ATen/cuda/CUDABlas.h>
#include <ATen/cuda/CUDAScaledBlas.h>
#include <ATen/cuda/tunable/Tunable.h>
#include <ATen/cuda/tunable/TunableGemm.h>
#include <ATen/native/Resize.h>
#include <c10/util/MaybeOwned.h>
#include <ATen/native/GroupedMMUtils.h>
#include <ATen/native/cuda/RowwiseScaledMM.h>
#include <ATen/native/cuda/ScaledGroupMM.h>
#include <ATen/native/cuda/GroupMM.h>
#include <ATen/ceil_div.h>
#ifdef USE_FBGEMM_GENAI
#include <fbgemm_gpu/torch_ops.h>
#endif
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/_addmm_activation_native.h>
#include <ATen/ops/_efficientzerotensor.h>
#include <ATen/ops/_scaled_mm_native.h>
#include <ATen/ops/_unsafe_view_native.h>
#include <ATen/ops/abs.h>
#include <ATen/ops/addmm_native.h>
#include <ATen/ops/addmv_native.h>
#include <ATen/ops/baddbmm_native.h>
#include <ATen/ops/bmm_native.h>
#include <ATen/ops/copy_native.h>
#include <ATen/ops/dot_native.h>
#include <ATen/ops/empty.h>
#include <ATen/ops/empty_strided.h>
#include <ATen/ops/gelu.h>
#include <ATen/ops/max.h>
#include <ATen/ops/mm_native.h>
#include <ATen/ops/mul.h>
#include <ATen/ops/relu.h>
#include <ATen/ops/ones.h>
#include <ATen/ops/scalar_tensor_native.h>
#include <ATen/ops/vdot_native.h>
#endif
using at::blas::ScalingType;
using at::blas::SwizzleType;
namespace scaled_blas = at::cuda::scaled;
using scaled_blas::ScaledGemmImplementation;
using scaled_blas::convert_int_to_enum;
using scaled_blas::_scaled_mm_allowed_device;
namespace at::native {
namespace {
// 2d-2d and 2d-3d
// scaling=MXFP8
// CUDA-only
Tensor&
_mx8_mx8_bf16_grouped_mm_fbgemm(
const Tensor& mat_a,
const Tensor& mat_b,
const Tensor& scale_a,
const SwizzleType& swizzle_a,
const Tensor& scale_b,
const SwizzleType& swizzle_b,
const std::optional<at::Tensor>& offs,
Tensor& out) {
const bool a_is_2d = mat_a.dim() == 2;
const bool b_is_2d = mat_b.dim() == 2;
bool b_is_3d = mat_b.dim() == 3;
bool is_2d_2d = a_is_2d && b_is_2d;
bool is_2d_3d = a_is_2d && b_is_3d;
TORCH_CHECK_VALUE(is_2d_2d || is_2d_3d, "MXFP8 grouped GEMM currently only supports 2d-2d and 2d-3d cases");
TORCH_CHECK_VALUE(offs.has_value(), "MXFP8 2d-2d and 2d-3d grouped GEMMs requires offsets");
TORCH_CHECK_VALUE(out.scalar_type() == at::kBFloat16, "Only bf16 out_dtype is supported for MXFP8 grouped gemm");
// MXFP8 expects float8_e8m0fnu scales.
TORCH_CHECK_VALUE(scale_a.scalar_type() == at::kFloat8_e8m0fnu && scale_b.scalar_type() == at::kFloat8_e8m0fnu,
"For MXFP8 grouped gemm, both scales must be float8_e8m0fnu tensors.");
#ifdef USE_ROCM
TORCH_CHECK_VALUE(swizzle_a == SwizzleType::NO_SWIZZLE && swizzle_b == SwizzleType::NO_SWIZZLE,
"For ROCM MXFP8 grouped gemm, both scale swizzle types must be SWIZZLE_NONE");
#else
TORCH_CHECK_VALUE(swizzle_a == SwizzleType::SWIZZLE_32_4_4 && swizzle_b == SwizzleType::SWIZZLE_32_4_4,
"For CUDA MXFP8 grouped gemm, both scale swizzle types must be SWIZZLE_32_4_4");
#endif
#if defined(USE_FBGEMM_GENAI) and !defined(USE_ROCM)
fbgemm_gpu::mx8mx8bf16_grouped_mm(
mat_a,
mat_b,
scale_a,
scale_b,
offs.value(),
out);
#else
TORCH_CHECK_NOT_IMPLEMENTED(false, "mxfp8_mxfp8 grouped gemm requires compile with USE_FBGEMM_GENAI");
#endif
return out;
}
// 2d-2d and 2d-3d cases
// scaling=rowwise
// CUDA-only
Tensor&
_f8_f8_bf16_rowwise_grouped_mm_cuda(
const Tensor& mat_a,
const Tensor& mat_b,
const Tensor& scale_a,
const Tensor& scale_b,
const std::optional<Tensor>& offs,
const std::optional<Tensor>& bias,
const bool use_fast_accum,
Tensor& out) {
TORCH_CHECK_VALUE(mat_a.dtype() == at::kFloat8_e4m3fn, "Expected mat_a to be Float8_e4m3 matrix got ", mat_a.scalar_type());
TORCH_CHECK_VALUE(mat_b.dtype() == at::kFloat8_e4m3fn, "Expected mat_a to be Float8_e4m3 matrix got ", mat_b.scalar_type());
at::cuda::detail::f8f8bf16_grouped_mm(
mat_a,
mat_b,
scale_a,
scale_b,
offs,
bias,
use_fast_accum,
out);
return out;
}
// 2d-2d and 2d-3d cases
// scaling=rowwise
// only being called for rocm
Tensor&
_f8_f8_bf16_rowwise_grouped_mm_rocm(
const Tensor& mat_a,
const Tensor& mat_b,
const Tensor& scale_a,
const Tensor& scale_b,
const std::optional<Tensor>& offs,
Tensor& out) {
TORCH_CHECK_VALUE(mat_a.dtype() == at::kFloat8_e4m3fnuz, "Expected mat_a to be Float8_e4m3fnuz matrix got ", mat_a.scalar_type());
TORCH_CHECK_VALUE(mat_b.dtype() == at::kFloat8_e4m3fnuz, "Expected mat_a to be Float8_e4m3fnuz matrix got ", mat_b.scalar_type());
#if defined(USE_FBGEMM_GENAI) && defined(USE_ROCM)
fbgemm_gpu::f8f8bf16_rowwise_grouped_mm(
mat_a,
// FBGEMM expects B matrix shape to be (.., N, K)
mat_b.transpose(-2, -1),
scale_a,
scale_b,
offs,
out);
#else
TORCH_CHECK_NOT_IMPLEMENTED(false, "grouped gemm is not supported without USE_FBGEMM_GENAI on ROCM")
#endif
return out;
}
// Dispatch f8 x f8 -> bf16 row-wise scaled to rocm/cuda
Tensor&
_f8_f8_bf16_rowwise_grouped_mm(
const Tensor& mat_a,
const Tensor& mat_b,
const Tensor& scale_a,
const Tensor& scale_b,
const std::optional<Tensor>& offs,
const std::optional<Tensor>& bias,
bool use_fast_accum,
Tensor& out) {
// FP8 per-tensor and per-row scaling expect fp32 scales.
TORCH_CHECK_VALUE(scale_a.scalar_type() == kFloat && scale_b.scalar_type() == kFloat,
"For grouped FP8 rowwise, both scales must be float32 tensors");
#ifndef USE_ROCM
return _f8_f8_bf16_rowwise_grouped_mm_cuda(
mat_a,
mat_b,
scale_a,
scale_b,
offs,
bias,
use_fast_accum,
out);
#else
// NOTE: ignore use_fast_accum
TORCH_CHECK_VALUE(!bias.has_value(), "ROCM grouped gemm does not support bias")
return _f8_f8_bf16_rowwise_grouped_mm_rocm(
mat_a,
mat_b,
scale_a,
scale_b,
offs,
out);
#endif
}
void _check_scales_fp8_rowwise(const Tensor& mat, const Tensor& scale, const int dim, const int arg_idx, const int scale_multiplier=1) {
// Checks scales for 2d or 3d target tensors (`mat`).
if (mat.dim() == 2) {
TORCH_CHECK(
scale.dim() == 1,
"scale must be a 1D tensor, but got ",
scale.dim(),
"D, arg ",
arg_idx);
TORCH_CHECK(
scale.is_contiguous(), "scale must be contiguous for arg ", arg_idx);
TORCH_CHECK(
scale.size(0) == mat.size(dim) * scale_multiplier,
"scale must have the same length as mat for arg ",
arg_idx);
} else {
TORCH_CHECK(
scale.dim() == 2,
"scale must be a 2D tensor, but got ",
scale.dim(),
"D for arg ",
arg_idx);
TORCH_CHECK(
scale.stride(1) == 1,
"scale must be contiguous in the last dimension for arg ",
arg_idx);
TORCH_CHECK(
scale.size(0) == mat.size(0),
"scale must have the same batch dimension as mat for arg ",
arg_idx);
TORCH_CHECK(
scale.size(1) == mat.size(1 + dim),
"scale must have the same first dimension as mat for arg ",
arg_idx);
}
}
void _check_scales_mxfp8(const Tensor& mat, const Tensor& scale, const int dim, const int arg_idx) {
// Checks scales for 2d or 3d target tensors (`mat`).
if (mat.dim() == 2) {
// For MXFP8, 2d tensors have variable size groups represented as subtensors,
// that are converted to blocked padded format individually,
// so we can't check the scale sizes without doing a d2h sync to get the group sizes here.
TORCH_CHECK(
scale.dim() == mat.dim(),
"for mxfp8, scale must have same number of dimensions as parent tensor, but got mat.dim() = ", mat.dim(), " and scale.dim() = ", scale.dim(), " for arg ", arg_idx);
// LHS mat shape (M, total_K) -> scale shape (rounded_up(M, 128), rounded_up_per_group(K/32, 4))
// RHS mat shape (total_K, N) -> scale shape (rounded_up(N, 128), rounded_up_per_group(K/32, 4))
// * weight is transposed prior to the call, scale stays non-transposed.
bool LHS = arg_idx == 0;
int scale_dim_to_check = 0;
int mat_dim_to_check = LHS ? 0 : 1;
TORCH_CHECK(
scale.size(scale_dim_to_check) >= mat.size(mat_dim_to_check),
"for mxfp8, arg ", arg_idx, " tensor shape (", mat.size(0), ", ", mat.size(1), ") ",
"must have scale.shape[", scale_dim_to_check, "] >= ", mat.size(mat_dim_to_check), " but got scale.shape=(", scale.size(0), ", ", scale.size(1), ")");
} else {
// For MXFP8, 3d tensors have static group sizes (stack of 2d tensors),
// so we can check the exact expected scale sizes here without a d2h sync.
auto round_up = [](auto x, auto y) {
return ((x + y - 1) / y) * y;
};
// TODO: this is for 3d tensor in 2d-3d case specifically.
// We'll need to support 3d-3d and 3d-2d cases once mxfp8 grouped gemm supports them.
int64_t G = mat.size(0);
int64_t K = mat.size(1);
int64_t N = mat.size(2);
int64_t blocked_scale_K = round_up(K/32, 4);
int64_t blocked_scale_N = round_up(N, 128);
// fbgemm expects stack of flattened blocked scales for 3d tensor, shape (G, blocked_scale_K * blocked_scale_N).
TORCH_CHECK(
scale.dim() == mat.dim() - 1,
"for mxfp8 2d-3d grouped GEMM, the 3d tensor of shape (G,K,N) must have a 2d scale of shape (G, blocked_scale_K * blocked_scale_N), but scale is ", scale.dim(), "D for arg ", arg_idx
);
TORCH_CHECK(
scale.size(0) == G && scale.size(1) == blocked_scale_K * blocked_scale_N,
"for mxfp8, the tensor shape (", G, ", ", K, ", ", N, ") must have scale shape (", G, ",", blocked_scale_K, ",", blocked_scale_N, ") for arg ", arg_idx
);
}
}
void check_scale(const Tensor& mat, const Tensor& scale, const int dim, const int arg_idx, const int scale_multiplier=1) {
bool using_fp8_rowwise = scale.scalar_type() == kFloat;
bool using_mxfp8 = scale.scalar_type() == at::kFloat8_e8m0fnu;
if (using_fp8_rowwise) {
_check_scales_fp8_rowwise(mat, scale, dim, arg_idx, scale_multiplier);
} else if (using_mxfp8) {
_check_scales_mxfp8(mat, scale, dim, arg_idx);
} else {
TORCH_CHECK(false, "scale must be float32 or float8_e8m0fnu, but got ", scale.dtype());
}
}
} // namespace
Tensor
_scaled_grouped_mm_cuda(
const Tensor& mat_a,
const Tensor& mat_b,
const Tensor& scale_a,
const Tensor& scale_b,
const std::optional<at::Tensor>& offs,
const std::optional<at::Tensor>& bias,
const std::optional<at::Tensor>& scale_result,
std::optional<c10::ScalarType> out_dtype,
bool use_fast_accum) {
bool allowed_device = _scaled_mm_allowed_device(/*sm90_only*/true, /*sm100_only*/true);
TORCH_CHECK_VALUE(allowed_device, "torch._scaled_grouped_mm is only supported on CUDA devices with compute capability = [9.0, 10.0], or ROCm MI300+");
TORCH_CHECK_VALUE(!check_valid_strides_and_return_transposed(mat_a), "Expected mat1 to not be transposed");
TORCH_CHECK_VALUE(check_valid_strides_and_return_transposed(mat_b), "Expected mat2 to be transposed");
TORCH_CHECK_VALUE(mat_a.dim() == 2 || mat_a.dim() == 3, "mat_a has to be 2 or 3d");
TORCH_CHECK_VALUE(mat_b.dim() == 2 || mat_b.dim() == 3, "mat_b has to be 2 or 3d");
const bool a_is_2d = mat_a.dim() == 2;
const bool b_is_2d = mat_b.dim() == 2;
// NOTE(slayton): For sub-1B formats want contraction_dim argument?
if (!a_is_2d || !b_is_2d) {
TORCH_CHECK_VALUE(mat_a.size(-1) == mat_b.size(-2), "contraction dimension of mat_a and mat_b must match");
}
TORCH_CHECK_VALUE(
mat_a.size(-1) % 16 == 0,
"Expected trailing dimension of mat_a to be divisible by 16 ",
"but got mat1 shape: (",
mat_a.sizes(),
").");
TORCH_CHECK_VALUE(mat_b.size(-2) % 16 == 0 && mat_b.size(-1) % 16 == 0,
"Expected mat_b shape to be divisible by 16 ",
"but got mat_b shape: (",
mat_b.sizes(),
").");
TORCH_CHECK_VALUE(!bias.has_value(), "Bias not supported yet");
TORCH_CHECK_VALUE(!scale_result.has_value(), "Scale result not supported yet");
TORCH_CHECK_VALUE(offs.has_value() == (a_is_2d || b_is_2d), "Have to provide offsets if there is a 2d matrix");
// NOTE: mxfp8 x mxfp8 requires (and asserts later) that offsets is present.
// for rowwise, no offsets implies 3d-3d and is handled by lower-level
// routines
if (offs.has_value()) {
TORCH_CHECK_VALUE(offs->dim() == 1, "offs has to be 1D");
TORCH_CHECK_VALUE(offs->dtype() == at::kInt, "Offsets have to be int32");
}
// FP8 per-tensor and per-row scaling expect fp32 scales.
// MXFP8 expects float8_e8m0fnu scales.
TORCH_CHECK_VALUE(
(scale_a.scalar_type() == kFloat && scale_b.scalar_type() == kFloat) ||
(scale_a.scalar_type() == at::kFloat8_e8m0fnu && scale_b.scalar_type() == at::kFloat8_e8m0fnu),
"For FP8 tensorwise and rowwise, both scales must both be float32 tensors. For MXFP8, scales must both be float8_e8m0fnu tensors.");
const int scale_multiplier = (mat_a.dim() == 2 && mat_b.dim() == 2) ? offs->size(0) : 1;
check_scale(mat_a, scale_a, 0 ,0, scale_multiplier);
check_scale(mat_b, scale_b, 1, 1, scale_multiplier);
const auto out_dtype_ = out_dtype.value_or(kBFloat16);
TORCH_CHECK_VALUE(out_dtype_ == kBFloat16, "Only bf16 high precision output types are supported for grouped gemm");
Tensor out = create_grouped_gemm_output_tensor(mat_a, mat_b, offs, out_dtype_);
#if defined(USE_FBGEMM_GENAI) && defined(USE_CUDA) && !defined(USE_ROCM)
// MXFP8 grouped GEMM dispatching
bool is_mx8mx8bf16 = (
mat_a.scalar_type() == at::kFloat8_e4m3fn && mat_b.scalar_type() == at::kFloat8_e4m3fn &&
scale_a.scalar_type() == at::kFloat8_e8m0fnu && scale_b.scalar_type() == at::kFloat8_e8m0fnu
);
#else
bool is_mx8mx8bf16 = false;
#endif
if (is_mx8mx8bf16) {
// Note: Passing implied SwizzleType here, correctness of scale previously checked
// in `check_scale` call
return _mx8_mx8_bf16_grouped_mm_fbgemm(
mat_a,
mat_b,
scale_a,
SwizzleType::SWIZZLE_32_4_4,
scale_b,
SwizzleType::SWIZZLE_32_4_4,
offs.value(),
out);
}
// If we're not MXFP8, then we're row-wise scaling.
return _f8_f8_bf16_rowwise_grouped_mm(
mat_a,
mat_b,
scale_a,
scale_b,
offs,
bias,
use_fast_accum,
out);
}
namespace {
using acceptance_fn = std::function<bool(c10::ScalarType, std::vector<ScalingType>&, ArrayRef<Tensor>&, c10::ScalarType, std::vector<ScalingType>&, ArrayRef<Tensor>&)>;
std::array<std::tuple<std::string, acceptance_fn, ScaledGemmImplementation>, 2> scale_grouped_kernel_dispatch = {{
{ "rowwise_rowwise", scaled_blas::check_rowwise_recipe, ScaledGemmImplementation::ROWWISE_ROWWISE},
{ "mxfp8_mxfp8", scaled_blas::check_mxfp8_recipe, ScaledGemmImplementation::MXFP8_MXFP8}}};
} // anonymous namespace
Tensor
_scaled_grouped_mm_cuda_v2(
const Tensor& mat_a, const Tensor& mat_b,
ArrayRef<Tensor> scale_a,
IntArrayRef scale_recipe_a,
IntArrayRef swizzle_a,
ArrayRef<Tensor> scale_b,
IntArrayRef scale_recipe_b,
IntArrayRef swizzle_b,
const std::optional<Tensor>& offs,
const std::optional<Tensor>& bias,
const std::optional<c10::ScalarType> out_dtype,
IntArrayRef contraction_dim,
bool use_fast_accum) {
bool allowed_device = _scaled_mm_allowed_device(/*sm90_only*/true, /*sm100_only*/true);
TORCH_CHECK_VALUE(allowed_device, "torch._scaled_grouped_mm is only supported on CUDA devices with compute capability = [9.0, 10.0], or ROCm MI300+");
TORCH_CHECK_VALUE(!check_valid_strides_and_return_transposed(mat_a), "Expected mat1 to not be transposed");
TORCH_CHECK_VALUE(check_valid_strides_and_return_transposed(mat_b), "Expected mat2 to be transposed");
TORCH_CHECK_VALUE(mat_a.dim() == 2 || mat_a.dim() == 3, "mat_a has to be 2 or 3d");
TORCH_CHECK_VALUE(mat_b.dim() == 2 || mat_b.dim() == 3, "mat_b has to be 2 or 3d");
const bool a_is_2d = mat_a.dim() == 2;
const bool b_is_2d = mat_b.dim() == 2;
// NOTE(slayton): For sub-1B formats want contraction_dim argument?
if (!a_is_2d || !b_is_2d) {
if (contraction_dim.size() > 0) {
const int dim_a = contraction_dim[0], dim_b = mat_b.size(contraction_dim[1]);
TORCH_CHECK_VALUE(mat_a.size(dim_a) == mat_b.size(dim_b),
"Contraction dimensions (", dim_a, ",", dim_b, ") of mat_a and mat_b must match, got: ", mat_a.size(dim_a), " and ",
mat_b.size(dim_b));
// Note: only (-1, -2) is currently supported
TORCH_CHECK_VALUE(dim_a == -1 && dim_b == -2, "Curently contraction dims must be (-1, -2) only");
} else {
TORCH_CHECK_VALUE(mat_a.size(-1) == mat_b.size(-2), "contraction dimension of mat_a and mat_b must match");
}
}
TORCH_CHECK_VALUE(
mat_a.size(-1) % 16 == 0,
"Expected trailing dimension of mat_a to be divisible by 16 ",
"but got mat1 shape: (",
mat_a.sizes(),
").");
TORCH_CHECK_VALUE(mat_b.size(-2) % 16 == 0 && mat_b.size(-1) % 16 == 0,
"Expected mat_b shape to be divisible by 16 ",
"but got mat_b shape: (",
mat_b.sizes(),
").");
TORCH_CHECK_VALUE(!bias.has_value(), "Bias not supported yet");
TORCH_CHECK_VALUE(offs.has_value() == (a_is_2d || b_is_2d), "Have to provide offsets if there is a 2d matrix");
// NOTE: mxfp8 x mxfp8 requires (and asserts later) that offsets is present.
// for rowwise, no offsets implies 3d-3d and is handled by lower-level
// routines
if (offs.has_value()) {
TORCH_CHECK_VALUE(offs->dim() == 1, "offs has to be 1D");
TORCH_CHECK_VALUE(offs->dtype() == at::kInt, "Offsets have to be int32");
}
const auto out_dtype_ = out_dtype.value_or(kBFloat16);
TORCH_CHECK_VALUE(out_dtype_ == kBFloat16, "Only bf16 high precision output types are supported for grouped gemm");
Tensor out = create_grouped_gemm_output_tensor(mat_a, mat_b, offs, out_dtype_);
// Conversion of implicitly-defined enums to explicit
auto scale_recipe_a_enum = convert_int_to_enum<ScalingType>(scale_recipe_a);
auto swizzle_a_enum = convert_int_to_enum<SwizzleType>(swizzle_a);
auto scale_recipe_b_enum = convert_int_to_enum<ScalingType>(scale_recipe_b);
auto swizzle_b_enum = convert_int_to_enum<SwizzleType>(swizzle_b);
// at this point we can start working out what we want to be doing
// Try to do as few steps as possible.
// NOTE: support is deliberately sparse, can explicitly enumerate all combinations allowed.
// Do this via a list of defined (name, acceptance, concrete_impl) tuples.
ScaledGemmImplementation gemm_impl = ScaledGemmImplementation::NONE;
for (const auto& fn_entry : scale_grouped_kernel_dispatch) {
const auto [name, accept_fn, scaled_gemm_impl] = fn_entry;
bool ok = accept_fn(mat_a.scalar_type(),
scale_recipe_a_enum,
scale_a,
mat_b.scalar_type(),
scale_recipe_b_enum,
scale_b);
if (ok) {
gemm_impl = scaled_gemm_impl;
break;
}
}
TORCH_CHECK_VALUE(gemm_impl != ScaledGemmImplementation::NONE,
"No gemm implementation was found");
switch (gemm_impl) {
case ScaledGemmImplementation::ROWWISE_ROWWISE: {
const int scale_multiplier = (mat_a.dim() == 2 && mat_b.dim() == 2) ? offs->size(0) : 1;
_check_scales_fp8_rowwise(mat_a, scale_a[0], 0 /* dim */ , 0 /* arg_idx */, scale_multiplier);
_check_scales_fp8_rowwise(mat_b, scale_b[0], 1 /* dim */ , 1 /* arg_idx */, scale_multiplier);
return _f8_f8_bf16_rowwise_grouped_mm(
mat_a,
mat_b,
scale_a[0],
scale_b[0],
offs,
bias,
use_fast_accum,
out);
}
case ScaledGemmImplementation::MXFP8_MXFP8: {
_check_scales_mxfp8(mat_a, scale_a[0], 0 /* dim */, 0 /* arg_idx */);
_check_scales_mxfp8(mat_b, scale_b[0], 1 /* dim */, 1 /* arg_idx */);
return _mx8_mx8_bf16_grouped_mm_fbgemm(
mat_a,
mat_b,
scale_a[0],
swizzle_a_enum[0],
scale_b[0],
swizzle_b_enum[0],
offs.value(),
out);
}
default:
TORCH_CHECK_NOT_IMPLEMENTED(false,
"_scaled_grouped_mm_cuda_v2 is in an inconsistent state - should never reach here");
}
}
Tensor _grouped_mm_cuda(const Tensor& mat_a, const Tensor& mat_b,
const std::optional<at::Tensor>& offs,
const std::optional<at::Tensor>& bias,
std::optional<c10::ScalarType> out_dtype) {
_grouped_mm_validate_inputs(mat_a, mat_b, offs, bias, out_dtype);
bool a_b_and_out_are_bf16 = (
mat_a.dtype() == at::kBFloat16 &&
mat_b.dtype() == at::kBFloat16 &&
out_dtype.value_or(at::kBFloat16) == at::kBFloat16
);
#ifndef USE_ROCM
bool use_fast_path = _scaled_mm_allowed_device(/*sm90_only*/true, /*sm100_only*/true) && a_b_and_out_are_bf16;
#else
// _scaled_mm_allowed_device is used here within _grouped_mm_cuda which seems incorrect since scale is not used.
// the _grouped_mm_fallback should be safe for any ROCm GPU since it's just calling typical mm/bmm
bool use_fast_path = false;
#endif
const auto out_dtype_ = _resolve_grouped_mm_out_dtype(mat_a, mat_b, out_dtype);
Tensor out = create_grouped_gemm_output_tensor(mat_a, mat_b, offs, out_dtype_);
if (use_fast_path) {
// fast path, no d2h sync needed
at::cuda::detail::bf16bf16_grouped_mm(mat_a, mat_b, offs, bias, out);
} else {
_grouped_mm_fallback(mat_a, mat_b, offs, bias, out_dtype, out);
}
return out;
}
} // namespace at::native

View File

@ -6,7 +6,7 @@
#endif
// ROCm 6.3 is planned to have these functions, but until then here they are.
#if defined(USE_ROCM) && ROCM_VERSION >= 60201
#if defined(USE_ROCM)
#include <device_functions.h>
#include <hip/hip_fp16.h>
#include <hip/hip_bf16.h>
@ -115,9 +115,7 @@ __device__ __forceinline__ void fastSpecializedAtomicAdd(
index_t index,
const index_t numel,
scalar_t value) {
#if ( \
(defined(USE_ROCM) && ROCM_VERSION < 60201) || \
(defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 700)))
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 700))
gpuAtomicAddNoReturn(
reinterpret_cast<at::Half*>(tensor) + index,
static_cast<at::Half>(value));
@ -160,9 +158,7 @@ __device__ __forceinline__ void fastSpecializedAtomicAdd(
index_t index,
const index_t numel,
scalar_t value) {
#if ( \
(defined(USE_ROCM) && ROCM_VERSION < 60201) || \
(defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)))
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))
gpuAtomicAddNoReturn(
reinterpret_cast<at::BFloat16*>(tensor) + index,
static_cast<at::BFloat16>(value));

View File

@ -154,7 +154,7 @@ REGISTER_CUDA_DISPATCH(lstsq_stub, &lazy_lstsq_kernel)
// Old style dispatches
// torch_cuda_linalg dynamic library should have a global constructor
// that calls regiserLinaglDispatch so in order ot lazy bind
// that calls registerLinalgDispatch so in order ot lazy bind
// old style dispatch all one have to do is to load library and call disp.func_name
// Protect from infinite recursion by initializing dispatch to self and checking
// that values are different after linalg library were loaded

View File

@ -121,7 +121,7 @@ __device__ scalar_t reduce(Op op, PTA tensor, int plane) {
for (int x = threadIdx.x; x < tensor.size(2); x += blockDim.x*UNRL) {
#pragma unroll
for (int u = 0; u < UNRL; u++)
tmp[u] = op(batch, plane, min((int)tensor.size(2)-1, (int)(x+u*blockDim.x)));
tmp[u] = op(batch, plane, std::min((int)tensor.size(2)-1, (int)(x+u*blockDim.x)));
#pragma unroll
for (int u = 0; u < UNRL; u++)
if (x+u*blockDim.x < tensor.size(2))
@ -306,6 +306,22 @@ __global__ void batch_norm_collect_statistics_kernel(
stat_accscalar_t var_n = 0;
int n = 0;
for (int batch = threadIdx.y; batch < input.size(0); batch += blockDim.y) {
#if defined(USE_ROCM)
constexpr int UNRL = 4;
stat_accscalar_t v_[UNRL];
for (int x = threadIdx.x; x < input.size(2); x += blockDim.x*UNRL) {
for (int u = 0; u < UNRL; u++)
v_[u] = input[batch][plane][min(x+u*blockDim.x, input.size(2)-1)];
for (int u = 0; u < UNRL; u++) {
if (x+u*blockDim.x < input.size(2)) {
stat_accscalar_t d1 = v_[u] - avg;
n++;
avg += d1 / n;
var_n += d1 * (v_[u] - avg);
}
}
}
#else
for (int x = threadIdx.x; x < input.size(2); x += blockDim.x) {
stat_accscalar_t v = input[batch][plane][x];
stat_accscalar_t d1 = v - avg;
@ -313,6 +329,7 @@ __global__ void batch_norm_collect_statistics_kernel(
avg += d1 / n;
var_n += d1 * (v - avg);
}
#endif
}
// first warpSum to get one value per thread to

View File

@ -43,6 +43,12 @@ std::tuple<Tensor&, Tensor&> kthvalue_out_impl_cuda(
TORCH_CHECK(k >= 1 && k <= slicesize,
"kthvalue(): selected number k out of range for dimension ", dim);
TORCH_CHECK(
slicesize <= std::numeric_limits<int32_t>::max(),
"kthvalue(): dimension ", dim, " is too large (", slicesize,
"). The current CUDA implementation supports dimension sizes up to ",
std::numeric_limits<int32_t>::max());
at::assert_no_overlap(self, values);
_reduction_with_indices_allocate_or_resize_output(
@ -163,10 +169,6 @@ std::tuple<Tensor&, Tensor&> kthvalue_out_cuda(
bool keepdim,
Tensor& values,
Tensor& indices) {
// See note [Writing Nondeterministic Operations]
// If there are duplicate elements of the kth value, the procedure for choosing which
// of the duplicates to use for the indices output is nondeterministic.
at::globalContext().alertNotDeterministic("kthvalue CUDA");
auto result = [&]() {
NoNamesGuard guard;
// `kthvalue_out_impl_cuda` expects contiguous in input `self`.

View File

@ -65,25 +65,34 @@ __global__ void gatherKthValue(
&kValue);
// Find the index of the k-th highest element
index_t kValueIndex = 0;
bool foundKValue = false;
__shared__ int32_t minIndexFound;
if (threadIdx.x == 0) {
minIndexFound = static_cast<int32_t>(inputSliceSize);
}
__syncthreads();
for (index_t i = threadIdx.x; i < inputSliceSize; i += blockDim.x) {
bool inRange = (i < inputSliceSize);
scalar_t v = inRange ? doLdg(&inputSliceStart[i * inputWithinSliceStride])
: static_cast<scalar_t>(0);
bool isKValue = inRange &&
((v == kValue) || (at::_isnan(v) && at::_isnan(kValue)));
if (isKValue) {
kValueIndex = i;
foundKValue = true;
break;
}
// Early exit based on best-so-far
if (i >= minIndexFound) {
break;
}
scalar_t v = doLdg(&inputSliceStart[i * inputWithinSliceStride]);
bool isKValue =
((v == kValue) || (at::_isnan(v) && at::_isnan(kValue)));
if (isKValue) {
atomicMin(&minIndexFound, static_cast<int32_t>(i));
break;
}
}
if (foundKValue) {
kthValueSliceStart[0] = kValue;
indicesSliceStart[0] = kValueIndex;
__syncthreads();
if (threadIdx.x == 0) {
indicesSliceStart[0] = static_cast<index_t>(minIndexFound);
kthValueSliceStart[0] = kValue;
}
}

View File

@ -127,6 +127,29 @@ __global__ void upsample_bilinear2d_nhwc_out_frame(
}
}
#ifdef USE_ROCM
// Helper function to compute output pixel range that can contribute to input pixel
template <typename accscalar_t>
__device__ __forceinline__ void compute_output_range(
int input_pos,
accscalar_t scale,
int output_size,
bool align_corners,
int& min_output,
int& max_output) {
accscalar_t lo, hi;
if (align_corners) {
lo = static_cast<accscalar_t>(input_pos - 1) / scale;
hi = static_cast<accscalar_t>(input_pos + 1) / scale;
} else {
lo = (input_pos - static_cast<accscalar_t>(0.5)) / scale - static_cast<accscalar_t>(0.5);
hi = (input_pos + static_cast<accscalar_t>(1.5)) / scale - static_cast<accscalar_t>(0.5);
}
min_output = max(0, static_cast<int>(std::ceil(lo)));
max_output = min(output_size - 1, static_cast<int>(std::floor(hi)));
}
#endif
// Backward (adjoint) operation 1 <- 2 (accumulates)
template <typename scalar_t, typename accscalar_t>
C10_LAUNCH_BOUNDS_1(1024)
@ -141,8 +164,74 @@ __global__ void upsample_bilinear2d_backward_out_frame(
const bool align_corners,
scalar_t* __restrict__ idata,
const scalar_t* __restrict__ odata) {
const size_t o_numel = nc * width2 * height2;
// In C++, integer multiplication, like in standard arithmetic, is generally commutative.
const size_t i_numel = nc * width1 * height1;
#ifdef USE_ROCM
for (size_t index = blockDim.x * blockIdx.x + threadIdx.x; index < i_numel;
index += blockDim.x * gridDim.x) {
// Decode input pixel coordinates
size_t index_temp = index;
const int w1 = index_temp % width1;
index_temp /= width1;
const int h1 = index_temp % height1;
const size_t nc_idx = index_temp / height1;
accscalar_t grad_sum = 0;
// Find range of output pixels that could interpolate from this input pixel
int h2_min, h2_max, w2_min, w2_max;
compute_output_range<accscalar_t>(h1, rheight, height2, align_corners, h2_min, h2_max);
compute_output_range<accscalar_t>(w1, rwidth, width2, align_corners, w2_min, w2_max);
// Iterate over potential output pixels
for (int h2 = h2_min; h2 <= h2_max; h2++) {
for (int w2 = w2_min; w2 <= w2_max; w2++) {
// Compute source coordinates for this output pixel
const accscalar_t h1r = area_pixel_compute_source_index<accscalar_t>(
rheight, h2, align_corners, /*cubic=*/false);
const int h1_base = (int)h1r;
const int h1p = (h1_base < height1 - 1) ? 1 : 0;
const accscalar_t h1lambda = h1r - h1_base;
const accscalar_t h0lambda = static_cast<accscalar_t>(1) - h1lambda;
const accscalar_t w1r = area_pixel_compute_source_index<accscalar_t>(
rwidth, w2, align_corners, /*cubic=*/false);
const int w1_base = (int)w1r;
const int w1p = (w1_base < width1 - 1) ? 1 : 0;
const accscalar_t w1lambda = w1r - w1_base;
const accscalar_t w0lambda = static_cast<accscalar_t>(1) - w1lambda;
// Check if our input pixel participates in this interpolation and accumulate all weights
// At boundaries, h1p=0 or w1p=0 causes some sampling positions to collapse
// to the same pixel, so we need to accumulate weights from all matching positions
accscalar_t weight = 0;
// Check all four interpolation positions and accumulate weights
if (h1 == h1_base && w1 == w1_base) {
weight += h0lambda * w0lambda; // top-left
}
if (h1 == h1_base && w1 == w1_base + w1p) {
weight += h0lambda * w1lambda; // top-right (may be same as top-left if w1p=0)
}
if (h1 == h1_base + h1p && w1 == w1_base) {
weight += h1lambda * w0lambda; // bottom-left (may be same as top-left if h1p=0)
}
if (h1 == h1_base + h1p && w1 == w1_base + w1p) {
weight += h1lambda * w1lambda; // bottom-right (may collapse to other positions)
}
if (weight > 0) {
const size_t output_idx = nc_idx * height2 * width2 + h2 * width2 + w2;
grad_sum += weight * static_cast<accscalar_t>(odata[output_idx]);
}
}
}
// Write accumulated gradient (no atomics needed)
idata[index] = static_cast<scalar_t>(grad_sum);
}
#else
const size_t o_numel = nc * width2 * height2;
for (size_t index = blockDim.x * blockIdx.x + threadIdx.x; index < o_numel;
index += blockDim.x * gridDim.x) {
size_t index_temp = index;
@ -191,6 +280,7 @@ __global__ void upsample_bilinear2d_backward_out_frame(
static_cast<scalar_t>(h1lambda * w1lambda * d2val),
true);
}
#endif
}
template <typename scalar_t, typename accscalar_t>
@ -387,7 +477,6 @@ static void upsample_bilinear2d_backward_out_cuda_template(
// threads are not covering the whole input tensor.
grad_input.zero_();
const size_t num_kernels = nbatch * channels * output_height * output_width;
const int num_threads = std::min(
at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock, 1024);
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
@ -397,6 +486,12 @@ static void upsample_bilinear2d_backward_out_cuda_template(
return;
}
#ifdef USE_ROCM
constexpr bool use_input = true;
#else
constexpr bool use_input = false;
#endif
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half, at::ScalarType::BFloat16,
grad_output_.scalar_type(), "upsample_bilinear2d_backward_out_frame", [&] {
@ -414,6 +509,8 @@ static void upsample_bilinear2d_backward_out_cuda_template(
const accscalar_t rwidth = area_pixel_compute_scale<accscalar_t>(
input_width, output_width, align_corners, scales_w);
const size_t num_kernels = nbatch * channels * output_height * output_width;
upsample_bilinear2d_backward_nhwc_out_frame<scalar_t, accscalar_t>
<<<ceil_div(num_kernels, static_cast<size_t>(num_threads)), num_threads, 0, stream>>>(
input_height,
@ -444,6 +541,8 @@ static void upsample_bilinear2d_backward_out_cuda_template(
const accscalar_t rwidth = area_pixel_compute_scale<accscalar_t>(
input_width, output_width, align_corners, scales_w);
const size_t num_kernels = nbatch * channels * (use_input ? input_height * input_width : output_height * output_width);
upsample_bilinear2d_backward_out_frame<scalar_t, accscalar_t>
<<<ceil_div(num_kernels, static_cast<size_t>(num_threads)),
num_threads,

View File

@ -1,4 +1,4 @@
#if (defined(USE_ROCM) && ROCM_VERSION >= 50700) || ((defined(CUDA_VERSION) && CUDA_VERSION >= 12000) && (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 800)))
#if defined(USE_ROCM) || ((defined(CUDA_VERSION) && CUDA_VERSION >= 12000) && (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 800)))
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
@ -133,7 +133,7 @@ inline __host__ __device__ uint32_t getAlignmentRoundUp(const void* p) {
#define CDNA2_OR_LATER 0
#endif
#if (defined(USE_ROCM) && ROCM_VERSION >= 50700) || ((defined(CUDA_VERSION) && CUDA_VERSION >= 12000) && (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 800)))
#if defined(USE_ROCM) || ((defined(CUDA_VERSION) && CUDA_VERSION >= 12000) && (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 800)))
#if defined(USE_ROCM)
// TODO: Support RDNA
@ -1161,7 +1161,7 @@ at::Tensor _weight_int4pack_mm_cuda(
auto C_final = at::empty(
{m, n}, at::TensorOptions().dtype(at::kBFloat16).device(A.device()));
#if (defined(USE_ROCM) && ROCM_VERSION >= 50700) || ((defined(CUDA_VERSION) && CUDA_VERSION >= 12000) && (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 800)))
#if defined(USE_ROCM) || ((defined(CUDA_VERSION) && CUDA_VERSION >= 12000) && (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 800)))
auto stream = at::cuda::getCurrentCUDAStream();
#define RUN_GEMM(WARPS, K_TILES_PER_WARP, Q_GROUP_SIZE, REDUCE_TYPE) \
do { \
@ -1327,7 +1327,7 @@ at::Tensor _convert_weight_to_int4pack_cuda(
{nTilesTensor, kSuperTiles, 32, innerKTiles / 2},
at::TensorOptions().dtype(at::kInt).device(in.device()));
#if (defined(USE_ROCM) && ROCM_VERSION >= 50700) || ((defined(CUDA_VERSION) && CUDA_VERSION >= 12000) && (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 800)))
#if defined(USE_ROCM) || ((defined(CUDA_VERSION) && CUDA_VERSION >= 12000) && (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 800)))
auto stream = at::cuda::getCurrentCUDAStream();
dim3 grid(kSuperTiles, nTiles);

View File

@ -1532,7 +1532,7 @@ NvrtcFunction jit_pwise_function(
std::string file_path;
if (cache_dir.has_value()) {
// Attemps to read from the cache.
// Attempts to read from the cache.
// Cubin name is <kernel name>_arch<major>.<minor>_nvrtc<major>.<minor>_<ptx or sass>_<program length>_<string hash>
// Note that the SHA1 hash used in the file name is NOT the SHA1 hash of the file's contents,
// because we hash on the CUDA code, but we save the compiled ptx or sass

View File

@ -1346,7 +1346,7 @@ void cholesky_helper_magma(const Tensor& input, bool upper, const Tensor& info)
});
if (input.dim() > 2) {
// if upper=true we need to tranpose and conjugate the result tensor
// if upper=true we need to transpose and conjugate the result tensor
// because the cholesky decomposition is stored in the lower triangular part
if (upper) {
input.copy_(result.mH());
@ -1857,7 +1857,7 @@ void geqrf_kernel(const Tensor& input, const Tensor& tau) {
auto preferred_backend = at::globalContext().linalgPreferredBackend();
switch (preferred_backend) {
// TODO Investigate whether the following magma bug is still occuring.
// TODO Investigate whether the following magma bug is still occurring.
// It may be the case that geqrf followed by orgqr is wrong for the magma backend
// geqrf_magma currently uses geqrf2_gpu
//

View File

@ -82,7 +82,7 @@ void lu_factor_looped_cusolver(const Tensor& self, const Tensor& pivots, const T
#if defined(BUILD_LAZY_CUDA_LINALG)
namespace cuda { namespace detail {
// This is only used for an old-style dispatches
// Please do not add any new entires to it
// Please do not add any new entries to it
struct LinalgDispatch {
Tensor (*cholesky_solve_helper)(const Tensor& self, const Tensor& A, bool upper);
};

View File

@ -147,7 +147,7 @@ static void check_shape_forward(const Tensor& input,
// blocked format will propagate between layers. Input, output will be in blocked format.
//
// For inference case, weight can be prepacked into blocked format by
// (so as to save weight reoder overhead):
// (so as to save weight reorder overhead):
// model = torch.utils.mkldnn.to_mkldnn(model)
//
// For training case, grad_output can be CPU tensor or MKLDNN tensor,
@ -723,7 +723,7 @@ Tensor _mkldnn_convolution_transpose(
ideep::tensor w = itensor_from_tensor(weight, /*from_const_data_ptr*/true);
if (!weight.is_mkldnn()) {
// mkldnn transposed convolution has weight in logical order of OIHW or OIDHW,
// while PyTorch has IOHW or IODHW, `._tranpose()` switches strides (no memory copy).
// while PyTorch has IOHW or IODHW, `._transpose()` switches strides (no memory copy).
w.transpose_(0, 1);
}

View File

@ -540,7 +540,7 @@ static void _mkldnn_matmul_i8i8i32_with_primitive(
args.insert({DNNL_ARG_WEIGHTS, expected_weight});
args.insert({DNNL_ARG_DST, dst});
args.insert({DNNL_ARG_SCRATCHPAD, scratchpad});
// Create primitve and execute
// Create primitive and execute
auto primitive = dnnl::matmul(prim_desc);
primitive.execute(ideep::stream::default_stream(), args);
}

View File

@ -439,7 +439,7 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor> mkldnn_rnn_la
// I. Memory Formats
// a. mkldnn will use plain formats for input, hx/cx, output, hy/cy
// and possibly use blocked formats for weights depending shape info.
// b. All mkldnn memorys are created (in plain format) as views on ATen tensor,
// b. All mkldnn memories are created (in plain format) as views on ATen tensor,
// the weight reorder(if any) is handed automatically inside ideep (mkldnn bridge)
//
// II. MKLDNN Primitive Mapping

View File

@ -39,7 +39,7 @@ void check_mkldnn_binary_fusion_inputs(
inline std::vector<int64_t> padding_r(
IntArrayRef padding, IntArrayRef output_padding)
{
// ConvTranpose padding adjustment
// ConvTranspose padding adjustment
//
// PyTorch uses padding/output_padding:
// osize = (isize - 1) * stride - 2 * padding + dilation * (kernel_size - 1) + output_padding + 1

View File

@ -75,7 +75,7 @@ bool can_use_overrideable_attention(sdp::sdp_params const& params, bool debug) {
}
bool can_use_flash_attention(sdp::sdp_params const& params, bool debug) {
// Currently, XPU fallbacks flash attention to overrideable
// Currently, XPU fallbacks flash attention to overridable
return can_use_overrideable_attention(params, debug);
}
@ -115,7 +115,7 @@ sdp::SDPBackend select_sdp_backend_xpu(sdp::sdp_params const& kernel_params) {
// 1. Flash Attention
// 2. Math fallback
auto& ctx = at::globalContext();
// use overrideable linked to onednn as overrideable implementation
// use overridable linked to onednn as overridable implementation
if (!ctx.userEnabledMathSDP() && !ctx.userEnabledOverrideableSDP() &&
!ctx.userEnabledFlashSDP()) {
return sdp::SDPBackend::error;
@ -165,7 +165,7 @@ sdp::SDPBackend select_sdp_backend_xpu(sdp::sdp_params const& kernel_params) {
}
}
// If we have gotten to this point then two things have happened:
// 1. can_use_overrideable_attention did not satisfy the constraints to be ran
// 1. can_use_overridable_attention did not satisfy the constraints to be ran
// 2. The user has explicitly disabled the math kernel
// We then re-run the kernel checks with debug enabled to print out the
// reason why the kernel was not selected

View File

@ -2,6 +2,7 @@
#include <ATen/WrapDimUtilsMulti.h>
#include <ATen/native/Resize.h>
#include <ATen/native/mkldnn/xpu/detail/oneDNN.h>
#include <ATen/native/xpu/Blas.h>
#include <torch/library.h>
#ifndef AT_PER_OPERATOR_HEADERS
@ -50,9 +51,13 @@ Tensor& addmm_out(
mat1.dtype(),
" != ",
mat2.dtype())
// complex case
TORCH_CHECK(
!mat1.is_complex(), "Complex datatype matmul is not supported in oneDNN");
if (self.is_complex()) {
at::native::addmm_complex_out_xpu(self, mat1, mat2, beta, alpha, result);
return result;
}
std::vector<int64_t> result_shape = {mat1.size(0), mat2.size(1)};
result.resize_(result_shape);
@ -167,8 +172,11 @@ Tensor& mm_out(const Tensor& self, const Tensor& mat2, Tensor& result) {
return result;
}
TORCH_CHECK(
!self.is_complex(), "Complex datatype matmul is not supported in oneDNN");
if (self.is_complex()) {
at::native::mm_complex_out_xpu(self, mat2, result);
return result;
}
onednn::matmul(result, self, mat2, Tensor(), true, onednn::Attr());
return result;
@ -208,9 +216,12 @@ Tensor& baddbmm_out(
input.sizes());
// complex case
TORCH_CHECK(
!batch1.is_complex(),
"Complex datatype matmul is not supported in oneDNN");
if (input.is_complex()) {
at::native::baddbmm_complex_out_xpu(
input, batch1, batch2, beta, alpha, result);
return result;
}
// general case
onednn::Attr attr;
@ -257,8 +268,13 @@ Tensor& bmm_out(const Tensor& self, const Tensor& batch2, Tensor& result) {
return result;
}
TORCH_CHECK(
!self.is_complex(), "Complex datatype matmul is not supported in oneDNN");
// complex case
if (self.is_complex()) {
at::native::bmm_complex_out_xpu(self, batch2, result);
return result;
}
onednn::matmul(result, self, batch2, at::Tensor(), true, onednn::Attr());
return result;
}

View File

@ -215,7 +215,7 @@ partition create_sdpa_graph_partition(
// For optional additive mask
std::optional<op> mask_add;
// For optional implicite causal mask
// For optional implicit causal mask
std::optional<op> mask_gen_idx_row;
std::optional<logical_tensor> mask_row_idx;
std::optional<op> mask_gen_idx_col;
@ -556,7 +556,7 @@ partition create_sdpa_backward_graph_partition(
// For optional additive mask
std::optional<op> mask_add;
// For optional implicite causal mask
// For optional implicit causal mask
std::optional<op> mask_gen_idx_row;
std::optional<logical_tensor> mask_row_idx;
std::optional<op> mask_gen_idx_col;

View File

@ -345,7 +345,7 @@ class Attr {
dnnl::memory binary_m;
auto binary = ops_params_[i].binary_;
auto md = ops_params_[i].meta_;
// qeury expected_md to achieve peak performance
// query expected_md to achieve peak performance
auto expected_md = pd.query_md(
dnnl::query::exec_arg_md,
DNNL_ARG_ATTR_MULTIPLE_POST_OP(i) | DNNL_ARG_SRC_1);

View File

@ -301,7 +301,7 @@ bool is_onednn_matmul_strides(const at::Tensor& tensor) {
return false;
}
// the overlaped cases are not supported
// the overlapped cases are not supported
dnnl::memory::dims strides = get_onednn_strides(tensor);
int64_t storage_size = 1;
for (size_t dim = 0; dim < tensor_dim; ++dim)

View File

@ -29,7 +29,7 @@
secondaryTensor:(MPSGraphTensor*)secondaryTensor
name:(NSString*)name {
// As of MacOS-15.1 m..imumWithNanPropagation is only defined for floating types and calling it with integral
// agruments results in
// arguments results in
// /AppleInternal/Library/BuildRoots/c7c74b64-74b4-11ef-aeda-9635a580fe0d/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShaders/MPSCore/Utility/MPSKernelDAG.mm:805:
// failed assertion `Error getting visible function: (null) Function isNaN_u8_i8 was not found in the library'
if (([primaryTensor dataType] & MPSDataTypeFloatBit) == 0) {
@ -42,7 +42,7 @@
secondaryTensor:(MPSGraphTensor*)secondaryTensor
name:(NSString*)name {
// As of MacOS-15.1 m..imumWithNanPropagation is only defined for floating types and calling it with integral
// agruments results in
// arguments results in
// /AppleInternal/Library/BuildRoots/c7c74b64-74b4-11ef-aeda-9635a580fe0d/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShaders/MPSCore/Utility/MPSKernelDAG.mm:805:
// failed assertion `Error getting visible function: (null) Function isNaN_u8_i8 was not found in the library'
if (([primaryTensor dataType] & MPSDataTypeFloatBit) == 0) {
@ -539,7 +539,7 @@ Placeholder::Placeholder(MPSGraphTensor* mpsGraphTensor,
static const bool is_macOS_15_0_or_newer = is_macos_13_or_newer(MacOSVersion::MACOS_VER_15_0_PLUS);
// Use gather kernel to solve strides for macOS < 15.0
// Starting with macOS 15.0, MPS supports native strides direclty in the kernels
// Starting with macOS 15.0, MPS supports native strides directly in the kernels
if (!is_macOS_15_0_or_newer || !useMPSStridedAPI) {
if ((!src.is_contiguous() || src.storage_offset()) && gatherTensorData) {
Tensor emptyShell = Tensor();

View File

@ -222,6 +222,13 @@ struct nextafter_functor {
}
};
struct hypot_functor {
template <typename T>
inline T operator()(const T a, const T b) {
return static_cast<T>(precise::sqrt(float(a) * a + float(b) * b));
}
};
// Complex binary functors
struct polar_functor {
template <typename U>
@ -362,6 +369,7 @@ struct igammac_functor {
REGISTER_OPMATH_BINARY_OP(NAME, half, half); \
REGISTER_OPMATH_BINARY_OP(NAME, bfloat, bfloat)
REGISTER_FLOAT_BINARY_OP(hypot);
REGISTER_FLOAT_BINARY_OP(copysign);
REGISTER_INT2FLOAT_BINARY_OP(copysign);
REGISTER_FLOAT_BINARY_OP(fmax);

View File

@ -0,0 +1,16 @@
#pragma onces
#include <c10/metal/common.h>
template <unsigned N = c10::metal::max_ndim>
struct OrgqrParams {
int32_t num_batch_dims;
uint32_t m;
uint32_t n;
uint32_t k;
::c10::metal::array<uint32_t, N> A_strides;
::c10::metal::array<uint32_t, N> tau_strides;
::c10::metal::array<uint32_t, N> H_strides;
::c10::metal::array<uint32_t, N> H_sizes;
};

View File

@ -1,3 +1,4 @@
#include <ATen/native/mps/kernels/LinearAlgebra.h>
#include <c10/metal/utils.h>
#include <metal_array>
#include <metal_simdgroup>
@ -640,6 +641,164 @@ kernel void applyPivots(
}
}
template <typename T>
static T bool_to_float(bool b) {
return static_cast<T>(b);
}
template <>
half2 bool_to_float(bool b) {
return half2(b ? 1 : 0, 0);
}
template <>
float2 bool_to_float(bool b) {
return float2(b ? 1 : 0, 0);
}
template <typename T>
static T calc_H_irc(
device T* A,
uint32_t A_stride_r,
uint32_t A_stride_c,
constant T* tau,
uint32_t tau_stride,
uint32_t r,
uint32_t c,
uint32_t i) {
T I_val = bool_to_float<T>(r == c);
T tau_val = tau[i * tau_stride];
T A_ci = c10::metal::conj(A[c * A_stride_r + i * A_stride_c]);
T A_ri = A[r * A_stride_r + i * A_stride_c];
T c_eq_i = bool_to_float<T>(c == i);
T r_eq_i = bool_to_float<T>(r == i);
T A_ci_ = (c > i) ? A_ci : c_eq_i;
T A_ri_ = (r > i) ? A_ri : r_eq_i;
return I_val - c10::metal::mul(tau_val, c10::metal::mul(A_ci_, A_ri_));
}
// Calculate (A @ B)[r, c], the element in the r-th row and c-th column of the
// result of matrix multiplying A and B together. A and B must be size m-by-m
// and have the same strides. The formula for this operation, written in Python
// syntax, is:
// (A @ B)[r, c] = A[r, :].dot(B[:, c])
template <typename T>
static T calc_matmul_rc(
device T* A,
device T* B,
uint32_t stride_r,
uint32_t stride_c,
uint32_t m,
uint32_t r,
uint32_t c) {
T AB_rc = 0;
auto A_row_offset = r * stride_r;
auto B_col_offset = c * stride_c;
uint32_t A_col_offset = 0;
uint32_t B_row_offset = 0;
for (uint32_t j = 0; j < m;
j++, A_col_offset += stride_c, B_row_offset += stride_r) {
AB_rc += c10::metal::mul(
A[A_row_offset + A_col_offset], B[B_row_offset + B_col_offset]);
}
return AB_rc;
}
template <typename T>
kernel void orgqr(
device T* A [[buffer(0)]],
constant T* tau [[buffer(1)]],
device T* H [[buffer(2)]],
device T* H_prod [[buffer(3)]],
constant OrgqrParams<>& params [[buffer(4)]],
uint tid [[thread_position_in_grid]]) {
constant auto& A_strides = params.A_strides;
constant auto& tau_strides = params.tau_strides;
constant auto& H_strides = params.H_strides;
constant auto& H_sizes = params.H_sizes;
auto num_batch_dims = params.num_batch_dims;
auto m = params.m;
auto n = params.n;
auto k = params.k;
auto m2 = m * m;
auto batch_idx = tid / m2;
// Find the matrices for this thread's batch index
uint32_t A_offset = 0;
uint32_t tau_offset = 0;
uint32_t H_offset = 0;
for (auto dim = num_batch_dims - 1; dim >= 0; dim--) {
auto dim_size = H_sizes[dim];
auto dim_idx = batch_idx % dim_size;
A_offset += dim_idx * A_strides[dim];
tau_offset += dim_idx * tau_strides[dim];
H_offset += dim_idx * H_strides[dim];
batch_idx /= dim_size;
}
A += A_offset;
tau += tau_offset;
H += H_offset;
H_prod += H_offset;
auto matrix_idx = tid % m2;
auto r = matrix_idx / m;
auto c = matrix_idx % m;
auto A_stride_r = A_strides[num_batch_dims];
auto A_stride_c = A_strides[num_batch_dims + 1];
auto tau_stride = tau_strides[num_batch_dims];
auto H_stride_r = H_strides[num_batch_dims];
auto H_stride_c = H_strides[num_batch_dims + 1];
// Find the element of H and H_prod that this thread will calculate
device T* H_elem_ptr = H + (r * H_stride_r + c * H_stride_c);
device T* H_prod_elem_ptr = H_prod + (r * H_stride_r + c * H_stride_c);
for (uint32_t i = 0; i < k; i++) {
// Calculate and write H_i
T H_irc = calc_H_irc(A, A_stride_r, A_stride_c, tau, tau_stride, r, c, i);
// Calculate element [r, c] of prod(H_0, ..., H_i)
if (i == 0) {
*H_prod_elem_ptr = H_irc;
} else {
*H_elem_ptr = H_irc;
// Need this sync because the below matmul requires all threads to finish
// writing their entries to `H_prod` and `H`.
threadgroup_barrier(mem_flags::mem_threadgroup);
T H_prod_0_to_i_rc =
calc_matmul_rc(H_prod, H, H_stride_r, H_stride_c, m, r, c);
// Need this sync because the above matmul uses the current values in
// `H_prod`, and we don't want to overwrite those until all threads are
// finished using them.
threadgroup_barrier(mem_flags::mem_threadgroup);
*H_prod_elem_ptr = H_prod_0_to_i_rc;
}
}
device T* A_elem_ptr = A + (r * A_stride_r + c * A_stride_c);
if (c < n) {
*A_elem_ptr = *H_prod_elem_ptr;
}
}
#define INSTANTIATE_MM_OPS(DTYPE) \
template [[host_name("matmul_" #DTYPE)]] kernel void matmul<DTYPE>( \
constant DTYPE * mat1Data [[buffer(0)]], \
@ -679,3 +838,19 @@ INSTANTIATE_MM_OPS(int);
INSTANTIATE_MM_OPS(short);
INSTANTIATE_MM_OPS(char);
INSTANTIATE_MM_OPS(uchar);
#define REGISTER_ORGQR(T) \
template [[host_name("orgqr_" #T)]] \
kernel void orgqr<T>( \
device T * A [[buffer(0)]], \
constant T * tau [[buffer(1)]], \
device T * H [[buffer(2)]], \
device T * H_prod [[buffer(3)]], \
constant OrgqrParams<> & params [[buffer(4)]], \
uint tid [[thread_position_in_grid]]);
REGISTER_ORGQR(float);
REGISTER_ORGQR(half);
REGISTER_ORGQR(bfloat);
REGISTER_ORGQR(float2);
REGISTER_ORGQR(half2);

View File

@ -5,6 +5,21 @@
using namespace metal;
using namespace c10::metal;
struct angle_functor {
template <typename T, enable_if_t<is_complex_v<T>, bool> = true>
inline T operator()(const T x) {
return T(atan2(x.y, x.x), 0);
}
template <typename T, enable_if_t<is_scalar_floating_point_v<T>, bool> = true>
inline T operator()(const T x) {
return T(isnan(x) ? x : x < 0 ? M_PI_F : 0.0);
}
template <typename T, enable_if_t<is_scalar_integral_v<T>, bool> = true>
inline float operator()(const T x) {
return x < 0 ? M_PI_F : 0.0;
}
};
// Implement exp wrapper for both real and complex types
template <typename T, enable_if_t<is_scalar_floating_point_v<T>, bool> = true>
inline T exp_(const T x) {
@ -545,6 +560,7 @@ REGISTER_UNARY_OP(abs, float, float);
REGISTER_UNARY_OP(abs, half, half);
#define INSTANTIATE_UNARY_KERNELS2(DTYPE0, DTYPE1) \
REGISTER_UNARY_OP(angle, DTYPE1, DTYPE0); \
REGISTER_UNARY_OP(erf, DTYPE1, DTYPE0); \
REGISTER_UNARY_OP(erfc, DTYPE1, DTYPE0); \
REGISTER_UNARY_OP(erfinv, DTYPE1, DTYPE0); \
@ -583,6 +599,7 @@ INSTANTIATE_UNARY_KERNELS2(float, int);
INSTANTIATE_UNARY_KERNELS2(float, long);
#define INSTANTIATE_UNARY_KERNELS_VEC2(DTYPE) \
REGISTER_UNARY_OP(angle, DTYPE##2, DTYPE##2); \
REGISTER_UNARY_OP(neg, DTYPE##2, DTYPE##2); \
REGISTER_UNARY_OP(exp, DTYPE##2, DTYPE##2); \
REGISTER_UNARY_OP(expm1, DTYPE##2, DTYPE##2); \

View File

@ -202,6 +202,10 @@ static void igammac_mps_kernel(TensorIteratorBase& iter) {
lib.exec_binary_kernel(iter, "igammac");
}
static void hypot_mps_kernel(TensorIteratorBase& iter) {
lib.exec_binary_kernel(iter, "hypot");
}
REGISTER_DISPATCH(fmax_stub, &fmax_mps_kernel)
REGISTER_DISPATCH(fmin_stub, &fmin_mps_kernel)
REGISTER_DISPATCH(copysign_stub, &copysign_mps_kernel)
@ -229,4 +233,5 @@ REGISTER_DISPATCH(fmod_stub, &fmod_mps_kernel)
REGISTER_DISPATCH(remainder_stub, &remainder_mps_kernel)
REGISTER_DISPATCH(igamma_stub, &igamma_mps_kernel)
REGISTER_DISPATCH(igammac_stub, &igammac_mps_kernel)
REGISTER_DISPATCH(hypot_stub, &hypot_mps_kernel)
} // namespace at::native

View File

@ -16,7 +16,6 @@
#include <ATen/ops/eq_native.h>
#include <ATen/ops/ge_native.h>
#include <ATen/ops/gt_native.h>
#include <ATen/ops/hypot_native.h>
#include <ATen/ops/le_native.h>
#include <ATen/ops/logaddexp2_native.h>
#include <ATen/ops/logaddexp_native.h>
@ -278,22 +277,6 @@ TORCH_IMPL_FUNC(pow_Scalar_out_mps)(const Scalar& base, const Tensor& exp, const
}
}
TORCH_IMPL_FUNC(hypot_out_mps)(const Tensor& self, const Tensor& other, const Tensor& output) {
mps::BinaryOpBlock hypot_op_block = ^BinaryOpFn(cachedGraph, primaryCastTensor, secondaryCastTensor) {
MPSGraph* mpsGraph = cachedGraph->graph();
MPSGraphTensor* twoTensor = [mpsGraph constantWithScalar:2.0 shape:@[ @1 ] dataType:primaryCastTensor.dataType];
MPSGraphTensor* sumTensor = [mpsGraph additionWithPrimaryTensor:[mpsGraph powerWithPrimaryTensor:primaryCastTensor
secondaryTensor:twoTensor
name:nil]
secondaryTensor:[mpsGraph powerWithPrimaryTensor:secondaryCastTensor
secondaryTensor:twoTensor
name:nil]
name:nil];
return [mpsGraph squareRootWithTensor:sumTensor name:nil];
};
mps::binaryOpTensor(self, other, output, "hypot_out_mps", hypot_op_block);
}
TORCH_IMPL_FUNC(logaddexp_out_mps)(const Tensor& self, const Tensor& other, const Tensor& output) {
mps::BinaryOpBlock logaddexp_op_block = ^BinaryOpFn(cachedGraph, primaryCastTensor, secondaryCastTensor) {
MPSGraph* mpsGraph = cachedGraph->graph();

View File

@ -8,6 +8,9 @@
#include <ATen/native/Resize.h>
#include <ATen/native/mps/MPSGraphSequoiaOps.h>
#include <ATen/native/mps/OperationUtils.h>
#include <ATen/native/mps/kernels/LinearAlgebra.h>
#include <fmt/format.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
@ -28,6 +31,7 @@
#include <ATen/ops/linalg_solve_triangular_native.h>
#include <ATen/ops/lu_unpack_native.h>
#include <ATen/ops/mm_native.h>
#include <ATen/ops/orgqr_native.h>
#include <ATen/ops/slice.h>
#include <ATen/ops/stack.h>
#include <ATen/ops/triangular_solve_native.h>
@ -1235,6 +1239,69 @@ static void cholesky_stub_impl(const Tensor& out, const Tensor& info, bool upper
}
}
static Tensor& orgqr_stub_impl(Tensor& self, const Tensor& tau) {
if (self.numel() == 0) {
return self;
}
auto m = self.size(-2);
auto n = self.size(-1);
auto k = tau.size(-1);
if (tau.numel() == 0) {
auto I = eye(m, self.scalar_type(), std::nullopt, self.device());
return self.copy_(I.slice(-1, 0, n));
}
auto num_batch_dims = self.dim() - 2;
auto batch_sizes = self.sizes().slice(0, num_batch_dims);
std::vector<int64_t> H_sizes(num_batch_dims + 2);
for (auto dim : c10::irange(num_batch_dims)) {
H_sizes[dim] = self.size(dim);
}
H_sizes[num_batch_dims] = m;
H_sizes[num_batch_dims + 1] = m;
auto H = at::empty(H_sizes, self.options().memory_format(MemoryFormat::Contiguous));
auto H_prod = at::empty_like(H);
OrgqrParams params;
params.num_batch_dims = num_batch_dims;
params.m = m;
params.n = n;
params.k = k;
for (const auto dim : c10::irange(self.dim())) {
params.A_strides[dim] = self.stride(dim);
if (dim < tau.dim()) {
params.tau_strides[dim] = tau.stride(dim);
}
params.H_strides[dim] = H.stride(dim);
params.H_sizes[dim] = H.size(dim);
}
auto num_threads = H.numel();
MPSStream* stream = getCurrentMPSStream();
dispatch_sync_with_rethrow(stream->queue(), ^() {
@autoreleasepool {
id<MTLComputeCommandEncoder> compute_encoder = stream->commandEncoder();
auto pipeline_state = lib.getPipelineStateForFunc(fmt::format("orgqr_{}", scalarToMetalTypeString(self)));
getMPSProfiler().beginProfileKernel(pipeline_state, "orgqr", {self, tau});
[compute_encoder setComputePipelineState:pipeline_state];
mtl_setArgs(compute_encoder, self, tau, H, H_prod, params);
mtl_dispatch1DJob(compute_encoder, pipeline_state, num_threads);
getMPSProfiler().endProfileKernel(pipeline_state);
}
});
return self;
}
} // namespace mps
Tensor addr_mps(const Tensor& self, const Tensor& vec1, const Tensor& vec2, const Scalar& beta, const Scalar& alpha) {
@ -1471,4 +1538,6 @@ TORCH_IMPL_FUNC(linalg_inv_ex_out_mps)(const Tensor& A, bool check_errors, const
}
REGISTER_DISPATCH(cholesky_stub, mps::cholesky_stub_impl)
REGISTER_DISPATCH(orgqr_stub, mps::orgqr_stub_impl);
} // namespace at::native

View File

@ -158,7 +158,7 @@ static void reduction_out_mps(const Tensor& input_t,
IntArrayRef dim = opt_dim.value();
for (const auto dim_val : dim) {
auto wrap_dim = maybe_wrap_dim(dim_val, input_shape.size());
// canSqueeze logic is broken when dim is negative, it introduces off-by-one-erros or crashes
// canSqueeze logic is broken when dim is negative, it introduces off-by-one-errors or crashes
// See https://github.com/pytorch/pytorch/issues/136132#issuecomment-2354482608
if (wrap_dim >= 4 || dim_val < 0) {
canSqueezeLastDim = false;
@ -1282,7 +1282,7 @@ static void all_any_common_impl_mps(const Tensor& input_t,
auto inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_t);
auto castInputTensor = castToIHFTypes(mpsGraph, inputTensor, input_t);
// reductionOrWithTensor:axis: will throw an internal assert if number of dimentions is more than 4
// reductionOrWithTensor:axis: will throw an internal assert if number of dimensions is more than 4
// See https://github.com/pytorch/pytorch/issues/95538
MPSGraphTensor* outputTensor = nil;
if (input_t.ndimension() > 4) {
@ -1352,7 +1352,7 @@ TORCH_IMPL_FUNC(any_all_out_mps)(const Tensor& input_t, const Tensor& output_t)
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
auto inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_t);
auto castInputTensor = castToIHFTypes(mpsGraph, inputTensor, input_t);
// reductionOrWithTensor:axes: will throw an internal assert if number of dimentions is more than 4
// reductionOrWithTensor:axes: will throw an internal assert if number of dimensions is more than 4
// See https://github.com/pytorch/pytorch/issues/95538
if (input_t.dim() > 4) {
castInputTensor = [mpsGraph reshapeTensor:castInputTensor withShape:@[ @-1 ] name:nil];
@ -1400,7 +1400,7 @@ TORCH_IMPL_FUNC(all_all_out_mps)(const Tensor& input_t, const Tensor& output_t)
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
auto inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_t);
auto castInputTensor = castToIHFTypes(mpsGraph, inputTensor, input_t);
// reductionAndWithTensor:axes: will throw an internal assert if number of dimentions is more than 4
// reductionAndWithTensor:axes: will throw an internal assert if number of dimensions is more than 4
// See https://github.com/pytorch/pytorch/issues/95538
if (input_t.ndimension() > 4) {
castInputTensor = [mpsGraph reshapeTensor:castInputTensor withShape:@[ @-1 ] name:nil];

View File

@ -34,6 +34,7 @@ REGISTER_UNARY_TI_DISPATCH(sinc);
REGISTER_UNARY_TI_DISPATCH(sinh);
REGISTER_UNARY_TI_DISPATCH(cosh);
REGISTER_UNARY_TI_DISPATCH(tanh);
REGISTER_UNARY_TI_DISPATCH(angle);
REGISTER_UNARY_TI_DISPATCH(abs);
REGISTER_UNARY_TI_DISPATCH(sin);
REGISTER_UNARY_TI_DISPATCH(cos);

View File

@ -12,7 +12,6 @@
#include <ATen/ops/_copy_from_and_resize.h>
#include <ATen/ops/acos_native.h>
#include <ATen/ops/acosh_native.h>
#include <ATen/ops/angle_native.h>
#include <ATen/ops/asin_native.h>
#include <ATen/ops/asinh_native.h>
#include <ATen/ops/atan_native.h>
@ -204,23 +203,6 @@ Tensor& logical_not_out_mps(const Tensor& self, Tensor& output) {
return output;
}
Tensor& angle_out_mps(const Tensor& self, Tensor& output) {
mps::unary_op(self, output, "angle_out_mps", ^MPSGraphTensor*(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) {
auto realPart = [mpsGraph realPartOfTensor:inputTensor name:nil];
auto imagPart = [mpsGraph imaginaryPartOfTensor:inputTensor name:nil];
return [mpsGraph atan2WithPrimaryTensor:imagPart secondaryTensor:realPart name:nil];
});
return output;
}
Tensor angle_mps(const Tensor& self) {
const auto float_type = c10::isIntegralType(self.scalar_type(), /*includeBool=*/true)
? c10::typeMetaToScalarType(c10::get_default_dtype())
: c10::toRealValueType(self.scalar_type());
Tensor result = at::empty({0}, self.options().dtype(float_type));
return angle_out_mps(self, result);
}
TORCH_IMPL_FUNC(frac_out_mps)(const Tensor& self, const Tensor& output) {
TORCH_CHECK(isFloatingType(self.scalar_type()), "frac_out_mps is only implemented for floating types");
mps::unary_op(self, output, "frac_out_mps", ^MPSGraphTensor*(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) {

View File

@ -19,7 +19,7 @@ namespace at::native::mps {
// For both scatter and gather kernels, there are 4 specized ones (for 1D to 4D tensor)
// and one generic, for 5+D ones. Assumption (to be tested) about specialized kernels
// is that reduction of n-dimentional vector, where n is 2, should be slower
// is that reduction of n-dimensional vector, where n is 2, should be slower
// than reduction of 2D one, as n is not known at compiler time, therefore compiler
// could not do loop unrolls, that is
// float sum(float* v, int n) {

Some files were not shown because too many files have changed in this diff Show More