Compare commits

...

168 Commits

Author SHA1 Message Date
50848db65d Fix pyrefly ignore syntax in _inductor 2025-10-25 22:35:21 -07:00
84b14f3a10 Fix error suppression syntax in utils and nn (#166242)
Fixes syntax for pyrefly : ignores so they only ignore a specific category. No functional changes

pyrefly check
lintrunner

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166242
Approved by: https://github.com/oulgen, https://github.com/cyyever
2025-10-26 05:21:07 +00:00
5121499f6b Fix pyrefly ignore syntax in /tools/... (#166240)
Second PR for this - only adjusts the syntax used for the ignores so the suppressions hide only one category of pyrefly errors.

test:
pyrefly check
lintrunner

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166240
Approved by: https://github.com/oulgen
2025-10-26 04:20:16 +00:00
8f80892359 Use correct pyrefly syntax in suppressions distributed/... (#166241)
Updates the pyrefy-ignores in the torch/distributed directory to use the correct syntax. No functional changes.

pyrefly check
lintrunner

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166241
Approved by: https://github.com/oulgen
2025-10-26 04:16:41 +00:00
cdb60e44eb [Inductor] Naive foreach autotune support (#162053)
Initial autotuning support for foreach kernels, 4x improvement for some kernels in internal workload. More improvements can surely be made here in the future. Removing num_warps for definition to enable autotune support in generated wrapper code.

Before:
triton_for_fused_18.kd 🔍 | 4.986 ms | 4.986 ms | 2.493 ms | 2 |
triton_for_fused_6.kd 🔍 | 0.098 ms | 0.098 ms | 0.049 ms | 2 |
triton_for_fused_7.kd 🔍 | 0.036 ms | 0.036 ms | 0.018 ms | 2 |

After:
triton_for_fused_18.kd 🔍 | 1.273 ms | 1.273 ms | 0.636 ms | 2 |
triton_for_fused_6.kd 🔍 | 0.044 ms | 0.044 ms | 0.022 ms | 2 |
triton_for_fused_7.kd 🔍 | 0.024 ms | 0.024 ms | 0.012 ms | 2 |

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162053
Approved by: https://github.com/mlazos, https://github.com/naromero77amd

Co-authored-by: Nichols A. Romero <nick.romero@amd.com>
2025-10-26 02:36:15 +00:00
25909d2629 Simplify SingletonOrSharedTypePtr (#166183)
@neildhar pointed out at PTC yesterday that the assumption SingletonOrSharedTypePtr makes about shared_ptr's pointers being either both null or both non-null is incorrect because of the aliasing constructor, and furthermore that SingletonOrSharedTypePtr needn't be as fancy as it is because said constructor exists. (See also https://github.com/pytorch/pytorch/issues/166152 .)

Differential Revision: [D85458769](https://our.internmc.facebook.com/intern/diff/D85458769/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166183
Approved by: https://github.com/Skylion007, https://github.com/cyyever
2025-10-26 01:25:24 +00:00
c7eee49525 Fix pyrefly ignores 1/n (#166239)
First diff adjusting the syntax for pyrefly: ignore suppressions so they only hide one class of type error.

Test:
lintrunner
pyrefly check

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166239
Approved by: https://github.com/oulgen
2025-10-26 00:44:10 +00:00
eqy
621ba05107 [cuDNN][SDPA] Handle c10:Error when checking device capability for prefer-cuDNN SDPA check (#166201)
Fake device test can execute this function when the number of visible CUDA devices is 0, fix to unblock #165922

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166201
Approved by: https://github.com/Skylion007
2025-10-25 23:00:06 +00:00
39a70cead1 [feat]: add optimized exp_u20 implementation from Arm Optimized Routi… (#161049)
This patch adds an optimized exp_u20() implementation, based on Arm Optimized Routines (AOR). The legacy svexp_f32_z function is removed, and internal uses (such as in tanh) now leverage the new exp_u20() logic. Unit tests have been updated to cover all scenarios.
The implementation ensures correct handling of edge cases by falling back to exp() for extreme inputs (|x| ≥ 0x1.5d5e2ap+6f or |x| ≤ -0x1.5d5e2ap+6f).

Performance:
<html>
<body>
<!--StartFragment--><h3 data-start="70" data-end="182"><strong data-start="77" data-end="182">Performance Improvements for <code data-start="108" data-end="144">aten::scaled_dot_product_attention</code> (Neoverse-V2, <code data-start="159" data-end="179">OMP_NUM_THREADS=16</code>)</strong></h3>
<div class="_tableContainer_1rjym_1"><div tabindex="-1" class="group _tableWrapper_1rjym_13 flex w-fit flex-col-reverse">
<html xmlns:v="urn:schemas-microsoft-com:vml"
xmlns:o="urn:schemas-microsoft-com:office:office"
xmlns:x="urn:schemas-microsoft-com:office:excel"
xmlns="http://www.w3.org/TR/REC-html40">

<head>

<meta name=ProgId content=Excel.Sheet>
<meta name=Generator content="Microsoft Excel 15">
<link id=Main-File rel=Main-File
href="file:///C:/Users/anaabu01/AppData/Local/Temp/msohtmlclip1/01/clip.htm">
<link rel=File-List
href="file:///C:/Users/anaabu01/AppData/Local/Temp/msohtmlclip1/01/clip_filelist.xml">

<!--table
	{mso-displayed-decimal-separator:"\.";
	mso-displayed-thousand-separator:"\,";}
@page
	{margin:.75in .7in .75in .7in;
	mso-header-margin:.3in;
	mso-footer-margin:.3in;}
tr
	{mso-height-source:auto;}
col
	{mso-width-source:auto;}
br
	{mso-data-placement:same-cell;}
td
	{padding-top:1px;
	padding-right:1px;
	padding-left:1px;
	mso-ignore:padding;
	color:black;
	font-size:11.0pt;
	font-weight:400;
	font-style:normal;
	text-decoration:none;
	font-family:"Aptos Narrow", sans-serif;
	mso-font-charset:0;
	mso-number-format:General;
	text-align:general;
	vertical-align:bottom;
	border:none;
	mso-background-source:auto;
	mso-pattern:auto;
	mso-protection:locked visible;
	white-space:nowrap;
	mso-rotate:0;}
.xl63
	{font-weight:700;
	text-align:center;
	vertical-align:middle;
	border:.5pt solid windowtext;
	white-space:normal;}
.xl64
	{text-align:center;
	vertical-align:middle;
	border:.5pt solid windowtext;
	white-space:normal;}
-->

</head>

<body link="#467886" vlink="#96607D">

Configuration | Current | With Changes   (F32) | Speedup
-- | -- | -- | --
Batch 1 · 16 Heads   · Seq 512 · Q 128 | 654.102 µs | 551.031 µs | 1.19× faster (≈ 19%)
Batch 8 · 64 Heads   · Seq 2048 · Q 128 | 30.308 ms | 17.142 ms | 1.77× faster (≈ 43%)

</body>

</html>

</div></div><!--EndFragment-->
</body>
</html>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/161049
Approved by: https://github.com/fadara01, https://github.com/jgong5

Co-authored-by: Fadi Arafeh <Fadi.Arafeh@arm.com>
2025-10-25 20:44:11 +00:00
d97f6550a2 [Intel GPU] Xpu matmul implementation for complex dtype (#160867)
Enabling complex datatype support for 4 ops: `mm`, `bmm`, `addmm`, `baddbmm` for XPU. From now implementation will call functions created in: https://github.com/intel/torch-xpu-ops/pull/1992.

Additionally added complex datatype tests for matmul operators. More detailed tests are going to be enabled in: https://github.com/intel/torch-xpu-ops/pull/1993

CI runs have found that `test_comprehensive_linalg_eig_xpu` tests were calling internally matmul with complex datatype. With this PR test starts to pass so linalg.eig was removed from `inductor_expected_failures_single_sample["xpu"]` as otherwise it was failing with: `Unexpected success` message.

Part of: https://github.com/intel/torch-xpu-ops/issues/1853

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160867
Approved by: https://github.com/guangyey, https://github.com/ZhiweiYan-96, https://github.com/gujinghui, https://github.com/EikanWang, https://github.com/Silv3S, https://github.com/CuiYifeng, https://github.com/jansel
2025-10-25 17:13:13 +00:00
516e58965a Revert "Export flex attention with kwargs and DTensor (#166045)"
This reverts commit de7fdfe41ad12aec719e3662be58ce9e9bf255a8.

Reverted https://github.com/pytorch/pytorch/pull/166045 on behalf of https://github.com/malfet due to Broke distributed tests, see b55b779ad3/1 ([comment](https://github.com/pytorch/pytorch/pull/166045#issuecomment-3446850955))
2025-10-25 15:47:32 +00:00
b55b779ad3 Add file size limits to linters and refactor grep_linter (#166202)
- Add 1GB file size limits to grep_linter, newlines_linter, codespell_linter
- Refactor grep_linter
  - process files once instead of per-line
  - Extract allowlist check to separate function
  - Add 512KB limit for computing replacements, 100 match limit per file
  - Detect duplicate arguments
- Fix .lintrunner.toml: RAWCUDADEVICE used --pattern twice
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166202
Approved by: https://github.com/Skylion007
2025-10-25 14:57:19 +00:00
74e53d0761 [TorchScript] clearer debug for ConcreteModuleType::findSubmoduleConcreteType (#166192)
Summary:
right now the log is just
```
RuntimeError: it != data_.modules_.end() INTERNAL ASSERT FAILED at "fbcode/caffe2/torch/csrc/jit/frontend/concrete_module_type.cpp":207, please report a bug to PyTorch.
```
we have no clue where the error happens
https://fb.workplace.com/groups/gpuinference/posts/789257990578348/?comment_id=789284783909002&reply_comment_id=789415260562621

Test Plan: UT

Reviewed By: jcwchen

Differential Revision: D80020093

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166192
Approved by: https://github.com/gmagogsfm
2025-10-25 14:07:54 +00:00
798a6d2be1 [Inductor][Autotune] Gracefully restart the autotune process after ULF failure (#166073)
This PR partially fixes https://github.com/pytorch/torchtitan/issues/1791, as it will work with `TORCHINDUCTOR_AUTOTUNE_IN_SUBPROC=1` setting only.

The core of the problem: In `max-autotune` mode Inductor runs multiple benchmarks to determine the best config. If one of these benchmarks fails with `cudaErrorLaunchFailure`, all other CUDA calls within the same process will fail including the rest of the benchmarks.

The solution: Restart the child process gracefully and continue benchmarking.

Unfortunately, if autotuning is done in the main process, the whole program falls into unrecoverable state. In this case, the only way of successful execution would be just preventing the ULF.

Here is some info from [CUDA documentation](https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html):
>cudaErrorLaunchFailure = 719
An exception occurred on the device while executing a kernel. ... . This leaves the process in an inconsistent state and any further CUDA work will return the same error. To continue using CUDA, the process must be terminated and relaunched.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166073
Approved by: https://github.com/syed-ahmed, https://github.com/drisspg
2025-10-25 10:40:59 +00:00
b0e9c86971 [MPS] Move hypot to Metal (#166216)
Which also prevents crashes, when invoked for integer types, for example, before this change following crashes
```
python -c "import torch; print(torch.hypot(torch.randint(0, 10, (3,), device='mps'), torch.randint(0, 10, (3,), device='mps')))"
*** Terminating app due to uncaught exception 'NSInvalidArgumentException', reason: '*** -[__NSDictionaryM setObject:forKey:]: object cannot be nil (key: squareRoot_i64)'
*** First throw call stack:
(
	0   CoreFoundation                      0x0000000194d33ae0 __exceptionPreprocess + 176
	1   libobjc.A.dylib                     0x00000001947f6b90 objc_exception_throw + 88
	2   CoreFoundation                      0x0000000194c7d884 -[__NSDictionaryM setObject:forKey:] + 1288
	3   MPSCore                             0x00000001a1187d0c _ZN12MPSKernelDAG15duodenaryCoreOpEP10BaseTensorS1_S1_S1_S1_S1_S1_S1_S1_S1_S1_S1_RKNSt3__16vectorIlNS2_9allocatorIlEEEE11MPSDataTypePKc + 37044
	4   MPSCore                             0x00000001a113fab0 _ZN12MPSKernelDAGD0Ev + 4256
	5   MPSCore                             0x00000001a1139f6c _ZN12MPSKernelDAG13getDAGAndHashEPU21objcproto10MTLLibrary11objc_objectP14MPSDAGKernelOpP19NSMutableDictionaryIP8NSStringPU22objcproto11MTLFunction11objc_objectEP14NSMutableArrayIS6_ERDv4_yPb + 8	6   MPSCore                             0x00000001a113c7a4 _ZN12MPSKernelDAG13getDAGAndHashEPU21objcproto10MTLLibrary11objc_objectP14MPSDAGKernelOpP19NSMutableDictionaryIP8NSStringPU22objcproto11MTLFunction11objc_objectEP14NSMutableArrayIS6_ERDv4_yPb + 1	7   MPSCore                             0x00000001a11c03c8 _ZN10MPSLibrary19CreateUberShaderKeyEP8NSStringRK23MPSFunctionConstantListyPFPU22objcproto11MTLFunction11objc_objectPU21objcproto10MTLLibrary11objc_objectPK13MPSKernelInfoS4_RK33MPSFunctionConstr	8   MPSNDArray                          0x00000001a27b546c MPSSetResourcesOnCommandEncoder + 154176
	9   MPSNDArray                          0x00000001a27967d8 MPSSetResourcesOnCommandEncoder + 28076
	10  MPSNDArray                          0x00000001a2798ec8 MPSSetResourcesOnCommandEncoder + 38044
	11  MetalPerformanceShadersGraph        0x00000001f97689ac _ZN3GPU17IdentityOpHandler15encodeNDArrayOpEPNS_16EncodeDescriptorEP7NSArray + 436
	12  MetalPerformanceShadersGraph        0x00000001f977f93c _ZN3GPU17StitchedOpHandler8encodeOpEPNS_16EncodeDescriptorE + 924
	13  MetalPerformanceShadersGraph        0x00000001f9544898 _ZN16GPURegionRuntime5runOpIN3GPU23AbsoluteSquareOpHandlerEEEvPN4mlir9OperationEPNS1_16EncodeDescriptorE + 120
	14  MetalPerformanceShadersGraph        0x00000001f9543894 _ZN16GPURegionRuntime8encodeOpEPN4mlir9OperationEPN3GPU16EncodeDescriptorE + 4700
	15  MetalPerformanceShadersGraph        0x00000001f954251c _ZN16GPURegionRuntime29encodeOpWithCommitAndContinueEPN4mlir9OperationEPN3GPU16EncodeDescriptorE + 92
	16  MetalPerformanceShadersGraph        0x00000001f954189c _ZN16GPURegionRuntime11evaluateOpsEPN3GPU16EncodeDescriptorEP7NSArrayIP18MPSGraphTensorDataES7_ + 3572
	17  MetalPerformanceShadersGraph        0x00000001f953f7b4 _ZN10MPSRuntime11evaluateOpsEN4mlir4func6FuncOpEP21RuntimeSpecializationP7NSArrayIP18MPSGraphTensorDataES9_P37MPSGraphExecutableExecutionDescriptorP16MPSCommandBufferbbbPb + 824
	18  MetalPerformanceShadersGraph        0x00000001f988dd38 -[MPSGraphExecutable runInternalWithDevice:commandBuffer:feeds:results:executableExecutionDescriptor:mpsGraphOwnedCommandBuffer:] + 3848
	19  MetalPerformanceShadersGraph        0x00000001f988ca04 -[MPSGraphExecutable runInternalWithDevice:commandBuffer:feedsDictionary:resultsDictionary:executableExecutionDescriptor:mpsGraphOwnedCommandBuffer:] + 608
	20  MetalPerformanceShadersGraph        0x00000001f9728aa0 -[MPSGraph runInternalWithMPSCommandBuffer:feeds:targetTensors:targetOperations:resultsDictionary:executionDescriptor:mpsGraphOwnedCommandBuffer:] + 320
	21  MetalPerformanceShadersGraph        0x00000001f9727b58 -[MPSGraph encodeToCommandBuffer:feeds:targetOperations:resultsDictionary:executionDescriptor:] + 188
	22  libtorch_cpu.dylib                  0x00000001556c9478 ___ZN2at3mps9MPSStream15executeMPSGraphEP8MPSGraphP12NSDictionaryS5_NS0_8SyncTypeE_block_invoke + 128
	23  libdispatch.dylib                   0x0000000194a3985c _dispatch_client_callout + 16
	24  libdispatch.dylib                   0x0000000194a2f7a8 _dispatch_lane_barrier_sync_invoke_and_complete + 56
	25  libtorch_cpu.dylib                  0x00000001556c93e0 _ZN2at3mps9MPSStream15executeMPSGraphEP8MPSGraphP12NSDictionaryS5_NS0_8SyncTypeE + 160
	26  libtorch_cpu.dylib                  0x00000001556fd0f4 _ZN2at6native3mpsL14binaryOpTensorERKNS_6TensorES4_S4_NSt3__112basic_stringIcNS5_11char_traitsIcEENS5_9allocatorIcEEEEU13block_pointerFP14MPSGraphTensorPNS1_19BinaryOpCachedGraphESD_SD_E + 3040
	27  libtorch_cpu.dylib                  0x00000001556ff680 _ZN2at6native24structured_hypot_out_mps4implERKNS_6TensorES4_S4_ + 84
	28  libtorch_cpu.dylib                  0x00000001522682e4 _ZN2at12_GLOBAL__N_117wrapper_MPS_hypotERKNS_6TensorES3_ + 216
	29  libtorch_cpu.dylib                  0x0000000153a1378c _ZN3c104impl28wrap_kernel_functor_unboxed_INS0_6detail24WrapFunctionIntoFunctor_INS_26CompileTimeFunctionPointerIFN2at6TensorENS_14DispatchKeySetERKS6_S9_EXadL_ZN5torch8autograd12VariableType12_G	30  libtorch_cpu.dylib                  0x0000000151241714 _ZN2at4_ops5hypot4callERKNS_6TensorES4_ + 304
	31  libtorch_python.dylib               0x0000000105d9a848 _ZN5torch8autogradL17THPVariable_hypotEP7_objectS2_S2_ + 752
	32  Python                              0x00000001036afa7c cfunction_call + 72
	33  Python                              0x000000010365db08 _PyObject_MakeTpCall + 124
	34  Python                              0x0000000103750f40 _PyEval_EvalFrameDefault + 23304
	35  Python                              0x000000010374b1c8 PyEval_EvalCode + 184
	36  Python                              0x00000001037ab8bc run_eval_code_obj + 88
	37  Python                              0x00000001037a9994 run_mod + 132
	38  Python                              0x00000001037a8fdc PyRun_StringFlags + 124
	39  Python                              0x00000001037a8f08 PyRun_SimpleStringFlags + 64
	40  Python                              0x00000001037cd464 Py_RunMain + 716
	41  Python                              0x00000001037cd950 pymain_main + 304
	42  Python                              0x00000001037cd9f0 Py_BytesMain + 40
	43  dyld                                0x0000000194836b98 start + 6076
)
libc++abi: terminating due to uncaught exception of type NSException
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166216
Approved by: https://github.com/Skylion007
ghstack dependencies: #166210
2025-10-25 08:51:38 +00:00
661a56002f [AI Codemod][DevmateFBSourceTestFailureBot] Fix for T241916639 ("Your diff, D84932408, broke one test") (#166168)
Reviewed By: XilunWu

Differential Revision: D84983164

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166168
Approved by: https://github.com/H-Huang, https://github.com/fduwjj
2025-10-25 06:46:23 +00:00
c9bc00f016 Split grouped_mm methods into their own file (#166140)
Summary:

`Blas.cpp` was getting a little full and hard to work with, split out
the `*_grouped_mm` methods into their own file

Test Plan:

```
pytest -svv -k group test/test_matmul_cuda.py
pytest -svv -k group test/test_scaled_matmul_cuda.py
```

Reviewers:

Subscribers:

Tasks:

Tags:
Signed-off-by: Simon Layton <simonlayton@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166140
Approved by: https://github.com/drisspg
ghstack dependencies: #166139
2025-10-25 05:40:31 +00:00
ec51b139e1 Factor out shared scaled mm routines (#166139)
Summary:

In preparation for splitting out scaled grouped mm functions, factor out
scaled-specific routines into their own file(s)

Test Plan:

```
pytest -svv test/test_scaled_matmul_cuda.py
```

Reviewers:

Subscribers:

Tasks:

Tags:
Signed-off-by: Simon Layton <simonlayton@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166139
Approved by: https://github.com/drisspg
2025-10-25 05:40:31 +00:00
eb83c3ca23 Clean up unused Pyrefly suppressions (#166178)
Cleaning up ignores that are no longer needed in the repo and adding select suppressions so the main branch is clean.

test plan:
`lintrunner -a`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166178
Approved by: https://github.com/oulgen
2025-10-25 05:32:21 +00:00
7924e3aacf Remove likely unnecessary _EXPAND trick for non-windows in HIDDEN_NAMESPACE_BEGIN (#166203)
I've learned that the EXPAND trick is needed mostly for an MSVC quirk to properly expand arguments. I tested on Linux locally and suspect that we don't need the _EXPAND for non-windows.  This PR is BE to minimalize what we need and remove what we don't, but I'm also okay not landing this if @malfet tells me that this quirk goes beyond MSVC.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166203
Approved by: https://github.com/malfet
ghstack dependencies: #166076, #166077, #166078, #166079
2025-10-25 04:44:07 +00:00
78bcfcf870 [fx] Optimize torch.fx.Node.replace_all_uses_with (#165889)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165889
Approved by: https://github.com/aorenste
2025-10-25 03:44:41 +00:00
1e2e7cb18b Add doc for Symmetric Memory (#166148)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166148
Approved by: https://github.com/fduwjj
2025-10-25 03:41:15 +00:00
003601a70d Set prefer_deferred_runtime_asserts_over_guards to True (#165820)
Set prefer_deferred_runtime_asserts_over_guards to True and allow a flag to control the behavior, just in case.

This option has enable the gemma3 model export with transformers==4.57. I am not sure how best to test it though.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165820
Approved by: https://github.com/titaiwangms
2025-10-25 03:38:19 +00:00
1d58d5fe25 [hops] fix unbacked runtime asserts for cond higher order op (#165893)
At a high level after this fix we get the following nice tlparse https://manifold.edge.x2p.facebook.net/v0/read/tree/logs/bobren/54a57665-7dcc-41e0-8ca7-df01393cd4aa/custom/index.html?bucketName=tlparse_reports&apiKey=tlparse_reports-key&withPayload=1&timeoutMsec=10000

As seen in this doc, previously we were simply dropping assert post
dynamo: https://docs.google.com/document/d/1nRQwvw_gWL0_9T3VKb5Ly3_tNI1fgqG9WtryeD6qaZI/edit?tab=t.0

The fixes are a couple things:

1) Actually run the runtime assertion fx graph pass on subgraphs
2) Reset fake mode unbacked memo across speculate subgraph invocations
   since the memos actually break the runtime assertion insertions since
   calls like nonzero end up not allocating new unbacked symints and
   hence not populating pending_unbacked which then results in incorrect
   unbacked_bindings on fx_nodes in subgraphs.

This is a first step in hardening runtime asserts across all phases of
the compiler (eager, aot_eager, inductor, etc.). I will continue kicking
tires and fixing bugs until we get runtime assert generations in a good
place. One obvious next step is the added test case in this PR fails
when compiled with inductor with the following error (NB: it fails before this PR as well):

```
  File "/data/users/bobren/a/pytorch/torch/_inductor/ir.py", line 659, in get_dtype
    return self.dtype
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
LoweringException: AttributeError: 'ShapeAsConstantBuffer' object has no attribute 'dtype'
  target: cond
  args[0]: Eq(Mod(s77, 4), 0)
  args[1]: Subgraph(name='true_graph_0', graph_module=<lambda>(), graph=<torch._inductor.graph.SubgraphLowering object at 0x7fbcbb11e110>)
  args[2]: Subgraph(name='false_graph_0', graph_module=<lambda>(), graph=<torch._inductor.graph.SubgraphLowering object at 0x7fbcbb21cf70>)
  args[3]: (s77, TensorBox(StorageBox(
    ComputedBuffer(name='buf0', layout=FlexibleLayout('cuda:0', torch.float32, size=[s77, s77], stride=[s77, 1]), data=Pointwise(device=device(type='cuda', index=0), dtype=torch.float32, inner_fn=<function make_pointwise.<locals>.inner.<locals>.inner_fn at 0x7fbcbb2f37f0>, ranges=[s77, s77]))
  )))
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165893
Approved by: https://github.com/zou3519
2025-10-25 03:25:36 +00:00
de7fdfe41a Export flex attention with kwargs and DTensor (#166045)
Fixes #165948

Adding registration of the MaskBlock makes flex attention with kwargs exportable.

Also modified unittests to accept kwargs

```
python test/distributed/tensor/test_dtensor_export.py -k test_flex_attention_dtensor_export

python test/inductor/test_flex_attention.py -k test_pytree_
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166045
Approved by: https://github.com/drisspg
2025-10-25 03:17:22 +00:00
b31bad1b8f [Pytorch] Enable autovec on aarch64 for type conversion (#166049)
Summary:
Implementing autovec template for type conversions on aarch64-NEON

Generated code can be seen here: https://godbolt.org/z/1K6T1d9TE

We've seen significant performance improvements for converting to and from bytes, compiling using clang with -march=armv9-a+sve2:

Before
float->uint8->float ===> 683.212us
float->int8->float ===> 687.846us
int32->uint8->int32 ===> 497.121us
int32->int8->int32 ===> 481.889us

After:
float->uint8->float ===> 198.204us  ----> 245% higher throughput
float->int8->float ===> 200.241us ----> 244% higher throughput
int32->uint8->int32 ===> 197.970us ----> 151% higher throughput
int32->int8->int32 ===> 198.206us ----> 143% higher throughput

Test Plan:

buck2 test mode/opt //caffe2/test:test_ops
buck2 test mode/opt //caffe2/test:torch

Differential Revision: D85213420

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166049
Approved by: https://github.com/ezyang, https://github.com/mcfi, https://github.com/aditew01
2025-10-25 02:55:50 +00:00
2efcf3ca98 Reverts #163712 and forces allgather/scatter inputs/outputs to be contiguous (#166181)
Per title

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166181
Approved by: https://github.com/kwen2501
2025-10-25 02:43:10 +00:00
761f946043 [ROCm] new implementation of upsample_bilinear2d_backward (#164572)
Changed the implementation from an output-based approach to an input-based one to remove `atomicAdd` operations, and it appears to deliver at least a 20× speedup.

The changes are from Yu-Yun <YuYun.Chang@amd.com>.

# Summary: Refactor of the implementation of the `upsample_bilinear2d_backward` opertion on MI300X/MI325X
- The original "scatter-add" approach
  - Each thread, representing an output pixel, scattered gradient contributions to four input pixels, using costly atomic operations on MI300X/MI325X GPUs.
- The new "gather-sum" approach
  - Each thread is responsible for a single input pixel and gathers all relevant gradient contributions from a small, calculated region of the output tensor (done by the `compute_output_range` device function).
# Breakdown of the code changes
- Inversion of the parallelization strategy of the kernel function `upsample_bilinear2d_backward_out_frame`
  - Originally, the main kernel loop was parallelized over the number of elements in the output gradient tensor (`const size_t o_numel = nc * width2 * height2;`).
    - Each thread processed one output pixel.
  - The new loop is parallelized over the number of elements in the input gradient tensor (`const size_t i_numel = nc * height1 * width1;`).
    - Each thread is responsible for calculating the final gradient for a single input pixel.
  - The kernel launch changes accordingly in the function `upsample_bilinear2d_backward_out_cuda_template`.
- Added a device function for calculating the range of output pixels that could have possibly used that the input pixel (`input_pos`) during the forward pass interpolation
  - This is essentially the mathematical inverse of the forward pass.
  - This function tries to prune a thread's search space so that it only needs to inspect a small, local window of the output tensor.
- Gradient calculation approach switching from "scatter-add" to "gather-sum"
  - Scatter-add
    - For each output pixel, the thread calculated 4 gradient contributions and use `fastAtomicAdd` 4 times to add these values to 4 different (and potentially highly contended) memory locations in the input gradient tensor.
  - Gather-sum
    - A thread responsible for one input pixel calls `compute_output_range` to determine the small rectangular region of output pixels that influence the input's final gradient value.
    - The thread iterates through this region, and for each output pixel in the regionre, it re-calculates the interpolation weights to determine the exact contribution to its specific input pixel.
    - All these contributions are accumulated into a private, per-thread register variable (`accscalar_t grad_sum = 0;`).
      - W/o any gloabl memory access, this accumulation is extremely fast.
    - When the loops are done, the thread performs a single, direct write (non-atomic) of the final summed gradient to its designated location in global memory (`idata[index] = static_cast<scalar_t>(grad_sum);`).
# Why performance gets boosted
- Analysis of the root cause of performance drop
  - Ref. (internal only) - https://amd.atlassian.net/wiki/spaces/~glencao2/pages/1140493327/PyTorch__upsample_bilinear2d_backward
- First and foremost, elimination of the contention of atomic operations
  - Many parallel threads called `atomicAdd` frequently attempting to update the exact same memory location in the input gradient tensor at the same time.
    - The GPU's memory controler has to serialize these operations, effectively nullifying the benefit of parallel capability at those contention points.
  - MI300X/MI325X chiplet-based CDNA 3 architeture amplified the issue.
    - When contending threads reside on different XCDs, resolving the atomic operation requires high-latency coherence traffic across the Infinity Fabric interconnect.
  - The implementation change eliminates hardware-level serialization and cross-chiplet coherence traffic caused by many `atomicAdd`.
- Improved memory access pattern and locality
  - Write coalescing
    - The regular sum writes `idata[index] = static_cast<scalar_t>(grad_sum);` can be perfectly coalesced by GPUs.
  - Read locality
    - Even though there are many (potentially repeated) reads from the output tensor (`static_cast<accscalar_t>(odata[output_idx])`), these are highly cache-friendly, meaning the data for one thread is likely to be in the L1 or L2 cache already due to an access from a neighboring thread.
- Trade-off: computation for memory synchronization
  - The recalculation of interpolation weights fits well on high-computational-throughput modern GPUs like MI300X/MI325X.
  - Removal of atomic operations avoids expensive memory synchronization.

---

Optimizations of `grid_sampler_2d_backward` will be addressed in a separate PR.
Doc for reference: (internal only) https://amd.atlassian.net/wiki/spaces/~glencao2/pages/1162750701/PyTorch__grid_sampler_2d_backward

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164572
Approved by: https://github.com/jeffdaily

Co-authored-by: Eli Uriegas <1700823+seemethere@users.noreply.github.com>
2025-10-25 02:39:24 +00:00
8aa465f18e [MPS] Migrate angle to Metal ops (#166210)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166210
Approved by: https://github.com/Skylion007
2025-10-25 01:52:33 +00:00
0a5d68d92d [dynamo] Remove unnecessary NAME_MATCH guard (#166112)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166112
Approved by: https://github.com/Lucaskabela
ghstack dependencies: #166155
2025-10-25 01:27:42 +00:00
42bd210fff [dynamo] Avoid ID_MATCH on methods - use CLOSURE_MATCH on functions (#166155)
id on methods can change from invocation to invocation. Here we guard on
__code__ objects which does not change

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166155
Approved by: https://github.com/jansel
2025-10-25 01:27:42 +00:00
1d13c314b3 [OpenReg] Remove the Unnecessary Fallback Implementation for AutogradPrivate1 (#165316)
As the title stated.

The fallback for AutogradPrivateUse1 is builtin in PyTorch, so it is no need to register general implementation for out of tree backend.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165316
Approved by: https://github.com/ezyang, https://github.com/albanD
ghstack dependencies: #165315
2025-10-25 01:27:27 +00:00
0c9763a5a0 [Autograd] Add Default Autograd Fallback for PrivateUse1 in PyTorch (#165315)
Please refer to this [link](https://github.com/pytorch/pytorch/issues/163979) for more background.

- Allow register fallback for AutogradPrivateUse1 multiple.
- Add Autograd fallback implemetation for AutogradPrivateUse1

PyTorch can privide a common implementation for AutogradPrivateUse1, and the user can override it based on the need of specififc accelerator.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165315
Approved by: https://github.com/albanD
2025-10-25 01:27:27 +00:00
79a4a9c02e Fix race condition and make CUDA kthvalue deterministic (#165762)
The gatherKthValue kernel had a race condition where multiple threads could write to the same output location without synchronization when duplicate k-th values exist, resulting in non-deterministic output.

Changes:
- aten/src/ATen/native/cuda/Sorting.cu: Use atomicMin with shared memory to deterministically find minimum index. Add early termination and remove redundant inRange checks. (We have to cast the index to `int32_t`, but this is already assumed to fit earlier in the kernel.)
- aten/src/ATen/native/cuda/Sorting.cpp: Remove non-deterministic alert since kthvalue is now deterministic on CUDA.
- torch/__init__.py: Remove kthvalue from non-deterministic operations list and remove kthvalue example from use_deterministic_algorithms() docstring.
- test/test_torch.py: Remove test_nondeterministic_alert_kthvalue since kthvalue no longer raises alerts on CUDA.

Benefits:
- Deterministic: always returns minimum index when duplicates exist
- Potential performance improvement on large arrays with repetitions

Test Results:
- All existing PyTorch tests pass (test_kthvalue)
- Custom determinism tests confirm consistent results
- Custom CUDA vs CPU correctness validated across 50+ scenarios
- Custom performance benchmarks show improvements with no visible regressions

Addresses #165227

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165762
Approved by: https://github.com/ngimel, https://github.com/eqy
2025-10-25 00:45:57 +00:00
9d0b77f4cd [10/N] Apply ruff UP035 rule (#165709)
This is a follow-up of #165515. ruff `UP035` rules are applied to  dynamo code to use Py 3.10+ typing.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165709
Approved by: https://github.com/ezyang
2025-10-25 00:20:13 +00:00
d486eee234 Hide APIs in torch::headeronly (#166079)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166079
Approved by: https://github.com/malfet, https://github.com/cyyever
ghstack dependencies: #166076, #166077, #166078
2025-10-25 00:18:26 +00:00
cddd5f74ab Hide stable Library structs instead of using anon namespace (#166078)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166078
Approved by: https://github.com/malfet
ghstack dependencies: #166076, #166077
2025-10-25 00:18:26 +00:00
dfdb68e51f Hide all APIs in torch::stable (#166077)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166077
Approved by: https://github.com/malfet
ghstack dependencies: #166076
2025-10-25 00:18:26 +00:00
98c818320a Add HIDDEN_NAMESPACE_BEGIN and END macros for hiding header APIs (#166076)
Spurred by the conversation started in https://github.com/pytorch/pytorch/issues/163343.

Context:
* Header implementations may be inlined _but_ are not necessarily inlined, even when using the `inline` keyword.
* When someone wants to use multiple extensions in the same runtime, e,g., with FA3 and AO, then 2 `.so`s are loaded that may have been built with different libtorch versions. Thus, if an API is not inlined and are differently implemented, one implementation will be arbitrarily picked up and used across the runtime, depending on link order. This is bad!
* Consequently, we need to be very good at guaranteeing that we don't modify header implementations within a namespace. This is easy to mess up by accident, which would be a dire mistake.

Solution:
In essence, we want APIs in torch::headeronly and torch::stable to be visible in each individual extension only, and nowhere else. We want to hide these symbols! Thankfully, pybind already solved this problem (thanks @malfet for bringing that to my attention). This PR is heavily inspired by the code in pybind here: e6984c805e/include/pybind11/detail/pybind11_namespace_macros.h (L73-L82).

In this PR, we introduce the macros for defining hidden namespaces in PyTorch.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166076
Approved by: https://github.com/malfet
2025-10-25 00:18:26 +00:00
cc20b7ad72 [FlexFlash] update names (#166193)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166193
Approved by: https://github.com/BoyuanFeng
2025-10-25 00:07:11 +00:00
bc11a42b3f [inductor][ez] fix score fusion memory typo (#166029)
Fix https://github.com/pytorch/pytorch/issues/165724 .
The typo does not affect the compilation result. It just affect compilation time a little bit.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166029
Approved by: https://github.com/eellison
2025-10-24 23:48:05 +00:00
4fc06f2e0a Use std::min for #165927 (#166199)
Summary: Like D85463674 (pr https://github.com/pytorch/pytorch/pull/166195) but for D85357351 (https://github.com/pytorch/pytorch/pull/165927)

Differential Revision: D85464917

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166199
Approved by: https://github.com/Camyll, https://github.com/malfet, https://github.com/Skylion007
2025-10-24 23:19:00 +00:00
82473c3d59 [torch.export] Add original module type to UnflattenedModule class (#166145)
Summary: Currently all sub modules of UnflattenedModule have orginal type name. This diff will orginal type for UnflattenedModule.

Test Plan:
```
buck test mode/opt caffe2/test:test_export
```
https://www.internalfb.com/intern/testinfra/testrun/17732923654320197

Differential Revision: D85373454

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166145
Approved by: https://github.com/angelayi
2025-10-24 22:47:29 +00:00
b6a4236e5d [label_to_label] minor updates (#166172)
vllm-compile implies "module: vllm" and "oncall: pt2".
The volume of issues in Flex -> HigherOrderOperators is too noisy,
plus we have a different set of folks looking at each, so I'm going to
make that not automatic anymore. We can still manually label flex issues
as higher order operator issues.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166172
Approved by: https://github.com/angelayi
2025-10-24 22:47:23 +00:00
b04173be9b [ONNX] Add a test to backed_size_oblivious patch in onnx (#166196)
Follow-up https://github.com/pytorch/pytorch/pull/166151

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166196
Approved by: https://github.com/justinchuby
2025-10-24 22:47:10 +00:00
32ac38f85d [lint] workflow consistency linter to look at all files instead of just changed files (#165171)
As in title

If you change only one workflow file, lintrunner (default arg, also the one in CI since it only inputs changed files) won't look at other files in the repo, but the sync-tag might come from those other files

This makes it so that it looks at all workflow files so it will catch those failures

Also change output line so it prints which file + which job it is different from

Pros:
catches errors

Cons:
unusual behavior (getting around what lintrunner says the linter should run on)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165171
Approved by: https://github.com/malfet, https://github.com/izaitsevfb, https://github.com/atalman
2025-10-24 21:43:18 +00:00
c9b49e506e [MPS] Add linalg.householder_product for MPS (#166090)
Fixes #166089
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166090
Approved by: https://github.com/malfet
2025-10-24 21:13:56 +00:00
6038e476e8 [Dynamo][Logging]Fix regression on stack adding to latest bytecode by… (#165946)
… adding verbose check (#165926)

[ghstack-poisoned]

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165946
Approved by: https://github.com/williamwen42
2025-10-24 20:36:50 +00:00
2c851c16e5 [FX][ez] fix the split_module tutorial code (#166154)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166154
Approved by: https://github.com/BoyuanFeng
2025-10-24 20:16:04 +00:00
31584f2d91 Add a Claude skill for writing docstrings. (#166175)
Generated with prompt:

> torch/_tensor_docs.py and torch/nn/functional.py contain the "gold standard" for docstrings in the PyTorch project. Write a skill describing how to write a docstring for a function/method in the PyTorch project. Note that add_docstring is specifically for C binded functions; a native Python function can just be a direct docstring. Sphinx is used to generate docs.

Signed-off-by: Edward Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166175
Approved by: https://github.com/Skylion007
2025-10-24 20:05:44 +00:00
0442125362 [Inductor] Restore original dtype for rank-0 CPU tensors (#166118)
# Problem
Inductor implicitly upcasts certain rank-0 kernel arguments from float16 to float32. Currently, this happens only on the `"cpu"` device, which appears to be related to float16 support in CPU Triton. However, it can also affect the behavior of GPU kernels, when a model contains tensors from multiple devices. Upcasting may be undesirable on some platforms, so users can typically disable it with the `config.triton.codegen_upcast_to_fp32` flag. However, this flag was not respected by the rank-0 kernel argument codepath.

Through an improbable series of events, float32 upcasting caused an internal model to fail compilation on MTIA. (Internal reviewers see T242444110.)

# Fix
If `config.triton.codegen_upcast_to_fp32` evaluates to `False`, cast the kernel argument to the original dtype.

# Test plan
Added a new CI test checking for the downcast iff the config flag is false. The test mixes GPU and CPU tensors to generate a GPU kernel with the implicit float32 upcast and explicit float16 downcast.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166118
Approved by: https://github.com/jfix71, https://github.com/jansel, https://github.com/kundaMwiza
2025-10-24 19:59:25 +00:00
fdcf402d82 vllm test build (#166146)
FIx the vllm test build it's broken due to the flashinfer dependency
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166146
Approved by: https://github.com/huydhn
2025-10-24 19:18:10 +00:00
13cda9b89e Allow BlockDescriptorOptions classes to be overridden In TritonKernel (#165899)
By allowing the options classes (`BlockPtrOptions`/`TensorDescriptorOptions`) to be overridden in `TritonKernel`, subclasses with custom behaviour can be used in place of them, which provides greater flexibility.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165899
Approved by: https://github.com/jansel
2025-10-24 18:59:59 +00:00
fa6d911dda [MPS] Sparse mul enable tests and fix on MPS (#166164)
Apparently mul tests in test_sparse were disabled. The dense representation i.e. when nnz is not a scalar was broken on MPS. This PR fixes it and enables the tests in test_sparse.py

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166164
Approved by: https://github.com/malfet
2025-10-24 18:30:30 +00:00
0db6bcc015 Fix accuracy for layernorm/rmsnorm benchmarking (#166005)
Example command:
    python benchmarks/dynamo/genai_layers/benchmark.py --exit-on-accuracy-failure --tolerance=1e-2 rmsnorm_backward

Fix the accuracy problem for layernorm/rmsnorm fwd/bwd.
Also fix some quack calls (maybe due to quack API change)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166005
Approved by: https://github.com/BoyuanFeng
2025-10-24 18:14:51 +00:00
eqy
60ac039998 [CUDA][Grouped Gemm] remove xFail on Group GEMM tests after fallback was added (#165378)
https://github.com/pytorch/pytorch/pull/162059 means we get unexpected successes now on e.g., SM 12.0

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165378
Approved by: https://github.com/Skylion007
2025-10-24 17:42:40 +00:00
380d440d1c Revert "inductor: avoid unrolling argmin/argmax reductions to preserve index … (#164040)"
This reverts commit 9038a30cee56e0d577a666fffa32e990732572d4.

Reverted https://github.com/pytorch/pytorch/pull/164040 on behalf of https://github.com/karthickai due to Kindly add the test case mentioned in the issue ([comment](https://github.com/pytorch/pytorch/pull/164040#issuecomment-3444137989))
2025-10-24 17:14:45 +00:00
9038a30cee inductor: avoid unrolling argmin/argmax reductions to preserve index … (#164040)
…semantics on views; add regression test for transposed mutation (fixes #163929)

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164040
Approved by: https://github.com/ngimel, https://github.com/jansel
2025-10-24 16:37:43 +00:00
690c8c13b9 Revert "Export should use aot_export_joint_with_descriptors (#165931)"
This reverts commit 882b834082719afd8ee41769c2cb224bc9031632.

Reverted https://github.com/pytorch/pytorch/pull/165931 on behalf of https://github.com/clee2000 due to breaking internal tests D85084301 for test_auto_functionalize?  I checked that they did run on OSS CI so I'm not entirely sure whats going on, I assume its the IS_FBCODE stuff ([comment](https://github.com/pytorch/pytorch/pull/165931#issuecomment-3443887361))
2025-10-24 16:02:20 +00:00
28ee6b62ed Revert "[DeviceMesh] Implement a device mesh concatenate api for submesh and SPMD use case (#163358)"
This reverts commit 5a4997dcae47acf69c929ac5b081321143bfbf11.

Reverted https://github.com/pytorch/pytorch/pull/163358 on behalf of https://github.com/clee2000 due to probably need to revert this one  too, its stacked with https://github.com/pytorch/pytorch/pull/166003#issuecomment-3443668389 ([comment](https://github.com/pytorch/pytorch/pull/163358#issuecomment-3443874910))
2025-10-24 15:58:54 +00:00
81577bdb3f Revert "[DeviceMesh] Use _flatten_rank_map to replace _flatten_mesh_list so that we don't need to compare root mesh (#166003)"
This reverts commit 8625ffbd45884464f736cfc61300c14f47633641.

Reverted https://github.com/pytorch/pytorch/pull/166003 on behalf of https://github.com/clee2000 due to failing internal tests D85405179 I believe there are uses of _flatten_mesh_list internally that need to be updated ([comment](https://github.com/pytorch/pytorch/pull/166003#issuecomment-3443668389))
2025-10-24 15:14:23 +00:00
e67e3d95f3 Simplify the CUPTI CMake check for kineto (#161370)
Simplify the CUPTI check because kineto has used `CUDA::cupti`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/161370
Approved by: https://github.com/Skylion007
2025-10-24 08:13:17 +00:00
27af8480ea Refactor api and configs of overlapping (#166130)
- pass important configs values directly into the class
- migrate those configs from `test_configs` to another class
- add an (off by default) config to enable inside inductor, instead of requiring a custom post pass

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166130
Approved by: https://github.com/bdhirsh
2025-10-24 07:03:54 +00:00
6494cdc40c [DebugMode] add nn.Module tracking (#165498)
Uses ModTracker to record nn.Module entries, much like CommDebugMode.

Can be switched on with `DebugMode(record_nn_module=True)`:
```
    [nn.Mod] Bar
      [nn.Mod] Bar.abc
        [nn.Mod] Bar.abc.l1
          aten::t(t: f32[4, 4])
          aten::addmm(t: f32[4], t: f32[4, 4], t: f32[4, 4])
        [nn.Mod] Bar.abc.l2
          aten::t(t: f32[4, 4])
          aten::addmm(t: f32[4], t: f32[4, 4], t: f32[4, 4])
      [nn.Mod] Bar.xyz
        aten::t(t: f32[4, 4])
        aten::addmm(t: f32[4], t: f32[4, 4], t: f32[4, 4])"""
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165498
Approved by: https://github.com/SherlockNoMad
2025-10-24 05:08:33 +00:00
ac7074efa2 [CUDA][cuBLAS] Fix a compilation issue in #163955 when CUDA_VERSION < 12010 (#166137)
Summary:
This PR fixes a compilation issue when `CUDA_VERSION < 12010`. Even if we might drop old CUDA support, let's correct the code itself.

## Issue
When `CUDA_VERSION` is `12010`, the following does not compile.
```
      mat1_sizes[0] > 1 && mat1_sizes[1] > 1 &&
      mat2_sizes[0] > 1 && mat2_sizes[1] > 1
      #if !(defined(CUDA_VERSION) && CUDA_VERSION >= 12010 || defined(USE_ROCM))
      // Here not "&&"
      mat2_sizes[0] < 65535 * 32 && mat2_sizes[1] < 65535 * 32 &&
```
This patch adds "&&"

Test Plan: CI

Differential Revision: D85356831

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166137
Approved by: https://github.com/ngimel, https://github.com/cyyever
2025-10-24 04:06:03 +00:00
263901cec4 [pytorch/kineto] Update Kineto Submodule (#166150)
Summary: Update to include some race condition fixes.

Test Plan: n/a

Differential Revision: D85390799

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166150
Approved by: https://github.com/sraikund16, https://github.com/cyyever
2025-10-24 04:03:13 +00:00
c12293dcbe [ONNX] Cover all FX passes into backed size oblivious (#166151)
Found a bug that after `run_decomposition()`, the shape could be fixed to 1. It's caused by the fact that all FX graph (related to shape inference) surgery should happen inside backed size oblivious patch.

```python
import torch
from transformers.models.phi3.modeling_phi3 import Phi3RMSNorm

# Previous to this PR, this will generate a fixed batch size
op = torch.onnx.export(
    Phi3RMSNorm(256).eval(),
    args=(),
    kwargs={"hidden_states": torch.rand((1, 32, 256))},
    dynamic_shapes={"hidden_states": {0: torch.export.Dim.DYNAMIC, 1: torch.export.Dim.DYNAMIC}},
)

# It is dynamic when it's only in torch.export
with torch.fx.experimental._config.patch(backed_size_oblivious=True):
    ep = torch.onnx.export(
    Phi3RMSNorm(256).eval(),
    args=(),
    kwargs={"hidden_states": torch.rand((1, 32, 256))},
    dynamic_shapes={"hidden_states": {0: torch.export.Dim.DYNAMIC, 1: torch.export.Dim.DYNAMIC}},
)
# But when run_decomposition is called outside of the patch, it is static.
# ep = ep.run_decompositions()
print(ep)

```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166151
Approved by: https://github.com/justinchuby
2025-10-24 03:25:16 +00:00
5a4997dcae [DeviceMesh] Implement a device mesh concatenate api for submesh and SPMD use case (#163358)
Today FSDP needs to slicing out spmd mesh from root mesh here: https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/_fully_shard/_fsdp_param.py#L301. But essentially, users want is a concatenate of some submesh into a big mesh and used as a spmd mesh. This PR is tentatively trying to implement this API for users.

One thing to note is that, all sub-mesh needs to slicing/flatten or unflatten from same root mesh otherwise the indices make no sense when it comes to mesh indexing and device allocation.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163358
Approved by: https://github.com/fegin
ghstack dependencies: #166003
2025-10-23 23:31:17 +00:00
47f638eae7 [ROCm] deserialize loads in planer sum portion of stats() of norm (#166021)
Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166021
Approved by: https://github.com/jeffdaily
2025-10-23 22:47:42 +00:00
882b834082 Export should use aot_export_joint_with_descriptors (#165931)
This diff moves export run_decompositions to use aot_export_joint_with_descriptors instead of aot_export_module. Doing so, i ran into 2 main bugs:
1) aot_export_joint_with_descriptors don't correctly pass in record_nn_module_stack flag that is needed to populate nn_module_stack by switching the internal tracer.
2) When creating symint with negative inputs, we need to pass in positive=False. This didn't matter before because aot_autograd directly returns integer inputs instead of creating symint.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165931
Approved by: https://github.com/zhxchen17
2025-10-23 22:42:11 +00:00
b146ea411e Save GitHub env variables on ROCm (#165821)
As `.github/actions/setup-rocm/action.yml` is now used on `linux_job_v2` to setup ROCm, we need to have this step here to save the list of GitHub env variables.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165821
Approved by: https://github.com/atalman
2025-10-23 22:13:37 +00:00
8625ffbd45 [DeviceMesh] Use _flatten_rank_map to replace _flatten_mesh_list so that we don't need to compare root mesh (#166003)
Since we are already share a flattened tensor `_rank_map` across all meshes from a same root mesh, we can just use a flattened list of it to replace the comparison of root_mesh and flattened_mesh_list (because with same _rank_map and layout, the mesh tensor is guaranteed to be the same). This way we can also give back the CPU overhead added in https://github.com/pytorch/pytorch/pull/164510 and further simply the code.

We do have a more ambitious universe-based change here: https://github.com/pytorch/pytorch/pull/165680 but it needs more discussions and would lead to BC breaking. We might eventually merge that PR but probably not now and this is a change which is not BC breaking and will help concatenate and 2D integration with concatenate.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166003
Approved by: https://github.com/Skylion007, https://github.com/fegin
2025-10-23 20:49:59 +00:00
0977cc4474 [lint] Extend workflowsync linter to more files (#166082)
And fix the lint issues found
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166082
Approved by: https://github.com/izaitsevfb, https://github.com/atalman
2025-10-23 20:29:29 +00:00
d9a55faccc [Pytorch] Add NEON Vectorized<double> translation layers (#166092)
Summary:
Adding NEON specializations of Vectorized<double>

Correcness has been checked using test_ops.py and running torch test

Test Plan:
Correctness:

buck2 test mode/opt //caffe2/test:test_ops
buck2 test mode/opt //caffe2/test:torch

Performance:

Added torch.float64 as data type to test within binary_test.py

Reviewed By: mcfi

Differential Revision: D84924406

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166092
Approved by: https://github.com/malfet
2025-10-23 20:20:48 +00:00
75b8295868 Revert "Warn if AccumulateGrad stream does not match producer node stream (#165065)"
This reverts commit 12f742941d6aecb72c18d8e602f90ac9b4f00af0.

Reverted https://github.com/pytorch/pytorch/pull/165065 on behalf of https://github.com/clee2000 due to broke internal builds D85273204 usages of TORCH_API void add need to be updated? ([comment](https://github.com/pytorch/pytorch/pull/165065#issuecomment-3438061854))
2025-10-23 17:02:49 +00:00
defb6a80d8 Enable torch.Generator to support pytorch/xla generator implementation (#161369)
Currently, the implementation of `torch.Generator` only support "cpu" and "cuda" device type.  https://github.com/pytorch/pytorch/blob/main/torch/csrc/Generator.cpp#L55-L61

This change enables `torch.Generator` to support more device type by allowing any device backend to register their own generator factory through a Generator Registry. This is similar to what "DeviceGuardImpl registry" does today.

# Key Changes:

## New registry API:

* Added GeneratorRegistry.h and GeneratorRegistry.cpp in c10/core/impl.
* API supports registerGenerator(DeviceType, GeneratorFactory), unregisterGenerator(DeviceType), and getGeneratorFactory(DeviceType).
* Uses c10::DeviceType as the key and stores a factory function returning c10::intrusive_ptr<c10::GeneratorImpl>.

## Python/C++ integration:

* The registry is consulted in the torch.Generator constructor path for non-CPU/CUDA devices.
* If a factory is registered for the requested device, it constructs the appropriate generator; otherwise, raises an error.

## Backend extensibility:

* Out-of-tree backends (e.g., torch_xla, torch-directml, torch_npu) can now register their custom generator implementation at module load via a static registrar object.
Example usage:
```
C++
namespace {
  struct Registrar {
    Registrar() {
      at::detail::registerGenerator(c10::DeviceType::XLA, &CreateXlaGenerator);
    }
  } registrar_instance;
}
```

This allows torch.Generator(device='xla') to return an XlaGeneratorImpl when the torch_xla extension is imported.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/161369
Approved by: https://github.com/FFFrog, https://github.com/qihqi, https://github.com/albanD
2025-10-23 16:49:28 +00:00
f8fccb1e48 [Code Clean] Clean asserts in torch/optim. (#165629)
Replaces 50 assert statements across 15 files in torch.optim with explicit  if-checks raising AssertionError to prevent assertions from being disabled with Python -O flag.

fix partially #164878

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165629
Approved by: https://github.com/albanD
2025-10-23 15:56:29 +00:00
5aac4cfce4 Use is rather than == to work around slow enum comparion in _ops.py (#165936)
This shows up (under _are_we_tracing) in DTensor dispatch. I have some work in flight to speed up enum comparison in pybind11, but `is` is just much faster and easy to use.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165936
Approved by: https://github.com/Skylion007, https://github.com/zou3519
2025-10-23 15:01:55 +00:00
baf91bbbfc Revert "[inductor][choices] lookup table choices 1/3 (#164978)"
This reverts commit ab9e466928e7a37844c4f2a8bf90c76d16ac3c34.

Reverted https://github.com/pytorch/pytorch/pull/164978 on behalf of https://github.com/malfet due to Looks like it broke slow tests, see cbcb4f7768/1 ([comment](https://github.com/pytorch/pytorch/pull/164978#issuecomment-3437424559))
2025-10-23 14:47:07 +00:00
cbcb4f7768 [pytorch][torchelastic] Duplicate stdout and stderr and apply custom filter in torchrun (#160712)
Summary:
Part of an effort to extract some important error logs (e.g. [#157996](https://github.com/pytorch/pytorch/pull/157996)) that was `tee`'ed to `stdout` and `stderr`.

The general idea is to:

- Duplicate the `tee`s on `stdout` and `stderr` to a separate file, `filtered_stdout.log` and `filtered_stderr.log`, respectively.
- In these files, as its name suggests, only log lines matching a customizable filter.
- Later on in another PR, append the contents of these files to the reply file.

Outline of changes in this PR:

- Enhance `TailLog` to be able to 1) stream to a file, and 2) only write when the line matches the passed filter.
- Add `filtered_stdout` and `filtered_stderr` to `LogsDest` and have `LogsSpecs` `reify` them.
- In `start_processes()` and `PContext`, add params `duplicate_stdout_filters` and `duplicate_stderr_filters` to filter and write the duplicated stream to the files above. When no filters are passed in, no duplicated streams are created.

Test Plan:
```
$ buck test 'fbcode//mode/opt' caffe2/test/distributed/elastic/multiprocessing:api_test
```
```
Buck UI: https://www.internalfb.com/buck2/f5c6b7da-217d-4a0b-872a-c7cd3d05587f
Test UI: https://www.internalfb.com/intern/testinfra/testrun/4222124951617688
Network: Up: 398B  Down: 44MiB  (reSessionID-a489a961-b602-45be-b851-3490ebb7a26a)
Analyzing targets. Remaining     0/200
Executing actions. Remaining     0/12856                                                                                                                                        0.1s exec time total
Command: test.     Finished 1 local
Time elapsed: 17:37.9s
Tests finished: Pass 52. Fail 0. Fatal 0. Skip 0. Build failure 0
```
```
$ buck test 'fbcode//mode/opt' caffe2/test/distributed/elastic/multiprocessing:tail_log_test
```
```
Buck UI: https://www.internalfb.com/buck2/d6d5c1c1-db98-4d9c-b608-7ba6fbb5e3ee
Test UI: https://www.internalfb.com/intern/testinfra/testrun/13510798985149262
Network: Up: 94KiB  Down: 417MiB  (reSessionID-27b46fba-d31c-4c04-8ede-a506454e6922)
Analyzing targets. Remaining     0/3                                                                                                                                            536 actions, 555 artifacts declared
Executing actions. Remaining     0/186                                                                                                                                          1:05.5s exec time total
Command: test.     Finished 7 local, 1 remote, 115 cache (93% hit)                                                                                                              37.0s exec time cached (56%)
Time elapsed: 1:11.5s
Tests finished: Pass 7. Fail 0. Fatal 0. Skip 0. Build failure 0
```

Rollback Plan:

Differential Revision: D80188995

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160712
Approved by: https://github.com/fduwjj
2025-10-23 14:22:21 +00:00
2b93d5b450 [FlexAttention][CUDA] Add flex configs for Blackwell (#165760)
This PR fixes ULFs on `max_autotune` mode for high head-dim sizes on B200. Closes https://github.com/pytorch/torchtitan/issues/1791

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165760
Approved by: https://github.com/syed-ahmed, https://github.com/drisspg
2025-10-23 10:22:06 +00:00
6b7cd48e7e [ROCm] Deserialize loads in planer sum portion of reduce() of norm. (#165927)
Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165927
Approved by: https://github.com/jeffdaily
2025-10-23 09:45:01 +00:00
bf5aa9e42e [dynamo] Remove ID guard on method object (#166096)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166096
Approved by: https://github.com/tugsbayasgalan
2025-10-23 06:22:49 +00:00
b1eb6dede5 [vision hash update] update the pinned vision hash (#166046)
This PR is auto-generated nightly by [this action](https://github.com/pytorch/pytorch/blob/main/.github/workflows/nightly.yml).
Update the pinned vision hash.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166046
Approved by: https://github.com/pytorchbot
2025-10-23 04:27:44 +00:00
673060beae [inductor] turn Inductor deterministic mode on with torch.use_deterministic_algorithms (#165950)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165950
Approved by: https://github.com/v0i0, https://github.com/eellison
2025-10-23 02:48:42 +00:00
2e8e9a59a8 Revert "[dynamo][easy] Support torch.accelerator.current_accelerator (#165734)" (#166094)
This reverts commit c18ddfc5721dd91bf29c769e850a99c4fdb6f380.

Discovers some latent issues causing internal failures. Will fix those issues first and resend the PR

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166094
Approved by: https://github.com/bdhirsh
2025-10-23 01:24:46 +00:00
fb277a5916 Enable new tracer by default (#165332)
Differential Revision: [D84516080](https://our.internmc.facebook.com/intern/diff/D84516080)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165332
Approved by: https://github.com/avikchaudhuri
ghstack dependencies: #165582, #163580
2025-10-23 00:40:29 +00:00
73fa0d0c63 test for #165446 (#165853)
Per title

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165853
Approved by: https://github.com/drisspg
2025-10-23 00:08:18 +00:00
36c21cc84e state dict staging fixes (#166025)
Summary:
This PR contains three changes -
1. We are losing non-blocking flag value and defaulting to False during the deep_copy. This is introducing a cuda synchronize after each tensor. This is slowing the staging.
2. Adding the capability to skip pinning for scalar tensors to reduce initial staging buffer creation cost. Setting it by default to 65 to avoid pinning small tensors.
3. Tensor share storage but each storage needs to be processed only once in the deep_copy with offloading logic. so, use the memoization table to cache storage ids.

Test Plan:
1. Verified non-blocking copies via kineto profile.
2. ran A/B jobs old and new staging with fixes such that it crashes after ever 2 checkpoints and restarts for several hours and compared loss curves and they are exactly identical.
3. tests

Differential Revision: D85180484

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166025
Approved by: https://github.com/pradeepfn
2025-10-22 23:32:41 +00:00
0b68814b44 Forward fix to D80948073 (#166023)
Summary:
realize tensor before accessing layout.

Differential Revision: D85172267

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166023
Approved by: https://github.com/laithsakka
2025-10-22 22:00:53 +00:00
e64a814ae7 [CUDA] Add experimental green context support for SM carveout (#159104)
Low-level PyTorch APIs should be usable/stable enough at this point but we might move the underlying driver API usage a bit from here...

Built on top of @drisspg 's branch

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159104
Approved by: https://github.com/ngimel, https://github.com/malfet, https://github.com/kwen2501

Co-authored-by: drisspg <drisspguessous@gmail.com>
Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
2025-10-22 21:38:52 +00:00
0b58d87aec [Submodule] Bump FBGEMM to latest (#165544)
Summary:

* FBGEMM submodule updated to main
* CMake updated to reflect necessary changes
* Notably pulls in NVFP4 grouped gemm kernels

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Signed-off-by: Simon Layton <simonlayton@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165544
Approved by: https://github.com/cyyever, https://github.com/jeffdaily
2025-10-22 20:57:15 +00:00
757975ad50 [export] Unified graph capture with fullgraph_capture. (#165562)
Summary:
_dynamo_graph_capture_for_export in the current form has the compability issue
with the main torch.compile() path despite we reuse fullgraph_capture as the
bytecode tracer. The reason is that we flip on many export specific flags
and even trace with a wrapped function which will cause divergence with
torch.compile() again.

This PR instead creates a new implementation of dynamo_graph_capture_for_export
which 100% relies on fullgraph capture and post-processing on CaptureOutput so
that we can avoid the inversion of phases in PT2 compiler stack.

This also benefits precompile workflow since we want to have a feature that
only accepts pytree inputs and ship portable python wrappers in package. In
other words, I think the code here is sharable between export and precompile
for exporting portable graph.

Test Plan:
===================================================================== test session starts =====================================================================
platform linux -- Python 3.12.11, pytest-7.3.2, pluggy-1.6.0
rootdir: /data/users/zhxchen17/pytorch
configfile: pytest.ini
plugins: xdoctest-1.1.0, hypothesis-5.35.1, xdist-3.3.1, subtests-0.13.1, rerunfailures-14.0, flakefinder-1.1.0, cpp-2.3.0, anyio-4.10.0
collected 9 items
Running 9 items in this shard

test/distributed/tensor/test_dtensor_export.py ........x                                                                                                [100%]

================================================================ 8 passed, 1 xfailed in 11.42s ================================================================

Reviewers:

Subscribers:

Tasks:

Tags:

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165562
Approved by: https://github.com/tugsbayasgalan
2025-10-22 20:44:55 +00:00
291712026b [dynamo][user_defined] Replace UserFunctionVariable with VariableTracker build (#165706)
Audit: To prevent future issues with functools.partial or callable
objects.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165706
Approved by: https://github.com/Lucaskabela, https://github.com/williamwen42
2025-10-22 19:28:27 +00:00
3e77a2b478 [PyTorch] Improve aarch64 performance of bfloat16 ops (#166028)
Summary:
PR allows compiler to better optimize some bfloat16-based operations, when ran on NEON

Benchmarks show measurable improvements:

Before:
bfloat16 add: 250.503us
bfloat16 sub: 245.674us
bfloat16 neg: 113.945us

After:
bfloat16 add: 203.862us ---> 23% higher throughput
bfloat16 sub: 201.526us ---> 22% higher throughput
bfloat16 neg: 74.986us ---> 52% higher throughput

Test Plan:
Correctness:

buck2 test mode/opt //caffe2/test:test_ops
buck2 test mode/opt //caffe2/test:torch

Performance:

 binary_test.py has been updated, to run bfloat16 benchmarks using basic arithmetic functions

Differential Revision: D85186786

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166028
Approved by: https://github.com/Skylion007
2025-10-22 19:25:33 +00:00
82ef1b5db3 [DebugMode] refactor logs into _DebugCalls (#165376)
Refactors `DebugMode.operators` to be more structured `_DebugCall` objects, instead of (op, args, kwargs, call_depth) tuples. Useful going forward for attaching more information (e.g. output info, call metadata).

Is BC-breaking, but attaches an `__iter__` method for `_OpCall` and `_RedistributeCall` so previous tuple usage is accessible.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165376
Approved by: https://github.com/yushangdi
2025-10-22 19:01:56 +00:00
5f370f5c42 inductor_provenance: Correctly handle null provenance (#166019)
Summary:
If the provenance is null, we're getting crashes of the form
```
[trainers0]:E1021 10:51:31.990525  2752 PythonApi.h:87] Exception caught in
GeneratedDynamoCompileLoggerConfig: <class
'dsi.logger.py3.GeneratedDynamoCompile.LogEntry.thrift_types.GeneratedDynamoCompileLogEntryThriftBase'>:
error initializing Thrift struct field 'inductor_provenance_thrift_safe':
Cannot create internal string data representation. Expected type <class 'str'>,
got: <class 'NoneType'>.
```

Also fixed a type signature that wasn't being enforced. (It's still not
enforced, but it's accurate).

Test Plan:
Added a new test which reproduces the logging issue

Differential Revision: D85173596

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166019
Approved by: https://github.com/ppanchalia, https://github.com/yushangdi
2025-10-22 18:21:57 +00:00
05b2e02cb4 Revert "[lint] workflow consistency linter to look at all files instead of just changed files (#165171)"
This reverts commit c746feb86a1459db5f6294730d1d72ed15f16dd3.

Reverted https://github.com/pytorch/pytorch/pull/165171 on behalf of https://github.com/clee2000 due to broke lint [GH job link](https://github.com/pytorch/pytorch/actions/runs/18723760085/job/53402955955) [HUD commit link](c746feb86a) ([comment](https://github.com/pytorch/pytorch/pull/165171#issuecomment-3433501457))
2025-10-22 17:47:29 +00:00
12f742941d Warn if AccumulateGrad stream does not match producer node stream (#165065)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165065
Approved by: https://github.com/ngimel
2025-10-22 17:33:27 +00:00
35180fafee Allow GraphPickler to pickle graph modules containing AOTCompiled subgraphs (#165844)
This PR allows GraphPickler to pickle aot_eager graph modules that have regional inductor bits in them, with a few exceptions:
- FlexAttentionBackward isn't marked cacheable, so those tests don't work immediately since we're not sure how to serialize it. But it's safe to serialize/cache, so the next PR fixes those unit tests.
- It seems that when reloading a GraphPickled object, we don't recompile subgraphs. Will investigate this in a future PR

All unit tests in test_regional_inductor are parameterized so that we try serializing and deserializing the returned graph module before returning.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165844
Approved by: https://github.com/oulgen
ghstack dependencies: #165843
2025-10-22 17:03:49 +00:00
c746feb86a [lint] workflow consistency linter to look at all files instead of just changed files (#165171)
As in title

If you change only one workflow file, lintrunner (default arg, also the one in CI since it only inputs changed files) won't look at other files in the repo, but the sync-tag might come from those other files

This makes it so that it looks at all workflow files so it will catch those failures

Pros:
catches errors

Cons:
unusual behavior (getting around what lintrunner says the linter should run on)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165171
Approved by: https://github.com/malfet
2025-10-22 16:57:59 +00:00
c5f26db5bf fix #166057: add tmp ptr to avoid gcc internal compiler error (#165717)
Fixes #166057

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165717
Approved by: https://github.com/malfet
2025-10-22 16:38:26 +00:00
18e99b6d45 [dirsync] Switch to top-level xplat/third-party/pthreadpool (#165995)
Summary: `fbcode//xplat/third-party/pthreadpool:` just redirects to the xplat version. Switch to the real location

Test Plan: This should be a no-op, so CI?

Differential Revision: D83999534

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165995
Approved by: https://github.com/bigfootjon, https://github.com/Skylion007
2025-10-22 16:18:23 +00:00
ab9e466928 [inductor][choices] lookup table choices 1/3 (#164978)
\# why

- enable users to control which choices get used on which inputs
- reduce lowering time, and pin kernel selection, by selecting
  them for the inputs

\# what

- a new InductorChoices subclass that implements a lookup table
- a README explaining the usage
- corresponding testing

- currently only supports templates that go through
  `V.choices.get_template_configs`

\# testing

```
python3 -bb -m pytest test/inductor/test_lookup_table.py -v
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164978
Approved by: https://github.com/PaulZhang12, https://github.com/eellison
2025-10-22 16:11:31 +00:00
af4ba78543 [scan x vmap] support scan in vmap (#165580)
This is required by the chunked_with_scan work where two nested vmap(vmap) with chunk sizes > 1 are invoked, which produces a scan-> vmap -> scan -> vmap chain and we need to handle the case of vmap(scan) and scan(vmap).

The way we handle vmap(scan) is to turn it into scan(vmap(combine_fn)). The idea being that the combine_fn no longer do the combine_fn for a single slice, it vmaps over the combine_fn and do multiple combine_fns in one step. We need to need know how combine_fn propagates the batched tensor and what are the batched dims of the output. For this purpose, we use restore_vmap to give us the out_dims information.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165580
Approved by: https://github.com/zou3519
ghstack dependencies: #165675
2025-10-22 09:46:00 +00:00
282f39a4bc [vmap][dynamo] use create_proxy instead of create_node in vmap increate nesting ctx manager (#165675)
create_node won't do the auto closure lifting, this cause problems when the context manager is used in a hop region. Switch to create_proxy instead.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165675
Approved by: https://github.com/zou3519, https://github.com/guilhermeleobas
2025-10-22 09:46:00 +00:00
a479769488 [dynamo] Clean up assert in dynamo [2/N] (#165745)
Extend from #165430
* #165903(Clean up for graph break)
* ->#165745
* #165430

One main refractor from the previous PR:
* For assertions like checking `len(args)` or `len(kwargs)`, using `raise_args_mismatch` instead of `raise_type_error_exc`

I am also considering moving `raise_type_error_exc` into `utils.py` for consistency.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165745
Approved by: https://github.com/Lucaskabela
2025-10-22 07:12:37 +00:00
26c7375477 Remove the branch of IS_CUSPARSE11_AVAILABLE is False (#166048)
This PR removes the branch when `IS_CUSPARSE11_AVAILABLE` is 0. Note that the condition `ROCM_VERSION >= 60300` holds currently as the minimum supported ROCm is 6.3 .
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166048
Approved by: https://github.com/Skylion007
2025-10-22 07:10:11 +00:00
d01f15152c Move toUnderlying to headeronly (#165694)
As in the title. Required in upper PRs of this ghstack.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165694
Approved by: https://github.com/janeyx99
2025-10-22 05:31:16 +00:00
4fae6968b1 Move toString(ScalarType) and ScalarType ostream operator to headeronly (#164405) (#166018)
This PR is created to replace the reverted PR https://github.com/pytorch/pytorch/pull/164405
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166018
Approved by: https://github.com/janeyx99
2025-10-22 05:16:58 +00:00
f9953e0f61 Enable PLC0414 on ruff (#165828)
This PR enables `PLC0414` that fixes redundant import aliases.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165828
Approved by: https://github.com/albanD
2025-10-22 04:56:52 +00:00
34ed7a8f0d [ROCm] Skip test_blockwise_nvfp4_with_global_scale (#165968)
Disable the fp4 global_scale test till the feature is enabled on ROCm.

Fixes #166027.
Not really, but we're trading an issue for a test skip decorator since the test is parameterized.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165968
Approved by: https://github.com/jeffdaily, https://github.com/drisspg
2025-10-22 04:23:05 +00:00
2fde10d914 [ROCm] fix test_allocator_backend (#166035)
Fixes #165872.

Forward fix PR #165298. hipify was causing some symbols to be replaced.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166035
Approved by: https://github.com/jeffdaily

Co-authored-by: Jeff Daily <jeff.daily@amd.com>
2025-10-22 03:46:23 +00:00
0a93295da0 Update doc (#166024)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166024
Approved by: https://github.com/yiming0416
2025-10-22 03:41:31 +00:00
4b898b51b9 [12/n][take2] : Remove fbandroid_compiler_flags platform args (#165916)
Summary: This diff removes the `fbandroid_compiler_flags` and merges its content with `compiler_flags` and wraps it in a android select. My first attempt at this got reverted - D84626885.

Test Plan:
CI and failing builds are now passing
```
buck2 build --target-universe fbsource//fbandroid/apps/wearable/system/healthservices:healthservices_target30_mosnative_xhdpi_arm64_release_debug_keystore_redex_postprocessed_repack_resign @//fbandroid/mode/nosan @//fbandroid/mode/opt @//fbandroid/mode/milan_build_rdk @//fbandroid/mode/relr-relocations fbsource//fbandroid/apps/wearable/system/healthservices:healthservices_target30_mosnative_xhdpi_arm64_release_debug_keystore_redex_postprocessed_repack_resign fbsource//fbandroid/apps/wearable/system/healthservices:healthservices_target30_mosnative_xhdpi_arm64_release_debug_keystore_redex_genrule fbsource//fbandroid/apps/wearable/system/healthservices:healthservices_target30_mosnative_xhdpi_arm64_release_debug_keystore-mobileconfig-definition-resource-gen fbsource//fbandroid/apps/wearable/system/healthservices:healthservices_target30_mosnative_xhdpi_arm64_release_debug_keystore
File changed: fbsource//tools/build_defs/fb_xplat_cxx_library.bzl
Buck UI: https://www.internalfb.com/buck2/509c0b7b-ada3-421a-8c32-2f1d3a7babdd
Network: Up: 1.3MiB  Down: 293MiB  (reSessionID-17f73b81-3c34-4c01-9f6c-2b4f3c8332e3)
Loading targets.   Remaining     0/1311                                                                                                                                                                                                292986 targets declared
Analyzing targets. Remaining     0/13515                                                                                                                                                                                               216715 actions, 359204 artifacts declared
Executing actions. Remaining     0/40415                                                                                                                                                                                               6:33.3s exec time total
Command: build.    Finished 40 local, 790 remote
Time elapsed: 32.0s
BUILD SUCCEEDED
```

Reviewed By: jaejunku

Differential Revision: D84868234

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165916
Approved by: https://github.com/malfet
2025-10-22 03:01:55 +00:00
550e3e6efb [dynamo] Fix MATCH_KEYS for dict pattern matching (#165956)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165956
Approved by: https://github.com/guilhermeleobas, https://github.com/cyyever
2025-10-22 02:52:07 +00:00
715449ca76 [MPS] Fix parity between CPU and MPS on singular matrices in linalg.lu_factor (#165871)
Fixes #165870. Follow up from #165254.

This PR [a] removes the MPS specific version of `lu_factor` in favor of the version in BatchedLinearAlgebra.cpp which uses `lu_factor_ex`, and [b] updates `lu_factor_ex` error codes to match expectations.

When `lu_factor` was first implemented for MPS (#99269), it bypassed the implementation in BatchedLinearAlgebra.cpp since we did not have `lu_factor_ex`. Since #144651 implements `lu_factor_ex`, we can now remove the MPS specific wrapper.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165871
Approved by: https://github.com/kulinseth, https://github.com/albanD
2025-10-22 02:48:40 +00:00
84d8d06fc3 Fixes floating point exception in torch.nn.PixelShuffle (#163154)
Fixes #162251

**Previous Output:**
`Floating point exception (core dumped)`

**Now Output:**
`RuntimeError: upscale factor is too large, (upscale_factor}^2 overflowed: upscale_factor=545460846592`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163154
Approved by: https://github.com/cyyever, https://github.com/albanD
2025-10-22 02:22:16 +00:00
60992d98b2 [dynamo][remaining] Replace UserFunctionVariable with VariableTracker build (#165896)
Audit: To prevent future issues with functools.partial or callable objects.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165896
Approved by: https://github.com/Lucaskabela
2025-10-22 02:13:00 +00:00
59e015e3a1 Remove outdated CUB macros (#164656)
This PR removes `CUB_SUPPORTS_NV_BFLOAT16` and `CUB_SUPPORTS_FUTURE_VALUE` because they are always true on CUDA >=12 installations with its CUB version. Their branches are also removed.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164656
Approved by: https://github.com/albanD, https://github.com/eqy, https://github.com/jeffdaily
2025-10-22 02:02:50 +00:00
8904a5a7c9 Move allocation size config to AllocatorConfig for cross-allocator sharing (#159553)
# Motivation
Make CUDA and XPU share the same config and code. And allow the other backends to reuse them.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159553
Approved by: https://github.com/albanD
ghstack dependencies: #160067
2025-10-22 01:48:56 +00:00
f5df9ca03a Fix creation of BINARY_SUBSCR in Python 3.14+ (#165864)
Python 3.14 replaced `BINARY_SUBSCR` by `BINARY_OP(opcode=BN_SUBSCR)`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165864
Approved by: https://github.com/williamwen42
2025-10-22 01:43:03 +00:00
2998abd777 [Code Clean] Better error handling in torch/csrc/distributed (#165053)
Replace the runtime_error of the vallina C++ exceptions with TORCH_CEHCK
Including:

torch/csrc/distributed/*

fix partialy #148114

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165053
Approved by: https://github.com/FFFrog, https://github.com/albanD
2025-10-22 01:40:36 +00:00
e13580e41c [AMD] Run int4_mm tests only for compatible arch (#165630)
Such tests should be skipped for rest including gfx1100(Navi3x)

Fixes for CI HUD for gfx1100

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165630
Approved by: https://github.com/jeffdaily

Co-authored-by: Jithun Nair <37884920+jithunnair-amd@users.noreply.github.com>
2025-10-22 01:38:55 +00:00
f3b8e15f20 [AMD][gfx1100] test_decompose_mem_bound_mm.py tolerance increase (#165625)
test_decompose_mem_bound_mm.py tolerance increase for navi3x(gfx11x)

(cherry picked from commit 03c7da05f61890bbf5ae41e23c8df6d5f6805bac) from

Fixes for CI HUD for gfx1100

Signed-off-by: Artem Kuzmitckii <artem.kuzmitckii@amd.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165625
Approved by: https://github.com/jeffdaily

Co-authored-by: iupaikov-amd <Iurii.Paikov@amd.com>
Co-authored-by: Dmitry Nikolaev <139769634+dnikolaev-amd@users.noreply.github.com>
Co-authored-by: Jeff Daily <jeff.daily@amd.com>
2025-10-22 01:38:48 +00:00
5211f4c108 [MPS] Fix SDPA fp16 overflow (#165961)
Do not cast intermediate result back to lower precision data data until
softmax is finished, otherwise it might produce NaN

Adjust the test to use 256 as filler value rather than 64

Fixes https://github.com/pytorch/pytorch/issues/160841
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165961
Approved by: https://github.com/dcci, https://github.com/Skylion007
ghstack dependencies: #165960
2025-10-22 01:29:42 +00:00
ad9027b80d [BE] Remove unused 'rows' parameter from spmm_bmm_coo_rows_grouped (#166041)
To fix following compilation warning
```
Users/malfet/git/pytorch/pytorch/aten/src/ATen/native/sparse/mps/kernels/Mul.metal:76:14: warning: unused variable 'B' [-Wunused-variable]
  const uint B = dims.x;
             ^
/Users/malfet/git/pytorch/pytorch/aten/src/ATen/native/sparse/mps/kernels/Mul.metal:65:26: warning: unused parameter 'rows' [-Wunused-parameter]
    device const long*   rows      [[buffer(0)]],
                         ^
2 warnings generated.
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166041
Approved by: https://github.com/Skylion007
2025-10-22 00:59:41 +00:00
a1005427bf [xpu] Support high stream for ProcessGroupXCCL (#163049)
Add high priority stream support for ProcessGroupXCCL. Just like CUDA, XPU streams also support execution with higher priority compared to other streams. Implementation in https://github.com/intel/torch-xpu-ops/pull/1715, add register here.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163049
Approved by: https://github.com/guangyey, https://github.com/gujinghui, https://github.com/EikanWang, https://github.com/albanD
2025-10-22 00:54:25 +00:00
35153d0846 Simplify c10::guts::apply (#164566)
There is only one call site of `c10::guts::apply` that can be replaced by `:std::apply` except for ROCm. This PR therefore simplifies the implementation of `c10::guts::apply`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164566
Approved by: https://github.com/Aidyn-A, https://github.com/albanD
2025-10-22 00:47:43 +00:00
7773a22cdb Revert "[AMP][Refactor] Autocast dtype handling to simplify device-specific c… (#165221)"
This reverts commit 4be1e3bf926b8e798fede3be6a3051560e9e00c5.

Reverted https://github.com/pytorch/pytorch/pull/165221 on behalf of https://github.com/clee2000 due to I think this broke test_openreg [GH job link](https://github.com/pytorch/pytorch/actions/runs/18698271058/job/53322459496) [HUD commit link](4be1e3bf92) note to self: bad TD ([comment](https://github.com/pytorch/pytorch/pull/165221#issuecomment-3430012693))
2025-10-22 00:26:57 +00:00
7cb467a169 [CI] Update ONNX CI packages to latest (#165883)
This PR updates ONNX related packages to their latest versions used in CI environments.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165883
Approved by: https://github.com/justinchuby, https://github.com/albanD
2025-10-22 00:25:35 +00:00
12aac12b8d [Code Clean] Replace std::runtime_error with TORCH_CHECK (#165209)
Including:
1. `aten/src/ATen/core`
2. `c10/core`

Fixes part of #148114

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165209
Approved by: https://github.com/FFFrog, https://github.com/albanD
2025-10-22 00:05:22 +00:00
2b748d0a56 Add operator name to output json (#164583)
The benchmarks, model_name on dashboard needs to be grouped with operator_name. This PR passed an additional argument operator_name to the json for grouping.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164583
Approved by: https://github.com/yangw-dev
2025-10-21 23:58:39 +00:00
16745a882a [aoti][win] add support for a list of shim libraries (#165914)
As title, support passing in a list of shim libraries when cross compiling artifacts

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165914
Approved by: https://github.com/desertfire
2025-10-21 22:55:17 +00:00
8daef35cf1 Revert "[Code Clean] Clean asserts in torch/ao/quantization (root, quantizer, backend_config) (#165433)"
This reverts commit df64c0c4649984093bd1a46f1e9c658c72018200.

Reverted https://github.com/pytorch/pytorch/pull/165433 on behalf of https://github.com/clee2000 due to I think this broke some quantization tests ([comment](https://github.com/pytorch/pytorch/pull/165433#issuecomment-3429741770))
2025-10-21 22:10:19 +00:00
51319ca090 [Pytorch] Add NEON Vectorized<uint> family of translation layers (#165690)
Summary:
Adding NEON specializations of Vectorized<T> for uint8, uint16, uint32 and uint64.

Correcness has been checked using test_ops.py

operator_benchmark_test.py, which uses the PyTorch API, shows significant enhancements in some operations:

Before:

uint8 mul: 1460.751us
uint8 add: 2359.565us
uint8 lsl: 2151.206us

After:

uint8 mul: 194.792us ---> 650% higher throughput
uint8 add: 195.609us ---> 1100% higher throughput
uint8 lsl: 186.249us ---> 1055% higher throughput

Test Plan:
Correctness:

buck2 test mode/opt //caffe2/test:test_ops
buck2 test mode/opt //caffe2/test:torch

Performance:

buck2 run mode/opt //caffe2/benchmarks/operator_benchmark/fb:operator_benchmark_test

Reviewed By: mcfi

Differential Revision: D84770153

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165690
Approved by: https://github.com/malfet
2025-10-21 21:46:55 +00:00
d311a3d1dc A temporary fix to autotune out of range and related IMA (#165943)
Summary:
Autotune issue during lowering w/ AOTI:
```
setStorage: sizes [1536, 32, 8192], strides [8192, 8192, 1], storage offset 0, and itemsize 2 requiring a storage size of 25673728 are out of bounds for storage of size 25362432
```
Need a hack to create new base tensor with sufficient storage

Test Plan: Finally be able to see the e2e test passes on CI. See the detailed Test Plan in D83520844

Differential Revision: D84872792

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165943
Approved by: https://github.com/laithsakka
2025-10-21 21:40:20 +00:00
04adfe5ba9 Make Backend::setGroupUid virtual (#165957)
As titled, so that we may customize this function in custom backends

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165957
Approved by: https://github.com/d4l3k
2025-10-21 21:33:24 +00:00
4be1e3bf92 [AMP][Refactor] Autocast dtype handling to simplify device-specific c… (#165221)
This PR refactors the autocast context manager in autocast_mode.py to simplify and centralize the logic for checking supported dtypes for each device. The previous implementation repeated similar checks for multiple device types. Now, a single mapping device_supported_dtypes is used to associate device types with their supported dtypes, and the validation logic is unified.

**The former PR #163446 was merged but reverted due to failed CI test on `openreg` related tests.**

This RR additionally slightly modified some test assertions for passing the CI tests. CI failed due to assertion for the exactly same error message. For example:
```
File "/var/lib/jenkins/workspace/test/cpp_extensions/open_registration_extension/torch_openreg/tests/test_autocast.py", line 9, in test_autocast_with_unsupported_type
    with self.assertWarnsRegex(
        AssertionError: "In openreg autocast, but the target dtype torch.float32 is not supported." does not match "In openreg autocast, but the target dtype is not supported. Disabling autocast."
```

Sorry for the inconvenience again.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165221
Approved by: https://github.com/FFFrog, https://github.com/albanD
2025-10-21 21:32:12 +00:00
e7592f4005 [CI] Move the periodic debug tests to newer runner (#165158)
Previously g3 = NVIDIA Tesla M60
Now g6 = NVIDIA L4
Also change cuda arch list accordingly

Pros:
More memory, newer GPU

Cons:
That was one of the few remaining tests on g3 runners, so we probably lost coverage?

We can probably run more tests in parallel now but I'm not going to do that here

Disabled a bunch of sparse tests and nestedtensor tests that were previously skipped due to not having sufficient hardware?  They are now failing with
```
Traceback (most recent call last):
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/testing/_internal/common_utils.py", line 3293, in wrapper
    method(*args, **kwargs)
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/testing/_internal/common_utils.py", line 3292, in wrapper
    with policy():
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/testing/_internal/common_utils.py", line 2532, in __enter__
    self.beforeStreams[-1].synchronize()
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/cuda/streams.py", line 105, in synchronize
    super().synchronize()
torch.AcceleratorError: CUDA error: device-side assert triggered
Search for `cudaErrorAssert' in https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html for more information.
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

Exception raised from stream_synchronize at /var/lib/jenkins/workspace/c10/cuda/CUDAFunctions.h:120 (most recent call first):
C++ CapturedTraceback:
#4 std::_Function_handler<std::shared_ptr<c10::LazyValue<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > > const> (), c10::SetStackTraceFetcher(std::function<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > ()>)::{lambda()#1}>::_M_invoke(std::_Any_data const&) from Logging.cpp:0
#5 c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >) from ??:0
#6 c10::cuda::c10_cuda_check_implementation(int, char const*, char const*, unsigned int, bool) [clone .cold] from CUDAException.cpp:0
#7 THCPStream_synchronize(_object*, _object*) from Stream.cpp:0
#8 cfunction_vectorcall_NOARGS from /usr/local/src/conda/python-3.10.14/Objects/methodobject.c:489
#9 _PyObject_VectorcallTstate from /usr/local/src/conda/python-3.10.14/Include/cpython/abstract.h:114
#10 _PyEval_EvalFrame from /usr/local/src/conda/python-3.10.14/Include/internal/pycore_ceval.h:46
#11 _PyObject_VectorcallTstate from /usr/local/src/conda/python-3.10.14/Include/cpython/abstract.h:114
#12 _PyEval_EvalFrame from /usr/local/src/conda/python-3.10.14/Include/internal/pycore_ceval.h:46
```
when run with cuda launch blocking I got a ton of stuff like
```

/var/lib/jenkins/workspace/third_party/cutlass/include/cutlass/integer_subbyte.h:124: cutlass::integer_subbyte<Bits, Signed>::integer_subbyte(unsigned int) [with int Bits = 2; __nv_bool Signed = false]: block: [5,3,0], thread: [2,7,0] Assertion `value < upper_bound` failed.
/var/lib/jenkins/workspace/third_party/cutlass/include/cutlass/integer_subbyte.h:124: cutlass::integer_subbyte<Bits, Signed>::integer_subbyte(unsigned int) [with int Bits = 2; __nv_bool Signed = false]: block: [5,3,0], thread: [3,7,0] Assertion `value < upper_bound` failed.
/var/lib/jenkins/workspace/third_party/cutlass/include/cutlass/integer_subbyte.h:124: cutlass::integer_subbyte<Bits, Signed>::integer_subbyte(unsigned int) [with int Bits = 2; __nv_bool Signed = false]: block: [3,8,0], thread: [0,0,0] Assertion `value < upper_bound` failed.
/var/lib/jenkins/workspace/third_party/cutlass/include/cutlass/integer_subbyte.h:124: cutlass::integer_subbyte<Bits, Signed>::integer_subbyte(unsigned int) [with int Bits = 2; __nv_bool Signed = false]: block: [3,8,0], thread: [1,0,0] Assertion `value < upper_bound` failed.
/var/lib/jenkins/workspace/third_party/cutlass/include/cutlass/integer_subbyte.h:124: cutlass::integer_subbyte<Bits, Signed>::integer_subbyte(unsigned int) [with int Bits = 2; __nv_bool Signed = false]: block: [3,8,0], thread: [2,0,0] Assertion `value < upper_bound` failed.
/var/lib/jenkins/workspace/third_party/cutlass/include/cutlass/integer_subbyte.h:124: cutlass::integer_subbyte<Bits, Signed>::integer_subbyte(unsigned int) [with int Bits = 2; __nv_bool Signed = false]: block: [3,8,0], thread: [3,0,0] Assertion `value < upper_bound` failed.
/var/lib/jenkins/workspace/third_party/cutlass/include/cutlass/integer_subbyte.h:124: cutlass::integer_subbyte<Bits, Signed>::integer_subbyte(unsigned int) [with int Bits = 2; __nv_bool Signed = false]: block: [3,8,0], thread: [0,1,0] Assertion `value < upper_bound` failed.
/var/lib/jenkins/workspace/third_party/cutlass/include/cutlass/integer_subbyte.h:124: cutlass::integer_subbyte<Bits, Signed>::integer_subbyte(unsigned int) [with int Bits = 2; __nv_bool Signed = false]: block: [3,8,0], thread: [1,1,0] Assertion `value < upper_bound` failed.
/var/lib/jenkins/workspace/third_party/cutlass/include/cutlass/integer_subbyte.h:124: cutlass::integer_subbyte<Bits, Signed>::integer_subbyte(unsigned int) [with int Bits = 2; __nv_bool Signed = false]: block: [3,8,0], thread: [3,1,0] Assertion `value < upper_bound` failed.
/var/lib/jenkins/workspace/third_party/cutlass/include/cutlass/integer_subbyte.h:124: cutlass::integer_subbyte<Bits, Signed>::integer_subbyte(unsigned int) [with int Bits = 2; __nv_bool Signed = false]: block: [3,8,0], thread: [0,2,0] Assertion `value < upper_bound` failed.
/var/lib/jenkins/workspace/third_party/cutlass/include/cutlass/integer_subbyte.h:124: cutlass::integer_subbyte<Bits, Signed>::integer_subbyte(unsigned int) [with int Bits = 2; __nv_bool Signed = false]: block: [3,8,0], thread: [2,2,0] Assertion `value < upper_bound` failed.
/var/lib/jenkins/workspace/third_party/cutlass/include/cutlass/integer_subbyte.h:124: cutlass::integer_subbyte<Bits, Signed>::integer_subbyte(unsigned int) [with int Bits = 2; __nv_bool Signed = false]: block: [3,8,0], thread: [3,2,0] Assertion `value < upper_bound` failed.
/var/lib/jenkins/workspace/third_party/cutlass/include/cutlass/integer_subbyte.h:124: cutlass::integer_subbyte<Bits, Signed>::integer_subbyte(unsigned int) [with int Bits = 2; __nv_bool Signed = false]: block: [3,8,0], thread: [0,3,0] Assertion `value < upper_bound` failed.
/var/lib/jenkins/workspace/third_party/cutlass/include/cutlass/integer_subbyte.h:124: cutlass::integer_subbyte<Bits, Signed>::integer_subbyte(unsigned int) [with int Bits = 2; __nv_bool Signed = false]: block: [3,8,0], thread: [1,3,0] Assertion `value < upper_bound` failed.
/var/lib/jenkins/workspace/third_party/cutlass/include/cutlass/integer_subbyte.h:124: cutlass::integer_subbyte<Bits, Signed>::integer_subbyte(unsigned int) [with int Bits = 2; __nv_bool Signed = false]: block: [3,8,0], thread: [1,4,0] Assertion `value < upper_bound` failed.
/var/lib/jenkins/workspace/third_party/cutlass/include/cutlass/integer_subbyte.h:124: cutlass::integer_subbyte<Bits, Signed>::integer_subbyte(unsigned int) [with int Bits = 2; __nv_bool Signed = false]: block: [3,8,0], thread: [3,4,0] Assertion `value < upper_bound` failed.
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165158
Approved by: https://github.com/seemethere
2025-10-21 21:28:12 +00:00
d334c3649d [CUDA] fix reflection padding for large batch size (#165942)
Fixes [#165861](https://github.com/pytorch/pytorch/issues/165861)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165942
Approved by: https://github.com/eqy
2025-10-21 21:07:38 +00:00
9f82535c5a [ROCm] [Normalization] Update block size (#165941)
* Seeing upto 6x improvement

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165941
Approved by: https://github.com/jeffdaily
2025-10-21 20:53:05 +00:00
5b35fc8777 Support multiple commits on push events in trunk tagging workflow (#165937)
Context:
* this workflow is used to create tags like `trunk/{sha}` for all `main` commits
* those tags are used by [autorevert](https://github.com/pytorch/test-infra/blob/main/aws/lambda/pytorch-auto-revert/README.md) to rerun selected workflows

Problem: currently the workflow creates only a single tag per push event, while ghstack pushes multiple commits per single push.

This PR supports tag creation for all commits in the push event.

Complimentary autorevert PR: https://github.com/pytorch/test-infra/pull/7291

---

### Testing

I created an identical copy of this workflow in my personal repo: https://github.com/izaitsevfb/pr-head-test/actions/workflows/trunk-tagging.yml

See action runs there.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165937
Approved by: https://github.com/huydhn
2025-10-21 20:52:34 +00:00
2f38eece7c [CUDA][cuBLAS] addmm -- some refactoring for easier navigation between the Lt and non-Lt paths (#163955)
As per title. Additionally, some Lt selection conditions are revisited, and some redundancy removed (especially in the ROCm vs non-ROCm paths).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163955
Approved by: https://github.com/ngimel, https://github.com/eqy
2025-10-21 20:48:12 +00:00
830e789a55 [dynamo][annotate] Graph break cleanly on fx.traceback.annotate reconstruction (#166006)
This avoids generation of bad bytecode, leading to really confusing
error. I am not sure why we can't reconstruct cleanly, it has to do with
the input being a dict, while other supported ctx managers take bools.

Fixing that is for another day. Lets give a good error message for now.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166006
Approved by: https://github.com/yushangdi, https://github.com/SherlockNoMad
2025-10-21 20:48:04 +00:00
ad4dc52bf6 Revert "shrink_group implementation to expose ncclCommShrink API (#164518)"
This reverts commit 4e643422f63a3cdd71bd141615f98de6bb54d15f.

Reverted https://github.com/pytorch/pytorch/pull/164518 on behalf of https://github.com/albanD due to Breaks lint ([comment](https://github.com/pytorch/pytorch/pull/164518#issuecomment-3429426503))
2025-10-21 20:24:14 +00:00
dac9ed9790 Bump uv from 0.8.6 to 0.9.5 in /.ci/lumen_cli (#166017)
Bumps [uv](https://github.com/astral-sh/uv) from 0.8.6 to 0.9.5.
- [Release notes](https://github.com/astral-sh/uv/releases)
- [Changelog](https://github.com/astral-sh/uv/blob/main/CHANGELOG.md)
- [Commits](https://github.com/astral-sh/uv/compare/0.8.6...0.9.5)

---
updated-dependencies:
- dependency-name: uv
  dependency-version: 0.9.5
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-10-21 13:16:30 -07:00
1c7fe8f861 [BugFix] chunk_size should always be int64_t (#165971)
aspired by https://github.com/pytorch/pytorch/pull/156872
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165971
Approved by: https://github.com/albanD
2025-10-21 19:52:47 +00:00
4e643422f6 shrink_group implementation to expose ncclCommShrink API (#164518)
Closes #164529

To expose the new [ncclCommShrink](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/comms.html#ncclcommshrink) API to PyTorch.

This is useful when you need to exclude certain GPUs or nodes from a collective operation, for example in fault tolerance scenarios or when dynamically adjusting resource utilization.

For more info:  [Shrinking a communicator](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/communicators.html#shrinking-a-communicator)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164518
Approved by: https://github.com/kwen2501
2025-10-21 19:47:33 +00:00
3c3b278872 [reland][fx] Move Node._prepend/Node._remove_from_list to C++ (#165882)
Relands #148261 that was reverted by #150542

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165882
Approved by: https://github.com/ezyang
2025-10-21 19:43:55 +00:00
0bd12c1168 [CI] Extend test_transfomers to MPS (#165960)
Just skip grad_checks as they need float64
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165960
Approved by: https://github.com/Skylion007
2025-10-21 19:27:44 +00:00
ce8a7764e2 Revert "[dynamo][misc] Replace UserFunctionVariable with VariableTracker build (#165707)"
This reverts commit 1290b077f26543a34262587137ef64ca9ca5e17d.

Reverted https://github.com/pytorch/pytorch/pull/165707 on behalf of https://github.com/clee2000 due to failing internal tests D85160820 ([comment](https://github.com/pytorch/pytorch/pull/165707#issuecomment-3429084393))
2025-10-21 19:25:03 +00:00
d1269a0434 update fr trace analysis (#165994)
Summary:
- allow empty entries from ranks
- allow not all ranks to provide dump

---
[//]: # (BEGIN SAPLING FOOTER)
Stack created with [Sapling](https://sapling-scm.com). Best reviewed with [ReviewStack](https://reviewstack.dev/pytorch/pytorch/pull/165994).
* #165638
* #165640
* #165642
* __->__ #165994
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165994
Approved by: https://github.com/fduwjj
2025-10-21 19:14:33 +00:00
c87cf1be32 Update workaround to old CUDA bug (#164354) (#165984)
The workaround cannot be removed because of BC. Here we'll
update PyTorch code base to not use the workaround.

See https://github.com/pytorch/pytorch/pull/164354 for the BC breakage issue.

Resolves https://github.com/pytorch/pytorch/issues/164348.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165984
Approved by: https://github.com/janeyx99
2025-10-21 19:09:43 +00:00
2fc5e45a41 better error message when there is no pytree impl (#165955)
Differential Revision: [D85117597](https://our.internmc.facebook.com/intern/diff/D85117597)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165955
Approved by: https://github.com/avikchaudhuri
2025-10-21 18:49:22 +00:00
f9022ba93b [PyTorch] Add user_metadata display to memory visualizer (#165939)
Summary: Enhanced the PyTorch CUDA memory visualizer to display user_metadata alongside stack frames when inspecting allocations. The user_metadata field is now shown in all views (Allocator State History, Active Memory Timeline, etc.) with consistent formatting. The implementation handles both string and object metadata types, displaying strings directly and objects as key-value pairs.

Test Plan:
1. Generate a memory snapshot with user_metadata
2. Open the memory visualizer in a browser
3. Load the snapshot file
4. Verify user_metadata appears
5. Test with both string metadata ("testing") and object metadata ({"key": "value"})
6. Verify formatting shows "User Metadata:\n  <value>" for strings

 {F1982860439}

Differential Revision: D85095152

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165939
Approved by: https://github.com/yushangdi
2025-10-21 18:48:33 +00:00
ff8be889ad Remove unused exception parameter from some files, to work with -Wunused-exception-parameter (#165770)
Summary: address compiler complains that were coming up to unblock the build

Test Plan:
before the change
```
aten/src/ATen/native/LinearAlgebra.cpp:3623:36: error: unused exception parameter 'e' [-Werror,-Wunused-exception-parameter]
 3623 |     } catch (const std::exception& e) {
      |
```

after: targets build with `-Wunused-exception-parameter`

Differential Revision: D84876246

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165770
Approved by: https://github.com/Skylion007, https://github.com/cyyever

Co-authored-by: Tony Targonski <tony.targonski@meta.com>
2025-10-21 18:30:29 +00:00
292454942e [CD] Introduce windows.12xlarge runners for CD Windows build (#165287)
Follows https://github.com/pytorch/test-infra/pull/7174. Windows CD build time cost comparison as below

|Runner|cpu|cuda|xpu|
|-|-|-|-|
|windows.4xlarge|1.5h| 4.0h| 5.5h|
|windows.12xlarge|0.5h|1.5h|2.5h|

Fixes #162962
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165287
Approved by: https://github.com/zxiiro, https://github.com/malfet, https://github.com/seemethere
2025-10-21 18:28:23 +00:00
6c4412f72b Revert "[Inductor] support masked vectorization for the tail_loop for float64 datatype (#163316)"
This reverts commit e9d89734274a4a2640fa77b898c800a87d1d874e.

Reverted https://github.com/pytorch/pytorch/pull/163316 on behalf of https://github.com/clee2000 due to seems to have broken some no_gpu tests? test/inductor/test_cpu_repro.py::CPUReproTests::test_double_reduction_vec [GH job link](https://github.com/pytorch/pytorch/actions/runs/18689033019/job/53290772740) [HUD commit link](e9d8973427) ([comment](https://github.com/pytorch/pytorch/pull/163316#issuecomment-3428210509))
2025-10-21 17:44:42 +00:00
78bf6186f2 Revert "[Inductor] support masked vectorization for the tail_loop for fp8 datatype (#163324)"
This reverts commit e8cb34dd52c063a130f3e659576c313bbe4b4981.

Reverted https://github.com/pytorch/pytorch/pull/163324 on behalf of https://github.com/clee2000 due to seems to have broken some no_gpu tests? test/inductor/test_cpu_repro.py::CPUReproTests::test_double_reduction_vec [GH job link](https://github.com/pytorch/pytorch/actions/runs/18689033019/job/53290772740) [HUD commit link](e9d8973427) ([comment](https://github.com/pytorch/pytorch/pull/163316#issuecomment-3428210509))
2025-10-21 17:44:42 +00:00
c40048472c Remove AOTI cross compilation time from internal CI (#165935)
Summary: as title

Test Plan: CI

Differential Revision: D85088451

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165935
Approved by: https://github.com/desertfire
2025-10-21 16:58:28 +00:00
3dfd0c7584 Improve PATH hints in FindvecLib.cmake (#165881)
Change  /Applications/Xcode.app/Contents/Developer/Platforms/MacOSX.platform/Developer/SDKs/MacOSX10.9.sdk to /Applications/Xcode.app/Contents/Developer/Platforms/MacOSX.platform/Developer/SDKs/MacOSX.sdk in `cmake/Modules/FindvecLib.cmake` which is more general (and MacOSX10.9 is not supported now). Otherwise, vecLib can't be found on MacOS 26.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165881
Approved by: https://github.com/ezyang
2025-10-21 16:44:12 +00:00
e6ba4d0725 Back out "Do not decompose in functionalization/proxy tensor if autograd wouldn't have decomposed (#164939)" (#165910)
Summary:
Original commit changeset: d6d62d0c96dd

Original Phabricator Diff: D84468451 and D84613184

D84468451 caused CUDA OutOfMemoryError in model.

Test Plan:
D84468451 was found through bisect.  Also double checked on recent trunk 9866939225248c2adc307be7a804b26db0b9b555: f815887517

With this diff that backs out D84468451 and D84613184 : f816114560

Differential Revision: D85025378

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165910
Approved by: https://github.com/clee2000
2025-10-21 16:36:38 +00:00
bdf7cb9d9c Revert "[torch/utils][Code Clean] Clean asserts in torch/utils/*.py (#165410)"
This reverts commit e20c9bf2889b9252ac45ae6af35c93c795eab701.

Reverted https://github.com/pytorch/pytorch/pull/165410 on behalf of https://github.com/clee2000 due to sorry I'm going to revert this since I want to try to back out some other things that are conflicting with this, there is nothing wrong with this PR, rebasing and resolving the merge conflicts should be enough, sorry for the churn ([comment](https://github.com/pytorch/pytorch/pull/165410#issuecomment-3427532373))
2025-10-21 16:27:54 +00:00
6aed378958 [export] Handle kwargs better in aot_export_joint_with_descriptors (#165334)
fx.Interpreter doesn't handle kwargs... not sure how this code worked previously

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165334
Approved by: https://github.com/tugsbayasgalan, https://github.com/ezyang
2025-10-21 15:53:05 +00:00
8b3dc0d1b0 Better error handling in torch/csrc/jit/runtime/* (#165118)
Refactor error handling by using TORCH_CHECK for improved clarity in constants and scope management in some files in torch/csrc/jit/runtime/*

Fixes some parts of ISSUE https://github.com/pytorch/pytorch/issues/148114

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165118
Approved by: https://github.com/FFFrog, https://github.com/albanD
2025-10-21 15:22:49 +00:00
06773663b5 Implement an AOT precompile mode for standalone_compile (#165843)
This PR introduces an `aot` flag to standalone_compile that uses BundledAOTAutogradCacheEntry, and then allows regional_inductor to use this so that we can start aot compiling regional compiler graphs. The diff above this will attempt to allow GraphPickler to fully serialize graphs that have regionally compiled subgraphs.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165843
Approved by: https://github.com/oulgen
2025-10-21 15:02:45 +00:00
0bff65503c Move hardware_destructive_interference_size to c10/core/alignment.h (#160067)
# Motivation
Move `hardware_destructive_interference_size` to `c10/core/alignment.h`, which gives a chance to reuse it across different accelerators.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/160067
Approved by: https://github.com/Skylion007, https://github.com/EikanWang
2025-10-21 14:39:46 +00:00
768 changed files with 12680 additions and 7155 deletions

View File

@ -19,7 +19,7 @@ pip_install \
transformers==4.36.2
pip_install coloredlogs packaging
pip_install onnxruntime==1.23.0
pip_install onnxruntime==1.23.1
pip_install onnxscript==0.5.4
# Cache the transformers model to be used later by ONNX tests. We need to run the transformers

View File

@ -334,12 +334,12 @@ sympy==1.13.3
#Pinned versions:
#test that import:
onnx==1.18.0
onnx==1.19.1
#Description: Required by onnx tests, and mypy and test_public_bindings.py when checking torch.onnx._internal
#Pinned versions:
#test that import:
onnxscript==0.5.3
onnxscript==0.5.4
#Description: Required by mypy and test_public_bindings.py when checking torch.onnx._internal
#Pinned versions:
#test that import:

View File

@ -6,7 +6,7 @@ dependencies = [
"GitPython==3.1.45",
"docker==7.1.0",
"pytest==7.3.2",
"uv==0.8.6"
"uv==0.9.5"
]
[tool.setuptools]

View File

@ -163,8 +163,13 @@ if [[ "$(uname)" != Darwin ]]; then
MEMORY_LIMIT_MAX_JOBS=12
NUM_CPUS=$(( $(nproc) - 2 ))
# Defaults here for **binary** linux builds so they can be changed in one place
export MAX_JOBS=${MAX_JOBS:-$(( ${NUM_CPUS} > ${MEMORY_LIMIT_MAX_JOBS} ? ${MEMORY_LIMIT_MAX_JOBS} : ${NUM_CPUS} ))}
if [[ "$(uname)" == Linux ]]; then
# Defaults here for **binary** linux builds so they can be changed in one place
export MAX_JOBS=${MAX_JOBS:-$(( ${NUM_CPUS} > ${MEMORY_LIMIT_MAX_JOBS} ? ${MEMORY_LIMIT_MAX_JOBS} : ${NUM_CPUS} ))}
else
# For other builds
export MAX_JOBS=${NUM_CPUS}
fi
cat >>"$envfile" <<EOL
export MAX_JOBS="${MAX_JOBS}"

View File

@ -0,0 +1,354 @@
# PyTorch Docstring Writing Guide
This skill describes how to write docstrings for functions and methods in the PyTorch project, following the conventions in `torch/_tensor_docs.py` and `torch/nn/functional.py`.
## General Principles
- Use **raw strings** (`r"""..."""`) for all docstrings to avoid issues with LaTeX/math backslashes
- Follow **Sphinx/reStructuredText** (reST) format for documentation
- Be **concise but complete** - include all essential information
- Always include **examples** when possible
- Use **cross-references** to related functions/classes
## Docstring Structure
### 1. Function Signature (First Line)
Start with the function signature showing all parameters:
```python
r"""function_name(param1, param2, *, kwarg1=default1, kwarg2=default2) -> ReturnType
```
**Notes:**
- Include the function name
- Show positional and keyword-only arguments (use `*` separator)
- Include default values
- Show return type annotation
- This line should NOT end with a period
### 2. Brief Description
Provide a one-line description of what the function does:
```python
r"""conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1) -> Tensor
Applies a 2D convolution over an input image composed of several input
planes.
```
### 3. Mathematical Formulas (if applicable)
Use Sphinx math directives for mathematical expressions:
```python
.. math::
\text{Softmax}(x_{i}) = \frac{\exp(x_i)}{\sum_j \exp(x_j)}
```
Or inline math: `:math:\`x^2\``
### 4. Cross-References
Link to related classes and functions using Sphinx roles:
- `:class:\`~torch.nn.ModuleName\`` - Link to a class
- `:func:\`torch.function_name\`` - Link to a function
- `:meth:\`~Tensor.method_name\`` - Link to a method
- `:attr:\`attribute_name\`` - Reference an attribute
- The `~` prefix shows only the last component (e.g., `Conv2d` instead of `torch.nn.Conv2d`)
**Example:**
```python
See :class:`~torch.nn.Conv2d` for details and output shape.
```
### 5. Notes and Warnings
Use admonitions for important information:
```python
.. note::
This function doesn't work directly with NLLLoss,
which expects the Log to be computed between the Softmax and itself.
Use log_softmax instead (it's faster and has better numerical properties).
.. warning::
:func:`new_tensor` always copies :attr:`data`. If you have a Tensor
``data`` and want to avoid a copy, use :func:`torch.Tensor.requires_grad_`
or :func:`torch.Tensor.detach`.
```
### 6. Args Section
Document all parameters with type annotations and descriptions:
```python
Args:
input (Tensor): input tensor of shape :math:`(\text{minibatch} , \text{in\_channels} , iH , iW)`
weight (Tensor): filters of shape :math:`(\text{out\_channels} , kH , kW)`
bias (Tensor, optional): optional bias tensor of shape :math:`(\text{out\_channels})`. Default: ``None``
stride (int or tuple): the stride of the convolving kernel. Can be a single number or a
tuple `(sH, sW)`. Default: 1
```
**Formatting rules:**
- Parameter name in **lowercase**
- Type in parentheses: `(Type)`, `(Type, optional)` for optional parameters
- Description follows the type
- For optional parameters, include "Default: ``value``" at the end
- Use double backticks for inline code: ``` ``None`` ```
- Indent continuation lines by 2 spaces
### 7. Keyword Args Section (if applicable)
Sometimes keyword arguments are documented separately:
```python
Keyword args:
dtype (:class:`torch.dtype`, optional): the desired type of returned tensor.
Default: if None, same :class:`torch.dtype` as this tensor.
device (:class:`torch.device`, optional): the desired device of returned tensor.
Default: if None, same :class:`torch.device` as this tensor.
requires_grad (bool, optional): If autograd should record operations on the
returned tensor. Default: ``False``.
```
### 8. Returns Section (if needed)
Document the return value:
```python
Returns:
Tensor: Sampled tensor of same shape as `logits` from the Gumbel-Softmax distribution.
If ``hard=True``, the returned samples will be one-hot, otherwise they will
be probability distributions that sum to 1 across `dim`.
```
Or simply include it in the function signature line if obvious from context.
### 9. Examples Section
Always include examples when possible:
```python
Examples::
>>> inputs = torch.randn(33, 16, 30)
>>> filters = torch.randn(20, 16, 5)
>>> F.conv1d(inputs, filters)
>>> # With square kernels and equal stride
>>> filters = torch.randn(8, 4, 3, 3)
>>> inputs = torch.randn(1, 4, 5, 5)
>>> F.conv2d(inputs, filters, padding=1)
```
**Formatting rules:**
- Use `Examples::` with double colon
- Use `>>>` prompt for Python code
- Include comments with `#` when helpful
- Show actual output when it helps understanding (indent without `>>>`)
### 10. External References
Link to papers or external documentation:
```python
.. _Link Name:
https://arxiv.org/abs/1611.00712
```
Reference them in text: ```See `Link Name`_```
## Method Types
### Native Python Functions
For regular Python functions, use a standard docstring:
```python
def relu(input: Tensor, inplace: bool = False) -> Tensor:
r"""relu(input, inplace=False) -> Tensor
Applies the rectified linear unit function element-wise. See
:class:`~torch.nn.ReLU` for more details.
"""
# implementation
```
### C-Bound Functions (using add_docstr)
For C-bound functions, use `_add_docstr`:
```python
conv1d = _add_docstr(
torch.conv1d,
r"""
conv1d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1) -> Tensor
Applies a 1D convolution over an input signal composed of several input
planes.
See :class:`~torch.nn.Conv1d` for details and output shape.
Args:
input: input tensor of shape :math:`(\text{minibatch} , \text{in\_channels} , iW)`
weight: filters of shape :math:`(\text{out\_channels} , kW)`
...
""",
)
```
### In-Place Variants
For in-place operations (ending with `_`), reference the original:
```python
add_docstr_all(
"abs_",
r"""
abs_() -> Tensor
In-place version of :meth:`~Tensor.abs`
""",
)
```
### Alias Functions
For aliases, simply reference the original:
```python
add_docstr_all(
"absolute",
r"""
absolute() -> Tensor
Alias for :func:`abs`
""",
)
```
## Common Patterns
### Shape Documentation
Use LaTeX math notation for tensor shapes:
```python
:math:`(\text{minibatch} , \text{in\_channels} , iH , iW)`
```
### Reusable Argument Definitions
For commonly used arguments, define them once and reuse:
```python
common_args = parse_kwargs(
"""
dtype (:class:`torch.dtype`, optional): the desired type of returned tensor.
Default: if None, same as this tensor.
"""
)
# Then use with .format():
r"""
...
Keyword args:
{dtype}
{device}
""".format(**common_args)
```
### Template Insertion
Insert reproducibility notes or other common text:
```python
r"""
{tf32_note}
{cudnn_reproducibility_note}
""".format(**reproducibility_notes, **tf32_notes)
```
## Complete Example
Here's a complete example showing all elements:
```python
def gumbel_softmax(
logits: Tensor,
tau: float = 1,
hard: bool = False,
eps: float = 1e-10,
dim: int = -1,
) -> Tensor:
r"""
Sample from the Gumbel-Softmax distribution and optionally discretize.
Args:
logits (Tensor): `[..., num_features]` unnormalized log probabilities
tau (float): non-negative scalar temperature
hard (bool): if ``True``, the returned samples will be discretized as one-hot vectors,
but will be differentiated as if it is the soft sample in autograd. Default: ``False``
dim (int): A dimension along which softmax will be computed. Default: -1
Returns:
Tensor: Sampled tensor of same shape as `logits` from the Gumbel-Softmax distribution.
If ``hard=True``, the returned samples will be one-hot, otherwise they will
be probability distributions that sum to 1 across `dim`.
.. note::
This function is here for legacy reasons, may be removed from nn.Functional in the future.
Examples::
>>> logits = torch.randn(20, 32)
>>> # Sample soft categorical using reparametrization trick:
>>> F.gumbel_softmax(logits, tau=1, hard=False)
>>> # Sample hard categorical using "Straight-through" trick:
>>> F.gumbel_softmax(logits, tau=1, hard=True)
.. _Link 1:
https://arxiv.org/abs/1611.00712
"""
# implementation
```
## Quick Checklist
When writing a PyTorch docstring, ensure:
- [ ] Use raw string (`r"""`)
- [ ] Include function signature on first line
- [ ] Provide brief description
- [ ] Document all parameters in Args section with types
- [ ] Include default values for optional parameters
- [ ] Use Sphinx cross-references (`:func:`, `:class:`, `:meth:`)
- [ ] Add mathematical formulas if applicable
- [ ] Include at least one example in Examples section
- [ ] Add warnings/notes for important caveats
- [ ] Link to related module class with `:class:`
- [ ] Use proper math notation for tensor shapes
- [ ] Follow consistent formatting and indentation
## Common Sphinx Roles Reference
- `:class:\`~torch.nn.Module\`` - Class reference
- `:func:\`torch.function\`` - Function reference
- `:meth:\`~Tensor.method\`` - Method reference
- `:attr:\`attribute\`` - Attribute reference
- `:math:\`equation\`` - Inline math
- `:ref:\`label\`` - Internal reference
- ``` ``code`` ``` - Inline code (use double backticks)
## Additional Notes
- **Indentation**: Use 4 spaces for code, 2 spaces for continuation of parameter descriptions
- **Line length**: Try to keep lines under 100 characters when possible
- **Periods**: End sentences with periods, but not the signature line
- **Backticks**: Use double backticks for code: ``` ``True`` ``None`` ``False`` ```
- **Types**: Common types are `Tensor`, `int`, `float`, `bool`, `str`, `tuple`, `list`, etc.

View File

@ -124,3 +124,10 @@ runs:
id: login-ecr
continue-on-error: true
uses: aws-actions/amazon-ecr-login@062b18b96a7aff071d4dc91bc00c4c1a7945b076 # v2.0.1
- name: Preserve github env variables for use in docker
shell: bash
run: |
env | grep '^GITHUB' >> "${RUNNER_TEMP}/github_env_${GITHUB_RUN_ID}"
env | grep '^CI' >> "${RUNNER_TEMP}/github_env_${GITHUB_RUN_ID}"
env | grep '^RUNNER' >> "${RUNNER_TEMP}/github_env_${GITHUB_RUN_ID}"

View File

@ -1 +1 @@
faffd5cf673615583da6517275e361cb3dbc77e6
1752fe6809b74921644866275ab80244b96e80bc

View File

@ -283,6 +283,9 @@ RUN --mount=type=bind,source=${TORCH_WHEELS_PATH},target=/dist \
uv pip install --system $(cat torch_build_versions.txt | xargs) --index-url https://download.pytorch.org/whl/nightly/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.'); \
fi
RUN --mount=type=cache,target=/root/.cache/uv \
uv pip install --system --pre apache-tvm-ffi==0.1.0b15
# Install the vllm wheel from previous stage
RUN --mount=type=cache,target=/root/.cache/uv \
uv pip install --system /wheels/vllm/*.whl --verbose
@ -295,6 +298,8 @@ RUN --mount=type=cache,target=/root/.cache/uv \
ARG torch_cuda_arch_list='8.0;8.9;9.0a;10.0a;12.0'
ENV TORCH_CUDA_ARCH_LIST=${torch_cuda_arch_list}
# TODO(elainewy): remove this once vllm commit is updated, and install flashinfer from pip
# see https://github.com/pytorch/pytorch/pull/165274#issuecomment-3408531784
ARG FLASHINFER_GIT_REPO="https://github.com/flashinfer-ai/flashinfer.git"
ARG FLASHINFER_GIT_REF="v0.2.14.post1"

View File

@ -15,6 +15,11 @@
- "module: reinplacing"
then:
- "module: pt2-dispatcher"
- any:
- "vllm-compile"
then:
- "module: vllm"
- "oncall: pt2"
- any:
- "module: vmap"
then:
@ -27,10 +32,6 @@
- "module: pt2 optimizer"
then:
- "module: dynamo"
- any:
- "module: flex attention"
then:
- "module: higher order operators"
- any:
- "module: aotinductor"
then:

View File

@ -79,9 +79,9 @@ jobs:
runs-on: "windows-11-arm64-preview"
{%- else %}
{%- if branches == "nightly" %}
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
{%- else %}
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge.nonephemeral"
{%- endif %}
{%- endif %}
timeout-minutes: !{{ common.timeout_minutes_windows_binary }}

View File

@ -44,7 +44,7 @@ jobs:
libtorch-cpu-shared-with-deps-debug-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -291,7 +291,7 @@ jobs:
libtorch-cuda12_6-shared-with-deps-debug-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -541,7 +541,7 @@ jobs:
libtorch-cuda12_8-shared-with-deps-debug-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -791,7 +791,7 @@ jobs:
libtorch-cuda13_0-shared-with-deps-debug-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch

View File

@ -44,7 +44,7 @@ jobs:
libtorch-cpu-shared-with-deps-release-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -291,7 +291,7 @@ jobs:
libtorch-cuda12_6-shared-with-deps-release-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -541,7 +541,7 @@ jobs:
libtorch-cuda12_8-shared-with-deps-release-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -791,7 +791,7 @@ jobs:
libtorch-cuda13_0-shared-with-deps-release-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch

View File

@ -44,7 +44,7 @@ jobs:
wheel-py3_10-cpu-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -279,7 +279,7 @@ jobs:
wheel-py3_10-cuda12_6-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -517,7 +517,7 @@ jobs:
wheel-py3_10-cuda12_8-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -755,7 +755,7 @@ jobs:
wheel-py3_10-cuda13_0-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -993,7 +993,7 @@ jobs:
wheel-py3_10-xpu-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -1229,7 +1229,7 @@ jobs:
wheel-py3_11-cpu-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -1464,7 +1464,7 @@ jobs:
wheel-py3_11-cuda12_6-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -1702,7 +1702,7 @@ jobs:
wheel-py3_11-cuda12_8-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -1940,7 +1940,7 @@ jobs:
wheel-py3_11-cuda13_0-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -2178,7 +2178,7 @@ jobs:
wheel-py3_11-xpu-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -2414,7 +2414,7 @@ jobs:
wheel-py3_12-cpu-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -2649,7 +2649,7 @@ jobs:
wheel-py3_12-cuda12_6-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -2887,7 +2887,7 @@ jobs:
wheel-py3_12-cuda12_8-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -3125,7 +3125,7 @@ jobs:
wheel-py3_12-cuda13_0-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -3363,7 +3363,7 @@ jobs:
wheel-py3_12-xpu-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -3599,7 +3599,7 @@ jobs:
wheel-py3_13-cpu-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -3834,7 +3834,7 @@ jobs:
wheel-py3_13-cuda12_6-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -4072,7 +4072,7 @@ jobs:
wheel-py3_13-cuda12_8-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -4310,7 +4310,7 @@ jobs:
wheel-py3_13-cuda13_0-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -4548,7 +4548,7 @@ jobs:
wheel-py3_13-xpu-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -4784,7 +4784,7 @@ jobs:
wheel-py3_13t-cpu-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -5019,7 +5019,7 @@ jobs:
wheel-py3_13t-cuda12_6-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -5257,7 +5257,7 @@ jobs:
wheel-py3_13t-cuda12_8-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -5495,7 +5495,7 @@ jobs:
wheel-py3_13t-cuda13_0-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -5733,7 +5733,7 @@ jobs:
wheel-py3_13t-xpu-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -5969,7 +5969,7 @@ jobs:
wheel-py3_14-cpu-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -6204,7 +6204,7 @@ jobs:
wheel-py3_14-cuda12_6-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -6442,7 +6442,7 @@ jobs:
wheel-py3_14-cuda12_8-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -6680,7 +6680,7 @@ jobs:
wheel-py3_14-cuda13_0-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -6918,7 +6918,7 @@ jobs:
wheel-py3_14-xpu-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -7154,7 +7154,7 @@ jobs:
wheel-py3_14t-cpu-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -7389,7 +7389,7 @@ jobs:
wheel-py3_14t-cuda12_6-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -7627,7 +7627,7 @@ jobs:
wheel-py3_14t-cuda12_8-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -7865,7 +7865,7 @@ jobs:
wheel-py3_14t-cuda13_0-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -8103,7 +8103,7 @@ jobs:
wheel-py3_14t-xpu-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch

View File

@ -88,7 +88,6 @@ jobs:
with:
build-environment: linux-jammy-rocm-py3_10
docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3-benchmarks
sync-tag: rocm-build
test-matrix: |
{ include: [
{ config: "dynamo_eager_torchbench", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.gfx942.1" },

View File

@ -147,15 +147,16 @@ jobs:
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
build-environment: linux-jammy-cuda12.8-py3.10-gcc9-debug
docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9
cuda-arch-list: 8.9
test-matrix: |
{ include: [
{ config: "default", shard: 1, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu", owners: ["oncall:debug-build"] },
{ config: "default", shard: 2, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu", owners: ["oncall:debug-build"] },
{ config: "default", shard: 3, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu", owners: ["oncall:debug-build"] },
{ config: "default", shard: 4, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu", owners: ["oncall:debug-build"] },
{ config: "default", shard: 5, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu", owners: ["oncall:debug-build"] },
{ config: "default", shard: 6, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu", owners: ["oncall:debug-build"] },
{ config: "default", shard: 7, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu", owners: ["oncall:debug-build"] },
{ config: "default", shard: 1, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", owners: ["oncall:debug-build"] },
{ config: "default", shard: 2, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", owners: ["oncall:debug-build"] },
{ config: "default", shard: 3, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", owners: ["oncall:debug-build"] },
{ config: "default", shard: 4, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", owners: ["oncall:debug-build"] },
{ config: "default", shard: 5, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", owners: ["oncall:debug-build"] },
{ config: "default", shard: 6, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", owners: ["oncall:debug-build"] },
{ config: "default", shard: 7, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", owners: ["oncall:debug-build"] },
]}
secrets: inherit

View File

@ -347,7 +347,8 @@ jobs:
uses: ./.github/workflows/_linux-build.yml
needs: get-label-type
with:
sync-tag: linux-xpu-n-build
# This should sync with the build in xpu.yml but xpu uses a larger runner
# sync-tag: linux-xpu-n-build
runner_prefix: ${{ needs.get-label-type.outputs.label-type }}
build-environment: linux-jammy-xpu-n-py3.10
docker-image-name: ci-image:pytorch-linux-jammy-xpu-n-py3

View File

@ -45,7 +45,6 @@ jobs:
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
build-environment: linux-noble-rocm-py3.12-mi300
docker-image-name: ci-image:pytorch-linux-noble-rocm-n-py3
sync-tag: rocm-build
test-matrix: |
{ include: [
{ config: "default", shard: 1, num_shards: 6, runner: "linux.rocm.gpu.gfx942.1" },

View File

@ -42,7 +42,6 @@ jobs:
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
build-environment: linux-noble-rocm-py3.12-mi355
docker-image-name: ci-image:pytorch-linux-noble-rocm-n-py3
sync-tag: rocm-build
test-matrix: |
{ include: [
{ config: "default", shard: 1, num_shards: 6, runner: "linux.rocm.gpu.mi355.1" },

View File

@ -26,11 +26,23 @@ jobs:
id-token: write
contents: read
get-label-type:
name: get-label-type
uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main
if: ${{ (github.event_name != 'schedule' || github.repository == 'pytorch/pytorch') && github.repository_owner == 'pytorch' }}
with:
triggering_actor: ${{ github.triggering_actor }}
issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }}
curr_branch: ${{ github.head_ref || github.ref_name }}
curr_ref_type: ${{ github.ref_type }}
linux-jammy-rocm-py3_10-build:
if: ${{ (github.event_name != 'schedule' || github.repository == 'pytorch/pytorch') && github.repository_owner == 'pytorch' }}
name: linux-jammy-rocm-py3.10
uses: ./.github/workflows/_linux-build.yml
needs: get-label-type
with:
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
build-environment: linux-jammy-rocm-py3.10
docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3
sync-tag: rocm-build

View File

@ -26,11 +26,23 @@ jobs:
id-token: write
contents: read
get-label-type:
name: get-label-type
uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main
if: ${{ (github.event_name != 'schedule' || github.repository == 'pytorch/pytorch') && github.repository_owner == 'pytorch' }}
with:
triggering_actor: ${{ github.triggering_actor }}
issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }}
curr_branch: ${{ github.head_ref || github.ref_name }}
curr_ref_type: ${{ github.ref_type }}
linux-jammy-rocm-py3_10-build:
if: ${{ (github.event_name != 'schedule' || github.repository == 'pytorch/pytorch') && github.repository_owner == 'pytorch' }}
name: linux-jammy-rocm-py3.10
uses: ./.github/workflows/_linux-build.yml
needs: get-label-type
with:
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
build-environment: linux-jammy-rocm-py3.10
docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3
sync-tag: rocm-build

View File

@ -58,8 +58,10 @@ jobs:
else
COMMIT_SHA="${{ github.sha }}"
fi
echo "sha=${COMMIT_SHA}" >> "${GITHUB_OUTPUT}"
echo "tag_name=trunk/${COMMIT_SHA}" >> "${GITHUB_OUTPUT}"
{
echo "sha=${COMMIT_SHA}"
echo "tag_name=trunk/${COMMIT_SHA}"
} >> "${GITHUB_OUTPUT}"
- name: Validate commit SHA
run: |
@ -87,7 +89,7 @@ jobs:
echo "✅ Commit ${COMMIT_SHA} is valid (automatic push trigger)"
fi
- name: Create and push tag with retry
- name: Create and push tag(s) with retry
id: check_tag
env:
TAG_NAME: ${{ steps.commit.outputs.tag_name }}
@ -112,14 +114,23 @@ jobs:
return 1
}
# Exit early if tag already exists
if check_tag_exists; then
echo "✅ Tag already exists - no action needed"
echo "exists=true" >> "${GITHUB_OUTPUT}"
exit 0
fi
# Counters for summary reporting
created_count=0
skipped_count=0
failed_count=0
echo "Tag ${TAG_NAME} does not exist, proceeding with creation"
# Always write outputs once on exit
finish() {
set +e
if [ -n "${GITHUB_OUTPUT:-}" ]; then
{
echo "created_count=${created_count}"
echo "skipped_count=${skipped_count}"
echo "failed_count=${failed_count}"
} >> "${GITHUB_OUTPUT}"
fi
}
trap finish EXIT
# Retry configuration
MAX_RETRIES=5
@ -194,31 +205,111 @@ jobs:
}
}
# Execute with retry
if retry_with_backoff "tag_with_retry" "Creating tag ${TAG_NAME} for commit ${COMMIT_SHA}"; then
echo "exists=false" >> "${GITHUB_OUTPUT}"
# New behavior for push events: enumerate commits in the push and tag each one.
# For workflow_dispatch, retain existing single-SHA behavior.
# Always fetch tags once up front to improve idempotency in loops
git fetch origin --tags --quiet || true
if [ "${{ github.event_name }}" = "push" ]; then
BEFORE_SHA="${{ github.event.before }}"
AFTER_SHA="${{ github.sha }}" # same as event.after
# List commits introduced by this push (old..new), oldest first for stable ordering
commits_file="$(mktemp)"
git rev-list --reverse "${BEFORE_SHA}..${AFTER_SHA}" > "${commits_file}"
if [ ! -s "${commits_file}" ]; then
echo "No new commits found between ${BEFORE_SHA}..${AFTER_SHA}; nothing to tag."
rm -f "${commits_file}"
exit 0
fi
commit_count="$(wc -l < "${commits_file}" | tr -d ' ')"
echo "Found ${commit_count} commit(s) to tag for push:"
while IFS= read -r sha; do
printf ' %s\n' "${sha}"
done < "${commits_file}"
while IFS= read -r sha; do
TAG_NAME="trunk/${sha}"
COMMIT_SHA="${sha}"
# If tag already exists locally or remotely, skip (idempotent)
if check_tag_exists; then
echo "✅ Tag ${TAG_NAME} already exists - skipping"
skipped_count=$((skipped_count + 1))
continue
fi
echo "Tag ${TAG_NAME} does not exist, proceeding with creation"
if retry_with_backoff "tag_with_retry" "Creating tag ${TAG_NAME} for commit ${COMMIT_SHA}"; then
created_count=$((created_count + 1))
else
echo "Tag creation failed after all retry attempts for ${TAG_NAME}"
failed_count=$((failed_count + 1))
fi
done < "${commits_file}"
rm -f "${commits_file}"
if [ "${failed_count}" -gt 0 ]; then
exit 1
fi
exit 0
else
echo "Tag creation failed after all retry attempts"
exit 1
# workflow_dispatch path (single SHA tagging preserved)
# Exit early if tag already exists
if check_tag_exists; then
echo "✅ Tag already exists - no action needed"
skipped_count=1
exit 0
fi
echo "Tag ${TAG_NAME} does not exist, proceeding with creation"
if retry_with_backoff "tag_with_retry" "Creating tag ${TAG_NAME} for commit ${COMMIT_SHA}"; then
created_count=1
exit 0
else
echo "Tag creation failed after all retry attempts"
failed_count=1
exit 1
fi
fi
- name: Tag creation summary
if: always()
run: |
if [ "${{ steps.check_tag.outputs.exists }}" = "true" ]; then
echo "✅ Tag ${{ steps.commit.outputs.tag_name }} already existed - no action needed"
elif [ "${{ job.status }}" = "success" ]; then
echo "✅ Successfully created tag ${{ steps.commit.outputs.tag_name }} for commit ${{ steps.commit.outputs.sha }}"
if [ "${{ github.event_name }}" = "push" ]; then
echo "Trigger: push on main"
echo "Created: ${{ steps.check_tag.outputs.created_count }}"
echo "Skipped (already existed): ${{ steps.check_tag.outputs.skipped_count }}"
echo "Failed: ${{ steps.check_tag.outputs.failed_count }}"
if [ "${{ steps.check_tag.outputs.failed_count }}" = "0" ]; then
echo "✅ Completed tagging for push range ${{ github.event.before }}..${{ github.sha }}"
else
echo "❌ Some tags failed to create for push range ${{ github.event.before }}..${{ github.sha }}"
fi
else
echo "❌ Failed to create tag ${{ steps.commit.outputs.tag_name }} for commit ${{ steps.commit.outputs.sha }}"
fi
if [ "${{ steps.check_tag.outputs.failed_count }}" = "0" ]; then
if [ "${{ steps.check_tag.outputs.created_count }}" = "0" ]; then
echo "✅ Tag ${{ steps.commit.outputs.tag_name }} already existed - no action needed"
else
echo "✅ Successfully created tag ${{ steps.commit.outputs.tag_name }} for commit ${{ steps.commit.outputs.sha }}"
fi
else
echo "❌ Failed to create tag ${{ steps.commit.outputs.tag_name }} for commit ${{ steps.commit.outputs.sha }}"
fi
echo ""
echo "Tag details:"
echo " Name: ${{ steps.commit.outputs.tag_name }}"
echo " Commit: ${{ steps.commit.outputs.sha }}"
echo " Trigger: ${{ github.event_name }}"
if [ -n "${{ github.event.inputs.commit_sha }}" ]; then
echo " Manual commit: ${{ github.event.inputs.commit_sha }}"
echo ""
echo "Tag details:"
echo " Name: ${{ steps.commit.outputs.tag_name }}"
echo " Commit: ${{ steps.commit.outputs.sha }}"
echo " Trigger: ${{ github.event_name }}"
if [ -n "${{ github.event.inputs.commit_sha }}" ]; then
echo " Manual commit: ${{ github.event.inputs.commit_sha }}"
fi
fi

View File

@ -833,8 +833,7 @@ exclude_patterns = [
command = [
'python3',
'tools/linter/adapters/grep_linter.py',
'--pattern=cudaSetDevice(',
'--pattern=cudaGetDevice(',
'--pattern=(cudaSetDevice|cudaGetDevice)\\(',
'--linter-name=RAWCUDADEVICE',
'--error-name=raw CUDA API usage',
"""--error-description=\
@ -1138,11 +1137,8 @@ command = [
[[linter]]
code = 'WORKFLOWSYNC'
include_patterns = [
'.github/workflows/pull.yml',
'.github/workflows/trunk.yml',
'.github/workflows/periodic.yml',
'.github/workflows/mac-mps.yml',
'.github/workflows/slow.yml',
'.github/workflows/*.yml',
'.github/workflows/*.yaml',
]
command = [
'python3',

View File

@ -289,14 +289,15 @@ IF(USE_FBGEMM_GENAI)
set_target_properties(fbgemm_genai PROPERTIES POSITION_INDEPENDENT_CODE ON)
set(fbgemm_genai_mx8mx8bf16_grouped
set(fbgemm_genai_cuh
"${FBGEMM_GENAI_SRCS}/cutlass_extensions/mx8mx8bf16_grouped/"
"${FBGEMM_GENAI_SRCS}/"
)
target_include_directories(fbgemm_genai PRIVATE
${FBGEMM_THIRD_PARTY}/cutlass/include
${FBGEMM_THIRD_PARTY}/cutlass/tools/util/include
${fbgemm_genai_mx8mx8bf16_grouped}
${fbgemm_genai_cuh}
${FBGEMM_GENAI_SRCS}/common/include/ # includes fbgemm_gpu/quantize/utils.h, fbgemm_gpu/quantize/tuning_cache.hpp
${FBGEMM_GENAI_SRCS}/include/ # includes fbgemm_gpu/torch_ops.h
)

View File

@ -19,6 +19,7 @@
#include <ATen/detail/MPSHooksInterface.h>
#include <ATen/detail/MTIAHooksInterface.h>
#include <ATen/detail/PrivateUse1HooksInterface.h>
#include <ATen/detail/XLAHooksInterface.h>
#include <ATen/detail/XPUHooksInterface.h>
#include <c10/core/QEngine.h>
#include <c10/core/impl/DeviceGuardImplInterface.h>
@ -88,6 +89,8 @@ class TORCH_API Context {
return at::detail::getHIPHooks();
} else if (opt_device_type == at::kHPU) {
return at::detail::getHPUHooks();
} else if (opt_device_type == at::kXLA) {
return at::detail::getXLAHooks();
} else {
TORCH_CHECK(
false,
@ -196,7 +199,7 @@ class TORCH_API Context {
return c10::impl::hasDeviceGuardImpl(c10::DeviceType::IPU);
}
static bool hasXLA() {
return c10::impl::hasDeviceGuardImpl(c10::DeviceType::XLA);
return detail::getXLAHooks().hasXLA();
}
static bool hasXPU() {
return detail::getXPUHooks().hasXPU();

View File

@ -39,7 +39,7 @@ struct HostBlock {
};
template <typename B>
struct alignas(64) FreeBlockList {
struct alignas(hardware_destructive_interference_size) FreeBlockList {
std::mutex mutex_;
std::deque<B*> list_;
};
@ -122,7 +122,7 @@ struct TORCH_API HostStats {
// Struct containing memory allocator summary statistics for host, as they
// are staged for reporting. This is a temporary struct that is used to
// avoid locking the allocator while collecting stats.
struct alignas(64) HostStatsStaged {
struct alignas(hardware_destructive_interference_size) HostStatsStaged {
std::mutex timing_mutex_;
// COUNT: total allocations (active + free)
// LOCK: access to this stat is protected by the allocator's blocks_mutex_
@ -669,7 +669,7 @@ struct CachingHostAllocatorImpl {
TORCH_CHECK_NOT_IMPLEMENTED(false, "Not implemented for query_event");
}
alignas(64) std::mutex blocks_mutex_;
alignas(hardware_destructive_interference_size) std::mutex blocks_mutex_;
ska::flat_hash_set<B*> blocks_; // block list
ska::flat_hash_map<void*, B*> ptr_to_block_;
@ -677,17 +677,17 @@ struct CachingHostAllocatorImpl {
// size. This allows us to quickly find a free block of the right size.
// We use deque to store per size free list and guard the list with its own
// mutex.
alignas(64) std::vector<FreeBlockList<B>> free_list_ =
alignas(hardware_destructive_interference_size) std::vector<FreeBlockList<B>> free_list_ =
std::vector<FreeBlockList<B>>(MAX_SIZE_INDEX);
alignas(64) std::mutex events_mutex_;
alignas(hardware_destructive_interference_size) std::mutex events_mutex_;
std::deque<std::pair<E, B*>> events_; // event queue paired with block
// Indicates whether the object is active.
// Set to false in the destructor to signal background threads to stop.
std::atomic<bool> active_{true};
protected:
alignas(64) HostStatsStaged stats_;
alignas(hardware_destructive_interference_size) HostStatsStaged stats_;
};
struct TORCH_API HostAllocator : public at::Allocator {

View File

@ -59,9 +59,7 @@ struct TORCH_API Generator {
explicit Generator(c10::intrusive_ptr<c10::GeneratorImpl> gen_impl)
: impl_(std::move(gen_impl)) {
if (impl_.get() == nullptr) {
throw std::runtime_error("GeneratorImpl with nullptr is not supported");
}
TORCH_CHECK(impl_.get(), "GeneratorImpl with nullptr is not supported");
}
bool operator==(const Generator& rhs) const {

View File

@ -111,9 +111,7 @@ class TORCH_API TensorBase {
explicit TensorBase(
c10::intrusive_ptr<TensorImpl, UndefinedTensorImpl> tensor_impl)
: impl_(std::move(tensor_impl)) {
if (impl_.get() == nullptr) {
throw std::runtime_error("TensorImpl with nullptr is not supported");
}
TORCH_CHECK(impl_.get(), "TensorImpl with nullptr is not supported");
}
TensorBase(const TensorBase&) = default;
TensorBase(TensorBase&&) noexcept = default;

View File

@ -109,6 +109,10 @@ TORCH_LIBRARY_IMPL(_, AutogradHPU, m) {
m.fallback(AUTOGRAD_FALLBACK);
}
TORCH_LIBRARY_IMPL(_, AutogradPrivateUse1, m) {
m.fallback(AUTOGRAD_FALLBACK);
}
#undef AUTOGRAD_FALLBACK
} // namespace

View File

@ -442,11 +442,17 @@ RegistrationHandleRAII Dispatcher::registerFallback(DispatchKey dispatchKey, Ker
auto idx = getDispatchTableIndexForDispatchKey(dispatchKey);
TORCH_CHECK(idx >= 0 && static_cast<uint64_t>(idx) < backendFallbackKernels_.size(), "idx=", idx);
// NB: Perserve BC for registering fallback for AutogradPrivateUse1 multiple time,
// refer to https://github.com/pytorch/pytorch/issues/163979 for more informations.
TORCH_CHECK(
!backendFallbackKernels_[idx].kernel.isValid(),
"Tried to register multiple backend fallbacks for the same dispatch key ", dispatchKey, "; previous registration ",
backendFallbackKernels_[idx].debug, ", new registration ", debug
);
dispatchKey == DispatchKey::AutogradPrivateUse1 ||
!backendFallbackKernels_[idx].kernel.isValid(),
"Tried to register multiple backend fallbacks for the same dispatch key ",
dispatchKey,
"; previous registration ",
backendFallbackKernels_[idx].debug,
", new registration ",
debug);
// NB: inferred function schema is always nullptr for fallbacks, as fallbacks
// cannot be unboxed
backendFallbackKernels_[idx] = impl::AnnotatedKernel(std::move(kernel), nullptr, std::move(debug));

View File

@ -68,11 +68,7 @@ Symbol InternedStrings::_symbol(const std::string& s) {
return it->second;
auto pos = s.find("::");
if (pos == std::string::npos) {
std::stringstream ss;
ss << "all symbols must have a namespace, <namespace>::<string>, but found: " << s;
throw std::runtime_error(ss.str());
}
TORCH_CHECK(pos != std::string::npos, "all symbols must have a namespace, <namespace>::<string>, but found: ", s);
Symbol ns = _symbol("namespaces::" + s.substr(0, pos));
Symbol sym(sym_to_info_.size());
@ -121,12 +117,7 @@ std::string Symbol::domainString() const {
}
Symbol Symbol::fromDomainAndUnqualString(const std::string & d, const std::string & s) {
if (d.compare(0, domain_prefix().size(), domain_prefix()) != 0) {
std::ostringstream ss;
ss << "Symbol: domain string is expected to be prefixed with '"
<< domain_prefix() << "', e.g. 'org.pytorch.aten'";
throw std::runtime_error(ss.str());
}
TORCH_CHECK(d.compare(0, domain_prefix().size(), domain_prefix()) == 0, "Symbol: domain string is expected to be prefixed with '", domain_prefix(), "', e.g. 'org.pytorch.aten'");
std::string qualString = d.substr(domain_prefix().size()) + "::" + s;
return fromQualString(qualString);
}

View File

@ -7,6 +7,7 @@
#include <ATen/core/jit_type.h>
#include <ATen/core/stack.h>
#include <ATen/core/type_factory.h>
#include <c10/util/Exception.h>
#include <c10/util/StringUtil.h>
#include <c10/util/hash.h>
#include <c10/util/irange.h>
@ -412,7 +413,7 @@ size_t IValue::hash(const IValue& v) {
case Tag::Enum:
case Tag::Stream:
case Tag::Uninitialized:
throw std::runtime_error(
TORCH_CHECK(false,
"unhashable type: '" + v.type()->repr_str() + "'");
}
// the above switch should be exhaustive

View File

@ -8,6 +8,7 @@
#include <ATen/core/type_factory.h>
#include <ATen/core/qualified_name.h>
#include <c10/util/TypeList.h>
#include <c10/util/Exception.h>
#include <optional>
#include <c10/core/SymFloat.h>
#include <c10/core/SymBool.h>
@ -116,10 +117,8 @@ struct SingleElementType : public SharedType {
protected:
SingleElementType(TypePtr elem) : SharedType(Kind), elem(std::move(elem)) {
if (!this->elem) {
throw std::runtime_error(c10::str(
TORCH_CHECK(this->elem, c10::str(
"Can not create ", typeKindToString(Kind), " with None type"));
}
}
private:
@ -416,16 +415,12 @@ struct TORCH_API SymbolicShape {
}
ShapeSymbol operator[](size_t i) const {
if (!dims_) {
throw std::runtime_error("Rank isn't fixed");
}
TORCH_CHECK(dims_, "Rank isn't fixed");
return (*dims_).at(i);
}
ShapeSymbol at(size_t i) const {
if (!dims_) {
throw std::runtime_error("Rank isn't fixed");
}
TORCH_CHECK(dims_, "Rank isn't fixed");
return (*dims_).at(i);
}
@ -520,9 +515,7 @@ struct VaryingShape {
}
const std::optional<T> &operator[](size_t i) const {
if (!dims_) {
throw std::runtime_error("Rank isn't fixed");
}
TORCH_CHECK(dims_, "Rank isn't fixed");
return (*dims_).at(i);
}
@ -957,9 +950,7 @@ struct TORCH_API DictType : public SharedType {
TypePtr createWithContained(
std::vector<TypePtr> contained_types) const override {
if (contained_types.size() != 2) {
throw std::runtime_error("Expected 2 contained types");
}
TORCH_CHECK(contained_types.size() == 2, "Expected 2 contained types");
return create(std::move(contained_types.at(0)), std::move(contained_types.at(1)));
}

View File

@ -185,11 +185,11 @@ struct TORCH_API Type {
: repr_(nullptr) {}
/* implicit */ SingletonOrSharedTypePtr(SingletonTypePtr<T> p)
: repr_(p) {}
: repr_(makeSingletonSharedPtr(p.get())) {}
template <typename U, std::enable_if_t<std::is_convertible_v<U*, T*>, bool> = true>
/* implicit */ SingletonOrSharedTypePtr(SingletonTypePtr<U> p)
: repr_(SingletonTypePtr<T>(p.get())) {}
: repr_(makeSingletonSharedPtr(static_cast<T*>(p.get()))) {}
// We need to support construction from T* for pybind. The problem
@ -202,8 +202,8 @@ struct TORCH_API Type {
// Case 2: if T is exactly Type, we need to do a dynamic_cast to
// check if it's a SharedType and do the right thing.
//
// Case 3: Otherwise, T is not a SharedType. (debug-check this
// assumption!) Use a singleton pointer.
// Case 3: Otherwise, T is not a SharedType. Use a singleton
// pointer.
template <typename U = T, std::enable_if_t<std::is_base_of_v<SharedType, U>, bool> = true>
/* implicit */ SingletonOrSharedTypePtr(T* p) : SingletonOrSharedTypePtr(static_cast<typename detail::as_shared_type<U>::type>(p)->shared_from_this()) {}
@ -211,15 +211,15 @@ struct TORCH_API Type {
template <typename U = T, std::enable_if_t<std::is_same_v<Type, U>, bool> = true>
/* implicit */ SingletonOrSharedTypePtr(T* p) {
if (auto* shared_p = dynamic_cast<typename detail::as_shared_type<U>::type>(p)) {
repr_ = Repr(shared_p->shared_from_this());
repr_ = shared_p->shared_from_this();
} else {
repr_ = Repr(p);
repr_ = makeSingletonSharedPtr(p);
}
}
template <typename U = T, std::enable_if_t<!std::is_same_v<Type, U> && !std::is_base_of_v<SharedType, U>, bool> = true>
/* implicit */ SingletonOrSharedTypePtr(T* p)
: repr_(p) {
: repr_(makeSingletonSharedPtr(p)) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(dynamic_cast<typename detail::as_shared_type<U>::type>(p) == nullptr);
}
@ -230,19 +230,19 @@ struct TORCH_API Type {
~SingletonOrSharedTypePtr() = default;
T* get() const {
return repr_.isSharedAndNonNull() ? repr_.shared_.repr_.get() : static_cast<T*>(repr_.rawRepr().first);
return repr_.get();
}
operator bool() const {
return repr_.isNonNull();
return repr_ != nullptr;
}
bool operator==(std::nullptr_t) const {
return !repr_.isNonNull();
return repr_ == nullptr;
}
bool operator!=(std::nullptr_t) const {
return repr_.isNonNull();
return repr_ != nullptr;
}
template <typename U = T, std::enable_if_t<!std::is_same_v<std::remove_const_t<U>, void>, bool> = true>
@ -255,138 +255,14 @@ struct TORCH_API Type {
}
private:
// NOTE: SharedPtrWrapper exists to work around a baffling bug in
// nvcc; see comment in destroy() below.
struct SharedPtrWrapper {
SharedPtrWrapper(std::shared_ptr<T> &&x)
: repr_(std::move(x)) {}
std::shared_ptr<T> repr_;
};
union Repr {
Repr() : Repr(nullptr) {}
// Use shared_ptr's aliasing constructor to create a non-owning pointer
// to a singleton. The lifetime is tied to the null shared_ptr, so there's
// no reference counting overhead for the singleton itself.
static std::shared_ptr<T> makeSingletonSharedPtr(T* ptr) {
return std::shared_ptr<T>(std::shared_ptr<T>(), ptr);
}
explicit Repr(std::shared_ptr<T> x)
: shared_(std::move(x)) {}
explicit Repr(std::nullptr_t)
: singletonRepr_(nullptr) {}
explicit Repr(SingletonTypePtr<T> p)
: singletonRepr_(p.get()) {}
~Repr() {
destroy();
}
// NOTE: the only non-UB way to access our null state is through
// rawRepr(), because our copy operation doesn't preserve which
// union member is active for null pointers.
Repr(const Repr& rhs) {
if (rhs.isSharedAndNonNull()) {
new (&shared_) SharedPtrWrapper(rhs.shared_);
} else {
singletonRepr_.singleton_ = static_cast<T*>(rhs.rawRepr().first);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(rhs.singletonRepr_.unused_ == nullptr);
singletonRepr_.unused_ = nullptr;
}
}
Repr(Repr&& rhs) noexcept {
if (rhs.isSharedAndNonNull()) {
new (&shared_) SharedPtrWrapper(std::move(rhs.shared_));
} else {
singletonRepr_.singleton_ = static_cast<T*>(rhs.rawRepr().first);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(rhs.singletonRepr_.unused_ == nullptr);
singletonRepr_.unused_ = nullptr;
}
}
Repr& operator=(const Repr& rhs) {
if (&rhs == this) {
return *this;
}
if (rhs.isSharedAndNonNull()) {
if (isSharedAndNonNull()) {
shared_ = rhs.shared_;
} else {
new (&shared_) SharedPtrWrapper(rhs.shared_);
}
} else {
if (isSharedAndNonNull()) {
destroy();
}
singletonRepr_.singleton_ = static_cast<T*>(rhs.rawRepr().first);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(rhs.rawRepr().nullIfSingleton_ == nullptr);
singletonRepr_.unused_ = nullptr;
}
return *this;
}
Repr& operator=(Repr&& rhs) noexcept {
if (&rhs == this) {
return *this;
}
if (rhs.isSharedAndNonNull()) {
if (isSharedAndNonNull()) {
shared_ = std::move(rhs.shared_);
} else {
new (&shared_) SharedPtrWrapper(std::move(rhs.shared_));
}
} else {
if (isSharedAndNonNull()) {
destroy();
}
singletonRepr_.singleton_ = static_cast<T*>(rhs.rawRepr().first);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(rhs.rawRepr().nullIfSingleton_ == nullptr);
singletonRepr_.unused_ = nullptr;
}
return *this;
}
SharedPtrWrapper shared_;
struct SingletonRepr {
explicit SingletonRepr(T* s) : singleton_(s) {}
T* singleton_;
void* unused_ = nullptr;
} singletonRepr_;
struct RawRepr {
void* first;
void* nullIfSingleton_;
};
// It is UB to read the singleton part of Repr if it was
// constructed as a shared_ptr and vice versa, but memcpying out
// the representation is always OK, so here's an accessor to obey
// the letter of the law.
RawRepr rawRepr() const {
RawRepr repr{};
memcpy(&repr, reinterpret_cast<const char *>(this), sizeof(RawRepr));
return repr;
}
bool isNonNull() const {
auto repr = rawRepr();
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(repr.nullIfSingleton_ == nullptr || repr.first != nullptr);
return repr.first != nullptr;
}
bool isSharedAndNonNull() const {
return rawRepr().nullIfSingleton_ != nullptr;
}
private:
void destroy() {
if (isSharedAndNonNull()) {
// Without SharedPtrWrapper, this line would read
// `shared_.~shared_ptr()` and nvcc would complain with
// "error: expected primary-expression before '>' token"
// referring to the "t" in "shared_ptr". SharedPtrWrapper
// exists to work around this compiler bug.
shared_.~SharedPtrWrapper();
}
}
} repr_;
std::shared_ptr<T> repr_;
};
using TypePtr = SingletonOrSharedTypePtr<Type>;

View File

@ -8,6 +8,7 @@
#include <ATen/core/jit_type.h>
#include <c10/macros/Macros.h>
#include <c10/util/env.h>
#include <c10/util/Exception.h>
#include <c10/util/flat_hash_map.h>
#include <c10/util/irange.h>
#include <array>
@ -826,9 +827,7 @@ TupleType::TupleType(
: NamedType(TypeKind::TupleType, std::move(name)),
elements_(std::move(elements)),
has_free_variables_(std::any_of(elements_.begin(), elements_.end(), [](const TypePtr& v) {
if (!v) {
throw std::runtime_error("Can not create tuple with None type");
}
TORCH_CHECK(v, "Can not create tuple with None type");
return v->hasFreeVariables();
})), schema_(std::move(schema)) {

View File

@ -104,71 +104,6 @@ class Vectorized<float> {
}
return b;
}
// Implementation is picked from
// https://github.com/ARM-software/ComputeLibrary/blob/v25.01/src/core/NEON/SVEMath.inl#L105
inline svfloat32_t svexp_f32_z(svbool_t pg, svfloat32_t x) const {
const auto c1 =
svreinterpret_f32_u32(svdup_n_u32(0x3f7ffff6)); // x^1: 0x1.ffffecp-1f
const auto c2 =
svreinterpret_f32_u32(svdup_n_u32(0x3efffedb)); // x^2: 0x1.fffdb6p-2f
const auto c3 =
svreinterpret_f32_u32(svdup_n_u32(0x3e2aaf33)); // x^3: 0x1.555e66p-3f
const auto c4 =
svreinterpret_f32_u32(svdup_n_u32(0x3d2b9f17)); // x^4: 0x1.573e2ep-5f
const auto c5 =
svreinterpret_f32_u32(svdup_n_u32(0x3c072010)); // x^5: 0x1.0e4020p-7f
const auto shift = svreinterpret_f32_u32(
svdup_n_u32(0x4b00007f)); // 2^23 + 127 = 0x1.0000fep23f
const auto inv_ln2 = svreinterpret_f32_u32(
svdup_n_u32(0x3fb8aa3b)); // 1 / ln(2) = 0x1.715476p+0f
const auto neg_ln2_hi = svreinterpret_f32_u32(svdup_n_u32(
0xbf317200)); // -ln(2) from bits -1 to -19: -0x1.62e400p-1f
const auto neg_ln2_lo = svreinterpret_f32_u32(svdup_n_u32(
0xb5bfbe8e)); // -ln(2) from bits -20 to -42: -0x1.7f7d1cp-20f
const auto inf = svdup_n_f32(std::numeric_limits<float>::infinity());
const auto max_input = svdup_n_f32(88.37f); // Approximately ln(2^127.5)
const auto zero = svdup_n_f32(0.f);
const auto min_input = svdup_n_f32(-86.64f); // Approximately ln(2^-125)
// Range reduction:
// e^x = 2^n * e^r
// where:
// n = floor(x / ln(2))
// r = x - n * ln(2)
//
// By adding x / ln(2) with 2^23 + 127 (shift):
// * As FP32 fraction part only has 23-bits, the addition of 2^23 + 127
// forces decimal part
// of x / ln(2) out of the result. The integer part of x / ln(2) (i.e.
// n) + 127 will occupy the whole fraction part of z in FP32 format.
// Subtracting 2^23 + 127 (shift) from z will result in the integer part
// of x / ln(2) (i.e. n) because the decimal part has been pushed out
// and lost.
// * The addition of 127 makes the FP32 fraction part of z ready to be
// used as the exponent
// in FP32 format. Left shifting z by 23 bits will result in 2^n.
const auto z = svmla_f32_z(pg, shift, x, inv_ln2);
const auto n = svsub_f32_z(pg, z, shift);
const auto scale = svreinterpret_f32_u32(
svlsl_n_u32_z(pg, svreinterpret_u32_f32(z), 23)); // 2^n
// The calculation of n * ln(2) is done using 2 steps to achieve accuracy
// beyond FP32. This outperforms longer Taylor series (3-4 tabs) both in
// term of accuracy and performance.
const auto r_hi = svmla_f32_z(pg, x, n, neg_ln2_hi);
const auto r = svmla_f32_z(pg, r_hi, n, neg_ln2_lo);
// Compute the truncated Taylor series of e^r.
// poly = scale * (1 + c1 * r + c2 * r^2 + c3 * r^3 + c4 * r^4 + c5 * r^5)
const auto r2 = svmul_f32_z(pg, r, r);
const auto p1 = svmul_f32_z(pg, c1, r);
const auto p23 = svmla_f32_z(pg, c2, c3, r);
const auto p45 = svmla_f32_z(pg, c4, c5, r);
const auto p2345 = svmla_f32_z(pg, p23, p45, r2);
const auto p12345 = svmla_f32_z(pg, p1, p2345, r2);
auto poly = svmla_f32_z(pg, scale, p12345, scale);
// Handle underflow and overflow.
poly = svsel_f32(svcmplt_f32(pg, x, min_input), zero, poly);
poly = svsel_f32(svcmpgt_f32(pg, x, max_input), inf, poly);
return poly;
}
static Vectorized<float> loadu(const void* ptr, int64_t count = size()) {
if (count == size())
return svld1_f32(ptrue, reinterpret_cast<const float*>(ptr));
@ -313,11 +248,41 @@ class Vectorized<float> {
return USE_SLEEF(
Vectorized<float>(Sleef_expm1fx_u10sve(values)), map(std::expm1));
}
// Implementation copied from Arm Optimized Routines:
// https://github.com/ARM-software/optimized-routines/blob/master/math/aarch64/sve/expf.c
Vectorized<float> exp_u20() const {
return exp();
// special case to handle special inputs that are too large or too small
// i.e. where there's at least one element x, s.t. |x| >= 87.3...
svbool_t is_special_case = svacgt(svptrue_b32(), values, 0x1.5d5e2ap+6f);
if (svptest_any(svptrue_b32(), is_special_case)) {
return exp();
}
const svfloat32_t ln2_hi = svdup_n_f32(0x1.62e4p-1f);
const svfloat32_t ln2_lo = svdup_n_f32(0x1.7f7d1cp-20f);
const svfloat32_t c1 = svdup_n_f32(0.5f);
const svfloat32_t inv_ln2 = svdup_n_f32(0x1.715476p+0f);
const float shift = 0x1.803f8p17f;
/* n = round(x/(ln2/N)). */
svfloat32_t z = svmad_x(svptrue_b32(), inv_ln2, values, shift);
svfloat32_t n = svsub_x(svptrue_b32(), z, shift);
/* r = x - n*ln2/N. */
svfloat32_t r = values;
r = svmls_x(svptrue_b32(), r, n, ln2_hi);
r = svmls_x(svptrue_b32(), r, n, ln2_lo);
/* scale = 2^(n/N). */
svfloat32_t scale = svexpa(svreinterpret_u32(z));
/* poly(r) = exp(r) - 1 ~= r + 0.5 r^2. */
svfloat32_t r2 = svmul_x(svptrue_b32(), r, r);
svfloat32_t poly = svmla_x(svptrue_b32(), r, r2, c1);
return svmla_x(svptrue_b32(), scale, scale, poly);
}
Vectorized<float> fexp_u20() const {
return exp();
return exp_u20();
}
Vectorized<float> fmod(const Vectorized<float>& q) const {USE_SLEEF(
{ return Vectorized<float>(Sleef_fmodfx_sve(values, q)); },
@ -453,9 +418,11 @@ class Vectorized<float> {
ptrue, svmax_f32_z(ptrue, values, CONST_MIN_TANH), CONST_MAX_TANH);
// Step 2: Calculate exp(2 * x), where x is the clamped value.
// svmul_f32_z computes 2 * x, and svexp_f32_z computes the exponential of
// the result.
svfloat32_t exp2x = svexp_f32_z(ptrue, svmul_f32_z(ptrue, CONST_2, x));
// svmul_f32_z computes 2 * x, and exp_u20() computes the exponential of
// the result (via Vectorized<float>, then auto-converts back to
// svfloat32_t).
svfloat32_t exp2x =
Vectorized<float>(svmul_f32_z(ptrue, CONST_2, x)).exp_u20();
// Step 3: Calculate the numerator of the tanh function, which is exp(2x)
// - 1.

View File

@ -6,9 +6,11 @@
#ifdef __aarch64__
#if !defined(CPU_CAPABILITY_SVE)
#include <ATen/cpu/vec/vec128/vec128_bfloat16_neon.h>
#include <ATen/cpu/vec/vec128/vec128_double_neon.h>
#include <ATen/cpu/vec/vec128/vec128_float_neon.h>
#include <ATen/cpu/vec/vec128/vec128_half_neon.h>
#include <ATen/cpu/vec/vec128/vec128_int_aarch64.h>
#include <ATen/cpu/vec/vec128/vec128_uint_aarch64.h>
#endif
#include <ATen/cpu/vec/vec128/vec128_convert.h>

View File

@ -354,9 +354,47 @@ class Vectorized<c10::BFloat16> : public Vectorized16<
DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(abs)
Vectorized frac() const;
DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(neg)
DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(trunc)
DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(sqrt)
#ifdef __ARM_FEATURE_BF16
Vectorized<c10::BFloat16> neg() const {
return -values;
}
Vectorized<c10::BFloat16> reciprocal() const {
return 1.0f / values;
}
Vectorized<c10::BFloat16> operator==(
const Vectorized<c10::BFloat16>& other) const {
return values == other.values;
}
Vectorized<c10::BFloat16> operator!=(
const Vectorized<c10::BFloat16>& other) const {
return values != other.values;
}
Vectorized<c10::BFloat16> operator<(
const Vectorized<c10::BFloat16>& other) const {
return values < other.values;
}
Vectorized<c10::BFloat16> operator<=(
const Vectorized<c10::BFloat16>& other) const {
return values <= other.values;
}
Vectorized<c10::BFloat16> operator>(
const Vectorized<c10::BFloat16>& other) const {
return values > other.values;
}
Vectorized<c10::BFloat16> operator>=(
const Vectorized<c10::BFloat16>& other) const {
return values >= other.values;
}
#else
DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(neg)
DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(reciprocal)
DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator==)
DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator!=)
@ -364,6 +402,7 @@ class Vectorized<c10::BFloat16> : public Vectorized16<
DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator<=)
DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator>)
DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator>=)
#endif
#undef DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD
#undef DEFINE_BINARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD
@ -412,28 +451,52 @@ template <>
Vectorized<c10::BFloat16> inline operator+(
const Vectorized<c10::BFloat16>& a,
const Vectorized<c10::BFloat16>& b) {
#ifdef __ARM_FEATURE_BF16
bfloat16x8_t x = a;
bfloat16x8_t y = b;
return x + y;
#else
return binary_operator_via_float(std::plus<Vectorized<float>>(), a, b);
#endif
}
template <>
Vectorized<c10::BFloat16> inline operator-(
const Vectorized<c10::BFloat16>& a,
const Vectorized<c10::BFloat16>& b) {
#ifdef __ARM_FEATURE_BF16
bfloat16x8_t x = a;
bfloat16x8_t y = b;
return x - y;
#else
return binary_operator_via_float(std::minus<Vectorized<float>>(), a, b);
#endif
}
template <>
Vectorized<c10::BFloat16> inline operator*(
const Vectorized<c10::BFloat16>& a,
const Vectorized<c10::BFloat16>& b) {
#ifdef __ARM_FEATURE_BF16
bfloat16x8_t x = a;
bfloat16x8_t y = b;
return x * y;
#else
return binary_operator_via_float(std::multiplies<Vectorized<float>>(), a, b);
#endif
}
template <>
Vectorized<c10::BFloat16> inline operator/(
const Vectorized<c10::BFloat16>& a,
const Vectorized<c10::BFloat16>& b) {
#ifdef __ARM_FEATURE_BF16
bfloat16x8_t x = a;
bfloat16x8_t y = b;
return x / y;
#else
return binary_operator_via_float(std::divides<Vectorized<float>>(), a, b);
#endif
}
// frac. Implement this here so we can use subtraction
@ -544,12 +607,19 @@ Vectorized<c10::BFloat16> inline fmadd(
const Vectorized<c10::BFloat16>& a,
const Vectorized<c10::BFloat16>& b,
const Vectorized<c10::BFloat16>& c) {
#ifdef __ARM_FEATURE_BF16
bfloat16x8_t x = a;
bfloat16x8_t y = b;
bfloat16x8_t z = c;
return x * y + z;
#else
// NOTE [BF16 FMA]: There isn't an FMA that accumulates into BF16! Also,
// vbfmlalbq_f32 and vbfmlaltq_f32 take the even and odd-numbered
// elements, not the bottom and top half, so they don't seem
// particularly useful here. Ideally we would include dot product in
// the Vectorized interface...
return a * b + c;
#endif
}
template <>
@ -557,8 +627,15 @@ Vectorized<c10::BFloat16> inline fnmadd(
const Vectorized<c10::BFloat16>& a,
const Vectorized<c10::BFloat16>& b,
const Vectorized<c10::BFloat16>& c) {
#ifdef __ARM_FEATURE_BF16
bfloat16x8_t x = a;
bfloat16x8_t y = b;
bfloat16x8_t z = c;
return (-x) * y + z;
#else
// See NOTE [BF16 FMA] above.
return -a * b + c;
#endif
}
template <>
@ -566,8 +643,15 @@ Vectorized<c10::BFloat16> inline fmsub(
const Vectorized<c10::BFloat16>& a,
const Vectorized<c10::BFloat16>& b,
const Vectorized<c10::BFloat16>& c) {
#ifdef __ARM_FEATURE_BF16
bfloat16x8_t x = a;
bfloat16x8_t y = b;
bfloat16x8_t z = c;
return x * y - z;
#else
// See NOTE [BF16 FMA] above.
return a * b - c;
#endif
}
template <>
@ -575,8 +659,15 @@ Vectorized<c10::BFloat16> inline fnmsub(
const Vectorized<c10::BFloat16>& a,
const Vectorized<c10::BFloat16>& b,
const Vectorized<c10::BFloat16>& c) {
#ifdef __ARM_FEATURE_BF16
bfloat16x8_t x = a;
bfloat16x8_t y = b;
bfloat16x8_t z = c;
return (-x) * y - z;
#else
// See NOTE [BF16 FMA] above.
return -a * b - c;
#endif
}
#endif // !defined(C10_MOBILE) && defined(__aarch64__)

View File

@ -5,6 +5,114 @@
namespace at::vec {
inline namespace CPU_CAPABILITY {
#if (defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE256))
// Enable auto-vectorization for GCC-13+ and clang-17+
// GCC-12 has a bug: gcc.gnu.org/bugzilla/show_bug.cgi?id=117001
#if __GNUC__ > 12 || (defined(__clang__) && (__clang_major__ >= 17))
template <typename from_type, typename to_type>
inline void convertImpl(
const from_type* __restrict src,
to_type* __restrict dst,
int64_t n) {
uint64_t len = static_cast<uint64_t>(n);
for (uint64_t i = 0; i < len; i++) {
dst[i] = static_cast<to_type>(src[i]);
}
}
#define CONVERT_TEMPLATE(from_type, to_type) \
template <> \
inline void convert(const from_type* src, to_type* dst, int64_t n) { \
return convertImpl<from_type, to_type>(src, dst, n); \
}
CONVERT_TEMPLATE(uint8_t, uint8_t)
CONVERT_TEMPLATE(uint8_t, int8_t)
CONVERT_TEMPLATE(uint8_t, int16_t)
CONVERT_TEMPLATE(uint8_t, int32_t)
CONVERT_TEMPLATE(uint8_t, int64_t)
CONVERT_TEMPLATE(uint8_t, float)
CONVERT_TEMPLATE(uint8_t, double)
CONVERT_TEMPLATE(int8_t, uint8_t)
CONVERT_TEMPLATE(int8_t, int8_t)
CONVERT_TEMPLATE(int8_t, int16_t)
CONVERT_TEMPLATE(int8_t, int32_t)
CONVERT_TEMPLATE(int8_t, int64_t)
CONVERT_TEMPLATE(int8_t, float)
CONVERT_TEMPLATE(int8_t, double)
CONVERT_TEMPLATE(int16_t, uint8_t)
CONVERT_TEMPLATE(int16_t, int8_t)
CONVERT_TEMPLATE(int16_t, int16_t)
CONVERT_TEMPLATE(int16_t, int32_t)
CONVERT_TEMPLATE(int16_t, int64_t)
CONVERT_TEMPLATE(int16_t, float)
CONVERT_TEMPLATE(int16_t, double)
CONVERT_TEMPLATE(int32_t, uint8_t)
CONVERT_TEMPLATE(int32_t, int8_t)
CONVERT_TEMPLATE(int32_t, int16_t)
CONVERT_TEMPLATE(int32_t, int32_t)
CONVERT_TEMPLATE(int32_t, int64_t)
CONVERT_TEMPLATE(int32_t, float)
CONVERT_TEMPLATE(int32_t, double)
CONVERT_TEMPLATE(int64_t, uint8_t)
CONVERT_TEMPLATE(int64_t, int8_t)
CONVERT_TEMPLATE(int64_t, int16_t)
CONVERT_TEMPLATE(int64_t, int32_t)
CONVERT_TEMPLATE(int64_t, int64_t)
CONVERT_TEMPLATE(int64_t, float)
CONVERT_TEMPLATE(int64_t, double)
CONVERT_TEMPLATE(float, uint8_t)
CONVERT_TEMPLATE(float, int8_t)
CONVERT_TEMPLATE(float, int16_t)
CONVERT_TEMPLATE(float, int32_t)
CONVERT_TEMPLATE(float, int64_t)
CONVERT_TEMPLATE(float, float)
CONVERT_TEMPLATE(float, double)
CONVERT_TEMPLATE(double, uint8_t)
CONVERT_TEMPLATE(double, int8_t)
CONVERT_TEMPLATE(double, int16_t)
CONVERT_TEMPLATE(double, int32_t)
CONVERT_TEMPLATE(double, int64_t)
CONVERT_TEMPLATE(double, float)
CONVERT_TEMPLATE(double, double)
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
CONVERT_TEMPLATE(float16_t, uint8_t)
CONVERT_TEMPLATE(float16_t, int8_t)
CONVERT_TEMPLATE(float16_t, int16_t)
CONVERT_TEMPLATE(float16_t, int32_t)
CONVERT_TEMPLATE(float16_t, int64_t)
CONVERT_TEMPLATE(float16_t, float16_t)
CONVERT_TEMPLATE(float16_t, float)
CONVERT_TEMPLATE(float16_t, double)
CONVERT_TEMPLATE(uint8_t, float16_t)
CONVERT_TEMPLATE(int8_t, float16_t)
CONVERT_TEMPLATE(int16_t, float16_t)
CONVERT_TEMPLATE(int32_t, float16_t)
CONVERT_TEMPLATE(int64_t, float16_t)
CONVERT_TEMPLATE(float, float16_t)
CONVERT_TEMPLATE(double, float16_t)
#endif
#ifdef __ARM_FEATURE_BF16
CONVERT_TEMPLATE(bfloat16_t, uint8_t)
CONVERT_TEMPLATE(bfloat16_t, int8_t)
CONVERT_TEMPLATE(bfloat16_t, int16_t)
CONVERT_TEMPLATE(bfloat16_t, int32_t)
CONVERT_TEMPLATE(bfloat16_t, int64_t)
CONVERT_TEMPLATE(bfloat16_t, bfloat16_t)
CONVERT_TEMPLATE(bfloat16_t, float)
CONVERT_TEMPLATE(bfloat16_t, double)
CONVERT_TEMPLATE(uint8_t, bfloat16_t)
CONVERT_TEMPLATE(int8_t, bfloat16_t)
CONVERT_TEMPLATE(int16_t, bfloat16_t)
CONVERT_TEMPLATE(int32_t, bfloat16_t)
CONVERT_TEMPLATE(int64_t, bfloat16_t)
CONVERT_TEMPLATE(float, bfloat16_t)
CONVERT_TEMPLATE(double, bfloat16_t)
#endif
#endif
template <typename src_t>
struct VecConvert<
float,

View File

@ -0,0 +1,586 @@
#pragma once
#include <ATen/cpu/vec/intrinsics.h>
#include <ATen/cpu/vec/vec_base.h>
#include <c10/macros/Macros.h>
#include <c10/util/irange.h>
#include <cmath>
namespace at::vec {
// Note [CPU_CAPABILITY namespace]
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// This header, and all of its subheaders, will be compiled with
// different architecture flags for each supported set of vector
// intrinsics. So we need to make sure they aren't inadvertently
// linked together. We do this by declaring objects in an `inline
// namespace` which changes the name mangling, but can still be
// accessed as `at::vec`.
inline namespace CPU_CAPABILITY {
template <>
struct is_vec_specialized_for<double> : std::bool_constant<true> {};
template <>
class Vectorized<double> {
private:
float64x2_t values;
public:
using value_type = double;
using size_type = int;
static constexpr size_type size() {
return 2;
}
Vectorized() {
values = vdupq_n_f64(0.0);
}
Vectorized(float64x2_t v) : values(v) {}
Vectorized(double val) {
values = vdupq_n_f64(val);
}
template <
typename... Args,
typename = std::enable_if_t<(sizeof...(Args) == size())>>
Vectorized(Args... vals) {
__at_align__ double buffer[size()] = {vals...};
values = vld1q_f64(buffer);
}
operator float64x2_t() const {
return values;
}
template <int64_t mask>
static Vectorized<double> blend(
const Vectorized<double>& a,
const Vectorized<double>& b) {
// Build an array of flags: each bit of element is 1 if the corresponding
// bit in 'mask' is set, 0 otherwise.
uint64x2_t maskArray = {
(mask & 1ULL) ? 0xFFFFFFFFFFFFFFFF : 0,
(mask & 2ULL) ? 0xFFFFFFFFFFFFFFFF : 0};
// Use BSL to select elements from b where the mask is 1, else from a
return vbslq_f64(maskArray, b.values, a.values);
}
static Vectorized<double> blendv(
const Vectorized<double>& a,
const Vectorized<double>& b,
const Vectorized<double>& mask_) {
return vbslq_f64(vreinterpretq_u64_f64(mask_.values), b.values, a.values);
}
template <typename step_t>
static Vectorized<double> arange(
double base = 0.,
step_t step = static_cast<step_t>(1)) {
return {base, base + static_cast<double>(step)};
}
static inline Vectorized<double> set(
const Vectorized<double>& a,
const Vectorized<double>& b,
int64_t count = size()) {
if (count == 0) {
return a;
} else if (count >= 2) {
return b;
} else {
float64x2_t c = {b.values[0], a.values[1]};
return c;
}
}
static Vectorized<double> loadu(const void* ptr, int64_t count = size()) {
if (count == size()) {
return vld1q_f64(reinterpret_cast<const double*>(ptr));
} else if (count == 1) {
float64x1_t x = vld1_f64(reinterpret_cast<const double*>(ptr));
float64x1_t z = {0.0};
return vcombine_f64(x, z);
} else {
return vdupq_n_f64(0.0);
}
}
void store(void* ptr, int64_t count = size()) const {
if (count == size()) {
vst1q_f64(reinterpret_cast<double*>(ptr), values);
} else if (count == 1) {
vst1_f64(reinterpret_cast<double*>(ptr), vget_low_f64(values));
}
}
const double& operator[](int idx) const = delete;
double& operator[](int idx) = delete;
int64_t zero_mask() const {
// returns an integer mask where all zero elements are translated to 1-bit
// and others are translated to 0-bit
uint64x2_t cmpReg = vceqzq_f64(values);
uint64x2_t mask = {1, 2};
uint64x2_t res = vandq_u64(cmpReg, mask);
return res[0] | res[1];
}
Vectorized<double> isnan() const {
// NaN check
return vreinterpretq_f64_u32(
vmvnq_u32(vreinterpretq_u32_u64(vceqq_f64(values, values))));
}
bool has_inf_nan() const {
Vectorized<double> x = vsubq_f64(values, values);
float64x2_t r = x.isnan();
uint64x2_t u = vreinterpretq_u64_f64(r);
return u[0] | u[1];
}
Vectorized<double> map(double (*f)(double)) const {
float64x2_t result;
result[0] = f(values[0]);
result[1] = f(values[1]);
return result;
}
Vectorized<double> map2(
const Vectorized<double>& second,
double (*const f)(double, double)) const {
float64x2_t result;
result[0] = f(values[0], second.values[0]);
result[1] = f(values[1], second.values[1]);
return result;
}
Vectorized<double> abs() const {
return vabsq_f64(values);
}
Vectorized<double> angle() const {
auto zero = Vectorized<double>(0.0);
auto pi = Vectorized<double>(c10::pi<double>);
auto tmp = blendv(zero, pi, vreinterpretq_f64_u64(vcltzq_f64(values)));
return blendv(tmp, *this, isnan());
}
Vectorized<double> real() const {
return *this;
}
Vectorized<double> imag() const {
return Vectorized<double>(0.0);
}
Vectorized<double> conj() const {
return *this;
}
Vectorized<double> acos() const {
return USE_SLEEF(
Vectorized<double>(Sleef_acosd2_u10(values)), map(std::acos));
}
Vectorized<double> acosh() const {
return USE_SLEEF(
Vectorized<double>(Sleef_acoshd2_u10(values)), map(std::acosh));
}
Vectorized<double> asin() const {
return USE_SLEEF(
Vectorized<double>(Sleef_asind2_u10(values)), map(std::asin));
}
Vectorized<double> asinh() const {
return USE_SLEEF(
Vectorized<double>(Sleef_asinhd2_u10(values)), map(std::asinh));
}
Vectorized<double> atan() const {
return USE_SLEEF(
Vectorized<double>(Sleef_atand2_u10(values)), map(std::atan));
}
Vectorized<double> atanh() const {
return USE_SLEEF(
Vectorized<double>(Sleef_atanhd2_u10(values)), map(std::atanh));
}
Vectorized<double> atan2(const Vectorized<double>& b) const {USE_SLEEF(
{ return Vectorized<double>(Sleef_atan2d2_u10(values, b)); },
{
__at_align__ double tmp[size()];
__at_align__ double tmp_b[size()];
store(tmp);
b.store(tmp_b);
for (int64_t i = 0; i < size(); i++) {
tmp[i] = std::atan2(tmp[i], tmp_b[i]);
}
return loadu(tmp);
})} Vectorized<double> copysign(const Vectorized<double>& sign) const {
USE_SLEEF(
{ return Vectorized<double>(Sleef_copysignd2(values, sign)); },
{
__at_align__ double tmp[size()];
__at_align__ double tmp_sign[size()];
store(tmp);
sign.store(tmp_sign);
for (int64_t i = 0; i < size(); i++) {
tmp[i] = std::copysign(tmp[i], tmp_sign[i]);
}
return loadu(tmp);
})} Vectorized<double> erf() const {
return USE_SLEEF(
Vectorized<double>(Sleef_erfd2_u10(values)), map(std::erf));
}
Vectorized<double> erfc() const {
return USE_SLEEF(
Vectorized<double>(Sleef_erfcd2_u15(values)), map(std::erfc));
}
Vectorized<double> exp() const {
return USE_SLEEF(
Vectorized<double>(Sleef_expd2_u10(values)), map(std::exp));
}
Vectorized<double> exp2() const {
return USE_SLEEF(
Vectorized<double>(Sleef_exp2d2_u10(values)), map(std::exp2));
}
Vectorized<double> expm1() const {
return USE_SLEEF(
Vectorized<double>(Sleef_expm1d2_u10(values)), map(std::expm1));
}
Vectorized<double> fmod(const Vectorized<double>& q) const {USE_SLEEF(
{ return Vectorized<double>(Sleef_fmodd2(values, q)); },
{
__at_align__ double tmp[size()];
__at_align__ double tmp_q[size()];
store(tmp);
q.store(tmp_q);
for (int64_t i = 0; i < size(); i++) {
tmp[i] = std::fmod(tmp[i], tmp_q[i]);
}
return loadu(tmp);
})} Vectorized<double> hypot(const Vectorized<double>& b) const {
USE_SLEEF(
{ return Vectorized<double>(Sleef_hypotd2_u05(values, b)); },
{
__at_align__ double tmp[size()];
__at_align__ double tmp_b[size()];
store(tmp);
b.store(tmp_b);
for (int64_t i = 0; i < size(); i++) {
tmp[i] = std::hypot(tmp[i], tmp_b[i]);
}
return loadu(tmp);
})} Vectorized<double> i0() const {
return map(calc_i0);
}
Vectorized<double> nextafter(const Vectorized<double>& b) const {USE_SLEEF(
{ return Vectorized<double>(Sleef_nextafterd2(values, b)); },
{
__at_align__ double tmp[size()];
__at_align__ double tmp_b[size()];
store(tmp);
b.store(tmp_b);
for (int64_t i = 0; i < size(); ++i) {
tmp[i] = std::nextafter(tmp[i], tmp_b[i]);
}
return loadu(tmp);
})} Vectorized<double> log() const {
return USE_SLEEF(
Vectorized<double>(Sleef_logd2_u10(values)), map(std::log));
}
Vectorized<double> log2() const {
return USE_SLEEF(
Vectorized<double>(Sleef_log2d2_u10(values)), map(std::log2));
}
Vectorized<double> log10() const {
return USE_SLEEF(
Vectorized<double>(Sleef_log10d2_u10(values)), map(std::log10));
}
Vectorized<double> log1p() const {
return USE_SLEEF(
Vectorized<double>(Sleef_log1pd2_u10(values)), map(std::log1p));
}
Vectorized<double> frac() const;
Vectorized<double> sin() const {
return USE_SLEEF(
Vectorized<double>(Sleef_sind2_u10(values)), map(std::sin));
}
Vectorized<double> sinh() const {
return USE_SLEEF(
Vectorized<double>(Sleef_sinhd2_u10(values)), map(std::sinh));
}
Vectorized<double> cos() const {
return USE_SLEEF(
Vectorized<double>(Sleef_cosd2_u10(values)), map(std::cos));
}
Vectorized<double> cosh() const {
return USE_SLEEF(
Vectorized<double>(Sleef_coshd2_u10(values)), map(std::cosh));
}
Vectorized<double> pow(const Vectorized<double>& b) const {USE_SLEEF(
{ return Vectorized<double>(Sleef_powd2_u10(values, b)); },
{
__at_align__ double tmp[size()];
__at_align__ double tmp_b[size()];
store(tmp);
b.store(tmp_b);
for (int64_t i = 0; i < size(); i++) {
tmp[i] = std::pow(tmp[i], tmp_b[i]);
}
return loadu(tmp);
})} // Comparison using the _CMP_**_OQ predicate.
// `O`: get false if an operand is NaN
// `Q`: do not raise if an operand is NaN
Vectorized<double> tan() const {
return USE_SLEEF(
Vectorized<double>(Sleef_tand2_u10(values)), map(std::tan));
}
Vectorized<double> tanh() const {
return USE_SLEEF(
Vectorized<double>(Sleef_tanhd2_u10(values)), map(std::tanh));
}
Vectorized<double> lgamma() const {
return USE_SLEEF(
Vectorized<double>(Sleef_lgammad2_u10(values)), map(std::lgamma));
}
Vectorized<double> erfinv() const {
return map(calc_erfinv);
}
Vectorized<double> exp_u20() const {
return exp();
}
Vectorized<double> fexp_u20() const {
return exp();
}
Vectorized<double> i0e() const {
return map(calc_i0e);
}
Vectorized<double> digamma() const {
return map(calc_digamma);
}
Vectorized<double> igamma(const Vectorized<double>& x) const {
__at_align__ double tmp[size()];
__at_align__ double tmp_x[size()];
store(tmp);
x.store(tmp_x);
for (int64_t i = 0; i < size(); i++) {
tmp[i] = calc_igamma(tmp[i], tmp_x[i]);
}
return loadu(tmp);
}
Vectorized<double> igammac(const Vectorized<double>& x) const {
__at_align__ double tmp[size()];
__at_align__ double tmp_x[size()];
store(tmp);
x.store(tmp_x);
for (int64_t i = 0; i < size(); i++) {
tmp[i] = calc_igammac(tmp[i], tmp_x[i]);
}
return loadu(tmp);
}
Vectorized<double> ceil() const {
return vrndpq_f64(values);
}
Vectorized<double> floor() const {
return vrndmq_f64(values);
}
Vectorized<double> neg() const {
return vnegq_f64(values);
}
Vectorized<double> round() const {
return vrndiq_f64(values);
}
Vectorized<double> trunc() const {
return vrndq_f64(values);
}
Vectorized<double> sqrt() const {
return vsqrtq_f64(values);
}
Vectorized<double> reciprocal() const {
return vdivq_f64(vdupq_n_f64(1.0), values);
}
Vectorized<double> rsqrt() const {
return vdivq_f64(vdupq_n_f64(1.0), vsqrtq_f64(values));
}
double reduce_add() const {
return vaddvq_f64(values);
}
double reduce_max() const {
return vmaxvq_f64(values);
}
Vectorized<double> operator==(const Vectorized<double>& other) const {
return Vectorized<double>(
vreinterpretq_f64_u64(vceqq_f64(values, other.values)));
}
Vectorized<double> operator!=(const Vectorized<double>& other) const {
float64x2_t r0 = vreinterpretq_f64_u32(
vmvnq_u32(vreinterpretq_u32_u64(vceqq_f64(values, other.values))));
return Vectorized<double>(r0);
}
Vectorized<double> operator<(const Vectorized<double>& other) const {
return Vectorized<double>(
vreinterpretq_f64_u64(vcltq_f64(values, other.values)));
}
Vectorized<double> operator<=(const Vectorized<double>& other) const {
return Vectorized<double>(
vreinterpretq_f64_u64(vcleq_f64(values, other.values)));
}
Vectorized<double> operator>(const Vectorized<double>& other) const {
return Vectorized<double>(
vreinterpretq_f64_u64(vcgtq_f64(values, other.values)));
}
Vectorized<double> operator>=(const Vectorized<double>& other) const {
return Vectorized<double>(
vreinterpretq_f64_u64(vcgeq_f64(values, other.values)));
}
Vectorized<double> eq(const Vectorized<double>& other) const;
Vectorized<double> ne(const Vectorized<double>& other) const;
Vectorized<double> gt(const Vectorized<double>& other) const;
Vectorized<double> ge(const Vectorized<double>& other) const;
Vectorized<double> lt(const Vectorized<double>& other) const;
Vectorized<double> le(const Vectorized<double>& other) const;
};
template <>
Vectorized<double> inline operator+(
const Vectorized<double>& a,
const Vectorized<double>& b) {
return vaddq_f64(a, b);
}
template <>
Vectorized<double> inline operator-(
const Vectorized<double>& a,
const Vectorized<double>& b) {
return vsubq_f64(a, b);
}
template <>
Vectorized<double> inline operator*(
const Vectorized<double>& a,
const Vectorized<double>& b) {
return vmulq_f64(a, b);
}
template <>
Vectorized<double> inline operator/(
const Vectorized<double>& a,
const Vectorized<double>& b) {
return vdivq_f64(a, b);
}
// frac. Implement this here so we can use subtraction
Vectorized<double> inline Vectorized<double>::frac() const {
return *this - this->trunc();
}
// Implements the IEEE 754 201X `maximum` operation, which propagates NaN if
// either input is a NaN.
template <>
Vectorized<double> inline maximum(
const Vectorized<double>& a,
const Vectorized<double>& b) {
return vmaxq_f64(a, b);
}
// Implements the IEEE 754 201X `minimum` operation, which propagates NaN if
// either input is a NaN.
template <>
Vectorized<double> inline minimum(
const Vectorized<double>& a,
const Vectorized<double>& b) {
return vminq_f64(a, b);
}
template <>
Vectorized<double> inline clamp(
const Vectorized<double>& a,
const Vectorized<double>& min,
const Vectorized<double>& max) {
return vminq_f64(max, vmaxq_f64(min, a));
}
template <>
Vectorized<double> inline clamp_max(
const Vectorized<double>& a,
const Vectorized<double>& max) {
return vminq_f64(max, a);
}
template <>
Vectorized<double> inline clamp_min(
const Vectorized<double>& a,
const Vectorized<double>& min) {
return vmaxq_f64(min, a);
}
template <>
Vectorized<double> inline operator&(
const Vectorized<double>& a,
const Vectorized<double>& b) {
return vreinterpretq_f64_u64(
vandq_u64(vreinterpretq_u64_f64(a), vreinterpretq_u64_f64(b)));
}
template <>
Vectorized<double> inline operator|(
const Vectorized<double>& a,
const Vectorized<double>& b) {
return vreinterpretq_f64_u64(
vorrq_u64(vreinterpretq_u64_f64(a), vreinterpretq_u64_f64(b)));
}
template <>
Vectorized<double> inline operator^(
const Vectorized<double>& a,
const Vectorized<double>& b) {
return vreinterpretq_f64_u64(
veorq_u64(vreinterpretq_u64_f64(a), vreinterpretq_u64_f64(b)));
}
inline Vectorized<double> Vectorized<double>::eq(
const Vectorized<double>& other) const {
return (*this == other) & Vectorized<double>(1.0);
}
inline Vectorized<double> Vectorized<double>::ne(
const Vectorized<double>& other) const {
return (*this != other) & Vectorized<double>(1.0);
}
inline Vectorized<double> Vectorized<double>::gt(
const Vectorized<double>& other) const {
return (*this > other) & Vectorized<double>(1.0);
}
inline Vectorized<double> Vectorized<double>::ge(
const Vectorized<double>& other) const {
return (*this >= other) & Vectorized<double>(1.0);
}
inline Vectorized<double> Vectorized<double>::lt(
const Vectorized<double>& other) const {
return (*this < other) & Vectorized<double>(1.0);
}
inline Vectorized<double> Vectorized<double>::le(
const Vectorized<double>& other) const {
return (*this <= other) & Vectorized<double>(1.0);
}
template <>
Vectorized<double> inline fmadd(
const Vectorized<double>& a,
const Vectorized<double>& b,
const Vectorized<double>& c) {
return vfmaq_f64(c, a, b);
}
template <>
Vectorized<double> inline fnmadd(
const Vectorized<double>& a,
const Vectorized<double>& b,
const Vectorized<double>& c) {
return vfmsq_f64(c, a, b);
}
template <>
Vectorized<double> inline fmsub(
const Vectorized<double>& a,
const Vectorized<double>& b,
const Vectorized<double>& c) {
return vfmaq_f64(vnegq_f64(c), a, b);
}
template <>
Vectorized<double> inline fnmsub(
const Vectorized<double>& a,
const Vectorized<double>& b,
const Vectorized<double>& c) {
return vfmsq_f64(vnegq_f64(c), a, b);
}
} // namespace CPU_CAPABILITY
} // namespace at::vec

View File

@ -307,11 +307,49 @@ class Vectorized<float> {
DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(exp)
DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(exp2)
DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(expm1)
// Implementation copied from Arm Optimized Routine
// https://github.com/ARM-software/optimized-routines/blob/master/math/aarch64/advsimd/expf.c
Vectorized<float> exp_u20() const {
return exp();
// bail out to sleef if it's a special case:
// i.e. there's an input s.t. |input| > 87.3....
const float32x4_t special_bound = vdupq_n_f32(0x1.5d5e2ap+6f);
uint32x4_t cmp = vcagtq_f32(values, special_bound);
if (vpaddd_u64(vreinterpretq_u64_u32(cmp)) != 0) {
return exp();
}
const float32x4_t inv_ln2 = vdupq_n_f32(0x1.715476p+0f);
const float ln2_hi = 0x1.62e4p-1f;
const float ln2_lo = 0x1.7f7d1cp-20f;
const float c0 = 0x1.0e4020p-7f;
const float c2 = 0x1.555e66p-3f;
const float32x4_t ln2_c02 = {ln2_hi, ln2_lo, c0, c2};
const uint32x4_t exponent_bias = vdupq_n_u32(0x3f800000);
const float32x4_t c1 = vdupq_n_f32(0x1.573e2ep-5f);
const float32x4_t c3 = vdupq_n_f32(0x1.fffdb6p-2f);
const float32x4_t c4 = vdupq_n_f32(0x1.ffffecp-1f);
/* exp(x) = 2^n (1 + poly(r)), with 1 + poly(r) in [1/sqrt(2),sqrt(2)]
x = ln2*n + r, with r in [-ln2/2, ln2/2]. */
float32x4_t n = vrndaq_f32(vmulq_f32(values, inv_ln2));
float32x4_t r = vfmsq_laneq_f32(values, n, ln2_c02, 0);
r = vfmsq_laneq_f32(r, n, ln2_c02, 1);
uint32x4_t e = vshlq_n_u32(vreinterpretq_u32_s32(vcvtq_s32_f32(n)), 23);
float32x4_t scale = vreinterpretq_f32_u32(vaddq_u32(e, exponent_bias));
float32x4_t r2 = vmulq_f32(r, r);
float32x4_t p = vfmaq_laneq_f32(c1, r, ln2_c02, 2);
float32x4_t q = vfmaq_laneq_f32(c3, r, ln2_c02, 3);
q = vfmaq_f32(q, p, r2);
p = vmulq_f32(c4, r);
float32x4_t poly = vfmaq_f32(p, q, r2);
return vfmaq_f32(scale, poly, scale);
}
Vectorized<float> fexp_u20() const {
return exp();
return exp_u20();
}
DEFINE_SLEEF_COMPATIBLE_BINARY_ELEMENTWISE_FUNC_WITH_SLEEF_NAME(
fmod,
@ -540,42 +578,6 @@ inline Vectorized<float> Vectorized<float>::le(
return (*this <= other) & Vectorized<float>(1.0f);
}
template <>
inline void convert(const float* src, int32_t* dst, int64_t n) {
int64_t i;
#ifndef __msvc_cl__
#pragma unroll
#endif
for (i = 0; i <= (n - Vectorized<float>::size());
i += Vectorized<float>::size()) {
vst1q_s32(dst + i, vcvtq_s32_f32(vld1q_f32(src + i)));
}
#ifndef __msvc_cl__
#pragma unroll
#endif
for (; i < n; i++) {
dst[i] = static_cast<int32_t>(src[i]);
}
}
template <>
inline void convert(const int32_t* src, float* dst, int64_t n) {
int64_t i;
#ifndef __msvc_cl__
#pragma unroll
#endif
for (i = 0; i <= (n - Vectorized<float>::size());
i += Vectorized<float>::size()) {
vst1q_f32(dst + i, vcvtq_f32_s32(vld1q_s32(src + i)));
}
#ifndef __msvc_cl__
#pragma unroll
#endif
for (; i < n; i++) {
dst[i] = static_cast<float>(src[i]);
}
}
template <>
Vectorized<float> inline fmadd(
const Vectorized<float>& a,

View File

@ -569,46 +569,6 @@ inline Vectorized<c10::Half> Vectorized<c10::Half>::le(
return (*this <= other) & Vectorized<c10::Half>(1);
}
// These are global functions, so the defaults in vec_base.h should
// work fine if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC is not available.
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
template <>
inline void convert(const float16_t* src, int16_t* dst, int64_t n) {
int64_t i;
#ifndef __msvc_cl__
#pragma unroll
#endif
for (i = 0; i <= (n - Vectorized<c10::Half>::size());
i += Vectorized<c10::Half>::size()) {
vst1q_s16(dst + i, vcvtq_s16_f16(vld1q_f16(src + i)));
}
#ifndef __msvc_cl__
#pragma unroll
#endif
for (; i < n; i++) {
dst[i] = static_cast<int16_t>(src[i]);
}
}
template <>
inline void convert(const int16_t* src, float16_t* dst, int64_t n) {
int64_t i;
#ifndef __msvc_cl__
#pragma unroll
#endif
for (i = 0; i <= (n - Vectorized<c10::Half>::size());
i += Vectorized<c10::Half>::size()) {
vst1q_f16(dst + i, vcvtq_f16_s16(vld1q_s16(src + i)));
}
#ifndef __msvc_cl__
#pragma unroll
#endif
for (; i < n; i++) {
dst[i] = static_cast<float16_t>(src[i]);
}
}
#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
template <>
Vectorized<c10::Half> inline fmadd(
const Vectorized<c10::Half>& a,

View File

@ -0,0 +1,378 @@
#pragma once
#include <ATen/cpu/vec/intrinsics.h>
#include <ATen/cpu/vec/vec_base.h>
#include <c10/macros/Macros.h>
#include <c10/util/irange.h>
namespace at::vec {
// Note [CPU_CAPABILITY namespace]
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// This header, and all of its subheaders, will be compiled with
// different architecture flags for each supported set of vector
// intrinsics. So we need to make sure they aren't inadvertently
// linked together. We do this by declaring objects in an `inline
// namespace` which changes the name mangling, but can still be
// accessed as `at::vec`.
inline namespace CPU_CAPABILITY {
#define VEC_UINT_NEON_TEMPLATE(vl, bit) \
template <> \
struct is_vec_specialized_for<uint##bit##_t> : std::bool_constant<true> {}; \
\
template <> \
class Vectorized<uint##bit##_t> { \
using neon_type = uint##bit##x##vl##_t; \
\
private: \
neon_type values; \
\
public: \
using value_type = uint##bit##_t; \
using size_type = int; \
static constexpr size_type size() { \
return vl; \
} \
Vectorized() { \
values = vdupq_n_u##bit(0); \
} \
Vectorized(neon_type v) : values(v) {} \
Vectorized(uint##bit##_t val); \
template < \
typename... Args, \
typename = std::enable_if_t<(sizeof...(Args) == size())>> \
Vectorized(Args... vals) { \
__at_align__ uint##bit##_t buffer[size()] = {vals...}; \
values = vld1q_u##bit(buffer); \
} \
operator neon_type() const { \
return values; \
} \
static Vectorized<uint##bit##_t> loadu( \
const void* ptr, \
uint64_t count = size()); \
void store(void* ptr, uint64_t count = size()) const; \
template <uint64_t mask> \
static Vectorized<uint##bit##_t> blend( \
const Vectorized<uint##bit##_t>& a, \
const Vectorized<uint##bit##_t>& b); \
static Vectorized<uint##bit##_t> blendv( \
const Vectorized<uint##bit##_t>& a, \
const Vectorized<uint##bit##_t>& b, \
const Vectorized<uint##bit##_t>& mask_) { \
return vbslq_u##bit(mask_.values, b, a); \
} \
template <typename step_t> \
static Vectorized<uint##bit##_t> arange( \
value_type base = 0, \
step_t step = static_cast<step_t>(1)); \
static Vectorized<uint##bit##_t> set( \
const Vectorized<uint##bit##_t>& a, \
const Vectorized<uint##bit##_t>& b, \
uint64_t count = size()); \
const uint##bit##_t& operator[](uint idx) const = delete; \
uint##bit##_t& operator[](uint idx) = delete; \
Vectorized<uint##bit##_t> abs() const { \
return values; \
} \
Vectorized<uint##bit##_t> real() const { \
return values; \
} \
Vectorized<uint##bit##_t> imag() const { \
return vdupq_n_u##bit(0); \
} \
Vectorized<uint##bit##_t> conj() const { \
return values; \
} \
Vectorized<uint##bit##_t> neg() const { \
return vreinterpretq_u##bit##_s##bit( \
vnegq_s##bit(vreinterpretq_s##bit##_u##bit(values))); \
} \
uint##bit##_t reduce_add() const { \
return vaddvq_u##bit(values); \
} \
uint##bit##_t reduce_max() const; \
Vectorized<uint##bit##_t> operator==( \
const Vectorized<uint##bit##_t>& other) const { \
return Vectorized<value_type>(vceqq_u##bit(values, other.values)); \
} \
Vectorized<uint##bit##_t> operator!=( \
const Vectorized<uint##bit##_t>& other) const; \
Vectorized<uint##bit##_t> operator<( \
const Vectorized<uint##bit##_t>& other) const { \
return Vectorized<value_type>(vcltq_u##bit(values, other.values)); \
} \
Vectorized<uint##bit##_t> operator<=( \
const Vectorized<uint##bit##_t>& other) const { \
return Vectorized<value_type>(vcleq_u##bit(values, other.values)); \
} \
Vectorized<uint##bit##_t> operator>( \
const Vectorized<uint##bit##_t>& other) const { \
return Vectorized<value_type>(vcgtq_u##bit(values, other.values)); \
} \
Vectorized<uint##bit##_t> operator>=( \
const Vectorized<uint##bit##_t>& other) const { \
return Vectorized<value_type>(vcgeq_u##bit(values, other.values)); \
} \
Vectorized<uint##bit##_t> eq( \
const Vectorized<uint##bit##_t>& other) const; \
Vectorized<uint##bit##_t> ne( \
const Vectorized<uint##bit##_t>& other) const; \
Vectorized<uint##bit##_t> gt( \
const Vectorized<uint##bit##_t>& other) const; \
Vectorized<uint##bit##_t> ge( \
const Vectorized<uint##bit##_t>& other) const; \
Vectorized<uint##bit##_t> lt( \
const Vectorized<uint##bit##_t>& other) const; \
Vectorized<uint##bit##_t> le( \
const Vectorized<uint##bit##_t>& other) const; \
}; \
template <> \
Vectorized<uint##bit##_t> inline operator+( \
const Vectorized<uint##bit##_t>& a, \
const Vectorized<uint##bit##_t>& b) { \
return vaddq_u##bit(a, b); \
} \
template <> \
Vectorized<uint##bit##_t> inline operator-( \
const Vectorized<uint##bit##_t>& a, \
const Vectorized<uint##bit##_t>& b) { \
return vsubq_u##bit(a, b); \
} \
template <> \
Vectorized<uint##bit##_t> inline operator&( \
const Vectorized<uint##bit##_t>& a, \
const Vectorized<uint##bit##_t>& b) { \
return vandq_u##bit(a, b); \
} \
template <> \
Vectorized<uint##bit##_t> inline operator|( \
const Vectorized<uint##bit##_t>& a, \
const Vectorized<uint##bit##_t>& b) { \
return vorrq_u##bit(a, b); \
} \
template <> \
Vectorized<uint##bit##_t> inline operator^( \
const Vectorized<uint##bit##_t>& a, \
const Vectorized<uint##bit##_t>& b) { \
return veorq_u##bit(a, b); \
} \
Vectorized<uint##bit##_t> inline Vectorized<uint##bit##_t>::eq( \
const Vectorized<uint##bit##_t>& other) const { \
return (*this == other) & Vectorized<uint##bit##_t>(1); \
} \
Vectorized<uint##bit##_t> inline Vectorized<uint##bit##_t>::ne( \
const Vectorized<uint##bit##_t>& other) const { \
return (*this != other) & Vectorized<uint##bit##_t>(1); \
} \
Vectorized<uint##bit##_t> inline Vectorized<uint##bit##_t>::gt( \
const Vectorized<uint##bit##_t>& other) const { \
return (*this > other) & Vectorized<uint##bit##_t>(1); \
} \
Vectorized<uint##bit##_t> inline Vectorized<uint##bit##_t>::ge( \
const Vectorized<uint##bit##_t>& other) const { \
return (*this >= other) & Vectorized<uint##bit##_t>(1); \
} \
Vectorized<uint##bit##_t> inline Vectorized<uint##bit##_t>::lt( \
const Vectorized<uint##bit##_t>& other) const { \
return (*this < other) & Vectorized<uint##bit##_t>(1); \
} \
Vectorized<uint##bit##_t> inline Vectorized<uint##bit##_t>::le( \
const Vectorized<uint##bit##_t>& other) const { \
return (*this <= other) & Vectorized<uint##bit##_t>(1); \
}
VEC_UINT_NEON_TEMPLATE(16, 8)
inline uint8_t Vectorized<uint8_t>::reduce_max() const {
return vmaxvq_u8(values);
}
template <>
Vectorized<uint8_t> inline operator*(
const Vectorized<uint8_t>& a,
const Vectorized<uint8_t>& b) {
return vmulq_u8(a, b);
}
template <>
inline Vectorized<uint8_t> operator~(const Vectorized<uint8_t>& a) {
return vmvnq_u8(a);
}
inline Vectorized<uint8_t> Vectorized<uint8_t>::operator!=(
const Vectorized<uint8_t>& other) const {
return ~(*this == other);
}
template <>
Vectorized<uint8_t> inline minimum(
const Vectorized<uint8_t>& a,
const Vectorized<uint8_t>& b) {
return vminq_u8(a, b);
}
template <>
Vectorized<uint8_t> inline maximum(
const Vectorized<uint8_t>& a,
const Vectorized<uint8_t>& b) {
return vmaxq_u8(a, b);
}
template <uint64_t mask>
Vectorized<uint8_t> Vectorized<uint8_t>::blend(
const Vectorized<uint8_t>& a,
const Vectorized<uint8_t>& b) {
// Build an array of flags: each bit of element is 1 if the corresponding bit
// in 'mask' is set, 0 otherwise.
uint8x16_t maskArray = {
(mask & 1LL) ? 0xFF : 0,
(mask & 2LL) ? 0xFF : 0,
(mask & 4LL) ? 0xFF : 0,
(mask & 8LL) ? 0xFF : 0,
(mask & 16LL) ? 0xFF : 0,
(mask & 32LL) ? 0xFF : 0,
(mask & 64LL) ? 0xFF : 0,
(mask & 128LL) ? 0xFF : 0,
(mask & 256LL) ? 0xFF : 0,
(mask & 512LL) ? 0xFF : 0,
(mask & 1024LL) ? 0xFF : 0,
(mask & 2048LL) ? 0xFF : 0,
(mask & 4096LL) ? 0xFF : 0,
(mask & 8192LL) ? 0xFF : 0,
(mask & 16384LL) ? 0xFF : 0,
(mask & 32768LL) ? 0xFF : 0};
// Use BSL to select elements from b where the mask is 1, else from a
return vbslq_u8(maskArray, b.values, a.values);
}
#define VEC_UINT_NEON_OPS(vl, bit) \
inline Vectorized<uint##bit##_t>::Vectorized(uint##bit##_t val) { \
values = vdupq_n_u##bit(val); \
} \
inline Vectorized<uint##bit##_t> Vectorized<uint##bit##_t>::loadu( \
const void* ptr, uint64_t count) { \
if (count == size()) { \
return vld1q_u##bit(reinterpret_cast<const uint##bit##_t*>(ptr)); \
} else { \
__at_align__ uint##bit##_t tmp_values[size()]; \
for (const auto i : c10::irange(size())) { \
tmp_values[i] = 0; \
} \
std::memcpy( \
tmp_values, \
reinterpret_cast<const uint##bit##_t*>(ptr), \
count * sizeof(uint##bit##_t)); \
return vld1q_u##bit(reinterpret_cast<const uint##bit##_t*>(tmp_values)); \
} \
} \
inline void Vectorized<uint##bit##_t>::store(void* ptr, uint64_t count) \
const { \
if (count == size()) { \
vst1q_u##bit(reinterpret_cast<uint##bit##_t*>(ptr), values); \
} else { \
uint##bit##_t tmp_values[size()]; \
vst1q_u##bit(reinterpret_cast<uint##bit##_t*>(tmp_values), values); \
std::memcpy(ptr, tmp_values, count * sizeof(uint##bit##_t)); \
} \
}
VEC_UINT_NEON_OPS(16, 8)
template <typename step_t>
inline Vectorized<uint8_t> Vectorized<uint8_t>::arange(
uint8_t base,
step_t step) {
const Vectorized<uint8_t> base_vec(base);
const Vectorized<uint8_t> step_vec(step);
const uint8x16_t step_sizes = {
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15};
return vmlaq_u8(base_vec, step_sizes, step_vec);
}
template <>
Vectorized<uint8_t> inline operator>>(
const Vectorized<uint8_t>& a,
const Vectorized<uint8_t>& b) {
uint8x16_t x = a;
uint8x16_t bound = vdupq_n_u8(8);
uint8x16_t z = vminq_u8(b, bound);
return x >> z;
}
template <>
Vectorized<uint8_t> inline operator<<(
const Vectorized<uint8_t>& a,
const Vectorized<uint8_t>& b) {
uint8x16_t bound = vdupq_n_u8(8);
uint8x16_t z = vminq_u8(b, bound);
return vshlq_u8(a, vreinterpretq_s8_u8(z));
}
inline Vectorized<uint8_t> Vectorized<uint8_t>::set(
const Vectorized<uint8_t>& a,
const Vectorized<uint8_t>& b,
uint64_t count) {
if (count == 0) {
return a;
} else if (count >= 16) {
return b;
} else {
// Build an array of flags: each bit of element is 1 if the corresponding
// bit in 'mask' is set, 0 otherwise.
uint8x16_t maskArray = {
static_cast<uint8_t>((count >= 1LL) ? 0xFF : 0),
static_cast<uint8_t>((count >= 2LL) ? 0xFF : 0),
static_cast<uint8_t>((count >= 3LL) ? 0xFF : 0),
static_cast<uint8_t>((count >= 4LL) ? 0xFF : 0),
static_cast<uint8_t>((count >= 5LL) ? 0xFF : 0),
static_cast<uint8_t>((count >= 6LL) ? 0xFF : 0),
static_cast<uint8_t>((count >= 7LL) ? 0xFF : 0),
static_cast<uint8_t>((count >= 8LL) ? 0xFF : 0),
static_cast<uint8_t>((count >= 9LL) ? 0xFF : 0),
static_cast<uint8_t>((count >= 10LL) ? 0xFF : 0),
static_cast<uint8_t>((count >= 11LL) ? 0xFF : 0),
static_cast<uint8_t>((count >= 12LL) ? 0xFF : 0),
static_cast<uint8_t>((count >= 13LL) ? 0xFF : 0),
static_cast<uint8_t>((count >= 14LL) ? 0xFF : 0),
static_cast<uint8_t>((count >= 15LL) ? 0xFF : 0),
0};
// Use BSL to select elements from b where the mask is 1, else from a
return vbslq_u8(maskArray, b.values, a.values);
}
}
template <>
Vectorized<uint8_t> inline operator/(
const Vectorized<uint8_t>& a,
const Vectorized<uint8_t>& b) {
uint8x16_t x = a;
uint8x16_t y = b;
return x / y;
}
template <>
Vectorized<uint8_t> inline clamp(
const Vectorized<uint8_t>& a,
const Vectorized<uint8_t>& min,
const Vectorized<uint8_t>& max) {
return minimum(max, maximum(min, a));
}
template <>
Vectorized<uint8_t> inline clamp_max(
const Vectorized<uint8_t>& a,
const Vectorized<uint8_t>& max) {
return minimum(max, a);
}
template <>
Vectorized<uint8_t> inline clamp_min(
const Vectorized<uint8_t>& a,
const Vectorized<uint8_t>& min) {
return maximum(min, a);
}
} // namespace CPU_CAPABILITY
} // namespace at::vec

View File

@ -1390,7 +1390,7 @@ std::pair<Vectorized<float>, Vectorized<float>> inline convert_int8_to_float(
std::pair<Vectorized<float>, Vectorized<float>> inline convert_int8_to_float(
at::vec::Vectorized<uint8_t> src) {
auto u8x8 = vld1_u8(src.operator const uint8_t*());
auto u8x8 = vget_low_u8(src);
auto u16x8 = vmovl_u8(u8x8);
auto u32x4_hi = vmovl_u16(vget_high_u16(u16x8));
auto u32x4_lo = vmovl_u16(vget_low_u16(u16x8));
@ -1412,7 +1412,7 @@ Vectorized<float> inline convert_int8_half_register_to_float(
Vectorized<float> inline convert_int8_half_register_to_float(
at::vec::Vectorized<uint8_t> src) {
auto u8x8 = vld1_u8(src.operator const uint8_t*());
auto u8x8 = vget_low_u8(src);
auto u16x8 = vmovl_u8(u8x8);
auto u32x4_lo = vmovl_u16(vget_low_u16(u16x8));

View File

@ -0,0 +1,192 @@
#include <ATen/cuda/CUDAGreenContext.h>
namespace at::cuda {
GreenContext::GreenContext(uint32_t device_id, uint32_t num_sms) {
#if CUDA_HAS_GREEN_CONTEXT
int driver_version;
C10_CUDA_CHECK(cudaDriverGetVersion(&driver_version));
TORCH_CHECK(
driver_version >= 12080, "cuda driver too old to use green context!");
CUcontext pctx = nullptr;
C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuCtxGetCurrent_(&pctx));
if (C10_UNLIKELY(!pctx)) {
TORCH_WARN(
"Attempted to create a green context but"
" there was no primary context! Creating a primary context...");
cudaFree(0);
}
CUdevice device;
device_id_ = device_id;
C10_CUDA_DRIVER_CHECK(
c10::cuda::DriverAPI::get()->cuDeviceGet_(&device, device_id));
// Get device resources
CUdevResource device_resource;
C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuDeviceGetDevResource_(
device, &device_resource, CU_DEV_RESOURCE_TYPE_SM));
// Split resources
std::vector<CUdevResource> result(1);
auto result_data = result.data();
unsigned int nb_groups = 1;
CUdevResource remaining;
C10_CUDA_DRIVER_CHECK(
c10::cuda::DriverAPI::get()->cuDevSmResourceSplitByCount_(
result_data,
&nb_groups,
&device_resource,
&remaining,
0, // default flags
num_sms));
TORCH_CHECK(nb_groups == 1, "Failed to create single resource group");
// Generate resource descriptor
CUdevResourceDesc desc;
C10_CUDA_DRIVER_CHECK(
c10::cuda::DriverAPI::get()->cuDevResourceGenerateDesc_(
&desc, result_data, 1));
// Create green context
// CU_GREEN_CTX_DEFAULT_STREAM is required per docs:
// https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__GREEN__CONTEXTS.html
C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuGreenCtxCreate_(
&green_ctx_, desc, device, CU_GREEN_CTX_DEFAULT_STREAM));
// Convert to regular context
C10_CUDA_DRIVER_CHECK(
c10::cuda::DriverAPI::get()->cuCtxFromGreenCtx_(&context_, green_ctx_));
TORCH_CHECK(context_, "Green ctx conversion to regular ctx failed!");
#else
TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
#endif
}
std::unique_ptr<GreenContext> GreenContext::create(
uint32_t num_sms,
std::optional<uint32_t> device_id) {
#if CUDA_HAS_GREEN_CONTEXT
if (!device_id.has_value()) {
device_id = at::cuda::current_device();
}
return std::make_unique<GreenContext>(device_id.value(), num_sms);
#else
TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
#endif
}
// Implement move operations
GreenContext::GreenContext(GreenContext&& other) noexcept{
#if CUDA_HAS_GREEN_CONTEXT
device_id_ = std::exchange(other.device_id_, -1);
green_ctx_ = std::exchange(other.green_ctx_, nullptr);
context_ = std::exchange(other.context_, nullptr);
parent_stream_ = std::exchange(other.parent_stream_, nullptr);
#else
TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
#endif
}
GreenContext& GreenContext::operator=(GreenContext&& other) noexcept{
#if CUDA_HAS_GREEN_CONTEXT
if (this != &other) {
// Clean up current resources
if (green_ctx_) {
CUcontext current = nullptr;
C10_CUDA_DRIVER_CHECK(
c10::cuda::DriverAPI::get()->cuCtxGetCurrent_(&current));
if (current == context_) {
TORCH_CHECK(
false,
"attempting to overwrite current green ctx "
"when it is active!");
}
C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuGreenCtxDestroy_(green_ctx_));
}
// Take ownership of other's resources
device_id_ = std::exchange(other.device_id_, -1);
green_ctx_ = std::exchange(other.green_ctx_, nullptr);
context_ = std::exchange(other.context_, nullptr);
parent_stream_ = std::exchange(other.parent_stream_, nullptr);
}
return *this;
#else
TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
#endif
}
GreenContext::~GreenContext() noexcept{
#if CUDA_HAS_GREEN_CONTEXT
C10_CUDA_DRIVER_CHECK(
c10::cuda::DriverAPI::get()->cuGreenCtxDestroy_(green_ctx_));
#else
TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
#endif
}
// Get the underlying CUDA context
CUcontext GreenContext::getContext() const {
#if CUDA_HAS_GREEN_CONTEXT
return context_;
#else
TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
#endif
}
// Get the underlying green context
#if CUDA_HAS_GREEN_CONTEXT
CUgreenCtx GreenContext::getGreenContext() const {
return green_ctx_;
}
#endif
// Make this context current
void GreenContext::setContext() {
#if CUDA_HAS_GREEN_CONTEXT
auto current_stream = c10::cuda::getCurrentCUDAStream();
parent_stream_ = current_stream.stream();
at::cuda::CUDAEvent ev;
ev.record(current_stream);
CUcontext current = nullptr;
C10_CUDA_DRIVER_CHECK(
c10::cuda::DriverAPI::get()->cuCtxGetCurrent_(&current));
if (!current) {
C10_CUDA_DRIVER_CHECK(
c10::cuda::DriverAPI::get()->cuCtxSetCurrent_(context_));
} else {
C10_CUDA_DRIVER_CHECK(
c10::cuda::DriverAPI::get()->cuCtxPushCurrent_(context_));
}
// currently hardcodes the new green context to use the default stream
// TODO(eqy): consider creating a new stream if e.g., it allows interop
// with CUDA Graph captures etc.
auto default_stream = c10::cuda::getDefaultCUDAStream();
ev.block(default_stream);
c10::cuda::setCurrentCUDAStream(default_stream);
#else
TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
#endif
}
void GreenContext::popContext() {
#if CUDA_HAS_GREEN_CONTEXT
// see above note about stream being hardcoded to the default stream
at::cuda::CUDAEvent ev;
ev.record(c10::cuda::getCurrentCUDAStream());
CUcontext popped;
C10_CUDA_DRIVER_CHECK(
c10::cuda::DriverAPI::get()->cuCtxPopCurrent_(&popped));
TORCH_INTERNAL_ASSERT(
popped == context_, "expected popped context to be the current ctx");
ev.block(c10::cuda::getStreamFromExternal(parent_stream_, device_id_));
#else
TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
#endif
}
} // namespace at::cuda

View File

@ -0,0 +1,53 @@
#pragma once
#include <ATen/cuda/CUDAEvent.h>
#if defined(CUDA_VERSION) && !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
#include <c10/cuda/driver_api.h>
#include <cuda.h>
#include <memory>
#include <stdexcept>
#include <vector>
#define CUDA_HAS_GREEN_CONTEXT 1
#else
#define CUDA_HAS_GREEN_CONTEXT 0
#endif
namespace at::cuda {
class TORCH_CUDA_CPP_API GreenContext {
public:
GreenContext(uint32_t device_id, uint32_t num_sms);
static std::unique_ptr<GreenContext> create(uint32_t num_sms, std::optional<uint32_t> device_id);
// Delete copy constructor and assignment
GreenContext(const GreenContext&) = delete;
GreenContext& operator=(const GreenContext&) = delete;
// Implement move operations
GreenContext(GreenContext&& other) noexcept;
GreenContext& operator=(GreenContext&& other) noexcept;
~GreenContext() noexcept;
// Get the underlying CUDA context
CUcontext getContext() const;
// Get the underlying green context
#if CUDA_HAS_GREEN_CONTEXT
CUgreenCtx getGreenContext() const;
#endif
// Make this context current
void setContext();
void popContext();
private:
#if CUDA_HAS_GREEN_CONTEXT
int32_t device_id_ = -1;
CUgreenCtx green_ctx_ = nullptr;
CUcontext context_ = nullptr;
cudaStream_t parent_stream_ = nullptr;
#endif
};
} // namespace at::cuda

View File

@ -0,0 +1,270 @@
#include <cstdint>
#include <c10/util/typeid.h>
#include <c10/util/Exception.h>
#include <c10/util/SmallVector.h>
#include <c10/core/Scalar.h>
#include <c10/core/ScalarType.h>
#include <c10/util/Exception.h>
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/core/Tensor.h>
#include <ATen/core/NamedTensor.h>
#include <ATen/Dispatch.h>
#include <ATen/ExpandUtils.h>
#include <ATen/OpMathType.h>
#include <ATen/TensorUtils.h>
#include <ATen/cuda/CUDABlas.h>
#include <ATen/cuda/tunable/Tunable.h>
#include <ATen/cuda/tunable/TunableGemm.h>
#include <ATen/native/Resize.h>
#include <c10/util/MaybeOwned.h>
#include <ATen/native/GroupedMMUtils.h>
#include <ATen/native/cuda/RowwiseScaledMM.h>
#include <ATen/native/cuda/ScaledGroupMM.h>
#include <ATen/native/cuda/GroupMM.h>
#include <ATen/ceil_div.h>
#ifdef USE_FBGEMM_GENAI
#include <fbgemm_gpu/torch_ops.h>
#endif
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/_addmm_activation_native.h>
#include <ATen/ops/_efficientzerotensor.h>
#include <ATen/ops/_scaled_mm_native.h>
#include <ATen/ops/_unsafe_view_native.h>
#include <ATen/ops/abs.h>
#include <ATen/ops/addmm_native.h>
#include <ATen/ops/addmv_native.h>
#include <ATen/ops/baddbmm_native.h>
#include <ATen/ops/bmm_native.h>
#include <ATen/ops/copy_native.h>
#include <ATen/ops/dot_native.h>
#include <ATen/ops/empty.h>
#include <ATen/ops/empty_strided.h>
#include <ATen/ops/gelu.h>
#include <ATen/ops/max.h>
#include <ATen/ops/mm_native.h>
#include <ATen/ops/mul.h>
#include <ATen/ops/relu.h>
#include <ATen/ops/ones.h>
#include <ATen/ops/scalar_tensor_native.h>
#include <ATen/ops/vdot_native.h>
#endif
using at::blas::ScalingType;
using at::blas::SwizzleType;
namespace at::cuda::scaled {
/**
* Both inputs must be fp8,
* Each needs a single scale, {Tensorwise (float)}
*/
bool check_tensorwise_recipe(c10::ScalarType type_a,
std::vector<ScalingType>& recipe_a,
ArrayRef<Tensor>& scales_a,
c10::ScalarType type_b,
std::vector<ScalingType>& recipe_b,
ArrayRef<Tensor>& scales_b) {
// both types must be fp8
if (!isFloat8Type(type_a) || !isFloat8Type(type_b)) {
return false;
}
// 1 scale each, {Tensorwise, float}
if (scales_a.size() != 1 || recipe_a.size() != 1 || scales_b.size() != 1 || recipe_b.size() != 1) {
return false;
}
// Need {Blockwise_1x32, e8m0} for A & B
if (recipe_a[0] != ScalingType::TensorWise) return false;
if (scales_a[0].scalar_type() != ScalarType::Float) return false;
if (recipe_b[0] != ScalingType::TensorWise) return false;
if (scales_b[0].scalar_type() != ScalarType::Float) return false;
return true;
}
/**
* Both inputs must be fp8,
* Each needs scales, {Rowwise (float)}
*/
bool check_rowwise_recipe(c10::ScalarType type_a,
std::vector<ScalingType>& recipe_a,
ArrayRef<Tensor>& scales_a,
c10::ScalarType type_b,
std::vector<ScalingType>& recipe_b,
ArrayRef<Tensor>& scales_b) {
// both types must be fp8
if (!isFloat8Type(type_a) || !isFloat8Type(type_b)) {
return false;
}
// 1 scale each, {Tensorwise, float}
if (scales_a.size() != 1 || recipe_a.size() != 1 || scales_b.size() != 1 || recipe_b.size() != 1) {
return false;
}
// Need {RowWise, dp32} for A & B
if (recipe_a[0] != ScalingType::RowWise) return false;
if (scales_a[0].scalar_type() != ScalarType::Float) return false;
if (recipe_b[0] != ScalingType::RowWise) return false;
if (scales_b[0].scalar_type() != ScalarType::Float) return false;
return true;
}
/**
* Two-level scaling, canonical NVFP4
* Both inputs must be fp4
* A, B need 2 scales, {Blockwise_1x16 (e4m3), Tensorwise (fp32)}
*/
bool check_nvfp4_recipe(c10::ScalarType type_a,
std::vector<ScalingType>& recipe_a,
ArrayRef<Tensor>& scales_a,
c10::ScalarType type_b,
std::vector<ScalingType>& recipe_b,
ArrayRef<Tensor>& scales_b) {
// both types must be fp4
if (type_a != ScalarType::Float4_e2m1fn_x2 || type_b != ScalarType::Float4_e2m1fn_x2) {
return false;
}
// 2 scales, 2 recipes for each input
if (scales_a.size() != 2 || recipe_a.size() != 2 || scales_b.size() != 2 || recipe_b.size() != 2) {
return false;
}
// Need {Blockwise_1x16, e4m3 for scale[0], Tensorwise, fp32 for scale[1]}
if (recipe_a[0] != ScalingType::BlockWise1x16 || recipe_a[1] != ScalingType::TensorWise) return false;
if (scales_a[0].scalar_type() != ScalarType::Float8_e4m3fn || scales_a[1].scalar_type() != ScalarType::Float) return false;
if (recipe_b[0] != ScalingType::BlockWise1x16 || recipe_b[1] != ScalingType::TensorWise) return false;
if (scales_b[0].scalar_type() != ScalarType::Float8_e4m3fn || scales_b[1].scalar_type() != ScalarType::Float) return false;
return true;
}
/**
* Single-level scaling, what PyT currently understands
* Both inputs must be fp4
* A, B need 1 scale, {Blockwise_1x16 (e4m3)}
*/
bool check_nvfp4_recipe_single_scale
(c10::ScalarType type_a,
std::vector<ScalingType>& recipe_a,
ArrayRef<Tensor>& scales_a,
c10::ScalarType type_b,
std::vector<ScalingType>& recipe_b,
ArrayRef<Tensor>& scales_b) {
// both types must be fp4
if (type_a != ScalarType::Float4_e2m1fn_x2 || type_b != ScalarType::Float4_e2m1fn_x2) {
return false;
}
// 2 scales, 2 recipes for each input
if (scales_a.size() != 1 || recipe_a.size() != 1 || scales_b.size() != 1 || recipe_b.size() != 1) {
return false;
}
// Need {Blockwise_1x16, e4m3 for scale[0], Tensorwise, fp32 for scale[1]}
if (recipe_a[0] != ScalingType::BlockWise1x16) return false;
if (scales_a[0].scalar_type() != ScalarType::Float8_e4m3fn) return false;
if (recipe_b[0] != ScalingType::BlockWise1x16) return false;
if (scales_b[0].scalar_type() != ScalarType::Float8_e4m3fn) return false;
return true;
}
/**
* Both inputs must be fp8
* A, B must only have 1 scale each, A: {Blockwise_1x128 (float), B: {Blockwise_128x128 (float)
*/
bool check_deepseek_recipe(ScalingType expected_recipe_a,
ScalingType expected_recipe_b,
c10::ScalarType type_a,
std::vector<ScalingType>& recipe_a,
ArrayRef<Tensor>& scales_a,
c10::ScalarType type_b,
std::vector<ScalingType>& recipe_b,
ArrayRef<Tensor>& scales_b) {
// both types must be fp8
if (type_a != ScalarType::Float8_e4m3fn || type_b != ScalarType::Float8_e4m3fn) {
return false;
}
// 1 scales, 1 recipes for each input
if (scales_a.size() != 1 || recipe_a.size() != 1 || scales_b.size() != 1 || recipe_b.size() != 1) {
return false;
}
// Need {Blockwise_1x128, float} for A, {Blockwise_128x128, float} for B
if (recipe_a[0] != expected_recipe_a) return false;
if (scales_a[0].scalar_type() != ScalarType::Float) return false;
if (recipe_b[0] != expected_recipe_b) return false;
if (scales_b[0].scalar_type() != ScalarType::Float) return false;
return true;
}
/**
* Both inputs must be fp8
* A, B must have 1 scale each, {Blockwise_1x32, e8m0}
*/
bool check_mxfp8_recipe(c10::ScalarType type_a,
std::vector<ScalingType>& recipe_a,
ArrayRef<Tensor>& scales_a,
c10::ScalarType type_b,
std::vector<ScalingType>& recipe_b,
ArrayRef<Tensor>& scales_b) {
// both types must be fp8
if (type_a != ScalarType::Float8_e4m3fn || type_b != ScalarType::Float8_e4m3fn) {
return false;
}
// 1 scales, 1 recipes for each input
if (scales_a.size() != 1 || recipe_a.size() != 1 || scales_b.size() != 1 || recipe_b.size() != 1) {
return false;
}
// Need {Blockwise_1x32, e8m0} for A & B
if (recipe_a[0] != ScalingType::BlockWise1x32) return false;
if (scales_a[0].scalar_type() != ScalarType::Float8_e8m0fnu) return false;
if (recipe_b[0] != ScalingType::BlockWise1x32) return false;
if (scales_b[0].scalar_type() != ScalarType::Float8_e8m0fnu) return false;
return true;
}
/**
* Both inputs must be fp4
* A, B must have 1 scale each, {Blockwise_1x32, e8m0}
*/
bool check_mxfp4_recipe(c10::ScalarType type_a,
std::vector<ScalingType>& recipe_a,
ArrayRef<Tensor>& scales_a,
c10::ScalarType type_b,
std::vector<ScalingType>& recipe_b,
ArrayRef<Tensor>& scales_b) {
// both types must be fp4
if (type_a != ScalarType::Float4_e2m1fn_x2 || type_b != ScalarType::Float4_e2m1fn_x2) {
return false;
}
// 1 scales, 1 recipes for each input
if (scales_a.size() != 1 || recipe_a.size() != 1 || scales_b.size() != 1 || recipe_b.size() != 1) {
return false;
}
// Need {Blockwise_1x32, e8m0} for A & B
if (recipe_a[0] != ScalingType::BlockWise1x32) return false;
if (scales_a[0].scalar_type() != ScalarType::Float8_e8m0fnu) return false;
if (recipe_b[0] != ScalingType::BlockWise1x32) return false;
if (scales_b[0].scalar_type() != ScalarType::Float8_e8m0fnu) return false;
return true;
}
} // namespace at::native::cuda::blas::scaled

View File

@ -0,0 +1,174 @@
#include <cstdint>
#include <c10/util/typeid.h>
#include <c10/util/Exception.h>
#include <c10/util/SmallVector.h>
#include <c10/core/Scalar.h>
#include <c10/core/ScalarType.h>
#include <c10/util/Exception.h>
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/core/Tensor.h>
#include <ATen/core/NamedTensor.h>
#include <ATen/Dispatch.h>
#include <ATen/ExpandUtils.h>
#include <ATen/OpMathType.h>
#include <ATen/TensorUtils.h>
#include <ATen/cuda/CUDABlas.h>
#include <ATen/cuda/tunable/Tunable.h>
#include <ATen/cuda/tunable/TunableGemm.h>
#include <ATen/native/Resize.h>
#include <c10/util/MaybeOwned.h>
#include <ATen/native/GroupedMMUtils.h>
#include <ATen/native/cuda/RowwiseScaledMM.h>
#include <ATen/native/cuda/ScaledGroupMM.h>
#include <ATen/native/cuda/GroupMM.h>
#include <ATen/ceil_div.h>
#ifdef USE_FBGEMM_GENAI
#include <fbgemm_gpu/torch_ops.h>
#endif
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/_addmm_activation_native.h>
#include <ATen/ops/_efficientzerotensor.h>
#include <ATen/ops/_scaled_mm_native.h>
#include <ATen/ops/_unsafe_view_native.h>
#include <ATen/ops/abs.h>
#include <ATen/ops/addmm_native.h>
#include <ATen/ops/addmv_native.h>
#include <ATen/ops/baddbmm_native.h>
#include <ATen/ops/bmm_native.h>
#include <ATen/ops/copy_native.h>
#include <ATen/ops/dot_native.h>
#include <ATen/ops/empty.h>
#include <ATen/ops/empty_strided.h>
#include <ATen/ops/gelu.h>
#include <ATen/ops/max.h>
#include <ATen/ops/mm_native.h>
#include <ATen/ops/mul.h>
#include <ATen/ops/relu.h>
#include <ATen/ops/ones.h>
#include <ATen/ops/scalar_tensor_native.h>
#include <ATen/ops/vdot_native.h>
#endif
using at::blas::ScalingType;
using at::blas::SwizzleType;
namespace at::cuda::scaled {
static bool _scaled_mm_allowed_device(bool sm90_only=false, bool sm100_only=false) {
#ifdef USE_ROCM
static const std::vector<std::string> archs = {
"gfx942",
#if ROCM_VERSION >= 60300
"gfx1200", "gfx1201",
#endif
#if ROCM_VERSION >= 60500
"gfx950"
#endif
};
return at::detail::getCUDAHooks().isGPUArch(archs);
#else
auto dprops = at::cuda::getCurrentDeviceProperties();
if (sm90_only || sm100_only) {
return (sm90_only && dprops->major == 9) || (sm100_only && dprops->major == 10);
} else {
return dprops->major >= 9 || (dprops->major == 8 && dprops->minor == 9);
}
#endif
}
#ifdef USE_ROCM
static bool _scaled_mm_is_fnuz() {
return at::detail::getCUDAHooks().isGPUArch({"gfx942"});
}
#endif
/**
* Track concrete implementations available
*/
enum class ScaledGemmImplementation {
NONE = 0,
TENSORWISE_TENSORWISE = 1,
ROWWISE_ROWWISE = 2,
BLOCK_128x128_1x128 = 3,
BLOCK_1x128_128x128 = 4,
BLOCK_1x128_1x128 = 5,
MXFP8_MXFP8 = 6,
NVFP4_NVFP4 = 7,
NVFP4_NVFP4_SINGLE_SCALE = 8,
MXFP4_MXFP4 = 9,
};
/**
* Convert passed int (enum) from python back into a
* strictly-typed enum
*/
template <class EnumType, class ArrayType>
std::vector<EnumType> convert_int_to_enum(ArrayType& v) {
std::vector<EnumType> converted;
converted.reserve(v.size());
for (auto vi : v) {
converted.push_back(static_cast<EnumType>(vi));
}
return converted;
}
bool check_tensorwise_recipe(c10::ScalarType,
std::vector<ScalingType>&,
ArrayRef<Tensor>&,
c10::ScalarType,
std::vector<ScalingType>&,
ArrayRef<Tensor>&);
bool check_rowwise_recipe(c10::ScalarType,
std::vector<ScalingType>&,
ArrayRef<Tensor>&,
c10::ScalarType,
std::vector<ScalingType>&,
ArrayRef<Tensor>&);
bool check_nvfp4_recipe(c10::ScalarType,
std::vector<ScalingType>&,
ArrayRef<Tensor>&,
c10::ScalarType,
std::vector<ScalingType>&,
ArrayRef<Tensor>&);
bool check_nvfp4_recipe_single_scale
(c10::ScalarType,
std::vector<ScalingType>&,
ArrayRef<Tensor>&,
c10::ScalarType,
std::vector<ScalingType>&,
ArrayRef<Tensor>&);
bool check_deepseek_recipe(ScalingType,
ScalingType,
c10::ScalarType,
std::vector<ScalingType>&,
ArrayRef<Tensor>&,
c10::ScalarType,
std::vector<ScalingType>&,
ArrayRef<Tensor>&);
bool check_mxfp8_recipe(c10::ScalarType,
std::vector<ScalingType>&,
ArrayRef<Tensor>&,
c10::ScalarType,
std::vector<ScalingType>&,
ArrayRef<Tensor>&);
bool check_mxfp4_recipe(c10::ScalarType,
std::vector<ScalingType>&,
ArrayRef<Tensor>&,
c10::ScalarType,
std::vector<ScalingType>&,
ArrayRef<Tensor>&);
} // namespace at::native::cuda::blas::scaled

View File

@ -70,11 +70,7 @@
#define ATEN_CUB_MAXIMUM() NO_ROCM(at_cuda_detail)ROCM_HIPCUB(::cub)::Max()
#endif
#if (!defined(USE_ROCM) && !CUB_SUPPORTS_NV_BFLOAT16()) || defined(USE_ROCM)
#if !defined(USE_ROCM)
namespace at_cuda_detail {
#endif
#if defined(USE_ROCM)
// backport https://github.com/NVIDIA/cub/pull/306 for c10::BFloat16
@ -96,10 +92,6 @@ template <>
struct ROCM_HIPCUB(cub)::NumericTraits<c10::BFloat16>:
ROCM_HIPCUB(cub)::BaseTraits<ROCM_HIPCUB(cub)::FLOATING_POINT, true, false, unsigned short, c10::BFloat16> {};
#if !defined(USE_ROCM)
} // namespace at_cuda_detail
#endif
#endif
#if !defined(USE_ROCM)
@ -121,7 +113,7 @@ struct cuda_type<c10::Half> {
using type = __half;
};
#if !defined(USE_ROCM) && CUB_SUPPORTS_NV_BFLOAT16()
#if !defined(USE_ROCM)
template<>
struct cuda_type<c10::BFloat16> {
@ -203,36 +195,6 @@ __global__ void transform_vals(InputIteratorT1 a, InputIteratorT2 b, OutputItera
*out = scan_op(static_cast<acc_t>(*a), static_cast<acc_t>(*b));
}
#if !CUB_SUPPORTS_FUTURE_VALUE()
template<typename ValueT, typename InputIteratorT>
struct chained_iterator {
using iterator_category = std::random_access_iterator_tag;
using difference_type = std::ptrdiff_t;
using value_type = ValueT;
using pointer = ValueT*;
using reference = ValueT&;
InputIteratorT iter;
ValueT *first;
difference_type offset = 0;
__device__ ValueT operator[](difference_type i) {
i += offset;
if (i == 0) {
return *first;
} else {
return ValueT(iter[i - 1]);
}
}
__device__ chained_iterator operator+(difference_type i) {
return chained_iterator{iter, first, i};
}
__device__ ValueT operator*() {
return (*this)[0];
}
};
#endif
// even though cub is supposed to support tensors with int_max elements, in reality it doesn't,
// so split at int_max/2
constexpr int max_cub_size = std::numeric_limits<int>::max() / 2 + 1; // 2**30
@ -277,25 +239,6 @@ inline void inclusive_scan(InputIteratorT input, OutputIteratorT output, ScanOpT
first_elem_ptr,
scan_op);
C10_CUDA_KERNEL_LAUNCH_CHECK();
#if !CUB_SUPPORTS_FUTURE_VALUE()
using ArgIndexInputIterator = NO_ROCM(at_cuda_detail)::cub::ArgIndexInputIterator<InputIteratorT>;
using tuple = typename ArgIndexInputIterator::value_type;
auto input_iter_transform = [=] __device__ (const tuple &x)->input_t {
if (x.key == 0) {
return *first_elem_ptr;
} else {
return x.value;
}
};
auto input_ = ATEN_CUB_TRANSFORM_ITERATOR(input_t, decltype(input_iter_transform), ArgIndexInputIterator)(
ArgIndexInputIterator(input + i), input_iter_transform);
CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceScan::InclusiveScan,
input_,
output + i,
scan_op,
size_cub,
at::cuda::getCurrentCUDAStream());
#else
CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceScan::ExclusiveScan,
input + i + 1,
output + i,
@ -303,7 +246,6 @@ inline void inclusive_scan(InputIteratorT input, OutputIteratorT output, ScanOpT
::at_cuda_detail::cub::FutureValue<input_t>(first_elem_ptr),
size_cub,
at::cuda::getCurrentCUDAStream());
#endif
}
#endif
}
@ -555,16 +497,6 @@ inline void exclusive_scan(InputIteratorT input, OutputIteratorT output, ScanOpT
first_elem_ptr,
scan_op);
C10_CUDA_KERNEL_LAUNCH_CHECK();
#if !CUB_SUPPORTS_FUTURE_VALUE()
auto input_ = impl::chained_iterator<InitValueT, InputIteratorT>{
input + i, first_elem_ptr};
CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceScan::InclusiveScan,
input_,
output + i,
scan_op,
size_cub,
at::cuda::getCurrentCUDAStream());
#else
CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceScan::ExclusiveScan,
input + i,
output + i,
@ -572,7 +504,6 @@ inline void exclusive_scan(InputIteratorT input, OutputIteratorT output, ScanOpT
::at_cuda_detail::cub::FutureValue<InitValueT>(first_elem_ptr),
size_cub,
at::cuda::getCurrentCUDAStream());
#endif
}
#endif
}

View File

@ -10,14 +10,6 @@
#define CUB_VERSION 200001
#endif
// cub sort support for __nv_bfloat16 is added to cub 1.13 in:
// https://github.com/NVIDIA/cub/pull/306
#if CUB_VERSION >= 101300
#define CUB_SUPPORTS_NV_BFLOAT16() true
#else
#define CUB_SUPPORTS_NV_BFLOAT16() false
#endif
// cub support for CUB_WRAPPED_NAMESPACE is added to cub 1.13.1 in:
// https://github.com/NVIDIA/cub/pull/326
// CUB_WRAPPED_NAMESPACE is defined globally in cmake/Dependencies.cmake
@ -28,14 +20,6 @@
#define USE_GLOBAL_CUB_WRAPPED_NAMESPACE() false
#endif
// cub support for cub::FutureValue is added to cub 1.15 in:
// https://github.com/NVIDIA/cub/pull/305
#if CUB_VERSION >= 101500
#define CUB_SUPPORTS_FUTURE_VALUE() true
#else
#define CUB_SUPPORTS_FUTURE_VALUE() false
#endif
// There were many bc-breaking changes in major version release of CCCL v3.0.0
// Please see https://nvidia.github.io/cccl/cccl/3.0_migration_guide.html
#if CUB_VERSION >= 200800

View File

@ -0,0 +1,23 @@
#include <ATen/detail/XLAHooksInterface.h>
namespace at {
namespace detail {
const XLAHooksInterface& getXLAHooks() {
auto create_impl = [] {
// Create XLA hooks using the registry
auto hooks = XLAHooksRegistry()->Create("torch_xla::detail::XLAHooks", XLAHooksArgs{});
if (hooks) {
return hooks;
}
// If hooks creation fails, fall back to default implementation
return std::make_unique<XLAHooksInterface>();
};
static auto hooks = create_impl();
return *hooks;
}
} // namespace detail
C10_DEFINE_REGISTRY(XLAHooksRegistry, XLAHooksInterface, XLAHooksArgs)
} // namespace at

View File

@ -0,0 +1,79 @@
#pragma once
#include <c10/core/Device.h>
#include <c10/util/Exception.h>
#include <c10/util/Registry.h>
#include <ATen/detail/AcceleratorHooksInterface.h>
C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-parameter")
namespace at {
constexpr const char* XLA_HELP =
"This error has occurred because you are trying "
"to use some XLA functionality, but the XLA library has not been "
"loaded by the dynamic linker. You must load xla libraries by `import torch_xla`";
struct TORCH_API XLAHooksInterface : AcceleratorHooksInterface {
~XLAHooksInterface() override = default;
void init() const override {
TORCH_CHECK(false, "Cannot initialize XLA without torch_xla library. ", XLA_HELP);
}
virtual bool hasXLA() const {
return false;
}
virtual std::string showConfig() const {
TORCH_CHECK(
false,
"Cannot query detailed XLA version without torch_xla library. ",
XLA_HELP);
}
const Generator& getDefaultGenerator(
[[maybe_unused]] DeviceIndex device_index = -1) const override {
TORCH_CHECK(
false, "Cannot get default XLA generator without torch_xla library. ", XLA_HELP);
}
Generator getNewGenerator(
[[maybe_unused]] DeviceIndex device_index = -1) const override {
TORCH_CHECK(false, "Cannot get XLA generator without torch_xla library. ", XLA_HELP);
}
virtual DeviceIndex getCurrentDevice() const override {
TORCH_CHECK(false, "Cannot get current XLA device without torch_xla library. ", XLA_HELP);
}
Device getDeviceFromPtr(void* /*data*/) const override {
TORCH_CHECK(false, "Cannot get device of pointer on XLA without torch_xla library. ", XLA_HELP);
}
Allocator* getPinnedMemoryAllocator() const override {
TORCH_CHECK(false, "Cannot get XLA pinned memory allocator without torch_xla library. ", XLA_HELP);
}
bool isPinnedPtr(const void* data) const override {
return false;
}
bool hasPrimaryContext(DeviceIndex device_index) const override {
TORCH_CHECK(false, "Cannot query primary context without torch_xla library. ", XLA_HELP);
}
};
struct TORCH_API XLAHooksArgs {};
TORCH_DECLARE_REGISTRY(XLAHooksRegistry, XLAHooksInterface, XLAHooksArgs);
#define REGISTER_XLA_HOOKS(clsname) \
C10_REGISTER_CLASS(XLAHooksRegistry, clsname, clsname)
namespace detail {
TORCH_API const XLAHooksInterface& getXLAHooks();
} // namespace detail
} // namespace at
C10_DIAGNOSTIC_POP()

View File

@ -3620,7 +3620,7 @@ Tensor& _int_mm_out_cpu(const Tensor& self, const Tensor& mat2, Tensor& result)
try {
mkldnn_matmul_i8i8i32(self, mat2, result);
dispatched = true;
} catch (const std::exception& e) {
} catch ([[maybe_unused]] const std::exception& e) {
TORCH_WARN(func_name, " failed, switching to BLAS gemm: ", e.what());
}
}

View File

@ -11,6 +11,8 @@ inline void check_pixel_shuffle_shapes(const Tensor& self, int64_t upscale_facto
"pixel_shuffle expects a positive upscale_factor, but got ",
upscale_factor);
int64_t c = self.size(-3);
TORCH_CHECK_VALUE(upscale_factor <= std::numeric_limits<decltype(upscale_factor)>::max() / upscale_factor,
"upscale factor is too large, (upscale_factor)^2 overflowed: upscale_factor=", upscale_factor);
int64_t upscale_factor_squared = upscale_factor * upscale_factor;
TORCH_CHECK(c % upscale_factor_squared == 0,
"pixel_shuffle expects its input's 'channel' dimension to be divisible by the square of "

View File

@ -259,11 +259,20 @@ inline void winograd_f2k3_input_transform_inplace__rvv(
const vfloat32m1_t wd1 = __riscv_vfadd_vv_f32m1(d1, d2, 4);
const vfloat32m1_t wd2 = __riscv_vfsub_vv_f32m1(d2, d1, 4);
const vfloat32m1_t wd3 = __riscv_vfsub_vv_f32m1(d1, d3, 4);
*input_tile_val = __riscv_vset_v_f32m1_f32m1x4(*input_tile_val, 0, wd0);
*input_tile_val = __riscv_vset_v_f32m1_f32m1x4(*input_tile_val, 1, wd1);
*input_tile_val = __riscv_vset_v_f32m1_f32m1x4(*input_tile_val, 2, wd2);
*input_tile_val = __riscv_vset_v_f32m1_f32m1x4(*input_tile_val, 3, wd3);
/* GCC 14.2 (RISC-V RVV) ICE workaround:
* Avoid single-statement read-modify-write on MEM_REF like:
* *input_tile_val =
* __riscv_vset_v_f32m1_f32m1x4(*input_tile_val, idx, val);
* This triggers an ICE during GIMPLE lower (gsi_replace / riscv_gimple_fold_builtin)
* with -march=rv64gcv. Use a temporary then write back.
* Do NOT refactor into the single-statement form. Clang is unaffected.
*/
vfloat32m1x4_t tmp_input_tile_val = *input_tile_val;
tmp_input_tile_val = __riscv_vset_v_f32m1_f32m1x4(tmp_input_tile_val, 0, wd0);
tmp_input_tile_val = __riscv_vset_v_f32m1_f32m1x4(tmp_input_tile_val, 1, wd1);
tmp_input_tile_val = __riscv_vset_v_f32m1_f32m1x4(tmp_input_tile_val, 2, wd2);
tmp_input_tile_val = __riscv_vset_v_f32m1_f32m1x4(tmp_input_tile_val, 3, wd3);
*input_tile_val = tmp_input_tile_val;
}
inline void winograd_f2k3_output_transform_inplace__rvv(
@ -277,9 +286,15 @@ inline void winograd_f2k3_output_transform_inplace__rvv(
const vfloat32m1_t wm0 = __riscv_vfadd_vv_f32m1(m0_plus_m1, m2, 4);
const vfloat32m1_t m1_sub_m2 = __riscv_vfsub_vv_f32m1(m1, m2, 4);
const vfloat32m1_t wm1 = __riscv_vfsub_vv_f32m1(m1_sub_m2, m3, 4);
*input_tile_val = __riscv_vset_v_f32m1_f32m1x4(*input_tile_val, 0, wm0);
*input_tile_val = __riscv_vset_v_f32m1_f32m1x4(*input_tile_val, 1, wm1);
/* GCC 14.2 (RISC-V RVV) ICE workaround — see note above.
* Keep the temporary + write-back pattern to avoid ICE.
* Do NOT rewrite into:
* *input_tile_val = __riscv_vset_v_f32m1_f32m1x4(*input_tile_val, idx, val);
*/
vfloat32m1x4_t tmp_output_tile_val = *input_tile_val;
tmp_output_tile_val = __riscv_vset_v_f32m1_f32m1x4(tmp_output_tile_val, 0, wm0);
tmp_output_tile_val = __riscv_vset_v_f32m1_f32m1x4(tmp_output_tile_val, 1, wm1);
*input_tile_val = tmp_output_tile_val;
}
inline vfloat32m1_t
@ -300,11 +315,17 @@ inline void winograd_f2k3_kernel_transform__rvv(
const vfloat32m1_t const_half = __riscv_vfmv_v_f_f32m1(0.5f, 4);
const vfloat32m1_t g0_plus_g2 = __riscv_vfadd_vv_f32m1(g0, g2, 4);
vfloat32m1_t half_g0_plus_g2 = __riscv_vfmul_vv_f32m1(const_half, g0_plus_g2, 4);
*transform = __riscv_vset_v_f32m1_f32m1x4(*transform, 0, g0);
*transform = __riscv_vset_v_f32m1_f32m1x4(*transform, 1, vmuladdq_f32(half_g0_plus_g2, const_half, g1));
*transform = __riscv_vset_v_f32m1_f32m1x4(*transform, 2, vmulsubq_f32(half_g0_plus_g2, const_half, g1));
*transform = __riscv_vset_v_f32m1_f32m1x4(*transform, 3, g2);
/* GCC 14.2 (RISC-V RVV) ICE workaround — see note above.
* Keep the temporary + write-back pattern to avoid ICE.
* Do NOT rewrite into:
* *transform = __riscv_vset_v_f32m1_f32m1x4(*transform, idx, val);
*/
vfloat32m1x4_t tmp_transform = *transform;
tmp_transform = __riscv_vset_v_f32m1_f32m1x4(tmp_transform, 0, g0);
tmp_transform = __riscv_vset_v_f32m1_f32m1x4(tmp_transform, 1, vmuladdq_f32(half_g0_plus_g2, const_half, g1));
tmp_transform = __riscv_vset_v_f32m1_f32m1x4(tmp_transform, 2, vmulsubq_f32(half_g0_plus_g2, const_half, g1));
tmp_transform = __riscv_vset_v_f32m1_f32m1x4(tmp_transform, 3, g2);
*transform = tmp_transform;
}
inline vfloat32m1x4_t v4f_transpose4x4__rvv(const vfloat32m1x4_t m) {

View File

@ -120,7 +120,7 @@ static void pow_tensor_scalar_kernel(
} else if (dtype == ScalarType::Half) {
[&]() {
using scalar_t =
decltype(c10::impl::ScalarTypeToCPPType<ScalarType::Half>::t);
c10::impl::ScalarTypeToCPPTypeT<ScalarType::Half>;
const auto exp = exp_scalar.to<scalar_t>();
using Vec = Vectorized<scalar_t>;
cpu_kernel_vec(iter,

File diff suppressed because it is too large Load Diff

View File

@ -856,9 +856,13 @@ struct type_specialized_kernel_launcher {
out_calc_t output_offset_calculator,
loader_t loader,
storer_t storer) {
if (ret_t == rt_binary_specializations[arg_index][0] &&
arg0_t == rt_binary_specializations[arg_index][1] &&
arg1_t == rt_binary_specializations[arg_index][2])
constexpr ScalarType sret_t = rt_binary_specializations[arg_index][0];
constexpr ScalarType sarg0_t = rt_binary_specializations[arg_index][1];
constexpr ScalarType sarg1_t = rt_binary_specializations[arg_index][2];
if (ret_t == sret_t && arg0_t == sarg0_t && arg1_t == sarg1_t) {
using cret_t = c10::impl::ScalarTypeToCPPTypeT<sret_t>;
using carg0_t = c10::impl::ScalarTypeToCPPTypeT<sarg0_t>;
using carg1_t = c10::impl::ScalarTypeToCPPTypeT<sarg1_t>;
launch_vectorized_templated_kernel<
func_t,
array_t,
@ -866,12 +870,9 @@ struct type_specialized_kernel_launcher {
out_calc_t,
loader_t,
storer_t,
decltype(c10::impl::ScalarTypeToCPPType<
rt_binary_specializations[arg_index][0]>::t),
decltype(c10::impl::ScalarTypeToCPPType<
rt_binary_specializations[arg_index][1]>::t),
decltype(c10::impl::ScalarTypeToCPPType<
rt_binary_specializations[arg_index][2]>::t)>(
cret_t,
carg0_t,
carg1_t>(
numel,
f,
data,
@ -879,6 +880,7 @@ struct type_specialized_kernel_launcher {
output_offset_calculator,
loader,
storer);
}
}
};

View File

@ -0,0 +1,574 @@
#include <cstdint>
#include <c10/util/typeid.h>
#include <c10/util/Exception.h>
#include <c10/util/SmallVector.h>
#include <c10/core/Scalar.h>
#include <c10/core/ScalarType.h>
#include <c10/util/Exception.h>
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/core/Tensor.h>
#include <ATen/core/NamedTensor.h>
#include <ATen/Dispatch.h>
#include <ATen/ExpandUtils.h>
#include <ATen/OpMathType.h>
#include <ATen/TensorUtils.h>
#include <ATen/cuda/CUDABlas.h>
#include <ATen/cuda/CUDAScaledBlas.h>
#include <ATen/cuda/tunable/Tunable.h>
#include <ATen/cuda/tunable/TunableGemm.h>
#include <ATen/native/Resize.h>
#include <c10/util/MaybeOwned.h>
#include <ATen/native/GroupedMMUtils.h>
#include <ATen/native/cuda/RowwiseScaledMM.h>
#include <ATen/native/cuda/ScaledGroupMM.h>
#include <ATen/native/cuda/GroupMM.h>
#include <ATen/ceil_div.h>
#ifdef USE_FBGEMM_GENAI
#include <fbgemm_gpu/torch_ops.h>
#endif
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/_addmm_activation_native.h>
#include <ATen/ops/_efficientzerotensor.h>
#include <ATen/ops/_scaled_mm_native.h>
#include <ATen/ops/_unsafe_view_native.h>
#include <ATen/ops/abs.h>
#include <ATen/ops/addmm_native.h>
#include <ATen/ops/addmv_native.h>
#include <ATen/ops/baddbmm_native.h>
#include <ATen/ops/bmm_native.h>
#include <ATen/ops/copy_native.h>
#include <ATen/ops/dot_native.h>
#include <ATen/ops/empty.h>
#include <ATen/ops/empty_strided.h>
#include <ATen/ops/gelu.h>
#include <ATen/ops/max.h>
#include <ATen/ops/mm_native.h>
#include <ATen/ops/mul.h>
#include <ATen/ops/relu.h>
#include <ATen/ops/ones.h>
#include <ATen/ops/scalar_tensor_native.h>
#include <ATen/ops/vdot_native.h>
#endif
using at::blas::ScalingType;
using at::blas::SwizzleType;
namespace scaled_blas = at::cuda::scaled;
using scaled_blas::ScaledGemmImplementation;
using scaled_blas::convert_int_to_enum;
using scaled_blas::_scaled_mm_allowed_device;
namespace at::native {
namespace {
// 2d-2d and 2d-3d
// scaling=MXFP8
// CUDA-only
Tensor&
_mx8_mx8_bf16_grouped_mm_fbgemm(
const Tensor& mat_a,
const Tensor& mat_b,
const Tensor& scale_a,
const SwizzleType& swizzle_a,
const Tensor& scale_b,
const SwizzleType& swizzle_b,
const std::optional<at::Tensor>& offs,
Tensor& out) {
const bool a_is_2d = mat_a.dim() == 2;
const bool b_is_2d = mat_b.dim() == 2;
bool b_is_3d = mat_b.dim() == 3;
bool is_2d_2d = a_is_2d && b_is_2d;
bool is_2d_3d = a_is_2d && b_is_3d;
TORCH_CHECK_VALUE(is_2d_2d || is_2d_3d, "MXFP8 grouped GEMM currently only supports 2d-2d and 2d-3d cases");
TORCH_CHECK_VALUE(offs.has_value(), "MXFP8 2d-2d and 2d-3d grouped GEMMs requires offsets");
TORCH_CHECK_VALUE(out.scalar_type() == at::kBFloat16, "Only bf16 out_dtype is supported for MXFP8 grouped gemm");
// MXFP8 expects float8_e8m0fnu scales.
TORCH_CHECK_VALUE(scale_a.scalar_type() == at::kFloat8_e8m0fnu && scale_b.scalar_type() == at::kFloat8_e8m0fnu,
"For MXFP8 grouped gemm, both scales must be float8_e8m0fnu tensors.");
#ifdef USE_ROCM
TORCH_CHECK_VALUE(swizzle_a == SwizzleType::NO_SWIZZLE && swizzle_b == SwizzleType::NO_SWIZZLE,
"For ROCM MXFP8 grouped gemm, both scale swizzle types must be SWIZZLE_NONE");
#else
TORCH_CHECK_VALUE(swizzle_a == SwizzleType::SWIZZLE_32_4_4 && swizzle_b == SwizzleType::SWIZZLE_32_4_4,
"For CUDA MXFP8 grouped gemm, both scale swizzle types must be SWIZZLE_32_4_4");
#endif
#if defined(USE_FBGEMM_GENAI) and !defined(USE_ROCM)
fbgemm_gpu::mx8mx8bf16_grouped_mm(
mat_a,
mat_b,
scale_a,
scale_b,
offs.value(),
out);
#else
TORCH_CHECK_NOT_IMPLEMENTED(false, "mxfp8_mxfp8 grouped gemm requires compile with USE_FBGEMM_GENAI");
#endif
return out;
}
// 2d-2d and 2d-3d cases
// scaling=rowwise
// CUDA-only
Tensor&
_f8_f8_bf16_rowwise_grouped_mm_cuda(
const Tensor& mat_a,
const Tensor& mat_b,
const Tensor& scale_a,
const Tensor& scale_b,
const std::optional<Tensor>& offs,
const std::optional<Tensor>& bias,
const bool use_fast_accum,
Tensor& out) {
TORCH_CHECK_VALUE(mat_a.dtype() == at::kFloat8_e4m3fn, "Expected mat_a to be Float8_e4m3 matrix got ", mat_a.scalar_type());
TORCH_CHECK_VALUE(mat_b.dtype() == at::kFloat8_e4m3fn, "Expected mat_a to be Float8_e4m3 matrix got ", mat_b.scalar_type());
at::cuda::detail::f8f8bf16_grouped_mm(
mat_a,
mat_b,
scale_a,
scale_b,
offs,
bias,
use_fast_accum,
out);
return out;
}
// 2d-2d and 2d-3d cases
// scaling=rowwise
// only being called for rocm
Tensor&
_f8_f8_bf16_rowwise_grouped_mm_rocm(
const Tensor& mat_a,
const Tensor& mat_b,
const Tensor& scale_a,
const Tensor& scale_b,
const std::optional<Tensor>& offs,
Tensor& out) {
TORCH_CHECK_VALUE(mat_a.dtype() == at::kFloat8_e4m3fnuz, "Expected mat_a to be Float8_e4m3fnuz matrix got ", mat_a.scalar_type());
TORCH_CHECK_VALUE(mat_b.dtype() == at::kFloat8_e4m3fnuz, "Expected mat_a to be Float8_e4m3fnuz matrix got ", mat_b.scalar_type());
#if defined(USE_FBGEMM_GENAI) && defined(USE_ROCM)
fbgemm_gpu::f8f8bf16_rowwise_grouped_mm(
mat_a,
// FBGEMM expects B matrix shape to be (.., N, K)
mat_b.transpose(-2, -1),
scale_a,
scale_b,
offs,
out);
#else
TORCH_CHECK_NOT_IMPLEMENTED(false, "grouped gemm is not supported without USE_FBGEMM_GENAI on ROCM")
#endif
return out;
}
// Dispatch f8 x f8 -> bf16 row-wise scaled to rocm/cuda
Tensor&
_f8_f8_bf16_rowwise_grouped_mm(
const Tensor& mat_a,
const Tensor& mat_b,
const Tensor& scale_a,
const Tensor& scale_b,
const std::optional<Tensor>& offs,
const std::optional<Tensor>& bias,
bool use_fast_accum,
Tensor& out) {
// FP8 per-tensor and per-row scaling expect fp32 scales.
TORCH_CHECK_VALUE(scale_a.scalar_type() == kFloat && scale_b.scalar_type() == kFloat,
"For grouped FP8 rowwise, both scales must be float32 tensors");
#ifndef USE_ROCM
return _f8_f8_bf16_rowwise_grouped_mm_cuda(
mat_a,
mat_b,
scale_a,
scale_b,
offs,
bias,
use_fast_accum,
out);
#else
// NOTE: ignore use_fast_accum
TORCH_CHECK_VALUE(!bias.has_value(), "ROCM grouped gemm does not support bias")
return _f8_f8_bf16_rowwise_grouped_mm_rocm(
mat_a,
mat_b,
scale_a,
scale_b,
offs,
out);
#endif
}
void _check_scales_fp8_rowwise(const Tensor& mat, const Tensor& scale, const int dim, const int arg_idx, const int scale_multiplier=1) {
// Checks scales for 2d or 3d target tensors (`mat`).
if (mat.dim() == 2) {
TORCH_CHECK(
scale.dim() == 1,
"scale must be a 1D tensor, but got ",
scale.dim(),
"D, arg ",
arg_idx);
TORCH_CHECK(
scale.is_contiguous(), "scale must be contiguous for arg ", arg_idx);
TORCH_CHECK(
scale.size(0) == mat.size(dim) * scale_multiplier,
"scale must have the same length as mat for arg ",
arg_idx);
} else {
TORCH_CHECK(
scale.dim() == 2,
"scale must be a 2D tensor, but got ",
scale.dim(),
"D for arg ",
arg_idx);
TORCH_CHECK(
scale.stride(1) == 1,
"scale must be contiguous in the last dimension for arg ",
arg_idx);
TORCH_CHECK(
scale.size(0) == mat.size(0),
"scale must have the same batch dimension as mat for arg ",
arg_idx);
TORCH_CHECK(
scale.size(1) == mat.size(1 + dim),
"scale must have the same first dimension as mat for arg ",
arg_idx);
}
}
void _check_scales_mxfp8(const Tensor& mat, const Tensor& scale, const int dim, const int arg_idx) {
// Checks scales for 2d or 3d target tensors (`mat`).
if (mat.dim() == 2) {
// For MXFP8, 2d tensors have variable size groups represented as subtensors,
// that are converted to blocked padded format individually,
// so we can't check the scale sizes without doing a d2h sync to get the group sizes here.
TORCH_CHECK(
scale.dim() == mat.dim(),
"for mxfp8, scale must have same number of dimensions as parent tensor, but got mat.dim() = ", mat.dim(), " and scale.dim() = ", scale.dim(), " for arg ", arg_idx);
// LHS mat shape (M, total_K) -> scale shape (rounded_up(M, 128), rounded_up_per_group(K/32, 4))
// RHS mat shape (total_K, N) -> scale shape (rounded_up(N, 128), rounded_up_per_group(K/32, 4))
// * weight is transposed prior to the call, scale stays non-transposed.
bool LHS = arg_idx == 0;
int scale_dim_to_check = 0;
int mat_dim_to_check = LHS ? 0 : 1;
TORCH_CHECK(
scale.size(scale_dim_to_check) >= mat.size(mat_dim_to_check),
"for mxfp8, arg ", arg_idx, " tensor shape (", mat.size(0), ", ", mat.size(1), ") ",
"must have scale.shape[", scale_dim_to_check, "] >= ", mat.size(mat_dim_to_check), " but got scale.shape=(", scale.size(0), ", ", scale.size(1), ")");
} else {
// For MXFP8, 3d tensors have static group sizes (stack of 2d tensors),
// so we can check the exact expected scale sizes here without a d2h sync.
auto round_up = [](auto x, auto y) {
return ((x + y - 1) / y) * y;
};
// TODO: this is for 3d tensor in 2d-3d case specifically.
// We'll need to support 3d-3d and 3d-2d cases once mxfp8 grouped gemm supports them.
int64_t G = mat.size(0);
int64_t K = mat.size(1);
int64_t N = mat.size(2);
int64_t blocked_scale_K = round_up(K/32, 4);
int64_t blocked_scale_N = round_up(N, 128);
// fbgemm expects stack of flattened blocked scales for 3d tensor, shape (G, blocked_scale_K * blocked_scale_N).
TORCH_CHECK(
scale.dim() == mat.dim() - 1,
"for mxfp8 2d-3d grouped GEMM, the 3d tensor of shape (G,K,N) must have a 2d scale of shape (G, blocked_scale_K * blocked_scale_N), but scale is ", scale.dim(), "D for arg ", arg_idx
);
TORCH_CHECK(
scale.size(0) == G && scale.size(1) == blocked_scale_K * blocked_scale_N,
"for mxfp8, the tensor shape (", G, ", ", K, ", ", N, ") must have scale shape (", G, ",", blocked_scale_K, ",", blocked_scale_N, ") for arg ", arg_idx
);
}
}
void check_scale(const Tensor& mat, const Tensor& scale, const int dim, const int arg_idx, const int scale_multiplier=1) {
bool using_fp8_rowwise = scale.scalar_type() == kFloat;
bool using_mxfp8 = scale.scalar_type() == at::kFloat8_e8m0fnu;
if (using_fp8_rowwise) {
_check_scales_fp8_rowwise(mat, scale, dim, arg_idx, scale_multiplier);
} else if (using_mxfp8) {
_check_scales_mxfp8(mat, scale, dim, arg_idx);
} else {
TORCH_CHECK(false, "scale must be float32 or float8_e8m0fnu, but got ", scale.dtype());
}
}
} // namespace
Tensor
_scaled_grouped_mm_cuda(
const Tensor& mat_a,
const Tensor& mat_b,
const Tensor& scale_a,
const Tensor& scale_b,
const std::optional<at::Tensor>& offs,
const std::optional<at::Tensor>& bias,
const std::optional<at::Tensor>& scale_result,
std::optional<c10::ScalarType> out_dtype,
bool use_fast_accum) {
bool allowed_device = _scaled_mm_allowed_device(/*sm90_only*/true, /*sm100_only*/true);
TORCH_CHECK_VALUE(allowed_device, "torch._scaled_grouped_mm is only supported on CUDA devices with compute capability = [9.0, 10.0], or ROCm MI300+");
TORCH_CHECK_VALUE(!check_valid_strides_and_return_transposed(mat_a), "Expected mat1 to not be transposed");
TORCH_CHECK_VALUE(check_valid_strides_and_return_transposed(mat_b), "Expected mat2 to be transposed");
TORCH_CHECK_VALUE(mat_a.dim() == 2 || mat_a.dim() == 3, "mat_a has to be 2 or 3d");
TORCH_CHECK_VALUE(mat_b.dim() == 2 || mat_b.dim() == 3, "mat_b has to be 2 or 3d");
const bool a_is_2d = mat_a.dim() == 2;
const bool b_is_2d = mat_b.dim() == 2;
// NOTE(slayton): For sub-1B formats want contraction_dim argument?
if (!a_is_2d || !b_is_2d) {
TORCH_CHECK_VALUE(mat_a.size(-1) == mat_b.size(-2), "contraction dimension of mat_a and mat_b must match");
}
TORCH_CHECK_VALUE(
mat_a.size(-1) % 16 == 0,
"Expected trailing dimension of mat_a to be divisible by 16 ",
"but got mat1 shape: (",
mat_a.sizes(),
").");
TORCH_CHECK_VALUE(mat_b.size(-2) % 16 == 0 && mat_b.size(-1) % 16 == 0,
"Expected mat_b shape to be divisible by 16 ",
"but got mat_b shape: (",
mat_b.sizes(),
").");
TORCH_CHECK_VALUE(!bias.has_value(), "Bias not supported yet");
TORCH_CHECK_VALUE(!scale_result.has_value(), "Scale result not supported yet");
TORCH_CHECK_VALUE(offs.has_value() == (a_is_2d || b_is_2d), "Have to provide offsets if there is a 2d matrix");
// NOTE: mxfp8 x mxfp8 requires (and asserts later) that offsets is present.
// for rowwise, no offsets implies 3d-3d and is handled by lower-level
// routines
if (offs.has_value()) {
TORCH_CHECK_VALUE(offs->dim() == 1, "offs has to be 1D");
TORCH_CHECK_VALUE(offs->dtype() == at::kInt, "Offsets have to be int32");
}
// FP8 per-tensor and per-row scaling expect fp32 scales.
// MXFP8 expects float8_e8m0fnu scales.
TORCH_CHECK_VALUE(
(scale_a.scalar_type() == kFloat && scale_b.scalar_type() == kFloat) ||
(scale_a.scalar_type() == at::kFloat8_e8m0fnu && scale_b.scalar_type() == at::kFloat8_e8m0fnu),
"For FP8 tensorwise and rowwise, both scales must both be float32 tensors. For MXFP8, scales must both be float8_e8m0fnu tensors.");
const int scale_multiplier = (mat_a.dim() == 2 && mat_b.dim() == 2) ? offs->size(0) : 1;
check_scale(mat_a, scale_a, 0 ,0, scale_multiplier);
check_scale(mat_b, scale_b, 1, 1, scale_multiplier);
const auto out_dtype_ = out_dtype.value_or(kBFloat16);
TORCH_CHECK_VALUE(out_dtype_ == kBFloat16, "Only bf16 high precision output types are supported for grouped gemm");
Tensor out = create_grouped_gemm_output_tensor(mat_a, mat_b, offs, out_dtype_);
#if defined(USE_FBGEMM_GENAI) && defined(USE_CUDA) && !defined(USE_ROCM)
// MXFP8 grouped GEMM dispatching
bool is_mx8mx8bf16 = (
mat_a.scalar_type() == at::kFloat8_e4m3fn && mat_b.scalar_type() == at::kFloat8_e4m3fn &&
scale_a.scalar_type() == at::kFloat8_e8m0fnu && scale_b.scalar_type() == at::kFloat8_e8m0fnu
);
#else
bool is_mx8mx8bf16 = false;
#endif
if (is_mx8mx8bf16) {
// Note: Passing implied SwizzleType here, correctness of scale previously checked
// in `check_scale` call
return _mx8_mx8_bf16_grouped_mm_fbgemm(
mat_a,
mat_b,
scale_a,
SwizzleType::SWIZZLE_32_4_4,
scale_b,
SwizzleType::SWIZZLE_32_4_4,
offs.value(),
out);
}
// If we're not MXFP8, then we're row-wise scaling.
return _f8_f8_bf16_rowwise_grouped_mm(
mat_a,
mat_b,
scale_a,
scale_b,
offs,
bias,
use_fast_accum,
out);
}
namespace {
using acceptance_fn = std::function<bool(c10::ScalarType, std::vector<ScalingType>&, ArrayRef<Tensor>&, c10::ScalarType, std::vector<ScalingType>&, ArrayRef<Tensor>&)>;
std::array<std::tuple<std::string, acceptance_fn, ScaledGemmImplementation>, 2> scale_grouped_kernel_dispatch = {{
{ "rowwise_rowwise", scaled_blas::check_rowwise_recipe, ScaledGemmImplementation::ROWWISE_ROWWISE},
{ "mxfp8_mxfp8", scaled_blas::check_mxfp8_recipe, ScaledGemmImplementation::MXFP8_MXFP8}}};
} // anonymous namespace
Tensor
_scaled_grouped_mm_cuda_v2(
const Tensor& mat_a, const Tensor& mat_b,
ArrayRef<Tensor> scale_a,
IntArrayRef scale_recipe_a,
IntArrayRef swizzle_a,
ArrayRef<Tensor> scale_b,
IntArrayRef scale_recipe_b,
IntArrayRef swizzle_b,
const std::optional<Tensor>& offs,
const std::optional<Tensor>& bias,
const std::optional<c10::ScalarType> out_dtype,
IntArrayRef contraction_dim,
bool use_fast_accum) {
bool allowed_device = _scaled_mm_allowed_device(/*sm90_only*/true, /*sm100_only*/true);
TORCH_CHECK_VALUE(allowed_device, "torch._scaled_grouped_mm is only supported on CUDA devices with compute capability = [9.0, 10.0], or ROCm MI300+");
TORCH_CHECK_VALUE(!check_valid_strides_and_return_transposed(mat_a), "Expected mat1 to not be transposed");
TORCH_CHECK_VALUE(check_valid_strides_and_return_transposed(mat_b), "Expected mat2 to be transposed");
TORCH_CHECK_VALUE(mat_a.dim() == 2 || mat_a.dim() == 3, "mat_a has to be 2 or 3d");
TORCH_CHECK_VALUE(mat_b.dim() == 2 || mat_b.dim() == 3, "mat_b has to be 2 or 3d");
const bool a_is_2d = mat_a.dim() == 2;
const bool b_is_2d = mat_b.dim() == 2;
// NOTE(slayton): For sub-1B formats want contraction_dim argument?
if (!a_is_2d || !b_is_2d) {
if (contraction_dim.size() > 0) {
const int dim_a = contraction_dim[0], dim_b = mat_b.size(contraction_dim[1]);
TORCH_CHECK_VALUE(mat_a.size(dim_a) == mat_b.size(dim_b),
"Contraction dimensions (", dim_a, ",", dim_b, ") of mat_a and mat_b must match, got: ", mat_a.size(dim_a), " and ",
mat_b.size(dim_b));
// Note: only (-1, -2) is currently supported
TORCH_CHECK_VALUE(dim_a == -1 && dim_b == -2, "Curently contraction dims must be (-1, -2) only");
} else {
TORCH_CHECK_VALUE(mat_a.size(-1) == mat_b.size(-2), "contraction dimension of mat_a and mat_b must match");
}
}
TORCH_CHECK_VALUE(
mat_a.size(-1) % 16 == 0,
"Expected trailing dimension of mat_a to be divisible by 16 ",
"but got mat1 shape: (",
mat_a.sizes(),
").");
TORCH_CHECK_VALUE(mat_b.size(-2) % 16 == 0 && mat_b.size(-1) % 16 == 0,
"Expected mat_b shape to be divisible by 16 ",
"but got mat_b shape: (",
mat_b.sizes(),
").");
TORCH_CHECK_VALUE(!bias.has_value(), "Bias not supported yet");
TORCH_CHECK_VALUE(offs.has_value() == (a_is_2d || b_is_2d), "Have to provide offsets if there is a 2d matrix");
// NOTE: mxfp8 x mxfp8 requires (and asserts later) that offsets is present.
// for rowwise, no offsets implies 3d-3d and is handled by lower-level
// routines
if (offs.has_value()) {
TORCH_CHECK_VALUE(offs->dim() == 1, "offs has to be 1D");
TORCH_CHECK_VALUE(offs->dtype() == at::kInt, "Offsets have to be int32");
}
const auto out_dtype_ = out_dtype.value_or(kBFloat16);
TORCH_CHECK_VALUE(out_dtype_ == kBFloat16, "Only bf16 high precision output types are supported for grouped gemm");
Tensor out = create_grouped_gemm_output_tensor(mat_a, mat_b, offs, out_dtype_);
// Conversion of implicitly-defined enums to explicit
auto scale_recipe_a_enum = convert_int_to_enum<ScalingType>(scale_recipe_a);
auto swizzle_a_enum = convert_int_to_enum<SwizzleType>(swizzle_a);
auto scale_recipe_b_enum = convert_int_to_enum<ScalingType>(scale_recipe_b);
auto swizzle_b_enum = convert_int_to_enum<SwizzleType>(swizzle_b);
// at this point we can start working out what we want to be doing
// Try to do as few steps as possible.
// NOTE: support is deliberately sparse, can explicitly enumerate all combinations allowed.
// Do this via a list of defined (name, acceptance, concrete_impl) tuples.
ScaledGemmImplementation gemm_impl = ScaledGemmImplementation::NONE;
for (const auto& fn_entry : scale_grouped_kernel_dispatch) {
const auto [name, accept_fn, scaled_gemm_impl] = fn_entry;
bool ok = accept_fn(mat_a.scalar_type(),
scale_recipe_a_enum,
scale_a,
mat_b.scalar_type(),
scale_recipe_b_enum,
scale_b);
if (ok) {
gemm_impl = scaled_gemm_impl;
break;
}
}
TORCH_CHECK_VALUE(gemm_impl != ScaledGemmImplementation::NONE,
"No gemm implementation was found");
switch (gemm_impl) {
case ScaledGemmImplementation::ROWWISE_ROWWISE: {
const int scale_multiplier = (mat_a.dim() == 2 && mat_b.dim() == 2) ? offs->size(0) : 1;
_check_scales_fp8_rowwise(mat_a, scale_a[0], 0 /* dim */ , 0 /* arg_idx */, scale_multiplier);
_check_scales_fp8_rowwise(mat_b, scale_b[0], 1 /* dim */ , 1 /* arg_idx */, scale_multiplier);
return _f8_f8_bf16_rowwise_grouped_mm(
mat_a,
mat_b,
scale_a[0],
scale_b[0],
offs,
bias,
use_fast_accum,
out);
}
case ScaledGemmImplementation::MXFP8_MXFP8: {
_check_scales_mxfp8(mat_a, scale_a[0], 0 /* dim */, 0 /* arg_idx */);
_check_scales_mxfp8(mat_b, scale_b[0], 1 /* dim */, 1 /* arg_idx */);
return _mx8_mx8_bf16_grouped_mm_fbgemm(
mat_a,
mat_b,
scale_a[0],
swizzle_a_enum[0],
scale_b[0],
swizzle_b_enum[0],
offs.value(),
out);
}
default:
TORCH_CHECK_NOT_IMPLEMENTED(false,
"_scaled_grouped_mm_cuda_v2 is in an inconsistent state - should never reach here");
}
}
Tensor _grouped_mm_cuda(const Tensor& mat_a, const Tensor& mat_b,
const std::optional<at::Tensor>& offs,
const std::optional<at::Tensor>& bias,
std::optional<c10::ScalarType> out_dtype) {
_grouped_mm_validate_inputs(mat_a, mat_b, offs, bias, out_dtype);
bool a_b_and_out_are_bf16 = (
mat_a.dtype() == at::kBFloat16 &&
mat_b.dtype() == at::kBFloat16 &&
out_dtype.value_or(at::kBFloat16) == at::kBFloat16
);
#ifndef USE_ROCM
bool use_fast_path = _scaled_mm_allowed_device(/*sm90_only*/true, /*sm100_only*/true) && a_b_and_out_are_bf16;
#else
// _scaled_mm_allowed_device is used here within _grouped_mm_cuda which seems incorrect since scale is not used.
// the _grouped_mm_fallback should be safe for any ROCm GPU since it's just calling typical mm/bmm
bool use_fast_path = false;
#endif
const auto out_dtype_ = _resolve_grouped_mm_out_dtype(mat_a, mat_b, out_dtype);
Tensor out = create_grouped_gemm_output_tensor(mat_a, mat_b, offs, out_dtype_);
if (use_fast_path) {
// fast path, no d2h sync needed
at::cuda::detail::bf16bf16_grouped_mm(mat_a, mat_b, offs, bias, out);
} else {
_grouped_mm_fallback(mat_a, mat_b, offs, bias, out_dtype, out);
}
return out;
}
} // namespace at::native

View File

@ -1,18 +1,17 @@
#pragma once
#include <ATen/OpMathType.h>
#include <ATen/cuda/detail/OffsetCalculator.cuh>
#include <ATen/detail/FunctionTraits.h>
#include <ATen/native/TensorIterator.h>
#include <ATen/native/TensorIteratorDynamicCasting.h>
#include <ATen/cuda/detail/OffsetCalculator.cuh>
#include <ATen/OpMathType.h>
#include <ATen/native/cuda/thread_constants.h>
#include <thrust/tuple.h>
#include <ATen/native/cuda/MemoryAccess.cuh>
#include <tuple>
namespace at::native {
template<int N>
@ -62,7 +61,11 @@ __device__ inline void elementwise_kernel_helper(func_t f, policy_t policy) {
#pragma unroll
for (int i = 0; i < elems_per_thread; i++) {
if (policy.check_inbounds(i)) {
#if defined(__HIP__)
results[i] = c10::guts::apply(f, args[i]);
#else
results[i] = std::apply(f, args[i]);
#endif
}
}

View File

@ -23,7 +23,7 @@ namespace at::native {
// The maximum number of threads in a block
#if defined(USE_ROCM)
constexpr int MAX_BLOCK_SIZE = 256;
constexpr int MAX_BLOCK_SIZE = 1024;
#else
constexpr int MAX_BLOCK_SIZE = 512;
#endif
@ -33,7 +33,7 @@ constexpr unsigned MAX_GRID_SIZE = 65535u;
// Number of threads in a block given an input size up to MAX_BLOCK_SIZE
static int getNumThreads(int nElem) {
#if defined(USE_ROCM)
int threadSizes[5] = { 16, 32, 64, 128, MAX_BLOCK_SIZE };
int threadSizes[5] = { 64, 128, 256, 512, MAX_BLOCK_SIZE };
#else
int threadSizes[5] = { 32, 64, 128, 256, MAX_BLOCK_SIZE };
#endif
@ -115,9 +115,23 @@ __device__ scalar_t reduce(Op op, PTA tensor, int plane) {
// first the reductions each thread does separately
scalar_t sum = static_cast<scalar_t>(0);
for (int batch = threadIdx.y; batch < tensor.size(0); batch += blockDim.y) {
#if defined(USE_ROCM)
constexpr int UNRL = 4; // load deserilize factor
scalar_t tmp[UNRL];
for (int x = threadIdx.x; x < tensor.size(2); x += blockDim.x*UNRL) {
#pragma unroll
for (int u = 0; u < UNRL; u++)
tmp[u] = op(batch, plane, std::min((int)tensor.size(2)-1, (int)(x+u*blockDim.x)));
#pragma unroll
for (int u = 0; u < UNRL; u++)
if (x+u*blockDim.x < tensor.size(2))
sum += tmp[u];
}
#else
for (int x = threadIdx.x; x < tensor.size(2); x += blockDim.x) {
sum += op(batch, plane, x);
}
#endif
}
__shared__ scalar_t shared[C10_WARP_SIZE];
SumReduceOp<scalar_t> reduce_op;
@ -292,6 +306,22 @@ __global__ void batch_norm_collect_statistics_kernel(
stat_accscalar_t var_n = 0;
int n = 0;
for (int batch = threadIdx.y; batch < input.size(0); batch += blockDim.y) {
#if defined(USE_ROCM)
constexpr int UNRL = 4;
stat_accscalar_t v_[UNRL];
for (int x = threadIdx.x; x < input.size(2); x += blockDim.x*UNRL) {
for (int u = 0; u < UNRL; u++)
v_[u] = input[batch][plane][min(x+u*blockDim.x, input.size(2)-1)];
for (int u = 0; u < UNRL; u++) {
if (x+u*blockDim.x < input.size(2)) {
stat_accscalar_t d1 = v_[u] - avg;
n++;
avg += d1 / n;
var_n += d1 * (v_[u] - avg);
}
}
}
#else
for (int x = threadIdx.x; x < input.size(2); x += blockDim.x) {
stat_accscalar_t v = input[batch][plane][x];
stat_accscalar_t d1 = v - avg;
@ -299,6 +329,7 @@ __global__ void batch_norm_collect_statistics_kernel(
avg += d1 / n;
var_n += d1 * (v - avg);
}
#endif
}
// first warpSum to get one value per thread to

View File

@ -92,6 +92,16 @@ inline thrust::pair<int64_t, int64_t> get_index_mapping2d(
output_offset + output_y * output_dim_x + output_x);
}
__device__ __forceinline__ int64_t reflect_index(int64_t x, int64_t len) {
const int64_t two = (len - 1) * 2;
if (two <= 0) {
return 0;
}
int64_t m = x % two;
if (m < 0) m += two;
return (m < len) ? m : (two - m);
}
template<typename scalar_t>
__global__ void reflection_pad1d_out_kernel(
const scalar_t * input, scalar_t * output,
@ -106,6 +116,28 @@ __global__ void reflection_pad1d_out_kernel(
}
}
template <typename scalar_t>
__global__ void reflection_pad1d_flat(
const scalar_t* __restrict__ input,
scalar_t* __restrict__ output,
int64_t input_w, int64_t pad_l, int64_t pad_r,
int64_t out_w, int64_t plane_count) {
const int64_t bx = blockDim.x;
const int64_t tx = threadIdx.x;
const int64_t total = plane_count * out_w;
const int64_t grid_stride = static_cast<int64_t>(bx) * gridDim.x;
int64_t linear = static_cast<int64_t>(blockIdx.x) * bx + tx;
for (; linear < total; linear += grid_stride) {
const int64_t plane = linear / out_w;
const int64_t x = linear - plane * out_w;
const int64_t j = reflect_index(x - pad_l, input_w);
output[plane * out_w + x] = input[plane * input_w + j];
}
}
template <typename scalar_t>
__global__ void reflection_pad1d_backward_out_kernel(
scalar_t * grad_input, const scalar_t * grad_output,
@ -710,25 +742,44 @@ TORCH_IMPL_FUNC(reflection_pad1d_out_cuda)
int64_t input_w = input_.size(dim_w);
int64_t output_w = input_w + pad_l + pad_r;
dim3 block_size(output_w > 256 ? 256 : output_w);
dim3 grid_size((int)::ceil(output_w / 256.0), nplane, nbatch);
Tensor input = input_.contiguous();
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(
kHalf, kBFloat16, input.scalar_type(), "reflection_pad1d_out_template", [&] {
reflection_pad1d_out_kernel<<<
grid_size,
block_size,
0,
at::cuda::getCurrentCUDAStream()>>>(
input.const_data_ptr<scalar_t>(),
output.mutable_data_ptr<scalar_t>(),
input_w,
pad_l,
pad_r);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
const int block_x = static_cast<int>(std::min<int64_t>(256, std::max<int64_t>(1, output_w)));
const cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
const int max_x = prop->maxGridSize[0];
const int max_y = prop->maxGridSize[1];
const int max_z = prop->maxGridSize[2];
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kHalf, kBFloat16, input.scalar_type(), "reflection_pad1d_out", [&] {
auto stream = at::cuda::getCurrentCUDAStream();
const int64_t gx = at::ceil_div(output_w, static_cast<int64_t>(block_x));
const bool fits3d = (nplane <= max_y) && (nbatch <= max_z) && (gx <= max_x);
if (fits3d) {
dim3 block(block_x, 1, 1);
dim3 grid(gx, static_cast<unsigned>(nplane), static_cast<unsigned>(nbatch));
reflection_pad1d_out_kernel<scalar_t><<<grid, block, 0, stream>>>(
input.const_data_ptr<scalar_t>(),
output.mutable_data_ptr<scalar_t>(),
input_w, pad_l, pad_r);
} else {
dim3 block(block_x, 1, 1);
const int64_t plane_count = nplane * nbatch;
const int64_t total_blocks = at::ceil_div(plane_count * output_w, static_cast<int64_t>(block_x));
const int grid_x = static_cast<int>(std::min<int64_t>(max_x, std::max<int64_t>(1, total_blocks)));
dim3 grid(grid_x, 1, 1);
reflection_pad1d_flat<scalar_t><<<grid, block, 0, stream>>>(
input.const_data_ptr<scalar_t>(),
output.mutable_data_ptr<scalar_t>(),
input_w, pad_l, pad_r, output_w, plane_count);
}
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
}
TORCH_IMPL_FUNC(reflection_pad1d_backward_out_cuda)(const Tensor& grad_output_,

View File

@ -43,6 +43,12 @@ std::tuple<Tensor&, Tensor&> kthvalue_out_impl_cuda(
TORCH_CHECK(k >= 1 && k <= slicesize,
"kthvalue(): selected number k out of range for dimension ", dim);
TORCH_CHECK(
slicesize <= std::numeric_limits<int32_t>::max(),
"kthvalue(): dimension ", dim, " is too large (", slicesize,
"). The current CUDA implementation supports dimension sizes up to ",
std::numeric_limits<int32_t>::max());
at::assert_no_overlap(self, values);
_reduction_with_indices_allocate_or_resize_output(
@ -163,10 +169,6 @@ std::tuple<Tensor&, Tensor&> kthvalue_out_cuda(
bool keepdim,
Tensor& values,
Tensor& indices) {
// See note [Writing Nondeterministic Operations]
// If there are duplicate elements of the kth value, the procedure for choosing which
// of the duplicates to use for the indices output is nondeterministic.
at::globalContext().alertNotDeterministic("kthvalue CUDA");
auto result = [&]() {
NoNamesGuard guard;
// `kthvalue_out_impl_cuda` expects contiguous in input `self`.

View File

@ -65,25 +65,34 @@ __global__ void gatherKthValue(
&kValue);
// Find the index of the k-th highest element
index_t kValueIndex = 0;
bool foundKValue = false;
__shared__ int32_t minIndexFound;
if (threadIdx.x == 0) {
minIndexFound = static_cast<int32_t>(inputSliceSize);
}
__syncthreads();
for (index_t i = threadIdx.x; i < inputSliceSize; i += blockDim.x) {
bool inRange = (i < inputSliceSize);
scalar_t v = inRange ? doLdg(&inputSliceStart[i * inputWithinSliceStride])
: static_cast<scalar_t>(0);
bool isKValue = inRange &&
((v == kValue) || (at::_isnan(v) && at::_isnan(kValue)));
if (isKValue) {
kValueIndex = i;
foundKValue = true;
break;
}
// Early exit based on best-so-far
if (i >= minIndexFound) {
break;
}
scalar_t v = doLdg(&inputSliceStart[i * inputWithinSliceStride]);
bool isKValue =
((v == kValue) || (at::_isnan(v) && at::_isnan(kValue)));
if (isKValue) {
atomicMin(&minIndexFound, static_cast<int32_t>(i));
break;
}
}
if (foundKValue) {
kthValueSliceStart[0] = kValue;
indicesSliceStart[0] = kValueIndex;
__syncthreads();
if (threadIdx.x == 0) {
indicesSliceStart[0] = static_cast<index_t>(minIndexFound);
kthValueSliceStart[0] = kValue;
}
}

View File

@ -127,6 +127,29 @@ __global__ void upsample_bilinear2d_nhwc_out_frame(
}
}
#ifdef USE_ROCM
// Helper function to compute output pixel range that can contribute to input pixel
template <typename accscalar_t>
__device__ __forceinline__ void compute_output_range(
int input_pos,
accscalar_t scale,
int output_size,
bool align_corners,
int& min_output,
int& max_output) {
accscalar_t lo, hi;
if (align_corners) {
lo = static_cast<accscalar_t>(input_pos - 1) / scale;
hi = static_cast<accscalar_t>(input_pos + 1) / scale;
} else {
lo = (input_pos - static_cast<accscalar_t>(0.5)) / scale - static_cast<accscalar_t>(0.5);
hi = (input_pos + static_cast<accscalar_t>(1.5)) / scale - static_cast<accscalar_t>(0.5);
}
min_output = max(0, static_cast<int>(std::ceil(lo)));
max_output = min(output_size - 1, static_cast<int>(std::floor(hi)));
}
#endif
// Backward (adjoint) operation 1 <- 2 (accumulates)
template <typename scalar_t, typename accscalar_t>
C10_LAUNCH_BOUNDS_1(1024)
@ -141,8 +164,74 @@ __global__ void upsample_bilinear2d_backward_out_frame(
const bool align_corners,
scalar_t* __restrict__ idata,
const scalar_t* __restrict__ odata) {
const size_t o_numel = nc * width2 * height2;
// In C++, integer multiplication, like in standard arithmetic, is generally commutative.
const size_t i_numel = nc * width1 * height1;
#ifdef USE_ROCM
for (size_t index = blockDim.x * blockIdx.x + threadIdx.x; index < i_numel;
index += blockDim.x * gridDim.x) {
// Decode input pixel coordinates
size_t index_temp = index;
const int w1 = index_temp % width1;
index_temp /= width1;
const int h1 = index_temp % height1;
const size_t nc_idx = index_temp / height1;
accscalar_t grad_sum = 0;
// Find range of output pixels that could interpolate from this input pixel
int h2_min, h2_max, w2_min, w2_max;
compute_output_range<accscalar_t>(h1, rheight, height2, align_corners, h2_min, h2_max);
compute_output_range<accscalar_t>(w1, rwidth, width2, align_corners, w2_min, w2_max);
// Iterate over potential output pixels
for (int h2 = h2_min; h2 <= h2_max; h2++) {
for (int w2 = w2_min; w2 <= w2_max; w2++) {
// Compute source coordinates for this output pixel
const accscalar_t h1r = area_pixel_compute_source_index<accscalar_t>(
rheight, h2, align_corners, /*cubic=*/false);
const int h1_base = (int)h1r;
const int h1p = (h1_base < height1 - 1) ? 1 : 0;
const accscalar_t h1lambda = h1r - h1_base;
const accscalar_t h0lambda = static_cast<accscalar_t>(1) - h1lambda;
const accscalar_t w1r = area_pixel_compute_source_index<accscalar_t>(
rwidth, w2, align_corners, /*cubic=*/false);
const int w1_base = (int)w1r;
const int w1p = (w1_base < width1 - 1) ? 1 : 0;
const accscalar_t w1lambda = w1r - w1_base;
const accscalar_t w0lambda = static_cast<accscalar_t>(1) - w1lambda;
// Check if our input pixel participates in this interpolation and accumulate all weights
// At boundaries, h1p=0 or w1p=0 causes some sampling positions to collapse
// to the same pixel, so we need to accumulate weights from all matching positions
accscalar_t weight = 0;
// Check all four interpolation positions and accumulate weights
if (h1 == h1_base && w1 == w1_base) {
weight += h0lambda * w0lambda; // top-left
}
if (h1 == h1_base && w1 == w1_base + w1p) {
weight += h0lambda * w1lambda; // top-right (may be same as top-left if w1p=0)
}
if (h1 == h1_base + h1p && w1 == w1_base) {
weight += h1lambda * w0lambda; // bottom-left (may be same as top-left if h1p=0)
}
if (h1 == h1_base + h1p && w1 == w1_base + w1p) {
weight += h1lambda * w1lambda; // bottom-right (may collapse to other positions)
}
if (weight > 0) {
const size_t output_idx = nc_idx * height2 * width2 + h2 * width2 + w2;
grad_sum += weight * static_cast<accscalar_t>(odata[output_idx]);
}
}
}
// Write accumulated gradient (no atomics needed)
idata[index] = static_cast<scalar_t>(grad_sum);
}
#else
const size_t o_numel = nc * width2 * height2;
for (size_t index = blockDim.x * blockIdx.x + threadIdx.x; index < o_numel;
index += blockDim.x * gridDim.x) {
size_t index_temp = index;
@ -191,6 +280,7 @@ __global__ void upsample_bilinear2d_backward_out_frame(
static_cast<scalar_t>(h1lambda * w1lambda * d2val),
true);
}
#endif
}
template <typename scalar_t, typename accscalar_t>
@ -387,7 +477,6 @@ static void upsample_bilinear2d_backward_out_cuda_template(
// threads are not covering the whole input tensor.
grad_input.zero_();
const size_t num_kernels = nbatch * channels * output_height * output_width;
const int num_threads = std::min(
at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock, 1024);
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
@ -397,6 +486,12 @@ static void upsample_bilinear2d_backward_out_cuda_template(
return;
}
#ifdef USE_ROCM
constexpr bool use_input = true;
#else
constexpr bool use_input = false;
#endif
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half, at::ScalarType::BFloat16,
grad_output_.scalar_type(), "upsample_bilinear2d_backward_out_frame", [&] {
@ -414,6 +509,8 @@ static void upsample_bilinear2d_backward_out_cuda_template(
const accscalar_t rwidth = area_pixel_compute_scale<accscalar_t>(
input_width, output_width, align_corners, scales_w);
const size_t num_kernels = nbatch * channels * output_height * output_width;
upsample_bilinear2d_backward_nhwc_out_frame<scalar_t, accscalar_t>
<<<ceil_div(num_kernels, static_cast<size_t>(num_threads)), num_threads, 0, stream>>>(
input_height,
@ -444,6 +541,8 @@ static void upsample_bilinear2d_backward_out_cuda_template(
const accscalar_t rwidth = area_pixel_compute_scale<accscalar_t>(
input_width, output_width, align_corners, scales_w);
const size_t num_kernels = nbatch * channels * (use_input ? input_height * input_width : output_height * output_width);
upsample_bilinear2d_backward_out_frame<scalar_t, accscalar_t>
<<<ceil_div(num_kernels, static_cast<size_t>(num_threads)),
num_threads,

View File

@ -52,7 +52,7 @@ struct FusedAdagradMathFunctor {
using opmath_t = at::opmath_type<scalar_t>;
C10_DEVICE __forceinline__ void operator()(
int chunk_size,
int64_t chunk_size,
FusedOptimizerTensorListMetadata<3>& tl,
const float* lr_ptr,
const double& lr,
@ -133,4 +133,4 @@ struct FusedAdagradMathFunctor {
} // namespace
} // namespace at::native
} // namespace at::native

View File

@ -2,6 +2,7 @@
#include <ATen/WrapDimUtilsMulti.h>
#include <ATen/native/Resize.h>
#include <ATen/native/mkldnn/xpu/detail/oneDNN.h>
#include <ATen/native/xpu/Blas.h>
#include <torch/library.h>
#ifndef AT_PER_OPERATOR_HEADERS
@ -50,9 +51,13 @@ Tensor& addmm_out(
mat1.dtype(),
" != ",
mat2.dtype())
// complex case
TORCH_CHECK(
!mat1.is_complex(), "Complex datatype matmul is not supported in oneDNN");
if (self.is_complex()) {
at::native::addmm_complex_out_xpu(self, mat1, mat2, beta, alpha, result);
return result;
}
std::vector<int64_t> result_shape = {mat1.size(0), mat2.size(1)};
result.resize_(result_shape);
@ -167,8 +172,11 @@ Tensor& mm_out(const Tensor& self, const Tensor& mat2, Tensor& result) {
return result;
}
TORCH_CHECK(
!self.is_complex(), "Complex datatype matmul is not supported in oneDNN");
if (self.is_complex()) {
at::native::mm_complex_out_xpu(self, mat2, result);
return result;
}
onednn::matmul(result, self, mat2, Tensor(), true, onednn::Attr());
return result;
@ -208,9 +216,12 @@ Tensor& baddbmm_out(
input.sizes());
// complex case
TORCH_CHECK(
!batch1.is_complex(),
"Complex datatype matmul is not supported in oneDNN");
if (input.is_complex()) {
at::native::baddbmm_complex_out_xpu(
input, batch1, batch2, beta, alpha, result);
return result;
}
// general case
onednn::Attr attr;
@ -257,8 +268,13 @@ Tensor& bmm_out(const Tensor& self, const Tensor& batch2, Tensor& result) {
return result;
}
TORCH_CHECK(
!self.is_complex(), "Complex datatype matmul is not supported in oneDNN");
// complex case
if (self.is_complex()) {
at::native::bmm_complex_out_xpu(self, batch2, result);
return result;
}
onednn::matmul(result, self, batch2, at::Tensor(), true, onednn::Attr());
return result;
}

View File

@ -222,6 +222,13 @@ struct nextafter_functor {
}
};
struct hypot_functor {
template <typename T>
inline T operator()(const T a, const T b) {
return static_cast<T>(precise::sqrt(float(a) * a + float(b) * b));
}
};
// Complex binary functors
struct polar_functor {
template <typename U>
@ -362,6 +369,7 @@ struct igammac_functor {
REGISTER_OPMATH_BINARY_OP(NAME, half, half); \
REGISTER_OPMATH_BINARY_OP(NAME, bfloat, bfloat)
REGISTER_FLOAT_BINARY_OP(hypot);
REGISTER_FLOAT_BINARY_OP(copysign);
REGISTER_INT2FLOAT_BINARY_OP(copysign);
REGISTER_FLOAT_BINARY_OP(fmax);

View File

@ -0,0 +1,16 @@
#pragma onces
#include <c10/metal/common.h>
template <unsigned N = c10::metal::max_ndim>
struct OrgqrParams {
int32_t num_batch_dims;
uint32_t m;
uint32_t n;
uint32_t k;
::c10::metal::array<uint32_t, N> A_strides;
::c10::metal::array<uint32_t, N> tau_strides;
::c10::metal::array<uint32_t, N> H_strides;
::c10::metal::array<uint32_t, N> H_sizes;
};

View File

@ -1,3 +1,4 @@
#include <ATen/native/mps/kernels/LinearAlgebra.h>
#include <c10/metal/utils.h>
#include <metal_array>
#include <metal_simdgroup>
@ -640,6 +641,164 @@ kernel void applyPivots(
}
}
template <typename T>
static T bool_to_float(bool b) {
return static_cast<T>(b);
}
template <>
half2 bool_to_float(bool b) {
return half2(b ? 1 : 0, 0);
}
template <>
float2 bool_to_float(bool b) {
return float2(b ? 1 : 0, 0);
}
template <typename T>
static T calc_H_irc(
device T* A,
uint32_t A_stride_r,
uint32_t A_stride_c,
constant T* tau,
uint32_t tau_stride,
uint32_t r,
uint32_t c,
uint32_t i) {
T I_val = bool_to_float<T>(r == c);
T tau_val = tau[i * tau_stride];
T A_ci = c10::metal::conj(A[c * A_stride_r + i * A_stride_c]);
T A_ri = A[r * A_stride_r + i * A_stride_c];
T c_eq_i = bool_to_float<T>(c == i);
T r_eq_i = bool_to_float<T>(r == i);
T A_ci_ = (c > i) ? A_ci : c_eq_i;
T A_ri_ = (r > i) ? A_ri : r_eq_i;
return I_val - c10::metal::mul(tau_val, c10::metal::mul(A_ci_, A_ri_));
}
// Calculate (A @ B)[r, c], the element in the r-th row and c-th column of the
// result of matrix multiplying A and B together. A and B must be size m-by-m
// and have the same strides. The formula for this operation, written in Python
// syntax, is:
// (A @ B)[r, c] = A[r, :].dot(B[:, c])
template <typename T>
static T calc_matmul_rc(
device T* A,
device T* B,
uint32_t stride_r,
uint32_t stride_c,
uint32_t m,
uint32_t r,
uint32_t c) {
T AB_rc = 0;
auto A_row_offset = r * stride_r;
auto B_col_offset = c * stride_c;
uint32_t A_col_offset = 0;
uint32_t B_row_offset = 0;
for (uint32_t j = 0; j < m;
j++, A_col_offset += stride_c, B_row_offset += stride_r) {
AB_rc += c10::metal::mul(
A[A_row_offset + A_col_offset], B[B_row_offset + B_col_offset]);
}
return AB_rc;
}
template <typename T>
kernel void orgqr(
device T* A [[buffer(0)]],
constant T* tau [[buffer(1)]],
device T* H [[buffer(2)]],
device T* H_prod [[buffer(3)]],
constant OrgqrParams<>& params [[buffer(4)]],
uint tid [[thread_position_in_grid]]) {
constant auto& A_strides = params.A_strides;
constant auto& tau_strides = params.tau_strides;
constant auto& H_strides = params.H_strides;
constant auto& H_sizes = params.H_sizes;
auto num_batch_dims = params.num_batch_dims;
auto m = params.m;
auto n = params.n;
auto k = params.k;
auto m2 = m * m;
auto batch_idx = tid / m2;
// Find the matrices for this thread's batch index
uint32_t A_offset = 0;
uint32_t tau_offset = 0;
uint32_t H_offset = 0;
for (auto dim = num_batch_dims - 1; dim >= 0; dim--) {
auto dim_size = H_sizes[dim];
auto dim_idx = batch_idx % dim_size;
A_offset += dim_idx * A_strides[dim];
tau_offset += dim_idx * tau_strides[dim];
H_offset += dim_idx * H_strides[dim];
batch_idx /= dim_size;
}
A += A_offset;
tau += tau_offset;
H += H_offset;
H_prod += H_offset;
auto matrix_idx = tid % m2;
auto r = matrix_idx / m;
auto c = matrix_idx % m;
auto A_stride_r = A_strides[num_batch_dims];
auto A_stride_c = A_strides[num_batch_dims + 1];
auto tau_stride = tau_strides[num_batch_dims];
auto H_stride_r = H_strides[num_batch_dims];
auto H_stride_c = H_strides[num_batch_dims + 1];
// Find the element of H and H_prod that this thread will calculate
device T* H_elem_ptr = H + (r * H_stride_r + c * H_stride_c);
device T* H_prod_elem_ptr = H_prod + (r * H_stride_r + c * H_stride_c);
for (uint32_t i = 0; i < k; i++) {
// Calculate and write H_i
T H_irc = calc_H_irc(A, A_stride_r, A_stride_c, tau, tau_stride, r, c, i);
// Calculate element [r, c] of prod(H_0, ..., H_i)
if (i == 0) {
*H_prod_elem_ptr = H_irc;
} else {
*H_elem_ptr = H_irc;
// Need this sync because the below matmul requires all threads to finish
// writing their entries to `H_prod` and `H`.
threadgroup_barrier(mem_flags::mem_threadgroup);
T H_prod_0_to_i_rc =
calc_matmul_rc(H_prod, H, H_stride_r, H_stride_c, m, r, c);
// Need this sync because the above matmul uses the current values in
// `H_prod`, and we don't want to overwrite those until all threads are
// finished using them.
threadgroup_barrier(mem_flags::mem_threadgroup);
*H_prod_elem_ptr = H_prod_0_to_i_rc;
}
}
device T* A_elem_ptr = A + (r * A_stride_r + c * A_stride_c);
if (c < n) {
*A_elem_ptr = *H_prod_elem_ptr;
}
}
#define INSTANTIATE_MM_OPS(DTYPE) \
template [[host_name("matmul_" #DTYPE)]] kernel void matmul<DTYPE>( \
constant DTYPE * mat1Data [[buffer(0)]], \
@ -679,3 +838,19 @@ INSTANTIATE_MM_OPS(int);
INSTANTIATE_MM_OPS(short);
INSTANTIATE_MM_OPS(char);
INSTANTIATE_MM_OPS(uchar);
#define REGISTER_ORGQR(T) \
template [[host_name("orgqr_" #T)]] \
kernel void orgqr<T>( \
device T * A [[buffer(0)]], \
constant T * tau [[buffer(1)]], \
device T * H [[buffer(2)]], \
device T * H_prod [[buffer(3)]], \
constant OrgqrParams<> & params [[buffer(4)]], \
uint tid [[thread_position_in_grid]]);
REGISTER_ORGQR(float);
REGISTER_ORGQR(half);
REGISTER_ORGQR(bfloat);
REGISTER_ORGQR(float2);
REGISTER_ORGQR(half2);

View File

@ -5,6 +5,21 @@
using namespace metal;
using namespace c10::metal;
struct angle_functor {
template <typename T, enable_if_t<is_complex_v<T>, bool> = true>
inline T operator()(const T x) {
return T(atan2(x.y, x.x), 0);
}
template <typename T, enable_if_t<is_scalar_floating_point_v<T>, bool> = true>
inline T operator()(const T x) {
return T(isnan(x) ? x : x < 0 ? M_PI_F : 0.0);
}
template <typename T, enable_if_t<is_scalar_integral_v<T>, bool> = true>
inline float operator()(const T x) {
return x < 0 ? M_PI_F : 0.0;
}
};
// Implement exp wrapper for both real and complex types
template <typename T, enable_if_t<is_scalar_floating_point_v<T>, bool> = true>
inline T exp_(const T x) {
@ -545,6 +560,7 @@ REGISTER_UNARY_OP(abs, float, float);
REGISTER_UNARY_OP(abs, half, half);
#define INSTANTIATE_UNARY_KERNELS2(DTYPE0, DTYPE1) \
REGISTER_UNARY_OP(angle, DTYPE1, DTYPE0); \
REGISTER_UNARY_OP(erf, DTYPE1, DTYPE0); \
REGISTER_UNARY_OP(erfc, DTYPE1, DTYPE0); \
REGISTER_UNARY_OP(erfinv, DTYPE1, DTYPE0); \
@ -583,6 +599,7 @@ INSTANTIATE_UNARY_KERNELS2(float, int);
INSTANTIATE_UNARY_KERNELS2(float, long);
#define INSTANTIATE_UNARY_KERNELS_VEC2(DTYPE) \
REGISTER_UNARY_OP(angle, DTYPE##2, DTYPE##2); \
REGISTER_UNARY_OP(neg, DTYPE##2, DTYPE##2); \
REGISTER_UNARY_OP(exp, DTYPE##2, DTYPE##2); \
REGISTER_UNARY_OP(expm1, DTYPE##2, DTYPE##2); \

View File

@ -92,13 +92,8 @@ static std::tuple<Tensor, Tensor> sdpa_general_mps(const Tensor& query,
}
// upcasting to float32 if needed to improve precision when multiplying by the scale factor
if ([maskedMM dataType] != MPSDataTypeFloat32) {
maskedMM = [mpsGraph castTensor:maskedMM toType:MPSDataTypeFloat32 name:nil];
}
maskedMM = castMPSTensor(mpsGraph, maskedMM, MPSDataTypeFloat32);
maskedMM = [mpsGraph multiplicationWithPrimaryTensor:maskedMM secondaryTensor:scaleTensor name:nil];
if ([maskedMM dataType] != qTensor.dataType) {
maskedMM = [mpsGraph castTensor:maskedMM toType:qTensor.dataType name:nil];
}
if (is_causal) {
auto causalMask = [mpsGraph constantWithScalar:1.0f
@ -112,7 +107,9 @@ static std::tuple<Tensor, Tensor> sdpa_general_mps(const Tensor& query,
name:nil];
} else if (attn_mask) {
graph->maskTensor = mpsGraphRankedPlaceHolder(mpsGraph, *attn_mask);
maskedMM = [mpsGraph additionWithPrimaryTensor:maskedMM secondaryTensor:graph->maskTensor name:nil];
maskedMM = [mpsGraph additionWithPrimaryTensor:maskedMM
secondaryTensor:castMPSTensor(mpsGraph, graph->maskTensor, maskedMM.dataType)
name:nil];
}
// Account for case where all values were masked causing division by 0 in softmax (issue:#156707)
@ -133,8 +130,8 @@ static std::tuple<Tensor, Tensor> sdpa_general_mps(const Tensor& query,
graph->qTensor = qTensor;
graph->kTensor = kTensor;
graph->vTensor = vTensor;
graph->outputTensor = output;
graph->attnTensor = sm;
graph->outputTensor = castMPSTensor(mpsGraph, output, qTensor.dataType);
graph->attnTensor = castMPSTensor(mpsGraph, sm, qTensor.dataType);
});
auto qPlaceholder = Placeholder(cachedGraph->qTensor, query);
auto kPlaceholder = Placeholder(cachedGraph->kTensor, key);

View File

@ -202,6 +202,10 @@ static void igammac_mps_kernel(TensorIteratorBase& iter) {
lib.exec_binary_kernel(iter, "igammac");
}
static void hypot_mps_kernel(TensorIteratorBase& iter) {
lib.exec_binary_kernel(iter, "hypot");
}
REGISTER_DISPATCH(fmax_stub, &fmax_mps_kernel)
REGISTER_DISPATCH(fmin_stub, &fmin_mps_kernel)
REGISTER_DISPATCH(copysign_stub, &copysign_mps_kernel)
@ -229,4 +233,5 @@ REGISTER_DISPATCH(fmod_stub, &fmod_mps_kernel)
REGISTER_DISPATCH(remainder_stub, &remainder_mps_kernel)
REGISTER_DISPATCH(igamma_stub, &igamma_mps_kernel)
REGISTER_DISPATCH(igammac_stub, &igammac_mps_kernel)
REGISTER_DISPATCH(hypot_stub, &hypot_mps_kernel)
} // namespace at::native

View File

@ -16,7 +16,6 @@
#include <ATen/ops/eq_native.h>
#include <ATen/ops/ge_native.h>
#include <ATen/ops/gt_native.h>
#include <ATen/ops/hypot_native.h>
#include <ATen/ops/le_native.h>
#include <ATen/ops/logaddexp2_native.h>
#include <ATen/ops/logaddexp_native.h>
@ -278,22 +277,6 @@ TORCH_IMPL_FUNC(pow_Scalar_out_mps)(const Scalar& base, const Tensor& exp, const
}
}
TORCH_IMPL_FUNC(hypot_out_mps)(const Tensor& self, const Tensor& other, const Tensor& output) {
mps::BinaryOpBlock hypot_op_block = ^BinaryOpFn(cachedGraph, primaryCastTensor, secondaryCastTensor) {
MPSGraph* mpsGraph = cachedGraph->graph();
MPSGraphTensor* twoTensor = [mpsGraph constantWithScalar:2.0 shape:@[ @1 ] dataType:primaryCastTensor.dataType];
MPSGraphTensor* sumTensor = [mpsGraph additionWithPrimaryTensor:[mpsGraph powerWithPrimaryTensor:primaryCastTensor
secondaryTensor:twoTensor
name:nil]
secondaryTensor:[mpsGraph powerWithPrimaryTensor:secondaryCastTensor
secondaryTensor:twoTensor
name:nil]
name:nil];
return [mpsGraph squareRootWithTensor:sumTensor name:nil];
};
mps::binaryOpTensor(self, other, output, "hypot_out_mps", hypot_op_block);
}
TORCH_IMPL_FUNC(logaddexp_out_mps)(const Tensor& self, const Tensor& other, const Tensor& output) {
mps::BinaryOpBlock logaddexp_op_block = ^BinaryOpFn(cachedGraph, primaryCastTensor, secondaryCastTensor) {
MPSGraph* mpsGraph = cachedGraph->graph();

View File

@ -8,6 +8,9 @@
#include <ATen/native/Resize.h>
#include <ATen/native/mps/MPSGraphSequoiaOps.h>
#include <ATen/native/mps/OperationUtils.h>
#include <ATen/native/mps/kernels/LinearAlgebra.h>
#include <fmt/format.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
@ -28,6 +31,7 @@
#include <ATen/ops/linalg_solve_triangular_native.h>
#include <ATen/ops/lu_unpack_native.h>
#include <ATen/ops/mm_native.h>
#include <ATen/ops/orgqr_native.h>
#include <ATen/ops/slice.h>
#include <ATen/ops/stack.h>
#include <ATen/ops/triangular_solve_native.h>
@ -338,6 +342,8 @@ static void linalg_lu_factor_ex_out_mps_impl(const Tensor& A,
". See https://developer.apple.com/documentation/metalperformanceshaders/mpsmatrixdecompositionstatus for details.");
}
}
map_mps_decomposition_error_code_to_blas(info);
}
static void linalg_solve_out_mps_impl(const Tensor& A,
@ -1233,6 +1239,69 @@ static void cholesky_stub_impl(const Tensor& out, const Tensor& info, bool upper
}
}
static Tensor& orgqr_stub_impl(Tensor& self, const Tensor& tau) {
if (self.numel() == 0) {
return self;
}
auto m = self.size(-2);
auto n = self.size(-1);
auto k = tau.size(-1);
if (tau.numel() == 0) {
auto I = eye(m, self.scalar_type(), std::nullopt, self.device());
return self.copy_(I.slice(-1, 0, n));
}
auto num_batch_dims = self.dim() - 2;
auto batch_sizes = self.sizes().slice(0, num_batch_dims);
std::vector<int64_t> H_sizes(num_batch_dims + 2);
for (auto dim : c10::irange(num_batch_dims)) {
H_sizes[dim] = self.size(dim);
}
H_sizes[num_batch_dims] = m;
H_sizes[num_batch_dims + 1] = m;
auto H = at::empty(H_sizes, self.options().memory_format(MemoryFormat::Contiguous));
auto H_prod = at::empty_like(H);
OrgqrParams params;
params.num_batch_dims = num_batch_dims;
params.m = m;
params.n = n;
params.k = k;
for (const auto dim : c10::irange(self.dim())) {
params.A_strides[dim] = self.stride(dim);
if (dim < tau.dim()) {
params.tau_strides[dim] = tau.stride(dim);
}
params.H_strides[dim] = H.stride(dim);
params.H_sizes[dim] = H.size(dim);
}
auto num_threads = H.numel();
MPSStream* stream = getCurrentMPSStream();
dispatch_sync_with_rethrow(stream->queue(), ^() {
@autoreleasepool {
id<MTLComputeCommandEncoder> compute_encoder = stream->commandEncoder();
auto pipeline_state = lib.getPipelineStateForFunc(fmt::format("orgqr_{}", scalarToMetalTypeString(self)));
getMPSProfiler().beginProfileKernel(pipeline_state, "orgqr", {self, tau});
[compute_encoder setComputePipelineState:pipeline_state];
mtl_setArgs(compute_encoder, self, tau, H, H_prod, params);
mtl_dispatch1DJob(compute_encoder, pipeline_state, num_threads);
getMPSProfiler().endProfileKernel(pipeline_state);
}
});
return self;
}
} // namespace mps
Tensor addr_mps(const Tensor& self, const Tensor& vec1, const Tensor& vec2, const Scalar& beta, const Scalar& alpha) {
@ -1448,20 +1517,6 @@ TORCH_IMPL_FUNC(_linalg_solve_ex_out_mps)
mps::linalg_solve_out_mps_impl(A, B, left, check_errors, result, LU, pivots, info);
}
std::tuple<Tensor&, Tensor&> linalg_lu_factor_out_mps(const Tensor& A, bool pivot, Tensor& LU, Tensor& pivots) {
Tensor info = at::empty({}, A.options().dtype(kInt));
mps::linalg_lu_factor_ex_out_mps_impl(A, pivot, LU, pivots, info, false);
return std::tie(LU, pivots);
}
std::tuple<Tensor, Tensor> linalg_lu_factor_mps(const Tensor& A, bool pivot) {
Tensor LU = at::empty({0}, A.options());
Tensor pivots = at::empty({0}, A.options().dtype(kInt));
Tensor info = at::empty({}, A.options().dtype(kInt));
mps::linalg_lu_factor_ex_out_mps_impl(A, pivot, LU, pivots, info, false);
return std::make_tuple(std::move(LU), std::move(pivots));
}
TORCH_IMPL_FUNC(lu_unpack_out_mps)
(const Tensor& LU_data,
const Tensor& LU_pivots,
@ -1483,4 +1538,6 @@ TORCH_IMPL_FUNC(linalg_inv_ex_out_mps)(const Tensor& A, bool check_errors, const
}
REGISTER_DISPATCH(cholesky_stub, mps::cholesky_stub_impl)
REGISTER_DISPATCH(orgqr_stub, mps::orgqr_stub_impl);
} // namespace at::native

View File

@ -34,6 +34,7 @@ REGISTER_UNARY_TI_DISPATCH(sinc);
REGISTER_UNARY_TI_DISPATCH(sinh);
REGISTER_UNARY_TI_DISPATCH(cosh);
REGISTER_UNARY_TI_DISPATCH(tanh);
REGISTER_UNARY_TI_DISPATCH(angle);
REGISTER_UNARY_TI_DISPATCH(abs);
REGISTER_UNARY_TI_DISPATCH(sin);
REGISTER_UNARY_TI_DISPATCH(cos);

View File

@ -12,7 +12,6 @@
#include <ATen/ops/_copy_from_and_resize.h>
#include <ATen/ops/acos_native.h>
#include <ATen/ops/acosh_native.h>
#include <ATen/ops/angle_native.h>
#include <ATen/ops/asin_native.h>
#include <ATen/ops/asinh_native.h>
#include <ATen/ops/atan_native.h>
@ -204,23 +203,6 @@ Tensor& logical_not_out_mps(const Tensor& self, Tensor& output) {
return output;
}
Tensor& angle_out_mps(const Tensor& self, Tensor& output) {
mps::unary_op(self, output, "angle_out_mps", ^MPSGraphTensor*(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) {
auto realPart = [mpsGraph realPartOfTensor:inputTensor name:nil];
auto imagPart = [mpsGraph imaginaryPartOfTensor:inputTensor name:nil];
return [mpsGraph atan2WithPrimaryTensor:imagPart secondaryTensor:realPart name:nil];
});
return output;
}
Tensor angle_mps(const Tensor& self) {
const auto float_type = c10::isIntegralType(self.scalar_type(), /*includeBool=*/true)
? c10::typeMetaToScalarType(c10::get_default_dtype())
: c10::toRealValueType(self.scalar_type());
Tensor result = at::empty({0}, self.options().dtype(float_type));
return angle_out_mps(self, result);
}
TORCH_IMPL_FUNC(frac_out_mps)(const Tensor& self, const Tensor& output) {
TORCH_CHECK(isFloatingType(self.scalar_type()), "frac_out_mps is only implemented for floating types");
mps::unary_op(self, output, "frac_out_mps", ^MPSGraphTensor*(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) {

View File

@ -403,16 +403,14 @@
device_check: NoCheck # TensorIterator
variants: function, method
dispatch:
CPU, CUDA: angle
MPS: angle_mps
CPU, CUDA, MPS: angle
SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: angle_sparse_csr
tags: pointwise
- func: angle.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
device_check: NoCheck # TensorIterator
dispatch:
CPU, CUDA: angle_out
MPS: angle_out_mps
CPU, CUDA, MPS: angle_out
SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: angle_sparse_csr_out
tags: pointwise
@ -10042,8 +10040,7 @@
structured: True
structured_inherits: TensorIteratorBase
dispatch:
CPU, CUDA: hypot_out
MPS: hypot_out_mps
CPU, CUDA, MPS: hypot_out
tags: pointwise
- func: hypot(Tensor self, Tensor other) -> Tensor
@ -14157,16 +14154,10 @@
- func: linalg_lu_factor(Tensor A, *, bool pivot=True) -> (Tensor LU, Tensor pivots)
python_module: linalg
variants: function
dispatch:
CompositeImplicitAutograd: linalg_lu_factor
MPS: linalg_lu_factor_mps
- func: linalg_lu_factor.out(Tensor A, *, bool pivot=True, Tensor(a!) LU, Tensor(b!) pivots) -> (Tensor(a!) LU, Tensor(b!) pivots)
python_module: linalg
variants: function
dispatch:
CompositeImplicitAutograd: linalg_lu_factor_out
MPS: linalg_lu_factor_out_mps
- func: linalg_lu_factor_ex(Tensor A, *, bool pivot=True, bool check_errors=False) -> (Tensor LU, Tensor pivots, Tensor info)
python_module: linalg
@ -14368,12 +14359,12 @@
python_module: linalg
variants: function
dispatch:
CPU, CUDA: linalg_householder_product
CPU, CUDA, MPS: linalg_householder_product
- func: linalg_householder_product.out(Tensor input, Tensor tau, *, Tensor(a!) out) -> Tensor(a!)
python_module: linalg
dispatch:
CPU, CUDA: linalg_householder_product_out
CPU, CUDA, MPS: linalg_householder_product_out
- func: linalg_inv_ex(Tensor A, *, bool check_errors=False) -> (Tensor inverse, Tensor info)
python_module: linalg

View File

@ -40,15 +40,7 @@
#include <thrust/iterator/discard_iterator.h>
#if defined(__CUDACC__) && (defined(CUSPARSE_VERSION) || (defined(USE_ROCM) && ROCM_VERSION >= 60300))
#define IS_CUSPARSE11_AVAILABLE() 1
#else
#define IS_CUSPARSE11_AVAILABLE() 0
#endif
#if IS_CUSPARSE11_AVAILABLE()
#include <library_types.h>
#endif
namespace at::native {
@ -103,17 +95,9 @@ struct csrMatrixRef {
int nnz_{0};
std::vector<int> size_{};
#if IS_CUSPARSE11_AVAILABLE()
cusparseSpMatDescr_t description_{0};
#else
cusparseMatDescr_t description_{0};
#endif
cusparseSpMatDescr_t description_{0};
csrMatrixRef() {
#if !IS_CUSPARSE11_AVAILABLE()
create_general_description_(description_);
#endif
}
csrMatrixRef() = default;
csrMatrixRef(
int* csr_indices,
@ -126,7 +110,6 @@ struct csrMatrixRef {
csr_values_{csr_values},
nnz_{nnz},
size_{size} {
#if IS_CUSPARSE11_AVAILABLE()
cudaDataType cuda_data_type = at::cuda::getCudaDataType<scalar_t>();
TORCH_CUDASPARSE_CHECK(cusparseCreateCsr(
&description_,
@ -140,17 +123,10 @@ struct csrMatrixRef {
CUSPARSE_INDEX_32I,
CUSPARSE_INDEX_BASE_ZERO,
cuda_data_type));
#else
create_general_description_(description_);
#endif
}
~csrMatrixRef() {
#if IS_CUSPARSE11_AVAILABLE()
cusparseDestroySpMat(description_);
#else
cusparseDestroyMatDescr(description_);
#endif
cusparseDestroySpMat(description_);
}
int size(int index) const {
@ -196,8 +172,6 @@ struct csrOutput {
}
};
#if IS_CUSPARSE11_AVAILABLE()
// RAII guard helps to support cuSparse 11 API for `A @ B` operation
// This generic template exists because with cuSparse the `scalar_t` type could be a double or float
template <class scalar_t>
@ -396,284 +370,6 @@ template struct CusparseMatrixMultiplyOp<float>;
template struct CusparseMatrixMultiplyOp<double>;
#else // if not IS_CUSPARSE11_AVAILABLE()
using DcsrMatrixRef = csrMatrixRef<double>;
using ScsrMatrixRef = csrMatrixRef<float>;
// RAII guard helps to support cuSparse 10 API for `A @ B` operation
// This generic template exists because with cuSparse the `scalar_t` type could be a double or float
template <class scalar_t>
struct CusparseMatrixMultiplyOp {
csrOutput operator()(
const csrMatrixRef<scalar_t>& lhs,
const csrMatrixRef<scalar_t>& rhs,
Tensor &output_values,
Tensor &output_indices)
{
static_assert(false&&sizeof(scalar_t), "cusparse csr sparse-sparse MM only supports data type of float and double.");
}
};
// Specializacion for `A @ B` operation for double values with cuSparse
template<> struct CusparseMatrixMultiplyOp<double> {
csrgemm2Info_t gemm2Info_;
CusparseMatrixMultiplyOp() {
TORCH_CUDASPARSE_CHECK(cusparseCreateCsrgemm2Info(&gemm2Info_));
}
~CusparseMatrixMultiplyOp() {
cusparseDestroyCsrgemm2Info(gemm2Info_);
}
csrOutput operator ()(
const DcsrMatrixRef& lhs,
const DcsrMatrixRef& rhs,
Tensor &output_values,
Tensor &output_indices) {
double alpha = 1.0;
DcsrMatrixRef empty;
return Dgemm2(lhs, rhs, empty, &alpha, nullptr, output_values, output_indices);
}
csrOutput Dgemm2(
const DcsrMatrixRef& A,
const DcsrMatrixRef& B,
const DcsrMatrixRef& C,
const double* alpha,
const double* beta,
Tensor &output_values,
Tensor &output_indices) {
void* buffer_{nullptr};
cusparseHandle_t cusparseHandle_ = at::cuda::getCurrentCUDASparseHandle();
TORCH_CUDASPARSE_CHECK(cusparseSetPointerMode(cusparseHandle_, CUSPARSE_POINTER_MODE_HOST));
csrOutput out({A.size(0), B.size(1)});
int innerSize = confirm_mult_size(A.size_, B.size_);
out.csr_pointers_ = at::empty({out.size(0) + 1}, output_indices.options().dtype(kInt));
// Compute needed buffer size
size_t new_bubber_sz;
TORCH_CUDASPARSE_CHECK(cusparseDcsrgemm2_bufferSizeExt(
cusparseHandle_,
out.size(0),
out.size(1),
innerSize,
alpha,
A.description_,
A.nnz_,
A.csr_pointers_,
A.csr_indices_,
B.description_,
B.nnz_,
B.csr_pointers_,
B.csr_indices_,
beta,
C.description_,
C.nnz_,
C.csr_pointers_,
C.csr_indices_,
gemm2Info_,
&new_bubber_sz));
// (Re)allocate buffer if needed
auto& allocator = *::c10::cuda::CUDACachingAllocator::get();
at::DataPtr data_ptr = allocator.allocate(new_bubber_sz);
buffer_ = data_ptr.get();
// Find the resulting non-zero pattern.
TORCH_CUDASPARSE_CHECK(cusparseXcsrgemm2Nnz(
cusparseHandle_,
out.size(0),
out.size(1),
innerSize,
A.description_,
A.nnz_,
A.csr_pointers_,
A.csr_indices_,
B.description_,
B.nnz_,
B.csr_pointers_,
B.csr_indices_,
C.description_,
C.nnz_,
C.csr_pointers_,
C.csr_indices_,
out.description_,
out.csr_pointers_.data_ptr<int>(),
&out.nnz_,
gemm2Info_,
buffer_));
out.csr_indices_ = at::empty({out.nnz_}, output_indices.options().dtype(kInt));
out.csr_values_ = at::empty({out.nnz_}, output_values.options());
// Perform the gemm2 operation for doubles
// out = alpha A B + beta C
TORCH_CUDASPARSE_CHECK(cusparseDcsrgemm2(
cusparseHandle_,
out.size(0),
out.size(1),
innerSize,
alpha,
A.description_,
A.nnz_,
A.csr_values_,
A.csr_pointers_,
A.csr_indices_,
B.description_,
B.nnz_,
B.csr_values_,
B.csr_pointers_,
B.csr_indices_,
beta,
C.description_,
C.nnz_,
C.csr_values_,
C.csr_pointers_,
C.csr_indices_,
out.description_,
out.csr_values_.data_ptr<double>(),
out.csr_pointers_.data_ptr<int>(),
out.csr_indices_.data_ptr<int>(),
gemm2Info_,
buffer_));
return out;
}
};
// Specializacion for `A @ B` operation for float values with cuSparse
template<> struct CusparseMatrixMultiplyOp<float> {
csrgemm2Info_t gemm2Info_;
CusparseMatrixMultiplyOp() {
TORCH_CUDASPARSE_CHECK(cusparseCreateCsrgemm2Info(&gemm2Info_));
}
~CusparseMatrixMultiplyOp() {
cusparseDestroyCsrgemm2Info(gemm2Info_);
}
csrOutput operator()(
const ScsrMatrixRef& lhs,
const ScsrMatrixRef& rhs,
Tensor &output_values,
Tensor &output_indices) {
float alpha = 1.0;
ScsrMatrixRef empty;
return Sgemm2(lhs, rhs, empty, &alpha, nullptr, output_values, output_indices);
}
csrOutput Sgemm2(
const ScsrMatrixRef& A,
const ScsrMatrixRef& B,
const ScsrMatrixRef& C,
const float* alpha,
const float* beta,
Tensor &output_values,
Tensor &output_indices) {
void* buffer_{nullptr};
cusparseHandle_t cusparseHandle_ = at::cuda::getCurrentCUDASparseHandle();
TORCH_CUDASPARSE_CHECK(cusparseSetPointerMode(cusparseHandle_, CUSPARSE_POINTER_MODE_HOST));
csrOutput out({A.size(0), B.size(1)});
int innerSize = confirm_mult_size(A.size_, B.size_);
out.csr_pointers_ = at::empty({out.size(0) + 1}, output_indices.options().dtype(kInt));
// Compute needed buffer size
size_t new_bubber_sz;
TORCH_CUDASPARSE_CHECK(cusparseScsrgemm2_bufferSizeExt(
cusparseHandle_,
out.size(0),
out.size(1),
innerSize,
alpha,
A.description_,
A.nnz_,
A.csr_pointers_,
A.csr_indices_,
B.description_,
B.nnz_,
B.csr_pointers_,
B.csr_indices_,
beta,
C.description_,
C.nnz_,
C.csr_pointers_,
C.csr_indices_,
gemm2Info_,
&new_bubber_sz));
auto& allocator = *::c10::cuda::CUDACachingAllocator::get();
at::DataPtr data_ptr = allocator.allocate(new_bubber_sz);
buffer_ = data_ptr.get();
// Find the resulting non-zero pattern.
TORCH_CUDASPARSE_CHECK(cusparseXcsrgemm2Nnz(
cusparseHandle_,
out.size(0),
out.size(1),
innerSize,
A.description_,
A.nnz_,
A.csr_pointers_,
A.csr_indices_,
B.description_,
B.nnz_,
B.csr_pointers_,
B.csr_indices_,
C.description_,
C.nnz_,
C.csr_pointers_,
C.csr_indices_,
out.description_,
out.csr_pointers_.data_ptr<int>(),
&out.nnz_,
gemm2Info_,
buffer_));
out.csr_indices_ = at::empty({out.nnz_}, output_indices.options().dtype(kInt));
out.csr_values_ = at::empty({out.nnz_}, output_values.options());
// Perform the gemm2 operation for doubles
// out = alpha A B + beta C
TORCH_CUDASPARSE_CHECK(cusparseScsrgemm2(
cusparseHandle_,
out.size(0),
out.size(1),
innerSize,
alpha,
A.description_,
A.nnz_,
A.csr_values_,
A.csr_pointers_,
A.csr_indices_,
B.description_,
B.nnz_,
B.csr_values_,
B.csr_pointers_,
B.csr_indices_,
beta,
C.description_,
C.nnz_,
C.csr_values_,
C.csr_pointers_,
C.csr_indices_,
out.description_,
out.csr_values_.data_ptr<float>(),
out.csr_pointers_.data_ptr<int>(),
out.csr_indices_.data_ptr<int>(),
gemm2Info_,
buffer_));
return out;
}
};
#endif // IS_CUSPARSE11_AVAILABLE()
template <typename scalar_t>
void sparse_sparse_matmul_cuda_kernel(
Tensor& result,
@ -815,19 +511,15 @@ Tensor sparse_sparse_matmul_cuda(const Tensor& mat1_, const Tensor& mat2_) {
auto output = at::native::empty_like(mat1_);
output.sparse_resize_and_clear_({mat1_.size(0), mat2_.size(1)}, mat1_.sparse_dim(), 0);
#if IS_CUSPARSE11_AVAILABLE() && !defined(USE_ROCM)
#if !defined(USE_ROCM)
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kHalf, kBFloat16, mat1_.scalar_type(), "sparse_matmul", [&] {
sparse_sparse_matmul_cuda_kernel<scalar_t>(output, mat1_.coalesce(), mat2_.coalesce());
});
#elif IS_CUSPARSE11_AVAILABLE() && defined(USE_ROCM)
#else
// ROCm does not support half and bfloat16 types for sparse_matmul
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(mat1_.scalar_type(), "sparse_matmul", [&] {
sparse_sparse_matmul_cuda_kernel<scalar_t>(output, mat1_.coalesce(), mat2_.coalesce());
});
#else
AT_DISPATCH_FLOATING_TYPES(mat1_.scalar_type(), "sparse_matmul", [&] {
sparse_sparse_matmul_cuda_kernel<scalar_t>(output, mat1_.coalesce(), mat2_.coalesce());
});
#endif
return output;
}

View File

@ -33,7 +33,7 @@ using namespace mps;
#ifndef PYTORCH_JIT_COMPILE_SHADERS
static auto& lib = MetalShaderLibrary::getBundledLibrary();
#else
#include <ATen/native/mps/Mul_metallib.h>
#include <ATen/native/mps/SparseTensorMath_metallib.h>
#endif
static Tensor& s_addmm_out_sparse_dense_mps(
@ -369,12 +369,7 @@ static SparseTensor& mul_out_dense_sparse_mps(
}
if (scalar_like) {
auto scalar = dense;
if (dense.numel() == 1 && dense.dim() > 0) {
scalar = dense.view({});
}
scalar = scalar.to(values.options());
auto out_vals = values.mul(scalar);
auto out_vals = values.mul(dense.to(values.options()));
if (out.scalar_type() != commonDtype) {
out_vals = out_vals.to(out.scalar_type());
}
@ -508,14 +503,14 @@ SparseTensor& mul_out_sparse_mps(const Tensor& t_, const Tensor& src_, SparseTen
const auto device = r_.device();
auto stream = getCurrentMPSStream();
auto lhs_indices = lhs._indices();
auto rhs_indices = rhs._indices();
auto lhs_values = lhs._values().to(commonDtype);
auto rhs_values = rhs._values().to(commonDtype);
auto lhs_indices = lhs._indices().contiguous();
auto rhs_indices = rhs._indices().contiguous();
auto lhs_values = lhs._values().to(commonDtype).contiguous();
auto rhs_values = rhs._values().to(commonDtype).contiguous();
// Flatten sparse indices to keys
auto lhs_keys = flatten_indices(lhs_indices, lhs.sizes());
auto rhs_keys = flatten_indices(rhs_indices, rhs.sizes());
auto lhs_keys = flatten_indices(lhs_indices, lhs.sizes().slice(0, ndim_i));
auto rhs_keys = flatten_indices(rhs_indices, rhs.sizes().slice(0, ndim_i));
// Intersect sorted keys (search the shorter in the longer)
const bool A_is_lhs = (lhs_nnz <= rhs_nnz);
@ -546,35 +541,54 @@ SparseTensor& mul_out_sparse_mps(const Tensor& t_, const Tensor& src_, SparseTen
auto out_indices = at::empty({ndim_i, static_cast<int64_t>(M)}, at::device(device).dtype(at::kLong));
auto lhs_match = outA_idx.narrow(0, 0, M);
auto rhs_match = outB_idx.narrow(0, 0, M);
auto out_val_sizes = lhs_values.sizes().vec();
out_val_sizes[0] = static_cast<int64_t>(M);
auto dense_sizes_vec = lhs.sizes().slice(ndim_i).vec();
int64_t cols64 = 1;
for (auto s : dense_sizes_vec) cols64 *= s;
const uint32_t cols = static_cast<uint32_t>(std::max<int64_t>(cols64, 1));
auto to2d = [&](Tensor t, int64_t nnz) -> Tensor {
const int64_t t_cols = t.numel() / nnz;
if (t_cols == cols64) {
return t.view({nnz, cols64});
}
return t.view({nnz, 1}).expand({nnz, cols64}).contiguous();
};
// make both sides 2d [nnz, cols] buffers so the kernel can index it
auto lhs_vals2d = to2d(lhs_values, lhs_nnz);
auto rhs_vals2d = to2d(rhs_values, rhs_nnz);
std::vector<int64_t> out_val_sizes;
out_val_sizes.reserve(1 + dense_sizes_vec.size());
out_val_sizes.push_back(static_cast<int64_t>(M));
out_val_sizes.insert(out_val_sizes.end(), dense_sizes_vec.begin(), dense_sizes_vec.end());
auto out_values = at::empty(out_val_sizes, lhs_values.options());
const uint32_t cols = static_cast<uint32_t>(
lhs_values.numel() / std::max<int64_t>(1, lhs_nnz));
if (M > 0) {
dispatch_sync_with_rethrow(stream->queue(), ^() {
@autoreleasepool {
auto pso = lib.getPipelineStateForFunc(
"fused_gather_mul_kernel_" + mps::scalarToMetalTypeString(lhs_values));
auto enc = stream->commandEncoder();
[enc setComputePipelineState:pso];
dispatch_sync_with_rethrow(stream->queue(), ^() {
@autoreleasepool {
auto pso = lib.getPipelineStateForFunc(
"fused_gather_mul_kernel_" + mps::scalarToMetalTypeString(lhs_values));
auto enc = stream->commandEncoder();
[enc setComputePipelineState:pso];
const uint32_t tew = pso.threadExecutionWidth;
const uint32_t gridW = std::max<uint32_t>(cols, 1u);
const uint32_t tgW = std::min(gridW, tew);
MTLSize grid = MTLSizeMake(gridW, 1, M);
MTLSize tgs = MTLSizeMake(tgW, 1, 1);
const uint32_t tew = pso.threadExecutionWidth;
uint32_t tgW = std::min(cols, tew);
MTLSize grid = MTLSizeMake(cols, 1, M);
MTLSize tgs = MTLSizeMake(tgW, 1, 1);
mtl_setArgs(enc,
lhs_values, rhs_values,
lhs_match, rhs_match,
lhs_indices, out_indices,
out_values,
std::array<uint32_t, 2>{static_cast<uint32_t>(ndim_i), static_cast<uint32_t>(lhs_nnz)},
std::array<uint32_t, 2>{M, cols});
[enc dispatchThreads:grid threadsPerThreadgroup:tgs];
}
});
mtl_setArgs(enc,
lhs_vals2d, rhs_vals2d,
lhs_match, rhs_match,
lhs_indices, out_indices,
out_values,
std::array<uint32_t, 2>{static_cast<uint32_t>(ndim_i), static_cast<uint32_t>(lhs_nnz)},
std::array<uint32_t, 2>{M, cols});
[enc dispatchThreads:grid threadsPerThreadgroup:tgs];
}
});
}
if (r_.scalar_type() != commonDtype) {
out_values = out_values.to(r_.scalar_type());

View File

@ -62,7 +62,6 @@ kernel void build_row_ptr_from_sorted_rows_by_batch(
template <typename T>
kernel void spmm_bmm_coo_rows_grouped(
device const long* rows [[buffer(0)]],
device const long* cols [[buffer(1)]],
device const T* vals [[buffer(2)]],
device const T* dense [[buffer(3)]],
@ -73,7 +72,6 @@ kernel void spmm_bmm_coo_rows_grouped(
uint3 ltid [[thread_position_in_threadgroup]],
uint3 tptg [[threads_per_threadgroup]])
{
const uint B = dims.x;
const uint I = dims.y;
const uint J = dims.z;
const uint K = dims.w;
@ -197,9 +195,9 @@ kernel void fused_gather_mul_kernel(
const ulong offR = (ulong)iR * (ulong)view_cols + (ulong)col;
const ulong offO = (ulong)k * (ulong)view_cols + (ulong)col;
const float a = (float)lhs_vals[offL];
const float b = (float)rhs_vals[offR];
out_vals[offO] = (T)(a * b);
const auto a = static_cast<accum_t<T>>(lhs_vals[offL]);
const auto b = static_cast<accum_t<T>>(rhs_vals[offR]);
out_vals[offO] = static_cast<T>(mul(a, b));
}
// One thread per match copies the indices column
@ -321,7 +319,6 @@ INSTANTIATE_FOR_FLOAT_TYPES(INSTANTIATE_FUSED_GATHER_MUL);
#define INSTANTIATE_SPMM_BMM_COO_ROWS_GROUPED(DTYPE) \
template [[host_name("spmm_bmm_coo_rows_grouped_" #DTYPE)]] kernel void \
spmm_bmm_coo_rows_grouped<DTYPE>( \
device const long* rows [[buffer(0)]], \
device const long* cols [[buffer(1)]], \
device const DTYPE* vals [[buffer(2)]], \
device const DTYPE* dense [[buffer(3)]], \

View File

@ -76,14 +76,21 @@ bool priority_order_init_ = false;
// TODO(eqy): more benchmarking to determine whether this should include sm86/89
// Needs to be kept in-sync with test_fused_chocie in test_transformers.py
bool check_prefer_cudnn_attention() {
static const bool prefer_cudnn = c10::utils::check_env("TORCH_CUDNN_SDPA_PREFERRED") != false;
static const bool prefer_cudnn = c10::utils::check_env("TORCH_CUDNN_SDPA_DEPRIORITIZED") != true;
if (!prefer_cudnn) {
return false;
}
#if (defined(CUDNN_VERSION) && (CUDNN_VERSION >= 90900))
auto dprops = at::cuda::getCurrentDeviceProperties();
auto major = dprops->major;
return (major == 9 || major == 10) && !dprops->minor;
try {
auto dprops = at::cuda::getCurrentDeviceProperties();
auto major = dprops->major;
return (major == 9 || major == 10) && !dprops->minor;
} catch (c10::Error const& e) {
#ifdef DEBUG
TORCH_WARN("check_prefer_cudnn_attention() caught exception ", e.what());
#endif
return false;
}
#else
return false;
#endif

View File

@ -202,7 +202,6 @@ supported:
- select_backward
- _trilinear
- linalg_pinv.atol_rtol_tensor
- svd
- logsumexp.out
symint:
- empty.memory_format

View File

@ -37,6 +37,10 @@ TEST(SingletonOrSharedTypePtr, Comparison) {
EXPECT_NE(empty, p);
EXPECT_NE(p, p2);
EXPECT_EQ(empty, empty);
EXPECT_EQ(p, p);
EXPECT_EQ(p2, p2);
}
TEST(SingletonOrSharedTypePtr, SingletonComparison) {
@ -47,6 +51,8 @@ TEST(SingletonOrSharedTypePtr, SingletonComparison) {
c10::TypePtr type = c10::NoneType::get();
EXPECT_NE(type, c10::StringType::get());
EXPECT_NE(type, c10::DeviceObjType::get());
EXPECT_EQ(type, type);
EXPECT_EQ(type, c10::NoneType::get());
}

View File

@ -526,6 +526,41 @@ namespace {
[](const vec& v) { return v.expm1(); },
createDefaultUnaryTestCase<vec>(TestSeed(), false, true));
}
TYPED_TEST(Exponents, ExpU20) {
using vec = TypeParam;
using VT = ValueType<TypeParam>;
using UVT = UvalueType<TypeParam>;
// Explicit edge values
VT v_too_small = VT(-100.0); // much less than -87.3
VT exp_too_small = std::exp(v_too_small);
VT v_neg_edge = VT(-0x1.5d5e2ap+6f); // just at the edge
VT exp_neg_edge = std::exp(v_neg_edge);
VT v_zero = VT(0.0); // middle, normal case
VT exp_zero = std::exp(v_zero);
VT v_pos_edge = VT(0x1.5d5e2ap+6f); // just at the edge
VT exp_pos_edge = std::exp(v_pos_edge);
VT v_too_large = VT(100.0); // much more than 87.3
VT exp_too_large = std::exp(v_too_large);
auto test_case = TestingCase<vec>::getBuilder()
// Randoms in normal range, but the .addCustom() below guarantees we hit the special/fallback cases
.addDomain(CheckWithinDomains<UVT>{{{-100, 100}}, false, getDefaultTolerance<UVT>()})
.addCustom({ {v_too_small}, exp_too_small })
.addCustom({ {v_neg_edge}, exp_neg_edge })
.addCustom({ {v_zero}, exp_zero })
.addCustom({ {v_pos_edge}, exp_pos_edge })
.addCustom({ {v_too_large}, exp_too_large })
.setTrialCount(65536)
.setTestSeed(TestSeed());
test_unary<vec>(
NAME_INFO(exp_u20_edge_cases),
RESOLVE_OVERLOAD(std::exp),
[](const vec& v) { return v.exp_u20(); },
test_case
);
}
TYPED_TEST(ErrorFunctions, Erf) {
using vec = TypeParam;
test_unary<vec>(

View File

@ -58,8 +58,7 @@ def list_benchmarks():
def run_benchmark(
benchmark_name: str,
should_visualize: bool = False,
compile_mode: str = "max-autotune-no-cudagraphs",
script_args,
):
"""Run a specific benchmark."""
if benchmark_name not in BENCHMARK_REGISTRY:
@ -68,29 +67,29 @@ def run_benchmark(
return False
print(f"Running benchmark: {benchmark_name}")
print(f"Torch compile mode: {compile_mode}")
print(f"Torch compile mode: {script_args.compile_mode}")
print("=" * 60)
benchmark_class = BENCHMARK_REGISTRY[benchmark_name]
benchmark = benchmark_class(compile_mode)
benchmark = benchmark_class(script_args)
benchmark.benchmark()
if should_visualize:
if script_args.visualize:
benchmark.visualize()
return True
def run_all_benchmarks(should_visualize: bool = False, compile_mode: str = "default"):
def run_all_benchmarks(script_args):
"""Run all available benchmarks."""
print("Running all benchmarks...")
print(f"Torch compile mode: {compile_mode}")
print(f"Torch compile mode: {script_args.compile_mode}")
print("=" * 60)
for name, cls in BENCHMARK_REGISTRY.items():
print(f"\n{'=' * 20} {name.upper()} {'=' * 20}")
benchmark = cls(compile_mode)
benchmark = cls(script_args)
benchmark.benchmark()
if should_visualize:
if script_args.visualize:
benchmark.visualize()
print()
@ -137,6 +136,19 @@ Examples:
help="Torch compile mode to use (default: default)",
)
parser.add_argument(
"--tolerance",
type=float,
default=None,
help="Tolerance for the accuracy check",
)
parser.add_argument(
"--exit-on-accuracy-failure",
action="store_true",
help="Whether to exit with an error message for accuracy failure",
)
args = parser.parse_args()
# Handle list option
@ -146,7 +158,7 @@ Examples:
# Handle all option
if args.all:
run_all_benchmarks(args.visualize, args.compile_mode)
run_all_benchmarks(args)
return
# Handle specific benchmarks
@ -157,7 +169,7 @@ Examples:
sys.exit(1)
for benchmark_name in args.benchmarks:
run_benchmark(benchmark_name, args.visualize, args.compile_mode)
run_benchmark(benchmark_name, args)
print() # Add spacing between benchmarks

View File

@ -9,8 +9,8 @@ import torch.nn.functional as F
class CrossEntropyForward(BenchmarkKernel):
def __init__(self, compile_mode: str = "max-autotune-no-cudagraphs"):
super().__init__(compile_mode)
def __init__(self, script_args):
super().__init__(script_args)
self.available_backends = ["eager", "compiled", "quack", "liger"]
def get_shapes(self) -> tuple[tuple[int, ...], ...]:
@ -106,8 +106,8 @@ class CrossEntropyForward(BenchmarkKernel):
class CrossEntropyBackward(BenchmarkKernel):
def __init__(self, compile_mode: str = "max-autotune-no-cudagraphs"):
super().__init__(compile_mode)
def __init__(self, script_args):
super().__init__(script_args)
self.available_backends = ["eager", "compiled", "quack", "liger"]
def get_shapes(self) -> tuple[tuple[int, ...], ...]:
@ -194,8 +194,8 @@ class CrossEntropyBackward(BenchmarkKernel):
class SoftmaxForward(BenchmarkKernel):
def __init__(self, compile_mode: str = "max-autotune-no-cudagraphs"):
super().__init__(compile_mode)
def __init__(self, script_args):
super().__init__(script_args)
self.available_backends = ["eager", "compiled", "quack", "liger"]
def get_shapes(self) -> tuple[tuple[int, ...], ...]:
@ -259,8 +259,8 @@ class SoftmaxForward(BenchmarkKernel):
class SoftmaxBackward(BenchmarkKernel):
def __init__(self, compile_mode: str = "max-autotune-no-cudagraphs"):
super().__init__(compile_mode)
def __init__(self, script_args):
super().__init__(script_args)
self.available_backends = ["eager", "compiled", "quack", "liger"]
def get_shapes(self) -> tuple[tuple[int, ...], ...]:
@ -329,8 +329,8 @@ class SoftmaxBackward(BenchmarkKernel):
class RMSNormForward(BenchmarkKernel):
def __init__(self, compile_mode: str = "max-autotune-no-cudagraphs"):
super().__init__(compile_mode)
def __init__(self, script_args):
super().__init__(script_args)
self.available_backends = ["eager", "compiled", "quack", "liger"]
def get_shapes(self) -> tuple[tuple[int, ...], ...]:
@ -383,7 +383,22 @@ class RMSNormForward(BenchmarkKernel):
from quack.rmsnorm import _rmsnorm_fwd
x, w = args
return lambda: _rmsnorm_fwd(x, w, eps=1e-6)
y = torch.empty_like(x)
def quack_fwd():
_rmsnorm_fwd(
x,
w,
out=y,
bias=None,
rstd=None,
residual=None,
residual_out=None,
eps=1e-6,
)
return y
return quack_fwd
def liger(self, args, kwargs) -> Any:
from liger_kernel.transformers.rms_norm import LigerRMSNorm
@ -404,9 +419,14 @@ class RMSNormForward(BenchmarkKernel):
class RMSNormBackward(BenchmarkKernel):
def __init__(self, compile_mode: str = "max-autotune-no-cudagraphs"):
super().__init__(compile_mode)
self.available_backends = ["eager", "compiled", "quack", "liger"]
def __init__(self, script_args):
super().__init__(script_args)
self.available_backends = [
"eager",
"compiled",
"quack",
"liger",
]
def get_shapes(self) -> tuple[tuple[int, ...], ...]:
# TODO: OOM for (32768, 65536) on h100
@ -454,8 +474,11 @@ class RMSNormBackward(BenchmarkKernel):
y, [x, w], grad_outputs=dy, retain_graph=True
)
def compute_rstd(self, x, eps):
return torch.rsqrt(torch.mean(x.float().square(), dim=-1, keepdim=True) + eps)
def quack(self, args, kwargs=None) -> Any:
from quack.rmsnorm import _rmsnorm_backward
from quack.rmsnorm import _get_sm_count, _rmsnorm_bwd
(
x,
@ -463,15 +486,40 @@ class RMSNormBackward(BenchmarkKernel):
dy,
) = args
M, N = x.shape
rstd = torch.randn(M, device="cuda", dtype=torch.float32)
return lambda: _rmsnorm_backward(x, w, dy, rstd)
rstd = self.compute_rstd(x, eps=1e-6)
dx = torch.empty_like(x)
sm_count = _get_sm_count(x.size(1), x.device)
dw_partial = torch.empty(
sm_count, x.size(1), device=x.device, dtype=torch.float32
)
def quack_bwd():
_rmsnorm_bwd(
x,
w,
dy,
rstd,
dx,
dw_partial,
db_partial=None,
dresidual_out=None,
dresidual=None,
sm_count=sm_count,
)
dw = dw_partial.sum(dim=0).to(w.dtype)
return dx, dw
return quack_bwd
def liger(self, args, kwargs=None) -> Any:
from liger_kernel.transformers.rms_norm import LigerRMSNorm
x, w, dy = args
M, N = x.shape
liger_rmsnorm = LigerRMSNorm(hidden_size=N, eps=1e-6).cuda()
liger_rmsnorm = LigerRMSNorm(
hidden_size=N, eps=1e-6, casting_mode="gemma"
).cuda()
liger_rmsnorm.weight.data.copy_(w)
y = liger_rmsnorm(x)
return lambda: torch.autograd.grad(
@ -489,8 +537,8 @@ class RMSNormBackward(BenchmarkKernel):
class LayerNormForward(BenchmarkKernel):
def __init__(self, compile_mode: str = "max-autotune-no-cudagraphs"):
super().__init__(compile_mode)
def __init__(self, script_args):
super().__init__(script_args)
self.available_backends = ["eager", "compiled", "quack", "liger"]
def get_shapes(self) -> tuple[tuple[int, ...], ...]:
@ -563,8 +611,8 @@ class LayerNormForward(BenchmarkKernel):
class LayerNormBackward(BenchmarkKernel):
def __init__(self, compile_mode: str = "max-autotune-no-cudagraphs"):
super().__init__(compile_mode)
def __init__(self, script_args):
super().__init__(script_args)
self.available_backends = ["eager", "compiled", "liger"]
def get_shapes(self) -> tuple[tuple[int, ...], ...]:
@ -614,20 +662,31 @@ class LayerNormBackward(BenchmarkKernel):
y, [x, w], grad_outputs=dy, retain_graph=True
)
def compute_mean_rstd(self, x, eps):
x = x.float()
var, mean = torch.var_mean(x, dim=-1, keepdim=True, correction=0)
rstd = torch.rsqrt(var + eps)
return mean, rstd
def liger(self, args, kwargs) -> Any:
from liger_kernel.transformers.layer_norm import LigerLayerNorm
"""
Call layer_norm_backward directly rather than calling
liger_kernel.transformers.layer_norm.LigerLayerNorm and
torch.autograd.grad.
The latter fashion saves mean/rstd in x.dtype which can fail
accuracy test. We call layer_norm_backward with fp32 mean and
rstd.
"""
from liger_kernel.ops.layer_norm import layer_norm_backward
x, w, dy = args
eps = 1e-6
mean, rstd = self.compute_mean_rstd(x, eps)
M, N = x.shape
liger_layernorm = LigerLayerNorm(hidden_size=N, eps=1e-6).cuda()
liger_layernorm.weight.data.copy_(w)
liger_layernorm.bias.data.copy_(
torch.zeros(N, device="cuda", dtype=torch.float32)
)
y = liger_layernorm(x)
return lambda: torch.autograd.grad(
y, [x, liger_layernorm.weight], grad_outputs=dy, retain_graph=True
)
return lambda: layer_norm_backward(dy, x, w, None, mean, rstd)[0:2]
def benchmark(self):
for M, N in self.get_shapes():

View File

@ -1,4 +1,5 @@
import os
import sys
from collections import defaultdict
from collections.abc import Callable
from dataclasses import dataclass
@ -43,10 +44,11 @@ class Performance:
class BenchmarkKernel:
def __init__(self, compile_mode: str = "max-autotune-no-cudagraphs"):
def __init__(self, script_args):
self.script_args = script_args
self.name = self.__class__.__name__
self.available_backends: list[str] = []
self.compile_mode: str = compile_mode
self.compile_mode: str = script_args.compile_mode
# mapping from backend to list of performance results
self.profiling_results: defaultdict[str, list[Performance]] = defaultdict(list)
@ -106,14 +108,21 @@ class BenchmarkKernel:
args_ref, kwargs_ref = self.clone_inputs(args, kwargs)
res[backend] = getattr(self, backend)(args_ref, kwargs_ref)()
gold = res["eager"]
tol = {}
if self.script_args.tolerance:
tol = {
"atol": self.script_args.tolerance,
"rtol": self.script_args.tolerance,
}
for backend in self.available_backends:
if backend == "eager":
continue
try:
torch.testing.assert_close(res[backend], gold)
torch.testing.assert_close(res[backend], gold, **tol)
for t, gold_t in zip(res[backend], gold):
if t.requires_grad:
torch.testing.assert_close(t.grad, gold_t.grad)
torch.testing.assert_close(t.grad, gold_t.grad, **tol)
print(
f"Accuracy check \033[92m✓ succeed\033[0m for {backend} backend on {self.name} kernel"
)
@ -121,6 +130,9 @@ class BenchmarkKernel:
print(
f"Accuracy check \033[91m✗ failed\033[0m for {backend} backend on {self.name} kernel. Error {e}"
)
if self.script_args.exit_on_accuracy_failure:
print("Exit right away since --exit-on-accuracy-failure is set")
sys.exit(1)
def benchmark_single_shape(
self, args, kwargs=None, should_check_accuracy=True, setting: str = ""

View File

@ -1,8 +1,8 @@
add_loop_eager,compile_time_instruction_count,3070000000,0.1
add_loop_eager,compile_time_instruction_count,3184000000,0.1
add_loop_eager_dynamic,compile_time_instruction_count,4432000000,0.1
add_loop_eager_dynamic,compile_time_instruction_count,4595000000,0.1
@ -18,7 +18,7 @@ add_loop_inductor_gpu,compile_time_instruction_count,26800000000,0.1
basic_modules_ListOfLinears_eager,compile_time_instruction_count,1048000000,0.1
basic_modules_ListOfLinears_eager,compile_time_instruction_count,1096000000,0.1
@ -26,7 +26,7 @@ basic_modules_ListOfLinears_inductor,compile_time_instruction_count,15240000000,
basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,17020000000,0.1
basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,17720000000,0.1
@ -34,11 +34,11 @@ basic_modules_ListOfLinears_inductor_gpu,compile_time_instruction_count,11090000
update_hint_regression,compile_time_instruction_count,1719000000,0.1
update_hint_regression,compile_time_instruction_count,1645000000,0.1
sum_floordiv_regression,compile_time_instruction_count,3686995725,0.1
sum_floordiv_regression,compile_time_instruction_count,3813000000,0.1
@ -50,31 +50,31 @@ symint_sum_loop,compile_time_instruction_count,4299000000,0.1
aotdispatcher_inference_nosubclass_cpu,compile_time_instruction_count,1869000000,0.1
aotdispatcher_inference_nosubclass_cpu,compile_time_instruction_count,1793000000,0.1
aotdispatcher_inference_subclass_cpu,compile_time_instruction_count,5281000000,0.1
aotdispatcher_inference_subclass_cpu,compile_time_instruction_count,5120000000,0.1
aotdispatcher_partitioner_cpu,compile_time_instruction_count,8333000000,0.1
aotdispatcher_partitioner_cpu,compile_time_instruction_count,7936000000,0.1
aotdispatcher_partitioner_cpu2,compile_time_instruction_count,1909000000,0.1
aotdispatcher_partitioner_cpu2,compile_time_instruction_count,1848000000,0.1
aotdispatcher_training_nosubclass_cpu,compile_time_instruction_count,3442000000,0.1
aotdispatcher_training_nosubclass_cpu,compile_time_instruction_count,3152000000,0.1
aotdispatcher_training_subclass_cpu,compile_time_instruction_count,9239000000,0.1
aotdispatcher_training_subclass_cpu,compile_time_instruction_count,8301000000,0.1
mm_loop_inductor_gpu,compile_time_instruction_count,4820968837,0.1
mm_loop_inductor_gpu,compile_time_instruction_count,4958000000,0.1
@ -82,8 +82,8 @@ mm_loop_inductor_dynamic_gpu,compile_time_instruction_count,9051000000,0.1
basic_NestedModule_eager,compile_time_instruction_count,9554000000,0.1
basic_NestedModule_eager,compile_time_instruction_count,9990000000,0.1
basic_InlineMod_eager,compile_time_instruction_count,7618000000,0.1
basic_InlineMod_eager,compile_time_instruction_count,8126000000,0.1

1 add_loop_eager compile_time_instruction_count 3070000000 3184000000 0.1
2 add_loop_eager_dynamic compile_time_instruction_count 4432000000 4595000000 0.1
3 add_loop_inductor compile_time_instruction_count 29660000000 29660000000 0.1
4 add_loop_inductor_dynamic_gpu compile_time_instruction_count 39910000000 39910000000 0.1
5 add_loop_inductor_gpu compile_time_instruction_count 26800000000 26800000000 0.1
6 basic_modules_ListOfLinears_eager compile_time_instruction_count 1048000000 1096000000 0.1
7 basic_modules_ListOfLinears_inductor compile_time_instruction_count 15240000000 15240000000 0.1
8 basic_modules_ListOfLinears_inductor_gpu_force_shape_pad compile_time_instruction_count 17020000000 17720000000 0.1
18 aotdispatcher_training_nosubclass_cpu compile_time_instruction_count 3442000000 3152000000 0.1
19 aotdispatcher_training_subclass_cpu compile_time_instruction_count 9239000000 8301000000 0.1
20 mm_loop_inductor_gpu compile_time_instruction_count 4820968837 4958000000 0.1
21 mm_loop_inductor_dynamic_gpu compile_time_instruction_count 9051000000 9051000000 0.1
22 basic_NestedModule_eager compile_time_instruction_count 9554000000 9990000000 0.1
23 basic_InlineMod_eager compile_time_instruction_count 7618000000 8126000000 0.1
24
26
27
28
29
30
31
32
34
35
36
37
38
39
40
41
42
43
44
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
82
83
84
85
86
87
88
89

View File

@ -43,6 +43,7 @@ tolerance:
- doctr_reco_predictor
- drq
- phlippe_resnet
- pytorch_CycleGAN_and_pix2pix
higher_bf16:
- doctr_reco_predictor

View File

@ -44,21 +44,101 @@ PyTorch,div_,div__M1_N1_K1_cpu_dtype_onetorch.float32_dtype_twotorch.float32,sho
PyTorch,div_,div__M64_N64_K64_cpu_dtype_onetorch.float32_dtype_twotorch.float32,short,False,59.241161,0.000000
PyTorch,div_,div__M64_N64_K128_cpu_dtype_onetorch.float32_dtype_twotorch.float32,short,False,59.852816,0.000000
PyTorch,add,"add_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.float32",short,False,57.006677,0.000000
PyTorch,add,"add_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.bfloat16",short,False,88.167000,0.000000
PyTorch,add,"add_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.float64",short,False,57.519000,0.000000
PyTorch,sub,"sub_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.float32",short,False,55.606088,0.000000
PyTorch,sub,"sub_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.bfloat16",short,False,86.551000,0.000000
PyTorch,sub,"sub_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.float64",short,False,57.864088,0.000000
PyTorch,div,"div_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.float32",short,False,58.529255,0.000000
PyTorch,div,"div_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.bfloat16",short,False,71.641000,0.000000
PyTorch,div,"div_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.float64",short,False,83.073000,0.000000
PyTorch,mul,"mul_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.float32",short,False,54.645077,0.000000
PyTorch,mul,"mul_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.bfloat16",short,False,67.570000,0.000000
PyTorch,mul,"mul_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.float64",short,False,57.895000,0.000000
PyTorch,add,add_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,4.397014,0.000000
PyTorch,add,add_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,7.739000,0.000000
PyTorch,add,add_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,7.786000,0.000000
PyTorch,add,add_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,1.911000,0.000000
PyTorch,add,add_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,59.243500,0.000000
PyTorch,add,add_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,105.066000,0.000000
PyTorch,add,add_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,106.076000,0.000000
PyTorch,add,add_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,47.225000,0.000000
PyTorch,add,add_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,57.947691,0.000000
PyTorch,add,add_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,107.291000,0.000000
PyTorch,add,add_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,107.224000,0.000000
PyTorch,add,add_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,47.912000,0.000000
PyTorch,sub,sub_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,1.925851,0.000000
PyTorch,sub,sub_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,8.0240000,0.000000
PyTorch,sub,sub_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,8.069000,0.000000
PyTorch,sub,sub_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,1.938000,0.000000
PyTorch,sub,sub_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,57.308320,0.000000
PyTorch,sub,sub_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,107.091000,0.000000
PyTorch,sub,sub_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,108.710000,0.000000
PyTorch,sub,sub_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,47.502000,0.000000
PyTorch,sub,sub_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,57.787743,0.000000
PyTorch,sub,sub_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,108.863000,0.000000
PyTorch,sub,sub_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,108.939000,0.000000
PyTorch,sub,sub_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,47.603000,0.000000
PyTorch,div,div_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,7.978539,0.000000
PyTorch,div,div_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,8.741000,0.000000
PyTorch,div,div_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,8.757000,0.000000
PyTorch,div,div_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,8.774000,0.000000
PyTorch,div,div_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,159.754860,0.000000
PyTorch,div,div_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,165.552000,0.000000
PyTorch,div,div_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,165.755000,0.000000
PyTorch,div,div_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,165.714000,0.000000
PyTorch,div,div_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,165.360235,0.000000
PyTorch,div,div_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,168.376000,0.000000
PyTorch,div,div_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,169.604000,0.000000
PyTorch,div,div_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,168.428000,0.000000
PyTorch,mul,mul_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,3.928136,0.000000
PyTorch,mul,mul_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,7.402000,0.000000
PyTorch,mul,mul_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,7.567000,0.000000
PyTorch,mul,mul_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,4.020000,0.000000
PyTorch,mul,mul_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,56.413499,0.000000
PyTorch,mul,mul_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,104.638000,0.000000
PyTorch,mul,mul_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,104.335000,0.000000
PyTorch,mul,mul_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,48.612000,0.000000
PyTorch,mul,mul_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,55.925090,0.000000
PyTorch,mul,mul_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,106.110000,0.000000
PyTorch,mul,mul_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,106.389000,0.000000
PyTorch,mul,mul_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,48.195000,0.000000
PyTorch,asr,asr_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,1.989000,0.000000
PyTorch,asr,asr_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,7.999000,0.000000
PyTorch,asr,asr_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,7.939000,0.000000
PyTorch,asr,asr_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,1.980000,0.000000
PyTorch,asr,asr_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,54.408000,0.000000
PyTorch,asr,asr_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,105.647000,0.000000
PyTorch,asr,asr_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,106.476000,0.000000
PyTorch,asr,asr_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,48.784000,0.000000
PyTorch,asr,asr_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,55.583000,0.000000
PyTorch,asr,asr_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,108.083000,0.000000
PyTorch,asr,asr_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,107.663000,0.000000
PyTorch,asr,asr_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,48.283000,0.000000
PyTorch,lsl,lsl_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,1.986000,0.000000
PyTorch,lsl,lsl_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,7.676000,0.000000
PyTorch,lsl,lsl_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,7.618000,0.000000
PyTorch,lsl,lsl_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,1.982000,0.000000
PyTorch,lsl,lsl_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,54.698000,0.000000
PyTorch,lsl,lsl_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,105.899000,0.000000
PyTorch,lsl,lsl_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,106.741000,0.000000
PyTorch,lsl,lsl_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,51.182000,0.000000
PyTorch,lsl,lsl_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,55.290000,0.000000
PyTorch,lsl,lsl_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,107.744000,0.000000
PyTorch,lsl,lsl_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,107.820000,0.000000
PyTorch,lsl,lsl_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,51.298000,0.000000
PyTorch,xor,xor_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,1.988000,0.000000
PyTorch,xor,xor_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,7.689000,0.000000
PyTorch,xor,xor_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,7.695000,0.000000
PyTorch,xor,xor_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,1.978000,0.000000
PyTorch,xor,xor_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,54.934000,0.000000
PyTorch,xor,xor_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,105.217000,0.000000
PyTorch,xor,xor_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,104.215000,0.000000
PyTorch,xor,xor_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,47.115000,0.000000
PyTorch,xor,xor_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,55.974000,0.000000
PyTorch,xor,xor_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,106.828000,0.000000
PyTorch,xor,xor_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,106.879000,0.000000
PyTorch,xor,xor_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,48.197000,0.000000
PyTorch,logical_and,"logical_and_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.bool",short,False,78.404254,0.000000
PyTorch,logical_and,logical_and_M1_N1_K1_cpu_dtype_onetorch.bool_dtype_twotorch.bool,short,False,5.354032,0.000000
PyTorch,logical_and,logical_and_M64_N64_K64_cpu_dtype_onetorch.bool_dtype_twotorch.bool,short,False,54.072783,0.000000
@ -71,6 +151,9 @@ PyTorch,baddbmm,baddbmm_B2_M1_N8_K2_cpu_dtypetorch.float32,short,False,6.631313,
PyTorch,baddbmm,baddbmm_B2_M1_N8_K2_cpu_dtypetorch.bfloat16,short,False,6.476986,0.000000
PyTorch,baddbmm,baddbmm_B128_M64_N32_K64_cpu_dtypetorch.float32,short,False,266.065131,0.000000
PyTorch,baddbmm,baddbmm_B128_M64_N32_K64_cpu_dtypetorch.bfloat16,short,False,295.503063,0.000000
PyTorch,all,all_M1_N1_K1_cpu,short,False,5.773000,0.000000
PyTorch,all,all_M64_N64_K64_cpu,short,False,89.427000,0.000000
PyTorch,all,all_M64_N64_K128_cpu,short,False,120.119000,0.000000
PyTorch,cat,"cat_sizes(1,1,1)_N2_dim0_cpu",short,False,4.301950,0.000000
PyTorch,cat,"cat_sizes(512,512,2)_N2_dim1_cpu",short,False,99.093415,0.000000
PyTorch,cat,"cat_sizes(128,1024,2)_N2_dim1_cpu",short,False,96.771578,0.000000

1 Benchmarking Framework Benchmarking Module Name Case Name tag run_backward Execution Time Peak Memory (KB)
44 PyTorch div_ div__M64_N64_K64_cpu_dtype_onetorch.float32_dtype_twotorch.float32 short False 59.241161 0.000000
45 PyTorch div_ div__M64_N64_K128_cpu_dtype_onetorch.float32_dtype_twotorch.float32 short False 59.852816 0.000000
46 PyTorch add add_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.float32 short False 57.006677 0.000000
47 PyTorch add add_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.bfloat16 short False 88.167000 0.000000
48 PyTorch add add_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.float64 short False 57.519000 0.000000
49 PyTorch sub sub_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.float32 short False 55.606088 0.000000
50 PyTorch sub sub_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.bfloat16 short False 86.551000 0.000000
51 PyTorch sub sub_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.float64 short False 57.864088 0.000000
52 PyTorch div div_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.float32 short False 58.529255 0.000000
53 PyTorch div div_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.bfloat16 short False 71.641000 0.000000
54 PyTorch div div_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.float64 short False 83.073000 0.000000
55 PyTorch mul mul_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.float32 short False 54.645077 0.000000
56 PyTorch mul mul_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.bfloat16 short False 67.570000 0.000000
57 PyTorch mul mul_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.float64 short False 57.895000 0.000000
58 PyTorch add add_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.int32 short False 4.397014 0.000000
59 PyTorch add add_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.uint8 short False 7.739000 0.000000
60 PyTorch add add_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.int32 short False 7.786000 0.000000
61 PyTorch add add_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8 short False 1.911000 0.000000
62 PyTorch add add_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.int32 short False 59.243500 0.000000
63 PyTorch add add_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.uint8 short False 105.066000 0.000000
64 PyTorch add add_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.int32 short False 106.076000 0.000000
65 PyTorch add add_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8 short False 47.225000 0.000000
66 PyTorch add add_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.int32 short False 57.947691 0.000000
67 PyTorch add add_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.uint8 short False 107.291000 0.000000
68 PyTorch add add_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.int32 short False 107.224000 0.000000
69 PyTorch add add_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8 short False 47.912000 0.000000
70 PyTorch sub sub_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.int32 short False 1.925851 0.000000
71 PyTorch sub sub_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.uint8 short False 8.0240000 0.000000
72 PyTorch sub sub_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.int32 short False 8.069000 0.000000
73 PyTorch sub sub_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8 short False 1.938000 0.000000
74 PyTorch sub sub_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.int32 short False 57.308320 0.000000
75 PyTorch sub sub_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.uint8 short False 107.091000 0.000000
76 PyTorch sub sub_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.int32 short False 108.710000 0.000000
77 PyTorch sub sub_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8 short False 47.502000 0.000000
78 PyTorch sub sub_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.int32 short False 57.787743 0.000000
79 PyTorch sub sub_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.uint8 short False 108.863000 0.000000
80 PyTorch sub sub_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.int32 short False 108.939000 0.000000
81 PyTorch sub sub_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8 short False 47.603000 0.000000
82 PyTorch div div_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.int32 short False 7.978539 0.000000
83 PyTorch div div_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.uint8 short False 8.741000 0.000000
84 PyTorch div div_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.int32 short False 8.757000 0.000000
85 PyTorch div div_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8 short False 8.774000 0.000000
86 PyTorch div div_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.int32 short False 159.754860 0.000000
87 PyTorch div div_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.uint8 short False 165.552000 0.000000
88 PyTorch div div_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.int32 short False 165.755000 0.000000
89 PyTorch div div_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8 short False 165.714000 0.000000
90 PyTorch div div_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.int32 short False 165.360235 0.000000
91 PyTorch div div_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.uint8 short False 168.376000 0.000000
92 PyTorch div div_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.int32 short False 169.604000 0.000000
93 PyTorch div div_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8 short False 168.428000 0.000000
94 PyTorch mul mul_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.int32 short False 3.928136 0.000000
95 PyTorch mul mul_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.uint8 short False 7.402000 0.000000
96 PyTorch mul mul_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.int32 short False 7.567000 0.000000
97 PyTorch mul mul_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8 short False 4.020000 0.000000
98 PyTorch mul mul_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.int32 short False 56.413499 0.000000
99 PyTorch mul mul_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.uint8 short False 104.638000 0.000000
100 PyTorch mul mul_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.int32 short False 104.335000 0.000000
101 PyTorch mul mul_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8 short False 48.612000 0.000000
102 PyTorch mul mul_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.int32 short False 55.925090 0.000000
103 PyTorch mul mul_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.uint8 short False 106.110000 0.000000
104 PyTorch mul mul_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.int32 short False 106.389000 0.000000
105 PyTorch mul mul_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8 short False 48.195000 0.000000
106 PyTorch asr asr_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.int32 short False 1.989000 0.000000
107 PyTorch asr asr_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.uint8 short False 7.999000 0.000000
108 PyTorch asr asr_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.int32 short False 7.939000 0.000000
109 PyTorch asr asr_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8 short False 1.980000 0.000000
110 PyTorch asr asr_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.int32 short False 54.408000 0.000000
111 PyTorch asr asr_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.uint8 short False 105.647000 0.000000
112 PyTorch asr asr_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.int32 short False 106.476000 0.000000
113 PyTorch asr asr_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8 short False 48.784000 0.000000
114 PyTorch asr asr_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.int32 short False 55.583000 0.000000
115 PyTorch asr asr_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.uint8 short False 108.083000 0.000000
116 PyTorch asr asr_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.int32 short False 107.663000 0.000000
117 PyTorch asr asr_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8 short False 48.283000 0.000000
118 PyTorch lsl lsl_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.int32 short False 1.986000 0.000000
119 PyTorch lsl lsl_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.uint8 short False 7.676000 0.000000
120 PyTorch lsl lsl_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.int32 short False 7.618000 0.000000
121 PyTorch lsl lsl_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8 short False 1.982000 0.000000
122 PyTorch lsl lsl_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.int32 short False 54.698000 0.000000
123 PyTorch lsl lsl_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.uint8 short False 105.899000 0.000000
124 PyTorch lsl lsl_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.int32 short False 106.741000 0.000000
125 PyTorch lsl lsl_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8 short False 51.182000 0.000000
126 PyTorch lsl lsl_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.int32 short False 55.290000 0.000000
127 PyTorch lsl lsl_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.uint8 short False 107.744000 0.000000
128 PyTorch lsl lsl_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.int32 short False 107.820000 0.000000
129 PyTorch lsl lsl_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8 short False 51.298000 0.000000
130 PyTorch xor xor_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.int32 short False 1.988000 0.000000
131 PyTorch xor xor_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.uint8 short False 7.689000 0.000000
132 PyTorch xor xor_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.int32 short False 7.695000 0.000000
133 PyTorch xor xor_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8 short False 1.978000 0.000000
134 PyTorch xor xor_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.int32 short False 54.934000 0.000000
135 PyTorch xor xor_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.uint8 short False 105.217000 0.000000
136 PyTorch xor xor_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.int32 short False 104.215000 0.000000
137 PyTorch xor xor_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8 short False 47.115000 0.000000
138 PyTorch xor xor_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.int32 short False 55.974000 0.000000
139 PyTorch xor xor_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.uint8 short False 106.828000 0.000000
140 PyTorch xor xor_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.int32 short False 106.879000 0.000000
141 PyTorch xor xor_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8 short False 48.197000 0.000000
142 PyTorch logical_and logical_and_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.bool short False 78.404254 0.000000
143 PyTorch logical_and logical_and_M1_N1_K1_cpu_dtype_onetorch.bool_dtype_twotorch.bool short False 5.354032 0.000000
144 PyTorch logical_and logical_and_M64_N64_K64_cpu_dtype_onetorch.bool_dtype_twotorch.bool short False 54.072783 0.000000
151 PyTorch baddbmm baddbmm_B2_M1_N8_K2_cpu_dtypetorch.bfloat16 short False 6.476986 0.000000
152 PyTorch baddbmm baddbmm_B128_M64_N32_K64_cpu_dtypetorch.float32 short False 266.065131 0.000000
153 PyTorch baddbmm baddbmm_B128_M64_N32_K64_cpu_dtypetorch.bfloat16 short False 295.503063 0.000000
154 PyTorch all all_M1_N1_K1_cpu short False 5.773000 0.000000
155 PyTorch all all_M64_N64_K64_cpu short False 89.427000 0.000000
156 PyTorch all all_M64_N64_K128_cpu short False 120.119000 0.000000
157 PyTorch cat cat_sizes(1,1,1)_N2_dim0_cpu short False 4.301950 0.000000
158 PyTorch cat cat_sizes(512,512,2)_N2_dim1_cpu short False 99.093415 0.000000
159 PyTorch cat cat_sizes(128,1024,2)_N2_dim1_cpu short False 96.771578 0.000000

View File

@ -580,6 +580,9 @@ class BenchmarkRunner:
else "unknown"
)
# Extract operator name from test_name
operator_name = test_name.split("_")[0]
# Create the record
@dataclass
class BenchmarkInfo:
@ -593,6 +596,7 @@ class BenchmarkRunner:
name: str
type: str
origins: list[str]
extra_info: dict[str, Any]
@dataclass
class MetricInfo:
@ -618,10 +622,14 @@ class BenchmarkRunner:
"device": device,
"arch": device_arch,
"use_compile": use_compile,
"operator_name": operator_name,
},
),
model=ModelInfo(
name=test_name, type="micro-benchmark", origins=["pytorch"]
name=test_name,
type="micro-benchmark",
origins=["pytorch"],
extra_info={"operator_name": operator_name},
),
metric=MetricInfo(
name="latency",

View File

@ -25,7 +25,7 @@ binary_configs_broadcast = op_bench.config_list(
],
cross_product_configs={
"device": ["cpu"],
"dtype": [torch.float],
"dtype": [torch.float, torch.bfloat16, torch.float64],
},
tags=["short"],
)
@ -71,8 +71,8 @@ binary_short_configs = op_bench.config_list(
],
cross_product_configs={
"device": ["cpu", "cuda"],
"dtype_one": [torch.int32],
"dtype_two": [torch.int32],
"dtype_one": [torch.int32, torch.uint8],
"dtype_two": [torch.int32, torch.uint8],
},
tags=["short"],
)
@ -82,8 +82,8 @@ binary_long_configs = op_bench.cross_product_configs(
N=[32, 64],
K=[256, 512],
device=["cpu", "cuda"],
dtype_one=[torch.int8, torch.int32],
dtype_two=[torch.int8, torch.int32],
dtype_one=[torch.int8, torch.int32, torch.uint8],
dtype_two=[torch.int8, torch.int32, torch.uint8],
tags=["long"],
)

View File

@ -176,8 +176,8 @@ THIRD_PARTY_LIBS = {
"omp": ["//xplat/third-party/linker_lib:omp", "//third_party:no-op"],
"pocketfft": ["//third-party/pocket_fft:pocketfft", "//third_party:pocketfft_header"],
"psimd": ["//xplat/third-party/psimd:psimd", "//third_party:psimd"],
"pthreadpool": ["//xplat/third-party/pthreadpool:pthreadpool", "//third_party:pthreadpool"],
"pthreadpool_header": ["//xplat/third-party/pthreadpool:pthreadpool_header", "//third_party:pthreadpool_header"],
"pthreadpool": ["fbsource//xplat/third-party/pthreadpool:pthreadpool", "//third_party:pthreadpool"],
"pthreadpool_header": ["fbsource//xplat/third-party/pthreadpool:pthreadpool_header", "//third_party:pthreadpool_header"],
"moodycamel": ["//third-party/moodycamel:moodycamel", "//third_party:moodycamel"],
"pyyaml": ["//third-party/pypi/pyyaml:pyyaml", "//third_party:pyyaml"],
"rt": ["//xplat/third-party/linker_lib:rt", "//third_party:rt"],
@ -1729,8 +1729,10 @@ def define_buck_targets(
"torch/csrc/jit/backends/backend_debug_info.cpp",
"torch/csrc/jit/backends/backend_interface.cpp",
],
compiler_flags = get_pt_compiler_flags(),
fbandroid_compiler_flags = c2_fbandroid_xplat_compiler_flags,
compiler_flags = get_pt_compiler_flags() + select({
"DEFAULT": [],
"ovr_config//os:android": c2_fbandroid_xplat_compiler_flags
}),
# @lint-ignore BUCKLINT link_whole
link_whole = True,
linker_flags = get_no_as_needed_linker_flag(),
@ -2023,6 +2025,9 @@ def define_buck_targets(
"ovr_config//os:android-x86_64": [
"-mssse3",
],
}) + select({
"DEFAULT": [],
"ovr_config//os:android": c2_fbandroid_xplat_compiler_flags,
}),
exported_preprocessor_flags = get_aten_preprocessor_flags(),
exported_deps = [

View File

@ -855,6 +855,7 @@ libtorch_python_cuda_core_sources = [
"torch/csrc/cuda/Stream.cpp",
"torch/csrc/cuda/Graph.cpp",
"torch/csrc/cuda/MemPool.cpp",
"torch/csrc/cuda/GreenContext.cpp",
"torch/csrc/cuda/shared/cudart.cpp",
"torch/csrc/cuda/shared/nvtx.cpp",
"torch/csrc/cuda/utils.cpp",

View File

@ -9,6 +9,7 @@
#include <c10/core/Device.h>
#include <c10/core/DeviceType.h>
#include <c10/core/alignment.h>
#include <c10/macros/Export.h>
#include <c10/macros/Macros.h>
#include <c10/util/Exception.h>

View File

@ -13,7 +13,17 @@
namespace c10::CachingAllocator {
// "large" allocations may be packed in 20 MiB blocks
const size_t kLargeBuffer = 20971520;
constexpr size_t kLargeBuffer = 20971520;
// "small" allocations are packed in 2 MiB blocks
constexpr size_t kSmallBuffer = 2097152;
// all sizes are rounded to at least 512 bytes
constexpr size_t kMinBlockSize = 512;
// largest "small" allocation is 1 MiB
constexpr size_t kSmallSize = 1048576;
// allocations between 1 and 10 MiB may use kLargeBuffer
constexpr size_t kMinLargeAlloc = 10485760;
// round up large allocations to 2 MiB
constexpr size_t kRoundLarge = 2097152;
// A utility class for tokenizing allocator configuration strings into discrete
// parts. For example, the config string:

View File

@ -223,7 +223,7 @@ inline DispatchKey backendToDispatchKey(Backend b) {
case Backend::PrivateUse1:
return DispatchKey::PrivateUse1;
default:
throw std::runtime_error("Unknown backend");
TORCH_CHECK(false, "Unknown backend");
}
}

View File

@ -52,7 +52,9 @@ constexpr DispatchKeySet math_dispatch_keyset = backend_dispatch_keyset |
// where we would like to support composite implicit kernels but not
// explicit kernels therefore we manually add the key to the
// math_dispatch_keyset
DispatchKeySet{DispatchKey::NestedTensor};
DispatchKeySet{DispatchKey::NestedTensor} |
// Functionalize should always reuse CompositeImplicit decomps.
DispatchKeySet{DispatchKey::Functionalize};
constexpr DispatchKeySet nested_dispatch_keyset =
DispatchKeySet(

View File

@ -336,7 +336,7 @@ class C10_API Scalar {
} else if (isBoolean()) {
return ScalarType::Bool;
} else {
throw std::runtime_error("Unknown scalar type.");
TORCH_CHECK(false, "Unknown scalar type.");
}
}

Some files were not shown because too many files have changed in this diff Show More